# Getting Started with AI/ML Training

This notebook demonstrates how to use the AI_ML_Learning repository for training models.

## Contents
1. Environment Setup
2. Data Loading
3. Model Creation
4. Training
5. Evaluation

## 1. Environment Setup

In [None]:
import sys
import os

# Add src to path
sys.path.append('../')

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

from src.models import SimpleCNN, get_model
from src.training import Trainer, get_optimizer, get_scheduler
from src.data_utils import get_image_transforms, create_data_loaders

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda if torch.cuda.is_available() else 'N/A'}")

## 2. Data Loading

Let's load a sample dataset (MNIST for demonstration)

In [None]:
# Download and prepare MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(
    '../data/raw',
    train=True,
    download=True,
    transform=transform
)

val_dataset = datasets.MNIST(
    '../data/raw',
    train=False,
    download=True,
    transform=transform
)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

### Visualize some samples

In [None]:
# Visualize some images
fig, axes = plt.subplots(2, 5, figsize=(12, 5))
for i, ax in enumerate(axes.flat):
    image, label = train_dataset[i]
    ax.imshow(image.squeeze(), cmap='gray')
    ax.set_title(f'Label: {label}')
    ax.axis('off')
plt.tight_layout()
plt.show()

### Create data loaders

In [None]:
from torch.utils.data import DataLoader

batch_size = 64

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2
)

print(f"Number of training batches: {len(train_loader)}")
print(f"Number of validation batches: {len(val_loader)}")

## 3. Model Creation

Create a simple CNN for MNIST classification

In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create model
model = SimpleCNN(num_classes=10, input_channels=1)
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel Architecture:")
print(model)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## 4. Training Setup

In [None]:
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = get_optimizer(model, 'adam', lr=0.001)
scheduler = get_scheduler(optimizer, 'reduce_on_plateau', patience=3, factor=0.5)

# Create trainer
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    scheduler=scheduler,
    save_dir='../models/checkpoints',
    experiment_name='mnist_cnn'
)

print("Trainer initialized successfully!")

### Train the model

In [None]:
# Train for 5 epochs
history = trainer.fit(num_epochs=5, early_stopping_patience=3)

## 5. Visualize Training History

In [None]:
# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Loss curves
ax1.plot(history['train_losses'], label='Train Loss', marker='o')
ax1.plot(history['val_losses'], label='Val Loss', marker='s')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True)

# Accuracy curves
ax2.plot(history['train_accuracies'], label='Train Accuracy', marker='o')
ax2.plot(history['val_accuracies'], label='Val Accuracy', marker='s')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Training and Validation Accuracy')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.show()

print(f"\nBest Validation Loss: {min(history['val_losses']):.4f}")
print(f"Best Validation Accuracy: {max(history['val_accuracies']):.2f}%")

## 6. Model Evaluation

Test the model on some sample images

In [None]:
# Make predictions on test set
model.eval()

# Get a batch of test images
test_images, test_labels = next(iter(val_loader))
test_images = test_images.to(device)

with torch.no_grad():
    outputs = model(test_images)
    _, predictions = torch.max(outputs, 1)

# Visualize predictions
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i, ax in enumerate(axes.flat):
    image = test_images[i].cpu().squeeze()
    pred = predictions[i].cpu().item()
    true = test_labels[i].item()
    
    ax.imshow(image, cmap='gray')
    color = 'green' if pred == true else 'red'
    ax.set_title(f'Pred: {pred}, True: {true}', color=color)
    ax.axis('off')

plt.tight_layout()
plt.show()

## Next Steps

1. Try different model architectures (ResNet, custom models)
2. Experiment with hyperparameters
3. Add data augmentation
4. Implement early stopping
5. Use TensorBoard or Weights & Biases for experiment tracking
6. Try transfer learning with pretrained models