In [1]:
import torch
import torch.nn as nn
from bio_LLM import BioTinyTransformer
from dataset import get_batch_split, vocab_size, stoi, decode

In [2]:
device = (
    "mps" if torch.backends.mps.is_available()
    else "cuda" if torch.cuda.is_available()
    else "cpu"
)
print(f"Using device: {device}")

Using device: mps


In [3]:
model = BioTinyTransformer(vocab_size, n_embd=128, block_size=32, n_heads=4, n_layers=4).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
loss_fn = nn.CrossEntropyLoss()

batch_size = 32
steps = 1000
eval_interval = 500
eval_steps = 50

@torch.no_grad()
def estimate_loss():
    model.eval()
    out = {}
    for split in ['train', 'val']:
        losses = torch.zeros(eval_steps)
        for k in range(eval_steps):
            x, y = get_batch_split(split, block_size=32, batch_size=batch_size, device=device)
            logits = model(x)
            B, T, C = logits.shape
            loss = loss_fn(logits.view(B*T, C), y.view(B*T))
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

for step in range(steps):
    x, y = get_batch_split('train', block_size=32, batch_size=batch_size, device=device)

    # Forward pass
    logits = model(x)
    B, T, C = logits.shape
    loss = loss_fn(logits.view(B*T, C), y.view(B*T))

    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Gradient clipping
    optimizer.step()
    
    if step % eval_interval == 0 or step == steps - 1:
        losses = estimate_loss()
        print(f"Step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

torch.save(model.state_dict(), 'bio_tiny_shakespeare_model.pth')
print("Model saved to bio_tiny_shakespeare_model.pth")

Step 0: train loss 3.9491, val loss 3.9523
Step 500: train loss 2.5205, val loss 2.5116
Step 999: train loss 2.4913, val loss 2.5001
Model saved to bio_tiny_shakespeare_model.pth


In [4]:
def generate(model, start, length=100, temperature=0.8, top_k=40):
    model.eval()
    with torch.no_grad():
        x = torch.tensor([stoi[s] for s in start], dtype=torch.long).unsqueeze(0).to(device)
        idx = model.generate(x, max_new_tokens=length, temperature=temperature, top_k=top_k)
        return decode(idx[0].tolist())

print(generate(model, "ROM", length=100))

ROMy
Mas I sthe o'sewin o t. y ave murea halis y ch be ceeinathateat,

Goulld myoucareayo her n,
I alie
