# 🎨 FONTe AI - Font Generation Training

**AI-powered unique font generation using SVG Path Transformers**

Repository: [github.com/nityam2007/fonte-ai](https://github.com/nityam2007/fonte-ai)

---

## What this notebook does:
1. ✅ Clone repository with training data (Git LFS)
2. ✅ Setup environment
3. ✅ Load pre-tokenized dataset (248K sequences)
4. ✅ **Auto-resume from checkpoint** (for 4-hour limit!)
5. ✅ Train SVG Path Transformer model
6. ✅ Generate sample fonts
7. ✅ Save model to Drive

**Requirements:** Google Colab (Free T4 GPU works!)

---

### ⚡ 4-Hour Session Strategy:
- **Batch size 128** → ~14 min/epoch (faster!)
- **Saves every epoch** → Never lose progress
- **Auto-resume** → Just run notebook again to continue
- **~15 epochs/session** → 4 sessions for 50 epochs

---
## 1️⃣ Setup & Clone Repository

In [None]:
# Check GPU
!nvidia-smi --query-gpu=name,memory.total --format=csv

import torch
print(f"\nPyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Mount Google Drive (for saving checkpoints)
from google.colab import drive
drive.mount('/content/drive')

# Create project folder in Drive
!mkdir -p /content/drive/MyDrive/fonte_ai

In [None]:
# Install Git LFS and clone repository
!apt-get install git-lfs -qq
!git lfs install

# Clone the repository (includes LFS files)
%cd /content
!git clone https://github.com/nityam2007/fonte-ai.git
%cd fonte-ai

# Pull LFS files
!git lfs pull

# Check data files
!ls -lh TOKENIZED/

---
## 2️⃣ Model Architecture

SVG Path Transformer - treats font glyphs as sequences of path commands

In [None]:
import math
import json
import struct
from pathlib import Path
from typing import Optional, Dict, List, Tuple
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR


@dataclass
class ModelConfig:
    vocab_size: int = 1106
    max_seq_length: int = 512
    d_model: int = 256
    n_heads: int = 4
    n_layers: int = 6
    d_ff: int = 1024
    dropout: float = 0.1
    pad_token_id: int = 0
    sos_token_id: int = 1
    eos_token_id: int = 2


class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 512, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape
        q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        k = self.w_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        v = self.w_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        return self.w_o(out)


class FeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.linear2(self.dropout(F.gelu(self.linear1(x))))


class TransformerBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, n_heads, dropout)
        self.ff = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        x = x + self.dropout(self.attention(self.norm1(x), mask))
        x = x + self.dropout(self.ff(self.norm2(x)))
        return x


class FonteModel(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        self.token_embedding = nn.Embedding(config.vocab_size, config.d_model, padding_idx=config.pad_token_id)
        self.pos_encoding = PositionalEncoding(config.d_model, config.max_seq_length, config.dropout)
        self.blocks = nn.ModuleList([
            TransformerBlock(config.d_model, config.n_heads, config.d_ff, config.dropout)
            for _ in range(config.n_layers)
        ])
        self.norm = nn.LayerNorm(config.d_model)
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.lm_head.weight = self.token_embedding.weight
        self.apply(self._init_weights)
        self.register_buffer('causal_mask', torch.tril(torch.ones(config.max_seq_length, config.max_seq_length)))

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, input_ids, labels=None):
        batch_size, seq_len = input_ids.shape
        x = self.token_embedding(input_ids)
        x = self.pos_encoding(x)
        mask = self.causal_mask[:seq_len, :seq_len]
        for block in self.blocks:
            x = block(x, mask)
        x = self.norm(x)
        logits = self.lm_head(x)
        result = {'logits': logits}
        if labels is not None:
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = labels[:, 1:].contiguous()
            loss = F.cross_entropy(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1), ignore_index=self.config.pad_token_id)
            result['loss'] = loss
        return result

    @torch.no_grad()
    def generate(self, style_id, char_id, max_length=256, temperature=1.0, top_k=50, top_p=0.9):
        self.eval()
        device = next(self.parameters()).device
        tokens = torch.tensor([[self.config.sos_token_id, style_id, char_id]], device=device)
        for _ in range(max_length - 3):
            outputs = self.forward(tokens)
            logits = outputs['logits'][:, -1, :] / temperature
            if top_k > 0:
                indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
                logits[indices_to_remove] = float('-inf')
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            tokens = torch.cat([tokens, next_token], dim=1)
            if next_token.item() == self.config.eos_token_id:
                break
        return tokens

    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def save(self, path):
        torch.save({'config': self.config.__dict__, 'state_dict': self.state_dict()}, path)

    @classmethod
    def load(cls, path, device='cpu'):
        checkpoint = torch.load(path, map_location=device)
        config = ModelConfig(**checkpoint['config'])
        model = cls(config)
        model.load_state_dict(checkpoint['state_dict'])
        return model


print("✅ Model classes defined!")

---
## 3️⃣ Dataset Loading

In [None]:
class FonteDataset(Dataset):
    def __init__(self, data_path, max_length=512):
        self.max_length = max_length
        with open(data_path, 'rb') as f:
            num_sequences, max_length, vocab_size = struct.unpack('III', f.read(12))
            self.token_ids = []
            self.lengths = []
            for _ in range(num_sequences):
                length = struct.unpack('H', f.read(2))[0]
                tokens = list(struct.unpack(f'{max_length}H', f.read(max_length * 2)))
                self.lengths.append(length)
                self.token_ids.append(tokens)
        print(f"Loaded {len(self.token_ids)} sequences")

    def __len__(self):
        return len(self.token_ids)

    def __getitem__(self, idx):
        return {
            'input_ids': torch.tensor(self.token_ids[idx], dtype=torch.long),
            'length': self.lengths[idx],
        }


# Load datasets
train_dataset = FonteDataset('TOKENIZED/train.bin')
val_dataset = FonteDataset('TOKENIZED/val.bin')

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")

---
## 4️⃣ Training Configuration

In [None]:
# ═══════════════════════════════════════════════════════════════════════════
# TRAINING CONFIG - Optimized for 4-hour Colab sessions
# ═══════════════════════════════════════════════════════════════════════════

BATCH_SIZE = 128         # ⚡ Increased for faster epochs (~14 min instead of 28)
EPOCHS = 50              # Total epochs (will resume from checkpoint)
LEARNING_RATE = 3e-4
SAVE_EVERY = 1           # Save EVERY epoch for session continuity
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Model config (medium size ~12M params)
config = ModelConfig(
    vocab_size=1106,
    max_seq_length=512,
    d_model=256,
    n_heads=4,
    n_layers=6,
    d_ff=1024,
    dropout=0.1,
)

# ═══════════════════════════════════════════════════════════════════════════
# RESUME FROM CHECKPOINT (if exists)
# ═══════════════════════════════════════════════════════════════════════════
import glob
import os

checkpoint_dir = '/content/drive/MyDrive/fonte_ai'
os.makedirs(checkpoint_dir, exist_ok=True)

checkpoints = glob.glob(f'{checkpoint_dir}/checkpoint_epoch_*.pt')
START_EPOCH = 1

if checkpoints:
    # Find latest checkpoint
    latest = max(checkpoints, key=lambda x: int(x.split('_')[-1].split('.')[0]))
    START_EPOCH = int(latest.split('_')[-1].split('.')[0]) + 1
    
    if START_EPOCH <= EPOCHS:
        print(f"📂 Found checkpoint: {latest}")
        model = FonteModel.load(latest, device=DEVICE).to(DEVICE)
        print(f"✅ Resuming from epoch {START_EPOCH}")
    else:
        print(f"✅ Training already complete! (epoch {START_EPOCH-1})")
        model = FonteModel(config).to(DEVICE)
else:
    print("🆕 Starting fresh training")
    model = FonteModel(config).to(DEVICE)

print(f"\n📊 Model parameters: {model.count_parameters():,}")
print(f"🖥️ Device: {DEVICE}")
print(f"📦 Batch size: {BATCH_SIZE}")
print(f"🎯 Epochs: {START_EPOCH} → {EPOCHS}")

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# Optimizer & Scheduler (recreate for resumed training)
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
total_steps = len(train_loader) * EPOCHS
completed_steps = len(train_loader) * (START_EPOCH - 1)
scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, last_epoch=completed_steps-1 if completed_steps > 0 else -1)

# Estimate time
est_epoch_time = len(train_loader) / 2.5 / 60  # ~2.5 it/s with batch 128
remaining_epochs = EPOCHS - START_EPOCH + 1
print(f"\n⏱️ Est. {est_epoch_time:.0f} min/epoch × {remaining_epochs} epochs = {est_epoch_time * remaining_epochs:.0f} min total")

---
## 5️⃣ Training Loop

In [None]:
import time
from tqdm.auto import tqdm

def train_epoch(model, dataloader, optimizer, scheduler, device):
    model.train()
    total_loss = 0
    pbar = tqdm(dataloader, desc="Training")
    for batch in pbar:
        input_ids = batch['input_ids'].to(device)
        outputs = model(input_ids, labels=input_ids)
        loss = outputs['loss']
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        total_loss += loss.item()
        pbar.set_postfix({'loss': f"{loss.item():.4f}"})
    return total_loss / len(dataloader)

@torch.no_grad()
def validate(model, dataloader, device):
    model.eval()
    total_loss = 0
    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        outputs = model(input_ids, labels=input_ids)
        total_loss += outputs['loss'].item()
    return total_loss / len(dataloader)

# ═══════════════════════════════════════════════════════════════════════════
# TRAINING LOOP - With checkpoint resuming
# ═══════════════════════════════════════════════════════════════════════════

# Load history if resuming
history_path = f'{checkpoint_dir}/training_history.json'
if os.path.exists(history_path) and START_EPOCH > 1:
    with open(history_path, 'r') as f:
        history = json.load(f)
    best_val_loss = min(h['val_loss'] for h in history)
    print(f"📈 Loaded history: {len(history)} epochs, best val_loss: {best_val_loss:.4f}")
else:
    history = []
    best_val_loss = float('inf')

print(f"\n🚀 Starting training from epoch {START_EPOCH}...")
start_time = time.time()

for epoch in range(START_EPOCH, EPOCHS + 1):
    epoch_start = time.time()
    
    train_loss = train_epoch(model, train_loader, optimizer, scheduler, DEVICE)
    val_loss = validate(model, val_loader, DEVICE)
    
    epoch_time = time.time() - epoch_start
    history.append({'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss})
    
    print(f"Epoch {epoch}/{EPOCHS} | Train: {train_loss:.4f} | Val: {val_loss:.4f} | Time: {epoch_time:.1f}s")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        model.save('best_model.pt')
        model.save(f'{checkpoint_dir}/best_model.pt')
        print(f"  💾 New best model! (val_loss: {val_loss:.4f})")
    
    # Save checkpoint EVERY epoch (for 4-hour session limit)
    model.save(f'{checkpoint_dir}/checkpoint_epoch_{epoch}.pt')
    
    # Save history after each epoch
    with open(history_path, 'w') as f:
        json.dump(history, f, indent=2)
    
    # Time estimate
    elapsed = time.time() - start_time
    epochs_done = epoch - START_EPOCH + 1
    epochs_left = EPOCHS - epoch
    if epochs_done > 0:
        eta = (elapsed / epochs_done) * epochs_left
        print(f"  ⏱️ ETA: {eta/60:.0f} min remaining")

total_time = time.time() - start_time
print(f"\n✅ Session complete! Trained epochs {START_EPOCH}-{epoch} in {total_time/60:.1f} minutes")
print(f"🏆 Best validation loss: {best_val_loss:.4f}")
print(f"💾 Checkpoints saved to: {checkpoint_dir}")

# Show how to resume
if epoch < EPOCHS:
    print(f"\n📌 To continue: Just run this notebook again from the start!")
    print(f"   It will auto-resume from epoch {epoch + 1}")

---
## 6️⃣ Test Generation

In [None]:
# Load best model
model = FonteModel.load('best_model.pt', device=DEVICE)
model = model.to(DEVICE)

# Style and character token mappings
STYLE_IDS = {
    'serif': 28,
    'sans-serif': 29,
    'monospace': 30,
    'handwriting': 31,
    'display': 32,
}

CHAR_IDS = {char: 33 + i for i, char in enumerate("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789!@#$%&*()-+=[]")}

# Generate a sample
style = 'serif'
char = 'A'

tokens = model.generate(
    style_id=STYLE_IDS[style],
    char_id=CHAR_IDS[char],
    max_length=256,
    temperature=0.8,
    top_k=50,
)

print(f"Generated {style} '{char}':")
print(f"Tokens: {tokens[0].tolist()[:30]}...")
print(f"Total tokens: {tokens.shape[1]}")

---
## 7️⃣ Save Training History & Export

In [None]:
# Save to Drive
!mkdir -p /content/drive/MyDrive/fonte_ai

model.save('/content/drive/MyDrive/fonte_ai/final_model.pt')
print("✅ Model saved to Google Drive!")

# Save training history
with open('/content/drive/MyDrive/fonte_ai/training_history.json', 'w') as f:
    json.dump(history, f, indent=2)

# Download to local
from google.colab import files
files.download('best_model.pt')

## 🎉 Done!

Your model is trained! Next steps:
1. Download `best_model.pt`
2. Use the generation script to create fonts
3. Convert to TTF using svgtofont