# üé® FONTe AI - B200 Training (Modal.com)

**Run this notebook directly on Modal.com**

GPU: B200 (192GB) @ $6.25/hr | ~50 epochs in ~1.3 hours | ~$8 total

## 1Ô∏è‚É£ Setup

In [None]:
!apt-get install git-lfs -qq
!git lfs install
!git clone https://github.com/nityam2007/fonte-ai.git
%cd fonte-ai
!git lfs pull
!nvidia-smi --query-gpu=name,memory.total --format=csv
!ls -lh TOKENIZED/

## 2Ô∏è‚É£ Train

In [None]:
import math, json, struct, time, os
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm.auto import tqdm

# Config
BATCH_SIZE = 512
EPOCHS = 50
LR = 3e-4
DEVICE = 'cuda'

@dataclass
class ModelConfig:
    vocab_size: int = 1105
    max_seq_length: int = 512
    d_model: int = 256
    n_heads: int = 4
    n_layers: int = 6
    d_ff: int = 1024
    dropout: float = 0.1
    pad_token_id: int = 0

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(max_len).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))
    def forward(self, x):
        return self.dropout(x + self.pe[:, :x.size(1)])

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        self.n_heads, self.d_k = n_heads, d_model // n_heads
        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        self.wo = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x, mask=None):
        B, L, D = x.shape
        q = self.wq(x).view(B, L, self.n_heads, self.d_k).transpose(1, 2)
        k = self.wk(x).view(B, L, self.n_heads, self.d_k).transpose(1, 2)
        v = self.wv(x).view(B, L, self.n_heads, self.d_k).transpose(1, 2)
        scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = self.dropout(F.softmax(scores, dim=-1))
        out = (attn @ v).transpose(1, 2).reshape(B, L, -1)
        return self.wo(out)

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.ff = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_ff, d_model))
        self.n1 = nn.LayerNorm(d_model)
        self.n2 = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)
    def forward(self, x, mask=None):
        x = x + self.drop(self.attn(self.n1(x), mask))
        return x + self.drop(self.ff(self.n2(x)))

class FonteModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.emb = nn.Embedding(cfg.vocab_size, cfg.d_model, padding_idx=cfg.pad_token_id)
        self.pos = PositionalEncoding(cfg.d_model, cfg.max_seq_length, cfg.dropout)
        self.blocks = nn.ModuleList([TransformerBlock(cfg.d_model, cfg.n_heads, cfg.d_ff, cfg.dropout) for _ in range(cfg.n_layers)])
        self.norm = nn.LayerNorm(cfg.d_model)
        self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
        self.head.weight = self.emb.weight
        self.register_buffer('mask', torch.tril(torch.ones(cfg.max_seq_length, cfg.max_seq_length)))
    def forward(self, x, labels=None):
        x = self.pos(self.emb(x))
        m = self.mask[:x.size(1), :x.size(1)]
        for b in self.blocks:
            x = b(x, m)
        logits = self.head(self.norm(x))
        if labels is not None:
            loss = F.cross_entropy(logits[:, :-1].reshape(-1, self.cfg.vocab_size), labels[:, 1:].reshape(-1), ignore_index=self.cfg.pad_token_id)
            return {'logits': logits, 'loss': loss}
        return {'logits': logits}

class FonteDataset(Dataset):
    def __init__(self, path):
        with open(path, 'rb') as f:
            n, ml, _ = struct.unpack('III', f.read(12))
            self.data = []
            for _ in range(n):
                f.read(2)  # skip length
                self.data.append(list(struct.unpack(f'{ml}H', f.read(ml * 2))))
        print(f"  Loaded {len(self.data)} sequences from {path}")
    def __len__(self):
        return len(self.data)
    def __getitem__(self, i):
        return {'input_ids': torch.tensor(self.data[i], dtype=torch.long)}

# Load data
print("Loading data...")
train_ds = FonteDataset('TOKENIZED/train.bin')
val_ds = FonteDataset('TOKENIZED/val.bin')
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, num_workers=0)
print(f"Train: {len(train_ds)}, Val: {len(val_ds)}, Batches/epoch: {len(train_dl)}")

# Model
model = FonteModel(ModelConfig()).to(DEVICE)
opt = AdamW(model.parameters(), lr=LR, weight_decay=0.01)
sched = CosineAnnealingLR(opt, T_max=len(train_dl) * EPOCHS)
print(f"Model: {sum(p.numel() for p in model.parameters() if p.requires_grad):,} params on {DEVICE}")

# Training
os.makedirs('TRAINED', exist_ok=True)
history, best_loss = [], float('inf')

print(f"\nüöÄ Training {EPOCHS} epochs...")
for epoch in range(1, EPOCHS + 1):
    model.train()
    train_loss = 0
    for batch in tqdm(train_dl, desc=f"Epoch {epoch}/{EPOCHS}"):
        x = batch['input_ids'].to(DEVICE)
        loss = model(x, x)['loss']
        opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()
        sched.step()
        train_loss += loss.item()
    train_loss /= len(train_dl)
    
    model.eval()
    with torch.no_grad():
        val_loss = sum(model(b['input_ids'].to(DEVICE), b['input_ids'].to(DEVICE))['loss'].item() for b in val_dl) / len(val_dl)
    
    print(f"Epoch {epoch} | Train: {train_loss:.4f} | Val: {val_loss:.4f}")
    history.append({'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss})
    
    # Save checkpoint every epoch
    torch.save({'config': model.cfg.__dict__, 'state_dict': model.state_dict(), 'epoch': epoch,
                'train_loss': train_loss, 'val_loss': val_loss}, f'TRAINED/checkpoint_epoch_{epoch}.pt')
    
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save({'config': model.cfg.__dict__, 'state_dict': model.state_dict()}, 'TRAINED/best_model.pt')
        print(f"  üíæ Best model saved!")
    
    with open('TRAINED/training_history.json', 'w') as f:
        json.dump(history, f)

print(f"\n‚úÖ Done! Best val_loss: {best_loss:.4f}")

## 3Ô∏è‚É£ Download Models

In [None]:
# Download all checkpoints
!ls -lh TRAINED/

# If on Modal, download via modal volume or use:
# !zip -r trained_models.zip TRAINED/
# Then download trained_models.zip from Modal UI