In [1]:
import torch
import torch.nn as nn
# from bio_LLM import BioTinyTransformer
from bio_LLM_compute_firing_rates import TinyTransformer as BioTinyTransformer
from dataset import get_batch_split, vocab_size, stoi, decode



In [2]:
device = (
    "mps" if torch.backends.mps.is_available()
    else "cuda" if torch.cuda.is_available()
    else "cpu"
)
print(f"Using device: {device}")

Using device: mps


In [5]:
block_size = 16
batch_size = 16
n_embd = 64
n_heads = 4
n_layers = 4

train_steps = 1000
eval_interval = 500
eval_steps = 50

LIF_model_dt = 1e-1
LIF_model_steps = 100
wta_inhibition = -0.9
wta_excitation = 1.1
wta_steps = 20

learning_rate = 1e-3

model = BioTinyTransformer(vocab_size, n_embd=n_embd, block_size=block_size, n_heads=n_heads, n_layers=n_layers, dot_mode = "NEURON_DOT").to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

@torch.no_grad()
def estimate_loss():
    model.eval()
    out = {}
    for split in ['train', 'val']:
        losses = torch.zeros(eval_steps)
        for k in range(eval_steps):
            x, y = get_batch_split(split, block_size=block_size, batch_size=batch_size, device=device)
            logits = model(x)
            B, T, C = logits.shape
            loss = loss_fn(logits.view(B*T, C), y.view(B*T))
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

for step in range(train_steps):
    x, y = get_batch_split('train', block_size=block_size, batch_size=batch_size, device=device)

    # Forward pass
    logits = model(x)
    B, T, C = logits.shape
    loss = loss_fn(logits.view(B*T, C), y.view(B*T))

    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    
    if step % eval_interval == 0 or step == train_steps - 1:
        losses = estimate_loss()
        print(f"Step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

torch.save(model.state_dict(), 'bio_tiny_shakespeare_model.pth')
print("Model saved to bio_tiny_shakespeare_model.pth")

Step 0: train loss 4.1030, val loss 4.1029
Step 500: train loss 2.5072, val loss 2.5032
Step 999: train loss 2.3766, val loss 2.3815
Model saved to bio_tiny_shakespeare_model.pth


In [None]:
def generate(model, start, length=100, temperature=0.8, top_k=40):
    model.eval()
    with torch.no_grad():
        x = torch.tensor([stoi[s] for s in start], dtype=torch.long).unsqueeze(0).to(device)
        idx = model.generate(x, max_new_tokens=length, temperature=temperature, top_k=top_k)
        return decode(idx[0].tolist())

print(generate(model, "ROMEO", length=200))