# Image Classification Demo

This notebook demonstrates the image classification package with:
- Data loading with train/val/test splits
- Training CNN and ViT models
- Experiment tracking with wandb
- Visualization of results
- Hyperparameter tuning

In [None]:
# Import standard libraries
import torch
import numpy as np
import matplotlib.pyplot as plt

# Import our package
from image_classification import (
    # Data
    create_dataloaders,
    # Models
    get_model,
    list_models,
    # Training
    Trainer,
    evaluate_model,
    # Experiments
    HyperparameterTuner,
    # Visualization
    plot_training_curves,
    plot_confusion_matrices,
    print_model_summary,
)

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")

## 1. Setup Device

In [None]:
# Device configuration
if torch.backends.mps.is_available():
    device = torch.device('mps')
elif torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print(f"Using device: {device}")

## 2. Load Data

The data module handles:
- Downloading FashionMNIST
- Train/val/test splitting (stratified)
- Data augmentation (only for training)
- Optional class imbalance handling

In [None]:
# Create dataloaders with train/val/test split
dataloaders, info = create_dataloaders(
    dataset_name="fashion_mnist",
    data_dir="./data",
    batch_size=64,
    train_ratio=0.85,
    val_ratio=0.15,
    augmentation_level="standard",  # Options: none, light, standard, heavy
    use_imbalanced_sampler=False,    # Set True for imbalanced datasets
)

print("Dataset Info:")
print(f"  - Num classes: {info['num_classes']}")
print(f"  - Image size: {info['img_size']}x{info['img_size']}")
print(f"  - Channels: {info['in_channels']}")
print(f"  - Train samples: {info['train_size']}")
print(f"  - Val samples: {info['val_size']}")
print(f"  - Test samples: {info['test_size']}")
print(f"\nClass names: {info['class_names']}")

## 3. Available Models

In [None]:
print("Available models:")
for model_name in list_models():
    print(f"  - {model_name}")

## 4. Train Models

Train both CNN and ViT models with wandb logging.

In [None]:
# Setup wandb (run once)
import wandb
wandb.login()

In [None]:
# Train CNN model
print("=" * 60)
print("Training CNN Model")
print("=" * 60)

cnn_model = get_model(
    "cnn",
    in_channels=info['in_channels'],
    num_classes=info['num_classes'],
    base_filters=32,
    num_conv_layers=3,
    dropout=0.25,
)

cnn_trainer = Trainer(cnn_model, device)
cnn_history = cnn_trainer.fit(
    dataloaders['train'],
    dataloaders['val'],
    num_epochs=10,
    lr=1e-3,
    use_wandb=True,
    wandb_project="image-classification-demo",
    wandb_run_name="CNN",
)

print(f"\nCNN Best Val Accuracy: {max(cnn_history['val_acc']):.2f}%")

In [None]:
# Train Vision Transformer
print("=" * 60)
print("Training Vision Transformer")
print("=" * 60)

vit_model = get_model(
    "vit",
    img_size=info['img_size'],
    patch_size=7,
    in_channels=info['in_channels'],
    num_classes=info['num_classes'],
    embed_dim=128,
    num_heads=4,
    num_layers=4,
    dropout=0.1,
)

vit_trainer = Trainer(vit_model, device)
vit_history = vit_trainer.fit(
    dataloaders['train'],
    dataloaders['val'],
    num_epochs=10,
    lr=3e-4,  # ViT typically needs lower learning rate
    use_wandb=True,
    wandb_project="image-classification-demo",
    wandb_run_name="ViT",
)

print(f"\nViT Best Val Accuracy: {max(vit_history['val_acc']):.2f}%")

## 5. Visualize Training Results

In [None]:
# Plot training curves for both models
results = [
    (cnn_history, 'CNN'),
    (vit_history, 'ViT'),
]
plot_training_curves(results)

In [None]:
# Print model summary
models_info = [
    (cnn_model, cnn_history, 'CNN'),
    (vit_model, vit_history, 'ViT'),
]
print_model_summary(models_info)

## 6. Evaluate on Test Set

In [None]:
# Evaluate CNN on test set
print("CNN Test Results:")
cnn_results = evaluate_model(
    cnn_model,
    dataloaders['test'],
    device,
    class_names=info['class_names'],
)

In [None]:
# Evaluate ViT on test set
print("ViT Test Results:")
vit_results = evaluate_model(
    vit_model,
    dataloaders['test'],
    device,
    class_names=info['class_names'],
)

In [None]:
# Plot confusion matrices
cm_results = [
    (cnn_results['y_true'], cnn_results['y_pred'], 'CNN'),
    (vit_results['y_true'], vit_results['y_pred'], 'ViT'),
]
plot_confusion_matrices(cm_results, class_names=info['class_names'])

## 7. Hyperparameter Tuning (Optional)

Example of grid search and random search.

In [None]:
# Example: Grid search for CNN (uncomment to run)

# tuner = HyperparameterTuner(
#     model_name="cnn",
#     device=device,
#     base_model_config={'in_channels': 1, 'num_classes': 10},
# )

# results = tuner.grid_search(
#     dataloaders['train'],
#     dataloaders['val'],
#     param_grid={
#         'lr': [1e-3, 1e-4],
#         'dropout': [0.2, 0.3],
#         'base_filters': [16, 32],
#     },
#     num_epochs=5,
#     use_wandb=True,
# )

# print(f"Best config: {tuner.get_best_config()}")

## 8. Adding Custom Models

Example of how to register and use a custom model.

In [None]:
from image_classification.models import register_model, BaseClassifier
import torch.nn as nn

@register_model("simple_mlp")
class SimpleMLP(BaseClassifier):
    """A simple MLP classifier for comparison."""
    
    def __init__(self, in_channels=1, img_size=28, num_classes=10, hidden_dim=256):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc = nn.Sequential(
            nn.Linear(in_channels * img_size * img_size, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, num_classes),
        )
    
    def forward(self, x):
        x = self.flatten(x)
        return self.fc(x)

# Verify it's registered
print(f"Available models: {list_models()}")

## Summary

This demo covered:
1. **Data Loading**: Using `create_dataloaders` for automatic train/val/test splitting
2. **Model Selection**: Using `get_model` to instantiate models from the registry
3. **Training**: Using `Trainer` class with wandb integration
4. **Evaluation**: Using `evaluate_model` for comprehensive metrics
5. **Visualization**: Using `plot_training_curves` and `plot_confusion_matrices`
6. **Hyperparameter Tuning**: Using `HyperparameterTuner` for grid/random search
7. **Custom Models**: Using `@register_model` decorator to add new models

Check your wandb dashboard to see the logged experiments!