# Tutorial: Training & Fine-Tuning TinyDiT

Learn how to train or fine-tune the TinyDiT model on your own cat breed dataset.

## What You'll Learn

- Understand training configuration
- Prepare custom datasets
- Run local and Modal GPU training
- Export and upload models to HuggingFace

## Prerequisites

**For Local Training:**
- Python 3.10+
- PyTorch with CUDA (GPU recommended)
- 8GB+ VRAM for full training

**For Modal GPU Training:**
- Modal account (free tier available)
- Modal CLI installed

In [None]:
# Install dependencies
!pip install torch torchvision modal huggingface_hub pillow numpy matplotlib -q

## Part 1: Dataset Preparation

### Step 1.1: Organize Your Dataset

The training script expects the following structure:

In [None]:
from pathlib import Path

# Create example dataset structure
dataset_dir = Path("my_custom_cats")
breeds = ["abyssinian", "bengal", "persian", "other"]

for breed in breeds:
    (dataset_dir / breed).mkdir(parents=True, exist_ok=True)

print(f"Dataset structure created at: {dataset_dir.absolute()}")
print("\nExpected structure:")
print(f"""
{dataset_dir}/
├── abyssinian/
│   ├── image1.jpg
│   ├── image2.jpg
│   └── ...
├── bengal/
│   └── ...
├── persian/
│   └── ...
└── other/
    └── (non-cat images or unknown breeds)
""")

### Step 1.2: Dataset Requirements

| Requirement | Minimum | Recommended |
|-------------|---------|-------------|
| Images per breed | 50 | 200+ |
| Image size | 128x128 | 512x512+ |
| Formats | JPG, PNG | JPG |
| Total dataset | 500 images | 2000+ images |

In [None]:
def check_dataset_quality(dataset_path: str) -> dict:
    """
    Analyze dataset quality and provide recommendations.
    """
    from PIL import Image
    from collections import Counter
    
    dataset = Path(dataset_path)
    breeds = [d for d in dataset.iterdir() if d.is_dir()]
    
    stats = {
        "num_breeds": len(breeds),
        "breed_counts": {},
        "total_images": 0,
        "warnings": []
    }
    
    for breed_dir in breeds:
        images = list(breed_dir.glob("*.jpg")) + list(breed_dir.glob("*.png"))
        count = len(images)
        stats["breed_counts"][breed_dir.name] = count
        stats["total_images"] += count
        
        if count < 50:
            stats["warnings"].append(
                f"⚠️  {breed_dir.name}: Only {count} images (need 50+)"
            )
    
    return stats

print("Dataset quality checker defined!")
print("Usage: check_dataset_quality('my_custom_cats')")

## Part 2: Training Configuration

### Step 2.1: Define Training Parameters

In [None]:
# Training configuration
training_config = {
    # Dataset
    "data_dir": "my_custom_cats",
    
    # Training
    "steps": 50000,  # Fine-tuning (100k for full training)
    "batch_size": 32,  # Reduce if OOM
    "learning_rate": 5e-5,
    "gradient_accumulation_steps": 2,  # Effective batch = 64
    "warmup_steps": 5000,
    
    # Model
    "image_size": 128,
    "augmentations": "full",  # full, basic, or none
    
    # Output
    "checkpoint_interval": 10000,
    "output_dir": "checkpoints",
}

print("Training Configuration:")
print("=" * 40)
for key, value in training_config.items():
    print(f"{key:30s}: {value}")
print("=" * 40)

# Estimate training time
effective_batch = training_config["batch_size"] * training_config["gradient_accumulation_steps"]
estimated_gpu_hours = training_config["steps"] * effective_batch / 50000
print(f"\nEffective batch size: {effective_batch}")
print(f"Estimated GPU hours: {estimated_gpu_hours:.1f} (on H100/A10G)")

## Part 3: Local Training (GPU/CPU)

### Step 3.1: Check GPU Availability

In [None]:
import torch

if torch.cuda.is_available():
    print(f"✅ GPU Available: {torch.cuda.get_device_name(0)}")
    print(f"   VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"   CUDA: {torch.version.cuda}")
else:
    print("⚠️  No GPU detected - training will be slow on CPU")
    print("   Consider using Modal GPU training (Part 4)")

### Step 3.2: Run Local Training

In [None]:
import subprocess
import sys

# Build training command
cmd = [
    sys.executable, "src/train_dit.py",
    training_config["data_dir"],
    "--steps", str(training_config["steps"]),
    "--batch-size", str(training_config["batch_size"]),
    "--lr", str(training_config["learning_rate"]),
    "--gradient-accumulation-steps", str(training_config["gradient_accumulation_steps"]),
    "--warmup-steps", str(training_config["warmup_steps"]),
    "--augmentation-level", training_config["augmentations"],
]

print("Training command:")
print(" ".join(cmd))
print("\n⏳ Starting training... (this will take a while)")
print("=" * 60)

# Run training
# Note: Commented out to avoid running in notebook
# subprocess.run(cmd)

print("\n✅ Training would start with the above command")
print("To run training, uncomment the subprocess.run(cmd) line")

## Part 4: Modal GPU Training (Recommended)

### Step 4.1: Setup Modal

In [None]:
# Install Modal if not already installed
!pip install modal -q

# Authenticate (run once)
# !modal token set

print("Modal setup complete!")
print("Note: Run 'modal token set' to authenticate if prompted")

### Step 4.2: Define Modal Training Job

In [None]:
import modal

# Define Modal app
app = modal.App("tiny-cats-training")

# Define GPU image
image = modal.Image.debian_slim().pip_install(
    "torch", "torchvision", "modal", "huggingface_hub", "pillow", "numpy"
)

# Define training function
@app.function(
    gpu="T4",  # or "A10G", "H100"
    timeout=7200,  # 2 hours
    volumes={"/data": modal.Volume.from_name("cats-data", create_if_missing=True)}
)
def train_on_modal():
    import subprocess
    
    cmd = [
        "python", "src/train_dit.py",
        "/data/cats",  # Dataset in volume
        "--steps", "100000",
        "--batch-size", "256",
        "--gradient-accumulation-steps", "2",
        "--lr", "1e-4",
        "--augmentation-level", "full",
    ]
    
    subprocess.run(cmd)

print("Modal training function defined!")
print("Usage: train_on_modal.remote()")

### Step 4.3: Run Modal Training

In [None]:
# Run training (commented out to avoid actual execution)
# with app.run():
#     train_on_modal.remote()

print("Modal training ready!")
print("\nTo run training:")
print("1. Uncomment the code above")
print("2. Ensure dataset is in Modal volume")
print("3. Run the cell")
print("\nOr use CLI:")
print("  modal run src/train_dit.py data/cats --steps 100000")

## Part 5: Export Model

### Step 5.1: Load Trained Checkpoint

In [None]:
import torch
from pathlib import Path

# Find latest checkpoint
checkpoint_dir = Path("checkpoints")
checkpoints = list(checkpoint_dir.glob("dit_step_*.pt"))

if checkpoints:
    latest_checkpoint = max(checkpoints, key=lambda p: p.stat().st_mtime)
    print(f"Latest checkpoint: {latest_checkpoint}")
else:
    print("No checkpoints found - run training first")
    latest_checkpoint = None

### Step 5.2: Export to ONNX

In [None]:
if latest_checkpoint:
    from src.export_dit_onnx import export_generator_onnx
    
    print("Exporting model to ONNX...")
    export_generator_onnx(
        checkpoint_path=str(latest_checkpoint),
        output_path="custom_generator.onnx",
        opset=17,
    )
    
    import os
    size_mb = os.path.getsize("custom_generator.onnx") / 1e6
    print(f"✅ Exported to custom_generator.onnx ({size_mb:.2f} MB)")
else:
    print("Skipping export - no checkpoint available")

### Step 5.3: Quantization (Optional)

In [None]:
if Path("custom_generator.onnx").exists():
    from src.optimize_onnx import quantize_model_dynamic
    
    print("Quantizing model...")
    quantize_model_dynamic(
        input_path="custom_generator.onnx",
        output_path="custom_generator_quantized.onnx",
    )
    
    import os
    original_size = os.path.getsize("custom_generator.onnx") / 1e6
    quantized_size = os.path.getsize("custom_generator_quantized.onnx") / 1e6
    reduction = (1 - quantized_size / original_size) * 100
    
    print(f"✅ Quantized: {original_size:.2f}MB → {quantized_size:.2f}MB ({reduction:.1f}% reduction)")
else:
    print("Skipping quantization - export model first")

## Part 6: Upload to HuggingFace

### Step 6.1: Setup Repository

In [None]:
from huggingface_hub import HfApi, create_repo, login
import os

# Get HuggingFace username
username = input("Enter your HuggingFace username: ").strip()
repo_name = input("Enter repository name (e.g., my-cats-model): ").strip()
repo_id = f"{username}/{repo_name}"

print(f"\nRepository ID: {repo_id}")

# Create repository
print("\nCreating repository...")
try:
    create_repo(repo_id=repo_id, repo_type="model", exist_ok=True)
    print(f"✅ Repository created: https://huggingface.co/{repo_id}")
except Exception as e:
    print(f"⚠️  Repository may already exist: {e}")

### Step 6.2: Upload Files

In [None]:
api = HfApi()

# Files to upload
files_to_upload = [
    ("custom_generator.onnx", "generator/model.onnx"),
    ("custom_generator_quantized.onnx", "generator/model_quantized.onnx"),
]

print(f"\nUploading files to {repo_id}...")

for local_path, repo_path in files_to_upload:
    if Path(local_path).exists():
        print(f"  Uploading {local_path} → {repo_path}")
        try:
            api.upload_file(
                path_or_fileobj=local_path,
                path_in_repo=repo_path,
                repo_id=repo_id,
            )
            print(f"    ✅ Success")
        except Exception as e:
            print(f"    ❌ Error: {e}")
    else:
        print(f"  ⚠️  Skipping {local_path} (not found)")

print(f"\n✅ Upload complete!")
print(f"View your model: https://huggingface.co/{repo_id}")

### Step 6.3: Create Model Card

In [None]:
model_card = f"""---
license: mit
tags:
- cat
- diffusion
- generative
- pytorch
---

# {repo_name}

A custom-trained TinyDiT model for generating cat images.

## Model Details

- **Architecture:** TinyDiT (Diffusion Transformer)
- **Training Steps:** {training_config['steps']:,}
- **Dataset:** Custom cat breeds
- **License:** MIT

## Usage

```python
from huggingface_hub import hf_hub_download
import torch

# Load model
model_path = hf_hub_download(
    repo_id="{repo_id}",
    filename="generator/model.onnx"
)
```

## Training Configuration

```python
{training_config}
```

## Generated Samples

![Samples](samples.png)

## Citation

```bibtex
@misc{{{repo_name.replace('-', '_')}},
  title = {{{repo_name}}},
  author = {{Your Name}},
  year = {{2026}},
  url = {{https://huggingface.co/{repo_id}}}
}
```
"""

# Save model card
with open("README.md", "w") as f:
    f.write(model_card)

print("Model card created: README.md")

# Upload model card
api.upload_file(
    path_or_fileobj="README.md",
    path_in_repo="README.md",
    repo_id=repo_id,
)
print(f"✅ Model card uploaded!")

## Summary

✅ You've learned how to:
- Prepare a custom dataset
- Configure training parameters
- Run local and Modal GPU training
- Export models to ONNX
- Upload to HuggingFace

## Next Steps

- Share your model with the community
- Read [ADR-036](../plans/ADR-036-high-accuracy-training-configuration.md) for advanced training
- Try [Notebook 01](01_quickstart_classification.ipynb) for classification
- Try [Notebook 02](02_conditional_generation.ipynb) for generation

## Troubleshooting

### Issue: Out of Memory
**Solution:** Reduce batch_size or use gradient accumulation

### Issue: Slow Training
**Solution:** Use Modal GPU or increase num_workers

### Issue: Poor Sample Quality
**Solution:** Increase training steps or check data quality

### Issue: Upload Fails
**Solution:** Check HF_TOKEN permissions and repository access