## Visualize image-specific class saliency with backpropagation

---

The gradients obtained can be used to visualise an image-specific class saliency map, which can gives some intuition on regions within the input image that contribute the most (and least) to the corresponding output.

More details on saliency maps: [Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps](https://arxiv.org/pdf/1312.6034.pdf).

### 0. Set up (Colab only)

In [1]:
# # Install flashtorch if you don't have it

# !pip install webdataset


# # This mounts your Google Drive to th?e Colab VM.
# from google.colab import drive
# drive.mount('/content/drive')

# # TODO: Enter the foldername in your Drive where you have saved the unzipped
# # assignment folder, e.g. 'cs231n/assignments/assignment3/'
# FOLDERNAME = 'Courses/Spring 2021/CS 231N/cs231n_final_proj/netviz'
# assert FOLDERNAME is not None, "[!] Enter the foldername."

# # Now that we've mounted your Drive, this ensures that
# # the Python interpreter of the Colab VM can load
# # python files from within it.
# import sys
# sys.path.append('/content/drive/My Drive/{}'.format(FOLDERNAME))

# os.chdir('/content/drive/My Drive/{}'.format(FOLDERNAME+'/notebooks'))

In [31]:
# %matplotlib inline
%config InlineBackend.figure_format = 'retina'

import os
import gc
import sys
import json
import webdataset as wds

sys.path.append('..')

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from tqdm import tqdm
import psutil


from model.baseline_3d_cnn import *
from model.selfattn_3d_cnn import *
from model.resattn_3d_cnn import *
from utils.saliency_map import *
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
data_dir = '../data'
shards_dir = os.path.join(data_dir, 'shards_new')

# Opening JSON file
with open('../parameters.json') as json_file:
    parameters = json.load(json_file)

batch_size = parameters['batch_size']
shard_size = parameters['shard_size']
parameters

{'batch_size': 4, 'shard_size': 16}

In [4]:
urls = [os.path.join(shards_dir, it) for it in os.listdir(shards_dir) if it.endswith('.tar')]
wds_len = len(urls)*shard_size//batch_size

image_dataset = (
    wds
    .WebDataset(urls, length=wds_len)
    .shuffle(shard_size)
    .decode('torch')
    .to_tuple('volumes.pyd', 'labels.pyd', 'studynames.pyd')
    .batched(batch_size)
#     .map_tuple(pre_transforms, identity, identity)
)


image_loader = torch.utils.data.DataLoader(image_dataset, num_workers=0, batch_size=None)

['../data/shards_new/shard-000076.tar', '../data/shards_new/shard-000003.tar', '../data/shards_new/shard-000081.tar', '../data/shards_new/shard-000065.tar', '../data/shards_new/shard-000051.tar', '../data/shards_new/shard-000044.tar', '../data/shards_new/shard-000037.tar', '../data/shards_new/shard-000095.tar', '../data/shards_new/shard-000084.tar', '../data/shards_new/shard-000090.tar', '../data/shards_new/shard-000040.tar', '../data/shards_new/shard-000062.tar', '../data/shards_new/shard-000022.tar', '../data/shards_new/shard-000078.tar', '../data/shards_new/shard-000045.tar', '../data/shards_new/shard-000023.tar', '../data/shards_new/shard-000058.tar', '../data/shards_new/shard-000063.tar', '../data/shards_new/shard-000086.tar', '../data/shards_new/shard-000071.tar', '../data/shards_new/shard-000067.tar', '../data/shards_new/shard-000013.tar', '../data/shards_new/shard-000034.tar', '../data/shards_new/shard-000072.tar', '../data/shards_new/shard-000082.tar', '../data/shards_new/shar

### 1. Load an image 

In [5]:
# patient_num = 1

# for t, (x, y, z) in enumerate(image_loader):
#     if t > 0:
#         break
#     img_df = x[patient_num, :, :, :].detach().numpy()
#     img_y = y[patient_num, :]
# del x # For now for memory reasons

In [6]:
# fig, axs = plt.subplots(5,8, figsize=(15, 6))
# axs = axs.ravel()
# for i in range(40):
#     axs[i].imshow(img_df[i,:,:])

### 2. Load a pre-trained Model

In [61]:
ckpt_dir = os.path.join('..', 'runs', 'baseline')
ckpt_path = os.path.join(ckpt_dir, 'baseline_final_model.pt') # COLAB
ckpt = torch.load(ckpt_path)

ckpt_model = baseline_3DCNN(in_num_ch=1)
ckpt_model.load_state_dict(ckpt)

plot_dir = os.path.join('..', 'net_visualization', 'baseline')


In [58]:
# ckpt_dir = os.path.join('..', 'runs', 'experiment_att')
# ckpt_path = os.path.join(ckpt_dir, 'experiment_final_model.pt') 
# ckpt = torch.load(ckpt_path)

# ckpt_model = selfattn_3DCNN(in_num_ch=1)
# ckpt_model.load_state_dict(ckpt)

# plot_dir = os.path.join('..', 'net_visualization', 'experiment_att')


In [32]:
# ckpt_dir = os.path.join('..', 'runs', 'experiment_res')
# ckpt_path = os.path.join(ckpt_dir, 'experimentres_final_model.pt') 
# ckpt = torch.load(ckpt_path)

# ckpt_model = resattn_3DCNN(in_num_ch=1)
# ckpt_model.load_state_dict(ckpt)

# plot_dir = os.path.join('..', 'net_visualization', 'experiment_res')


<All keys matched successfully>

### 3. Generate saliency map statistics

In [62]:
USE_GPU = True

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
#     dtype = torch.cuda.FloatTensor
else:
    device = torch.device('cpu')
print(device)

cuda


In [63]:
gc.collect()

61206

In [None]:
full_sal_df = pd.DataFrame()

with tqdm(image_loader, unit="batch") as tepoch:
    for t, (x, y, z) in enumerate(tepoch): # For each batch
        gc.collect()
            
        tepoch.set_description("Batch %d" % t)
        
        # Some default constants
        N = y.shape[0]
        ICHD = y.shape[1]

        # Compute saliency and rank of slices
        saliency = compute_saliency_maps(x, y, ckpt_model, device) # (ICH_types, N, D, H, W)
        saliency_rank = rank_saliency_slices(saliency) #(ICH_types, N, D)
        
        # Compute whether prediction was correct or not
        ckpt_model = ckpt_model.to(device)
        ckpt_model.eval()
        
        preds = (torch.sigmoid(ckpt_model(x.to(device, dtype=torch.float))) >= 0.5).long()
        corr_bool = (y == preds.cpu()).long() # N X 6

        # Save plot of saliency map in each patient directory
        # Only do up to 10 batches for plotting (40 patients)
#         if t < 10:
#             for p_index in range(N):
#                 patient_id = z[p_index].item()
#                 patient_ichs_ind = y[p_index, :].nonzero(as_tuple=True)[0] # Only ICH nums which patients have
#                 if len(patient_ichs_ind) == 0: # Skip patient if patient doesn't have any ICH
#                     continue 
#                 patient_corr = corr_bool[p_index, :]

#                 # Make plot directory for patient
#                 dir_nm = '_'.join([patient_id, 
#                                    'numICHs', str(len(patient_ichs_ind)), 
#                                    'numCorr', str(patient_corr[patient_ichs_ind].sum().item())
#                                   ])
#                 patient_dir = os.path.join(plot_dir, dir_nm)
#                 if not os.path.exists(patient_dir):
#                     os.mkdir(patient_dir)

#                 for ich_num in patient_ichs_ind.cpu().numpy():
#                     sal_plot = plot_saliency_maps(x, saliency, ich_num=ich_num, patient_id=p_index, 
#                                                   d_range=np.arange(0,40), num_rows = 2*8, num_cols = 5)
#                     sal_plot.savefig(os.path.join(patient_dir, '_'.join(['ICHNum', str(ich_num), 'Corr', str(patient_corr[ich_num].item())])))
#                     # Clear the current axes.
#                     plt.cla() 
#                     # Clear the current figure.
#                     plt.clf() 
#                     sal_plot.clear()
#                     plt.close('all')
#                     del sal_plot
                
        # Convert to pandas df to save general information (patient ID, num of ICH class, pred corr or not (per num ICH class), sal rank per ICH class)
        save_df = pd.DataFrame(saliency_rank.detach().cpu().numpy().reshape(ICHD*N, -1)) # (ICH_types*N, D)
        save_df.columns = ['rank_' + str(rk) for rk in save_df.columns]

        save_df['patient_id'] = z.repeat(ICHD)
        save_df['ICH_num'] = np.tile(np.arange(0, ICHD), N)
        save_df['ICH_true'] = y.numpy().astype(int).reshape(-1)
        save_df['ICH_pred'] = preds.detach().cpu().numpy().astype(int).reshape(-1)
        save_df['corr_pred'] = corr_bool.detach().numpy().reshape(-1)
        save_df = save_df[save_df.columns.tolist()[-5:] + save_df.columns.tolist()[:-5]]

        full_sal_df = pd.concat([full_sal_df, save_df], axis=0)
        
        # Delete to free up memory?
        del saliency
        del saliency_rank
        del x
        
        print('CUDA:', torch.cuda.memory_allocated())
        print('RAM:', psutil.virtual_memory().percent)

full_sal_df.shape


In [40]:
full_sal_df.to_csv(os.path.join(plot_dir, 'saliency_statistics.csv'), index=False)

In [41]:
full_sal_df.shape

(9216, 45)

### Plots for first 2 batches' patients, across different models, and top 10 slices:

In [113]:
urls = [os.path.join(shards_dir, it) for it in os.listdir(shards_dir) if it.endswith('.tar')]
urls = urls[:2]
wds_len = len(urls)*shard_size//batch_size
print(urls)
image_dataset = (
    wds
    .WebDataset(urls, length=wds_len)
#     .shuffle(shard_size)
    .decode('torch')
    .to_tuple('volumes.pyd', 'labels.pyd', 'studynames.pyd')
    .batched(batch_size)
#     .map_tuple(pre_transforms, identity, identity)
)


image_loader = torch.utils.data.DataLoader(image_dataset, num_workers=0, batch_size=None)

['../data/shards_new/shard-000076.tar', '../data/shards_new/shard-000003.tar']


In [114]:
models = []

In [115]:

ckpt_dir = os.path.join('..', 'runs', 'baseline')
ckpt_path = os.path.join(ckpt_dir, 'baseline_final_model.pt') # COLAB
ckpt = torch.load(ckpt_path)

ckpt_model = baseline_3DCNN(in_num_ch=1)
ckpt_model.load_state_dict(ckpt)

models.append(ckpt_model)

In [116]:
ckpt_dir = os.path.join('..', 'runs', 'experiment_att')
ckpt_path = os.path.join(ckpt_dir, 'experiment_final_model.pt') 
ckpt = torch.load(ckpt_path)

ckpt_model = selfattn_3DCNN(in_num_ch=1)
ckpt_model.load_state_dict(ckpt)

models.append(ckpt_model)

In [117]:
ckpt_dir = os.path.join('..', 'runs', 'experiment_res')
ckpt_path = os.path.join(ckpt_dir, 'experimentres_final_model.pt') 
ckpt = torch.load(ckpt_path)

ckpt_model = resattn_3DCNN(in_num_ch=1)
ckpt_model.load_state_dict(ckpt)

models.append(ckpt_model)


In [118]:
model_names = ['baseline', 'experiment_att', 'experiment_res']
plot_dir = os.path.join('..', 'net_visualization', 'select_patients')
gc.collect()

52

In [124]:
for i, mod in enumerate(zip(model_names, models)):
    with tqdm(image_loader, unit="batch") as tmod:
        for t, (x, y, z) in enumerate(tmod): # For each batch
            gc.collect()
            tmod.set_description("Model %d" % i)

            # Some default constants
            N = y.shape[0]
            ICHD = y.shape[1]
            
            modname = mod[0]
            ckpt_model = mod[1]
                        
            # Compute saliency and rank of slices
            saliency = compute_saliency_maps(x, y, ckpt_model, device) # (ICH_types, N, D, H, W)
            saliency_rank = rank_saliency_slices(saliency) #(ICH_types, N, D)

            # Compute whether prediction was correct or not
            ckpt_model = ckpt_model.to(device)
            ckpt_model.eval()

            preds = (torch.sigmoid(ckpt_model(x.to(device, dtype=torch.float))) >= 0.5).long()
            corr_bool = (y == preds.cpu()).long() # N X 6

            # Save plot of saliency map in each patient directory
            for p_index in range(N):
                patient_id = z[p_index].item()
                patient_ichs_ind = y[p_index, :].nonzero(as_tuple=True)[0] # Only ICH nums which patients have
                if len(patient_ichs_ind) == 0: # Skip patient if patient doesn't have any ICH
                    continue 
                patient_corr = corr_bool[p_index, :]

                # Make plot directory for patient
                patient_dir = os.path.join(plot_dir, patient_id)
                if not os.path.exists(patient_dir):
                    os.mkdir(patient_dir)
                
                # Make directory for model within that patient
                dir_nm = '_'.join([modname, 
                                   'numICHs', str(len(patient_ichs_ind)), 
                                   'numCorr', str(patient_corr[patient_ichs_ind].sum().item())
                                  ])
                mod_dir = os.path.join(patient_dir, dir_nm)
                if not os.path.exists(mod_dir):
                    os.mkdir(mod_dir)

                # For each different ICH subtype that the patient has
                for ich_num in patient_ichs_ind.cpu().numpy():
                    # Generate top 10 slices for that ICH type, patient
                    sal_plot = plot_saliency_maps(x, saliency, ich_num=ich_num, patient_id=p_index, 
                                                  d_range=saliency_rank[ich_num, p_index, :10].detach().cpu().numpy(),
                                                  num_rows=4, num_cols=5, figsize=(14,9)
                                                 ) 
                    top10_dir = os.path.join(
                        mod_dir,
                        '_'.join(['ICHNum', str(ich_num), 'Corr', str(patient_corr[ich_num].item()), 'top10'])
                    )
                    sal_plot.savefig(top10_dir)
                    # Clear the current axes.
                    plt.cla() 
                    # Clear the current figure.
                    plt.clf() 
                    sal_plot.clear()
                    plt.close('all')
                    del sal_plot
                    
                    # Generate full slices for each patient
                    sal_plot = plot_saliency_maps(x, saliency, ich_num=ich_num, patient_id=p_index, 
                                                  d_range=np.arange(0,40),
                                                  num_rows=16, num_cols=5, figsize=(20,40)
                                                 ) 
                    full_dir = os.path.join(
                        mod_dir,
                        '_'.join(['ICHNum', str(ich_num), 'Corr', str(patient_corr[ich_num].item()), 'full'])
                    )
                    sal_plot.savefig(full_dir)
                    # Clear the current axes.
                    plt.cla() 
                    # Clear the current figure.
                    plt.clf() 
                    sal_plot.clear()
                    plt.close('all')
                    del sal_plot

            # Delete to free up memory?
            del saliency
            del saliency_rank
            del x

#             print('CUDA:', torch.cuda.memory_allocated())
#             print('RAM:', psutil.virtual_memory().percent)


 12%|█▎        | 1/8 [01:33<10:55, 93.71s/batch]

CUDA: 108878336
RAM: 36.0


 25%|██▌       | 2/8 [03:59<12:27, 124.60s/batch]

CUDA: 108878336
RAM: 38.5


 38%|███▊      | 3/8 [05:02<08:01, 96.20s/batch] 

CUDA: 108878336
RAM: 40.2


 50%|█████     | 4/8 [05:03<03:55, 58.86s/batch]

CUDA: 108878336
RAM: 40.2


 62%|██████▎   | 5/8 [06:32<03:28, 69.56s/batch]

CUDA: 108878336
RAM: 42.7


 75%|███████▌  | 6/8 [07:47<02:22, 71.30s/batch]

CUDA: 108878336
RAM: 44.7


 88%|████████▊ | 7/8 [09:20<01:18, 78.40s/batch]

CUDA: 108878336
RAM: 47.1


100%|██████████| 8/8 [10:16<00:00, 77.05s/batch]
  0%|          | 0/8 [00:00<?, ?batch/s]

CUDA: 108878336
RAM: 47.5


 12%|█▎        | 1/8 [01:32<10:44, 92.04s/batch]

CUDA: 216417280
RAM: 50.6


 25%|██▌       | 2/8 [02:46<08:10, 81.73s/batch]

CUDA: 216417280
RAM: 52.8


 38%|███▊      | 3/8 [04:15<07:04, 84.90s/batch]

CUDA: 216417280
RAM: 55.5


 50%|█████     | 4/8 [05:09<04:51, 72.88s/batch]

CUDA: 216417280
RAM: 57.1


 62%|██████▎   | 5/8 [06:40<03:58, 79.40s/batch]

CUDA: 216417280
RAM: 59.8


 75%|███████▌  | 6/8 [09:02<03:21, 100.73s/batch]

CUDA: 216417280
RAM: 64.1


 88%|████████▊ | 7/8 [10:07<01:29, 89.00s/batch] 

CUDA: 216417280
RAM: 66.0


100%|██████████| 8/8 [10:10<00:00, 76.31s/batch]
  0%|          | 0/8 [00:00<?, ?batch/s]

CUDA: 216417280
RAM: 65.4


 12%|█▎        | 1/8 [01:29<10:25, 89.36s/batch]

CUDA: 428837376
RAM: 68.0


 25%|██▌       | 2/8 [02:42<07:57, 79.58s/batch]

CUDA: 428837376
RAM: 70.2


 38%|███▊      | 3/8 [04:11<07:01, 84.28s/batch]

CUDA: 428837376
RAM: 72.9


 50%|█████     | 4/8 [05:07<04:51, 72.91s/batch]

CUDA: 428837376
RAM: 74.6


 62%|██████▎   | 5/8 [06:38<03:58, 79.46s/batch]

CUDA: 428837376
RAM: 77.4


 75%|███████▌  | 6/8 [09:00<03:21, 100.79s/batch]

CUDA: 428837376
RAM: 81.8


 88%|████████▊ | 7/8 [10:05<01:28, 88.91s/batch] 

CUDA: 428837376
RAM: 83.7


100%|██████████| 8/8 [10:08<00:00, 76.08s/batch]

CUDA: 428837376
RAM: 83.1





In [None]:
# # Getting which patients have non ICH types
# patient_keys, nonzero_dims = y.nonzero(as_tuple=True)
# patient_keys = list(patient_keys.numpy())
# nonzero_dims = list(nonzero_dims.numpy())

# nonzero_dict = {}
# for i in range(len(patient_keys)):
#     nonzero_dict[patient_keys[i]] = nonzero_dict.get(patient_keys[i], []) + [nonzero_dims[i]]
# nonzero_dict

### Analysis of important CT slices over all patients

In [251]:
sal_stat_dir = os.path.join('..', 'net_visualization')
model_names = ['baseline', 'experiment_att', 'experiment_res']

In [250]:
for modnm in model_names:
    sal_stat = pd.read_csv(os.path.join(sal_stat_dir, modnm, 'saliency_statistics.csv'))
#     print(sal_stat.shape)

    # Weird duplicates, shouldn't exist..
    sal_stat = sal_stat.drop_duplicates(['patient_id', 'ICH_num'])
#     print(sal_stat.shape)

    # Restrict to only cases when correct prediction was made
    sal_stat = sal_stat.loc[sal_stat['corr_pred']==1]
#     print(sal_stat.shape)

    # Also restrict to only cases when there actually was an ICH (bc of how we defined saliency)
    sal_stat = sal_stat.loc[sal_stat['ICH_true']==1]
#     print(sal_stat.shape)

    sal_stat_l = sal_stat.melt(id_vars=sal_stat.columns[:5], var_name='Rank', value_name='slice')
    sal_stat_l['Rank'] = sal_stat_l['Rank'].str.replace('rank_', '').astype(int) + 1
    sal_avgrank = (
        sal_stat_l
        .groupby(['ICH_num', 'slice'])['Rank']
        .mean()
        .reset_index(name="avg_rank")
        .sort_values(['ICH_num', 'avg_rank'])
    )
    sal_avgrank['ICH_num_rank'] = sal_avgrank.groupby(['ICH_num'])['avg_rank'].rank()

    sal_avg_w = sal_avgrank.pivot(index='slice', columns=['ICH_num'], values='ICH_num_rank')
    sal_avg_w.columns.name = ''
    sal_avg_w.reset_index(inplace=True)
    sal_avg_w.columns = ['ICH_' + str(s) if s != 'slice' else s for s in sal_avg_w.columns]
    sal_avg_w['avg_rank'] = sal_avg_w.iloc[:, 1:].mean(axis=1)
    sal_avg_w.sort_values('avg_rank', ascending=True, inplace=True)
    sal_avg_w.to_csv(os.path.join(sal_stat_dir, modnm, 'saliency_ranks.csv'), index=False)

In [254]:
sal_stat = pd.read_csv(os.path.join(sal_stat_dir, modnm, 'saliency_statistics.csv'))


### TEMP, just to generate slices 31-25

In [261]:
ckpt_model = models[1]

In [263]:
plot_dir = '../net_visualization/tmp'

In [265]:

with tqdm(image_loader, unit="batch") as tepoch:
    for t, (x, y, z) in enumerate(tepoch): # For each batch
        if t > 0:
            break
        gc.collect()
            
        tepoch.set_description("Batch %d" % t)
        
        # Some default constants
        N = y.shape[0]
        ICHD = y.shape[1]

        # Compute saliency and rank of slices
        saliency = compute_saliency_maps(x, y, ckpt_model, device) # (ICH_types, N, D, H, W)
        saliency_rank = rank_saliency_slices(saliency) #(ICH_types, N, D)
        
        # Compute whether prediction was correct or not
        ckpt_model = ckpt_model.to(device)
        ckpt_model.eval()
        
        preds = (torch.sigmoid(ckpt_model(x.to(device, dtype=torch.float))) >= 0.5).long()
        corr_bool = (y == preds.cpu()).long() # N X 6

        # Save plot of saliency map in each patient directory
        # Only do up to 10 batches for plotting (40 patients)
        if t < 10:
            for p_index in range(N):
                patient_id = z[p_index].item()
                patient_ichs_ind = y[p_index, :].nonzero(as_tuple=True)[0] # Only ICH nums which patients have
                if len(patient_ichs_ind) == 0: # Skip patient if patient doesn't have any ICH
                    continue 
                patient_corr = corr_bool[p_index, :]

                # Make plot directory for patient
                dir_nm = '_'.join([patient_id, 
                                   'numICHs', str(len(patient_ichs_ind)), 
                                   'numCorr', str(patient_corr[patient_ichs_ind].sum().item())
                                  ])
                patient_dir = os.path.join(plot_dir, dir_nm)
                if not os.path.exists(patient_dir):
                    os.mkdir(patient_dir)

                for ich_num in patient_ichs_ind.cpu().numpy():
                    sal_plot = plot_saliency_maps(x, saliency, ich_num=ich_num, patient_id=p_index, 
                                                  d_range=[31,30,29,28,27,25], num_rows = 2, num_cols = 6, figsize=(12,5))
                    sal_plot.savefig(os.path.join(patient_dir, '_'.join(['ICHNum', str(ich_num), 'Corr', str(patient_corr[ich_num].item())])))
                    # Clear the current axes.
                    plt.cla() 
                    # Clear the current figure.
                    plt.clf() 
                    sal_plot.clear()
                    plt.close('all')
                    del sal_plot
        
        print('CUDA:', torch.cuda.memory_allocated())
        print('RAM:', psutil.virtual_memory().percent)


Batch 0:  12%|█▎        | 1/8 [00:13<01:33, 13.34s/batch]

CUDA: 1435478016
RAM: 76.9


Batch 0:  12%|█▎        | 1/8 [00:15<01:46, 15.16s/batch]
