<a href="https://colab.research.google.com/github/felixpeters/lung-cancer-detection/blob/master/nbs/colab/01_MSD_Lung_Nodule_Segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MSD Lung Nodule Segmentation

## Setup

In [1]:
!pip install "monai[nibabel,skimage,pillow,tqdm]" pytorch_lightning



In [2]:
from typing import Optional
from pathlib import Path

import torch
from monai.config import print_config
import pytorch_lightning as pl

In [3]:
print_config()

MONAI version: 0.5.2
Numpy version: 1.19.5
Pytorch version: 1.8.1+cu101
MONAI flags: HAS_EXT = False, USE_COMPILED = False
MONAI rev id: feb3a334b7bbf302b13a6da80e0b022a4cf75a4e

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 3.0.2
scikit-image version: 0.16.2
Pillow version: 7.1.2
Tensorboard version: 2.4.1
gdown version: 3.6.4
TorchVision version: 0.9.1+cu101
ITK version: NOT INSTALLED or UNKNOWN VERSION.
tqdm version: 4.61.0
lmdb version: 0.99
psutil version: 5.4.8

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies



In [4]:
print(torch.cuda.is_available())
print(torch.cuda.device_count())

True
1


In [5]:
root_dir = Path("data/msd")
root_dir.mkdir(parents=True, exist_ok=True)
print(root_dir)

data/msd


## Define data module

In [6]:
import os
from typing import Sequence
from torch.utils.data import DataLoader
from monai.apps import DecathlonDataset
from monai.transforms import (
    Compose,
    LoadImaged,
    AddChanneld,
    ScaleIntensityd,
    ToTensord,
)
from monai.data.utils import list_data_collate

class MSDLungDataModule(pl.LightningDataModule):

  def __init__(self, 
               root_dir: Path, 
               batch_size: int = 16, 
               spacing: Sequence[float] = (1.5,1.5,2.0), 
               roi_size: Sequence[int] = [100, 100, 75],
               random_seed: int = 47,
               ):
    super().__init__()
    self.root_dir = root_dir
    self.spacing = spacing
    self.roi_size = roi_size
    self.batch_size = batch_size
    self.random_seed = random_seed
    self.train_transforms = Compose([
      LoadImaged(keys=["image", "label"]),
      AddChanneld(keys=["image", "label"]),
      ScaleIntensityd(keys=["image"]),
      ToTensord(keys=["image", "label"]),
    ])
    self.valid_transforms = Compose([
      LoadImaged(keys=["image", "label"]),
      AddChanneld(keys=["image", "label"]),
      ScaleIntensityd(keys=["image"]),
      ToTensord(keys=["image", "label"]),
    ])

  def prepare_data(self):
    return

  def setup(self, stage: Optional[str] = None):
    if stage in (None, "fit"):
      self.train_data = DecathlonDataset(
          root_dir=self.root_dir, 
          task="Task06_Lung",
          transform=self.train_transforms,
          section="training",
          seed=self.random_seed,
          download=True,
          num_workers=os.cpu_count(),
      )
      self.valid_data = DecathlonDataset(
          root_dir=self.root_dir, 
          task="Task06_Lung",
          transform=self.valid_transforms,
          section="validation",
          seed=self.random_seed,
          download=True,
          num_workers=os.cpu_count(),
      )

  def train_dataloader(self):
    return DataLoader(
        self.train_data, 
        batch_size=self.batch_size, 
        shuffle=True,
        num_workers=os.cpu_count(),
        collate_fn=list_data_collate,
    )

  def val_dataloader(self):
    return DataLoader(
        self.valid_data, 
        batch_size=self.batch_size, 
        shuffle=True,
        num_workers=os.cpu_count(),
        collate_fn=list_data_collate,
    )

  def test_dataloader(self):
    pass

In [None]:
dm = MSDLungDataModule(root_dir)
dm.setup()

Loading dataset:   0%|          | 0/51 [00:00<?, ?it/s]

Verified 'Task06_Lung.tar', md5: 8afd997733c7fc0432f71255ba4e52dc.
file data/msd/Task06_Lung.tar exists, skip downloading.
extracted file data/msd/Task06_Lung exists, skip extracting.


Loading dataset:  37%|███▋      | 19/51 [01:22<01:58,  3.71s/it]

In [None]:
import matplotlib.pyplot as plt

def print_shapes(dataset):
  for item in dataset:
    print(f"image: {item["image"].shape}", f"label: {item["label"].shape}")

def preview(item, z=None):
  plt.figure("Chest CTs with labels", (12, 6))
  img = item["image"].numpy()[0]
  label = item["label"].numpy()[0]
  z = z if not None else int(img.shape[2]/2)
  plt.subplot(1, 2, 1)
  plt.imshow(img[:,:,z], cmap="gray")
  plt.title("Image")
  plt.subplot(1, 2, 2)
  plt.imshow(label[:,:,z], cmap="gray")
  plt.title("Label")
  plt.show()

In [None]:
print_shapes(dm.train_data)
print_shapes(dm.valid_data)

In [None]:
for item in dm.train_data[:10]:
  preview(item)

## Define model module

In [None]:
from typing import Dict
import torch
import torch.nn as nn
from torch.optim import Adam
from monai.losses import DiceLoss

from monai.networks.nets import BasicUNet

class NoduleSegmentationNet(pl.LightningModule):
  
  def __init__(self, model: nn.Module = None, lr: float = 1e-4):
    super().__init__()
    if model:
      self.model = model
    else:
      self.model = BasicUNet()
    self.lr = lr
    self.loss = DiceLoss(to_onehot_y=True, softmax=True)
    self.save_hyperparameters()

  def forward(self, x: torch.Tensor):
    return self.model(x)

  def training_step(self, batch: Dict, batch_idx: int):
    x, y = batch["image"], batch["label"]
    output = self.forward(x)
    loss = self.loss(output, y)
    self.log("train_loss", loss)
    return loss

  def validation_step(self, batch: Dict, batch_idx: int):
    x, y = batch["image"], batch["label"]
    output = self.forward(x)
    loss = self.loss(output, y)
    self.log("val_loss", loss)
    return loss

  def configure_optimizers(self):
    optimizer = Adam(self.model.parameters(), self.lr)
    return optimizer

## Configure experiment