# Evaluate model on selected case

## Import modules

In [1]:
%load_ext autoreload
%autoreload 2

In [102]:
from pathlib import Path

import ipywidgets as widgets
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from monai.data import list_data_collate
from pytorch_lightning import Trainer
from captum.attr import IntegratedGradients
import matplotlib.pyplot as plt

from lung_cancer_detection.data.nodule import ClassificationDataModule
from lung_cancer_detection.models.classification import NoduleClassificationModule
from lung_cancer_detection.utils import preview_dataset, load_json, preview_explanations

## Setup data module

In [3]:
rp = Path("/Volumes/LaCie/data/lung-cancer-detection/lidc-idri/").absolute()
dp = rp/"processed"
cp = (Path()/"../data/cache/").absolute()
sp = rp/"splits"
cp.mkdir(exist_ok=True)
mp = (Path()/"../models/classification/nodule_classification_baseline.ckpt").absolute()
print(dp.exists(), cp.exists(), sp.exists(), mp.exists())

True True True True


In [4]:
splits = [load_json(f) for f in sp.iterdir()]

In [5]:
dm = ClassificationDataModule(dp, 
                              cp, 
                              splits, 
                              min_anns=1, 
                              exclude_labels=[], 
                              label_mapping=([1,2,3,4,5], [0,0,0,1,1]))

In [6]:
dm.setup()

## Show case nodules

In [7]:
case = "LIDC-IDRI-0186"

In [8]:
nods = dm.query_by_case(case)

In [9]:
z = widgets.IntSlider(value=14, min=0, max=29, step=1)

def print_images(z):
    preview_dataset(nods, z=z)
    
out = widgets.interactive_output(print_images, {'z': z})
widgets.HBox([widgets.HBox([z]), out])

HBox(children=(HBox(children=(IntSlider(value=14, max=29),)), Output()))

In [10]:
data = DataLoader(nods, batch_size=len(nods), shuffle=False, num_workers=4, collate_fn=list_data_collate)

## Evaluate model predictions

In [11]:
model = NoduleClassificationModule.load_from_checkpoint(mp)

In [12]:
trainer = Trainer()

GPU available: False, used: False
TPU available: False, using: 0 TPU cores


In [15]:
preds = trainer.predict(model, dataloaders=data)
print(type(preds))

Predicting: 0it [00:00, ?it/s]

<class 'list'>


In [21]:
preds = preds[0].numpy()[:,1]

In [28]:
z = widgets.IntSlider(value=14, min=0, max=29, step=1)

def print_images(z):
    preview_dataset(nods, z=z, preds=preds)
    
out = widgets.interactive_output(print_images, {'z': z})
widgets.HBox([widgets.HBox([z]), out])

HBox(children=(HBox(children=(IntSlider(value=14, max=29),)), Output()))

## Explain model predictions

In [32]:
batch = None
for idx, sample in enumerate(data):
    batch = sample

In [35]:
x = batch["image"]

In [39]:
F.softmax(model.forward(x), dim=-1)

tensor([[0.8745, 0.1255],
        [0.9346, 0.0654],
        [0.7460, 0.2540],
        [0.7485, 0.2515],
        [0.1518, 0.8482]], grad_fn=<SoftmaxBackward>)

In [48]:
ig = IntegratedGradients(model)

In [54]:
baseline = torch.zeros(x.shape)

In [55]:
attributions, delta = ig.attribute(x, baseline, target=1, return_convergence_delta=True)

In [105]:
z = widgets.IntSlider(value=14, min=0, max=29, step=1)

def print_images(z):
    preview_explanations(x, attributions, z=z)
    
out = widgets.interactive_output(print_images, {'z': z})
widgets.HBox([widgets.HBox([z]), out])

HBox(children=(HBox(children=(IntSlider(value=14, max=29),)), Output()))