In [None]:
import sys
FOLDERNAME = 'BrainLat_skullstrip'
DEFAULT_ROOT = './'
sys.path.append(DEFAULT_ROOT)

from src.dataset import MRIDataset, CLASS_TO_IDX

In [None]:
from pytorch_grad_cam import GradCAM, ShapleyCAM, ScoreCAM, GradCAMPlusPlus, XGradCAM
from pytorch_grad_cam import GuidedBackpropReLUModel
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import (
    show_cam_on_image, deprocess_image, preprocess_image
)

In [None]:
import datetime
import glob
import os
import pickle
import warnings
import numpy as np
import random
import pandas as pd
import cv2

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from captum.influence import TracInCP, TracInCPFast, TracInCPFastRandProj
from torch.utils.data import DataLoader, Dataset, Subset, random_split
from src.dataset import MRIDataset, CLASS_TO_IDX, IDX_TO_CLASS

def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")

In [24]:
save_dir = os.path.join(DEFAULT_ROOT, "results/resnet")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
run_id = "axial_brainlat_ss_transformed"  # set this to use an existing checkpoint

print(f"Loading model from {os.path.join(save_dir, run_id)}...")
model = torch.load(os.path.join(save_dir, run_id, 'best_model.pth'), weights_only=False)

Loading model from /content/drive/MyDrive/CSE 599 - Deep Learning for Computer Vision/Final_Project/results/resnet/axial_brainlat_ss_transformed...


In [25]:
def load_checkpoints(net, path):
    weights = torch.load(path, weights_only=False)
    net.load_state_dict(weights["model_state_dict"])
    return 1.

checkpoints_dir = os.path.join(save_dir, 'axial_brainlat_transformed', 'checkpoints')
final_checkpoint = os.path.join(checkpoints_dir, 'checkpoint-16.pt')
load_checkpoints(model, final_checkpoint)
model = model.to(device)

In [41]:
orientation = 'axial'
batch_size = 32

train_path = os.path.join(DEFAULT_ROOT, f"{FOLDERNAME}/train_index.csv")
val_path = os.path.join(DEFAULT_ROOT, f"{FOLDERNAME}/val_index.csv")
test_path = os.path.join(DEFAULT_ROOT, f"{FOLDERNAME}/test_index.csv")

train_dataset = MRIDataset(train_path, DEFAULT_ROOT, orient=orientation, device='cuda')
val_dataset = MRIDataset(val_path, DEFAULT_ROOT, orient=orientation, device='cuda')
test_dataset = MRIDataset(test_path, DEFAULT_ROOT, orient=orientation, device='cuda')

train_loader = DataLoader(train_dataset, batch_size=batch_size)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

print(f"Train size: {len(train_dataset)}")
print(f"Val size: {len(val_dataset)}")
print(f"Test size: {len(test_dataset)}")

Train size: 308
Val size: 38
Test size: 40


### TraceIn

In [42]:
tracin_cp_fast = TracInCPFast(
    model=model,
    final_fc_layer=list(model.children())[-1],
    train_dataset=train_dataset,
    checkpoints=checkpoints_dir,
    checkpoints_load_func=load_checkpoints,
    loss_fn=nn.CrossEntropyLoss(reduction="sum"),
    batch_size=16,
    vectorize=False,
)

In [43]:
test_examples_features = []
test_examples_true_labels = []

for images, labels in test_loader:
  test_examples_features = images.to(device)
  test_examples_true_labels = labels.to(device)
  # break
test_examples_predicted_probs, test_examples_predicted_labels = torch.max(F.softmax(model(test_examples_features), dim=1), dim=1)

In [44]:
k = 10
start_time = datetime.datetime.now()
proponents_indices, proponents_influence_scores = tracin_cp_fast.influence(
    (test_examples_features, test_examples_true_labels), k=k, proponents=True
)
opponents_indices, opponents_influence_scores = tracin_cp_fast.influence(
    (test_examples_features, test_examples_true_labels), k=k, proponents=False
)
total_minutes = (datetime.datetime.now() - start_time).total_seconds() / 60.0
print(
    "Computed proponents / opponents over a dataset of %d examples in %.2f minutes"
    % (len(train_dataset), total_minutes)
)

Computed proponents / opponents over a dataset of 308 examples in 3.86 minutes


In [45]:
proponents_indices = proponents_indices.int()
opponents_indices = opponents_indices.int()

In [46]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [None]:
target_layers = [model.layer4[-1]]
cam = GradCAM(model=model, target_layers=target_layers)
targets = [ClassifierOutputTarget(CLASS_TO_IDX["AD"])]

# UNCOMMENT FOR GUIDED BACKPROP
# gb_model = GuidedBackpropReLUModel(model=model, device=device)

label_to_class = IDX_TO_CLASS

imshow_transform = lambda x: torch.permute(x, (1, 2, 0)).cpu().numpy()

def display_test_example(example, true_label, predicted_label, predicted_prob, label_to_class, idx):
    plt.subplots()
    print('true_class:', label_to_class[true_label.item()])
    print('predicted_class:', label_to_class[predicted_label.item()])
    print('predicted_prob', predicted_prob)
    plt.imshow(np.rot90(imshow_transform(example)))
    grayscale_cam = cam(input_tensor=example[None, :, :, :], targets=targets)
    grayscale_cam_ = grayscale_cam[0, :]
    vis = show_cam_on_image(imshow_transform(example), grayscale_cam_)
    plt.title('True: ' + str(label_to_class[true_label.item()]) + ", Pred: " + label_to_class[predicted_label.item()])
    plt.imshow(np.rot90(vis))

    # UNCOMMENT FOR GUIDED BACKPROP
    # gb = gb_model(example[None, :, :, :], target_category=None)
    # cam_gb = deprocess_image(grayscale_cam.transpose(1, 2, 0) * gb)
    # result = deprocess_image(gb)
    # ax.imshow(result)

    plt.savefig(DEFAULT_ROOT + 'figures/test_' + str(idx) + '.png')
    plt.show()

def display_training_examples(examples, true_labels, label_to_class, idx, figsize=(10,4)):
    plt.figure(figsize=figsize)
    num_examples = len(examples)
    for i in range(num_examples):
        plt.subplot(1, num_examples, i + 1)
        plt.imshow(np.rot90(imshow_transform(examples[i])))
        plt.title(label_to_class[true_labels[i].item()])
        grayscale_cam = cam(input_tensor=examples[i][None, :, :, :], targets=targets)
        grayscale_cam_ = grayscale_cam[0, :]
        vis = show_cam_on_image(imshow_transform(examples[i]), grayscale_cam_)
        plt.imshow(np.rot90(vis))

    plt.savefig(DEFAULT_ROOT + './figures/' + idx + '.png')

def display_proponents_and_opponents(test_examples_batch, proponents_indices, opponents_indices, test_examples_true_labels, test_examples_predicted_labels, test_examples_predicted_probs):
    idx = 1
    for (
        test_example,
        test_example_proponents,
        test_example_opponents,
        test_example_true_label,
        test_example_predicted_label,
        test_example_predicted_prob,
    ) in zip(
        test_examples_batch,
        proponents_indices,
        opponents_indices,
        test_examples_true_labels,
        test_examples_predicted_labels,
        test_examples_predicted_probs,
    ):

        print("test example:")
        display_test_example(
            test_example,
            test_example_true_label,
            test_example_predicted_label,
            test_example_predicted_prob,
            label_to_class,
            idx + 32
        )

        print("proponents:")
        test_example_proponents_tensors, test_example_proponents_labels = zip(
            *[train_dataset[int(i)] for i in test_example_proponents]
        )
        display_training_examples(
            test_example_proponents_tensors, test_example_proponents_labels, label_to_class, 'pro_' + str(idx + 32), figsize=(20, 8),
        )

        print("opponents:")
        test_example_opponents_tensors, test_example_opponents_labels = zip(
            *[train_dataset[int(i)] for i in test_example_opponents]
        )
        display_training_examples(
            test_example_opponents_tensors, test_example_opponents_labels, label_to_class, 'opp_' + str(idx + 32), figsize=(20, 8)
        )
        idx += 1


In [None]:
display_proponents_and_opponents(
    test_examples_features,
    proponents_indices,
    opponents_indices,
    test_examples_true_labels,
    test_examples_predicted_labels,
    test_examples_predicted_probs,
)