## Visualize image-specific class saliency with backpropagation

---

The gradients obtained can be used to visualise an image-specific class saliency map, which can gives some intuition on regions within the input image that contribute the most (and least) to the corresponding output.

More details on saliency maps: [Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps](https://arxiv.org/pdf/1312.6034.pdf).

### 0. Set up (Colab only)

In [1]:
# # Install flashtorch if you don't have it

# !pip install webdataset


# # This mounts your Google Drive to the Colab VM.
# from google.colab import drive
# drive.mount('/content/drive')

# # TODO: Enter the foldername in your Drive where you have saved the unzipped
# # assignment folder, e.g. 'cs231n/assignments/assignment3/'
# FOLDERNAME = 'Courses/Spring 2021/CS 231N/cs231n_final_proj/netviz'
# assert FOLDERNAME is not None, "[!] Enter the foldername."

# # Now that we've mounted your Drive, this ensures that
# # the Python interpreter of the Colab VM can load
# # python files from within it.
# import sys
# sys.path.append('/content/drive/My Drive/{}'.format(FOLDERNAME))

# os.chdir('/content/drive/My Drive/{}'.format(FOLDERNAME+'/notebooks'))

In [2]:
# %matplotlib inline
%config InlineBackend.figure_format = 'retina'

import os
import gc
import sys
import json
import webdataset as wds

sys.path.append('..')

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from tqdm import tqdm

from model.baseline_3d_cnn import *
from utils.saliency_map import *
%load_ext autoreload
%autoreload 2

In [3]:
data_dir = '../data'
shards_dir = os.path.join(data_dir, 'shards_new')

# Opening JSON file
with open('../parameters.json') as json_file:
    parameters = json.load(json_file)

batch_size = parameters['batch_size']
shard_size = parameters['shard_size']
parameters

{'batch_size': 4, 'shard_size': 16}

In [4]:
urls = [os.path.join(shards_dir, it) for it in os.listdir(shards_dir) if it.endswith('.tar')]
print(urls)
wds_len = len(urls)*shard_size//batch_size

image_dataset = (
    wds
    .WebDataset(urls, length=wds_len)
    .shuffle(shard_size)
    .decode('torch')
    .to_tuple('volumes.pyd', 'labels.pyd', 'studynames.pyd')
    .batched(batch_size)
#     .map_tuple(pre_transforms, identity, identity)
)


image_loader = torch.utils.data.DataLoader(image_dataset, num_workers=0, batch_size=None)

['../data/shards_new/shard-000076.tar', '../data/shards_new/shard-000003.tar', '../data/shards_new/shard-000081.tar', '../data/shards_new/shard-000065.tar', '../data/shards_new/shard-000051.tar', '../data/shards_new/shard-000044.tar', '../data/shards_new/shard-000037.tar', '../data/shards_new/shard-000095.tar', '../data/shards_new/shard-000084.tar', '../data/shards_new/shard-000090.tar', '../data/shards_new/shard-000040.tar', '../data/shards_new/shard-000062.tar', '../data/shards_new/shard-000022.tar', '../data/shards_new/shard-000078.tar', '../data/shards_new/shard-000045.tar', '../data/shards_new/shard-000023.tar', '../data/shards_new/shard-000058.tar', '../data/shards_new/shard-000063.tar', '../data/shards_new/shard-000086.tar', '../data/shards_new/shard-000071.tar', '../data/shards_new/shard-000067.tar', '../data/shards_new/shard-000013.tar', '../data/shards_new/shard-000034.tar', '../data/shards_new/shard-000072.tar', '../data/shards_new/shard-000082.tar', '../data/shards_new/shar

### 1. Load an image 

In [5]:
# patient_num = 1

# for t, (x, y, z) in enumerate(image_loader):
#     if t > 0:
#         break
#     img_df = x[patient_num, :, :, :].detach().numpy()
#     img_y = y[patient_num, :]
# del x # For now for memory reasons

In [6]:
# fig, axs = plt.subplots(5,8, figsize=(15, 6))
# axs = axs.ravel()
# for i in range(40):
#     axs[i].imshow(img_df[i,:,:])

### 2. Load a pre-trained Model

In [7]:
ckpt_dir = os.path.join('..', 'runs', 'baseline')
# ckpt_path = os.path.join(ckpt_dir, 'Checkpoints', 'ep_5_iter_65_ckpt.pt')
ckpt_path = os.path.join(ckpt_dir, 'baseline_final_model.pt') # COLAB
ckpt = torch.load(ckpt_path)

ckpt_model = baseline_3DCNN(in_num_ch=1)
ckpt_model.load_state_dict(ckpt)


<All keys matched successfully>

### 3. Saliency map

In [8]:
USE_GPU = True

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
#     dtype = torch.cuda.FloatTensor
else:
    device = torch.device('cpu')
print(device)

cuda


In [10]:
plot_dir = os.path.join('..', 'net_visualization')
full_sal_df = pd.DataFrame()

In [11]:
with tqdm(image_loader, unit="batch") as tepoch:
    for t, (x, y, z) in enumerate(tepoch): # For each batch
        if t>0:
            break
        tepoch.set_description("Batch %d" % t)
        
        # Some default constants
        N = y.shape[0]
        ICHD = y.shape[1]

        # Compute saliency and rank of slices
        saliency = compute_saliency_maps(x, y, ckpt_model, device) # (ICH_types, N, D, H, W)
        saliency_rank = rank_saliency_slices(saliency) #(ICH_types, N, D)

        # Compute whether prediction was correct or not
        ckpt_model = ckpt_model.to(device)
        # x = x.to(device, dtype=torch.float)
        # y = y.to(device, dtype=torch.float)
        scores = ckpt_model(x.to(device, dtype=torch.float)) # N X 6
        probs = torch.sigmoid(scores) # N X 6

        preds = (probs >= 0.5).long()
        corr_bool = (y == preds.cpu()).long() # N X 6

        # Save plot of saliency map in each patient directory
        for p_index in range(N):
            patient_id = z[p_index].item()
            patient_ichs_ind = y[p_index, :].nonzero(as_tuple=True)[0] # Only ICH nums which patients have
            patient_corr = corr_bool[p_index, :]

            # Make plot directory for patient
            dir_nm = '_'.join([patient_id, 
                               'numICHs', str(len(patient_ichs_ind)), 
                               'numCorr', str(patient_corr[patient_ichs_ind].sum().item())
                              ])
            patient_dir = os.path.join(plot_dir, dir_nm)
            os.mkdir(patient_dir)

            for ich_num in patient_ichs_ind.cpu().numpy():
                sal_plot = plot_saliency_maps(x, saliency, ich_num=ich_num, patient_id=p_index, d_range=np.arange(0,40))
                sal_plot.savefig(os.path.join(patient_dir, '_'.join(['ICHNum', str(ich_num), 'Corr', str(patient_corr[ich_num].item())])))
                plt.close(sal_plot)

        # Convert to pandas df to save general information (patient ID, num of ICH class, pred corr or not (per num ICH class), sal rank per ICH class)
        save_df = pd.DataFrame(saliency_rank.cpu().numpy().reshape(ICHD*N, -1)) # (ICH_types*N, D)
        save_df.columns = ['rank_' + str(rk) for rk in save_df.columns]

        save_df['patient_id'] = z.repeat(ICHD)
        save_df['ICH_num'] = np.tile(np.arange(0, ICHD), N)
        save_df['ICH_true'] = y.numpy().astype(int).reshape(-1)
        save_df['ICH_pred'] = preds.cpu().numpy().astype(int).reshape(-1)
        save_df['corr_pred'] = corr_bool.numpy().reshape(-1)
        save_df = save_df[save_df.columns.tolist()[-5:] + save_df.columns.tolist()[:-5]]

        full_sal_df = pd.concat([full_sal_df, save_df], axis=0)





Batch 0:   0%|          | 0/384 [00:13<?, ?batch/s]


NameError: name 'model' is not defined

In [None]:
# Getting which patients have non ICH types
patient_keys, nonzero_dims = y.nonzero(as_tuple=True)
patient_keys = list(patient_keys.numpy())
nonzero_dims = list(nonzero_dims.numpy())

nonzero_dict = {}
for i in range(len(patient_keys)):
    nonzero_dict[patient_keys[i]] = nonzero_dict.get(patient_keys[i], []) + [nonzero_dims[i]]
nonzero_dict