# Handwriting Synthesis Model Fine-Tuning on Google Colab

This notebook fine-tunes a pre-trained handwriting synthesis model using your personal handwriting data from HuggingFace.

## Features:
- Loads pre-trained model from existing checkpoints
- Fine-tunes on your HuggingFace dataset (finnbusse/v3testing format)
- tqdm progress bars for training visualization
- Sample generation every 5-10 epochs with multiple visualization methods
- ONNX model export
- Automatic upload to HuggingFace Hub (finnbusse/handwriting-synthesis-models)
- CUDA/T4 GPU optimized

**Note:** Select GPU runtime: Runtime > Change runtime type > GPU (T4)

## 1. Setup and Installation

In [None]:
# Clone the repository
!git clone https://github.com/finnbusse/pytorch-handwriting-synthesis-toolkit.git
%cd pytorch-handwriting-synthesis-toolkit

In [None]:
# Install dependencies (compatible with Google Colab Python 3.11+ and PyTorch 2.x)
!pip install -q Pillow h5py svgwrite datasets tqdm huggingface_hub onnx onnxruntime

In [None]:
import os
import sys
import json
import math
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from IPython.display import display, Image, clear_output, HTML
from datasets import load_dataset
from datetime import datetime
import h5py
import io
import base64
from PIL import Image as PILImage

# Add the repository to the path
sys.path.insert(0, os.getcwd())

from handwriting_synthesis import data, models, utils, losses, metrics
from handwriting_synthesis.data import Tokenizer
from handwriting_synthesis.sampling import HandwritingSynthesizer
from handwriting_synthesis.tasks import HandwritingSynthesisTask
from handwriting_synthesis.optimizers import CustomRMSprop

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

## 2. Configuration

In [None]:
# =============================================================================
# CONFIGURATION - Fine-tuning settings
# =============================================================================

# HuggingFace Settings
HUGGINGFACE_DATASET = "finnbusse/v3testing"  # Your handwriting dataset
HUGGINGFACE_MODEL_REPO = "finnbusse/handwriting-synthesis-models"  # Where to upload models
HF_TOKEN = ""  # Your HuggingFace token (set this for private repos and uploads)

# Pre-trained model source (from this repo's checkpoints or HuggingFace)
USE_PRETRAINED = True  # Set False to train from scratch
PRETRAINED_CHECKPOINT = "checkpoints/Epoch_50"  # Local checkpoint path

# Fine-tuning hyperparameters (optimized for small personal datasets)
config = {
    'batch_size': 8,              # Smaller batch for fine-tuning
    'epochs': 100,                # Fine-tuning epochs
    'learning_rate': 0.00005,     # Lower LR for fine-tuning (half of training LR)
    'hidden_size': 400,           # Must match pre-trained model
    'gradient_clip_output': 100,
    'gradient_clip_lstm': 10,
    'validation_split': 0.15,
    'max_seq_length': 500,
    'sample_interval': 5,         # Show samples every N epochs
}

# Directories
MODEL_SAVE_DIR = 'finetuned_checkpoints'
SAMPLES_DIR = 'finetuning_samples'
ONNX_DIR = 'onnx_export'

os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
os.makedirs(SAMPLES_DIR, exist_ok=True)
os.makedirs(ONNX_DIR, exist_ok=True)

# Generate unique model ID for this training run
TRAINING_TIMESTAMP = datetime.now().strftime('%Y%m%d_%H%M%S')
MODEL_ID = f"finetuned_{TRAINING_TIMESTAMP}"

print(f"Model ID: {MODEL_ID}")
print(f"Configuration:")
for key, value in config.items():
    print(f"  {key}: {value}")

## 3. Load Dataset from HuggingFace

In [None]:
# Load the dataset from HuggingFace
print(f"Loading dataset: {HUGGINGFACE_DATASET}")
try:
    if HF_TOKEN:
        dataset = load_dataset(HUGGINGFACE_DATASET, token=HF_TOKEN)
    else:
        dataset = load_dataset(HUGGINGFACE_DATASET)
except Exception as e:
    print(f"Error loading dataset: {e}")
    raise

# Get the training split
if 'train' in dataset:
    raw_data = dataset['train']
else:
    split_name = list(dataset.keys())[0]
    raw_data = dataset[split_name]

print(f"\nTotal samples: {len(raw_data)}")
print(f"\nSample columns: {raw_data.column_names}")
if len(raw_data) > 0:
    sample = raw_data[0]
    print(f"First sample text: '{sample.get('text', 'N/A')}'")
    print(f"Stroke points: {len(sample.get('dx', []))}")

## 4. Data Preprocessing

In [None]:
def prepare_training_data(raw_data, max_seq_length=500):
    """
    Prepare training data from HuggingFace dataset.
    The dataset format: dx, dy, eos lists (already in offset format)
    """
    prepared_data = []
    skipped = 0
    
    for item in tqdm(raw_data, desc="Converting data"):
        text = item.get('text', '')
        dx = item.get('dx', [])
        dy = item.get('dy', [])
        eos = item.get('eos', [])
        
        if not text or not dx or not dy or not eos:
            skipped += 1
            continue
        
        min_len = min(len(dx), len(dy), len(eos))
        if min_len == 0 or min_len > max_seq_length:
            skipped += 1
            continue
        
        # Combine into offset tuples (dx, dy, eos)
        offsets = [(float(dx[i]), float(dy[i]), int(eos[i])) for i in range(min_len)]
        
        # Ensure last point has eos=1
        if offsets:
            last = offsets[-1]
            offsets[-1] = (last[0], last[1], 1)
        
        prepared_data.append((offsets, text))
    
    if skipped > 0:
        print(f"Skipped {skipped} samples")
    
    return prepared_data

# Prepare the data
print("Preparing training data...")
all_data = prepare_training_data(raw_data, config['max_seq_length'])
print(f"Prepared {len(all_data)} valid samples")

In [None]:
# Split into train/val
random.seed(42)
shuffled_data = all_data.copy()
random.shuffle(shuffled_data)

val_size = int(len(shuffled_data) * config['validation_split'])
train_data = shuffled_data[:-val_size] if val_size > 0 else shuffled_data
val_data = shuffled_data[-val_size:] if val_size > 0 else []

print(f"Training samples: {len(train_data)}")
print(f"Validation samples: {len(val_data)}")

## 5. Build Charset and Create H5 Datasets

In [None]:
# Build charset from training data
def build_charset(data_list):
    charset = set()
    for _, text in data_list:
        charset.update(set(text))
    return ''.join(sorted(charset))

charset = build_charset(train_data)
print(f"Charset ({len(charset)} characters): '{charset}'")

# Add newline as sentinel if not present
if '\n' not in charset:
    charset = charset + '\n'

tokenizer = Tokenizer(charset)
print(f"Tokenizer size: {tokenizer.size}")

In [None]:
# Save data to H5 format for compatibility with the training pipeline
def save_data_to_h5(data_list, save_path, max_length):
    """Save prepared data to H5 format."""
    with h5py.File(save_path, 'w') as f:
        dt = h5py.string_dtype(encoding='utf-8')
        ds_sequences = f.create_dataset('sequences', (0, max_length, 3), maxshape=(None, max_length, 3))
        ds_lengths = f.create_dataset('lengths', (0,), maxshape=(None,), dtype='i2')
        ds_texts = f.create_dataset('texts', (0,), maxshape=(None,), dtype=dt)
        ds_sequences.attrs['max_length'] = max_length
        
        for i, (points, text) in enumerate(tqdm(data_list, desc="Saving to H5")):
            a = np.array(points, dtype=np.float16)
            unpadded_length = len(a)
            padding_value = max_length - unpadded_length
            a = np.pad(a, pad_width=[(0, padding_value), (0, 0)])
            
            ds_sequences.resize((i + 1, max_length, 3))
            ds_sequences[i] = a
            ds_lengths.resize((i + 1,))
            ds_texts.resize((i + 1,))
            ds_lengths[i] = unpadded_length
            ds_texts[i] = text
    
    # Compute mu and std
    with h5py.File(save_path, 'r') as f:
        ds_lengths = f['lengths']
        ds_sequences = f['sequences']
        num_examples = len(ds_lengths)
        
        s = np.zeros(3, dtype=np.float32)
        n = 0
        for i in range(num_examples):
            seq_len = ds_lengths[i]
            seq = ds_sequences[i][:seq_len]
            s += seq.sum(axis=0)
            n += seq_len
        mu = s / n if n > 0 else np.zeros(3)
        mu[2] = 0.
        
        squared_sum = np.zeros(3, dtype=np.float32)
        for i in range(num_examples):
            seq_len = ds_lengths[i]
            seq = ds_sequences[i][:seq_len]
            squared_sum += ((seq - mu) ** 2).sum(axis=0)
        std = np.sqrt(squared_sum / n) if n > 0 else np.ones(3)
        std[2] = 1.
    
    with h5py.File(save_path, 'a') as f:
        ds_sequences = f['sequences']
        ds_sequences.attrs['mu'] = mu
        ds_sequences.attrs['std'] = std
    
    return tuple(mu), tuple(std)

# Get max sequence length
dataset_max_length = max(len(pts) for pts, _ in all_data) if all_data else config['max_seq_length']
dataset_max_length = min(dataset_max_length, config['max_seq_length'])
print(f"Max sequence length: {dataset_max_length}")

# Save train and val data
train_path = 'finetune_train.h5'
val_path = 'finetune_val.h5'

mu, std = save_data_to_h5(train_data, train_path, dataset_max_length)
if val_data:
    save_data_to_h5(val_data, val_path, dataset_max_length)
else:
    # Use train data for validation if no val data
    save_data_to_h5(train_data[:max(1, len(train_data)//10)], val_path, dataset_max_length)

print(f"\nData statistics:")
print(f"  mu: {mu}")
print(f"  std: {std}")

## 6. Load Pre-trained Model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load or create model
if USE_PRETRAINED and os.path.exists(PRETRAINED_CHECKPOINT):
    print(f"\nLoading pre-trained model from: {PRETRAINED_CHECKPOINT}")
    synthesizer = HandwritingSynthesizer.load(PRETRAINED_CHECKPOINT, device, bias=0)
    model = synthesizer.model
    pretrained_charset = synthesizer.tokenizer.charset
    
    # Check charset compatibility
    print(f"Pre-trained charset: '{pretrained_charset}'")
    print(f"Dataset charset: '{charset}'")
    
    # Use the larger charset (union of both)
    combined_charset = ''.join(sorted(set(pretrained_charset) | set(charset)))
    if '\n' not in combined_charset:
        combined_charset += '\n'
    
    # If charsets differ, we need to create a new model with the combined charset
    if combined_charset != pretrained_charset:
        print(f"\nExpanding charset to: '{combined_charset}'")
        new_alphabet_size = len(combined_charset) + 1  # +1 for padding
        
        # Create new model with expanded alphabet
        new_model = models.SynthesisNetwork.get_default_model(new_alphabet_size, device)
        new_model = new_model.to(device)
        
        # Transfer compatible weights
        old_state = model.state_dict()
        new_state = new_model.state_dict()
        
        for name, param in old_state.items():
            if name in new_state:
                if param.shape == new_state[name].shape:
                    new_state[name] = param
                else:
                    # Handle dimension mismatch (alphabet size change)
                    print(f"  Skipping {name} due to shape mismatch: {param.shape} vs {new_state[name].shape}")
        
        new_model.load_state_dict(new_state)
        model = new_model
        tokenizer = Tokenizer(combined_charset)
    else:
        tokenizer = Tokenizer(combined_charset)
    
    print(f"\nModel loaded successfully!")
    print(f"Final alphabet size: {tokenizer.size}")
else:
    print("\nCreating new model (no pre-trained weights)")
    alphabet_size = tokenizer.size
    model = models.SynthesisNetwork.get_default_model(alphabet_size, device)
    model = model.to(device)

# Create synthesizer for saving
mu_tensor = torch.tensor(mu, dtype=torch.float32)
std_tensor = torch.tensor(std, dtype=torch.float32)
synthesizer = HandwritingSynthesizer(model, mu_tensor, std_tensor, tokenizer.charset, num_steps=dataset_max_length)

## 7. Training Utilities

In [None]:
def collate_fn(batch):
    """Custom collate function for DataLoader."""
    x = [item[0] for item in batch]
    y = [item[1] for item in batch]
    return x, y


def compute_loss(model, batch, tokenizer, device):
    """Compute loss for a batch."""
    points, transcriptions = batch
    ground_true = utils.PaddedSequencesBatch(points, device=device)
    
    batch_size, steps, input_dim = ground_true.tensor.shape
    prefix = torch.zeros(batch_size, 1, input_dim, device=device)
    x = torch.cat([prefix, ground_true.tensor[:, :-1]], dim=1)
    
    c = data.transcriptions_to_tensor(tokenizer, transcriptions).to(device)
    
    mixtures, eos_hat = model(x, c)
    loss = losses.nll_loss(mixtures, eos_hat, ground_true)
    
    return (mixtures, eos_hat), loss


def generate_sample(model, tokenizer, mu, std, text, device, max_steps=500):
    """Generate a handwriting sample."""
    model.eval()
    with torch.no_grad():
        sentinel = '\n'
        text_with_sentinel = text + sentinel
        c = data.transcriptions_to_tensor(tokenizer, [text_with_sentinel]).to(device)
        sample = model.sample_means(context=c, steps=max_steps, stochastic=True)
        sample = sample.cpu() * torch.tensor(std) + torch.tensor(mu)
    model.train()
    return sample

In [None]:
def visualize_sample_multiple(sample, text, epoch, sample_idx, save_dir):
    """
    Visualize sample using multiple methods for robust preview.
    Uses PIL, matplotlib, and SVG for redundancy.
    """
    os.makedirs(save_dir, exist_ok=True)
    
    if sample is None or len(sample) == 0:
        print(f"  '{text}' - [Empty sample]")
        return
    
    sample_np = sample.cpu().numpy() if hasattr(sample, 'cpu') else np.array(sample)
    x_range = sample_np[:, 0].max() - sample_np[:, 0].min()
    y_range = sample_np[:, 1].max() - sample_np[:, 1].min()
    
    if abs(x_range) > 5000 or abs(y_range) > 5000:
        print(f"  '{text}' - [Wild coordinates, model still learning]")
        return
    
    # Method 1: PIL Image (primary)
    try:
        png_path = os.path.join(save_dir, f'epoch_{epoch}_sample_{sample_idx}.png')
        utils.visualize_strokes(sample, png_path, lines=True, thickness=8)
        if os.path.exists(png_path) and os.path.getsize(png_path) > 0:
            display(Image(filename=png_path))
            print(f"  '{text}'")
            return
    except Exception as e:
        pass
    
    # Method 2: Matplotlib fallback
    try:
        fig, ax = plt.subplots(figsize=(12, 4))
        ax.set_facecolor('white')
        
        # Convert to absolute coordinates
        x, y = 0, 0
        positions = []
        for dx, dy, eos in sample_np:
            x += dx
            y += dy
            positions.append((x, y, eos))
        
        # Draw strokes
        current_stroke_x, current_stroke_y = [], []
        for px, py, eos in positions:
            current_stroke_x.append(px)
            current_stroke_y.append(py)
            if eos > 0.5:
                ax.plot(current_stroke_x, current_stroke_y, 'k-', linewidth=2)
                current_stroke_x, current_stroke_y = [], []
        if current_stroke_x:
            ax.plot(current_stroke_x, current_stroke_y, 'k-', linewidth=2)
        
        ax.invert_yaxis()
        ax.set_aspect('equal')
        ax.axis('off')
        ax.set_title(f"'{text}'")
        
        plt_path = os.path.join(save_dir, f'epoch_{epoch}_sample_{sample_idx}_mpl.png')
        plt.savefig(plt_path, dpi=100, bbox_inches='tight', facecolor='white')
        plt.show()
        plt.close()
        print(f"  '{text}' (matplotlib)")
    except Exception as e:
        print(f"  '{text}' - [Visualization failed: {str(e)[:50]}]")

## 8. Fine-Tuning Training Loop

In [None]:
# Setup optimizer with lower learning rate for fine-tuning
optimizer = CustomRMSprop(
    model.parameters(),
    lr=config['learning_rate'],
    alpha=0.95,
    eps=1e-4,
    momentum=0.9,
    centered=True
)

clip_output = config['gradient_clip_output']
clip_lstm = config['gradient_clip_lstm']

print(f"Optimizer: CustomRMSprop, LR: {config['learning_rate']}")
print(f"Gradient clipping: output={clip_output}, lstm={clip_lstm}")

In [None]:
# Training loop
print("\n" + "="*60)
print("Starting Fine-Tuning")
print("="*60 + "\n")

EPOCHS = config['epochs']
BATCH_SIZE = config['batch_size']
SAMPLE_INTERVAL = config['sample_interval']

train_losses = []
val_losses = []
best_val_loss = float('inf')
best_model_dir = os.path.join(MODEL_SAVE_DIR, 'best_model')

# Open datasets
train_dataset = data.NormalizedDataset(train_path, mu, std)
val_dataset = data.NormalizedDataset(val_path, mu, std)

# Sample texts for visualization
sample_texts = []
for i in range(min(3, len(val_dataset))):
    _, text = val_dataset[i]
    sample_texts.append(text)
print(f"Sample texts: {sample_texts}")

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

num_train_batches = len(train_loader)
num_val_batches = len(val_loader)

print(f"Training batches: {num_train_batches}")
print(f"Validation batches: {num_val_batches}")

try:
    epoch_pbar = tqdm(range(1, EPOCHS + 1), desc="Fine-tuning", unit="epoch")
    
    for epoch in epoch_pbar:
        # Training phase
        model.train()
        epoch_train_loss = 0.0
        
        batch_pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}", leave=False, unit="batch")
        for batch_idx, batch in enumerate(batch_pbar):
            optimizer.zero_grad()
            y_hat, loss = compute_loss(model, batch, tokenizer, device)
            loss.backward()
            model.clip_gradients(output_clip_value=clip_output, lstm_clip_value=clip_lstm)
            optimizer.step()
            epoch_train_loss += loss.item()
            batch_pbar.set_postfix({'loss': f'{loss.item():.2f}'})
        
        avg_train_loss = epoch_train_loss / max(1, num_train_batches)
        train_losses.append(avg_train_loss)
        
        # Validation phase
        model.eval()
        epoch_val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                _, loss = compute_loss(model, batch, tokenizer, device)
                epoch_val_loss += loss.item()
        
        avg_val_loss = epoch_val_loss / max(1, num_val_batches)
        val_losses.append(avg_val_loss)
        
        epoch_pbar.set_postfix({'train': f'{avg_train_loss:.2f}', 'val': f'{avg_val_loss:.2f}'})
        
        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            synthesizer.save(best_model_dir)
        
        # Generate samples every SAMPLE_INTERVAL epochs
        if epoch % SAMPLE_INTERVAL == 0:
            print(f"\n\n{'='*60}")
            print(f"Epoch {epoch} - Samples (Train: {avg_train_loss:.2f}, Val: {avg_val_loss:.2f})")
            print(f"{'='*60}")
            
            for i, text in enumerate(sample_texts[:2]):
                sample = generate_sample(model, tokenizer, mu, std, text, device, dataset_max_length)
                visualize_sample_multiple(sample, text, epoch, i, SAMPLES_DIR)
            print()
        
        # Save checkpoint every 10 epochs
        if epoch % 10 == 0:
            checkpoint_dir = os.path.join(MODEL_SAVE_DIR, f'Epoch_{epoch}')
            synthesizer.save(checkpoint_dir)
    
    print("\n" + "="*60)
    print("Fine-Tuning Complete!")
    print(f"Best validation loss: {best_val_loss:.4f}")
    print("="*60)

finally:
    train_dataset.close()
    val_dataset.close()

## 9. Training Results

In [None]:
# Plot training history
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss', color='blue', alpha=0.8)
plt.plot(val_losses, label='Validation Loss', color='orange', alpha=0.8)
plt.xlabel('Epoch')
plt.ylabel('Loss (nats)')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
window = min(10, len(train_losses) // 4) if len(train_losses) > 4 else 1
if window > 1:
    train_smooth = np.convolve(train_losses, np.ones(window)/window, mode='valid')
    val_smooth = np.convolve(val_losses, np.ones(window)/window, mode='valid')
    plt.plot(range(window-1, len(train_losses)), train_smooth, label='Training (smoothed)')
    plt.plot(range(window-1, len(val_losses)), val_smooth, label='Validation (smoothed)')
else:
    plt.plot(train_losses, label='Training')
    plt.plot(val_losses, label='Validation')
plt.xlabel('Epoch')
plt.ylabel('Loss (nats)')
plt.title('Smoothed Loss')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('finetuning_history.png', dpi=150)
plt.show()

print(f"Final Train Loss: {train_losses[-1]:.4f}")
print(f"Final Val Loss: {val_losses[-1]:.4f}")
print(f"Best Val Loss: {best_val_loss:.4f}")

## 10. Generate Final Samples

In [None]:
# Load best model
best_synthesizer = HandwritingSynthesizer.load(best_model_dir, device, bias=1.0)

print("\n" + "="*60)
print("Final Generated Samples (Best Model)")
print("="*60 + "\n")

val_dataset = data.NormalizedDataset(val_path, mu, std)

for i in range(min(5, len(val_dataset))):
    _, text = val_dataset[i]
    sample = generate_sample(best_synthesizer.model, tokenizer, mu, std, text, device, dataset_max_length)
    visualize_sample_multiple(sample, text, 'final', i, SAMPLES_DIR)

val_dataset.close()

## 11. Export to ONNX

In [None]:
print("Exporting model to ONNX format...")

# Import ONNX export model
import onnx_models

# Create ONNX-compatible model
onnx_model = onnx_models.SynthesisNetwork.get_default_model(tokenizer.size, torch.device('cpu'), bias=0)

# Load weights from best model
best_model_path = os.path.join(best_model_dir, 'model.pt')
state_dict = torch.load(best_model_path, map_location='cpu')
onnx_model.load_state_dict(state_dict)
onnx_model.eval()

# Create dummy inputs for export
alphabet_size = tokenizer.size
x = torch.randn(1, 1, 3, dtype=torch.float32)
c = torch.randn(1, 1, alphabet_size, dtype=torch.float32)
w = torch.randn(1, 1, alphabet_size, dtype=torch.float32)
k = torch.randn(1, 10, dtype=torch.float32)
h1 = torch.randn(1, 400, dtype=torch.float32)
c1 = torch.randn(1, 400, dtype=torch.float32)
h2 = torch.randn(1, 400, dtype=torch.float32)
c2 = torch.randn(1, 400, dtype=torch.float32)
h3 = torch.randn(1, 400, dtype=torch.float32)
c3 = torch.randn(1, 400, dtype=torch.float32)
bias = torch.randn(1, dtype=torch.float32)

onnx_path = os.path.join(ONNX_DIR, f'{MODEL_ID}.onnx')

torch.onnx.export(
    onnx_model,
    (x, c, w, k, h1, c1, h2, c2, h3, c3, bias),
    onnx_path,
    verbose=False,
    opset_version=11,
    input_names=['x', 'c', 'w', 'k', 'h1', 'c1', 'h2', 'c2', 'h3', 'c3', 'bias'],
    output_names=['pi', 'mu', 'sd', 'ro', 'eos', 'w_out', 'k_out', 'h1_out', 'c1_out', 'h2_out', 'c2_out', 'h3_out', 'c3_out', 'phi'],
    dynamic_axes={'c': {1: 'sequence'}, 'phi': {2: 'string_length'}}
)

print(f"ONNX model exported to: {onnx_path}")

# Verify ONNX model
import onnx
onnx_model_check = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model_check)
print("ONNX model verification: PASSED")

## 12. Upload to HuggingFace Hub

In [None]:
from huggingface_hub import HfApi, login
import shutil

# Login to HuggingFace
if HF_TOKEN:
    login(token=HF_TOKEN)
    print("Logged in to HuggingFace Hub")
else:
    print("No HF_TOKEN provided. Please set HF_TOKEN to upload models.")
    print("You can still download the model files manually.")

In [None]:
def upload_to_huggingface(model_dir, onnx_path, repo_id, model_id, epoch_num='best'):
    """
    Upload model files to HuggingFace Hub.
    Structure: repo_id/datetime/model_id/epoch_number/
    """
    if not HF_TOKEN:
        print("Skipping upload - no HF_TOKEN provided")
        return None
    
    api = HfApi()
    
    # Create path in repo
    timestamp = TRAINING_TIMESTAMP
    repo_path = f"{timestamp}/{model_id}/epoch_{epoch_num}"
    
    try:
        # Check if repo exists, create if not
        try:
            api.repo_info(repo_id=repo_id, repo_type="model")
        except Exception:
            print(f"Creating repository: {repo_id}")
            api.create_repo(repo_id=repo_id, repo_type="model", exist_ok=True)
        
        # Upload model files
        print(f"Uploading to {repo_id}/{repo_path}...")
        
        # Upload model.pt
        model_pt = os.path.join(model_dir, 'model.pt')
        if os.path.exists(model_pt):
            api.upload_file(
                path_or_fileobj=model_pt,
                path_in_repo=f"{repo_path}/model.pt",
                repo_id=repo_id,
                repo_type="model"
            )
            print(f"  Uploaded model.pt")
        
        # Upload meta.json
        meta_json = os.path.join(model_dir, 'meta.json')
        if os.path.exists(meta_json):
            api.upload_file(
                path_or_fileobj=meta_json,
                path_in_repo=f"{repo_path}/meta.json",
                repo_id=repo_id,
                repo_type="model"
            )
            print(f"  Uploaded meta.json")
        
        # Upload ONNX model
        if os.path.exists(onnx_path):
            onnx_filename = os.path.basename(onnx_path)
            api.upload_file(
                path_or_fileobj=onnx_path,
                path_in_repo=f"{repo_path}/{onnx_filename}",
                repo_id=repo_id,
                repo_type="model"
            )
            print(f"  Uploaded {onnx_filename}")
        
        # Create and upload config
        config_data = {
            'model_id': model_id,
            'timestamp': timestamp,
            'epoch': epoch_num,
            'charset': tokenizer.charset,
            'mu': list(mu),
            'std': list(std),
            'config': config
        }
        config_path = '/tmp/training_config.json'
        with open(config_path, 'w') as f:
            json.dump(config_data, f, indent=2)
        
        api.upload_file(
            path_or_fileobj=config_path,
            path_in_repo=f"{repo_path}/training_config.json",
            repo_id=repo_id,
            repo_type="model"
        )
        print(f"  Uploaded training_config.json")
        
        print(f"\nUpload complete! View at: https://huggingface.co/{repo_id}/tree/main/{repo_path}")
        return f"{repo_id}/{repo_path}"
        
    except Exception as e:
        print(f"Upload failed: {e}")
        return None

# Upload best model
upload_to_huggingface(best_model_dir, onnx_path, HUGGINGFACE_MODEL_REPO, MODEL_ID, 'best')

## 13. Download Model Files

In [None]:
# Create zip archive for download
import shutil

# Create export directory with all files
export_dir = f'export_{MODEL_ID}'
os.makedirs(export_dir, exist_ok=True)

# Copy files
if os.path.exists(best_model_dir):
    shutil.copytree(best_model_dir, os.path.join(export_dir, 'pytorch_model'), dirs_exist_ok=True)
if os.path.exists(onnx_path):
    shutil.copy(onnx_path, export_dir)

# Save training info
training_info = {
    'model_id': MODEL_ID,
    'timestamp': TRAINING_TIMESTAMP,
    'dataset': HUGGINGFACE_DATASET,
    'config': config,
    'final_train_loss': train_losses[-1] if train_losses else None,
    'final_val_loss': val_losses[-1] if val_losses else None,
    'best_val_loss': best_val_loss,
    'charset': tokenizer.charset,
    'mu': list(mu),
    'std': list(std)
}
with open(os.path.join(export_dir, 'training_info.json'), 'w') as f:
    json.dump(training_info, f, indent=2)

# Create zip
shutil.make_archive(f'handwriting_model_{MODEL_ID}', 'zip', export_dir)

print(f"\nModel files ready for download:")
print(f"  handwriting_model_{MODEL_ID}.zip")
print(f"\nContents:")
for root, dirs, files in os.walk(export_dir):
    level = root.replace(export_dir, '').count(os.sep)
    indent = '  ' * level
    print(f"{indent}{os.path.basename(root)}/")
    for file in files:
        print(f"{indent}  {file}")

In [None]:
# Download link for Colab
try:
    from google.colab import files
    files.download(f'handwriting_model_{MODEL_ID}.zip')
except ImportError:
    print("Not running in Colab - file saved locally")

## 14. Test Your Fine-Tuned Model

Enter custom text to generate handwriting samples with your fine-tuned model.

In [None]:
# Interactive testing
test_texts = [
    "Hallo Welt",
    "Das ist meine Handschrift",
    "Test"
]

print("Generating samples with fine-tuned model:\n")

for i, text in enumerate(test_texts):
    sample = generate_sample(best_synthesizer.model, tokenizer, mu, std, text, device, dataset_max_length)
    visualize_sample_multiple(sample, text, 'test', i, SAMPLES_DIR)

---

## Summary

This notebook:
1. ✅ Loaded your handwriting data from HuggingFace
2. ✅ Fine-tuned a pre-trained model on your personal handwriting
3. ✅ Generated preview samples every 5 epochs
4. ✅ Exported the model to ONNX format
5. ✅ Uploaded to HuggingFace Hub (if token provided)

Your model files are available for download above!