# Gemma Fine-Tuning on Apple Silicon - Quick Start

This notebook demonstrates how to fine-tune Google Gemma models locally on Apple Silicon.

## 1. Setup and Imports

In [None]:
import sys
import os

# Add src to path
sys.path.append('../src')

import torch
from config import TrainingConfig
from train import train

print(f"PyTorch version: {torch.__version__}")
print(f"MPS available: {torch.backends.mps.is_available()}")

## 2. Configure Training

Adjust these parameters based on your needs:

In [None]:
config = TrainingConfig(
    model_name="google/gemma-2-270m",  # Start with the smallest model
    dataset_name="bebechien/MobileGameNPC",
    num_train_epochs=5,
    per_device_train_batch_size=4,
    learning_rate=5e-5,
    max_seq_length=512,
    output_dir="../outputs",
    logging_dir="../logs",
    use_mps=True,
    gradient_checkpointing=True,
    # hf_token="your_token_here"  # Uncomment and add your HF token if needed
)

print(f"Training device: {config.get_device_info()}")
print(f"Model: {config.model_name}")
print(f"Dataset: {config.dataset_name}")

## 3. Start Training

This will download the model and dataset, then start fine-tuning:

In [None]:
# Start training
train(config)

## 4. Test the Fine-tuned Model

After training completes, test the model:

In [None]:
from inference import GemmaInference

# Load the fine-tuned model
model_path = "../outputs/final_model"
inference = GemmaInference(model_path)

# Test with a prompt
prompt = "Hello! Tell me about yourself."
response = inference.generate(prompt, max_new_tokens=100)

print(f"Prompt: {prompt}")
print(f"Response: {response}")

## 5. View Training Logs

To view TensorBoard logs, run in terminal:
```bash
tensorboard --logdir ../logs
```

## Tips for Apple Silicon

1. **Start small**: Begin with gemma-2-270m before trying larger models
2. **Monitor memory**: Use Activity Monitor to watch memory usage
3. **Batch size**: Reduce if you run out of memory
4. **Gradient checkpointing**: Keep enabled to save memory
5. **Use float32**: MPS works best with float32 precision