# Evaluate model on selected case

## Import modules

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path

from PIL import Image
import ipywidgets as widgets
import numpy as np
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from monai.data import list_data_collate
from pytorch_lightning import Trainer
from captum.attr import IntegratedGradients, NoiseTunnel
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm

from lung_cancer_detection.data.nodule import ClassificationDataModule
from lung_cancer_detection.models.classification import NoduleClassificationModule
from lung_cancer_detection.utils import preview_dataset, load_json, preview_explanations

## Setup data module

In [3]:
rp = Path("/Volumes/LaCie/data/lung-cancer-detection/lidc-idri/").absolute()
dp = rp/"processed"
cp = (Path()/"../data/cache/").absolute()
sp = rp/"splits"
cp.mkdir(exist_ok=True)
mp = (Path()/"../models/classification/nodule_classification_baseline.ckpt").absolute()
nod_path = Path("/Users/felix/Downloads/nodules/").absolute()
nod_path.mkdir(exist_ok=True, parents=True)
print(dp.exists(), cp.exists(), sp.exists(), mp.exists(), nod_path.exists())

True True True True True


In [4]:
splits = [load_json(f) for f in sp.iterdir()]

In [5]:
dm = ClassificationDataModule(dp, 
                              cp, 
                              splits, 
                              min_anns=1, 
                              exclude_labels=[], 
                              label_mapping=([1,2,3,4,5], [0,0,0,1,1]))

In [6]:
dm.setup()

## Show case nodules

### Load nodules

In [7]:
case = "LIDC-IDRI-0186"

In [8]:
nods = dm.query_by_case(case)

In [9]:
z = widgets.IntSlider(value=14, min=0, max=29, step=1)

def print_images(z):
    preview_dataset(nods, z=z)
    
out = widgets.interactive_output(print_images, {'z': z})
widgets.HBox([widgets.HBox([z]), out])

HBox(children=(HBox(children=(IntSlider(value=14, max=29),)), Output()))

### Save nodules as series of 2D images

In [10]:
def save_nodule_slices(nodule, target_dir, target_size=(512,512)):
    for z in range(nodule.shape[2]):
        arr = nodule[:,:,z]
        img = Image.fromarray(np.uint8(cm.gray(arr)*255))
        img = img.resize(target_size)
        img.save(target_dir/f"slice_{z:02d}.png")

#### Nodule 1

In [13]:
nod1 = nods[0]["image"].squeeze().numpy()
z = widgets.IntSlider(value=14, min=0, max=29, step=1)

def show_img(z):
    plt.imshow(nod1[:,:,z], cmap="gray")
    
out = widgets.interactive_output(show_img, {'z': z})
widgets.HBox([widgets.HBox([z]), out])

HBox(children=(HBox(children=(IntSlider(value=14, max=29),)), Output()))

In [14]:
nod1_dir = nod_path/"nodule_1"
nod1_dir.mkdir(exist_ok=True)

In [15]:
save_nodule_slices(nod1, nod1_dir)

#### Nodule 2

In [16]:
nod = nods[1]["image"].squeeze().numpy()
z = widgets.IntSlider(value=14, min=0, max=29, step=1)

def show_img(z):
    plt.imshow(nod[:,:,z], cmap="gray")
    
out = widgets.interactive_output(show_img, {'z': z})
widgets.HBox([widgets.HBox([z]), out])

HBox(children=(HBox(children=(IntSlider(value=14, max=29),)), Output()))

In [17]:
nod_dir = nod_path/"nodule_2"
nod_dir.mkdir(exist_ok=True)
save_nodule_slices(nod, nod_dir)

#### Nodule 3

In [18]:
nod = nods[2]["image"].squeeze().numpy()
z = widgets.IntSlider(value=14, min=0, max=29, step=1)

def show_img(z):
    plt.imshow(nod[:,:,z], cmap="gray")
    
out = widgets.interactive_output(show_img, {'z': z})
widgets.HBox([widgets.HBox([z]), out])

HBox(children=(HBox(children=(IntSlider(value=14, max=29),)), Output()))

In [19]:
nod_dir = nod_path/"nodule_3"
nod_dir.mkdir(exist_ok=True)
save_nodule_slices(nod, nod_dir)

#### Nodule 4

In [20]:
nod = nods[3]["image"].squeeze().numpy()
z = widgets.IntSlider(value=14, min=0, max=29, step=1)

def show_img(z):
    plt.imshow(nod[:,:,z], cmap="gray")
    
out = widgets.interactive_output(show_img, {'z': z})
widgets.HBox([widgets.HBox([z]), out])

HBox(children=(HBox(children=(IntSlider(value=14, max=29),)), Output()))

Slice 20 should be the center slice. Thus, we remove slices 0-11.

In [21]:
print(nod.shape)
nod = nod[:,:,12:]
print(nod.shape)

(40, 40, 30)
(40, 40, 18)


In [22]:
nod_dir = nod_path/"nodule_4"
nod_dir.mkdir(exist_ok=True)
save_nodule_slices(nod, nod_dir)

#### Nodule 5

In [23]:
nod = nods[4]["image"].squeeze().numpy()
z = widgets.IntSlider(value=14, min=0, max=29, step=1)

def show_img(z):
    plt.imshow(nod[:,:,z], cmap="gray")
    
out = widgets.interactive_output(show_img, {'z': z})
widgets.HBox([widgets.HBox([z]), out])

HBox(children=(HBox(children=(IntSlider(value=14, max=29),)), Output()))

Slice 21 should be the center slice. Thus, we remove slices 0-13.

In [24]:
print(nod.shape)
nod = nod[:,:,14:]
print(nod.shape)

(40, 40, 30)
(40, 40, 16)


In [25]:
nod_dir = nod_path/"nodule_5"
nod_dir.mkdir(exist_ok=True)
save_nodule_slices(nod, nod_dir)

### Create data loader

In [26]:
data = DataLoader(nods, batch_size=len(nods), shuffle=False, num_workers=4, collate_fn=list_data_collate)

## Evaluate model predictions

In [27]:
model = NoduleClassificationModule.load_from_checkpoint(mp)

In [28]:
trainer = Trainer()

GPU available: False, used: False
TPU available: False, using: 0 TPU cores


In [29]:
preds = trainer.predict(model, dataloaders=data)
print(type(preds))

Predicting: 0it [00:00, ?it/s]

<class 'list'>


In [30]:
preds = preds[0].numpy()[:,1]

In [31]:
z = widgets.IntSlider(value=14, min=0, max=29, step=1)

def print_images(z):
    preview_dataset(nods, z=z, preds=preds)
    
out = widgets.interactive_output(print_images, {'z': z})
widgets.HBox([widgets.HBox([z]), out])

HBox(children=(HBox(children=(IntSlider(value=14, max=29),)), Output()))

## Explain model predictions

### Get predictions for batch

In [32]:
batch = None
for idx, sample in enumerate(data):
    batch = sample

In [33]:
x = batch["image"]

In [34]:
F.softmax(model.forward(x), dim=-1)

tensor([[0.8745, 0.1255],
        [0.9346, 0.0654],
        [0.7460, 0.2540],
        [0.7485, 0.2515],
        [0.1518, 0.8482]], grad_fn=<SoftmaxBackward>)

### Get raw attributions using Integrated Gradients

In [35]:
ig = IntegratedGradients(model)

In [36]:
baseline = torch.zeros(x.shape)

In [37]:
raw_attrs, _ = ig.attribute(x, baseline, target=1, return_convergence_delta=True)

In [38]:
z = widgets.IntSlider(value=14, min=0, max=29, step=1)

def print_images(z):
    preview_explanations(x, raw_attrs, z=z)
    
out = widgets.interactive_output(print_images, {'z': z})
widgets.HBox([widgets.HBox([z]), out])

HBox(children=(HBox(children=(IntSlider(value=14, max=29),)), Output()))

### Get smoothed attributions using NoiseTunnel wrapper

In [39]:
nig = NoiseTunnel(ig)

In [40]:
smooth_attrs = nig.attribute(x, nt_samples=5, target=1)

In [41]:
z = widgets.IntSlider(value=14, min=0, max=29, step=1)

def print_images(z):
    preview_explanations(x, smooth_attrs, z=z)
    
out = widgets.interactive_output(print_images, {'z': z})
widgets.HBox([widgets.HBox([z]), out])

HBox(children=(HBox(children=(IntSlider(value=14, max=29),)), Output()))

### Save heatmaps as series of 2D images

In [69]:
def preprocess_attribution(attribution):
    arr = F.relu(attribution).mean(dim=0).detach()
    arr /= arr.quantile(0.98)
    arr = torch.clamp(arr, 0, 1).numpy()
    return arr

def save_heatmap_slices(nodule, attribution, target_dir, target_size=(512,512)):
    for z in range(nodule.shape[2]):
        arr, attr = nodule[:,:,z], attribution[:,:,z]
        arr = arr * attr
        cm = plt.get_cmap("copper")
        arr = cm(arr)
        img = Image.fromarray((arr[:,:,:3]*255).astype(np.uint8))
        img = img.resize(target_size)
        img.save(target_dir/f"slice_{z:02d}.png")

#### Nodule 1

In [73]:
attr = raw_attrs[0]
nod = nods[0]["image"].squeeze().numpy()
attr = preprocess_attribution(attr)
target_dir = nod_path/"heatmap_1"
target_dir.mkdir(exist_ok=True)
save_heatmap_slices(nod, attr, target_dir)

#### Nodule 2

In [74]:
attr = raw_attrs[1]
nod = nods[1]["image"].squeeze().numpy()
attr = preprocess_attribution(attr)
target_dir = nod_path/"heatmap_2"
target_dir.mkdir(exist_ok=True)
save_heatmap_slices(nod, attr, target_dir)

#### Nodule 3

In [75]:
attr = raw_attrs[2]
nod = nods[2]["image"].squeeze().numpy()
attr = preprocess_attribution(attr)
target_dir = nod_path/"heatmap_3"
target_dir.mkdir(exist_ok=True)
save_heatmap_slices(nod, attr, target_dir)

#### Nodule 4

In [76]:
attr = raw_attrs[3]
nod = nods[3]["image"].squeeze().numpy()
attr = preprocess_attribution(attr)

nod = nod[:,:,12:]
attr = attr[:,:,12:]

target_dir = nod_path/"heatmap_4"
target_dir.mkdir(exist_ok=True)
save_heatmap_slices(nod, attr, target_dir)

#### Nodule 5

In [77]:
attr = raw_attrs[4]
nod = nods[4]["image"].squeeze().numpy()
attr = preprocess_attribution(attr)

nod = nod[:,:,14:]
attr = attr[:,:,14:]

target_dir = nod_path/"heatmap_5"
target_dir.mkdir(exist_ok=True)
save_heatmap_slices(nod, attr, target_dir)