11.11.2022

# Test Grad-CAM on UNet model

Provo a usare una Grad-CAM su un modello salvato della UNet e il movie 34 (dove la fine della wave viene detettata come puff).

In [27]:
import configparser
import logging
import os
import sys

import numpy as np
import torch
from architectures import TempRedUNet
from datasets import SparkDataset
from in_out_tools import write_videos_on_disk
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
from training_inference_tools import run_samples_in_model, sampler
from visualization_tools import get_discrete_cmap

from medcam import medcam
import unet
import napari

### Set parameters

In [2]:
training_name = 'TEMP_new_annotated_peaks_physio'
config_file = 'config_temp_new_annotated_peaks_physio.ini'

print(f"Processing training '{training_name}'...")

Processing training 'TEMP_new_annotated_peaks_physio'...


### Configure output folder

In [3]:
output_folder = "trainings_validation" # same folder for train and test preds
os.makedirs(output_folder, exist_ok=True)

# subdirectory of output_folder where predictions are saved
# change this to save results for same model with different inference approaches
output_name = training_name

save_folder = os.path.join(output_folder, output_name, "gradCAM")
os.makedirs(save_folder, exist_ok=True)

print(f"Output files will be saved on '{save_folder}'")

Output files will be saved on 'trainings_validation\TEMP_new_annotated_peaks_physio\gradCAM'


### Detect GPU, if available

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
n_gpus = torch.cuda.device_count()
print(f"Using device '{device}' with {n_gpus} GPUs")

Using device 'cuda' with 1 GPUs


### Load config file

In [5]:
config_folder = "config_files"
CONFIG_FILE = os.path.join(config_folder, config_file)
c = configparser.ConfigParser()

print(f"Loading {CONFIG_FILE}")
c.read(CONFIG_FILE)

Loading config_files\config_temp_new_annotated_peaks_physio.ini


['config_files\\config_temp_new_annotated_peaks_physio.ini']

### Config UNet model

In [6]:
### Params ###
load_epoch = c.getint("testing", "load_epoch")

batch_size = c.getint("testing", "batch_size", fallback="1")
ignore_frames = c.getint("training", "ignore_frames_loss")

temporal_reduction = c.getboolean("network", "temporal_reduction", fallback=False)
num_channels = c.getint("network", "num_channels", fallback=1) if temporal_reduction else 1

In [7]:
### Configure UNet ###

batch_norm = {'batch': True, 'none': False}

unet_config = unet.UNetConfig(
    steps=c.getint("network", "unet_steps"),
    first_layer_channels=c.getint("network", "first_layer_channels"),
    num_classes=4,
    ndims=3,
    dilation=c.getint("network", "dilation", fallback=1),
    border_mode=c.get("network", "border_mode"),
    batch_normalization=batch_norm[c.get("network", "batch_normalization")],
    num_input_channels=num_channels,
)
if not temporal_reduction:
    network = unet.UNetClassifier(unet_config)
else:
    assert c.getint("dataset", "data_duration") % num_channels == 0, \
    "using temporal reduction chunks_duration must be a multiple of num_channels"
    network = TempRedUNet(unet_config)

network = nn.DataParallel(network).to(device)

In [8]:
### Load UNet model ###
models_relative_path = 'runs/'
model_path = os.path.join(models_relative_path, training_name)
#logger.info(f"Saved model path: {model_path}")
summary_writer = SummaryWriter(os.path.join(model_path, "summary"),
                               purge_step=0)

trainer = unet.TrainingManager(
        # training items
        training_step = None,
        save_path=model_path,
        managed_objects=unet.managed_objects({'network': network}),
        summary_writer=summary_writer
    )

print(f"Loading trained model '{training_name}' at epoch {load_epoch}...")
trainer.load(load_epoch)

Loading trained model 'TEMP_new_annotated_peaks_physio' at epoch 100000...


In [9]:
network.eval();

#### Print summary of network architecture

In [46]:
summary(network, (1,256,64,512))

Layer (type:depth-idx)                   Output Shape              Param #
├─UNetClassifier: 1-1                    [-1, 4, 256, 64, 512]     --
|    └─Sequential: 2                     []                        --
|    |    └─UNetLayer: 3-1               [-1, 8, 256, 64, 512]     1,960
|    └─MaxPool3d: 2-1                    [-1, 8, 128, 32, 256]     --
|    └─Sequential: 2                     []                        --
|    |    └─UNetLayer: 3-2               [-1, 16, 128, 32, 256]    10,400
|    └─MaxPool3d: 2-2                    [-1, 16, 64, 16, 128]     --
|    └─Sequential: 2                     []                        --
|    |    └─UNetLayer: 3-3               [-1, 32, 64, 16, 128]     41,536
|    └─MaxPool3d: 2-3                    [-1, 32, 32, 8, 64]       --
|    └─Sequential: 2                     []                        --
|    |    └─UNetLayer: 3-4               [-1, 64, 32, 8, 64]       166,016
|    └─MaxPool3d: 2-4                    [-1, 64, 16, 4, 32]       --

Layer (type:depth-idx)                   Output Shape              Param #
├─UNetClassifier: 1-1                    [-1, 4, 256, 64, 512]     --
|    └─Sequential: 2                     []                        --
|    |    └─UNetLayer: 3-1               [-1, 8, 256, 64, 512]     1,960
|    └─MaxPool3d: 2-1                    [-1, 8, 128, 32, 256]     --
|    └─Sequential: 2                     []                        --
|    |    └─UNetLayer: 3-2               [-1, 16, 128, 32, 256]    10,400
|    └─MaxPool3d: 2-2                    [-1, 16, 64, 16, 128]     --
|    └─Sequential: 2                     []                        --
|    |    └─UNetLayer: 3-3               [-1, 32, 64, 16, 128]     41,536
|    └─MaxPool3d: 2-3                    [-1, 32, 32, 8, 64]       --
|    └─Sequential: 2                     []                        --
|    |    └─UNetLayer: 3-4               [-1, 64, 32, 8, 64]       166,016
|    └─MaxPool3d: 2-4                    [-1, 64, 16, 4, 32]       --

In [53]:
for name, layer in network.named_modules():
    if isinstance(layer, torch.nn.Conv3d):
        print(name, layer)

module.down_path.0.layers.0 Conv3d(1, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
module.down_path.0.layers.3 Conv3d(8, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
module.down_path.1.layers.0 Conv3d(8, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
module.down_path.1.layers.3 Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
module.down_path.2.layers.0 Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
module.down_path.2.layers.3 Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
module.down_path.3.layers.0 Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
module.down_path.3.layers.3 Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
module.down_path.4.layers.0 Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
module.down_path.4.layers.3 Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1

### Load input sample

In [11]:
### Configure sample input ### 

sample_ids = ["34"]

dataset_path = os.path.realpath(c.get("dataset", "relative_path"))
assert os.path.isdir(dataset_path), f"\"{dataset_path}\" is not a directory"
print(f"Using {dataset_path} as dataset root path")


Using C:\Users\dotti\sparks_project\data\sparks_dataset as dataset root path


In [12]:
### Configure inference method and parameters ###

data_step = c.getint("testing", "data_step")
data_duration = c.getint("testing", "data_duration")
inference = c.get("testing", "inference")

In [13]:
testing_dataset = SparkDataset(
    base_path=dataset_path,
    sample_ids=sample_ids,
    testing=False, # we just do inference, without metrics computation
    smoothing=c.get("dataset", "data_smoothing"),
    step=data_step,
    duration=data_duration,
    remove_background=c.get("dataset", "remove_background"),
    temporal_reduction=c.getboolean("network", "temporal_reduction", fallback=False),
    num_channels=num_channels,
    normalize_video=c.get("dataset", "norm_video"),
    only_sparks=c.getboolean("dataset", "only_sparks", fallback=False),
    sparks_type=c.get("dataset", "sparks_type"),
    ignore_frames=c.get("training", "ignore_frames_loss"),
    ignore_index=4,
    gt_available=True,
    inference=inference)

print(f"\tTesting dataset of movie {testing_dataset.video_name} "\
      f"\tcontains {len(testing_dataset)} samples.")

TiffPage 0: TypeError: read_bytes() missing 3 required positional arguments: 'dtype', 'count', and 'offsetsize'


	Testing dataset of movie 34 	contains 22 samples.


In [14]:
data_loader = DataLoader(testing_dataset, batch_size=1, shuffle=False)

### Configure Grad-CAM

In [61]:
cam_network = medcam.inject(network,
                        label=3,
                        replace=True,
                        #backend="gcam",
                        layer='module.final_layer',
                        output_dir=save_folder,
                        save_maps=True)

### Run sample's chunks in network and re-assemble UNet's output

In [62]:
n_chunks = len(testing_dataset)
half_overlap = (data_duration-data_step)//2

In [68]:
out_concat = []
x_concat = []
for i, (x, _) in enumerate(data_loader):
    # define start and end of used frames in chunks
    start = 0 if i==0 else half_overlap
    end = None if i+1==n_chunks else -half_overlap

    x_concat.append(x[0,start:end])

    x = x.to(device)
    out = cam_network(x[None,:])[0,0]
    out_concat.append(out[start:end].cpu())
x_concat = torch.cat(x_concat, dim=0).numpy()
out_concat = torch.cat(out_concat, dim=0).numpy()

In [69]:
x_concat.shape

(928, 64, 512)

### Visualise result with Napari

In [28]:
# Configure Napari cmap
cmap = get_discrete_cmap(name='gray', lut=16)

In [70]:
viewer = napari.Viewer()

viewer.add_image(x_concat, 
                 name='input movie', 
                 colormap=('colors',cmap)
                )

viewer.add_image(out_concat, 
                 name='network output', 
                 colormap=('colors',cmap)
                )


<Image layer 'network output' at 0x14486fb6130>