In [13]:
import pandas as pd
import numpy as np
import os
import numpy
import matplotlib.pyplot as plt
import SimpleITK
import itertools
import sys
import torch
from torchvision import transforms
from PIL import Image
from matplotlib import cm
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torchvision import models

from pathlib import Path

from src.survival.survival_custom import SurvivalNet, get_survival_dataloaders, train_survivalnet

SOURCE_PATH = Path(os.getcwd()) / 'src'

if SOURCE_PATH not in sys.path:
    sys.path.append(SOURCE_PATH)

from src.extraction import (
    extract_images_in_survival_order,
    export_images_list_jpeg,
)

from src.plots import (
    plot_observation,
    plot_deeplab_mobile_predictions,
    plot_mobile_prediction_from_path
)

from src.deeplab_mobile.modelling import(
    select_images_input,
    get_deeplab_mobile_model,
    train_deeplab_mobile
)

from src.deeplab_mobile.segdataset import(
    get_mobile_dataloaders
)

from src.utils import(
    LOGS_FILE_PATH,
    TYPE_NAMES
)

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [14]:
data_path = Path(os.getcwd()) / 'data' / 'HGG'
survival = pd.read_csv(data_path.parent /'survival_data.csv')

dir_ids = survival['BraTS19ID']

In [15]:
t2, t1ce, t1, flair, seg = extract_images_in_survival_order(data_path, dir_ids)
images = [t2, t1ce, t1, flair, seg]

for i in range(len(TYPE_NAMES)):
    survival[TYPE_NAMES[i]] = images[i] 

In [16]:
modelname = 'flair_totalpipe_nighttrain.pt'
model = torch.load(Path(os.getcwd()) / 'models' / modelname)
model.eval()

DeepLabV3(
  (backbone): IntermediateLayerGetter(
    (0): ConvNormActivation(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): Hardswish()
    )
    (1): InvertedResidual(
      (block): Sequential(
        (0): ConvNormActivation(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): ConvNormActivation(
          (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (2): InvertedResidual(
      (block): Sequential(
        (0): ConvNormActivation(
          (0): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1),

In [17]:
flair_segout = []
for i in range(len(survival)):
    input_tensor = torch.tensor(survival['flair'][i]).expand(3, -1, -1).type(torch.ShortTensor).float().unsqueeze(0)
    with torch.no_grad():
        output = model(input_tensor)['out'][0][0]
    flair_segout.append(np.array(output))

survival['flair_seg'] = flair_segout

In [18]:
survival = survival.dropna().reset_index(drop=True)
survival['Survival'] = survival['Survival'].apply(lambda x: str(x))

to_drop = []
for i in range(len(survival)):
    if 'ALIVE' in survival['Survival'].loc[i]:
        to_drop.append(i)
        
        
survival.drop(to_drop, inplace=True)
print(len(survival))

103


In [19]:
survival['Survival'] = survival['Survival'].apply(lambda x: int(x))

In [20]:
net = SurvivalNet()

In [21]:
dataloaders = get_survival_dataloaders(survival['flair_seg'].values, survival['Survival'].values)

In [22]:
train_survivalnet(net, dataloaders, num_epochs=50)

Epoch 1/50
----------


100%|██████████| 8/8 [00:10<00:00,  1.33s/it]


Train Loss: 379369.4375


100%|██████████| 8/8 [00:01<00:00,  4.95it/s]


Test Loss: 354976.7500
{'epoch': 1, 'Train_loss': 379369.4375, 'Test_loss': 354976.75, 'Train_f1_score': 0.7777777777777778, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 1, 'Train_loss': 379369.4375, 'Test_loss': 354976.75, 'Train_f1_score': 0.7777777777777778, 'Test_f1_score': 0.8888888888888888}
Epoch 2/50
----------


100%|██████████| 8/8 [00:09<00:00,  1.22s/it]


Train Loss: 83554.4375


100%|██████████| 8/8 [00:01<00:00,  5.32it/s]


Test Loss: 77209.0391
{'epoch': 2, 'Train_loss': 83554.4375, 'Test_loss': 77209.0390625, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 2, 'Train_loss': 83554.4375, 'Test_loss': 77209.0390625, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 3/50
----------


100%|██████████| 8/8 [00:10<00:00,  1.29s/it]


Train Loss: 146721.5938


100%|██████████| 8/8 [00:02<00:00,  3.39it/s]


Test Loss: 144991.3906
{'epoch': 3, 'Train_loss': 146721.59375, 'Test_loss': 144991.390625, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 3, 'Train_loss': 146721.59375, 'Test_loss': 144991.390625, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 4/50
----------


100%|██████████| 8/8 [00:08<00:00,  1.09s/it]


Train Loss: 78513.4219


100%|██████████| 8/8 [00:01<00:00,  4.40it/s]


Test Loss: 69699.2969
{'epoch': 4, 'Train_loss': 78513.421875, 'Test_loss': 69699.296875, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 4, 'Train_loss': 78513.421875, 'Test_loss': 69699.296875, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 5/50
----------


100%|██████████| 8/8 [00:07<00:00,  1.03it/s]


Train Loss: 95542.2578


100%|██████████| 8/8 [00:01<00:00,  5.85it/s]


Test Loss: 99789.3438
{'epoch': 5, 'Train_loss': 95542.2578125, 'Test_loss': 99789.34375, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 5, 'Train_loss': 95542.2578125, 'Test_loss': 99789.34375, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 6/50
----------


100%|██████████| 8/8 [00:07<00:00,  1.08it/s]


Train Loss: 92874.4219


100%|██████████| 8/8 [00:01<00:00,  5.07it/s]


Test Loss: 84858.7578
{'epoch': 6, 'Train_loss': 92874.421875, 'Test_loss': 84858.7578125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 6, 'Train_loss': 92874.421875, 'Test_loss': 84858.7578125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 7/50
----------


100%|██████████| 8/8 [00:07<00:00,  1.02it/s]


Train Loss: 82308.9766


100%|██████████| 8/8 [00:01<00:00,  4.96it/s]


Test Loss: 83076.7500
{'epoch': 7, 'Train_loss': 82308.9765625, 'Test_loss': 83076.75, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 7, 'Train_loss': 82308.9765625, 'Test_loss': 83076.75, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 8/50
----------


100%|██████████| 8/8 [00:07<00:00,  1.08it/s]


Train Loss: 91646.1562


100%|██████████| 8/8 [00:01<00:00,  5.85it/s]


Test Loss: 87777.8203
{'epoch': 8, 'Train_loss': 91646.15625, 'Test_loss': 87777.8203125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 8, 'Train_loss': 91646.15625, 'Test_loss': 87777.8203125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 9/50
----------


100%|██████████| 8/8 [00:07<00:00,  1.14it/s]


Train Loss: 83789.7656


100%|██████████| 8/8 [00:01<00:00,  5.76it/s]


Test Loss: 82997.8359
{'epoch': 9, 'Train_loss': 83789.765625, 'Test_loss': 82997.8359375, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 9, 'Train_loss': 83789.765625, 'Test_loss': 82997.8359375, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 10/50
----------


100%|██████████| 8/8 [00:06<00:00,  1.22it/s]


Train Loss: 88390.1328


100%|██████████| 8/8 [00:01<00:00,  6.42it/s]


Test Loss: 86838.0078
{'epoch': 10, 'Train_loss': 88390.1328125, 'Test_loss': 86838.0078125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 10, 'Train_loss': 88390.1328125, 'Test_loss': 86838.0078125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 11/50
----------


100%|██████████| 8/8 [00:07<00:00,  1.14it/s]


Train Loss: 84444.2109


100%|██████████| 8/8 [00:01<00:00,  5.91it/s]


Test Loss: 83280.9844
{'epoch': 11, 'Train_loss': 84444.2109375, 'Test_loss': 83280.984375, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 11, 'Train_loss': 84444.2109375, 'Test_loss': 83280.984375, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 12/50
----------


100%|██████████| 8/8 [00:08<00:00,  1.02s/it]


Train Loss: 84763.0469


100%|██████████| 8/8 [00:01<00:00,  4.51it/s]


Test Loss: 83387.9531
{'epoch': 12, 'Train_loss': 84763.046875, 'Test_loss': 83387.953125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 12, 'Train_loss': 84763.046875, 'Test_loss': 83387.953125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 13/50
----------


100%|██████████| 8/8 [00:09<00:00,  1.17s/it]


Train Loss: 82448.1094


100%|██████████| 8/8 [00:01<00:00,  4.64it/s]


Test Loss: 81017.8281
{'epoch': 13, 'Train_loss': 82448.109375, 'Test_loss': 81017.828125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 13, 'Train_loss': 82448.109375, 'Test_loss': 81017.828125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 14/50
----------


100%|██████████| 8/8 [00:09<00:00,  1.14s/it]


Train Loss: 81806.2578


100%|██████████| 8/8 [00:01<00:00,  5.33it/s]


Test Loss: 80279.3594
{'epoch': 14, 'Train_loss': 81806.2578125, 'Test_loss': 80279.359375, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 14, 'Train_loss': 81806.2578125, 'Test_loss': 80279.359375, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 15/50
----------


100%|██████████| 8/8 [00:09<00:00,  1.24s/it]


Train Loss: 81383.2266


100%|██████████| 8/8 [00:01<00:00,  4.14it/s]


Test Loss: 80205.6250
{'epoch': 15, 'Train_loss': 81383.2265625, 'Test_loss': 80205.625, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 15, 'Train_loss': 81383.2265625, 'Test_loss': 80205.625, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 16/50
----------


100%|██████████| 8/8 [00:09<00:00,  1.21s/it]


Train Loss: 79925.2891


100%|██████████| 8/8 [00:01<00:00,  5.27it/s]


Test Loss: 78628.9141
{'epoch': 16, 'Train_loss': 79925.2890625, 'Test_loss': 78628.9140625, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 16, 'Train_loss': 79925.2890625, 'Test_loss': 78628.9140625, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 17/50
----------


100%|██████████| 8/8 [00:08<00:00,  1.06s/it]


Train Loss: 79453.6875


100%|██████████| 8/8 [00:01<00:00,  5.41it/s]


Test Loss: 78337.8047
{'epoch': 17, 'Train_loss': 79453.6875, 'Test_loss': 78337.8046875, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 17, 'Train_loss': 79453.6875, 'Test_loss': 78337.8046875, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 18/50
----------


100%|██████████| 8/8 [00:08<00:00,  1.08s/it]


Train Loss: 77835.8906


100%|██████████| 8/8 [00:01<00:00,  4.26it/s]


Test Loss: 76740.0781
{'epoch': 18, 'Train_loss': 77835.890625, 'Test_loss': 76740.078125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 18, 'Train_loss': 77835.890625, 'Test_loss': 76740.078125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 19/50
----------


100%|██████████| 8/8 [00:08<00:00,  1.03s/it]


Train Loss: 76633.4531


100%|██████████| 8/8 [00:01<00:00,  5.24it/s]


Test Loss: 75705.9531
{'epoch': 19, 'Train_loss': 76633.453125, 'Test_loss': 75705.953125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 19, 'Train_loss': 76633.453125, 'Test_loss': 75705.953125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 20/50
----------


100%|██████████| 8/8 [00:07<00:00,  1.02it/s]


Train Loss: 75242.2031


100%|██████████| 8/8 [00:02<00:00,  3.46it/s]


Test Loss: 74808.5156
{'epoch': 20, 'Train_loss': 75242.203125, 'Test_loss': 74808.515625, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 20, 'Train_loss': 75242.203125, 'Test_loss': 74808.515625, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 21/50
----------


100%|██████████| 8/8 [00:08<00:00,  1.12s/it]


Train Loss: 73300.7656


100%|██████████| 8/8 [00:01<00:00,  5.34it/s]


Test Loss: 72787.8281
{'epoch': 21, 'Train_loss': 73300.765625, 'Test_loss': 72787.828125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 21, 'Train_loss': 73300.765625, 'Test_loss': 72787.828125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 22/50
----------


100%|██████████| 8/8 [00:07<00:00,  1.13it/s]


Train Loss: 73080.3672


100%|██████████| 8/8 [00:01<00:00,  6.00it/s]


Test Loss: 72523.8516
{'epoch': 22, 'Train_loss': 73080.3671875, 'Test_loss': 72523.8515625, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 22, 'Train_loss': 73080.3671875, 'Test_loss': 72523.8515625, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 23/50
----------


100%|██████████| 8/8 [00:06<00:00,  1.16it/s]


Train Loss: 71156.9609


100%|██████████| 8/8 [00:01<00:00,  6.24it/s]


Test Loss: 71158.1953
{'epoch': 23, 'Train_loss': 71156.9609375, 'Test_loss': 71158.1953125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 23, 'Train_loss': 71156.9609375, 'Test_loss': 71158.1953125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 24/50
----------


100%|██████████| 8/8 [00:07<00:00,  1.13it/s]


Train Loss: 71319.5781


100%|██████████| 8/8 [00:01<00:00,  5.59it/s]


Test Loss: 70818.4453
{'epoch': 24, 'Train_loss': 71319.578125, 'Test_loss': 70818.4453125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 24, 'Train_loss': 71319.578125, 'Test_loss': 70818.4453125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 25/50
----------


100%|██████████| 8/8 [00:08<00:00,  1.05s/it]


Train Loss: 71864.7031


100%|██████████| 8/8 [00:01<00:00,  4.52it/s]


Test Loss: 71333.9531
{'epoch': 25, 'Train_loss': 71864.703125, 'Test_loss': 71333.953125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 25, 'Train_loss': 71864.703125, 'Test_loss': 71333.953125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 26/50
----------


100%|██████████| 8/8 [00:08<00:00,  1.02s/it]


Train Loss: 71644.5469


100%|██████████| 8/8 [00:01<00:00,  4.30it/s]


Test Loss: 70918.0391
{'epoch': 26, 'Train_loss': 71644.546875, 'Test_loss': 70918.0390625, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 26, 'Train_loss': 71644.546875, 'Test_loss': 70918.0390625, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 27/50
----------


100%|██████████| 8/8 [00:07<00:00,  1.13it/s]


Train Loss: 72473.2969


100%|██████████| 8/8 [00:01<00:00,  4.50it/s]


Test Loss: 71001.2812
{'epoch': 27, 'Train_loss': 72473.296875, 'Test_loss': 71001.28125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 27, 'Train_loss': 72473.296875, 'Test_loss': 71001.28125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 28/50
----------


100%|██████████| 8/8 [00:06<00:00,  1.22it/s]


Train Loss: 73580.4922


100%|██████████| 8/8 [00:01<00:00,  6.26it/s]


Test Loss: 72680.2656
{'epoch': 28, 'Train_loss': 73580.4921875, 'Test_loss': 72680.265625, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 28, 'Train_loss': 73580.4921875, 'Test_loss': 72680.265625, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 29/50
----------


100%|██████████| 8/8 [00:07<00:00,  1.08it/s]


Train Loss: 73598.6719


100%|██████████| 8/8 [00:01<00:00,  4.57it/s]


Test Loss: 72120.2578
{'epoch': 29, 'Train_loss': 73598.671875, 'Test_loss': 72120.2578125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 29, 'Train_loss': 73598.671875, 'Test_loss': 72120.2578125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 30/50
----------


100%|██████████| 8/8 [00:08<00:00,  1.12s/it]


Train Loss: 74800.5625


100%|██████████| 8/8 [00:01<00:00,  4.23it/s]


Test Loss: 71892.6875
{'epoch': 30, 'Train_loss': 74800.5625, 'Test_loss': 71892.6875, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 30, 'Train_loss': 74800.5625, 'Test_loss': 71892.6875, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 31/50
----------


100%|██████████| 8/8 [00:09<00:00,  1.19s/it]


Train Loss: 79638.2500


100%|██████████| 8/8 [00:01<00:00,  5.03it/s]


Test Loss: 76758.8281
{'epoch': 31, 'Train_loss': 79638.25, 'Test_loss': 76758.828125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 31, 'Train_loss': 79638.25, 'Test_loss': 76758.828125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 32/50
----------


100%|██████████| 8/8 [00:07<00:00,  1.05it/s]


Train Loss: 76504.3516


100%|██████████| 8/8 [00:01<00:00,  5.47it/s]


Test Loss: 75009.4141
{'epoch': 32, 'Train_loss': 76504.3515625, 'Test_loss': 75009.4140625, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 32, 'Train_loss': 76504.3515625, 'Test_loss': 75009.4140625, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 33/50
----------


100%|██████████| 8/8 [00:08<00:00,  1.04s/it]


Train Loss: 70068.4531


100%|██████████| 8/8 [00:01<00:00,  5.22it/s]


Test Loss: 68908.0469
{'epoch': 33, 'Train_loss': 70068.453125, 'Test_loss': 68908.046875, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 33, 'Train_loss': 70068.453125, 'Test_loss': 68908.046875, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 34/50
----------


100%|██████████| 8/8 [00:07<00:00,  1.02it/s]


Train Loss: 90034.6953


100%|██████████| 8/8 [00:01<00:00,  4.18it/s]


Test Loss: 83091.5781
{'epoch': 34, 'Train_loss': 90034.6953125, 'Test_loss': 83091.578125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 34, 'Train_loss': 90034.6953125, 'Test_loss': 83091.578125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 35/50
----------


100%|██████████| 8/8 [00:09<00:00,  1.13s/it]


Train Loss: 75429.4297


100%|██████████| 8/8 [00:01<00:00,  4.50it/s]


Test Loss: 76829.4453
{'epoch': 35, 'Train_loss': 75429.4296875, 'Test_loss': 76829.4453125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 35, 'Train_loss': 75429.4296875, 'Test_loss': 76829.4453125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 36/50
----------


100%|██████████| 8/8 [00:08<00:00,  1.06s/it]


Train Loss: 74265.8906


100%|██████████| 8/8 [00:01<00:00,  4.55it/s]


Test Loss: 70951.1719
{'epoch': 36, 'Train_loss': 74265.890625, 'Test_loss': 70951.171875, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 36, 'Train_loss': 74265.890625, 'Test_loss': 70951.171875, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 37/50
----------


100%|██████████| 8/8 [00:08<00:00,  1.01s/it]


Train Loss: 81654.4844


100%|██████████| 8/8 [00:01<00:00,  5.40it/s]


Test Loss: 80542.3125
{'epoch': 37, 'Train_loss': 81654.484375, 'Test_loss': 80542.3125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 37, 'Train_loss': 81654.484375, 'Test_loss': 80542.3125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 38/50
----------


100%|██████████| 8/8 [00:08<00:00,  1.02s/it]


Train Loss: 70195.0938


100%|██████████| 8/8 [00:01<00:00,  5.50it/s]


Test Loss: 68677.3516
{'epoch': 38, 'Train_loss': 70195.09375, 'Test_loss': 68677.3515625, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 38, 'Train_loss': 70195.09375, 'Test_loss': 68677.3515625, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 39/50
----------


100%|██████████| 8/8 [00:08<00:00,  1.01s/it]


Train Loss: 82313.2656


100%|██████████| 8/8 [00:01<00:00,  4.82it/s]


Test Loss: 76091.7969
{'epoch': 39, 'Train_loss': 82313.265625, 'Test_loss': 76091.796875, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 39, 'Train_loss': 82313.265625, 'Test_loss': 76091.796875, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 40/50
----------


100%|██████████| 8/8 [00:08<00:00,  1.05s/it]


Train Loss: 77669.4922


100%|██████████| 8/8 [00:01<00:00,  5.01it/s]


Test Loss: 77419.8438
{'epoch': 40, 'Train_loss': 77669.4921875, 'Test_loss': 77419.84375, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 40, 'Train_loss': 77669.4921875, 'Test_loss': 77419.84375, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 41/50
----------


100%|██████████| 8/8 [00:07<00:00,  1.01it/s]


Train Loss: 70922.7812


100%|██████████| 8/8 [00:02<00:00,  3.69it/s]


Test Loss: 68152.7812
{'epoch': 41, 'Train_loss': 70922.78125, 'Test_loss': 68152.78125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 41, 'Train_loss': 70922.78125, 'Test_loss': 68152.78125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 42/50
----------


100%|██████████| 8/8 [00:09<00:00,  1.18s/it]


Train Loss: 83561.7656


100%|██████████| 8/8 [00:01<00:00,  5.18it/s]


Test Loss: 80873.8984
{'epoch': 42, 'Train_loss': 83561.765625, 'Test_loss': 80873.8984375, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 42, 'Train_loss': 83561.765625, 'Test_loss': 80873.8984375, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 43/50
----------


100%|██████████| 8/8 [00:07<00:00,  1.12it/s]


Train Loss: 73920.4297


100%|██████████| 8/8 [00:01<00:00,  5.96it/s]


Test Loss: 72203.6406
{'epoch': 43, 'Train_loss': 73920.4296875, 'Test_loss': 72203.640625, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 43, 'Train_loss': 73920.4296875, 'Test_loss': 72203.640625, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 44/50
----------


100%|██████████| 8/8 [00:08<00:00,  1.00s/it]


Train Loss: 77054.0234


100%|██████████| 8/8 [00:01<00:00,  5.87it/s]


Test Loss: 71196.1641
{'epoch': 44, 'Train_loss': 77054.0234375, 'Test_loss': 71196.1640625, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 44, 'Train_loss': 77054.0234375, 'Test_loss': 71196.1640625, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 45/50
----------


100%|██████████| 8/8 [00:07<00:00,  1.08it/s]


Train Loss: 80710.7188


100%|██████████| 8/8 [00:01<00:00,  5.96it/s]


Test Loss: 79917.0469
{'epoch': 45, 'Train_loss': 80710.71875, 'Test_loss': 79917.046875, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 45, 'Train_loss': 80710.71875, 'Test_loss': 79917.046875, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 46/50
----------


100%|██████████| 8/8 [00:07<00:00,  1.10it/s]


Train Loss: 71854.0625


100%|██████████| 8/8 [00:02<00:00,  3.52it/s]


Test Loss: 69736.6406
{'epoch': 46, 'Train_loss': 71854.0625, 'Test_loss': 69736.640625, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 46, 'Train_loss': 71854.0625, 'Test_loss': 69736.640625, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 47/50
----------


100%|██████████| 8/8 [00:11<00:00,  1.42s/it]


Train Loss: 80914.9688


100%|██████████| 8/8 [00:01<00:00,  4.42it/s]


Test Loss: 76233.1953
{'epoch': 47, 'Train_loss': 80914.96875, 'Test_loss': 76233.1953125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 47, 'Train_loss': 80914.96875, 'Test_loss': 76233.1953125, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 48/50
----------


100%|██████████| 8/8 [00:09<00:00,  1.19s/it]


Train Loss: 77245.8281


100%|██████████| 8/8 [00:01<00:00,  4.69it/s]


Test Loss: 76098.4922
{'epoch': 48, 'Train_loss': 77245.828125, 'Test_loss': 76098.4921875, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 48, 'Train_loss': 77245.828125, 'Test_loss': 76098.4921875, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 49/50
----------


100%|██████████| 8/8 [00:07<00:00,  1.07it/s]


Train Loss: 73207.8281


100%|██████████| 8/8 [00:01<00:00,  5.16it/s]


Test Loss: 69943.6875
{'epoch': 49, 'Train_loss': 73207.828125, 'Test_loss': 69943.6875, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 49, 'Train_loss': 73207.828125, 'Test_loss': 69943.6875, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Epoch 50/50
----------


100%|██████████| 8/8 [00:07<00:00,  1.09it/s]


Train Loss: 80614.4062


100%|██████████| 8/8 [00:01<00:00,  5.12it/s]


Test Loss: 78451.2734
{'epoch': 50, 'Train_loss': 80614.40625, 'Test_loss': 78451.2734375, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': [0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}
{'epoch': 50, 'Train_loss': 80614.40625, 'Test_loss': 78451.2734375, 'Train_f1_score': 0.8888888888888888, 'Test_f1_score': 0.8888888888888888}
Training complete in 8m 17s
Lowest Loss: 68152.781250


In [23]:
model_exp_name = "custom_made_survivalnet_final"
torch.save(net, str(Path(os.getcwd()) / "models" / (model_exp_name + ".pt")))

In [24]:
print('prediction', net(dataloaders['Test'].dataset[30]['image'].unsqueeze(0).float()))
print('target', dataloaders['Test'].dataset[30]['target'])

prediction tensor([[670.8815]], grad_fn=<AddmmBackward0>)
target 597
