In [None]:
from google.colab import drive
drive.mount('/content/drive')
!ls "/content/drive/My Drive/colab"

In [None]:
%pip install monai torch pyimage

In [None]:
import os
from glob import glob

import torch
from monai.transforms import (
    Compose,
    LoadImaged,
    ToTensord,
    AddChanneld,
    Spacingd,
    ScaleIntensityRanged,
    CropForegroundd,
    Resized,
    EnsureChannelFirstd,
    RandCropByPosNegLabeld,
    Rand3DElasticd,
    RandShiftIntensityd,
    RandGaussianNoised,
    EnsureTyped,
    RandFlipd,
    RandRotate90d,
    #GammaTransformd,
    RandZoomd,
    Orientationd,
)

from monai.data import Dataset, DataLoader
from monai.utils import first
import matplotlib.pyplot as plt

import numpy as np

#the data paths are loaded and stored in a dictionary list
data_dir = "/content/drive/My Drive/colab"
root_dir = data_dir

train_images = sorted(glob(os.path.join(data_dir, "KITStrain/case_*/imaging.nii.gz")))
train_labels = sorted(glob(os.path.join(data_dir, "KITStrain/case_*/aggregated_MAJ_seg.nii.gz")))

val_images = sorted(glob(os.path.join(data_dir, "KITSval/case_*/imaging.nii.gz")))
val_labels = sorted(glob(os.path.join(data_dir, "KITSval/case_*/aggregated_MAJ_seg.nii.gz")))

train_files = [{"image": image_name, 'label': label_name} for image_name, label_name in zip(train_images, train_labels)]
val_files = [{"image": image_name, 'label': label_name} for image_name, label_name in zip(val_images, val_labels)]


#the transformations to apply to the images are defined

#the transforms are described at the train_transforms block of code
orig_transforms = Compose(

    [
        LoadImaged(keys=['image', 'label']),
        AddChanneld(keys=['image', 'label']),
        
        ToTensord(keys=['image', 'label'])
    ]
)

train_transforms = Compose(

    [


        #load image and label
        LoadImaged(keys=['image', 'label']),
        #makes the image format to have furst the num_channel, like from (spatial_dim_1[, spatial_dim_2, …]) to (num_channels, spatial_dim_1[, spatial_dim_2, …])
        AddChanneld(keys=['image', 'label']),
        #Ensures that the data are in format that the channel is first -> (C,H,W) instead of (H,W,C)
        #EnsureChannelFirstd(keys=["image", "label"]),
        Spacingd(keys=['image', 'label'], pixdim=(1.5, 1.5, 2)),
        #Must be used before any anisotropiic spatial transform, it assures that the images are in the standar RAS (right,anterior,superior) orientation
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        #Scales the intensity of the image to the given range (normalization)
        ScaleIntensityRanged(keys=["image"], a_min=-80, a_max=305,b_min=0.0, b_max=1.0, clip=True),
        #it removes the background
        CropForegroundd(keys=['image', 'label'], source_key='image'),

        #Resizes the image to the given spatial size
        #Resized(keys=['image', 'label'], spatial_size=[128,128,128]),
        #Randomly crops the image to the given spatial size taking into account the label
        RandCropByPosNegLabeld(
                keys=["image", "label"],
                label_key="label",
                spatial_size=(160, 160, 64),
                pos=1,
                neg=1,
                num_samples=4,
                image_key="image",
                image_threshold=0,
            ),
        #Randomly elasticly deforms the image
        Rand3DElasticd(
                keys=["image", "label"],
                mode=("bilinear", "nearest"),
                prob=0.5,
                sigma_range=(5, 8),
                magnitude_range=(50, 150),
                spatial_size=(160, 160, 64),
                translate_range=(10, 10, 5),
                rotate_range=(np.pi/36,np.pi/36, np.pi),
                scale_range=(0.1, 0.1, 0.1),
                padding_mode="zeros",
            ),
        #Randomly shifts the intensity of the image
        RandShiftIntensityd(
                keys=["image"],
                offsets=0.10,
                prob=0.25,
            ),
        #Randomly adds gaussian noise to the image
        RandGaussianNoised(keys=["image"], prob=0.25, mean=0.0, std=0.1),

        #these transformations randomly flip the images in different orientations
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[0],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[1],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[2],
            prob=0.10,
        ),
     
        #randomly rotates the images 90 degrees
        RandRotate90d(
            keys=["image", "label"],
            prob=0.10,
            max_k=3,
        ),

       #problema -> no hay una transformacion aleatoria de gamma -> o la creo o la aplico a todas??
        #GammaTransformd(keys=["image"], gamma_range=(0.7, 1.3)),
        
        #randomly makes a zoom to the image
        RandZoomd(keys=["image", "label"], prob=0.1, zoom_range=(0.9,1.1)),
        #it ensures that the input data is a pytorch tensor or a numpy array
        EnsureTyped(keys=["image", "label"]),
        #it transforms the data to a tensor
        ToTensord(keys=['image', 'label'])
    ]
)

#the transforms are described at the train_transforms block of code
val_transforms = Compose(

    [
        #load image and label
        LoadImaged(keys=['image', 'label']),
        AddChanneld(keys=['image', 'label']),
        #Ensures that the data are in format that the channel is first -> (C,H,W) instead of (H,W,C)
        #EnsureChannelFirstd(keys=["image", "label"]),
        Spacingd(keys=['image', 'label'], pixdim=(1.5, 1.5, 2)),
    
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        #Scales the intensity of the image to the given range
        ScaleIntensityRanged(keys=["image"], a_min=-80, a_max=305,b_min=0.0, b_max=1.0, clip=True),
        CropForegroundd(keys=['image', 'label'], source_key='image'),

        EnsureTyped(keys=["image", "label"]),

        ToTensord(keys=['image', 'label'])
    ]
)


#the dataloaders are defined

orig_ds = Dataset(data=train_files, transform=orig_transforms)
orig_loader = DataLoader(orig_ds, batch_size=2)

train_ds = Dataset(data=train_files, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=2)

val_ds = Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=2)


#an image is loaded and compared the original and the train preprocessed image to compare the differences

test_patient = first(train_loader)
orig_patient = first(orig_loader)


print(torch.min(test_patient['image']))
print(torch.max(test_patient['image']))



plt.figure('test', (12, 6))

plt.subplot(1, 3, 1)
plt.title('Orig patient')
plt.imshow(orig_patient['image'][0, 0, : ,: ,50], cmap= "gray")

plt.subplot(1, 3, 2)
plt.title('Slice of a patient')
plt.imshow(test_patient['image'][0, 0, : ,: ,50], cmap= "gray")

plt.subplot(1,3,3)
plt.title('Label of a patient')
plt.imshow(test_patient['label'][0, 0, : ,: ,50])
plt.show()

In [None]:
#an image is loaded and compared the original and the train preprocessed image to compare the differences

#n is used to define which layer from the 3d image show
n = 30

plt.figure('test', (12, 6))

plt.subplot(1, 3, 1)
plt.title('Orig patient')
plt.imshow(orig_patient['image'][0, 0, : ,: ,n], cmap= "gray")

plt.subplot(1, 3, 2)
plt.title('Slice of a patient')
plt.imshow(test_patient['image'][0, 0, : ,: ,n], cmap= "gray")

plt.subplot(1,3,3)
plt.title('Label of a patient')
plt.imshow(test_patient['label'][0, 0, : ,: ,n])
plt.show()

In [None]:
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch,SmartCacheDataset

print("CREATING TRAIN DS", flush=True)
train_ds = CacheDataset(
    data=train_files, transform=train_transforms,
    #cache_rate=0.1,
    #replace_rate=0.5
    )
print(len(train_ds))
# train_ds = Dataset(data=train_files, transform=train_transforms)
print("CREATED TRAIN DS", flush=True)
# use batch_size=2 to load images and use RandCropByPosNegLabeld
# to generate 2 x 4 images for network training
train_loader = DataLoader(train_ds, batch_size=2, 
                          shuffle=True,
                          num_workers=2)
print("CREATED TRAIN DATALOADER", flush=True)

val_ds = CacheDataset(
    data=val_files, transform=val_transforms,
    #cache_rate=0.1,
    #replace_rate=0.5
    )
# val_ds = Dataset(data=val_files, transform=val_transforms)
print("CREATED VAL DS", flush=True)
val_loader = DataLoader(val_ds, batch_size=2, num_workers=2)
print("CREATED VAL DATALOADER", flush=True)

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device('cpu')
print(device)

In [None]:
torch.cuda.is_available()

In [None]:
from monai.networks.nets import UNet, UNETR, DynUNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss, DiceCELoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch,SmartCacheDataset
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
import numpy as np




from monai.losses import  DiceCELoss
from monai.metrics import DiceMetric
from monai.networks.nets import UNet
from monai.transforms import (
    AsDiscrete,
    Compose,
    EnsureType,
)

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = UNet(
    dimensions=3, #3 dimensions because the data is 3d
    in_channels=1, #the nº of input channels is 1 (the RX intensity)
    out_channels=4, #the nº of output channels are 4 (which are the 4 labels to output)
    channels=(64, 128, 256, 512), #indicates the nº of channels in each model layer (will have 4 layers with 64,128,256 and 512 channels)
    strides=(2, 2, 2, 2), #indicates the stride step (how many pixels jumps the convolution window) (bigger stride makes reduces the output dimension but accelerates the processing and quantity of parameters)
    num_res_units=2, #number of residual units
    norm="INSTANCE",
).to(device) #to device moves the model to the device (for example GPU) to be used for the training

#model.load_state_dict(torch.load(best_metric_model_file,map_location=torch.device(device)))

print("CREATED MODEL", flush=True)
#this loss funcion is a combination between dice and CE (cross entropy)
#Dice measures the similarity between 2 images and compares the superposition of segmented regions in the model and in the ground truth
#CE compares the output of the model (which is a probability map of the label classes) and compares it with the ground truth
loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
#the optimizer AdamW is an Adam variation
optimizer = torch.optim.AdamW(model.parameters(), 1e-4)
#dicemetric is used to evaluate the dice metric of the model
dice_metric = DiceMetric(include_background=False, reduction="mean")


max_epochs = 10
#how many epochs happens between each evaluation of the model
val_interval = 1
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
#these post_pred and post_label are used in the evaluation of the model to see how well the model performs during the training
post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=4, n_classes=4)])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=4, n_classes=4)])


###############################################


for epoch in range(max_epochs):
    print("-" * 10, flush=True)
    print(f"epoch {epoch + 1}/{max_epochs}", flush=True)
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["label"].to(device),
        )
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        print(
            f"{step}/{len(train_ds) // train_loader.batch_size}, "
            f"train_loss: {loss.item():.4f}", flush=True)
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}", flush=True)

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            for val_data in val_loader:
                val_inputs, val_labels = (
                    val_data["image"].to(device),
                    val_data["label"].to(device),
                )
                roi_size = (160,160, 64)
                sw_batch_size = 2
                val_outputs = sliding_window_inference(
                    val_inputs, roi_size, sw_batch_size, model)
                val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                # compute metric for current iteration
                dice_metric(y_pred=val_outputs, y=val_labels)

            # aggregate the final mean dice result
            metric = dice_metric.aggregate().item()
            # reset the status for next validation round
            dice_metric.reset()

            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                #torch.save(model.state_dict(), os.path.join(
                #    root_dir, "best_metric_model_"+str(epoch)+"_"+str(f"{metric:.4f}")+".pth"))
                print("saved new best metric model", flush=True)
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f"\nbest mean dice: {best_metric:.4f} "
                f"at epoch: {best_metric_epoch}", flush=True
            )



print(epoch_loss_values, flush=True)
print(metric_values, flush=True)
print(
f"train completed, best_metric: {best_metric:.4f} "
f"at epoch: {best_metric_epoch}", flush=True)

In [None]:
eval_num = val_interval
plt.figure("train", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Iteration Average Loss")
x = [eval_num * (i + 1) for i in range(len(epoch_loss_values))]
y = epoch_loss_values
plt.xlabel("Iteration")
plt.plot(x, y)
plt.subplot(1, 2, 2)
plt.title("Val Mean Dice")
x = [eval_num * (i + 1) for i in range(len(metric_values))]
y = metric_values
plt.xlabel("Iteration")
plt.plot(x, y)
plt.show()

In [None]:
root_dir = data_dir
torch.save(model.state_dict(), os.path.join(
                    root_dir, "best_metric_model"+".pth"))

In [None]:
case_num = 1
model.load_state_dict(torch.load(os.path.join(root_dir, "best_metric_model.pth")))
model.eval()
with torch.no_grad():
    n = 60
    img_name = os.path.split(val_ds[case_num]["image"].meta["filename_or_obj"])[1]
    img = val_ds[case_num]["image"]
    label = val_ds[case_num]["label"]
    val_inputs = torch.unsqueeze(img, 1).cuda()
    val_labels = torch.unsqueeze(label, 1).cuda()
    val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 4, model, overlap=0.8)
    plt.figure("check", (18, 6))
    plt.subplot(1, 3, 1)
    plt.title("image")
    plt.imshow(val_inputs.cpu().numpy()[0, 0, :, :, n], cmap="gray")
    plt.subplot(1, 3, 2)
    plt.title("label")
    plt.imshow(val_labels.cpu().numpy()[0, 0, :, :, n])
    plt.subplot(1, 3, 3)
    plt.title("output")
    plt.imshow(torch.argmax(val_outputs, dim=1).detach().cpu()[0, :, :, n])
    plt.show()

In [None]:
    n = 50
    plt.figure("check", (18, 6))
    plt.subplot(1, 3, 1)
    plt.title("image")
    plt.imshow(val_inputs.cpu().numpy()[0, 0, :, :, n], cmap="gray")
    plt.subplot(1, 3, 2)
    plt.title("label")
    plt.imshow(val_labels.cpu().numpy()[0, 0, :, :, n])
    plt.subplot(1, 3, 3)
    plt.title("output")
    plt.imshow(torch.argmax(val_outputs, dim=1).detach().cpu()[0, :, :, n])
    plt.show()

In [None]:
case_num = 1
model.load_state_dict(torch.load(os.path.join(root_dir, "best_metric_model.pth")))
model.eval()
with torch.no_grad():
    n = 30
    img_name = os.path.split(val_ds[case_num]["image"].meta["filename_or_obj"])[1]
    img = val_ds[case_num]["image"]
    label = val_ds[case_num]["label"]
    val_inputs = torch.unsqueeze(img, 1).cuda()
    val_labels = torch.unsqueeze(label, 1).cuda()
    val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 4, model, overlap=0.8)
    plt.figure("check", (18, 6))
    plt.subplot(1, 3, 1)
    plt.title("image")
    plt.imshow(val_inputs.cpu().numpy()[0, 0, :, :, slice_map[img_name]], cmap="gray")
    plt.subplot(1, 3, 2)
    plt.title("label")
    plt.imshow(val_labels.cpu().numpy()[0, 0, :, :, slice_map[img_name]])
    plt.subplot(1, 3, 3)
    plt.title("output")
    plt.imshow(torch.argmax(val_outputs, dim=1).detach().cpu()[0, :, :, slice_map[img_name]])
    plt.show()



    

plt.figure('test', (12, 6))

plt.subplot(1, 3, 1)
plt.title('Orig patient')
plt.imshow(orig_patient['image'][0, 0, : ,: ,n], cmap= "gray")

plt.subplot(1, 3, 2)
plt.title('Slice of a patient')
plt.imshow(test_patient['image'][0, 0, : ,: ,n], cmap= "gray")

plt.subplot(1,3,3)
plt.title('Label of a patient')
plt.imshow(test_patient['label'][0, 0, : ,: ,n])
plt.show()

In [None]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = UNETR(
    in_channels=1,
    out_channels=4,
    img_size=(96, 96, 96),
    feature_size=16,
    hidden_size=768,
    mlp_dim=3072,
    num_heads=12,
    pos_embed="perceptron",
    norm_name="instance",
    res_block=True,
    dropout_rate=0.0,
).to(device)

loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
torch.backends.cudnn.benchmark = True
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

def validation(epoch_iterator_val):
    model.eval()
    with torch.no_grad():
        for batch in epoch_iterator_val:
            val_inputs, val_labels = (batch["image"].cpu(), batch["label"].cpu())
            val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 4, model)
            val_labels_list = decollate_batch(val_labels)
            val_labels_convert = [post_label(val_label_tensor) for val_label_tensor in val_labels_list]
            val_outputs_list = decollate_batch(val_outputs)
            val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list]
            dice_metric(y_pred=val_output_convert, y=val_labels_convert)
            epoch_iterator_val.set_description("Validate (%d / %d Steps)" % (global_step, 10.0))
        mean_dice_val = dice_metric.aggregate().item()
        dice_metric.reset()
    return mean_dice_val


def train(global_step, train_loader, dice_val_best, global_step_best):
    model.train()
    epoch_loss = 0
    step = 0
    epoch_iterator = tqdm(train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True)
    for step, batch in enumerate(epoch_iterator):
        step += 1
        x, y = (batch["image"].cpu , batch["label"].cpu())
        logit_map = model(x)
        loss = loss_function(logit_map, y)
        loss.backward()
        epoch_loss += loss.item()
        optimizer.step()
        optimizer.zero_grad()
        epoch_iterator.set_description("Training (%d / %d Steps) (loss=%2.5f)" % (global_step, max_iterations, loss))
        if (global_step % eval_num == 0 and global_step != 0) or global_step == max_iterations:
            epoch_iterator_val = tqdm(val_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True)
            dice_val = validation(epoch_iterator_val)
            epoch_loss /= step
            epoch_loss_values.append(epoch_loss)
            metric_values.append(dice_val)
            if dice_val > dice_val_best:
                dice_val_best = dice_val
                global_step_best = global_step
                torch.save(model.state_dict(), os.path.join(root_dir, "best_metric_model.pth"))
                print(
                    "Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(dice_val_best, dice_val)
                )
            else:
                print(
                    "Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(
                        dice_val_best, dice_val
                    )
                )
        global_step += 1
    return global_step, dice_val_best, global_step_best


max_iterations = 5
eval_num = 500
post_label = AsDiscrete(to_onehot=14)
post_pred = AsDiscrete(argmax=True, to_onehot=14)
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
global_step = 0
dice_val_best = 0.0
global_step_best = 0
epoch_loss_values = []
metric_values = []
while global_step < max_iterations:
    global_step, dice_val_best, global_step_best = train(global_step, train_loader, dice_val_best, global_step_best)
model.load_state_dict(torch.load(os.path.join(root_dir, "best_metric_model.pth")))

In [None]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = UNETR(
    in_channels=1,
    out_channels=14,
    img_size=(96, 96, 96),
    feature_size=16,
    hidden_size=768,
    mlp_dim=3072,
    num_heads=12,
    pos_embed="perceptron",
    norm_name="instance",
    res_block=True,
    dropout_rate=0.0,
).to(device)

loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
torch.backends.cudnn.benchmark = True
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

def validation(epoch_iterator_val):
    model.eval()
    with torch.no_grad():
        for batch in epoch_iterator_val:
            val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda())
            val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 4, model)
            val_labels_list = decollate_batch(val_labels)
            val_labels_convert = [post_label(val_label_tensor) for val_label_tensor in val_labels_list]
            val_outputs_list = decollate_batch(val_outputs)
            val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list]
            dice_metric(y_pred=val_output_convert, y=val_labels_convert)
            epoch_iterator_val.set_description("Validate (%d / %d Steps)" % (global_step, 10.0))
        mean_dice_val = dice_metric.aggregate().item()
        dice_metric.reset()
    return mean_dice_val


def train(global_step, train_loader, dice_val_best, global_step_best):
    model.train()
    epoch_loss = 0
    step = 0
    epoch_iterator = tqdm(train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True)
    for step, batch in enumerate(epoch_iterator):
        step += 1
        x, y = (batch["image"].cuda(), batch["label"].cuda())
        logit_map = model(x)
        loss = loss_function(logit_map, y)
        loss.backward()
        epoch_loss += loss.item()
        optimizer.step()
        optimizer.zero_grad()
        epoch_iterator.set_description("Training (%d / %d Steps) (loss=%2.5f)" % (global_step, max_iterations, loss))
        if (global_step % eval_num == 0 and global_step != 0) or global_step == max_iterations:
            epoch_iterator_val = tqdm(val_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True)
            dice_val = validation(epoch_iterator_val)
            epoch_loss /= step
            epoch_loss_values.append(epoch_loss)
            metric_values.append(dice_val)
            if dice_val > dice_val_best:
                dice_val_best = dice_val
                global_step_best = global_step
                torch.save(model.state_dict(), os.path.join(root_dir, "best_metric_model.pth"))
                print(
                    "Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(dice_val_best, dice_val)
                )
            else:
                print(
                    "Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(
                        dice_val_best, dice_val
                    )
                )
        global_step += 1
    return global_step, dice_val_best, global_step_best


max_iterations = 5
eval_num = 500
post_label = AsDiscrete(to_onehot=14)
post_pred = AsDiscrete(argmax=True, to_onehot=14)
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
global_step = 0
dice_val_best = 0.0
global_step_best = 0
epoch_loss_values = []
metric_values = []
while global_step < max_iterations:
    global_step, dice_val_best, global_step_best = train(global_step, train_loader, dice_val_best, global_step_best)
model.load_state_dict(torch.load(os.path.join(root_dir, "best_metric_model.pth")))

In [None]:
from monai.losses import  DiceCELoss
from monai.metrics import DiceMetric
from monai.networks.nets import UNETR
from monai.transforms import (
    AsDiscrete,
    Compose,
    EnsureType,
)


model = UNETR(
    in_channels=1,
    out_channels=8,
    img_size=(96, 96, 96),
    feature_size=16,
    hidden_size=768,
    mlp_dim=3072,
    num_heads=12,
    pos_embed="perceptron",
    norm_name="instance",
    res_block=True,
    dropout_rate=0.0,
).to(device)

loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
torch.backends.cudnn.benchmark = True
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
dice_metric = DiceMetric(include_background=False, reduction="mean")


max_epochs = 250
val_interval = 1
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=4, n_classes=4)])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=4, n_classes=4)])


In [None]:
for epoch in range(max_epochs):
    print("-" * 10, flush=True)
    print(f"epoch {epoch + 1}/{max_epochs}", flush=True)
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["label"].to(device),
        )
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        print(
            f"{step}/{len(train_ds) // train_loader.batch_size}, "
            f"train_loss: {loss.item():.4f}", flush=True)
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}", flush=True)

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            for val_data in val_loader:
                val_inputs, val_labels = (
                    val_data["image"].to(device),
                    val_data["label"].to(device),
                )
                roi_size = (160, 160, 64)
                sw_batch_size = 4
                val_outputs = sliding_window_inference(
                    val_inputs, roi_size, sw_batch_size, model)
                val_outputs = [post_pred(i)
                                            for i in decollate_batch(val_outputs)]
                val_labels = [post_label(i)
                                            for i in decollate_batch(val_labels)]
                # compute metric for current iteration
                dice_metric(y_pred=val_outputs, y=val_labels)

            # aggregate the final mean dice result
            metric = dice_metric.aggregate().item()
            # reset the status for next validation round
            dice_metric.reset()

            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), os.path.join(
                    root_dir, "best_metric_model_"+str(epoch)+"_"+str(f"{metric:.4f}")+".pth"))
                print("saved new best metric model", flush=True)
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f"\nbest mean dice: {best_metric:.4f} "
                f"at epoch: {best_metric_epoch}", flush=True
            )

# torch.save(model.state_dict(), os.path.join(
#                    root_dir, "last_model.pth"))
print(epoch_loss_values, flush=True)
print(metric_values, flush=True)
print(
f"train completed, best_metric: {best_metric:.4f} "
f"at epoch: {best_metric_epoch}", flush=True)

In [None]:
from tqdm import tqdm

def validation(epoch_iterator_val):
    model.eval()
    with torch.no_grad():
        for batch in epoch_iterator_val:
            val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda())
            val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 4, model)
            val_labels_list = decollate_batch(val_labels)
            val_labels_convert = [post_label(val_label_tensor) for val_label_tensor in val_labels_list]
            val_outputs_list = decollate_batch(val_outputs)
            val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list]
            dice_metric(y_pred=val_output_convert, y=val_labels_convert)
            epoch_iterator_val.set_description("Validate (%d / %d Steps)" % (global_step, 10.0))
        mean_dice_val = dice_metric.aggregate().item()
        dice_metric.reset()
    return mean_dice_val


def train(global_step, train_loader, dice_val_best, global_step_best):
    model.train()
    epoch_loss = 0
    step = 0
    epoch_iterator = tqdm(train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True)
    for step, batch in enumerate(epoch_iterator):
        step += 1
        x, y = (batch["image"].cuda(), batch["label"].cuda())
        logit_map = model(x)
        loss = loss_function(logit_map, y)
        loss.backward()
        epoch_loss += loss.item()
        optimizer.step()
        optimizer.zero_grad()
        epoch_iterator.set_description("Training (%d / %d Steps) (loss=%2.5f)" % (global_step, max_iterations, loss))
        if (global_step % eval_num == 0 and global_step != 0) or global_step == max_iterations:
            epoch_iterator_val = tqdm(val_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True)
            dice_val = validation(epoch_iterator_val)
            epoch_loss /= step
            epoch_loss_values.append(epoch_loss)
            metric_values.append(dice_val)
            if dice_val > dice_val_best:
                dice_val_best = dice_val
                global_step_best = global_step
                torch.save(model.state_dict(), os.path.join(root_dir, "best_metric_model.pth"))
                print(
                    "Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(dice_val_best, dice_val)
                )
            else:
                print(
                    "Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(
                        dice_val_best, dice_val
                    )
                )
        global_step += 1
    return global_step, dice_val_best, global_step_best


max_iterations = 50
eval_num = 500
post_label = AsDiscrete(to_onehot=14)
post_pred = AsDiscrete(argmax=True, to_onehot=14)
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
global_step = 0
dice_val_best = 0.0
global_step_best = 0
epoch_loss_values = []
metric_values = []
while global_step < max_iterations:
    global_step, dice_val_best, global_step_best = train(global_step, train_loader, dice_val_best, global_step_best)
model.load_state_dict(torch.load(os.path.join(root_dir, "best_metric_model.pth")))

In [None]:
import os
import shutil
import tempfile

import matplotlib.pyplot as plt
from tqdm import tqdm

from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.transforms import (
    AsDiscrete,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
    AddChanneld,
    ToTensord,

    LoadImage,
    ToTensor,
    AddChannel,
    Resized,
)

from monai.config import print_config
from monai.metrics import DiceMetric
from monai.networks.nets import UNETR

from monai.data import (
    DataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
)


import torch

print_config()

In [None]:
directory = os.environ.get("C:\\Users\\jm\\Desktop\\monai")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

In [None]:
orig_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        AddChanneld(keys=["image", "label"]),

        ToTensord(keys=["image", "label"])
    ]
)

train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-175,
            a_max=250,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=(96, 96, 96),
            pos=1,
            neg=1,
            num_samples=4,
            image_key="image",
            image_threshold=0,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[0],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[1],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[2],
            prob=0.10,
        ),
        RandRotate90d(
            keys=["image", "label"],
            prob=0.10,
            max_k=3,
        ),
        RandShiftIntensityd(
            keys=["image"],
            offsets=0.10,
            prob=0.50,
        ),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
        CropForegroundd(keys=["image", "label"], source_key="image"),
    ]
)

In [None]:
import os
from glob import glob
data_dir = "C:\\Users\\jm\\Desktop\\folder"

train_images = sorted(glob(os.path.join(data_dir, "KITStrain\\case_*\\imaging.nii.gz")))
train_labels = sorted(glob(os.path.join(data_dir, "KITStrain\\case_*\\aggregated_MAJ_seg.nii.gz")))

val_images = sorted(glob(os.path.join(data_dir, "KITSval\\case_*\\imaging.nii.gz")))
val_labels = sorted(glob(os.path.join(data_dir, "KITSval\\case_*\\aggregated_MAJ_seg.nii.gz")))

train_files = [{"image": image_name, 'label': label_name} for image_name, label_name in zip(train_images, train_labels)]
val_files = [{"image": image_name, 'label': label_name} for image_name, label_name in zip(val_images, val_labels)]


In [None]:
from monai.data import Dataset, DataLoader
from monai.utils import first
import matplotlib.pyplot as plt

#orig_transforms = Compose(
#    [
#        LoadImaged(keys=["image", "label"])
        #AddChanneld(keys=["image", "label"]),

        #ToTensord(keys=["image", "label"])
#    ]
#)

#train_files = [{"image" : "C:\\Users\\Desktop\\KITS-21\\case_00001\\imaging.nii.gz", "label" : "C:\\Users\\Desktop\\KITS-21\\case_00001\\aggregated_MAJ_seg.nii.gz"}]

orig_ds = Dataset(data=train_files, transform=orig_transforms)
orig_loader = DataLoader(orig_ds, batch_size=1)#, num_workers=4, shuffle=True)

train_ds = Dataset(data=train_files, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=1)#, num_workers=4, shuffle=True)

val_ds = Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=1)#, num_workers=4, shuffle=True)


In [None]:
import copy
import time
import pprint

import torch
import torchio as tio
import numpy as np
import pandas as pd
import nibabel as nib
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()

sns.set_style("whitegrid", {'axes.grid' : False})
%config InlineBackend.figure_format = 'retina'
torch.manual_seed(14041931)

print('TorchIO version:', tio.__version__)
print('Last run on', time.ctime())

fpg = tio.datasets.FPG()
print('Sample subject:', fpg)
show_fpg(fpg)

In [None]:
test_patient = first(train_loader)
orig_patient = first(orig_loader)

plt.figure("test", (12, 6))

plt.subplot(1, 3, 1)
plt.title("orig patient image")
plt.imshow(orig_patient["image"][0, 0, :, :, 30], cmap="gray")

plt.subplot(1, 3, 2)
plt.title("slice of a patient")

In [None]:
plt.figure('test', (12, 6))

plt.subplot(1, 3, 1)
plt.title('Orig patient')
plt.imshow(orig_patient['image'][0, 0, : ,: ,30], cmap= "gray")

plt.subplot(1, 3, 2)
plt.title('Slice of a patient')
plt.imshow(test_patient['image'][0, 0, : ,: ,30], cmap= "gray")

plt.subplot(1,3,3)
plt.title('Label of a patient')
plt.imshow(test_patient['label'][0, 0, : ,: ,30])
plt.show()

In [None]:
import SimpleITK as sitk
import matplotlib.pyplot as plt

m = first(train_files)
m = "C:\\Users\\Desktop\\KITS-21\\case_00000\\imaging.nii.gz"
image = sitk.ReadImage(m)
image_array = sitk.GetArrayFromImage(image)
n = image_array[300,:,:]
img = sitk.ReadImage(n)
plt.subplot(121)
plt.imshow(img)
plt.title('Imagen')



plt.tight_layout()
plt.show()

In [None]:
# load the images
# do any transforms
# need to convert them into troch tensors

orig_transforms = Compose(

    [
        LoadImaged(keys=['image', 'label']),
        AddChanneld(keys=['image', 'label']),
        
        ToTensord(keys=['image', 'label'])
    ]
)

orig_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        #AddChanneld(keys=["image", "label"]),

        #ToTensord(keys=["image", "label"])
    ]
)

train_transforms = Compose(

    [
        LoadImaged(keys=['image', 'label']),
        AddChanneld(keys=['image', 'label']),
        Spacingd(keys=['image', 'label'], pixdim=(1.5, 1.5, 2)),
        ScaleIntensityRanged(keys='image', a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True),
        CropForegroundd(keys=['image', 'label'], source_key='image'),
        Resized(keys=['image', 'label'], spatial_size=[128,128,128]),
        ToTensord(keys=['image', 'label'])
    ]
)


val_transforms = Compose(

    [
        LoadImaged(keys=['image', 'label']),
        AddChanneld(keys=['image', 'label']),
        Spacingd(keys=['image', 'label'], pixdim=(1.5, 1.5, 2)),
        ScaleIntensityRanged(keys='image', a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True),
        ToTensord(keys=['image', 'label'])
    ]
)

In [None]:
import monai
import os

import os
import shutil
import tempfile

import matplotlib.pyplot as plt
from tqdm import tqdm

from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.transforms import (
    AsDiscrete,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
    AddChanneld,
    ToTensord,

    LoadImage,
    ToTensor,
    AddChannel,
    Resized,
)

from monai.config import print_config
from monai.metrics import DiceMetric
from monai.networks.nets import UNETR

from monai.data import (
    DataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
)

import numpy as np

import torch

filename = os.path.join("C:\\Users\\Desktop", "MONAI-logo_color.png")
monai.apps.download_url("https://monai.io/assets/img/MONAI-logo_color.png", filepath=filename)

transform = Compose(
    [
        LoadImaged(keys="image", image_only=True, ensure_channel_first=True, dtype=np.uint8),
        Resized(keys="image", spatial_size=[60, 64]),
    ]
)
test_data = {"image": filename}
result = transform(test_data)
print(result["image"].shape)

In [None]:
import os
from glob import glob

import torch
from monai.transforms import (
    Compose,
    LoadImaged,
    ToTensord,
    AddChanneld,
    Spacingd,
    ScaleIntensityRanged,
    CropForegroundd,
    Resized,

)

from monai.data import Dataset, DataLoader
from monai.utils import first
import matplotlib.pyplot as plt

data_dir = "C:\\Users\\jm\\Desktop\\folder"

train_images = sorted(glob(os.path.join(data_dir, "KITStrain\\case_*\\imaging.nii.gz")))
train_labels = sorted(glob(os.path.join(data_dir, "KITStrain\\case_*\\aggregated_MAJ_seg.nii.gz")))

val_images = sorted(glob(os.path.join(data_dir, "KITSval\\case_*\\imaging.nii.gz")))
val_labels = sorted(glob(os.path.join(data_dir, "KITSval\\case_*\\aggregated_MAJ_seg.nii.gz")))

train_files = [{"image": image_name, 'label': label_name} for image_name, label_name in zip(train_images, train_labels)]
val_files = [{"image": image_name, 'label': label_name} for image_name, label_name in zip(val_images, val_labels)]



orig_transforms = Compose(

    [
        LoadImaged(keys=['image', 'label']),
        AddChanneld(keys=['image', 'label']),
        
        ToTensord(keys=['image', 'label'])
    ]
)

train_transforms = Compose(

    [
        LoadImaged(keys=['image', 'label']),
        AddChanneld(keys=['image', 'label']),
        Spacingd(keys=['image', 'label'], pixdim=(1.5, 1.5, 2)),
        ScaleIntensityRanged(keys='image', a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True),
        CropForegroundd(keys=['image', 'label'], source_key='image'),
        Resized(keys=['image', 'label'], spatial_size=[128,128,128]),
        ToTensord(keys=['image', 'label'])
    ]
)


val_transforms = Compose(

    [
        LoadImaged(keys=['image', 'label']),
        AddChanneld(keys=['image', 'label']),
        Spacingd(keys=['image', 'label'], pixdim=(1.5, 1.5, 2)),
        ScaleIntensityRanged(keys='image', a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True),
        ToTensord(keys=['image', 'label'])
    ]
)



orig_ds = Dataset(data=train_files, transform=orig_transforms)
orig_loader = DataLoader(orig_ds, batch_size=1)

train_ds = Dataset(data=train_files, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=1)

val_ds = Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=1)



test_patient = first(train_loader)
orig_patient = first(orig_loader)


print(torch.min(test_patient['image']))
print(torch.max(test_patient['image']))



plt.figure('test', (12, 6))

plt.subplot(1, 3, 1)
plt.title('Orig patient')
plt.imshow(orig_patient['image'][0, 0, : ,: ,50], cmap= "gray")

plt.subplot(1, 3, 2)
plt.title('Slice of a patient')
plt.imshow(test_patient['image'][0, 0, : ,: ,50], cmap= "gray")

plt.subplot(1,3,3)
plt.title('Label of a patient')
plt.imshow(test_patient['label'][0, 0, : ,: ,50])
plt.show()
