- creating the datasets for training
- instantiate some 2d/3d Unet
- training loop


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import monai
import glob
import SimpleITK as sitk
import os

In [None]:
data_path = "database"

if not os.path.exists(data_path):
    print("Please update your data path to an existing folder.")
elif not set(["training", "testing"]).issubset(set(os.listdir(data_path))):
    print("Please update your data path to the correct folder (should contain train, val and test folders).")
else:
    print("Congrats! You selected the correct folder :)")

In [None]:
class ACDCDataset(monai.data.Dataset):
    def __init__(self, rootpath, mode, transform=None):
        if mode not in ["training", "testing"]:
            raise Exception("must be either training or testing for the dataset to be loaded")

        self.path = os.path.join(rootpath, mode)
        self.transform = transform
        self.data = []
        self.load_data()

    def load_data(self):
        """
        returns dict{2dimg, 2dmask}
        """
        for patient in next(os.walk(self.path))[1]:
            patient_paths = glob.glob(os.path.join(self.path, patient, '*.gz'))

            patient_paths.sort()
            self.load_patient(patient_paths)

    def load_patient(self, patient_paths):
        for combi in [(1, 2), (3, 4)]:
            image = sitk.ReadImage(patient_paths[combi[0]])
            image_array = sitk.GetArrayFromImage(image)

            mask = sitk.ReadImage(patient_paths[combi[1]])
            mask_array = sitk.GetArrayFromImage(mask)

            for i in range(image_array.shape[0]):
                dictionary = {}
                dictionary['img'] = image_array[i, :, :]
                dictionary['mask'] = mask_array[i, :, :]

                self.data.append(dictionary)

    def __getitem__(self, index):
        # Make getitem return a dictionary with keys ['img', 'label'] for the image and label respectively
        item = self.data[index]
        if self.transform:
            item = self.transform(item)
        return item

    def get_total_meansd(self):
        norm = []
        for x in self.data:
          norm.append(x["img"])

        norm = np.array(norm)
        return np.mean(norm), np.std(norm)

    def __len__(self):
        return len(self.data)

In [None]:
transforms = monai.transforms.Compose([
    monai.transforms.AddChanneld(keys=['img', 'mask']),
    monai.transforms.NormalizeIntensityd(keys='img', subtrahend=67.27, divisor=84.66),
    monai.transforms.Resized(keys=['img', 'mask'], spatial_size=(200, 200))
])

In [None]:
train_dataset = ACDCDataset(data_path, "training", transforms)
test_dataset = ACDCDataset(data_path, "testing", transforms)

In [None]:
train_loader = monai.data.DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = monai.data.DataLoader(test_dataset, batch_size=16, shuffle=True)

In [None]:
def visualize_sample(sample, title=None):
    # Visualize the x-ray and overlay the mask, using the dictionary as input
    image = np.squeeze(sample['img'])
    mask = np.squeeze(sample['mask'])
    plt.figure(figsize=[10,7])
    plt.imshow(image, 'gray')
    overlay_mask = np.ma.masked_where(mask == 0, mask == 1)
    plt.imshow(overlay_mask, 'Greens', alpha = 0.7, clim=[0,1], interpolation='nearest')
    if title is not None:
        plt.title(title)
    plt.show()

In [None]:
train_dataset[0]

In [None]:
visualize_sample(train_dataset[0])

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'The used device is {device}')

In [None]:
model = monai.networks.nets.UNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)

In [None]:
from tqdm import tqdm
import numpy as np
import monai
import torch

loss_function =  monai.losses.DiceLoss(sigmoid=True, batch=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

num_epochs = 40

dataloaders = {'train': train_loader, 'val': test_loader}

for epoch in range(num_epochs):
    print(f"Epoch: {epoch+1}/{num_epochs}")
    epoch_losses = {'train': [], 'val': []}
    batch_data = {'train': [], 'val': []}
    outputs = {'train': [], 'val': []}

    for mode in ['train', 'val']:
    # for mode in ['train']:
        print(f"Current mode: {mode}")
        for i, batch in enumerate(tqdm(dataloaders[mode])):
            # batch_data[mode].extend(batch['img'])
            # batch_data[mode].extend(batch['mask'])
            x_batch = batch['img'].to(device)
            y_batch = batch['mask'].to(device)

            output = model(x_batch)

            loss = loss_function(output, y_batch)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # outputs[mode].extend(output.cpu().numpy)
            epoch_losses[mode].append(loss.item())
        print(f"Mean loss in {mode} mode: {np.mean(epoch_losses[mode])}")

    # log_to_wandb(epoch, epoch_loss['train'], epoch_loss['val'], batch_data['train'], outputs['train'])


# Store the network parameters
torch.save(model.state_dict(), r'trainedUNet.pt')

In [None]:
batch = iter(test_loader).next()

outputs = model(batch['img'].to(device))

In [None]:
visualize_sample({'img': batch['img'][0, 0, :, :], 'mask': outputs.detach().cpu().numpy()[0, 0, :, :]>0})

In [None]:
# plt.imshow(batch['img'][0, 0, :, :], cmap='gray')
plt.imshow(outputs.detach().cpu().numpy()[0, 0, :, :], alpha=.5)
plt.imshow(batch['mask'][0, 0, :, :] > 0, alpha=.5)

In [None]:
outputs.detach().cpu().numpy()[0] > 0

In [None]:
outputs.detach().cpu().numpy()[0, 0, :, :]