# 2D Glioma classification

This notebook demonstrates classification of brain tumors with MONAI. To accelerate training, we generate a 2D dataset from a 3D one.

## Dataset

The dataset used here is the Decathlon 3D brain tumor dataset, taking the 2D slice containing the most voxels > 0 (the most label), and then saving the new dataset to disk. We'll download the pre-computed dataset from Google Drive, but the script is available in case you're interested.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/rijobro/GliMR_MONAI_workshop/blob/main/2D_Glioma_classification.ipynb)

## Setup environment

This checks if MONAI is installed, and if not installs it (plus any optional extras that might be needed for this notebook).

In [None]:
!python -c "import monai" || pip install -q "monai-weekly[nibabel, tqdm]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline

## Setup imports

In [None]:
# Copyright 2022 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial
from glob import glob
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import os
import random
import tempfile
import torch
from tqdm import tqdm

import monai
from monai.apps import download_and_extract
from monai.data import (
    CacheDataset,
    DataLoader,
    Dataset,
    pad_list_data_collate,
    TestTimeAugmentation,
    decollate_batch,
)
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.networks import eval_mode
from monai.networks.nets import UNet
import monai.transforms as mt
from monai.transforms.utils import allow_missing_keys_mode
from monai.utils import first, set_determinism

monai.config.print_config()

# Set deterministic training for reproducibility
set_determinism(seed=0)

## Setup data directory

You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  
This allows you to save results and reuse downloads.  
If not specified a temporary directory will be used.

In [None]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else os.path.expanduser(directory)
print(root_dir)

## Get 2D data

We'll download the pre-computed dataset from Google Drive, but the script is available in case you're interested.

In [None]:
download_from_gdrive = False
task = "Task01_BrainTumour"
output_dir = os.path.join(root_dir, task + "2D")

if download_from_gdrive:
    resource = "https://drive.google.com/file/d/1BB0S2PcY6yUR7TK-AeyCFoh6PyoJiH0E"
    compressed_file = os.path.join(root_dir, task + "2D.tar")
    md5 = "a2482cf48b7c72b09b4b647820e61c8e"
    download_and_extract(resource, compressed_file, root_dir, hash_val=md5)
else:
    %run -i ../utils/2d_slice_creator.py --path {output_dir} --download_path {root_dir} --task {task}
    
images = sorted(glob(os.path.join(output_dir, "image", "*.nii.gz")))
labels = sorted(glob(os.path.join(output_dir, "label", "*.nii.gz")))
assert len(images) == len(labels)
data_dicts = [{"image": image, "label": label}
              for image, label in zip(images, labels)]

In [None]:
random.shuffle(data_dicts)
num_files = len(data_dicts)
num_train_files = round(0.8 * num_files)
train_files = data_dicts[:num_train_files]
val_files = data_dicts[num_train_files:]
print("total num files:", len(data_dicts))
print("num training files:", len(train_files))
print("num validation files:", len(val_files))

In [None]:
keys = ["image", "label"]
train_transforms = mt.Compose(
    [
        mt.LoadImaged(keys),
        mt.Lambdad("label", lambda x: (x > 0).astype(np.float32)),  # make label binary
        mt.RandAffined(
            keys,
            prob=1.0,
            spatial_size=(300, 300),
            rotate_range=(np.pi / 3, np.pi / 3),
            translate_range=(3, 3),
            scale_range=((0.8, 1), (0.8, 1)),
            padding_mode="zeros",
            mode=("bilinear", "nearest"),
        ),
        mt.CropForegroundd(keys, source_key="image"),
        mt.DivisiblePadd(keys, 16),
        mt.ScaleIntensityd("image"),
        mt.EnsureTyped(keys),
    ]
)
val_transforms = train_transforms

In [None]:
train_ds = CacheDataset(
    data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=4)
train_loader = DataLoader(train_ds, batch_size=10,
                          num_workers=10, collate_fn=pad_list_data_collate)
val_ds = CacheDataset(
    data=val_files, transform=train_transforms, cache_rate=1.0, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=1,
                        num_workers=10, collate_fn=pad_list_data_collate)

## Display some examples

In [None]:
%matplotlib inline
def imshows(ims):
    nrow = len(ims)
    ncol = len(ims[0])
    fig, axes = plt.subplots(nrow, ncol, figsize=(
        ncol * 3, nrow * 3), facecolor='white')
    for i, im_dict in enumerate(ims):
        for j, (title, im) in enumerate(im_dict.items()):
            if isinstance(im, torch.Tensor):
                im = im.detach().cpu().numpy()
            im = np.mean(im, axis=0)  # average across channels
            if len(ims) == 1:
                ax = axes[j]
            else:
                ax = axes[i, j]
            ax.set_title(f"{title}\n{im.shape}")
            im_show = ax.imshow(im)
            ax.axis("off")
            fig.colorbar(im_show, ax=ax)


to_imshow = []
for file in np.random.choice(train_files, size=5, replace=False):
    data = train_transforms(file)
    to_imshow.append({"image": data["image"], "label": data["label"]})
imshows(to_imshow)

In [None]:
# Function for live plotting whilst running training
def plot_range(data, wrapped_generator):
    # Get ax, show plot, etc.
    plt.ion()
    for d in data.values():
        ax = d["line"].axes
    fig = ax.get_figure()
    fig.show()

    for i in wrapped_generator:
        yield i
        # update plots, legend, view
        for d in data.values():
            d["line"].set_data(d["x"], d["y"])
        ax.legend()
        ax.relim()
        ax.autoscale_view()
        fig.canvas.draw()

In [None]:
post_trans = mt.Compose([
    mt.Activations(sigmoid=True),
    mt.AsDiscrete(threshold=0.5),
    mt.KeepLargestConnectedComponent(applied_labels=1),
])


def infer_seg(images, model):
    val_outputs = model(images)
    return torch.stack([post_trans(i) for i in decollate_batch(val_outputs)])


# Create network, loss fn., etc.
dice_metric = DiceMetric(include_background=True, reduction="mean")
in_channels = train_ds[0]["image"].shape[0]
# create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = UNet(
    spatial_dims=2,
    in_channels=in_channels,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)
loss_function = DiceLoss(sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-3)

In [None]:
%matplotlib notebook

best_model_path = "best_model_2d_glioma_classification.pth"

# Plotting stuff
fig, ax = plt.subplots(1, 1, figsize=(10, 10), facecolor="white")
ax.set_xlabel("Epoch")
ax.set_ylabel("Metric")

data = {}
for i in ["train", "val dice"]:
    data[i] = {"x": [], "y": []}
    (data[i]["line"],) = ax.plot(data[i]["x"], data[i]["y"], label=i)

# start a typical PyTorch training
max_epochs = 100
val_interval = 1
best_metric = -1
best_metric_epoch = -1

for epoch in plot_range(data, range(max_epochs)):
    model.train()
    epoch_loss = 0

    for batch_data in train_loader:
        inputs, labels = batch_data["image"].to(
            device), batch_data["label"].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    epoch_loss /= len(train_loader)
    data["train"]["x"].append(epoch + 1)
    data["train"]["y"].append(epoch_loss)

    if (epoch + 1) % val_interval == 0:
        with eval_mode(model):
            val_outputs = None
            for val_data in val_loader:
                val_images, val_labels = val_data["image"].to(
                    device), val_data["label"].to(device)
                val_outputs = infer_seg(val_images, model)
                dice_metric(y_pred=val_outputs, y=val_labels)

            # aggregate the final mean dice result
            metric = dice_metric.aggregate().item()
            # reset the status for next validation round
            dice_metric.reset()

            data["val dice"]["x"].append(epoch + 1)
            data["val dice"]["y"].append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), best_model_path)

print(
    f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")

In [None]:
model.load_state_dict(torch.load(best_model_path))
_ = model.eval()

## Check segmentations

Load validation files, apply validation transforms and display (no inverses yet!).

In [None]:
%matplotlib inline
to_imshow = []

for file in np.random.choice(val_files, size=5, replace=False):
    data = val_transforms(file)
    inferred = post_trans(model(data["image"][None].to(device))[0])
    to_imshow.append({
        "image": data["image"],
        "GT label": data["label"],
        "inferred label": inferred,
    })
imshows(to_imshow)

In [None]:
from monai.visualize.utils import blend_images
import numpy as np
import matplotlib.pyplot as plt

transforms = mt.Compose([
        mt.LoadImaged(keys),
        mt.Rotate90d(keys),
        mt.CropForegroundd(keys, source_key="image", margin=5),
])

idx = 0
data = train_files[idx]
data = transforms(data)
img, lbl = data["image"][:1], data["label"]

blended = np.moveaxis(blend_images(img, lbl), 0, -1)

plt.figure(figsize=(100, 100))
plt.imshow(blended)
_ = plt.axis(False)