In [1]:
!pip install monai



In [2]:
import sys
import json
import time
import torch
import logging
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from medpy.io import load, save

import monai
from monai.config import print_config
from monai.data import DataLoader, Dataset
from monai.networks.nets import SegResNet
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.transforms import (
    Activations,
    Activationsd,
    AsDiscrete,
    AsDiscreted,
    LoadImage,
    LoadImaged,
    AddChanneld,
    EnsureChannelFirstd,
    AsChannelFirstd,
    AsChannelLastd,
    Compose,
    RandRotate90d,
    Resized,
    ScaleIntensityd,
    ConcatItemsd,
    ToTensord
)

pin_memory = torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

In [3]:
fold = 0
path = '/media/mpmri/mpmri/fluke/workdir/nnUNet_raw_data/'
dataset = f"{path}Task2213_picai_fluke"
dir_image = f"{dataset}/imagesTr/"
dir_label = f"{dataset}/labelsTr/"
dataset_js = json.load(open(f"{dataset}/dataset.json"))
split_js = json.load(open(f"{dataset}/splits.json"))

batch_size = 2
image_size = (256,256,32)

In [4]:
def get_list(img_list):
    file_list = []
    for img in img_list:
        one = {'t2w':f"{dir_image}{img}_0000.nii.gz",
               'adc':f"{dir_image}{img}_0001.nii.gz",
               'hbv':f"{dir_image}{img}_0002.nii.gz",
               'label':f"{dir_label}{img}.nii.gz"}
        file_list.append(one)
    return file_list

In [5]:
def load_image(path_nii):
    image_data, image_header = load(path_nii)
    image_sitk = image_header.get_sitkimage()
    image_direction = image_header.get_direction()
    image_offset = image_header.get_offset()
    image_voxel_spacing = image_header.get_voxel_spacing()
    return image_data, image_sitk, image_header

In [6]:
def render_file(path):
    image_data, image_sitk, image_header = load_image(path)
    
    # print(image_sitk)
    print(image_data.shape)
    
    plt.figure(figsize=(16,16))
    for i in range(image_data.shape[2]):
        plt.subplot(6,6,i+1)
        plt.imshow(image_data[:,:,i], cmap='gray')
        plt.tick_params(axis='both',
                        which='both',
                        bottom=False,
                        left=False,
                        top=False,
                        labelbottom=False,
                        labelleft=False)

In [7]:
current_split = split_js[fold]
train_list = get_list(current_split['train'])
valid_list = get_list(current_split['val'])

In [8]:
# configure image transform
train_transforms = Compose([
    LoadImaged(keys=["t2w", "adc", "hbv", "label"], image_only=True),
    ScaleIntensityd(keys=["t2w", "adc", "hbv"]),
    EnsureChannelFirstd(keys=["t2w", "adc", "hbv", "label"]),
    Resized(keys=["t2w", "adc", "hbv", "label"], spatial_size=image_size),
    ConcatItemsd(keys=["t2w", "adc", "hbv"], name="input"),
    RandRotate90d(keys=["input", "label"]),
    ToTensord(keys=["input", "label"])
])
test_transforms = Compose([
    LoadImaged(keys=["t2w", "adc", "hbv", "label"], image_only=True),
    ScaleIntensityd(keys=["t2w", "adc", "hbv"]),
    EnsureChannelFirstd(keys=["t2w", "adc", "hbv", "label"]),
    Resized(keys=["t2w", "adc", "hbv", "label"], spatial_size=image_size),
    ConcatItemsd(keys=["t2w", "adc", "hbv"], name="input"),
    ToTensord(keys=["input", "label"])
])
check_ds = Dataset(data=train_list, transform=train_transforms)
check_loader = DataLoader(check_ds, 
                          batch_size=4, 
                          num_workers=4, 
                          pin_memory=pin_memory)

one_batch = next(iter(check_loader))

In [9]:
# debug data loader
print(f"Feature batch shape: {one_batch['input'].size()}")
print(f"Labels batch shape:  {one_batch['label'].size()}")
print(type(one_batch['input']), one_batch['input'].shape, one_batch['label'].shape)

Feature batch shape: torch.Size([4, 3, 256, 256, 32])
Labels batch shape:  torch.Size([4, 1, 256, 256, 32])
<class 'monai.data.meta_tensor.MetaTensor'> torch.Size([4, 3, 256, 256, 32]) torch.Size([4, 1, 256, 256, 32])


In [10]:
# create a training data loader (balanced)
train_ds = Dataset(data=train_list, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=batch_size, num_workers=4, pin_memory=pin_memory, shuffle=True)

# create a validation data loader
val_ds = Dataset(data=valid_list, transform=test_transforms)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=pin_memory)

In [11]:
# debug data loader 
sample_batch = next(iter(train_loader))
print(f"Feature batch shape: {sample_batch['input'].size()}")
print(f"Labels batch shape: {sample_batch['label'].size()}")

Feature batch shape: torch.Size([2, 3, 256, 256, 32])
Labels batch shape: torch.Size([2, 1, 256, 256, 32])


In [14]:
max_epochs = 300
val_interval = 1
VAL_AMP = True

# standard PyTorch program style: create SegResNet, DiceLoss and Adam optimizer
device = torch.device("cuda")
model = SegResNet(
    blocks_down=[1, 2, 2, 4],
    blocks_up=[1, 1, 1],
    init_filters=16,
    in_channels=3,
    out_channels=3,
    dropout_prob=0.2,
).to(device)
loss_function = DiceLoss(smooth_nr=0, smooth_dr=1e-5, squared_pred=True, to_onehot_y=False, sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4, weight_decay=1e-5)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)

dice_metric = DiceMetric(include_background=True, reduction="mean")
dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch")

post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])


# define inference method
def inference(input):
    def _compute(input):
        return sliding_window_inference(
            inputs=input,
            roi_size=(240, 240, 160),
            sw_batch_size=1,
            predictor=model,
            overlap=0.5,
        )

    if VAL_AMP:
        with torch.cuda.amp.autocast():
            return _compute(input)
    else:
        return _compute(input)

# use amp to accelerate training
scaler = torch.cuda.amp.GradScaler()
# enable cuDNN benchmark
torch.backends.cudnn.benchmark = True

In [13]:
est_metric = -1
best_metric_epoch = -1
best_metrics_epochs_and_time = [[], [], []]
epoch_loss_values = []
metric_values = []
metric_values_tc = []
metric_values_wt = []
metric_values_et = []

total_start = time.time()
for epoch in range(max_epochs):
    epoch_start = time.time()
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step_start = time.time()
        step += 1
        label = torch.cat([batch_data["label"].round()==0,batch_data["label"].round()==1,batch_data["label"].round()==2],dim=1)
        inputs, labels = (
            batch_data["input"].to(device),
            label.to(device),
        )
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        epoch_loss += loss.item()
        print(
            f"{step}/{len(train_ds) // train_loader.batch_size}"
            f", train_loss: {loss.item():.4f}"
            f", step time: {(time.time() - step_start):.4f}"
        )
    lr_scheduler.step()
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            for val_data in val_loader:
                label = torch.cat([val_data["label"].round()==0,val_data["label"].round()==1,val_data["label"].round()==2],dim=1)
                val_inputs, val_labels = (
                    val_data["input"].to(device),
                    label,
                )
                val_outputs = inference(val_inputs)
                val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
                dice_metric(y_pred=val_outputs, y=val_labels)
                dice_metric_batch(y_pred=val_outputs, y=val_labels)

            metric = dice_metric.aggregate().item()
            metric_values.append(metric)
            metric_batch = dice_metric_batch.aggregate()
            metric_tc = metric_batch[0].item()
            metric_values_tc.append(metric_tc)
            metric_wt = metric_batch[1].item()
            metric_values_wt.append(metric_wt)
            metric_et = metric_batch[2].item()
            metric_values_et.append(metric_et)
            dice_metric.reset()
            dice_metric_batch.reset()

            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                best_metrics_epochs_and_time[0].append(best_metric)
                best_metrics_epochs_and_time[1].append(best_metric_epoch)
                best_metrics_epochs_and_time[2].append(time.time() - total_start)
                torch.save(
                    model.state_dict(),
                    os.path.join(root_dir, "best_metric_model.pth"),
                )
                print("saved new best metric model")
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}"
                f"\nbest mean dice: {best_metric:.4f}"
                f" at epoch: {best_metric_epoch}"
            )
    print(f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}")
total_time = time.time() - total_start

----------
epoch 1/300
1/517, train_loss: 0.7140, step time: 5.7146
2/517, train_loss: 0.6836, step time: 0.2788
3/517, train_loss: 0.6772, step time: 0.2672
4/517, train_loss: 0.6681, step time: 0.2685
5/517, train_loss: 0.6647, step time: 0.2749
6/517, train_loss: 0.6522, step time: 0.2698
7/517, train_loss: 0.6633, step time: 0.2694
8/517, train_loss: 0.6334, step time: 0.2739
9/517, train_loss: 0.6662, step time: 0.2736
10/517, train_loss: 0.6537, step time: 0.2721
11/517, train_loss: 0.6546, step time: 0.2706
12/517, train_loss: 0.6567, step time: 0.2963
13/517, train_loss: 0.6557, step time: 0.2743
14/517, train_loss: 0.6345, step time: 0.2701
15/517, train_loss: 0.6645, step time: 0.2850
16/517, train_loss: 0.6417, step time: 0.2784
17/517, train_loss: 0.6363, step time: 0.2753
18/517, train_loss: 0.6586, step time: 0.2771
19/517, train_loss: 0.6333, step time: 0.2781
20/517, train_loss: 0.6500, step time: 0.2716
21/517, train_loss: 0.6705, step time: 0.2926
22/517, train_loss: 

NameError: name 'inference' is not defined

In [33]:
inputs, labels = (
    batch_data["input"].to(device),
    batch_data["label"].to(device),
)
optimizer.zero_grad()
with torch.cuda.amp.autocast():
    outputs = model(inputs)
    loss = loss_function(outputs, labels)

In [None]:
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}, total time: {total_time}.")