# Visualizing Intermediate Activatinos
Based on https://github.com/himanshurawlani/convnet-interpretability-keras/blob/master/Visualizing%20intermediate%20activations/visualizing_intermediate_activations.ipynb

https://towardsdatascience.com/visual-interpretability-for-convolutional-neural-networks-2453856210ce



# Sample Config file (dataset.info)
[input_data]  
dataset_name = CarbonPhantom  
subvol_type = NearestNeighbor  
col = Col6  
ground_truth = CarbonInk  
dataset_num = 1  
path_prefix = /home/mhaya2/3d-utilities/SubvolumeVisualization/Data/labeled_subvolume_sampler/  

[output]  
output_prefix = /home/mhaya2/3d-utilities/SubvolumeVisualization/Results/IntermediateActivations2/  

[saved_weights]  
weight_path = /home/mhaya2/3d-utilities/SubvolumeVisualization/SavedWeights/  
weight_file = m1252444_200000.pt  

[hook_layers]  
conv1 = y  
batch_norm1 = y  
conv2 = y  
batch_norm2 = y  
conv3 = y  
batch_norm3 = y  
conv4 = y  
batch_norm4 = y  

# First things first

In [None]:
# Read from the dataset info file
import configparser
data_info = configparser.ConfigParser()
data_info.read('dataset.info')

In [None]:
# Define input data
dataset_name =  data_info['input_data']['dataset_name']
subvol_type = data_info['input_data']['subvol_type']
col = data_info['input_data']['col']
ground_truth = data_info['input_data']['ground_truth']
dataset_num = data_info['input_data']['dataset_num']

# Define output data
output_prefix = data_info['output']['output_prefix']

# Define weights
weight_path = data_info['saved_weights']['weight_path']
weight_file = data_info['saved_weights']['weight_file']

In [None]:
# Sanity check
print("dataset name: ", dataset_name)
print("subvol_type: ", subvol_type)
print("col: ", col)
print("ground_truth: ", ground_truth)
print("dataset_num: ", dataset_num)
print("output_prefix: ", output_prefix)
print("weight_path: ", weight_path)
print("weight_file: ", weight_file)

In [None]:
# Log file
metadata = {} 

In [None]:
# Create output dir
output_subdir = f"{dataset_name}/{subvol_type}/{col}/{ground_truth}/{dataset_num}/{weight_file[:8]}"

output_dir = output_prefix + output_subdir

metadata['output_dir'] = output_dir
print("output_dir: ", output_dir)

import os
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# 3D Volume Rendering Function

In [None]:
import plotly.graph_objects as go

def render_3d(cube, output_file, unit=None):
  X, Y, Z = np.mgrid[0:cube.shape[0], 0:cube.shape[1], 0:cube.shape[2]]
  vol = go.Volume(
      name=dataset_name,
      x = X.flatten(),
      y = Y.flatten(),
      z = Z.flatten(),
      value = cube.flatten(),
      opacity = 0.3,
      opacityscale = 0.3,
      surface_count = 10,
      colorscale='rainbow',
      slices_z = dict(show=True, locations=[10]),
    )
  fig = go.Figure(data=vol)

  if unit:
    vals = []
    texts = []
    for i in range (0,cube.shape[0],8):
      vals.append(i)
      texts.append(str(i*unit))

    fig.update_layout(scene = dict(
                    xaxis = dict(
                        ticktext=texts,
                        tickvals=vals),
                    yaxis = dict(
                        ticktext=texts,
                        tickvals=vals),
                    zaxis = dict(
                        ticktext=texts,
                        tickvals=vals)))
    
  fig.write_image(output_file)
  fig.show()

# Load data

In [None]:
import numpy as np
from pathlib import Path
from PIL import Image
import re
import json

In [None]:
# dataset_name is a directory that contains a series of .tif files
path_prefix = "/home/mhaya2/3d-utilities/SubvolumeVisualization/Data/labeled_subvolume_sampler/"
data_path = f"{dataset_name}/{subvol_type}/{col}/{ground_truth}Individual/{dataset_num}"
dataset = Path(f'{path_prefix}/{data_path}/')
files = list(dataset.glob('*.tif'))
files.sort(key=lambda f: int(re.sub(r'[^0-9]*', "", str(f))))

metadata['dataset'] = str(dataset)

print(files)

In [None]:
subvolume = []
images = []
for f in files:
  i = Image.open(f)
  subvolume.append(np.array(Image.open(f), dtype=np.float32))
  images.append(i)

# convert to numpy
subvolume = np.array(subvolume) 
print(np.shape(subvolume))

# Visualize the Input Data


In [None]:
input_renderings = f'{output_dir}/InputImages'

if not os.path.exists(input_renderings):
    os.makedirs(input_renderings)

In [None]:
import matplotlib.pyplot as plt

## See the original TIFF files

In [None]:

NUM_ROWS = 8
IMGs_IN_ROW = 6

f, ax_arr = plt.subplots(NUM_ROWS, IMGs_IN_ROW, figsize=(18,24))
for j, row in enumerate(ax_arr):
    for i, ax in enumerate(row):
        ax.imshow(images[j*IMGs_IN_ROW+i])
        ax.set_title(f'image {j*IMGs_IN_ROW+i}')

title = output_subdir
f.suptitle(title, fontsize=16)
plt.savefig(f'{input_renderings}/tiff_slices.png')
plt.show() 


## See slices in all 3 directions

In [None]:
NUM_ROWS = 8
IMGs_IN_ROW = 6

f, ax_arr = plt.subplots(NUM_ROWS*3, IMGs_IN_ROW, figsize=(18,60))
for j, row in enumerate(ax_arr):
    if j < 8:
      for i, ax in enumerate(row):
        ax.imshow(subvolume[j*IMGs_IN_ROW+i, :, :])
        ax.set_title(f'x-slice {j*IMGs_IN_ROW+i}')
    elif j < 16:
      for i, ax in enumerate(row):
        ax.imshow(subvolume[:,(j-8)*IMGs_IN_ROW+i, :])
        ax.set_title(f'y-slice {(j-8)*IMGs_IN_ROW+i}')
    else:
      for i, ax in enumerate(row):
        ax.imshow(subvolume[:,:,(j-16)*IMGs_IN_ROW+i])
        ax.set_title(f'z-slice {(j-16)*IMGs_IN_ROW+i}') 

title = output_subdir
f.suptitle(title, fontsize=16)
plt.savefig(f'{input_renderings}/slices.png')
plt.show() 

## Render Subvolume 3D 

In [None]:
render_3d(subvolume, f'{input_renderings}/plotly.png', unit=12)

# Load the Inkid CNN model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Union

In [None]:
### Taken from model.py

def conv_output_shape(input_shape, kernel_size: Union[int, tuple], stride: Union[int, tuple],
                      padding: Union[int, tuple], dilation: Union[int, tuple] = 1):
    dim = len(input_shape)
    # Accept either ints or tuples for these parameters. If int, then convert into tuple (same value all for all dims).
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size,) * dim
    if isinstance(stride, int):
        stride = (stride,) * dim
    if isinstance(padding, int):
        padding = (padding,) * dim
    if isinstance(dilation, int):
        dilation = (dilation,) * dim
    # https://pytorch.org/docs/stable/nn.html#torch.nn.Conv3d See "Shape:" section.
    return tuple(math.floor((input_shape[d] + 2 * padding[d] - dilation[d] * (kernel_size[d] - 1) - 1) / stride[d] + 1)
                 for d in range(dim))


In [None]:
### Slightly modified version of what is in model.py

class Test3DCNN(torch.nn.Module):

  def __init__(self):

    super().__init__()

    input_shape = [48, 48, 48] # to match the subvolume size we are testing
    in_channels = 1
    batch_norm_momentum = 0.9

    self.relu = torch.nn.ReLU()
    self._in_channels = in_channels
    self._batch_norm = True

    filters = filters = [32, 16, 8, 4]
    paddings = [1, 1, 1, 1]
    kernel_sizes = [3, 3, 3, 3]
    strides = [1, 2, 2, 2]

    self.conv1 = torch.nn.Conv3d(in_channels=in_channels, out_channels=filters[0],
                                     kernel_size=kernel_sizes[0], stride=strides[0], padding=paddings[0])
    torch.nn.init.xavier_uniform_(self.conv1.weight)
    torch.nn.init.zeros_(self.conv1.bias)
    self.batch_norm1 = torch.nn.BatchNorm3d(num_features=filters[0], momentum=batch_norm_momentum)
    shape = conv_output_shape(input_shape, kernel_sizes[0], strides[0], paddings[0])

    self.conv2 = torch.nn.Conv3d(in_channels=filters[0], out_channels=filters[1],
                                  kernel_size=kernel_sizes[1], stride=strides[1], padding=paddings[1])
    torch.nn.init.xavier_uniform_(self.conv2.weight)
    torch.nn.init.zeros_(self.conv2.bias)
    self.batch_norm2 = torch.nn.BatchNorm3d(num_features=filters[1], momentum=batch_norm_momentum)
    shape = conv_output_shape(shape, kernel_sizes[1], strides[1], paddings[1])

    self.conv3 = torch.nn.Conv3d(in_channels=filters[1], out_channels=filters[2],
                                  kernel_size=kernel_sizes[2], stride=strides[2], padding=paddings[2])
    torch.nn.init.xavier_uniform_(self.conv3.weight)
    torch.nn.init.zeros_(self.conv3.bias)
    self.batch_norm3 = torch.nn.BatchNorm3d(num_features=filters[2], momentum=batch_norm_momentum)
    shape = conv_output_shape(shape, kernel_sizes[2], strides[2], paddings[2])

    self.conv4 = torch.nn.Conv3d(in_channels=filters[2], out_channels=filters[3],
                                  kernel_size=kernel_sizes[3], stride=strides[3], padding=paddings[3])
    torch.nn.init.xavier_uniform_(self.conv4.weight)
    torch.nn.init.zeros_(self.conv4.bias)
    self.batch_norm4 = torch.nn.BatchNorm3d(num_features=filters[3], momentum=batch_norm_momentum)
    shape = conv_output_shape(shape, kernel_sizes[3], strides[3], paddings[3])
    self.output_shape = (filters[3],) + shape

  def forward(self, x):
      if self._in_channels > 1:
          x = torch.squeeze(x)
      #y = self.conv1(x)
      #y = self.relu(y)
      #if self._batch_norm:
      #    y = self.batch_norm1(y)
      #y = self.conv2(y)
      #y = self.relu(y)
      #if self._batch_norm:
      #    y = self.batch_norm2(y)
      #y = self.conv3(y)
      #y = self.relu(y)
      #if self._batch_norm:
      #    y = self.batch_norm3(y)
      #y = self.conv4(y)
      #y = self.relu(y)
      #if self._batch_norm:
      #    y = self.batch_norm4(y)

      y = self.relu(self.conv1(x))
      if self._batch_norm:
        y = self.batch_norm1(y)  
      y = self.relu(self.conv2(y))
      if self._batch_norm:
        y = self.batch_norm2(y) 
      y = self.relu(self.conv3(y))
      if self._batch_norm:
        y = self.batch_norm3(y) 
      y = self.relu(self.conv4(y))
      if self._batch_norm:
        y = self.batch_norm4(y)          
      return y

class LinearInkDecoder(torch.nn.Module):
    def __init__(self, drop_rate, input_shape, output_neurons):
        super().__init__()

        self.fc = torch.nn.Linear(int(np.prod(input_shape)), output_neurons)
        self.dropout = torch.nn.Dropout(p=drop_rate)

        self.relu = torch.nn.ReLU()
        self.flatten = torch.nn.Flatten()

    def forward(self, x):
        y = self.flatten(x)
        y = self.fc(y)
        y = self.dropout(y)
        # Add some dimensions to match the dimensionality of label which is always 2D even if shape is (1, 1)
        y = torch.unsqueeze(y, 2)
        y = torch.unsqueeze(y, 3)
        return y


# Helper Functions

In [None]:
def print_model(model):
  for param_tensor in model.state_dict():
      print(param_tensor, "\t", model.state_dict()[param_tensor].size())

# Put Models Together

In [None]:
encoder = Test3DCNN()
decoder = LinearInkDecoder(0.6, encoder.output_shape, 2)
model = torch.nn.Sequential(encoder, decoder)

metadata['model'] = str(model)
print(model)

# Load the weights

In [None]:
#weight_path = '/home/mhaya2/3d-utilities/SubvolumeVisualization/SavedWeights/'
#weight_file = 'm1252444_200000.pt'
model.load_state_dict(torch.load(weight_path+weight_file, map_location=torch.device('cpu'))['model_state_dict'], strict=False)

metadata['saved_weights'] = weight_file
# set to eval mode
model.eval()

# Sanity check
print_model(model)

# Hook function for recording activations



In [None]:
activation = {}

def get_activation(name):
  # hook signature
  def hook_function(model, input, output):
    activation[name] = output.detach()
  return hook_function

# Create desired hooks

In [None]:
if data_info['hook_layers']['conv1'] == 'y':
    h1c_name = 'conv1'
    h1c = model[0].conv1.register_forward_hook(get_activation(h1c_name))

if data_info['hook_layers']['batch_norm1'] == 'y':
    h1b_name = 'batch_norm1'
    h1b = model[0].batch_norm1.register_forward_hook(get_activation(h1b_name))

if data_info['hook_layers']['conv2'] == 'y':  
    h2c_name = 'conv2'
    h2c = model[0].conv2.register_forward_hook(get_activation(h2c_name))
    
if data_info['hook_layers']['batch_norm2'] == 'y':
    h2b_name = 'batch_norm2'
    h2b = model[0].batch_norm2.register_forward_hook(get_activation(h2b_name))

if data_info['hook_layers']['conv3'] == 'y':
    h3c_name = 'conv3'
    h3c = model[0].conv3.register_forward_hook(get_activation(h3c_name))

if data_info['hook_layers']['batch_norm3'] == 'y':
    h3b_name = 'batch_norm3'
    h3b = model[0].conv3.register_forward_hook(get_activation(h3b_name))

if data_info['hook_layers']['conv4'] == 'y':   
    h4c_name = 'conv4'
    h4c = model[0].conv4.register_forward_hook(get_activation(h4c_name))
    
if data_info['hook_layers']['batch_norm4'] == 'y':
    h4b_name = 'batch_norm4'
    h4b = model[0].batch_norm4.register_forward_hook(get_activation(h4b_name))



# Add input

In [None]:
subvolume = np.array(subvolume)  # Converting a list to a numpy array

# Add two extra axes
subvolume=subvolume[np.newaxis, np.newaxis, ...]
print("final subvolume input shape:", subvolume.shape)

# Input subvolume
subvolume = torch.from_numpy(subvolume)

# Run the Model Forward

In [None]:
output=model(subvolume)

metadata['prediction_output'] = str(output)

print(output)

prediction = output.argmax(dim=1).item()
metadata['prediction'] = str(prediction)
#print(prediction)

In [None]:
#print(activation)
#print(activation[layer].shape)

h1c.remove()
h1b.remove()
h2c.remove()
h2b.remove()
h3c.remove()
h3b.remove()
h4c.remove()
h4b.remove()


In [None]:
#activations = activation[layer].numpy()
#activations.shape

conv1_activations = activation[h1c_name].numpy()
print("conv1_activations: ", conv1_activations.shape)
bnorm1_activations = activation[h1b_name].numpy()
print("bnorm1_activations: ", bnorm1_activations.shape)

conv2_activations = activation[h2c_name].numpy()
print("conv2_activations: ", conv2_activations.shape)
bnorm2_activations = activation[h2b_name].numpy()
print("bnorm2_activations: ", bnorm2_activations.shape)

conv3_activations = activation[h3c_name].numpy()
print("conv3_activations: ", conv3_activations.shape)
bnorm3_activations = activation[h3b_name].numpy()
print("bnorm3_activations: ", bnorm3_activations.shape)

conv4_activations = activation[h4c_name].numpy()
print("conv4_activations: ", conv4_activations.shape)
bnorm4_activations = activation[h4b_name].numpy()
print("bnorm4_activations: ", bnorm4_activations.shape)

# Visualize the Activations on Plotly

In [None]:
activations_path = f'{output_dir}/IntermediateActivations'

if not os.path.exists(activations_path):
    os.makedirs(activations_path)

In [None]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.io as pio
import os


def plot3d_layer_activations(activations, name, colorscale='viridis', dirname=None):
  # activations.shape = (1, n_filter, cube[0], cube[1],  cube[2],  )
  # Try height 300 for each row
  # width 1200

    #TODO: Add the dataset name etc.
    dpath = dirname + '/' + name
    
    if not os.path.exists(dpath):
        os.makedirs(dpath)
    min_val = np.min(activations)
    max_val = np.max(activations)
    
    print("layer: ", name)
    print("Max value: ", max_val)
    print("Min value: ", min_val)
    
    metadata[name] = {"max":str(max_val), "min": str(min_val)}
    n_filters = activations.shape[1]  

    for i in range(n_filters):
        signal = activations[0,i,:,:,:] # resulting shape is 3 dimentions.

        X, Y, Z = np.mgrid[0:signal.shape[0], 0:signal.shape[1], 0:signal.shape[2]]

        fig = go.Figure(data=go.Volume(
            name=dataset_name,
            x = X.flatten(),
            y = Y.flatten(),
            z = Z.flatten(),
            value = signal.flatten(),
            cmin = min_val,
            cmax = max_val,
            opacity = 0.3,
            opacityscale = 0.3,
            surface_count = 8,
            colorscale=colorscale,
            #slices_z = dict(show=True, locations=[10]),
        ))
        
        fig.update_layout(scene = dict(
                    xaxis = dict(showticklabels=False),
                    yaxis = dict(showticklabels=False),
                    zaxis = dict(showticklabels=False)))
    
        filename = name + '_' + str(i) + '.png'
        print ("writing in : ", dpath + '/' + filename)
        
        fig.write_image(dpath + '/' + filename)

  

In [None]:

plot3d_layer_activations(conv1_activations, h1c_name + "_activations", colorscale='viridis', dirname = activations_path)

In [None]:
plot3d_layer_activations(conv2_activations, h2c_name + "_activations", colorscale='viridis', dirname = activations_path)

In [None]:
plot3d_layer_activations(conv3_activations, h3c_name + "_activations", colorscale='viridis', dirname = activations_path)

In [None]:
plot3d_layer_activations(conv4_activations, h4c_name + "_activations", colorscale='viridis', dirname = activations_path)

In [None]:
plot3d_layer_activations(bnorm1_activations, h1b_name + "_activations", colorscale='viridis', dirname = activations_path)

In [None]:
plot3d_layer_activations(bnorm2_activations, h2b_name + "_activations", colorscale='viridis', dirname = activations_path)

In [None]:
plot3d_layer_activations(bnorm3_activations, h3b_name + "_activations", colorscale='viridis', dirname = activations_path)

In [None]:
plot3d_layer_activations(bnorm4_activations, h4b_name + "_activations", colorscale='viridis', dirname = activations_path)

## Last layers slices

In [None]:
conv4_activations.shape

In [None]:
def print_slices(layer_activations, name, output_dir):
    n_filters = layer_activations.shape[1]
    
    for n in range(n_filters):
        filter_data = layer_activations[0,n,:,:,:]
        
        NUM_ROWS = 1
        IMGs_IN_ROW = layer_activations.shape[2]

        f, ax_arr = plt.subplots(NUM_ROWS*3, IMGs_IN_ROW, figsize=(18,12))
        for j, row in enumerate(ax_arr):
            if j == 0:
              for i, ax in enumerate(row):
                ax.imshow(filter_data[j*IMGs_IN_ROW+i, :, :])
                ax.set_title(f'x-slices {j*IMGs_IN_ROW+i}')
            elif j == 1:
              for i, ax in enumerate(row):
                ax.imshow(filter_data[:,(j-1)*IMGs_IN_ROW+i, :])
                ax.set_title(f'y-slices {(j-1)*IMGs_IN_ROW+i}')
            else:
              for i, ax in enumerate(row):
                ax.imshow(filter_data[:,:,(j-2)*IMGs_IN_ROW+i])
                ax.set_title(f'z-slices {(j-2)*IMGs_IN_ROW+i}') 


        slices_title = f'{title}_{name}_filter{str(n)}'
        f.suptitle(slices_title, fontsize=16)
        
        plt.savefig(f'{output_dir}/{name}/filter_{str(n)}_slices.png')
        plt.show() 


In [None]:
print_slices(conv4_activations, h4c_name + '_activations', activations_path)

In [None]:
print_slices(bnorm4_activations, h4b_name + '_activations', activations_path)

In [None]:
with open(f"{output_dir}/metadata.json", "w") as outfile: 
    json.dump(metadata, outfile, indent=2)

In [None]:
metadata