In [1]:
import matplotlib.pyplot as plt
from collections import defaultdict
import time

import torch
import torch.nn as nn

from bio_LLM import TinyTransformer

from dataset import get_batch_split, vocab_size, stoi, decode

In [2]:
device = (
    "mps" if torch.backends.mps.is_available()
    else "cuda" if torch.cuda.is_available()
    else "cpu"
)
print(f"Using device: {device}")

Using device: mps


In [3]:
PRINT_ATTENTION_VALUES = False

ATTENTION_STATS = {
    "scores": defaultdict(list),
    "weights": defaultdict(list),
}

def make_attention_hook(layer_idx):
    def inspect_attention(module, input, output):
        if not PRINT_ATTENTION_VALUES:
            return
        scores = module._last_attn_scores.detach().cpu()
        weights = module._last_attn_weights.detach().cpu()

        # Flatten to 1D for histogram
        ATTENTION_STATS["scores"][layer_idx].append(scores.view(-1))
        ATTENTION_STATS["weights"][layer_idx].append(weights.view(-1))
    return inspect_attention

In [4]:
dot_mode = "NEURON_DOT"
softmax_mode = "NEURON_SOFTMAX"

block_size = 32
batch_size = 16
n_embd = 64
n_heads = 4
n_layers = 4
train_steps = 1000
eval_interval = 500
eval_steps = 50

LIF_model_dt = 1e-1
LIF_model_steps = 100
wta_inhibition = -0.9
wta_excitation = 1.1
wta_steps = 20

learning_rate = 1e-3

model = TinyTransformer(
    vocab_size, 
    n_embd=n_embd,
    block_size=block_size, 
    n_heads=n_heads, 
    n_layers=n_layers, 
    dot_mode=dot_mode,
    softmax_mode=softmax_mode).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

# for inspecting attention scores & weights during training
for i, block in enumerate(model.blocks):
    block.attn.register_forward_hook(make_attention_hook(i))

@torch.no_grad()
def estimate_loss():
    model.eval()
    out = {}
    for split in ['train', 'val']:
        losses = torch.zeros(eval_steps)
        for k in range(eval_steps):
            x, y = get_batch_split(split, block_size=block_size, batch_size=batch_size, device=device)
            logits = model(x)
            B, T, C = logits.shape
            loss = loss_fn(logits.view(B*T, C), y.view(B*T))
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

start_time = time.time()
for step in range(train_steps):
    x, y = get_batch_split('train', block_size=block_size, batch_size=batch_size, device=device)

    # Forward pass
    logits = model(x)
    B, T, C = logits.shape
    loss = loss_fn(logits.view(B*T, C), y.view(B*T))

    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    
    if step % eval_interval == 0 or step == train_steps - 1:
        losses = estimate_loss()
        print(f"Step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        PRINT_ATTENTION_VALUES = True
        x, _ = get_batch_split('val', block_size=block_size, batch_size=1, device=device)
        _ = model(x)
        PRINT_ATTENTION_VALUES = False

end_time = time.time()
print(f"Training time: {end_time - start_time:.2f} seconds")

torch.save(model.state_dict(), 'bio_tiny_shakespeare_model.pth')
print("Model saved to bio_tiny_shakespeare_model.pth")

Step 0: train loss 4.0397, val loss 4.0383
Step 500: train loss 2.7140, val loss 2.7362
Step 999: train loss 2.6631, val loss 2.6766
Training time: 369.25 seconds
Model saved to bio_tiny_shakespeare_model.pth


In [5]:
for stat_type in ['scores', 'weights']:
    for layer_idx, values in ATTENTION_STATS[stat_type].items():
        data = torch.cat(values).numpy()

        plt.figure(figsize=(6, 4))
        plt.hist(data, bins=100, density=True)
        plt.title(f"{dot_mode} {softmax_mode} — Transformer block {layer_idx} {stat_type}")
        plt.xlabel(stat_type)
        plt.grid(True)
        plt.savefig(f"plots/{dot_mode}_{softmax_mode}/hist_layer{layer_idx}_{stat_type}.png")
        plt.close()

In [5]:
def generate(model, start, length=100, temperature=0.8, top_k=40):
    model.eval()
    with torch.no_grad():
        x = torch.tensor([stoi[s] for s in start], dtype=torch.long).unsqueeze(0).to(device)
        idx = model.generate(x, max_new_tokens=length, temperature=temperature, top_k=top_k)
        return decode(idx[0].tolist())

gen_start = time.time()
output = generate(model, "ROMEO", length=200)
gen_end = time.time()

print(output)
print(f"\nGeneration time: {gen_end - gen_start:.2f} seconds")

ROMEOY:
Ale at t o shoudel machecebe B!
Wod sumat spe. Endeereend monkld hav Meundthere ororule,
Thin hadon,
CInd rorer thelelirdcit opris avor mand 'ciseng Irsitheve toncog man I piung heandeak henneng
yc

Generation time: 66.20 seconds
