In [None]:
import pandas as pd
import numpy as np
import os
import numpy
import matplotlib.pyplot as plt
import SimpleITK
import itertools
import sys
import torch
from torchvision import transforms
from PIL import Image
from matplotlib import cm
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torchvision import models

from pathlib import Path

SOURCE_PATH = Path(os.getcwd()) / 'src'

if SOURCE_PATH not in sys.path:
    sys.path.append(SOURCE_PATH)

from src.extraction import (
    extract_images_in_survival_order,
    export_images_list_jpeg,
)

from src.plots import (
    plot_observation,
    plot_deeplab_mobile_predictions,
    plot_mobile_prediction_from_path
)

from src.deeplab_mobile.modelling import(
    select_images_input,
    get_deeplab_mobile_model,
    train_deeplab_mobile
)

from src.deeplab_mobile.segdataset import(
    get_mobile_dataloaders
)

from src.utils import(
    LOGS_FILE_PATH,
    TYPE_NAMES
)

%load_ext autoreload
%autoreload 2

In [None]:
data_path = Path(os.getcwd()) / 'data' / 'HGG'
survival = pd.read_csv(data_path.parent /'survival_data.csv')

In [None]:
dir_ids = survival['BraTS19ID']

In [None]:
t2, t1ce, t1, flair, seg = extract_images_in_survival_order(data_path, dir_ids)
images = [t2, t1ce, t1, flair, seg]

In [None]:
for i in range(len(TYPE_NAMES)):
    survival[TYPE_NAMES[i]] = images[i] 

In [None]:
modelname = 'flair_totalpipe_nighttrain.pt'
model = torch.load(Path(os.getcwd()) / 'models' / modelname)
model.eval()

In [None]:
flair_segout = []

for i in range(len(survival)):
    input_tensor = torch.tensor(survival['flair'][i]).expand(3, -1, -1).type(torch.ShortTensor).float().unsqueeze(0)
    with torch.no_grad():
        output = model(input_tensor)['out'][0][0]
    flair_segout.append(np.array(output))

survival['flair_seg'] = flair_segout

In [None]:
survival = survival.dropna().reset_index(drop=True)
survival['Survival'] = survival['Survival'].apply(lambda x: str(x))

In [None]:
to_drop = []
for i in range(len(survival)):
    if 'ALIVE' in survival['Survival'].loc[i]:
        to_drop.append(i)
        
        
survival.drop(to_drop, inplace=True)
print(len(survival))

In [None]:
survival['Survival'] = survival['Survival'].apply(lambda x: int(x))

In [None]:
modelname = "custom_made_survivalnet.pt"
net = torch.load(Path(os.getcwd()) / "models" / modelname)
net.eval()

In [None]:
from torch.utils.data import Dataset

class Survival_Dataset(Dataset):
    def __init__(self,
                 images,
                 targets,
                 fraction = 0.1,
                 subset = None,
                 image_color_mode = "rgb"
                 ) -> None:

        if image_color_mode not in ["rgb", "grayscale"]:
            raise ValueError(
                f"{image_color_mode} is an invalid choice. Please enter from rgb grayscale."
            )

        self.image_color_mode = image_color_mode

        if subset not in ["Train", "Test"]:
            raise (ValueError(
                f"{subset} is not a valid input. Acceptable values are Train and Test."
            ))

        self.fraction = fraction
        self.images = images
        self.targets = targets

    def __len__(self) -> int:
        return len(self.images)

    def __getitem__(self, index: int):
        image = self.images[index]
        target = self.targets[index]
        
        sample = {"image": image, "target": target}
        sample["image"] = torch.tensor(sample["image"]).expand(3, -1, -1).type(torch.ShortTensor)
        sample["target"] = target
        return sample

In [None]:
def get_survival_dataloaders(images, target, batch_size=14):
    image_datasets = {
        x: Survival_Dataset(
            images,
            target,
            subset=x)
        for x in ['Train', 'Test']
    }

    dataloaders = {
            x: DataLoader(image_datasets[x],
                        batch_size=batch_size,
                        )
            for x in ['Train', 'Test']
        }
    return dataloaders

In [None]:
dataloaders = get_survival_dataloaders(survival['flair_seg'].values, survival['Survival'].values)

In [None]:
print('prediction', model(dataloaders['Test'].dataset[30]['image'].unsqueeze(0).float()))
print('target', dataloaders['Test'].dataset[30]['target'])