<a href="https://colab.research.google.com/github/joshpack3/EE4745_Final_Proj/blob/main/project_code_A.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [20]:
import os
import time
import copy
import numpy as np
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
import scipy.ndimage # Added for mask resizing


# =========================
# Config
# =========================
class Config:
    data_root = "data/Sports"  # expects data_root/train and data_root/val
    image_size = 64            # 32 or 64
    batch_size = 64
    num_workers = 4

    num_classes = 10
    num_epochs = 10
    lr = 1e-3
    weight_decay = 1e-4

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    out_dir = "outputs_problem_a"
    model_a_name = "simple_cnn"
    model_b_name = "small_resnet"

    run_name_a = "run_model_A_simple_cnn"
    run_name_b = "run_model_B_small_resnet"


cfg = Config()
os.makedirs(cfg.out_dir, exist_ok=True)


# =========================
# Data loaders
# =========================
def get_dataloaders(cfg: Config):
    train_transform = transforms.Compose([
        transforms.Resize((cfg.image_size, cfg.image_size)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    val_transform = transforms.Compose([
        transforms.Resize((cfg.image_size, cfg.image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    train_dir = os.path.join(cfg.data_root, "train")
    val_dir = os.path.join(cfg.data_root, "valid") # Changed 'val' to 'valid'

    train_dataset = datasets.ImageFolder(train_dir, transform=train_transform)
    val_dataset = datasets.ImageFolder(val_dir, transform=val_transform)

    train_loader = DataLoader(
        train_dataset,
        batch_size=cfg.batch_size,
        shuffle=True,
        num_workers=cfg.num_workers,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=cfg.batch_size,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=True
    )

    class_names = train_dataset.classes
    return train_loader, val_loader, class_names


# =========================
# Model A - Simple CNN
# =========================
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.MaxPool2d(2),  # 64 -> 32

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.MaxPool2d(2),  # 32 -> 16

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),

            nn.MaxPool2d(2),  # 16 -> 8
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 8 * 8, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


# =========================
# Model B - Small ResNet style
# =========================
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels, out_channels, kernel_size=3,
            stride=stride, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3,
            stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(
                    in_channels, out_channels,
                    kernel_size=1, stride=stride, bias=False
                ),
                nn.BatchNorm2d(out_channels),
            )

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))

        if self.downsample is not None:
            identity = self.downsample(identity)

        out = out + identity
        out = self.relu(out)
        return out


class SmallResNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
        )

        self.layer1 = BasicBlock(32, 64, stride=2)   # 64 -> 32
        self.layer2 = BasicBlock(64, 128, stride=2)  # 32 -> 16
        self.layer3 = BasicBlock(128, 256, stride=2) # 16 -> 8

        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x


# =========================
# Utility
# =========================
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


# =========================
# Train and eval loops
# =========================
def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    start_time = time.time()

    for inputs, targets in dataloader:
        inputs = inputs.to(device)
        targets = targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, preds = outputs.max(1)
        correct += preds.eq(targets).sum().item()
        total += targets.size(0)

    epoch_time = time.time() - start_time
    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc, epoch_time


def evaluate(model, dataloader, criterion, device, num_classes, class_names=None):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    per_class_correct = np.zeros(num_classes, dtype=np.int64)
    per_class_total = np.zeros(num_classes, dtype=np.int64)

    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs = inputs.to(device)
            targets = targets.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, targets)

            running_loss += loss.item() * inputs.size(0)
            _, preds = outputs.max(1)

            correct += preds.eq(targets).sum().item()
            total += targets.size(0)

            for t, p in zip(targets, preds):
                per_class_total[t.item()] += 1
                if t.item() == p.item():
                    per_class_correct[t.item()] += 1

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    per_class_acc = per_class_correct / np.maximum(per_class_total, 1)

    if class_names is not None:
        print("Per class accuracy:")
        for idx, name in enumerate(class_names):
            print(f"{name}: {per_class_acc[idx] * 100:.2f}% "
                  f"({per_class_correct[idx]}/{per_class_total[idx]})")

    return epoch_loss, epoch_acc, per_class_acc


# =========================
# Saliency maps
# =========================
def compute_saliency_map(model, image_tensor, label, device):
    model.eval()
    image = image_tensor.unsqueeze(0).to(device)
    image.requires_grad_()

    output = model(image)
    target_score = output[0, label]
    model.zero_grad()
    target_score.backward()

    saliency = image.grad.data.abs().max(dim=1)[0]  # shape (1, H, W)
    saliency = saliency.squeeze(0).cpu().numpy() # Now shape (H, W)
    saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min() + 1e-8)
    return saliency


# =========================
# Grad CAM
# =========================
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer

        self.activations = None
        self.gradients = None

        self.hook_a = target_layer.register_forward_hook(self.save_activation)
        self.hook_g = target_layer.register_full_backward_hook(self.save_gradient)

    def save_activation(self, module, input, output):
        self.activations = output.detach()

    def save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()

    def __call__(self, input_tensor, target_class=None):
        self.model.eval()
        input_tensor = input_tensor.unsqueeze(0)
        input_tensor = input_tensor.to(next(self.model.parameters()).device)
        input_tensor.requires_grad_()

        output = self.model(input_tensor)
        if target_class is None:
            target_class = output.argmax(dim=1).item()

        loss = output[0, target_class]
        self.model.zero_grad()
        loss.backward()

        gradients = self.gradients  # shape (1, C, H, W)
        activations = self.activations  # shape (1, C, H, W)

        # Corrected: average gradients over spatial dimensions (H, W) for each channel and batch item
        weights = gradients.mean(dim=(2, 3), keepdim=True)  # shape (1, C, 1, 1)

        # Corrected: multiply weights with activations and sum over channel dimension, then squeeze batch dim
        cam = (weights * activations).sum(dim=1)  # shape (1, H, W)
        cam = cam.squeeze(0)  # shape (H, W)

        cam = cam.cpu().numpy()
        cam = np.maximum(cam, 0)
        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
        return cam

    def close(self):
        self.hook_a.remove()
        self.hook_g.remove()


def show_and_save_heatmap(image_tensor, mask, out_path, alpha=0.4):
    # image_tensor is normalized. Need to unnormalize for visualization.
    mean = np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1)
    std = np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1)

    img = image_tensor.cpu().numpy()
    img = std * img + mean
    img = np.clip(img, 0, 1)
    img = np.transpose(img, (1, 2, 0)) # Shape (H, W, C)

    # Resize mask to the same dimensions as the image if needed
    img_h, img_w, _ = img.shape
    mask_h, mask_w = mask.shape

    if mask_h != img_h or mask_w != img_w:
        zoom_factor_h = img_h / mask_h
        zoom_factor_w = img_w / mask_w
        # Use order=1 for bilinear interpolation
        mask = scipy.ndimage.zoom(mask, (zoom_factor_h, zoom_factor_w), order=1)

    heatmap = plt.cm.jet(mask)[..., :3]
    overlay = alpha * heatmap + (1 - alpha) * img
    overlay = np.clip(overlay, 0, 1)

    plt.figure(figsize=(4, 4))
    plt.axis("off")
    plt.imshow(overlay)
    plt.tight_layout()
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    plt.savefig(out_path, bbox_inches="tight", pad_inches=0)
    plt.close()


# =========================
# Training wrapper
# =========================
def train_and_eval_model(model, model_name, run_name, cfg, train_loader, val_loader, class_names):
    device = cfg.device
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

    writer = SummaryWriter(log_dir=os.path.join(cfg.out_dir, run_name))

    best_val_acc = 0.0
    best_model_wts = copy.deepcopy(model.state_dict())

    num_params = count_parameters(model)
    print(f"Model {model_name} param count: {num_params}")

    for epoch in range(cfg.num_epochs):
        print(f"Epoch {epoch + 1}/{cfg.num_epochs}")

        train_loss, train_acc, train_time = train_one_epoch(
            model, train_loader, criterion, optimizer, device
        )
        val_loss, val_acc, _ = evaluate(
            model, val_loader, criterion, device,
            cfg.num_classes, class_names=None
        )

        print(f"Train loss {train_loss:.4f} acc {train_acc:.4f} "
              f"time {train_time:.2f}s")
        print(f"Val   loss {val_loss:.4f} acc {val_acc:.4f}")

        writer.add_scalar("Loss/train", train_loss, epoch)
        writer.add_scalar("Loss/val", val_loss, epoch)
        writer.add_scalar("Acc/train", train_acc, epoch)
        writer.add_scalar("Acc/val", val_acc, epoch)
        writer.add_scalar("Time/train_epoch_sec", train_time, epoch)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_wts = copy.deepcopy(model.state_dict())
            checkpoint_path = os.path.join(
                cfg.out_dir, f"{model_name}-best-original.pt"
            )
            torch.save(best_model_wts, checkpoint_path)
            print(f"Saved best checkpoint to {checkpoint_path}")

    writer.close()

    model.load_state_dict(best_model_wts)
    final_ckpt_path = os.path.join(
        cfg.out_dir, f"{model_name}-final-original.pt"
    )
    torch.save(model.state_dict(), final_ckpt_path)
    print(f"Saved final model to {final_ckpt_path}")

    criterion = nn.CrossEntropyLoss()
    val_loss, val_acc, per_class_acc = evaluate(
        model, val_loader, criterion, device,
        cfg.num_classes, class_names=class_names
    )

    print(f"Best val accuracy for {model_name}: {best_val_acc:.4f}")
    print(f"Final val accuracy for {model_name}: {val_acc:.4f}")

    return model, per_class_acc


# =========================
# Run interpretability
# =========================
def run_interpretability(model, cfg, class_names, val_loader, model_name):
    device = cfg.device
    model.eval()
    model.to(device)

    # choose last conv layer as target for Grad CAM
    if isinstance(model, SimpleCNN):
        target_layer = model.features[-3]  # last Conv in features
    elif isinstance(model, SmallResNet):
        target_layer = model.layer3  # last residual block
    else:
        raise ValueError("Unknown model type for Grad CAM")

    grad_cam = GradCAM(model, target_layer)

    out_dir = Path(cfg.out_dir) / f"{model_name}_interpretability"
    out_dir.mkdir(parents=True, exist_ok=True)

    images_done = 0
    max_images = 10  # you can change

    # Removed with torch.no_grad(): block here
    for inputs, targets in val_loader:
        for i in range(inputs.size(0)):
            img = inputs[i]
            label = targets[i].item()
            label_name = class_names[label]

            # Saliency uses grad, so need to re run with grad
            sal_map = compute_saliency_map(model, img, label, device)

            # Grad CAM
            cam = grad_cam(img, target_class=label)

            # save saliency
            sal_out_path = out_dir / f"saliency_{images_done}_{label_name}.png"
            plt.figure(figsize=(4, 4))
            plt.axis("off")
            plt.imshow(sal_map, cmap="hot")
            plt.tight_layout()
            plt.savefig(sal_out_path, bbox_inches="tight", pad_inches=0)
            plt.close()

            # save Grad CAM overlay
            cam_out_path = out_dir / f"gradcam_{images_done}_{label_name}.png"
            show_and_save_heatmap(img, cam, str(cam_out_path))

            images_done += 1
            if images_done >= max_images:
                grad_cam.close()
                print(f"Saved {images_done} interpretability images for {model_name}")
                return

    grad_cam.close()
    print(f"Saved {images_done} interpretability images for {model_name}")


# =========================
# Main
# =========================
def main():
    train_loader, val_loader, class_names = get_dataloaders(cfg)

    print("Training Model A - SimpleCNN")
    model_a = SimpleCNN(num_classes=cfg.num_classes)
    model_a, per_class_acc_a = train_and_eval_model(
        model_a,
        cfg.model_a_name,
        cfg.run_name_a,
        cfg,
        train_loader,
        val_loader,
        class_names
    )
    run_interpretability(model_a, cfg, class_names, val_loader, cfg.model_a_name)

    print("\nTraining Model B - SmallResNet")
    model_b = SmallResNet(num_classes=cfg.num_classes)
    model_b, per_class_acc_b = train_and_eval_model(
        model_b,
        cfg.model_b_name,
        cfg.run_name_b,
        cfg,
        train_loader,
        val_loader,
        class_names
    )
    run_interpretability(model_b, cfg, class_names, val_loader, cfg.model_b_name)

    print("Done")


if __name__ == "__main__":
    main()

Training Model A - SimpleCNN
Model simple_cnn param count: 4586506
Epoch 1/10




Train loss 4.0491 acc 0.2241 time 59.64s
Val   loss 1.9434 acc 0.3800
Saved best checkpoint to outputs_problem_a/simple_cnn-best-original.pt
Epoch 2/10
Train loss 1.8904 acc 0.3315 time 59.16s
Val   loss 1.5749 acc 0.4600
Saved best checkpoint to outputs_problem_a/simple_cnn-best-original.pt
Epoch 3/10
Train loss 1.7181 acc 0.3810 time 60.55s
Val   loss 1.4824 acc 0.4800
Saved best checkpoint to outputs_problem_a/simple_cnn-best-original.pt
Epoch 4/10
Train loss 1.6580 acc 0.4244 time 60.56s
Val   loss 1.5847 acc 0.4200
Epoch 5/10
Train loss 1.5660 acc 0.4338 time 67.51s
Val   loss 1.3349 acc 0.4800
Epoch 6/10
Train loss 1.5080 acc 0.4721 time 60.28s
Val   loss 1.3385 acc 0.5400
Saved best checkpoint to outputs_problem_a/simple_cnn-best-original.pt
Epoch 7/10
Train loss 1.4211 acc 0.4765 time 61.55s
Val   loss 1.3122 acc 0.4600
Epoch 8/10
Train loss 1.4037 acc 0.5116 time 61.01s
Val   loss 1.2654 acc 0.5800
Saved best checkpoint to outputs_problem_a/simple_cnn-best-original.pt
Epoch 9/

In [3]:
from google.colab import files
uploaded = files.upload()  # select sports.zip
!unzip Sports.zip -d data/
!ls data

Saving Sports.zip to Sports.zip
Archive:  Sports.zip
   creating: data/Sports/
   creating: data/Sports/valid/
  inflating: data/__MACOSX/Sports/._valid  
  inflating: data/Sports/.DS_Store   
  inflating: data/__MACOSX/Sports/._.DS_Store  
   creating: data/Sports/train/
  inflating: data/__MACOSX/Sports/._train  
   creating: data/Sports/valid/golf/
  inflating: data/__MACOSX/Sports/valid/._golf  
   creating: data/Sports/valid/weightlifting/
  inflating: data/__MACOSX/Sports/valid/._weightlifting  
  inflating: data/Sports/valid/.DS_Store  
  inflating: data/__MACOSX/Sports/valid/._.DS_Store  
   creating: data/Sports/valid/football/
  inflating: data/__MACOSX/Sports/valid/._football  
   creating: data/Sports/valid/baseball/
  inflating: data/__MACOSX/Sports/valid/._baseball  
   creating: data/Sports/valid/basketball/
  inflating: data/__MACOSX/Sports/valid/._basketball  
   creating: data/Sports/valid/volleyball/
  inflating: data/__MACOSX/Sports/valid/._volleyball  
   creating: