In [40]:
# https://gist.github.com/BenjaminPhillips22/834e7434e1a4960d851f068f9c09a9c8

import pandas as pd
import numpy as np
import os
import numpy
import matplotlib.pyplot as plt
import SimpleITK
import itertools
import sys
import torch
from torchvision import transforms
from PIL import Image
from matplotlib import cm
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torchvision import models

from pathlib import Path

SOURCE_PATH = Path(os.getcwd()) / 'src'

if SOURCE_PATH not in sys.path:
    sys.path.append(SOURCE_PATH)

from src.extraction import (
    extract_images_in_survival_order,
)

from src.plots import (
    plot_observation,
    plot_deeplab_mobile_predictions,
    plot_mobile_prediction_from_path
)

from src.deeplab_mobile.modelling import(
    select_images_input,
    get_deeplab_mobile_model,
    train_deeplab_mobile
)

from src.deeplab_mobile.segdataset import(
    get_mobile_dataloaders
)

from src.utils import(
    LOGS_FILE_PATH,
    TYPE_NAMES
)

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [41]:
data_path = Path(os.getcwd()) / 'data' / 'HGG'
survival = pd.read_csv(data_path.parent /'survival_data.csv')

In [42]:
dir_ids = survival['BraTS19ID']

In [43]:
t2, t1ce, t1, flair, seg = extract_images_in_survival_order(data_path, dir_ids)
images = [t2, t1ce, t1, flair, seg]

In [45]:
for i in range(len(TYPE_NAMES)):
    survival[TYPE_NAMES[i]] = images[i] 

In [47]:
modelname = 'flair_totalpipe_old.pt'
model = torch.load(Path(os.getcwd()) / 'models' / modelname)
model.eval()

DeepLabV3(
  (backbone): IntermediateLayerGetter(
    (0): ConvNormActivation(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): Hardswish()
    )
    (1): InvertedResidual(
      (block): Sequential(
        (0): ConvNormActivation(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): ConvNormActivation(
          (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (2): InvertedResidual(
      (block): Sequential(
        (0): ConvNormActivation(
          (0): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1),

In [102]:
flair_segout = []

for i in range(len(survival)):
    input_tensor = torch.tensor(survival['flair'][i]).expand(3, -1, -1).type(torch.ShortTensor).float().unsqueeze(0)
    with torch.no_grad():
        output = model(input_tensor)['out'][0][0]
    flair_segout.append(np.array(output))

survival['flair_seg'] = flair_segout

In [122]:
for i in range(len(survival)):
    if 'ALIVE' in survival['Survival'].loc[i]:
        survival.drop(i, inplace=True)

    survival = survival.dropna().reset_index(drop=True)
print(len(survival))

103


In [132]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()

criterion = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

In [None]:
from torch.utils.data import Dataset

class Survival_Dataset(Dataset):
    def __init__(self,
                 images,
                 targets,
                 fraction = 0.1,
                 subset = None,
                 image_color_mode = "rgb"
                 ) -> None:

        if image_color_mode not in ["rgb", "grayscale"]:
            raise ValueError(
                f"{image_color_mode} is an invalid choice. Please enter from rgb grayscale."
            )

        self.image_color_mode = image_color_mode

        if subset not in ["Train", "Test"]:
            raise (ValueError(
                f"{subset} is not a valid input. Acceptable values are Train and Test."
            ))

        self.fraction = fraction
        self.images = images
        self.targets = targets

    def __len__(self) -> int:
        return len(self.images)

    def __getitem__(self, index: int):
        image = self.image_list[index]
        mask = self.mask_list[index]
        
        sample = {"image": image, "mask": mask}
        sample["image"] = torch.tensor(sample["image"]).expand(3, -1, -1).type(torch.ShortTensor)
        sample["mask"] = torch.tensor(sample["mask"]).expand(3, -1, -1).type(torch.ShortTensor)
        return sample

In [None]:
def get_survival_dataloaders(images, target, batch_size=14):
    image_datasets = {
        x: Survival_Dataset(
            images,
            target,
            transforms=None,
            subset=x)
        for x in ['Train', 'Test']
    }

    dataloaders = {
            x: DataLoader(image_datasets[x],
                        batch_size=batch_size,
                        )
            for x in ['Train', 'Test']
        }
    return dataloaders

In [None]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=5, shuffle=True, num_workers=2)

In [None]:
for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

In [105]:
survival

Unnamed: 0,BraTS19ID,Age,Survival,ResectionStatus,t2,t1,t1ce,flair,seg,flair_seg
0,BraTS19_CBICA_AAB_1,60.463014,289,GTR,"[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[-0.023373518, -0.023373518, -0.023373518, -0..."
1,BraTS19_CBICA_AAG_1,52.263014,616,GTR,"[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[-0.008784294, -0.008784294, -0.008784294, -0..."
2,BraTS19_CBICA_AAL_1,54.301370,464,GTR,"[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[-0.006356202, -0.006356202, -0.006356202, -0..."
3,BraTS19_CBICA_AAP_1,39.068493,788,GTR,"[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[-0.021736968, -0.021736968, -0.021736968, -0..."
4,BraTS19_CBICA_ABB_1,68.493151,465,GTR,"[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[-0.021570053, -0.021570053, -0.021570053, -0..."
...,...,...,...,...,...,...,...,...,...,...
254,BraTS19_TCIA08_436_1,44.180822,,,"[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[-0.010999404, -0.010999404, -0.010999404, -0..."
255,BraTS19_TCIA08_113_1,49.463014,,,"[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0.0055095404, 0.0055095404, 0.0055095404, 0...."
256,BraTS19_CBICA_BFP_1,70.652055,ALIVE (1029 days later),STR,"[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[-0.021447022, -0.021447022, -0.021447022, -0..."
257,BraTS19_CBICA_BHK_1,73.827397,1020,STR,"[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[[0.010417864, 0.010417864, 0.010417864, 0.010..."
