# PyTorch Transfer Learning Tutorial

This notebook demonstrates transfer learning using PyTorch with pre-trained models.

## 1. Setup and Imports

In [None]:
import sys
sys.path.append('..')

import torch
import torchvision
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

from src.pytorch_transfer import PyTorchTransferModel
from src.gradcam import GradCAM
from src.utils import prepare_pytorch_dataset, plot_training_history

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 2. Load Pre-trained Model

In [None]:
# Initialize model with ImageNet weights
model = PyTorchTransferModel(
    model_name='resnet50',
    num_classes=1000,
    pretrained=True
)

print(f"Model loaded: {model.model_name}")
print(f"Device: {model.device}")

## 3. Test Prediction on Sample Image

In [None]:
# Load a sample image
image_path = '../examples/cat.jpg'  # Replace with your image
image = Image.open(image_path)

# Display image
plt.figure(figsize=(6, 6))
plt.imshow(image)
plt.axis('off')
plt.title('Input Image')
plt.show()

# Make prediction
prediction, confidence = model.predict(image, top_k=1)
print(f"\nPrediction: {prediction}")
print(f"Confidence: {confidence:.2%}")

# Get top-5 predictions
top5 = model.predict(image, top_k=5)
print("\nTop-5 Predictions:")
for i, (pred, conf) in enumerate(top5, 1):
    print(f"{i}. {pred}: {conf:.2%}")

## 4. Grad-CAM Visualization

In [None]:
# Initialize Grad-CAM
gradcam = GradCAM(model.model, framework='pytorch')

# Generate heatmap
heatmap = gradcam.generate_heatmap(image)

# Create overlay
overlay = gradcam.overlay_heatmap(image, heatmap)

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

axes[0].imshow(image)
axes[0].set_title('Original Image')
axes[0].axis('off')

axes[1].imshow(heatmap, cmap='jet')
axes[1].set_title('Grad-CAM Heatmap')
axes[1].axis('off')

axes[2].imshow(overlay)
axes[2].set_title(f'Overlay: {prediction}')
axes[2].axis('off')

plt.tight_layout()
plt.show()

## 5. Transfer Learning on Custom Dataset

### 5.1 Prepare Data

In [None]:
# Prepare data loaders
data_dir = '../data'  # Your dataset directory
batch_size = 32

train_loader, val_loader = prepare_pytorch_dataset(
    data_dir=data_dir,
    batch_size=batch_size,
    img_size=224
)

print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")
print(f"Number of classes: {len(train_loader.dataset.classes)}")
print(f"Classes: {train_loader.dataset.classes}")

### 5.2 Visualize Training Data

In [None]:
# Get a batch of training data
images, labels = next(iter(train_loader))

# Denormalize images for visualization
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
images_denorm = images * std + mean

# Show images
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
axes = axes.flatten()

for i in range(8):
    img = images_denorm[i].permute(1, 2, 0).numpy()
    img = np.clip(img, 0, 1)
    axes[i].imshow(img)
    axes[i].set_title(train_loader.dataset.classes[labels[i]])
    axes[i].axis('off')

plt.tight_layout()
plt.show()

### 5.3 Initialize Model for Custom Dataset

In [None]:
# Number of classes in your dataset
num_classes = len(train_loader.dataset.classes)

# Initialize model
custom_model = PyTorchTransferModel(
    model_name='resnet50',
    num_classes=num_classes,
    pretrained=True
)

# Freeze base layers
custom_model.freeze_layers(freeze_all=True)

print(f"Model architecture: {custom_model.model_name}")
print(f"Number of classes: {num_classes}")
print(f"Trainable parameters: {sum(p.numel() for p in custom_model.model.parameters() if p.requires_grad):,}")

### 5.4 Train Model

In [None]:
# Train the model
custom_model.train_model(
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=10,
    lr=0.001,
    save_path='../models/custom_model.pth'
)

### 5.5 Fine-tuning

In [None]:
# Unfreeze all layers for fine-tuning
for param in custom_model.model.parameters():
    param.requires_grad = True

# Train with lower learning rate
custom_model.train_model(
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=5,
    lr=0.0001,
    save_path='../models/custom_model_finetuned.pth'
)

## 6. Evaluate Model

In [None]:
# Load best model
custom_model.load('../models/custom_model.pth')

# Evaluate on validation set
custom_model.model.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in val_loader:
        images = images.to(custom_model.device)
        labels = labels.to(custom_model.device)
        
        outputs = custom_model.model(images)
        _, predicted = torch.max(outputs.data, 1)
        
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Validation Accuracy: {accuracy:.2f}%")

## 7. Compare Different Architectures

In [None]:
# Test different architectures
architectures = ['resnet50', 'vgg16', 'efficientnetb0', 'mobilenetv2']
results = {}

for arch in architectures:
    print(f"\nTesting {arch}...")
    
    model_test = PyTorchTransferModel(
        model_name=arch,
        num_classes=1000,
        pretrained=True
    )
    
    prediction, confidence = model_test.predict(image)
    results[arch] = (prediction, confidence)
    print(f"{arch}: {prediction} ({confidence:.2%})")

# Visualize results
fig, ax = plt.subplots(figsize=(10, 6))
models = list(results.keys())
confidences = [results[m][1] for m in models]

ax.barh(models, confidences)
ax.set_xlabel('Confidence')
ax.set_title('Model Comparison')
plt.tight_layout()
plt.show()

## 8. Save and Export Model

In [None]:
# Save model
custom_model.save('../models/final_model.pth')

# Export to TorchScript for production
example_input = torch.randn(1, 3, 224, 224).to(custom_model.device)
traced_script = torch.jit.trace(custom_model.model, example_input)
traced_script.save('../models/final_model_scripted.pt')

print("Model saved successfully!")

## Conclusion

In this tutorial, you learned:
- How to load pre-trained PyTorch models
- Making predictions with transfer learning
- Generating Grad-CAM visualizations
- Fine-tuning models on custom datasets
- Comparing different architectures
- Exporting models for production

For more examples, check out the other notebooks!