# üé® FONTe AI - B200 Training (Modal.com)

**Run this notebook directly on Modal.com**

GPU: B200 (192GB) @ $6.25/hr | ~50 epochs in ~1.3 hours | ~$8 total

## 1Ô∏è‚É£ Setup

In [None]:
!apt-get install git-lfs -qq
!git lfs install
!git clone https://github.com/nityam2007/fonte-ai.git
%cd fonte-ai
!git lfs pull
!nvidia-smi --query-gpu=name,memory.total --format=csv
!ls -lh TOKENIZED/

## 2Ô∏è‚É£ Train

In [None]:
import modal
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm.auto import tqdm
from dataclasses import dataclass
import json
import struct
import math
import time
from pathlib import Path

# Config
BATCH_SIZE = 512
EPOCHS = 50
LR = 3e-4
DEVICE = 'cuda'

@dataclass
class ModelConfig:
    vocab_size: int = 1106  # FIXED: was 1105, now includes <NEG> token at ID 24
    max_seq_length: int = 512
    d_model: int = 256
    n_heads: int = 4
    n_layers: int = 6
    d_ff: int = 1024
    dropout: float = 0.1
    pad_token_id: int = 0

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(max_len).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))
    def forward(self, x):
        return self.dropout(x + self.pe[:, :x.size(1)])

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        self.n_heads, self.d_k = n_heads, d_model // n_heads
        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        self.wo = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x, mask=None):
        B, L, D = x.shape
        q = self.wq(x).view(B, L, self.n_heads, self.d_k).transpose(1, 2)
        k = self.wk(x).view(B, L, self.n_heads, self.d_k).transpose(1, 2)
        v = self.wv(x).view(B, L, self.n_heads, self.d_k).transpose(1, 2)
        scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = self.dropout(F.softmax(scores, dim=-1))
        out = (attn @ v).transpose(1, 2).reshape(B, L, -1)
        return self.wo(out)

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.ff = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_ff, d_model))
        self.n1 = nn.LayerNorm(d_model)
        self.n2 = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)
    def forward(self, x, mask=None):
        x = x + self.drop(self.attn(self.n1(x), mask))
        return x + self.drop(self.ff(self.n2(x)))

class FonteModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.emb = nn.Embedding(cfg.vocab_size, cfg.d_model, padding_idx=cfg.pad_token_id)
        self.pos = PositionalEncoding(cfg.d_model, cfg.max_seq_length, cfg.dropout)
        self.blocks = nn.ModuleList([TransformerBlock(cfg.d_model, cfg.n_heads, cfg.d_ff, cfg.dropout) for _ in range(cfg.n_layers)])
        self.norm = nn.LayerNorm(cfg.d_model)
        self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
        self.head.weight = self.emb.weight
        self.register_buffer('mask', torch.tril(torch.ones(cfg.max_seq_length, cfg.max_seq_length)))
    def forward(self, input_ids, labels=None):
        x = self.pos(self.emb(input_ids))
        m = self.mask[:x.size(1), :x.size(1)]
        for b in self.blocks:
            x = b(x, m)
        logits = self.head(self.norm(x))
        loss = None
        if labels is not None:
            loss = F.cross_entropy(logits[:, :-1].reshape(-1, self.cfg.vocab_size), labels[:, 1:].reshape(-1), ignore_index=self.cfg.pad_token_id)
        return {"loss": loss, "logits": logits}
    def count_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

class BinaryDataset(Dataset):
    """
    Load binary dataset with proper format:
    - Header: 12 bytes (num_sequences, max_len, vocab_size as uint32)
    - Per sequence: 2 bytes (length as uint16) + max_len*2 bytes (tokens as uint16)
    """
    def __init__(self, bin_path, json_path=None, max_len=512):
        # Read binary file header and data
        with open(bin_path, 'rb') as f:
            # Read 12-byte header
            header = f.read(12)
            self.count, self.max_len, self.vocab_size = struct.unpack('III', header)
            # Read rest of file
            self.data = f.read()
        
        # Calculate bytes per sequence: 2 (length) + max_len * 2 (tokens)
        self.bytes_per_seq = 2 + self.max_len * 2
        
        print(f"Loaded {self.count:,} sequences, max_len={self.max_len}, vocab_size={self.vocab_size}")
        
    def __len__(self):
        return self.count
    
    def __getitem__(self, idx):
        # Calculate offset (each sequence: 2 byte length + max_len * 2 byte tokens)
        offset = idx * self.bytes_per_seq
        
        # Read length (2 bytes) - we don't actually use it but need to skip it
        # length = struct.unpack('H', self.data[offset:offset+2])[0]
        
        # Read tokens starting after length field
        token_offset = offset + 2
        tokens = struct.unpack(f'{self.max_len}H', self.data[token_offset:token_offset + self.max_len * 2])
        
        return torch.tensor(tokens, dtype=torch.long)

app = modal.App("fonte-ai-training")
vol = modal.Volume.from_name("fonte-data", create_if_missing=True)

image = modal.Image.debian_slim(python_version="3.11").pip_install("torch", "tqdm")

@app.function(image=image, gpu="B200", timeout=7200, volumes={"/data": vol})
def train():
    train_ds = BinaryDataset("/data/train.bin")
    val_ds = BinaryDataset("/data/val.bin")
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, num_workers=4)
    
    cfg = ModelConfig()
    model = FonteModel(cfg).to(DEVICE)
    print(f"Model: {model.count_params()/1e6:.1f}M params, vocab_size={cfg.vocab_size}")
    
    opt = AdamW(model.parameters(), lr=LR, weight_decay=0.01)
    sched = CosineAnnealingLR(opt, T_max=EPOCHS * len(train_loader))
    
    best_loss = float('inf')
    for epoch in range(EPOCHS):
        model.train()
        total_loss = 0
        for batch in tqdm(train_loader, desc=f"E{epoch+1}"):
            batch = batch.to(DEVICE)
            out = model(batch, batch)
            loss = out['loss']
            opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            sched.step()
            total_loss += loss.item()
        
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(DEVICE)
                val_loss += model(batch, batch)['loss'].item()
        val_loss /= len(val_loader)
        
        print(f"Epoch {epoch+1}: train_loss={total_loss/len(train_loader):.4f}, val_loss={val_loss:.4f}")
        
        if val_loss < best_loss:
            best_loss = val_loss
            torch.save({'config': cfg.__dict__, 'state_dict': model.state_dict()}, f"/data/best_model.pt")
        torch.save({'config': cfg.__dict__, 'state_dict': model.state_dict()}, f"/data/epoch_{epoch+1}.pt")
    
    vol.commit()
    return {"final_val_loss": val_loss, "best_val_loss": best_loss}

@app.local_entrypoint()
def main():
    result = train.remote()
    print(f"Training complete: {result}")

## 3Ô∏è‚É£ Download Models

In [None]:
# Download all checkpoints
!ls -lh TRAINED/

# If on Modal, download via modal volume or use:
# !zip -r trained_models.zip TRAINED/
# Then download trained_models.zip from Modal UI