<a href="https://colab.research.google.com/github/jban28/MPhys-Radiotherapy-49/blob/main/Tensorboard.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Pre-requisites
This block makes the necessary installations and imports for the rest of the code blocks to run, connects to the GPU if one is available, and specifies the location of the folder containing the data. That data folder should contain a sub-folder containing all nifti files, along with a metadata csv file.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install torch torchvision
!pip install opencv-contrib-python
!pip install scikit-learn
!pip install SimpleITK
!pip install kornia
!pip install utils
!pip install torchio

import numpy as np
import random
import os
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import SimpleITK as sitk
import torch
import torchio as tio
import kornia.augmentation as K
import torch.nn.functional as F


from mpl_toolkits.mplot3d import Axes3D
from torch.nn import Module
from torch.nn import Conv3d
from torch.nn import Linear
from torch.nn import MaxPool2d
from torch.nn import ReLU
from torch.nn import LogSoftmax
from torch.nn import LeakyReLU
from torch import flatten
from torch import nn
from torch import reshape
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
from torch.optim import Adam
import torchvision
import torchvision.models as models
from torchvision.io import read_image
from torchsummary import summary
from scipy.ndimage import zoom, rotate
from torch.utils.tensorboard import SummaryWriter
#from torch.utils.data import windowLevelNormalize

#set tag
tag = 0

# Connect to GPU is available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

# Specify project folder location
#project_folder = "/content/drive/My Drive/Degree/MPhys/Data/"
project_folder = "/content/drive/My Drive/Data/"

Using cuda device


## Define arrays of patient and outcome data
This block allows you to specify the criteria which defines the patient outcome as True or False. It then loops through all the patients in the metadata.csv file, searches for their corresponding image in the image folder, and then adds patient and outcome to either the training, testing, or validation array

In [None]:
# Open the metadata.csv file, convert to an array, and remove column headers
metadata_file = open(project_folder + "metadata.csv")
metadata = np.loadtxt(metadata_file, dtype="str", delimiter=",")
metadata = metadata[1:10][:]

# Set the values which are used to define the outcome for each patient
outcome_type = 1 #int(input("Select which outcome you are aiming to predict \n(1=Locoregional, 2=Distant Metastasis, 3=Death):"))
check_day = 3000 #int(input("Select the number of days at which to check for event:"))
which_patients = 1 #int(input("Do you want to include patients whose last follow up is before the check day? (no = 0, yes = 1):"))

# Create empty arrays to store patient names and outcomes in
patient_with_event = []
patient_no_event = []
outcomes_train = []
outcomes_test = []
images = []

# Loop through each patient and identify whether they are true or false for the specified outcome from above
for patient in metadata:
  if (patient[(5+outcome_type)] == "") and (int(patient[5]) >= check_day):
    # Last follow up after check day, no event
    outcome = 0
  elif (patient[(5+outcome_type)] == "") and (int(patient[5]) < check_day) and (which_patients == 0):
    # Last follow up before check day, event unknown
    continue
  elif (patient[(5+outcome_type)] == "") and (int(patient[5]) < check_day) and (which_patients == 1):
    outcome = 0
  elif int(patient[(5+outcome_type)]) <= check_day:
    # Event occurred before or on check day
    outcome = 1
  else:
    # Event occurred after check day
    outcome = 0
  # No Image file found for patient
  if not os.path.exists(project_folder + "crop/Images/" + patient[0] + ".nii"):
    print("No image found for patient " + patient[0])
    continue
  
  if outcome == 1:
    patient_with_event.append([patient[0], outcome])
  else:
    patient_no_event.append([patient[0], outcome])

# # Make arrays the same length
# if len(patient_with_event) < len(patient_no_event):
#   new_patient_no_event = random.sample(patient_no_event,len(patient_with_event))
#   new_patient_with_event = patient_with_event
# elif len(patient_with_event) > len(patient_no_event):
#   new_patient_with_event = random.sample(patient_with_event, len(patient_no_event))
#   new_patient_no_event = patient_no_event
# elif len(patient_with_event) == len(patient_no_event):
new_patient_no_event = patient_no_event
new_patient_with_event = patient_with_event
pos_weights = len(new_patient_no_event)/len(new_patient_with_event)
# Add patient name, outcome and image to array
seventy_percent_event = int(0.7*len(new_patient_with_event))
seventy_percent_no_event = int(0.7*len(new_patient_no_event))

print('NO event')
print(len(new_patient_no_event))
print('WITH event')
print(len(new_patient_with_event))
train_patients_event = random.sample(new_patient_with_event, seventy_percent_event)
train_patients_no_event = random.sample(new_patient_no_event, seventy_percent_no_event)

def remove(small_array, original_array):
  for i in small_array:
    original_array.remove(i)
    
  return original_array

new_patients_with_event = remove(train_patients_event, new_patient_with_event)
new_patient_no_event = remove(train_patients_no_event, new_patient_no_event)

print('NO event')
print(len(new_patient_no_event))
print('WITH event')
print(len(new_patient_with_event))


fifty_percent_event = int(0.5*len(new_patient_with_event))
fifty_percent_no_event = int(0.5*len(new_patient_no_event))

validate_patients_event = random.sample(new_patient_with_event, fifty_percent_event)
validate_patients_no_event = random.sample(new_patient_no_event, fifty_percent_no_event)

new_patient_with_event = remove(validate_patients_event, new_patient_with_event)
new_patient_no_event = remove(validate_patients_no_event, new_patient_no_event)

print('NO event')
print(len(new_patient_no_event))
print('WITH event')
print(len(new_patient_with_event))

test_patients_event = new_patient_with_event
test_patients_no_event = new_patient_no_event

outcomes_train = train_patients_event + train_patients_no_event
outcomes_validate = validate_patients_event + validate_patients_no_event
outcomes_test = test_patients_event + test_patients_no_event

print(outcomes_train)
print(outcomes_validate)
print(outcomes_test)



No image found for patient HN-CHUM-005
NO event
7
WITH event
1
NO event
3
WITH event
1
NO event
2
WITH event
1
[['HN-CHUM-008', 0], ['HN-CHUM-007', 0], ['HN-CHUM-006', 0], ['HN-CHUM-003', 0]]
[['HN-CHUM-001', 0]]
[['HN-CHUM-002', 1], ['HN-CHUM-004', 0], ['HN-CHUM-009', 0]]


## Define dataset class
This block defines the class on which to build dataset objects

In [None]:
# class Normalize(Dataset):
#     def __init__(self):
#       pass
#     def __call__(self, vol):
#         vol = (vol-vol.mean())/vol.std()
#         return(vol) 


data_transform = transforms.Compose([
        transforms.ToTensor(),
        #Normalize()
    ])

#window and levelling and this does normalise as well
def windowLevelNormalize(image, level, window):
    minval = level - window/2
    maxval = level + window/2
    wld = np.clip(image, minval, maxval)
    wld -= minval
    wld *= (1 / window)
    return wld





class ImageDataset(Dataset):
    def __init__(self, annotations, img_dir, transform= data_transform, target_transform=None, rotate_augment=False, scale_augment=False, flip_augment=False, shift_augment=False):
        self.img_labels = annotations
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
        self.flips = flip_augment
        self.rotations = rotate_augment
        self.scaling = scale_augment
        self.shifts = shift_augment

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels[idx][0]+".nii")
        image_sitk = sitk.ReadImage(img_path)
        image = sitk.GetArrayFromImage(image_sitk)
        label = self.img_labels[idx][1]
        print(image.shape)
        
        if self.shifts and random.random()<0.5:
            mx_x, mx_yz = 10, 10
            # find shift values
            cc_shift, ap_shift, lr_shift = random.randint(-mx_x,mx_x), random.randint(-mx_yz,mx_yz), random.randint(-mx_yz,mx_yz)
            # pad for shifting into
            image = np.pad(image, pad_width=((mx_x,mx_x),(mx_yz,mx_yz),(mx_yz,mx_yz)), mode='constant', constant_values=-1024)
            # crop to complete shift
            image = image[mx_x+cc_shift:246+mx_x+cc_shift, mx_yz+ap_shift:246+mx_yz+ap_shift, mx_yz+lr_shift:246+mx_yz+lr_shift]
            #print(image.shape)
            #print('shift')

        if self.target_transform:
            label = self.target_transform(label)

        
        if self.rotations and random.random()<0.5:
            # taking implementation from my 3DSegmentationNetwork which can be applied -> rotations in the axial plane only I should think? -10->10 degrees?
            roll_angle = np.clip(np.random.normal(loc=0,scale=3), -10, 10) # make -10,10
            image = self.rotation(image, roll_angle, rotation_plane=(1,2)) # (1,2) originally
            #print('rotation')
            
        if self.scaling and random.random()<0.5:
            # same here -> zoom between 80-120%
            scale_factor = np.clip(np.random.normal(loc=1.0,scale=0.5), 0.8, 1.2) # original scale = 0.05
            image = self.scale(image, scale_factor)
            #print('scale')
            
        if self.flips and random.random()<0.5:
            image = self.flip(image)
            #print('horizontal flip')
        if self.transform:
            image = self.transform(image)

        # window and levelling
        image = windowLevelNormalize(image, level=40, window=80)
 
        return image, label
    def scale(self, image, scale_factor):
        # scale the image or mask using scipy zoom function
        order, cval = (3, 0) # changed from -1024 to 0
        height, width, depth = image.shape
        zheight = int(np.round(scale_factor*height))
        zwidth = int(np.round(scale_factor*width))
        zdepth = int(np.round(scale_factor*depth))
        # zoomed out
        if scale_factor < 1.0:
            new_image = np.full_like(image, cval)
            ud_buffer = (height-zheight) // 2
            ap_buffer = (width-zwidth) // 2
            lr_buffer = (depth-zdepth) // 2
            new_image[ud_buffer:ud_buffer+zheight, ap_buffer:ap_buffer+zwidth, lr_buffer:lr_buffer+zdepth] = zoom(input=image, zoom=scale_factor, order=order, mode='constant', cval=cval)[0:zheight, 0:zwidth, 0:zdepth]
            return new_image
        elif scale_factor > 1.0:
            new_image = zoom(input=image, zoom=scale_factor, order=order, mode='constant', cval=cval)[0:zheight, 0:zwidth, 0:zdepth]
            ud_extra = (new_image.shape[0] - height) // 2
            ap_extra = (new_image.shape[1] - width) // 2
            lr_extra = (new_image.shape[2] - depth) // 2
            new_image = new_image[ud_extra:ud_extra+height, ap_extra:ap_extra+width, lr_extra:lr_extra+depth]
            return new_image
        return image
      
    def rotation(self, image, rotation_angle, rotation_plane):
        # rotate the image using scipy rotate function
        order, cval = (3, -1024) # changed from -1024 to 0
        return rotate(input=image, angle=rotation_angle, axes=rotation_plane, reshape=False, order=order, mode='constant', cval=cval)

    def flip(self, image):
        #hflip = np.fliplr(image)
        #image = (reversed(image[1:]))
        image = np.flipud(image).copy()
        return image


## Build Datasets
This block uses the class and arrays defined previously to build datasets for training, testing and validation.

In [None]:

training_data = ImageDataset(outcomes_train, project_folder + "crop/Images/")
validation_data = ImageDataset(outcomes_validate, project_folder + "crop/Images/", rotate_augment=False, scale_augment=False, flip_augment=False, shift_augment=False)
test_data = ImageDataset(outcomes_test, project_folder + "crop/Images/", rotate_augment=False, scale_augment=False, flip_augment=False, shift_augment=False)
print(len(training_data))


4


## View binary masks in 3d
This block allows you to view a binary mask from the image in 3d by extracting the image from a given dataset. This helps to confirm that the data has not been affected by reading in to pytorch.

In [None]:
  # Set which dataset to look at, and the index of the patient to view
dataset = training_data
index = 1
print('flipud')
print(outcomes_train[index])
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

#print(dataset[0])

#array = dataset[index][0].numpy()
array = dataset[index][0]
print(type(array))
print('array shape')
print(array.shape)
x,y,z = np.where(array > 0.) # what >=
ax.scatter(x, y, z, c=z, alpha=1)

ax.set_xlim(0,246)
ax.set_ylim(0,246)
ax.set_zlim(0,246)


flipud
['HN-CHUM-007', 0]


  


(246, 246, 246)
<class 'torch.Tensor'>
array shape
torch.Size([246, 246, 246])


(0.0, 246.0)

# Dataloader

In [None]:
train_dataloader = DataLoader(training_data, batch_size=4, shuffle=True)
validate_dataloader = DataLoader(validation_data, batch_size=4, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=4, shuffle=True)

print()




# Define CNN Class

In [None]:
# class CNN(nn.Module):
#   def __init__(self):
#     super(CNN, self).__init__()
#     out1 = 4
#     out2 = 4
#     out3 = 2
#     self.cnn_layers = nn.Sequential(
#       # Layer 1
#       nn.Conv3d(1,out1,4,1,1),
#       nn.BatchNorm3d(out1),
#       #nn.ReLU(inplace=True),
#       nn.LeakyReLU(inplace=True),
#       nn.MaxPool3d(kernel_size=2, stride=2),
#       # Layer 2
#       nn.Conv3d(out1, out2, 4, 1, 1),
#       nn.BatchNorm3d(out2),
#       #nn.ReLU(inplace=True),
#       nn.LeakyReLU(inplace=True),
#       nn.MaxPool3d(kernel_size=2, stride=2),
#       # Layer 3
#       nn.Conv3d(out2, out3, 4, 1, 1),
#       nn.BatchNorm3d(out3),
#       #nn.ReLU(inplace=True),
#       nn.LeakyReLU(inplace=True),
#       nn.MaxPool3d(kernel_size=2, stride=2),
#     )
    # self.linear_layers = nn.Sequential(
    #   nn.Linear(48778, 2)
    # )
#   def forward(self, x):
#     x = self.cnn_layers(x)
#     x = x.view(x.size(0), -1)
#     x = self.linear_layers(x)
#     return x


class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    out1 = 32
    out2 = 64
    out3 = 128
    out4 = 64
    out5 = 16
    out6 = 2
    self.cnn_layers = nn.Sequential(
      # Layer 1
      nn.Conv3d(1,out1,2,2),
      nn.BatchNorm3d(out1),
      nn.LeakyReLU(inplace=True),
      nn.MaxPool3d(kernel_size=2, stride=2),
      # Layer 2
      nn.Conv3d(out1, out2, 2, 2),
      nn.BatchNorm3d(out2),
      nn.LeakyReLU(inplace=True),
      nn.MaxPool3d(kernel_size=2, stride=2),
      # Layer 3
      nn.Conv3d(out2, out3, 2, 2),
      nn.BatchNorm3d(out3),
      nn.LeakyReLU(inplace=True),
      nn.MaxPool3d(kernel_size=2, stride=2),
      # Layer 4
      nn.Conv3d(out3, out4, 1, 1),
      nn.BatchNorm3d(out4),
      nn.LeakyReLU(inplace=True),
      # Layer 5
      nn.Conv3d(out4, out5, 1, 1),
      nn.BatchNorm3d(out5),
      nn.LeakyReLU(inplace=True),
      # Layer 6
      nn.Conv3d(out5, out6, 1, 1),
      nn.BatchNorm3d(out6),
      nn.LeakyReLU(inplace=True),
      nn.AvgPool3d(2)

    )


  def forward(self, x):
    x = self.cnn_layers(x)
    x = x.view(x.size(0), -1)
    return x


model = CNN().to(device)
print(model)


CNN(
  (cnn_layers): Sequential(
    (0): Conv3d(1, 32, kernel_size=(2, 2, 2), stride=(2, 2, 2))
    (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.01, inplace=True)
    (3): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv3d(32, 64, kernel_size=(2, 2, 2), stride=(2, 2, 2))
    (5): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): LeakyReLU(negative_slope=0.01, inplace=True)
    (7): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv3d(64, 128, kernel_size=(2, 2, 2), stride=(2, 2, 2))
    (9): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.01, inplace=True)
    (11): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (12): Conv3d(128, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (13): BatchNorm3d

# Define Train and Test Loops

In [None]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        X = reshape(X, (X.shape[0],1,246,246,246))
        X = X.float()
        X = X.to(device)
        y = y.to(device)
        print('y issssssssssssssss')


        #y = reshape(y, (y.shape[0],1))
        hot_y = torch.empty((X.shape[0],2)).to(device)
        for index in range(len(y)):
          if y[index] == 0:
            hot_y[index,0] = 1
            hot_y[index,1] = 0
          elif y[index] == 1:
            hot_y[index,0] = 0
            hot_y[index,1] = 1
      
        print(hot_y)
        pred = model(X)
        torch.squeeze(pred)
        loss = loss_fn(pred, hot_y.float())

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Print results after each batch        
        if batch % 1 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
    return loss

def validate_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    validate_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            X = reshape(X, (X.shape[0],1,246,246,246))
            X = X.float()
            X = X.to(device)
            y = y.to(device)
            #y = reshape(y, (y.shape[0],1))
            hot_y = torch.empty((X.shape[0],2)).to(device)
            for index in range(len(y)):
              if y[index] == 0:
                hot_y[index,0] = 1
                hot_y[index,1] = 0
              elif y[index] == 1:
                hot_y[index,0] = 0
                hot_y[index,1] = 1
            
            pred = model(X)
            # print(f'pred: {pred}')
            # print(f'hot_y: {hot_y}')
            _,predictions = torch.max(pred , 1)
            _,targets = torch.max(hot_y, 1)
            # print(f'predictions: {predictions}')
            # print(f'targets: {targets}')
            print(f'Correct this batch = {(predictions == targets).sum().item()}')

            torch.squeeze(pred)
            validate_loss += loss_fn(pred, hot_y.float()).item()
            # correct += (pred.argmax(1) == hot_y).type(torch.float).sum().item()
            correct += (predictions == targets).sum().item()

    validate_loss /= num_batches
    correct /= size
    accuracy = 100*correct
    print(f"Validate Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {validate_loss:>8f} \n")
    return validate_loss, accuracy

learning_rate = 0.00101
# defining the model
model = CNN()
# defining the optimizer
# optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# defining the loss function
pos_weights = torch.tensor([(1/pos_weights), pos_weights])
# print(pos_weights.size())
# loss_fn = nn.BCEWithLogitsLoss()
loss_fn = nn.BCEWithLogitsLoss(pos_weights)
# loss_fn = nn.CrossEntropyLoss()

model.to(device)
loss_fn.to(device)

summary(model=model, input_size=(1, 246, 246, 246), batch_size=2)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1     [2, 32, 123, 123, 123]             288
       BatchNorm3d-2     [2, 32, 123, 123, 123]              64
         LeakyReLU-3     [2, 32, 123, 123, 123]               0
         MaxPool3d-4        [2, 32, 61, 61, 61]               0
            Conv3d-5        [2, 64, 30, 30, 30]          16,448
       BatchNorm3d-6        [2, 64, 30, 30, 30]             128
         LeakyReLU-7        [2, 64, 30, 30, 30]               0
         MaxPool3d-8        [2, 64, 15, 15, 15]               0
            Conv3d-9          [2, 128, 7, 7, 7]          65,664
      BatchNorm3d-10          [2, 128, 7, 7, 7]             256
        LeakyReLU-11          [2, 128, 7, 7, 7]               0
        MaxPool3d-12          [2, 128, 3, 3, 3]               0
           Conv3d-13           [2, 64, 3, 3, 3]           8,256
      BatchNorm3d-14           [2, 64, 

In [None]:
!tensorboard --logdir=content/drive/logsdir


NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.8.0 at http://localhost:6006/ (Press CTRL+C to quit)
Error in atexit._run_exitfuncs:
Traceback (most recent call last):
  File "/usr/lib/python3.7/multiprocessing/util.py", line 320, in _exit_function
    def _exit_function(info=info, debug=debug, _run_finalizers=_run_finalizers,
KeyboardInterrupt


# Run Network

In [None]:
epochs = 2
train_losses = [[],[]]
validate_losses = [[],[]]
validate_accuracies = [[],[]]




"""
Custom tensorboard writer class
"""

class customWriter(SummaryWriter):
    def __init__(self, log_dir, batch_size, epoch, num_classes, dataloader):
        super(customWriter, self).__init__()
        self.log_dir = log_dir
        self.batch_size = batch_size
        self.epoch = epoch
        self.num_classes = num_classes
        self.train_loss = []
        self.val_loss = []
        self.class_loss = {n: [] for n in range(num_classes+1)}
        self.dataloader = dataloader
    
    @staticmethod
    def sigmoid(x):
        return 1/(1+torch.exp(-x))

    def reset_losses(self):
        self.train_loss, self.val_loss, self.class_loss = [], [], {
            n: [] for n in range(self.num_classes+1)}

    def plot_batch(self, tag, images):
        """
        Plot batches in grid
​
        Args: tag = identifier for plot (string)
              images = input batch (torch.tensor)
        """
        img_grid = torchvision.utils.make_grid(images, nrow=self.batch_size // 2)
        self.add_image(tag, img_grid)

    # def plot_prediction(self, tag, prediction, target, plot_target=True):
    #     """
    #     Plot predictions vs target segmentation.
    #     Args: tag = identifier for plot (string)
    #           prediction = batch output of trained model (torch.tensor)
    #           target = batch ground-truth segmentations (torch.tensor)
    #     """
    #     fig = plt.figure(figsize=(24, 24, 24))#changed from (24,24)
    #     prediction = self.sigmoid(prediction)
    #     for idx in np.arange(self.batch_size):
    #         ax = fig.add_subplot(self.batch_size // 2, self.batch_size // 2, self.batch_size // 2,
    #                             idx+1, label='segmentations')
    #         ax.imshow(prediction[idx, 0].cpu().numpy(
    #         ), cmap='viridis')
    #         if plot_target:
    #             ax.imshow(target[idx, 0].cpu().numpy(), cmap='gray', alpha=0.25)
    #         ax.set_title('prediction @ epoch: {} - idx: {}'.format(self.epoch, idx))
    #     self.add_figure(tag, fig)
    def plot_pred(self, tag, prediction):
        """
        Plot predictions vs target segmentation.
        Args: tag = identifier for plot (string)
              prediction = batch output of trained model (torch.tensor)
              target = batch ground-truth segmentations (torch.tensor)
        """
        fig = plt.figure(figsize=(24, 24))#changed from (24,24)
        for idx in np.arange(self.batch_size):
            ax = fig.add_subplot(self.batch_size // 2, self.batch_size // 2, self.batch_size // 2,
                                idx+1, label='images')
            ax.imshow(prediction[idx, 0].cpu().numpy(
            ), cmap='viridis')
            
            ax.set_title('prediction @ epoch: {} - idx: {}'.format(self.epoch, idx))
        self.add_figure(tag, fig)

    def plot_tumour(self, tag, dataloader):
        fig = plt.figure(figsize=(24, 24))
        size = len(dataloader.dataset)
        
        for batch, (X, y) in enumerate(dataloader):
            X = reshape(X, (X.shape[0],1,246,246,246))
            X = X.float()
            print(X)
            X = X.to(device)
            #X.numpy()
            X = X.cpu()
            X = X.detach().numpy()
            for i in range(X.shape[0]):
                Xbig = X[i,0,:,:,:]
                print(Xbig.shape)
                print(type(Xbig))
                Xsmall = Xbig[:,:,123]
                print(Xsmall.shape)
                print(type(Xsmall))
                ax = fig.add_subplot()
                ax.imshow(Xsmall, cmap='viridis')
                self.add_figure(str(tag), fig)
                tag += 1

    def plot_histogram(self, tag, prediction):
        print('Plotting histogram')
        fig = plt.figure(figsize=(24, 24))
        for idx in np.arange(self.batch_size):
            ax = fig.add_subplot(self.batch_size // 2, self.batch_size // 2,
                                 idx+1, yticks=[], label='histogram')
            pred_norm = (prediction[idx, 0]-prediction[idx, 0].min())/(
                prediction[idx, 0].max()-prediction[idx, 0].min())
            ax.hist(pred_norm.cpu().flatten(), bins=100)
            ax.set_title(
                f'Prediction histogram @ epoch: {self.epoch} - idx: {idx}')
        self.add_figure(tag, fig)

    def per_class_loss(self, prediction, target, criterion, alpha=None):
        # Predict shape: (4, 1, 512, 512)
        # Target shape: (4, 1, 512, 512)
        #pred, target = prediction.cpu().numpy(), target.cpu().numpy()
        pred, target = prediction, target
        for class_ in range(self.num_classes + 1):
            class_pred, class_tgt = torch.where(
                target == class_, pred, torch.tensor([0], dtype=torch.float32).cuda()),  torch.where(target == class_, target, torch.tensor([0], dtype=torch.float32).cuda())

            #class_pred, class_tgt = pred[target == class_], target[target == class_] 
            if alpha is not None:
                loss = criterion(class_pred, class_tgt, alpha)
                #bce_loss, dice_loss = criterion(class_pred, class_tgt, alpha)
            else:
                loss = criterion(class_pred, class_tgt)
                #bce_loss, dice_loss = criterion(class_pred, class_tgt)
            #loss = bce_loss + dice_loss
            self.class_loss[class_].append(loss.item())

    def write_class_loss(self):
        for class_ in range(self.num_classes+1):
            self.add_scalar(f'Per Class loss for class {class_}', np.mean(self.class_loss[class_]), self.epoch)












writer = customWriter(project_folder, 2, 0, 1, train_dataloader)




for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loss = train_loop(train_dataloader, model, loss_fn, optimizer)
    validate_loss = validate_loop(validate_dataloader, model, loss_fn)

    train_losses[0].append(t)
    train_losses[1].append(train_loss)
    validate_losses[0].append(t)
    validate_losses[1].append(validate_loss[0])
    validate_accuracies[0].append(t)
    validate_accuracies[1].append(validate_loss[1])

    writer.add_scalar('Train Loss', train_loss, t)
    writer.add_scalar('Validate Loss', validate_loss[0], t)
    # plot 3d plots here
    writer.plot_tumour(dataloader = train_dataloader, tag=tag)
writer.close()
print("Done!")



Epoch 1
-------------------------------
(246, 246, 246)
(246, 246, 246)
(246, 246, 246)
(246, 246, 246)
y issssssssssssssss
tensor([[1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.]], device='cuda:0')
loss: 2.452642  [    0/    4]
(246, 246, 246)
Correct this batch = 1
Validate Error: 
 Accuracy: 100.0%, Avg loss: 2.449444 





(246, 246, 246)
(246, 246, 246)
(246, 246, 246)
(246, 246, 246)
tensor([[[[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],

          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],

          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],

          ...,

          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
      



(246, 246, 246)
<class 'numpy.ndarray'>
(246, 246)
<class 'numpy.ndarray'>
(246, 246, 246)
<class 'numpy.ndarray'>
(246, 246)
<class 'numpy.ndarray'>
Epoch 2
-------------------------------
(246, 246, 246)
(246, 246, 246)
(246, 246, 246)
(246, 246, 246)
y issssssssssssssss
tensor([[1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.]], device='cuda:0')
loss: 2.451806  [    0/    4]
(246, 246, 246)
Correct this batch = 1
Validate Error: 
 Accuracy: 100.0%, Avg loss: 2.449681 

(246, 246, 246)
(246, 246, 246)
(246, 246, 246)
(246, 246, 246)
tensor([[[[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],

          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  

# Test

In [None]:
def test_loop(dataloader, model):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            X = reshape(X, (X.shape[0],1,246,246,246))
            X = X.float()
            X = X.to(device)
            y = y.to(device)
            #y = reshape(y, (y.shape[0],1))
            hot_y = torch.empty((X.shape[0],2)).to(device)
            for index in range(len(y)):
              if y[index] == 0:
                hot_y[index,0] = 1
                hot_y[index,1] = 0
              elif y[index] == 1:
                hot_y[index,0] = 0
                hot_y[index,1] = 1
                
            pred = model(X)
            _,predictions = torch.max(pred , 1)
            _,targets = torch.max(hot_y, 1)
            torch.squeeze(pred)
            #test_loss += loss_fn(pred, y.float()).item()
            test_loss += loss_fn(pred, hot_y.float()).item()
            #correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            correct += (predictions == targets).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return test_loss

test_loop(test_dataloader, model)

# Plot Losses

In [None]:
fig = plt.figure()
ax = plt.axes()
ax.plot(train_losses[0], train_losses[1], label="Train Loss")
ax.plot(validate_losses[0], validate_losses[1], label="Validate Loss")
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.legend()

# Plot Accuracies

In [None]:
fig = plt.figure()
ax = plt.axes()
ax.plot(validate_accuracies[0], validate_accuracies[1], label="Validate Accuracies")
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy / %')
ax.legend()