# HW01: Transfer Learning and Ensemble Methods for Medical Image Classification

**Course**: CSYE 7374 - Deep Learning and Generative AI in Healthcare

---

## Objectives

In this homework, you will apply the transfer learning and ensemble techniques covered in class to a **different medical imaging dataset**. You will:

1. Train **two** pretrained models on **BloodMNIST** (blood cell microscopy images)
2. Implement and compare **ensemble methods**
3. Analyze when and why ensemble learning outperforms individual models

---

## Dataset: BloodMNIST

**BloodMNIST** contains 17,092 microscopic images of blood cells categorized into **8 classes**:
- 0: basophil
- 1: eosinophil
- 2: erythroblast
- 3: immature granulocyte (ig)
- 4: lymphocyte
- 5: monocyte
- 6: neutrophil
- 7: platelet

---

## Instructions

- Complete all cells marked with **`# TODO`**
- Do not modify the provided helper functions unless instructed
- Run all cells in order
- Answer the analysis questions at the end

---

## Grading Rubric

| Task | Points |
|------|--------|
| Data loading and visualization | 10 |
| Data preprocessing (val_transform) | 5 |
| Model definition (TransferLearningModel class) | 15 |
| Loss function, optimizer, and scheduler | 10 |
| Training and validation functions | 15 |
| Model training and learning curves | 10 |
| Ensemble implementation (average + voting) | 10 |
| Ensemble evaluation on test set | 5 |
| Analysis questions (4 x 5 points) | 20 |
| **Total** | **100** |

---
## 1. Setup and Imports

In [None]:
# Install required packages (run once)
!pip install -q torch torchvision medmnist matplotlib seaborn scikit-learn tqdm pandas

In [None]:
import os
import random
import copy
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
from collections import defaultdict

# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

# Torchvision - transforms and pretrained models
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights
# TODO: Import additional models as needed for your second model choice
# from torchvision.models import vgg19, VGG19_Weights
# from torchvision.models import googlenet, GoogLeNet_Weights
# from torchvision.models import inception_v3, Inception_V3_Weights

# MedMNIST dataset
import medmnist
from medmnist import INFO

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Metrics
from sklearn.metrics import (
    confusion_matrix, classification_report,
    accuracy_score, precision_recall_fscore_support
)

# Progress bar
from tqdm.notebook import tqdm

plt.style.use('seaborn-v0_8-whitegrid')
print(f"PyTorch: {torch.__version__}")
print(f"MedMNIST: {medmnist.__version__}")

---
## 2. Configuration

In [None]:
class Config:
    DATA_FLAG = 'bloodmnist'  # BloodMNIST dataset
    DOWNLOAD = True
    BATCH_SIZE = 32
    NUM_EPOCHS = 15  # Reduced for homework
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-5
    IMG_SIZE = 224
    SEED = 42
    CHECKPOINT_DIR = './checkpoints_hw'

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

set_seed(Config.SEED)
os.makedirs(Config.CHECKPOINT_DIR, exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

---
## 3. Load and Explore Data

In [None]:
# Load dataset information
info = INFO[Config.DATA_FLAG]
n_channels = info['n_channels']
n_classes = len(info['label'])
class_names = list(info['label'].values())

print(f"Dataset: {Config.DATA_FLAG.upper()}")
print(f"Classes: {n_classes}")
print(f"Class names: {class_names}")

In [None]:
# Load raw datasets
DataClass = getattr(medmnist, info['python_class'])

train_dataset_raw = DataClass(split='train', download=Config.DOWNLOAD)

# ============================================================
# TODO: Load the validation and test datasets
# ============================================================
val_dataset_raw = None  # TODO
test_dataset_raw = None  # TODO
# ============================================================

print(f"\nTrain: {len(train_dataset_raw)} | Val: {len(val_dataset_raw)} | Test: {len(test_dataset_raw)}")

### Visualize Sample Images from Each Class

Complete the function below to display **3 sample images from each of the 8 classes** in a grid layout.

In [None]:
# ============================================================
# TODO: Visualize sample images from each class
# ============================================================

def visualize_samples_per_class(dataset, class_names, n_samples=3):
    """Visualize n_samples from each class."""
    n_classes = len(class_names)
    fig, axes = plt.subplots(n_classes, n_samples, figsize=(n_samples * 2, n_classes * 2))
    
    # TODO: Collect samples for each class and plot them
    
    pass  # Remove this line after completing TODO
    
    plt.suptitle('Sample Images from Each Class', fontsize=14)
    plt.tight_layout()
    plt.show()

# Call your function
# visualize_samples_per_class(train_dataset_raw, class_names, n_samples=3)

# ============================================================

In [None]:
# Class distribution
def get_class_distribution(dataset):
    labels = [label[0] for _, label in dataset]
    return np.bincount(labels, minlength=n_classes)

train_dist = get_class_distribution(train_dataset_raw)

plt.figure(figsize=(10, 4))
plt.bar(range(n_classes), train_dist, color=sns.color_palette("husl", n_classes))
plt.xlabel('Class')
plt.ylabel('Count')
plt.title('Training Set Class Distribution')
plt.xticks(range(n_classes), [f'{i}' for i in range(n_classes)])
plt.show()

print("Class distribution:")
for i, name in enumerate(class_names):
    print(f"  {i}: {name}: {train_dist[i]} ({train_dist[i]/train_dist.sum()*100:.1f}%)")

---
## 4. Data Preprocessing

In [None]:
# ImageNet normalization
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

# Training transforms (with data augmentation)
train_transform = transforms.Compose([
    transforms.Resize((Config.IMG_SIZE, Config.IMG_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])

# ============================================================
# TODO: Complete the validation/test transform (no augmentation)
# ============================================================
val_transform = None  # TODO
# ============================================================

# Create datasets
train_dataset = DataClass(split='train', transform=train_transform, download=Config.DOWNLOAD)
val_dataset = DataClass(split='val', transform=val_transform, download=Config.DOWNLOAD)
test_dataset = DataClass(split='test', transform=val_transform, download=Config.DOWNLOAD)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=Config.BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Data loaders created: {len(train_loader)} train batches, {len(test_loader)} test batches")

---
## 5. Model Definitions

### Complete the TransferLearningModel class

You need to implement **two models**:
1. **ResNet50** (required)
2. **One model of your choice**

Complete the model initialization code below by modifying the final classification layer to output `num_classes` instead of 1000.

In [None]:
class TransferLearningModel(nn.Module):
    def __init__(self, model_name, num_classes, pretrained=True):
        super(TransferLearningModel, self).__init__()
        self.model_name = model_name
        
        # ============================================================
        # TODO: Load pretrained model and modify the final layer
        # ============================================================
        
        if model_name == 'resnet50':
            weights = ResNet50_Weights.IMAGENET1K_V2 if pretrained else None
            self.model = resnet50(weights=weights)
            # TODO: Modify the final layer
            
        # TODO: Add elif blocks for your second model choice
        # elif model_name == '...':
        #     ...
            
        else:
            raise ValueError(f"Unknown model: {model_name}")
        
        # ============================================================
    
    def forward(self, x):
        return self.model(x)

In [None]:
# ============================================================
# TODO: Create your two model instances
# Model 1 must be ResNet50
# Model 2 is your choice
# ============================================================

# Model 1: ResNet50 (required)
model_1_name = 'ResNet50'
model_1 = None  # TODO

# Model 2: Your choice
model_2_name = 'YourChoice'  # TODO: Change name
model_2 = None  # TODO

# ============================================================

# Store models in a dictionary
models_dict = {
    model_1_name: model_1,
    model_2_name: model_2
}

# Print model parameters
for name, model in models_dict.items():
    if model is not None:
        total = sum(p.numel() for p in model.parameters())
        print(f"{name}: {total:,} parameters")

---
## 6. Loss Function and Optimizers

In [None]:
# Class weights for imbalanced data
def calculate_class_weights(distribution):
    total = distribution.sum()
    weights = total / (len(distribution) * distribution)
    weights = weights / weights.sum() * len(weights)
    return torch.FloatTensor(weights)

class_weights = calculate_class_weights(train_dist)

# ============================================================
# TODO: Define the loss function with class weights
# ============================================================
criterion = None  # TODO
# ============================================================

# Create optimizers and schedulers for each model
optimizers = {}
schedulers = {}

for name, model in models_dict.items():
    if model is not None:
        # ============================================================
        # TODO: Create optimizer and learning rate scheduler
        # ============================================================
        optimizers[name] = None  # TODO
        schedulers[name] = None  # TODO
        # ============================================================

print("Loss and optimizers configured.")

---
## 7. Training Functions

In [None]:
def train_one_epoch(model, train_loader, criterion, optimizer, device, model_name):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    
    for images, labels in tqdm(train_loader, desc='Training', leave=False):
        images = images.to(device)
        labels = labels.squeeze().long().to(device)
        
        # ============================================================
        # TODO: Complete the training step
        # ============================================================
        
        outputs = None  # TODO
        loss = None  # TODO
        
        # ============================================================
        
        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    return running_loss / total, 100. * correct / total

def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss, correct, total = 0.0, 0, 0
    all_preds, all_labels, all_probs = [], [], []
    
    # ============================================================
    # TODO: Complete the validation loop
    # ============================================================
    
    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc='Validating', leave=False):
            images = images.to(device)
            labels = labels.squeeze().long().to(device)
            
            outputs = None  # TODO
            loss = None  # TODO
            probs = None  # TODO
            predicted = None  # TODO
            
            # Track statistics (do not modify)
            running_loss += loss.item() * images.size(0)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    # ============================================================
    
    return running_loss / total, 100. * correct / total, np.array(all_preds), np.array(all_labels), np.array(all_probs)

In [None]:
def train_model(model, model_name, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device):
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    best_val_acc = 0.0
    best_weights = copy.deepcopy(model.state_dict())
    
    print(f"\n{'='*50}\nTraining {model_name}\n{'='*50}")
    
    for epoch in range(num_epochs):
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device, model_name)
        val_loss, val_acc, _, _, _ = validate(model, val_loader, criterion, device)
        scheduler.step(val_loss)
        
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        print(f"Epoch {epoch+1}/{num_epochs} | Train: {train_acc:.1f}% | Val: {val_acc:.1f}%")
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_weights = copy.deepcopy(model.state_dict())
    
    model.load_state_dict(best_weights)
    print(f"Best validation accuracy: {best_val_acc:.2f}%")
    return model, history

---
## 8. Train Models

Run training for both models. This may take 10-20 minutes per model depending on your hardware.

In [None]:
# ============================================================
# TODO: Train both models and store histories
# ============================================================

all_histories = {}

# TODO: Train each model in models_dict

# ============================================================

print("\nTraining complete!")

---
## 9. Learning Curves

In [None]:
# ============================================================
# TODO: Plot learning curves (loss and accuracy)
# ============================================================

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# TODO: Plot validation loss and accuracy for each model

# ============================================================

plt.tight_layout()
plt.show()

---
## 10. Ensemble Methods

### Implement the Ensemble Class

Complete the `ensemble_average` and `ensemble_voting` methods.

In [None]:
class EnsembleModel:
    def __init__(self, models_dict, device):
        self.models = models_dict
        self.device = device
    
    def predict_proba(self, images, model_name):
        """Get probability predictions from a single model."""
        model = self.models[model_name]
        model.eval()
        with torch.no_grad():
            outputs = model(images)
            if isinstance(outputs, tuple):
                outputs = outputs[0]
            probs = F.softmax(outputs, dim=1)
        return probs
    
    def ensemble_average(self, images):
        """
        Ensemble by averaging probabilities from all models.
        
        Returns:
            predictions: argmax of averaged probabilities
            avg_probs: averaged probability distribution
        """
        # ============================================================
        # TODO: Implement probability averaging ensemble
        # ============================================================
        
        # TODO: Return predictions and averaged probabilities
        
        pass  # Remove this line after completing TODO
        # ============================================================
    
    def ensemble_voting(self, images):
        """
        Ensemble by majority voting.
        
        Returns:
            votes: majority vote predictions
            avg_probs: averaged probabilities (for confidence)
        """
        # ============================================================
        # TODO: Implement majority voting ensemble
        # ============================================================
        
        # TODO: Return majority vote predictions and probabilities
        
        pass  # Remove this line after completing TODO
        # ============================================================

# Create ensemble
ensemble = EnsembleModel(models_dict, device)
print("Ensemble model created.")

---
## 11. Evaluate All Models on Test Set

### Complete the evaluation loop

In [None]:
# Evaluate individual models
individual_results = {}

for name, model in models_dict.items():
    if model is not None:
        _, test_acc, preds, labels, probs = validate(model, test_loader, criterion, device)
        individual_results[name] = {
            'accuracy': test_acc,
            'predictions': preds,
            'labels': labels,
            'probabilities': probs
        }
        print(f"{name} Test Accuracy: {test_acc:.2f}%")

In [None]:
# ============================================================
# TODO: Evaluate ensemble methods on test set
# ============================================================

ensemble_results = {}

# TODO: Evaluate both 'average' and 'voting' ensemble methods
# Store results in ensemble_results dictionary

# ============================================================

---
## 12. Results Visualization

In [None]:
# Confusion matrix for best individual model
best_individual = max(individual_results.items(), key=lambda x: x[1]['accuracy'])
best_name, best_results = best_individual

cm = confusion_matrix(best_results['labels'], best_results['predictions'])
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.title(f'Confusion Matrix - {best_name}')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

In [None]:
# Print classification reports
print("="*60)
print(f"Classification Report: {best_name}")
print("="*60)
print(classification_report(best_results['labels'], best_results['predictions'], 
                            target_names=class_names, digits=3))

In [None]:
# Compare all results
print("\n" + "="*60)
print("FINAL RESULTS COMPARISON")
print("="*60)

all_results = {**individual_results, **ensemble_results}

results_data = []
for name, res in all_results.items():
    model_type = 'Ensemble' if 'Ensemble' in name else 'Individual'
    precision, recall, f1, _ = precision_recall_fscore_support(
        res['labels'], res['predictions'], average='macro'
    )
    results_data.append({
        'Model': name,
        'Type': model_type,
        'Accuracy (%)': res['accuracy'],
        'Precision': precision,
        'Recall': recall,
        'F1-Score': f1
    })

results_df = pd.DataFrame(results_data).sort_values('Accuracy (%)', ascending=False)
print(results_df.to_string(index=False))

In [None]:
# Bar chart comparison
plt.figure(figsize=(10, 6))
colors = ['#2196F3' if 'Ensemble' not in m else '#4CAF50' for m in results_df['Model']]
bars = plt.barh(results_df['Model'], results_df['Accuracy (%)'], color=colors)
plt.xlabel('Test Accuracy (%)')
plt.title('Model Performance Comparison')

for bar, acc in zip(bars, results_df['Accuracy (%)']):
    plt.text(bar.get_width() + 0.3, bar.get_y() + bar.get_height()/2, 
             f'{acc:.1f}%', va='center')

from matplotlib.patches import Patch
plt.legend(handles=[Patch(color='#2196F3', label='Individual'), 
                    Patch(color='#4CAF50', label='Ensemble')])
plt.tight_layout()
plt.show()

---
## 13. Analysis Questions (20 points)

Answer the following questions based on your results. Write your answers in the markdown cells below each question.

### Question 1 (5 points)

Compare the test accuracy of your two individual models. Which model performed better? Provide at least **two possible reasons** why one model outperformed the other, considering factors like model architecture, number of parameters, or training dynamics.

**Your Answer:**

*[Write your answer here]*

### Question 2 (5 points)

Did the ensemble methods outperform the individual models? Compare the **average probability** ensemble vs **majority voting** ensemble. Which performed better and why might this be the case?

**Your Answer:**

*[Write your answer here]*

### Question 3 (5 points)

Looking at the confusion matrix and per-class metrics, which blood cell types were **most difficult to classify**? Propose one reason why these classes might be challenging and one potential solution to improve classification for these classes.

**Your Answer:**

*[Write your answer here]*

### Question 4 (5 points)

In this homework, we used only **two models** for the ensemble. Based on the class material (which used four models), do you think adding more models would improve ensemble performance? What are the **trade-offs** of using more models in an ensemble?

**Your Answer:**

*[Write your answer here]*

---
## Submission Instructions

1. Run all cells and ensure there are no errors
2. Rename this notebook as: **`FirstName_LastName_HW01.ipynb`**
3. Download and submit to the course portal by the deadline

**Example filename:** `John_Smith_HW01.ipynb`

**Make sure your notebook shows:**
- All TODOs completed
- Training output for both models
- All visualizations rendered
- Analysis questions answered