In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import math
from pymodules.transformations import Compose, ToTensor, RandomHorizontalFlip, RandomCrop, Resize, RandomCrop,RandomRotation,RandomGaussianNoise
from pymodules.TeslaSiemensDataset import TeslaSiemensDataset
from pymodules.model import SegNet
from pymodules.LossFunctions import dice
from pymodules.trainloop import train
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn as nn
from pathlib import Path
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
rgb_map = [
    [255, 0, 0],  # cap
    [181,70,174],  # cg
    [61,184,102],  # pz
    [0,0,0]  # background
]

import matplotlib.pyplot as plt
import numpy as np

def vizualize_labels(true,pred):
    maxes = torch.argmax(true, dim=0)
    rgb_values = [rgb_map[p] for p in maxes.numpy().flatten()]
    matlib_true = np.array(rgb_values).reshape(true.shape[1], true.shape[2], 3)

    maxes = torch.argmax(pred, dim=0)
    rgb_values = [rgb_map[p] for p in maxes.numpy().flatten()]
    matlib_pred = np.array(rgb_values).reshape(true.shape[1], true.shape[2], 3)

    f, axarr = plt.subplots(1, 2)
    axarr[0].set_title('True')
    axarr[0].imshow(matlib_true)
    axarr[1].set_title('Pred')
    axarr[1].imshow(matlib_pred)
    plt.show()

In [3]:
transform = Compose([
    ToTensor(),
    Resize((368, 448)),
    RandomHorizontalFlip(),
    RandomCrop(10),
    RandomRotation(2),
    RandomGaussianNoise(0.008)
])

transform_test = Compose([
    ToTensor(),
    Resize((368, 448)),
])

BATCH_SIZE = 5


# Dataset needs to be downloaded

trainset = TeslaSiemensDataset(root_dir='data/siemens_reduced/train/', transform=transform,include_cap=True,num_of_surrouding_imgs=1)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,shuffle=True)

testset = TeslaSiemensDataset(root_dir='data/siemens_reduced/test/', transform=transform_test,include_cap=True,num_of_surrouding_imgs=1)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,shuffle=True)

In [4]:

train_name = 'training'
durchlauf = '1_images_with_cap_full'
durchlauf_path = f"train/{train_name}/{durchlauf}"
Path(durchlauf_path).mkdir(parents=True, exist_ok=True)
EPOCHS = 80
loss_fn = dice


net = SegNet(3,4)
net = net.to(device)
optimizer = optim.Adam(net.parameters(), lr=0.0001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
train(model=net, optimizer=optimizer, loss_fn=loss_fn, train_loader=trainloader, val_loader=testloader, epochs=EPOCHS, device=device,best_model_dir=durchlauf_path,early_stopping_patience=100,step_size_decay=scheduler)
del net
torch.cuda.empty_cache()

-------Epoch 1-------


  "Argument interpolation should be of type InterpolationMode instead of int. "


 CaP: 0.0 | CG 0.0 | PZ 0.0 | BG 0.9998224763309255
Validation Loss: 0.75 (Took 79.83020091056824 seconds)
Saving best validation loss 0.75
-------Epoch 2-------
 CaP: 0.6094084735302363 | CG 0.27800464238081235 | PZ 0.7414779584197437 | BG 0.9051427175016964
Validation Loss: 0.62 (Took 79.32795238494873 seconds)
Saving best validation loss 0.62
-------Epoch 3-------
 CaP: 0.6668221288575575 | CG 0.58220165398191 | PZ 0.6513841345029718 | BG 0.9608750904307646
Validation Loss: 0.57 (Took 79.36186146736145 seconds)
Saving best validation loss 0.57
-------Epoch 4-------
 CaP: 0.513890150014092 | CG 0.4224753695375779 | PZ 0.7750912887208602 | BG 0.958408979808583
Validation Loss: 0.61 (Took 79.45461320877075 seconds)
-------Epoch 5-------
 CaP: 0.6426750597827894 | CG 0.274440943975659 | PZ 0.5501725918010754 | BG 0.9929771563586067
Validation Loss: 0.52 (Took 79.50049114227295 seconds)
Saving best validation loss 0.52
-------Epoch 6-------
 CaP: 0.42928423629525836 | CG 0.13108591293441

 CaP: 0.5416384584763471 | CG 0.4809585288167 | PZ 0.6867643489557154 | BG 0.996891481034896
Validation Loss: 0.47 (Took 79.40075707435608 seconds)
-------Epoch 49-------
 CaP: 0.5105161465266171 | CG 0.5032339867423562 | PZ 0.8095775106373955 | BG 0.9975351340630475
Validation Loss: 0.46 (Took 79.2960376739502 seconds)
-------Epoch 50-------
Training Loss: 0.347 (Progress: 72.917)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



 CaP: 0.4790484988713329 | CG 0.43969983914319205 | PZ 0.7564096959198222 | BG 0.9971227470566245
Validation Loss: 0.46 (Took 83.4324278831482 seconds)
-------Epoch 70-------
 CaP: 0.5100867382504725 | CG 0.5140670048182502 | PZ 0.7083568643121159 | BG 0.9971821904182434
Validation Loss: 0.46 (Took 79.43666195869446 seconds)
-------Epoch 71-------
 CaP: 0.6471912723058968 | CG 0.48161176597589955 | PZ 0.7201061108533073 | BG 0.9969140431460213
Validation Loss: 0.46 (Took 79.4984962940216 seconds)
-------Epoch 72-------
 CaP: 0.4818006757108923 | CG 0.444989084978314 | PZ 0.7060076138552498 | BG 0.9973242598421433
Validation Loss: 0.46 (Took 79.44962692260742 seconds)
-------Epoch 73-------
 CaP: 0.554799413300073 | CG 0.4149344291757135 | PZ 0.7904795399483513 | BG 0.9969748539083144
Validation Loss: 0.47 (Took 79.16039991378784 seconds)
-------Epoch 74-------
 CaP: 0.5755164019700859 | CG 0.46522941993659034 | PZ 0.7126721406684202 | BG 0.9967697192640865
Validation Loss: 0.47 (Took 7