In [None]:
!python -c "import monai" || pip install -q "monai-weekly[all]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline

In [None]:
!pip install einops
import einops



In [None]:
from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    CropForeground,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    RandCropByPosNegLabel,
    RandAffined,
    ScaleIntensityRanged,
    Spacingd,
    EnsureTyped,
    EnsureType,
    Invertd,
    LabelFilterd,
    Resized
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet, UNETR
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
import numpy as np
import random

In [None]:
path_to_images='/content/drive/My Drive/MIMRTL Lab/Andy-labels'

In [None]:
train_images = sorted(
    glob.glob(os.path.join(path_to_images, '*_CT.nii.gz')))
train_labels = sorted(
    glob.glob(os.path.join(path_to_images, '*_segmented.nii.gz')))
random.seed(1)
print(train_images)
print(train_labels)
data_dicts = [
    {"image": image_name, "label": label_name}
    for image_name, label_name in zip(train_images, train_labels)
]

['/content/drive/My Drive/MIMRTL Lab/Andy-labels/51_CT.nii.gz', '/content/drive/My Drive/MIMRTL Lab/Andy-labels/52_CT.nii.gz', '/content/drive/My Drive/MIMRTL Lab/Andy-labels/53_CT.nii.gz', '/content/drive/My Drive/MIMRTL Lab/Andy-labels/54_CT.nii.gz', '/content/drive/My Drive/MIMRTL Lab/Andy-labels/55_CT.nii.gz', '/content/drive/My Drive/MIMRTL Lab/Andy-labels/56_CT.nii.gz', '/content/drive/My Drive/MIMRTL Lab/Andy-labels/57_CT.nii.gz', '/content/drive/My Drive/MIMRTL Lab/Andy-labels/58_CT.nii.gz', '/content/drive/My Drive/MIMRTL Lab/Andy-labels/59_CT.nii.gz', '/content/drive/My Drive/MIMRTL Lab/Andy-labels/60_CT.nii.gz', '/content/drive/My Drive/MIMRTL Lab/Andy-labels/61_CT.nii.gz', '/content/drive/My Drive/MIMRTL Lab/Andy-labels/62_CT.nii.gz', '/content/drive/My Drive/MIMRTL Lab/Andy-labels/63_CT.nii.gz', '/content/drive/My Drive/MIMRTL Lab/Andy-labels/64_CT.nii.gz', '/content/drive/My Drive/MIMRTL Lab/Andy-labels/65_CT.nii.gz', '/content/drive/My Drive/MIMRTL Lab/Andy-labels/66_CT.

In [None]:
set_determinism(seed=0)

In [None]:
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Spacingd(keys=["image", "label"], pixdim=(
            1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        ScaleIntensityRanged(
            keys=["image"], a_min=-1000, a_max=3000,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=(96, 96, 96),
            pos=1,
            neg=1,
            num_samples=4,
            image_key="image",
            image_threshold=0,
        ),
        # user can also add other random transforms
        RandAffined(
             keys=['image', 'label'],
             mode=('bilinear', 'nearest'),
             prob=1.0, spatial_size=(96, 96, 96),
             rotate_range=(0, 0, np.pi/15),
             scale_range=(0.1, 0.1, 0.1)),
        LabelFilterd(keys=["label"], applied_labels=(1)),
        EnsureTyped(keys=["image", "label"]),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Spacingd(keys=["image", "label"], pixdim=(
            1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        ScaleIntensityRanged(
            keys=["image"], a_min=-1000, a_max=3000,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        #Resized(keys=["image", "label"], spatial_size=[128, 128, -1], 
        #        mode = ["area", "nearest"]),
        LabelFilterd(keys=["label"], applied_labels=(1)),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        EnsureTyped(keys=["image", "label"]),
    ]
)

torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at  ../aten/src/ATen/native/TensorShape.cpp:2157.)


In [None]:
def load_data(trainIndex, testIndex):
  train_files = []
  val_files = []
  for i in range(len(data_dicts)):
    if i in trainIndex:
      train_files.append(data_dicts[i])
    else:
      val_files.append(data_dicts[i])
  train_ds = CacheDataset(
      data=train_files, transform=train_transforms,
      cache_rate=1.0, num_workers=4)
  # train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)

  # use batch_size=2 to load images and use RandCropByPosNegLabeld
  # to generate 2 x 4 images for network training
  train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)

  val_ds = CacheDataset(
      data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=4)
  # val_ds = Dataset(data=val_files, transform=val_transforms)
  val_loader = DataLoader(val_ds, batch_size=1, num_workers=4)
  return (train_loader, val_loader)

In [None]:
# standard PyTorch program style: create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def create_model():
  model = UNETR(
      spatial_dims=3,
      in_channels=1,
      featu
      out_channels=2,
      img_size=(96, 96, 96),
      norm_name='batch',
  ).to(device)
  loss_function = DiceLoss(to_onehot_y=True, softmax=True)
  optimizer = torch.optim.Adam(model.parameters(), 1e-4)
  dice_metric = DiceMetric(include_background=False, reduction="mean")
  return (model, loss_function, optimizer, dice_metric)

In [None]:
max_epochs = 200
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)])

def fit(model, loss_function, optimizer, dice_metric, train_loader, val_loader, i):
  for epoch in range(max_epochs):
      print("-" * 10)
      print(f"epoch {epoch + 1}/{max_epochs}")
      model.train()
      epoch_loss = 0
      step = 0
      for batch_data in train_loader:
          step += 1
          inputs, labels = (
              batch_data["image"].to(device),
              batch_data["label"].to(device),
          )
          optimizer.zero_grad()
          outputs = model(inputs)
          loss = loss_function(outputs, labels)
          loss.backward()
          optimizer.step()
          epoch_loss += loss.item()
          print(
              f"{step}/{len(train_ds) // train_loader.batch_size}, "
              f"train_loss: {loss.item():.4f}")
      epoch_loss /= step
      epoch_loss_values.append(epoch_loss)
      print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

      if (epoch + 1) % val_interval == 0:
          model.eval()
          with torch.no_grad():
              for val_data in val_loader:
                  val_inputs, val_labels = (
                      val_data["image"].to(device),
                      val_data["label"].to(device),
                  )
                  roi_size = (160, 160, 160)
                  sw_batch_size = 4
                  val_outputs = sliding_window_inference(
                      val_inputs, roi_size, sw_batch_size, model)
                  val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                  val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                  # compute metric for current iteration
                  dice_metric(y_pred=val_outputs, y=val_labels)

              # aggregate the final mean dice result
              metric = dice_metric.aggregate().item()
              # reset the status for next validation round
              dice_metric.reset()

              metric_values.append(metric)
              if metric > best_metric:
                  best_metric = metric
                  best_metric_epoch = epoch + 1
                  torch.save(model.state_dict(), os.path.join(
                      path_to_images, "best_metric_model" + str(i+1) + ".pth"))
                  print("saved new best metric model")
              print(
                  f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                  f"\nbest mean dice: {best_metric:.4f} "
                  f"at epoch: {best_metric_epoch}"
              )
  print(
      f"train completed, best_metric: {best_metric:.4f} "
      f"at epoch: {best_metric_epoch}")

In [None]:
val_org_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Spacingd(keys=["image"], pixdim=(
            1.5, 1.5, 2.0), mode="bilinear"),
        Orientationd(keys=["image"], axcodes="RAS"),
        ScaleIntensityRanged(
            keys=["image"], a_min=-1000, a_max=3000,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        LabelFilterd(keys=["label"], applied_labels=(1)),
        CropForegroundd(keys=["image"], source_key="image"),
        EnsureTyped(keys=["image", "label"]),
    ]
)

post_transforms = Compose([
    EnsureTyped(keys="pred"),
    Invertd(
        keys="pred",
        transform=val_org_transforms,
        orig_keys="image",
        meta_keys="pred_meta_dict",
        orig_meta_keys="image_meta_dict",
        meta_key_postfix="meta_dict",
        nearest_interp=False,
        to_tensor=True,
    ),
    ScaleIntensityRanged(
            keys=["image"], a_min=0.0, a_max=1.0,
            b_min=-1000.0, b_max=3000.0, clip=True,
        ),
    AsDiscreted(keys="pred", argmax=True, to_onehot=2),
    AsDiscreted(keys="label", to_onehot=2),
])

In [None]:
def run_stats(testIndex, i):
  model.load_state_dict(torch.load(
      os.path.join(path_to_images, "best_metric_model" + str(i+1) + ".pth")))
  model.eval()
  dataTest = []
  for i in len(data_dicts):
    if i in testIndex:
      dataTest.append(data_dicts[i])
  val_org_ds = Dataset(
      data=dataTest, transform=val_org_transforms)
  val_org_loader = DataLoader(val_org_ds, batch_size=1, num_workers=4)

  with torch.no_grad():
      for val_data in val_org_loader:
        val_inputs = val_data["image"].to(device)
        roi_size = (160, 160, 160)
        sw_batch_size = 4
        val_data["pred"] = sliding_window_inference(
                  val_inputs, roi_size, sw_batch_size, model)
        val_data = [post_transforms(i) for i in decollate_batch(val_data)]
        subject = val_data[0]["pred_meta_dict"]["filename_or_obj"]
        subject = subject.split("/")
        subject = subject[6].replace("_CT.nii.gz", "")
        val_images, val_outputs, val_labels = from_engine(["image", "pred", "label"])(val_data)
        val_images = val_images[0][0]
        val_outputs = val_outputs[0]
        val_labels = val_labels[0]
        mean = torch.mean(val_images[val_outputs[1] == 1])
        std = torch.std(val_images[val_outputs[1] == 1])
        # compute metric for current iteration
        dice_metric(y_pred=val_outputs, y=val_labels)
        score = dice_metric.aggregate().item()
        AllStats.append((subject, mean, std, score))
        dice_metric.reset()

In [None]:
# Initialize variables
n = 10
allStats = []

# Implement KFold with n splits
from sklearn.model_selection import KFold
kf = KFold(n_splits=n, shuffle=True)

for i, (trainIndex, testIndex) in enumerate(kf.split(list(range(len(data_dicts))))):
  print ("Running Fold", i+1, "/", n)
  (train_loader, val_loader) = load_data(trainIndex, testIndex)
  model = None
  (model, loss_function, optimizer, dice_metric) = create_model()
  fit(model, loss_function, optimizer, dice_metric, train_loader, val_loader, i)
  run_stats(testIndex, i)

Running Fold 1 / 10


Loading dataset:   0%|          | 0/15 [00:00<?, ?it/s]