In [1]:
import numpy as np
import torch
from models.multitask_cnn import SoftXOR, MultitaskCNNClassifier

xor_fn = SoftXOR()

In [2]:
for _ in range(10):
    x_idx, y_idx = (np.random.randint(256) for _ in range(2))
    x, y = (torch.zeros(1, 256, dtype=torch.float) for _ in range(2))
    x[:, x_idx] = 1.0
    y[:, y_idx] = 1.0
    out = xor_fn(x, y)
    print(f'{hex(x_idx)} XOR {hex(y_idx)} = {hex(out.argmax(dim=-1).item())}')

0x56 XOR 0x51 = 0x7
0xfc XOR 0x99 = 0x65
0x9a XOR 0x5d = 0xc7
0xf2 XOR 0xc7 = 0x35
0x3b XOR 0x8f = 0xb4
0x8 XOR 0x38 = 0x30
0x30 XOR 0x3 = 0x33
0xd XOR 0xdf = 0xd2
0x33 XOR 0x81 = 0xb2
0xc0 XOR 0x5b = 0x9b


In [5]:
model = MultitaskCNNClassifier((1, 100000), 256)
print(model)
eg_x = torch.randn(16, 1, 100000)
print(f'{eg_x.shape} -> {model(eg_x).shape}')
print(f'Parameter count: {sum(p.numel() for p in model.parameters())}')

MultitaskCNNClassifier(
  (xor_fn): SoftXOR()
  (shared_feature_extractor): Sequential(
    (0): Sequential(
      (0): Conv1d(1, 16, kernel_size=(11,), stride=(4,), padding=(5,), bias=False)
      (1): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SELU()
    )
    (1): Sequential(
      (0): Conv1d(16, 32, kernel_size=(11,), stride=(4,), padding=(5,), bias=False)
      (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SELU()
    )
    (2): Sequential(
      (0): Conv1d(32, 64, kernel_size=(11,), stride=(4,), padding=(5,), bias=False)
      (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SELU()
    )
  )
  (split_feature_extractors): ModuleList(
    (0-16): 17 x Sequential(
      (0): Sequential(
        (0): Conv1d(64, 256, kernel_size=(11,), stride=(4,), padding=(5,), bias=False)
        (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, 