-
Notifications
You must be signed in to change notification settings - Fork 0
/
ParallelNets_visualize.py
94 lines (75 loc) · 3.07 KB
/
ParallelNets_visualize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import sys
import torch
import matplotlib.pyplot as plt
from src.dataprocessing import preprocess
from src.visualization import setup, seggradcam
# Settings
EXPERIMENT = sys.argv[2]
SIDE = sys.argv[3]
MODEL_NAME = sys.argv[1]
MODEL_PATH = os.path.join('models')
PLOT_PATH = os.path.join('plots', MODEL_NAME, EXPERIMENT, SIDE, 'gradcam')
VISU_LAYERS = ['down1', 'down2', 'down3', 'down4', 'base', 'up1', 'up2', 'up3', 'up4']
setup = setup.Setup(data_path=os.path.join('data', EXPERIMENT, 'raw'),
experiment=EXPERIMENT,
side=SIDE)
# setup.set_stages([7]) # uncomment to only plot specific examples
setup.set_model(model_path=MODEL_PATH, model_name=MODEL_NAME)
setup.set_output_path(PLOT_PATH)
setup.set_visu_layers(VISU_LAYERS)
# Load the model
print('Loading model...')
model = seggradcam.ParallelNetsWithHooks()
model_path = os.path.join(setup.model_path, setup.model_name, setup.model_name + '.pt')
model.load_state_dict(torch.load(model_path))
model = model.unet
# Load Data
print('Loading data...')
inputs, targets = setup.load_data()
################################################################################################
# Segmentation Grad-CAM: overall network attention
sgc = seggradcam.SegGradCAM(setup, model)
print('\nPlotting Segmentation-Grad-CAM...')
# iterate over nodemap input_t samples
for key, input_t in inputs.items():
print(f'\r{key}', end='')
# calculate output and features in forward pass
input_t = preprocess.normalize(input_t)
output, heatmap = sgc(input_t)
# plot and save heatmap
stage_num = setup.nodemaps_to_stages[key]
fig = sgc.plot(output, heatmap, stage_num)
sgc.save(key, fig, subfolder='network')
################################################################################################
# Segmentation Grad-CAM: layer-wise attention
seg_grad_cams = {}
for name in setup.visu_layers:
seg_grad_cams[name] = seggradcam.SegGradCAM(setup, model, feature_modules=name)
print('\nPlotting Segmentation-Grad-CAM...')
# iterate over nodemap input_t samples
for key, input_t in inputs.items():
print(f'\r{key}', end='')
# calculate output and features in forward pass
input_t = preprocess.normalize(input_t)
output = model(input_t)
# calculate heatmaps
heatmaps = {}
for name, seg_grad_cam in seg_grad_cams.items():
_, heatmap = seg_grad_cam(input_t)
heatmaps[name] = heatmap
# plot heatmap
stage_num = setup.nodemaps_to_stages[key]
specimen = setup.experiment
plot_title = f'Specimen: {specimen} - Side: {setup.side} - Image: {stage_num}'
fig = seggradcam.plot_overview(output=output,
maps=heatmaps,
side=setup.side,
title=plot_title,
scale='QUALITATIVE')
# save
save_folder = os.path.join(setup.output_path, 'layers')
if not os.path.exists(save_folder):
os.makedirs(save_folder)
plt.savefig(os.path.join(save_folder, f'{stage_num:04d}.png'), dpi=100)
plt.close()