In [1]:
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torch import nn
from tqdm.notebook import tqdm

In [2]:
import torch


device = torch.device("mps")

In [3]:
altmodel = torch.nn.Sequential(
    nn.Linear(3, 8),
    nn.ReLU(),
    nn.Linear(8, 8),
    nn.ReLU(),
    nn.Linear(8, 2)
)

In [4]:
from miexp.bfuncs import MajDataset


dataset = MajDataset(3, num_samples=8)

In [5]:
list(dataset)

[(tensor([0., 0., 0.]), tensor(0, dtype=torch.int32)),
 (tensor([0., 0., 1.]), tensor(0, dtype=torch.int32)),
 (tensor([1., 1., 0.]), tensor(1, dtype=torch.int32)),
 (tensor([1., 1., 1.]), tensor(1, dtype=torch.int32)),
 (tensor([0., 0., 0.]), tensor(0, dtype=torch.int32)),
 (tensor([0., 0., 0.]), tensor(0, dtype=torch.int32)),
 (tensor([1., 0., 1.]), tensor(1, dtype=torch.int32)),
 (tensor([1., 1., 1.]), tensor(1, dtype=torch.int32))]

In [6]:
from miexp.models.btransformer import BooleanTransformer


model = BooleanTransformer(max_seq_len=3, hidden_dim=4, n_heads=2, num_classifier_hidden_layers=3)

In [7]:
sum(p.numel() for p in model.parameters())

230

In [8]:
def train_epoch(model: nn.Module, optimizer: Optimizer, dataloader: DataLoader, device: torch.device, criterion: nn.Module) -> dict[str, float | None]:
    model = model.to(device)
    total_train_loss = 0
    total_train_acc = 0
    total_items = 0
    for input, labels in dataloader:
        input = input.to(device)
        labels = labels.to(device)
        output = model(input)
        loss = criterion(output, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_train_loss += (loss.item()) * len(input)
        total_train_acc += torch.sum(torch.argmax(output, dim=1) == labels).item()
        total_items += len(input)
    return {
        "acc": total_train_acc / total_items,
        "loss": total_train_loss / total_items,
        **{f"norm/{name}": torch.norm(param.grad).item() for name, param in model.named_parameters() if param.grad is not None}
    }

In [9]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
dataloader = DataLoader(dataset, batch_size=256, shuffle=True)

In [10]:
model

BooleanTransformer(
  (embedding): Embedding(3, 4)
  (pos_embedding): Embedding(3, 4)
  (transformer_layer): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=4, out_features=4, bias=True)
    )
    (linear1): Linear(in_features=4, out_features=4, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=4, out_features=4, bias=True)
    (norm1): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (classifier): Sequential(
    (0): Linear(in_features=4, out_features=4, bias=True)
    (1): ReLU()
    (2): Linear(in_features=4, out_features=4, bias=True)
    (3): ReLU()
    (4): Linear(in_features=4, out_features=4, bias=True)
    (5): ReLU()
    (6): Linear(in_features=4, out_features=2, bias=True)
  )
)

In [11]:
results = []
for epoch in tqdm(range(500)):
    results.append(train_epoch(model, optimizer, dataloader, device, criterion))

  0%|          | 0/500 [00:00<?, ?it/s]

[[2, 1, 1, 0], [2, 0, 0, 0], [2, 1, 1, 1], [2, 1, 0, 1], [2, 0, 0, 0], [2, 0, 0, 1], [2, 0, 0, 0], [2, 1, 1, 1]]
[[2, 0, 0, 0], [2, 0, 0, 0], [2, 0, 0, 0], [2, 1, 1, 1], [2, 1, 0, 1], [2, 0, 0, 1], [2, 1, 1, 1], [2, 1, 1, 0]]
[[2, 1, 1, 1], [2, 1, 1, 0], [2, 0, 0, 0], [2, 0, 0, 0], [2, 0, 0, 1], [2, 1, 0, 1], [2, 0, 0, 0], [2, 1, 1, 1]]
[[2, 1, 0, 1], [2, 1, 1, 1], [2, 1, 1, 1], [2, 0, 0, 0], [2, 1, 1, 0], [2, 0, 0, 1], [2, 0, 0, 0], [2, 0, 0, 0]]
[[2, 1, 1, 0], [2, 0, 0, 0], [2, 1, 1, 1], [2, 1, 1, 1], [2, 1, 0, 1], [2, 0, 0, 0], [2, 0, 0, 0], [2, 0, 0, 1]]
[[2, 1, 1, 1], [2, 0, 0, 0], [2, 0, 0, 1], [2, 0, 0, 0], [2, 0, 0, 0], [2, 1, 1, 0], [2, 1, 1, 1], [2, 1, 0, 1]]
[[2, 1, 1, 1], [2, 1, 1, 0], [2, 0, 0, 0], [2, 1, 1, 1], [2, 0, 0, 0], [2, 1, 0, 1], [2, 0, 0, 1], [2, 0, 0, 0]]
[[2, 0, 0, 0], [2, 0, 0, 0], [2, 1, 0, 1], [2, 0, 0, 1], [2, 0, 0, 0], [2, 1, 1, 1], [2, 1, 1, 0], [2, 1, 1, 1]]
[[2, 0, 0, 0], [2, 1, 1, 1], [2, 1, 1, 1], [2, 0, 0, 1], [2, 1, 0, 1], [2, 1, 1, 0], [2, 0, 0, 0

In [12]:
def eval_epoch(model: nn.Module, dataloader: DataLoader, device: torch.device) -> dict[str, list[float]]:
    model = model.to(device)
    inputs = []
    correct_outputs = []
    probabilities = []
    for input, labels in dataloader:
        input = input.to(device)
        labels = labels.to(device)
        output = model(input)
        inputs += input.tolist()
        correct_outputs += labels.tolist()
        probabilities += torch.softmax(output, dim=1)[:, 1].tolist()
    return {
        "inputs": inputs,
        "correct_outputs": correct_outputs,
        "probabilities": probabilities
    }

In [13]:
eval_res = eval_epoch(model, dataloader, device)

[[2, 1, 1, 1], [2, 1, 1, 1], [2, 1, 1, 0], [2, 0, 0, 0], [2, 0, 0, 0], [2, 0, 0, 1], [2, 0, 0, 0], [2, 1, 0, 1]]


In [14]:
eval_res["probabilities"]

[0.49928057193756104,
 0.4980148673057556,
 0.4956323802471161,
 0.5003236532211304,
 0.4979563355445862,
 0.49729418754577637,
 0.4969536066055298,
 0.4990168511867523]

In [15]:
import pandas as pd


# pd.DataFrame(eval_res, )
pd.DataFrame.from_dict(eval_res, orient="columns")

Unnamed: 0,inputs,correct_outputs,probabilities
0,"[1.0, 1.0, 1.0]",1,0.499281
1,"[1.0, 1.0, 1.0]",1,0.498015
2,"[1.0, 1.0, 0.0]",1,0.495632
3,"[0.0, 0.0, 0.0]",0,0.500324
4,"[0.0, 0.0, 0.0]",0,0.497956
5,"[0.0, 0.0, 1.0]",0,0.497294
6,"[0.0, 0.0, 0.0]",0,0.496954
7,"[1.0, 0.0, 1.0]",1,0.499017


In [16]:
eval_res["inputs"]

[[1.0, 1.0, 1.0],
 [1.0, 1.0, 1.0],
 [1.0, 1.0, 0.0],
 [0.0, 0.0, 0.0],
 [0.0, 0.0, 0.0],
 [0.0, 0.0, 1.0],
 [0.0, 0.0, 0.0],
 [1.0, 0.0, 1.0]]

In [17]:

res = pd.DataFrame.from_records(results)
res

Unnamed: 0,acc,loss,norm/embedding.weight,norm/transformer_layer.self_attn.in_proj_weight,norm/transformer_layer.self_attn.in_proj_bias,norm/transformer_layer.self_attn.out_proj.weight,norm/transformer_layer.self_attn.out_proj.bias,norm/transformer_layer.linear1.weight,norm/transformer_layer.linear1.bias,norm/transformer_layer.linear2.weight,...,norm/transformer_layer.norm2.weight,norm/transformer_layer.norm2.bias,norm/classifier.0.weight,norm/classifier.0.bias,norm/classifier.2.weight,norm/classifier.2.bias,norm/classifier.4.weight,norm/classifier.4.bias,norm/classifier.6.weight,norm/classifier.6.bias
0,0.500,0.758050,0.001380,0.002661,0.001119,0.003392,0.001616,0.000720,0.000350,0.000823,...,0.002861,0.002407,0.009242,0.004724,0.023165,0.018490,0.063566,0.186110,0.070225,0.245490
1,0.500,0.756717,0.001816,0.002452,0.001031,0.003587,0.002049,0.000628,0.000278,0.000652,...,0.002676,0.002519,0.009592,0.004871,0.020493,0.018021,0.058428,0.184144,0.070433,0.242935
2,0.500,0.755389,0.001676,0.002015,0.000848,0.003721,0.002040,0.001136,0.000574,0.001463,...,0.002454,0.002252,0.008775,0.004457,0.019055,0.017631,0.052032,0.182873,0.068062,0.241302
3,0.500,0.754617,0.001563,0.002214,0.000931,0.003386,0.001922,0.001221,0.000604,0.000691,...,0.002496,0.002189,0.008785,0.004352,0.019163,0.017305,0.051781,0.182152,0.066667,0.240388
4,0.500,0.753470,0.002033,0.002282,0.000959,0.003622,0.002300,0.000215,0.000102,0.000872,...,0.002277,0.002117,0.008313,0.004230,0.017522,0.016908,0.048002,0.180564,0.066862,0.238327
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
495,0.250,0.694898,0.002464,0.005137,0.002159,0.004823,0.002924,0.000917,0.000340,0.002378,...,0.002400,0.000001,0.008416,0.000002,0.007416,0.000006,0.007133,0.000024,0.003358,0.000028
496,0.375,0.694400,0.000656,0.002054,0.000863,0.003917,0.002077,0.001464,0.000679,0.001144,...,0.000963,0.000119,0.003145,0.000206,0.004832,0.000604,0.006627,0.002424,0.003746,0.002807
497,0.625,0.692000,0.001922,0.003168,0.001331,0.006007,0.002595,0.002654,0.001265,0.002512,...,0.001158,0.000096,0.003637,0.000166,0.002771,0.000487,0.003303,0.001956,0.001063,0.002265
498,0.625,0.691168,0.000286,0.002783,0.001170,0.005679,0.003131,0.003061,0.001576,0.001763,...,0.001581,0.000115,0.004934,0.000199,0.005135,0.000584,0.006363,0.002345,0.002395,0.002716


In [18]:
import plotly.express as px

fig = px.scatter(res)
fig.update_traces(mode='lines')
fig.show()

In [19]:
list(dataset)

[(tensor([0., 0., 0.]), tensor(0, dtype=torch.int32)),
 (tensor([0., 0., 1.]), tensor(0, dtype=torch.int32)),
 (tensor([1., 1., 0.]), tensor(1, dtype=torch.int32)),
 (tensor([1., 1., 1.]), tensor(1, dtype=torch.int32)),
 (tensor([0., 0., 0.]), tensor(0, dtype=torch.int32)),
 (tensor([0., 0., 0.]), tensor(0, dtype=torch.int32)),
 (tensor([1., 0., 1.]), tensor(1, dtype=torch.int32)),
 (tensor([1., 1., 1.]), tensor(1, dtype=torch.int32))]

In [20]:
model.to(device)

BooleanTransformer(
  (embedding): Embedding(3, 4)
  (pos_embedding): Embedding(3, 4)
  (transformer_layer): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=4, out_features=4, bias=True)
    )
    (linear1): Linear(in_features=4, out_features=4, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=4, out_features=4, bias=True)
    (norm1): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (classifier): Sequential(
    (0): Linear(in_features=4, out_features=4, bias=True)
    (1): ReLU()
    (2): Linear(in_features=4, out_features=4, bias=True)
    (3): ReLU()
    (4): Linear(in_features=4, out_features=4, bias=True)
    (5): ReLU()
    (6): Linear(in_features=4, out_features=2, bias=True)
  )
)

In [21]:
model(torch.tensor([[1, 0, 0, 1, 0]]).to(device))

[[2, 1, 0, 0, 1, 0]]


tensor([[-0.0041,  0.0173]], device='mps:0', grad_fn=<LinearBackward0>)

In [22]:
model(torch.tensor([[0, 0, 1, 1, 0]]).to(device))

[[2, 0, 0, 1, 1, 0]]


tensor([[-0.0041,  0.0173]], device='mps:0', grad_fn=<LinearBackward0>)

In [23]:
model(torch.tensor([[1, 0, 0, 1, 1]]).to(device))

[[2, 1, 0, 0, 1, 1]]


tensor([[ 0.0132, -0.0067]], device='mps:0', grad_fn=<LinearBackward0>)

In [24]:
model.save_to_checkpoint("../checkpoints/example_transformer.ckpt")