# Gradcam (Jupyter Notebook)

### Prerequisites
- gradcam_input_data.info file with the following:
    TODO:
- model.py file 
- saved weights (.pt file)
- input subvolume organized as *.tif images or 3D numpy arrray

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Union

In [None]:
# Read from the dataset info file
import configparser
data_info = configparser.ConfigParser()
data_info.read('gradcam_input_data.info')

In [None]:
# Define input data
input_prefix = data_info['input_data']['input_prefix']
data_dir = data_info['input_data']['data_dir']

# Define model
weights = data_info['model']['saved_weights']

# Define output data
output_prefix = data_info['output_data']['output_prefix']

In [None]:
 # Sanity check
print("input_prefix: ", input_prefix)
print("data_dir: ", data_dir)
print("weights: ", weights)
print("output_prefix: ", output_prefix)

## Load input subvolume data (can be skipped if starting with a numpy 3D array)

In [None]:
import numpy as np
from pathlib import Path
from PIL import Image
import re

In [None]:
dataset_dir = Path(f'{input_prefix}/{data_dir}/')

# print(dataset_dir)
files = list(dataset_dir.glob('*.tif'))
files.sort(key=lambda f: int(re.sub(r'[^0-9]*', "", str(f))))

print(files)

## Load data and convert to a numpy 3D array

In [None]:
subvolume = []
images = []
for f in files:
  i = Image.open(f)
  subvolume.append(np.array(Image.open(f), dtype=np.float32))
  images.append(i)

# convert to numpy
subvolume = np.array(subvolume) 
print(np.shape(subvolume))

## GradCam Wrapper class

In [None]:
class Gradcam3D:

  def __init__(self, encoder, decoder):
    
    # TODO: Can we pass the model (torch.nnSequential) and the layers for hooks?

    # Create a model
    self.encoder = encoder
    self.decoder = decoder
    self.model=torch.nn.Sequential(encoder, decoder)

    # placeholder for the gradients
    self.gradients = None

    # placeholder for the activations
    self.activations = None

    # Register hooks
    self.model[0].conv4.register_backward_hook(self.printgradnorm)
    self.model[0].conv4.register_forward_hook(self.printnorm)

  def printnorm(self, module, input, output):
    print('input size:', input[0].size(), '    <---- activations')
    self.activations = input[0].detach()     

  def printgradnorm(self, module, grad_input, grad_output):

    print('grad_input size:', grad_input[0].size(), '    <---- gradients' )
    self.gradients = grad_input[0].detach()

  def print_model(self):
    for param_tensor in self.model.state_dict():
      print(param_tensor, "\t", self.model.state_dict()[param_tensor].size())

  def load_weights(self, path):
      self.model.load_state_dict(torch.load(path, map_location=torch.device('cpu'))['model_state_dict'], strict=False)
      self.model.eval()

## Inspect the saved weights (optional)

In [None]:
# Sanity check

checkpoint = torch.load(weights, map_location=torch.device('cpu'))
print(checkpoint.keys()) # --> dict_keys(['epoch', 'batch', 'model_state_dict', 'optimizer_state_dict'])
print(checkpoint['model_state_dict'].keys())

## Load encoder and decoder to the Gradcam wrapper class

In [None]:
import model_localcopy as saved_model

In [None]:
subvolume_shape = [48, 48, 48]
batch_norm_momentum = 0.9
no_batch_norm = False
filters = [32, 16, 8, 4] 
in_channels = 1

encoder = saved_model.Subvolume3DcnnEncoder(
                subvolume_shape=subvolume_shape,
                batch_norm_momentum=batch_norm_momentum, 
                no_batch_norm=no_batch_norm, 
                filters=filters, 
                in_channels=in_channels)

In [None]:
decoder = saved_model.LinearInkDecoder(
                drop_rate=0.5, 
                input_shape=encoder.output_shape,
                output_neurons=2)

In [None]:
inkid_gradcam = Gradcam3D(encoder, decoder)

In [None]:
inkid_gradcam.print_model()

## Load the saved weights

In [None]:
inkid_gradcam.load_weights(weights)

## Prepare the input subvolume by adding 2 extra axes

In [None]:
subvolume=subvolume[np.newaxis, np.newaxis, ...]
print("final subvolume input shape:", subvolume.shape)

## Convert subvolume to Torch tensor

In [None]:
subvolume = torch.from_numpy(subvolume)

## Push the subvolume through and obtain prediction

In [None]:
prediction=inkid_gradcam.model(subvolume).argmax(dim=1).item()
print("prediction:", prediction)   # 1 is ink; 0 is no ink

## Calculate activation "hotness" (GradCam)

In [None]:
inkid_gradcam.model(subvolume)[:, prediction, :, :].backward()

In [None]:
activations = inkid_gradcam.activations
print("activations size: ", activations.size())
gradients = inkid_gradcam.gradients
print("gradients size: ", gradients.size())

#### Expression (1) in https://arxiv.org/pdf/1610.02391.pdf
\begin{equation}
\alpha^c_k = 1/Z\sum_{i}\sum_{j} \frac{\partial y^c}{\partial A^k_{ij}}
\end{equation}


In [None]:
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3, 4])
# Sanity Check
print(pooled_gradients.size())  # ---> [8]
print(pooled_gradients[0]) 

#### Expression (2) in https://arxiv.org/pdf/1610.02391.pdf
\begin{equation}
L^c_{Grad-CAM} = ReLU(\sum_{k}\alpha^c_K A^k) 
\end{equation}

In [None]:
# weight the channels by corresponding gradients
for i in range(8):
    activations[:, i, :, :, :] *= pooled_gradients[i]

In [None]:
# average the channels of the activations
heatmap = torch.mean(activations, dim=1).squeeze()
#heatmap.size() # torch.Size([9, 9, 9])

In [None]:
# relu on top of the heatmap (we are not insterested in negative values)
heatmap = np.maximum(heatmap.detach(), 0)
heatmap

## Visualization

In [None]:
# normalize the heatmap
heatmap /= torch.max(heatmap)
heatmap[0][11]

In [None]:
import plotly.graph_objects as go

X, Y, Z = np.mgrid[0:12:, 0:12, 0:12]
values = heatmap

gradient_map = go.Figure(data=go.Volume(
    x=X.flatten(),
    y=Y.flatten(),
    z=Z.flatten(),
    value=values.flatten(),
    isomin=0.1,
    isomax=1.0,
    opacity=0.2, # needs to be small to see through all surfaces
    surface_count=10, # needs to be a large number for good volume rendering
    colorscale = 'rainbow'
    ))

gradient_map.update_layout(showlegend=False)

#gradient_map.write_image("images/prediction_gradient_map.png")
gradient_map.show()

### Original Subvolume

In [None]:
subvolume_orig = torch.squeeze(subvolume)

In [None]:
X, Y, Z = np.mgrid[0:48, 0:48, 0:48]
# TODO: Change this to accommodate more shapes

subvolume_map = go.Figure(data=go.Volume(
    x=X.flatten(),
    y=Y.flatten(),
    z=Z.flatten(),
    value=subvolume_orig.flatten(),
    isomin=0.1,
    isomax=torch.max(subvolume_orig).item()*0.9,
    opacity=0.3, # needs to be small to see through all surfaces
    surface_count=10, # needs to be a large number for good volume rendering
    colorscale = 'Greys'
    ))

subvolume_map.update_layout(showlegend=False)
#subvolume_map.write_image("images/subvolume_map.png")
subvolume_map.show()

### If overlapping is desired

In [None]:
'''
import cv2
gradient_img = cv2.imread('images/gradient_map.png')
subvolume_img = cv2.imread('images/subvolume_map.png')

superimposed_img = cv2.addWeighted(gradient_img, 0.35, subvolume_img, 0.65, 0)
cv2.imwrite('images/superimposed.png', superimposed_img)
'''