Evaluate pretrained U-net from https://github.com/JoHof/lungmask on Plethora dataset.

In [1]:
%load_ext autoreload
%autoreload 2
%cd ..

/storage/ducpm/lung-segmentation


In [2]:
import SimpleITK as sitk
import numpy as np
import matplotlib.pyplot as plt
import skimage

import torch
from torchvision import transforms
from tqdm.notebook import tqdm

from lungmask.mask import get_model
from lungmask.utils import preprocess, postrocessing, reshape_mask

from src.metrics import dice_coeff_vectorized
from src.data.preprocess import ToTensor
from src.data.data_modules import PlethoraDataModule, Covid19DataModule
from src.visualization.plotting import plot_batch

In [3]:
class LungMaskPreprocess:
    def __call__(self, sample):
        img = sample["img"]
        new_img, new_box = preprocess(np.expand_dims(img, axis=0),
                                               resolution=[256,256])
        new_img[new_img > 600] = 600
        new_img = np.divide((new_img + 1024), 1624)
        new_img = np.squeeze(new_img)
        new_box = np.squeeze(new_box)
        return dict(sample, img=new_img, box=new_box, original_img=img)

In [4]:
transform = transforms.Compose([
    LungMaskPreprocess(),
    ToTensor(),
    transforms.Lambda(lambda sample: dict(sample, img=sample["img"].unsqueeze(0))),
])

#dm = PlethoraDataModule(batch_size=16, transform=transform)
dm = Covid19DataModule(batch_size=32, transform=transform)
dm.setup('test')
test_loader = dm.test_dataloader()

In [5]:
device = "cuda:0"
model = get_model('unet', 'R231').to(device)

In [6]:
dice_2d_list = []
dice_3d_list = []
true_buffer = []
pred_buffer = []

res = {}
pbar = tqdm(test_loader)
for batch_idx, batch in enumerate(pbar):
    with torch.no_grad():
        X, y = batch["img"].to(device), batch["mask"].to(device)
        logits = model(X)
        pred_masks = torch.max(logits, 1)[1].detach().cpu().numpy().astype(np.uint8)
    
    boxes = batch["box"].cpu().numpy()
    # post-process masks
    pred_masks = postrocessing(pred_masks)
    pred_masks = np.asarray(
            [reshape_mask(pred_masks[i], batch["box"][i], batch["original_img"].shape[1:]) for i in range(pred_masks.shape[0])],
            dtype=np.uint8)
    pred_masks[pred_masks > 0] = 1
    pred_masks = torch.from_numpy(pred_masks).to(device)
    
    # calculate metrics
    dsc = dice_coeff_vectorized(pred_masks, y, reduce_fn=torch.mean)
    dice_2d_list.append(dsc.item())
    res["dice_2d"] = dsc.item()

    # 3D dice coeff
    slice_idxs = batch["slice_idx"]
    split_idx = torch.where(slice_idxs == 0)[0]
    if len(split_idx) > 1:
        raise RuntimeError(f"there are multiple zeros in slice_idxs: {slice_idxs}")
    split_idx = split_idx.item() if len(split_idx) > 0 else None
    true_buffer.append(y[:split_idx])
    pred_buffer.append(pred_masks[:split_idx])

    # if we have finish iterating over a CT scan, calculate 3D dice
    if (split_idx is not None and batch_idx > 0) or \
            batch_idx == len(test_loader) - 1:
        true_v_mask = torch.cat(true_buffer).reshape(1, -1)
        pred_v_mask = torch.cat(pred_buffer).reshape(1, -1)

        dsc_v = dice_coeff_vectorized(pred_v_mask, true_v_mask, reduce_fn=None)
        dice_3d_list.append(dsc_v.item())
        res["dice_3d"] = dsc_v.item()

        # clear the buffers
        true_buffer = [y[split_idx:]]
        pred_buffer = [pred_masks[split_idx:]]
        del true_v_mask
        del pred_v_mask
    pbar.set_postfix(res)

HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




In [7]:
print(len(dice_3d_list))
print(dice_3d_list)

10
[0.9836381673812866, 0.9823898077011108, 0.9759968519210815, 0.9851373434066772, 0.9864538311958313, 0.986423909664154, 0.9833903312683105, 0.9856266975402832, 0.9808560609817505, 0.9734964966773987]


In [30]:
print(f"2d dice: {np.mean(dice_2d_list):.4f}")
print(f"3d dice: {np.mean(dice_3d_list):.4f}")

2d dice: 0.9391
3d dice: 0.9822
