In [None]:
import pandas as pd
import numpy as np
import os
import numpy
import matplotlib.pyplot as plt
import SimpleITK
import itertools
import sys
import torch
from torchvision import transforms
from PIL import Image
from matplotlib import cm
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torchvision import models
from src.loading import get_dataloader_single_folder

from pathlib import Path

SOURCE_PATH = Path(os.getcwd()) / 'src'

if SOURCE_PATH not in sys.path:
    sys.path.append(SOURCE_PATH)

from src.extraction import (
    export_all_images_jpeg,
    export_images_list_jpeg,
    get_images_lists_from_path,
    get_images_lists_from_more_paths
)

from src.plots import (
    plot_observation
)

from src.modelling import (
    train_model
)

%load_ext autoreload
%autoreload 2

In [None]:
hgg = Path(os.getcwd()) / 'data' / 'HGG'
lgg = Path(os.getcwd()) / 'data' / 'LGG'

In [None]:
t2, t1ce, t1, flair, seg = get_images_lists_from_path(hgg)
t2l, t1cel, t1l, flairl, segl = get_images_lists_from_path(lgg)

type_names = ['t2', 't1', 't1ce', 'flair', 'seg']
images = [t2, t1, t1ce, flair, seg]
imagesl = [t2l, t1l, t1cel, flairl, segl]
all_images = get_images_lists_from_more_paths([hgg,lgg])

In [None]:
plot_observation(images, 0)

In [None]:
plot_observation(imagesl, 0)

In [None]:
plot_observation(all_images, 0)

In [None]:
type_to_use = 'flair'
images_chosen = np.array([images[type_names.index(type_to_use)][i] for i in range(len(images[type_names.index(type_to_use)]))])
images_seg = np.array([images[type_names.index('seg')][i] for i in range(len(images[type_names.index('seg')]))])

In [None]:
model = models.segmentation.deeplabv3_mobilenet_v3_large(
    pretrained=False,
    progress=True,
    num_classes = 3
)

In [None]:
class Custom_Dataset(Dataset):
    def __init__(self,
                 images_list,
                 seg_list,
                 fraction = 0.1,
                 subset = None,
                 transforms = None,
                 image_color_mode = "rgb",
                 mask_color_mode = "rgb") -> None:

        if image_color_mode not in ["rgb", "grayscale"]:
            raise ValueError(
                f"{image_color_mode} is an invalid choice. Please enter from rgb grayscale."
            )
        if mask_color_mode not in ["rgb", "grayscale"]:
            raise ValueError(
                f"{mask_color_mode} is an invalid choice. Please enter from rgb grayscale."
            )

        self.image_color_mode = image_color_mode
        self.mask_color_mode = mask_color_mode
        self.transforms = transforms

        if subset not in ["Train", "Test"]:
            raise (ValueError(
                f"{subset} is not a valid input. Acceptable values are Train and Test."
            ))
        self.fraction = fraction
        self.image_list = images_list
        self.mask_list = seg_list

    def __len__(self) -> int:
        return len(self.image_list)

    def __getitem__(self, index: int):
        image = self.image_list[index]
        mask = self.mask_list[index]
        
        sample = {"image": image, "mask": mask}
        # if self.transforms:
        #     sample["image"] = torch.tensor(sample["image"]).expand(3, -1, -1).type(torch.ShortTensor)
        #     sample["mask"] = torch.tensor(sample["mask"]).expand(3, -1, -1).type(torch.ShortTensor)
        sample["image"] = torch.tensor(sample["image"]).expand(3, -1, -1).type(torch.ShortTensor)
        sample["mask"] = torch.tensor(sample["mask"]).expand(3, -1, -1).type(torch.ShortTensor)
        return sample

In [None]:
# data_transforms = transforms.Compose(
#     [transforms.ToTensor(),
#     ])

image_datasets = {
        x: Custom_Dataset(
            images_chosen,
            images_seg,
            transforms=None,
            subset=x)
        for x in ['Train', 'Test']
    }

dataloaders = {
        x: DataLoader(image_datasets[x],
                      batch_size=5,
                      #shuffle=True, #to test without shuffle not done yet
                      #num_workers=8
                      )
        for x in ['Train', 'Test']
    }

In [None]:
import time
import copy
from tqdm import tqdm
from sklearn.metrics import f1_score
import csv

since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_loss = 1e10
metrics = {'f1_score': f1_score}
num_epochs = 2
bpath = Path(os.getcwd()) / 'models'

#criterion = torch.nn.CrossEntropyLoss(reduction='mean')
criterion = torch.nn.MSELoss(reduction='mean')

# Specify the optimizer with a lower learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


# Initialize the log file for training and testing loss and metrics
fieldnames = ['epoch', 'Train_loss', 'Test_loss'] + \
    [f'Train_{m}' for m in metrics.keys()] + \
    [f'Test_{m}' for m in metrics.keys()]

with open(os.path.join(bpath, 'log.csv'), 'w', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()

for epoch in range(1, num_epochs + 1):
    print('Epoch {}/{}'.format(epoch, num_epochs))
    print('-' * 10)
        # Each epoch has a training and validation phase
        # Initialize batch summary
    batchsummary = {a: [0] for a in fieldnames}

    for phase in ['Train', 'Test']:
        if phase == 'Train':
            model.train()  # Set model to training mode
        else:
            model.eval()  # Set model to evaluate mode

        # Iterate over data.
        for sample in tqdm(iter(dataloaders[phase])):
            inputs = sample['image'].float()
            masks = sample['mask'].float()
            # zero the parameter gradients
            optimizer.zero_grad()

            # track history if only in train
            with torch.set_grad_enabled(phase == 'Train'):
                outputs = model(inputs)
                loss = criterion(outputs['out'], masks)
                y_pred = outputs['out'].data.cpu().numpy().ravel()
                y_true = masks.data.cpu().numpy().ravel()
                for name, metric in metrics.items():
                    if name == 'f1_score':
                        # Use a classification threshold of 0.1
                        batchsummary[f'{phase}_{name}'].append(
                            metric(y_true > 0, y_pred > 0.1))
                    else:
                        batchsummary[f'{phase}_{name}'].append(
                            metric(y_true.astype('uint8'), y_pred))

                # backward + optimize only if in training phase
                if phase == 'Train':
                    loss.backward()
                    optimizer.step()
        batchsummary['epoch'] = epoch
        epoch_loss = loss
        batchsummary[f'{phase}_loss'] = epoch_loss.item()
        print('{} Loss: {:.4f}'.format(phase, loss))
    for field in fieldnames[3:]:
        batchsummary[field] = np.mean(batchsummary[field])
        print(batchsummary)
    with open(os.path.join(bpath, 'log.csv'), 'a', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writerow(batchsummary)
        # deep copy the model
        if phase == 'Test' and loss < best_loss:
            best_loss = loss
            best_model_wts = copy.deepcopy(model.state_dict())

time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
    time_elapsed // 60, time_elapsed % 60))
print('Lowest Loss: {:4f}'.format(best_loss))

# load best model weights
model.load_state_dict(best_model_wts)

In [None]:
model_exp_name = 'flair_totalpipe_final'
torch.save(model, str(Path(os.getcwd()) / 'models' / (model_exp_name + '.pt')))

In [None]:
modelname = 'flair_totalpipe_final.pt'
model = torch.load(Path(os.getcwd()) / 'models' / modelname)
model.eval()

In [None]:
indexes_predict = np.arange(0, 20)

for i in indexes_predict:
    input_tensor = torch.tensor(images[type_names.index(type_to_use)][i]).expand(3, -1, -1).type(torch.ShortTensor).float()
    truth = torch.tensor(images[type_names.index('seg')][i]).expand(3, -1, -1).type(torch.ShortTensor).float()

    input_batch = input_tensor.unsqueeze(0) 

    with torch.no_grad():
        output = model(input_batch)['out'][0]
    output_predictions = output #.argmax(0)

    # create a color pallette, selecting a color for each class
    palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
    colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
    colors = (colors % 255).numpy().astype("uint8")

    # r = Image.fromarray(output_predictions.byte().cpu().numpy()).resize(input_tensor.size)
    # r.putpalette(colors)

    f, ax = plt.subplots(1, 3, figsize=(15, 4))
    ax[0].set_title('input image')
    ax[0].axis('off')
    ax[0].imshow(input_tensor[0])
    ax[1].set_title('segmented output')
    ax[1].axis('off')
    ax[1].imshow(output_predictions[0])
    ax[2].set_title('ground truth')
    ax[2].axis('off')
    ax[2].imshow(truth[0])
    #plt.savefig("segmented_output.png", bbox_inches='tight')
    plt.show()

In [None]:
df = pd.read_csv(str(bpath) + '/log.csv')

In [None]:
# Plot all the values with respect to the epochs
df.plot(x='epoch',figsize=(15,8))