In [1]:
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torch import nn
import torch
from tqdm.notebook import tqdm
device = torch.device("cpu")

In [2]:
from miexp.models.transformer import Transformer
N = 20
dropout = 0
hidden_dim, heads, layers, feed_forward_dim = 2, 1, 1, 2
model = Transformer(dropout, N, hidden_dim, heads, layers, feed_forward_dim, "cpu")
sum(p.numel() for p in model.parameters())


from miexp.bfuncs import MajDataset
num_samples = 1000
train_dataset = MajDataset(N, num_samples=1000)
test_dataset = MajDataset(N, num_samples=1000)
train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=True)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

print(model)

Transformer(
  (embeddings): Embedding(3, 2)
  (transformer): Sequential(
    (0): AttentionBlock(
      (attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=2, out_features=2, bias=False)
      )
      (norm1): LayerNorm((2,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((2,), eps=1e-05, elementwise_affine=True)
      (linear): Sequential(
        (0): Linear(in_features=2, out_features=2, bias=True)
        (1): ReLU()
        (2): Linear(in_features=2, out_features=2, bias=True)
      )
    )
  )
  (mlp_head): Sequential(
    (0): Linear(in_features=2, out_features=2, bias=True)
  )
)


In [3]:
def train_epoch(model: nn.Module, optimizer: Optimizer, dataloader: DataLoader, device: torch.device, criterion: nn.Module) -> dict[str, float | None]:
    model = model.to(device)
    total_train_loss = 0
    total_train_acc = 0
    total_items = 0
    for input, labels in dataloader:
        input = input.to(device).to(torch.int32)
        labels = labels.to(device).to(torch.int64)
        output = model(input)
        loss = criterion(output, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_train_loss += (loss.item()) * len(input)
        total_train_acc += torch.sum(torch.argmax(output, dim=1) == labels).item()
        total_items += len(input)
    return {
        "acc": total_train_acc / total_items,
        "loss": total_train_loss / total_items,
        **{f"norm/{name}": torch.norm(param.grad).item() for name, param in model.named_parameters() if param.grad is not None}
    }

In [4]:
results = []
for epoch in tqdm(range(200)):
    results.append(train_epoch(model, optimizer, test_dataloader, device, criterion))

  0%|          | 0/200 [00:00<?, ?it/s]

In [5]:
import pandas as pd
import plotly.graph_objects as go

# Convert to pandas DataFrame
df = pd.DataFrame(results)

fig = go.Figure()

fig.add_trace(
    go.Scatter(x=df.index, y=df['loss'], name='Loss', yaxis='y1')
)
fig.add_trace(
    go.Scatter(x=df.index, y=df['acc'], name='Accuracy', yaxis='y2')
)

fig.update_layout(
    title='Loss and Accuracy',
    xaxis=dict(title='Index'),
    yaxis=dict(title='Loss', side='left'),
    yaxis2=dict(title='Accuracy', overlaying='y', side='right')
)

fig.show()

In [6]:
def eval_epoch(model: nn.Module, dataloader: DataLoader, device: torch.device) -> dict[str, list[float]]:
    model = model.to(device)
    inputs = []
    correct_outputs = []
    probabilities = []
    for input, labels in dataloader:
        input = input.to(device).to(torch.int32)
        labels = labels.to(device).to(torch.float)
        output = model(input)
        inputs += input.tolist()
        correct_outputs += labels.tolist()
        probabilities += torch.softmax(output, dim=1)[:, 1].tolist()
    return {
        "inputs": inputs,
        "correct_outputs": correct_outputs,
        "probabilities": probabilities
    }

In [7]:
eval_res = eval_epoch(model, test_dataloader, device)
pd.DataFrame.from_dict(eval_res, orient="columns")

Unnamed: 0,inputs,correct_outputs,probabilities
0,"[1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, ...",1.0,0.413695
1,"[1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, ...",1.0,0.413695
2,"[1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, ...",0.0,0.413695
3,"[1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, ...",0.0,0.413695
4,"[0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, ...",0.0,0.413695
...,...,...,...
995,"[1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, ...",0.0,0.413695
996,"[0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, ...",0.0,0.413695
997,"[0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, ...",0.0,0.413695
998,"[1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, ...",1.0,0.413695


In [8]:
torch.save(model.state_dict(), "checkpoints/example_transformer.ckpt")

In [9]:
import plotly.express as px
import plotly.subplots as sp

print(model)

qkv = model.transformer[0].attn.in_proj_weight.detach()
q, k, v = qkv.reshape(3, hidden_dim, hidden_dim)
qkT = q @ k.T

# Create subplots
fig = sp.make_subplots(rows=2, cols=2, subplot_titles=("Q", "K", "V", "Q @ K.T"))

# Add q plot
fig.add_trace(px.imshow(q).data[0], row=1, col=1)

# Add k plot
fig.add_trace(px.imshow(k).data[0], row=1, col=2)

# Add v plot
fig.add_trace(px.imshow(v).data[0], row=2, col=1)

# Add q @ k.T plot
fig.add_trace(px.imshow(qkT).data[0], row=2, col=2)

# Update layout
fig.update_layout(height=800, width=800, title_text="Subplots of q, k, v, and q @ k.T")

fig.show()



Transformer(
  (embeddings): Embedding(3, 2)
  (transformer): Sequential(
    (0): AttentionBlock(
      (attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=2, out_features=2, bias=False)
      )
      (norm1): LayerNorm((2,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((2,), eps=1e-05, elementwise_affine=True)
      (linear): Sequential(
        (0): Linear(in_features=2, out_features=2, bias=True)
        (1): ReLU()
        (2): Linear(in_features=2, out_features=2, bias=True)
      )
    )
  )
  (mlp_head): Sequential(
    (0): Linear(in_features=2, out_features=2, bias=True)
  )
)


In [10]:
import plotly.express as px
import plotly.subplots as sp

X = torch.tensor([[1 if i < set_bits else 0 for i in range(N)] for set_bits in range(1, 15, 2)])
batch_size = X.shape[0]

X = torch.cat(
    [
        2 * torch.ones((batch_size, 1), dtype=torch.int).to(X.device),
        X.type(torch.int),
    ],
    dim=1)
# pos = (
#     torch.eye(N + 1, N)
#     .to(device)
#     .unsqueeze(0)
#     .repeat(batch_size, 1, 1)
# )
X = model.embeddings(X).detach()
# X = torch.cat([pos, dat], dim=2)
qkv = model.transformer[0].attn.in_proj_weight.detach()
W   = model.transformer[0].attn.out_proj.weight.detach()
q, k, v = qkv.reshape(3, hidden_dim, hidden_dim)
Q, K, V = X @ q, X @ k, X @ v
QKT = Q @ K.transpose(1, 2) / torch.sqrt(torch.tensor(hidden_dim).to(device))
A = torch.softmax(QKT, dim=1) @ V
OUT = A @ W
postMLP = model.transformer[0].norm1(X+OUT)
postMLP = model.transformer[0].norm2(postMLP + model.transformer[0].linear(postMLP))[:, -1, :].detach()
FINAL   = model.mlp_head(postMLP).detach()
# Create subplots
X_mean, QKT_mean, A_mean, OUT_mean, postMLP_mean = X.mean(dim=0), QKT.mean(dim=0), A.mean(dim=0), OUT.mean(dim=0), postMLP.mean(dim=0)

fig = sp.make_subplots(rows=X.shape[0], cols=3, subplot_titles=("QK.T", "A", "OUT"))

for i in range(X.shape[0]):
    # fig.add_trace(px.imshow(X[i]-X_mean).data[0], row=i+1, col=1)
    fig.add_trace(px.imshow(QKT[i]-QKT_mean).data[0], row=i+1, col=1)
    fig.add_trace(px.imshow(A[i]-A_mean).data[0], row=i+1, col=2)
    fig.add_trace(px.imshow(OUT[i]-OUT_mean).data[0], row=i+1, col=3)
    # fig.add_trace(px.imshow((postMLP[i]-postMLP_mean).unsqueeze(0)).data[0], row=i+1, col=5)
    # fig.add_trace(px.imshow(FINAL[i].unsqueeze(0)).data[0], row=i+1, col=6)

fig.update_layout(height=800, width=800, title_text="Plots for 1, 3, 5, 7, 9, 11, 13 set bits")
fig.show()



