In [1]:
%load_ext autoreload
%autoreload 2
%cd ..

/storage/ducpm/lung-segmentation


In [2]:
import os
import glob

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib
from tqdm import tqdm
from sklearn.metrics import f1_score
from src.metrics import *
from src.data.data_modules import Covid19DataModule, PlethoraDataModule

In [3]:
def flatten(tensor):
    """Flattens a given tensor such that the channel axis is first.
    The shapes are transformed as follows:
       (N, C, D, H, W) -> (C, N * D * H * W)
    """
    # number of channels
    C = tensor.size(1)
    # new axis order
    axis_order = (1, 0) + tuple(range(2, tensor.dim()))
    # Transpose: (N, C, D, H, W) -> (C, N, D, H, W)
    transposed = tensor.permute(axis_order)
    # Flatten: (C, N, D, H, W) -> (C, N * D * H * W)
    return transposed.contiguous().view(C, -1)

def compute_per_channel_dice(input, target, epsilon=1e-6, weights=None):
    """
    Implemented by https://github.com/wolny/pytorch-3dunet
    
    Computes DiceCoefficient as defined in V-net paperL https://arxiv.org/abs/1606.04797,
    given a multi channel input and target.
    Assumes the input is a normalized probability, e.g. a result of Sigmoid or Softmax function.
    
    Args:
         input (torch.Tensor): NxCxSpatial input tensor
         target (torch.Tensor): NxCxSpatial target tensor
         epsilon (float): prevents division by zero
         weight (torch.Tensor): Cx1 tensor of weight per channel/class
    """
    assert input.shape == target.shape
    
    input = flatten(input).float()
    target = flatten(target).float()
    
    inter = (input * target).sum(-1)
    if weights is not None:
        inter *= weights
    
    # extension proposed in V-net paper 
    denom = (input * input).sum(-1) + (target * target).sum(-1)
    return 2 * (inter / denom.clamp(min=epsilon))

In [3]:
data_module_args = { 
   "batch_size": 8,
   "img_size": 512,
   "clip_low": -1000,
   "clip_high": 1000,
   "pin_memory": True,
   "num_workers": 2
}                                                              
#dm = Covid19DataModule(**data_module_args)        
dm = PlethoraDataModule(**data_module_args)
dm.setup()
print("No. test samples:", len(dm.test_ds))
test_loader = dm.test_dataloader()

No. test samples: 5199


In [4]:
device = "cuda:0"
from src.models.unet import UNet
net = UNet.load_from_checkpoint("logs/unet-plethora-512/version_0/ckpts/epoch=8-dice_coeff_val=0.942.ckpt")
net.to(device).eval();

In [7]:
# sanity check model predictions
it = iter(test_loader)
#for _ in range(8):
#    batch = next(it)
# batch['img'] = batch['img'][:4]
# batch['mask'] = batch['mask'][:4]
with torch.no_grad():
    logits = net(batch['img'].to(device))
    pred_masks = torch.argmax(logits, dim=1)

NameError: name 'batch' is not defined

In [24]:
# plot_true_vs_pred(batch['img'], 
#                   batch['mask'], 
#                   pred_masks.cpu(), mask_alpha=0.3)

In [5]:
dice_scores = []
#pbar = tqdm(dm.test_dataloader())
pbar = tqdm(test_loader)
true_buffer = []
pred_buffer = []
# evaluate on test set
for i, batch in enumerate(pbar):
    X, y = batch["img"].to(device), batch["mask"].to(device)
    slice_idxs = batch["slice_idx"]
    with torch.no_grad():
        logits = net(X.to(device))
        pred_masks = torch.argmax(logits, dim=1)
    pred_masks = pred_masks.detach()
    
    split_idx = torch.where(slice_idxs == 0)[0]
    if len(split_idx) > 1:
        raise RuntimeError(f"there are multiple zeros in slice_idxs: {slice_idxs}")
    split_idx = split_idx.item() if len(split_idx) > 0 \
            else None
    true_buffer.append(y[:split_idx])
    pred_buffer.append(pred_masks[:split_idx])
    
    # check if we have started to process a new CT scan
    if (split_idx is not None and i > 0) or i == len(test_loader) - 1: 
        pbar.set_description(f"calculating 3D dice")
        true_v_mask = torch.cat(true_buffer).reshape(1, -1)
        pred_v_mask = torch.cat(pred_buffer).reshape(1, -1)
    
        dsc_v = dice_coeff_vectorized(pred_v_mask, true_v_mask, reduce_fn=None).item()
        dice_scores.append(dsc_v)
        pbar.set_description(f"dsc={dsc_v:.3f}")
        # empty the buffers and collect slices from new CT scan
        true_buffer = [y[split_idx:]]
        pred_buffer = [pred_masks[split_idx:]]
        
        del true_v_mask
        del pred_v_mask
        torch.cuda.empty_cache()

dsc=0.964: 100%|██████████| 650/650 [20:13<00:00,  1.87s/it]          


In [8]:
np.mean(dice_scores), len(dice_scores)

(0.970837050821723, 41)