In [1]:
import os
from glob import glob
import shutil
from tqdm import tqdm
import dicom2nifti
import numpy as np
import nibabel as nib
from monai.transforms import(
    Compose,
    AddChanneld,
    LoadImaged,
    Resized,
    ToTensord,
    Spacingd,
    Orientationd,
    ScaleIntensityRanged,
    CropForegroundd,
)
from monai.data import DataLoader, Dataset, CacheDataset
from monai.utils import set_determinism

In [2]:

"""
This file is for preporcessing only, it contains all the functions that you need
to make your data ready for training.

You need to install the required libraries if you do not already have them.

pip install os, ...
"""


def create_groups(in_dir, out_dir, Number_slices):
    '''
    This function is to get the last part of the path so that we can use it to name the folder.
    `in_dir`: the path to your folders that contain dicom files
    `out_dir`: the path where you want to put the converted nifti files
    `Number_slices`: here you put the number of slices that you need for your project and it will 
    create groups with this number.
    '''

    for patient in glob(in_dir + '/*'):
        patient_name = os.path.basename(os.path.normpath(patient))

        # Here we need to calculate the number of folders which mean into how many groups we will divide the number of slices
        number_folders = int(len(glob(patient + '/*')) / Number_slices)

        for i in range(number_folders):
            output_path = os.path.join(out_dir, patient_name + '_' + str(i))
            os.mkdir(output_path)

            # Move the slices into a specific folder so that you will save memory in your desk
            for i, file in enumerate(glob(patient + '/*')):
                if i == Number_slices + 1:
                    break

                shutil.move(file, output_path)


def dcm2nifti(in_dir, out_dir):
    '''
    This function will be used to convert dicoms into nifti files after creating the groups with 
    the number of slices that you want.
    `in_dir`: the path to the folder where you have all the patients (folder of all the groups).
    `out_dir`: the path to the output, which means where you want to save the converted nifties.
    '''

    for folder in tqdm(glob(in_dir + '/*')):
        patient_name = os.path.basename(os.path.normpath(folder))
        dicom2nifti.dicom_series_to_nifti(
            folder, os.path.join(out_dir, patient_name + '.nii.gz'))


def find_empy(in_dir):
    '''
    This function will help you to find the empty volumes that you may not need for your training
    so instead of opening all the files and search for the empty ones, them use this function to make it quick.
    '''

    list_patients = []
    for patient in glob(os.path.join(in_dir, '*')):
        img = nib.load(patient)

        if len(np.unique(img.get_fdata())) > 2:
            print(os.path.basename(os.path.normpath(patient)))
            list_patients.append(os.path.basename(os.path.normpath(patient)))

    return list_patients


def prepare(in_dir, pixdim=(1.5, 1.5, 1.0), a_min=-200, a_max=200, spatial_size=[128, 128, 64], cache=False):
    """
    This function is for preprocessing, it contains only the basic transforms, but you can add more operations that you 
    find in the Monai documentation.
    https://monai.io/docs.html
    """

    set_determinism(seed=0)

    path_train_volumes = sorted(
        glob(os.path.join(in_dir, "images", "*.nii.gz")))
    path_train_segmentation = sorted(
        glob(os.path.join(in_dir, "labels", "*.nii.gz")))

    path_test_volumes = sorted(
        glob(os.path.join(in_dir, "images", "*.nii.gz")))
    path_test_segmentation = sorted(
        glob(os.path.join(in_dir, "labels", "*.nii.gz")))

    train_files = [{"vol": image_name, "seg": label_name} for image_name,
                   label_name in zip(path_train_volumes, path_train_segmentation)]
    test_files = [{"vol": image_name, "seg": label_name} for image_name,
                  label_name in zip(path_test_volumes, path_test_segmentation)]

    train_transforms = Compose(
        [
            LoadImaged(keys=["vol", "seg"]),
            AddChanneld(keys=["vol", "seg"]),
            Spacingd(keys=["vol", "seg"], pixdim=pixdim,
                     mode=("bilinear", "nearest")),
            Orientationd(keys=["vol", "seg"], axcodes="RAS"),
            ScaleIntensityRanged(
                keys=["vol"], a_min=a_min, a_max=a_max, b_min=0.0, b_max=1.0, clip=True),
            CropForegroundd(keys=["vol", "seg"], source_key="vol"),
            Resized(keys=["vol", "seg"], spatial_size=spatial_size),
            ToTensord(keys=["vol", "seg"]),

        ]
    )

    test_transforms = Compose(
        [
            LoadImaged(keys=["vol", "seg"]),
            AddChanneld(keys=["vol", "seg"]),
            Spacingd(keys=["vol", "seg"], pixdim=pixdim,
                     mode=("bilinear", "nearest")),
            Orientationd(keys=["vol", "seg"], axcodes="RAS"),
            ScaleIntensityRanged(
                keys=["vol"], a_min=a_min, a_max=a_max, b_min=0.0, b_max=1.0, clip=True),
            CropForegroundd(keys=['vol', 'seg'], source_key='vol'),
            Resized(keys=["vol", "seg"], spatial_size=spatial_size),
            ToTensord(keys=["vol", "seg"]),


        ]
    )

    if cache:
        train_ds = CacheDataset(
            data=train_files, transform=train_transforms, cache_rate=1.0)
        train_loader = DataLoader(train_ds, batch_size=1)

        test_ds = CacheDataset(
            data=test_files, transform=test_transforms, cache_rate=1.0)
        test_loader = DataLoader(test_ds, batch_size=1)

        return train_loader, test_loader

    else:
        train_ds = Dataset(data=train_files, transform=train_transforms)
        train_loader = DataLoader(train_ds, batch_size=1)

        test_ds = Dataset(data=test_files, transform=test_transforms)
        test_loader = DataLoader(test_ds, batch_size=1)

        return train_loader, test_loader


In [3]:
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.losses import DiceLoss, DiceCELoss

import torch
from utilities import train


data_dir = "F:\\research_datasets\\Task01_BrainTumour\\Task01_BrainTumour\\nifti_files"
model_dir = "F:\\research_datasets\\Task01_BrainTumour\\Task01_BrainTumour\\models" 
data_in = prepare(data_dir, cache=False)

device = torch.device("cuda:0")
train_loader, test_loader = data_in
print(train_loader)
print(test_loader)


<monai.data.dataloader.DataLoader object at 0x00000252771505E0>
<monai.data.dataloader.DataLoader object at 0x0000025277150CD0>




In [4]:
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256), 
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)


#loss_function = DiceCELoss(to_onehot_y=True, sigmoid=True, squared_pred=True, ce_weight=calculate_weights(1792651250,2510860).to(device))
loss_function = DiceLoss(to_onehot_y=True, sigmoid=True, squared_pred=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-5, weight_decay=1e-5, amsgrad=True)

if __name__ == '__main__':
    
    train(model, data_in, loss_function, optimizer, 20, model_dir)


----------
epoch 1/600
2/4, Train_loss: 0.5902
Train_dice: 0.4098
3/4, Train_loss: 0.6014
Train_dice: 0.3986
4/4, Train_loss: 0.5896
Train_dice: 0.4104
5/4, Train_loss: 0.5864
Train_dice: 0.4136
--------------------
train_step:  5
Epoch_loss: 0.4735
Epoch_metric: 0.3265
test_loss_epoch: 0.4734
save_loss_test:  [0.4734485030174255]
test_dice_epoch: 0.3266
save_metric_test:  [0.32655149698257446]
current epoch: 1 current mean dice: 0.4142
best mean dice: 0.3266 at epoch: 1
----------
epoch 2/600
2/4, Train_loss: 0.5893
Train_dice: 0.4107
3/4, Train_loss: 0.6007
Train_dice: 0.3993
4/4, Train_loss: 0.5889
Train_dice: 0.4111
5/4, Train_loss: 0.5857
Train_dice: 0.4143
--------------------
train_step:  5
Epoch_loss: 0.4729
Epoch_metric: 0.3271
test_loss_epoch: 0.4742
save_loss_test:  [0.4734485030174255, 0.47415995597839355]
test_dice_epoch: 0.3258
save_metric_test:  [0.32655149698257446, 0.32584004402160643]
current epoch: 2 current mean dice: 0.4132
best mean dice: 0.3266 at epoch: 1
------

KeyboardInterrupt: 