In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import torchvision.transforms.v2 as v2
from torch.utils.data import DataLoader
from model import BaseUnet, BaseUnet3D
from data_processing_tools import remove_repeating_pattern
from evaluation_metrics import dice_overlap, intersection_over_union, accuracy, sensitivity, specificity
import timeit
from dataset import BrightfieldMicroscopyDataset

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mhndrkjs[0m ([33mhndrkjs-danmarks-tekniske-universitet-dtu[0m). Use [1m`wandb login --relogin`[0m to force relogin


## Loading all the models

In [12]:
one_channel_unet = BaseUnet(num_inputs=1)
one_channel_unet.load_state_dict(torch.load('early_stopping_modelablation_channel_base_unet_1.pth', map_location=torch.device('cpu')))

two_channel_unet = BaseUnet(num_inputs=2)
two_channel_unet.load_state_dict(torch.load('early_stopping_modelablation_channel_base_unet_2.pth', map_location=torch.device('cpu')))

three_channel_unet = BaseUnet(num_inputs=3)
three_channel_unet.load_state_dict(torch.load('early_stopping_modelablation_channel_base_unet_3.pth', map_location=torch.device('cpu')))

four_channel_unet = BaseUnet(num_inputs=4)
four_channel_unet.load_state_dict(torch.load('early_stopping_modelablation_channel_base_unet_4.pth', map_location=torch.device('cpu')))

five_channel_unet = BaseUnet(num_inputs=5)
five_channel_unet.load_state_dict(torch.load('early_stopping_modelablation_channel_base_unet_5.pth', map_location=torch.device('cpu')))

six_channel_unet = BaseUnet(num_inputs=6)
six_channel_unet.load_state_dict(torch.load('early_stopping_modelablation_channel_base_unet_6.pth', map_location=torch.device('cpu')))

seven_channel_unet = BaseUnet(num_inputs=7)
seven_channel_unet.load_state_dict(torch.load('early_stopping_modelablation_channel_base_unet_7.pth', map_location=torch.device('cpu')))

eight_channel_unet = BaseUnet(num_inputs=8)
eight_channel_unet.load_state_dict(torch.load('early_stopping_modelablation_channel_base_unet_8.pth', map_location=torch.device('cpu')))

nine_channel_unet = BaseUnet(num_inputs=9)
nine_channel_unet.load_state_dict(torch.load('early_stopping_modelablation_channel_base_unet_9.pth', map_location=torch.device('cpu')))

ten_channel_unet = BaseUnet(num_inputs=10)
ten_channel_unet.load_state_dict(torch.load('early_stopping_modelablation_channel_base_unet_10.pth', map_location=torch.device('cpu')))

eleven_channel_unet = BaseUnet(num_inputs=11)
eleven_channel_unet.load_state_dict(torch.load('early_stopping_modelablation_channel_base_unet_11.pth', map_location=torch.device('cpu')))

base_unet_no_preprocessing = BaseUnet()
base_unet_no_preprocessing.load_state_dict(torch.load('early_stopping_modelUnet_simple_11_no_preprocess.pth', map_location=torch.device('cpu')))

base_unet_no_preprocessing_3d = BaseUnet3D()
base_unet_no_preprocessing_3d.load_state_dict(torch.load('early_stopping_modelUnet3D_simple_11_no_preprocess.pth', map_location=torch.device('cpu')))

base_unet_preprocessing = BaseUnet()
base_unet_preprocessing.load_state_dict(torch.load('final_model_classificationUnet_simple_11_preprocess.pth', map_location=torch.device('cpu')))

base_unet_preprocessing_3d = BaseUnet3D()
base_unet_preprocessing_3d.load_state_dict(torch.load('early_stopping_modelUnet3D_simple_11_preprocess.pth', map_location=torch.device('cpu')))

one_channel_unet.eval()
two_channel_unet.eval()
three_channel_unet.eval()
four_channel_unet.eval()
five_channel_unet.eval()
six_channel_unet.eval()
seven_channel_unet.eval()
eight_channel_unet.eval()
nine_channel_unet.eval()
ten_channel_unet.eval()
eleven_channel_unet.eval()
base_unet_no_preprocessing.eval()
base_unet_no_preprocessing_3d.eval()
base_unet_preprocessing.eval()
base_unet_preprocessing_3d.eval()

  one_channel_unet.load_state_dict(torch.load('early_stopping_modelablation_channel_base_unet_1.pth', map_location=torch.device('cpu')))
  two_channel_unet.load_state_dict(torch.load('early_stopping_modelablation_channel_base_unet_2.pth', map_location=torch.device('cpu')))
  three_channel_unet.load_state_dict(torch.load('early_stopping_modelablation_channel_base_unet_3.pth', map_location=torch.device('cpu')))
  four_channel_unet.load_state_dict(torch.load('early_stopping_modelablation_channel_base_unet_4.pth', map_location=torch.device('cpu')))
  five_channel_unet.load_state_dict(torch.load('early_stopping_modelablation_channel_base_unet_5.pth', map_location=torch.device('cpu')))
  six_channel_unet.load_state_dict(torch.load('early_stopping_modelablation_channel_base_unet_6.pth', map_location=torch.device('cpu')))
  seven_channel_unet.load_state_dict(torch.load('early_stopping_modelablation_channel_base_unet_7.pth', map_location=torch.device('cpu')))
  eight_channel_unet.load_state_dic

BaseUnet3D(
  (enc_conv0): Conv3d(11, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (downconv0): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
  (batchnorm0): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (enc_conv1): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (downconv1): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
  (batchnorm1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (enc_conv2): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (downconv2): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
  (batchnorm2): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (enc_conv3): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (downconv3): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2),

## Dataloader and Testing function

In [3]:
def get_dataloader_test(sample_size=512, batch_size=1, channels=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]):

    # image_root = '/zhome/70/5/14854/nobackup/deeplearningf24/forcebiology/data/brightfield'
    # mask_root = '/zhome/70/5/14854/nobackup/deeplearningf24/forcebiology/data/masks'
    image_root = 'data/brightfield'
    mask_root = 'data/masks'

    transform_test = v2.Compose([
        v2.Resize((sample_size, sample_size)),
        v2.ToDtype(torch.float32, scale=True),
        v2.ToTensor(),
    ])

    brightfield_test_datatset = BrightfieldMicroscopyDataset(root_dir_images=image_root, root_dir_labels=mask_root, train=False, validation=False, transform=transform_test, channels_to_use=channels)

    brightfield_loader_test = DataLoader(brightfield_test_datatset,  batch_size=batch_size, shuffle=False)

    return brightfield_loader_test

In [4]:
def test_model(model, test_loader, criterion, device, preprocessing=False, test_3d=False):
    test_loss = []
    dice = []
    iou = []
    acc = []
    sens = []
    spec = []

    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            # remove repeating pattern
            if preprocessing:
                for i in range(images.shape[0]):
                    images[i] = torch.tensor(remove_repeating_pattern(images[i].numpy()))

            images, labels = images.to(device), labels.to(device)

            if test_3d:
                images = images.unsqueeze(2)
                labels = labels.unsqueeze(1)

            outputs = model(images)
            loss = criterion(outputs, labels.float())
            test_loss.append(loss.item())

            Y_pred = (outputs > 0.45).float()
            
            dice.append(dice_overlap(Y_pred, labels))
            iou.append(intersection_over_union(Y_pred, labels))
            acc.append(accuracy(Y_pred, labels))
            sens.append(sensitivity(Y_pred, labels))
            spec.append(specificity(Y_pred, labels))

    return test_loss, dice, iou, acc, sens, spec
        

## Testing the 2D and 3D Unets with and without preprocessing

In [13]:
test_loss_3d_no_preprocess, dice_3d_no_preprocess, iou_3d_no_preprocess, acc_3d_no_preprocess, sens_3d_no_preprocess, spec_3d_no_preprocess = test_model(base_unet_no_preprocessing_3d, get_dataloader_test(channels=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), nn.BCEWithLogitsLoss(), torch.device('cpu'), preprocessing=False, test_3d=True)

test_loss_3d_preprocess, dice_3d_preprocess, iou_3d_preprocess, acc_3d_preprocess, sens_3d_preprocess, spec_3d_preprocess = test_model(base_unet_preprocessing_3d, get_dataloader_test(channels=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), nn.BCEWithLogitsLoss(), torch.device('cpu'), preprocessing=True, test_3d=True)

tess_loss_2d_no_preprocess, dice_2d_no_preprocess, iou_2d_no_preprocess, acc_2d_no_preprocess, sens_2d_no_preprocess, spec_2d_no_preprocess = test_model(base_unet_no_preprocessing, get_dataloader_test(channels=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), nn.BCEWithLogitsLoss(), torch.device('cpu'), preprocessing=False, test_3d=False)

test_loss_2d_preprocess, dice_2d_preprocess, iou_2d_preprocess, acc_2d_preprocess, sens_2d_preprocess, spec_2d_preprocess = test_model(base_unet_preprocessing, get_dataloader_test(channels=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), nn.BCEWithLogitsLoss(), torch.device('cpu'), preprocessing=True, test_3d=False)

print('3D no preprocess')
print('Test loss: ', np.mean(test_loss_3d_no_preprocess))
print('Dice: ', np.mean(dice_3d_no_preprocess))
print('IoU: ', np.mean(iou_3d_no_preprocess))
print('Accuracy: ', np.mean(acc_3d_no_preprocess))
print('Sensitivity: ', np.mean(sens_3d_no_preprocess))
print('Specificity: ', np.mean(spec_3d_no_preprocess))

print('3D preprocess')
print('Test loss: ', np.mean(test_loss_3d_preprocess))
print('Dice: ', np.mean(dice_3d_preprocess))
print('IoU: ', np.mean(iou_3d_preprocess))
print('Accuracy: ', np.mean(acc_3d_preprocess))
print('Sensitivity: ', np.mean(sens_3d_preprocess))
print('Specificity: ', np.mean(spec_3d_preprocess))

print('2D no preprocess')
print('Test loss: ', np.mean(tess_loss_2d_no_preprocess))
print('Dice: ', np.mean(dice_2d_no_preprocess))
print('IoU: ', np.mean(iou_2d_no_preprocess))
print('Accuracy: ', np.mean(acc_2d_no_preprocess))
print('Sensitivity: ', np.mean(sens_2d_no_preprocess))
print('Specificity: ', np.mean(spec_2d_no_preprocess))

print('2D preprocess')
print('Test loss: ', np.mean(test_loss_2d_preprocess))
print('Dice: ', np.mean(dice_2d_preprocess))
print('IoU: ', np.mean(iou_2d_preprocess))
print('Accuracy: ', np.mean(acc_2d_preprocess))
print('Sensitivity: ', np.mean(sens_2d_preprocess))
print('Specificity: ', np.mean(spec_2d_preprocess))

3D no preprocess
Test loss:  0.32766258848064084
Dice:  0.66471165
IoU:  0.5213731
Accuracy:  0.9010626
Sensitivity:  0.5811094
Specificity:  0.9794943
3D preprocess
Test loss:  0.303239405885631
Dice:  0.7123706
IoU:  0.5669817
Accuracy:  0.9005842
Sensitivity:  0.68898076
Specificity:  0.9514856
2D no preprocess
Test loss:  0.2897508324069135
Dice:  0.6941363
IoU:  0.56036824
Accuracy:  0.9063084
Sensitivity:  0.64471835
Specificity:  0.9740524
2D preprocess
Test loss:  0.221010604179373
Dice:  0.73376006
IoU:  0.5942269
Accuracy:  0.91118246
Sensitivity:  0.6848869
Specificity:  0.9653142


## Testing whether results are statisitcally significant

### 2D Unet vs 3D Unet

In [7]:
import numpy as np
from scipy.stats import wilcoxon


modelA_scores_dice = np.array(dice_3d_no_preprocess)
modelB_scores_dice = np.array(dice_2d_no_preprocess)

modelA_scores_iou = np.array(iou_3d_no_preprocess)
modelB_scores_iou = np.array(iou_2d_no_preprocess)


# Perform a Wilcoxon signed-rank test (non-parametric)
stat_dice, p_value_dice = wilcoxon(modelA_scores_dice, modelB_scores_dice)
stat_iou, p_value_iou = wilcoxon(modelA_scores_iou, modelB_scores_iou)

print("Wilcoxon signed-rank test results for dice:")
print(f"Statistic: {stat_dice:.4f}, p-value: {p_value_dice:.4f}")

print("Wilcoxon signed-rank test results for iou:")
print(f"Statistic: {stat_iou:.4f}, p-value: {p_value_iou:.4f}")


Wilcoxon signed-rank test results for dice:
Statistic: 550.0000, p-value: 0.2895
Wilcoxon signed-rank test results for iou:
Statistic: 588.0000, p-value: 0.4820


### 3D Unet preprocessing vs no preprocessing

In [8]:
modelA_scores_dice = np.array(dice_3d_no_preprocess)
modelB_scores_dice = np.array(dice_3d_preprocess)

modelA_scores_iou = np.array(iou_3d_no_preprocess)
modelB_scores_iou = np.array(iou_3d_preprocess)


# Perform a Wilcoxon signed-rank test (non-parametric)
stat_dice, p_value_dice = wilcoxon(modelA_scores_dice, modelB_scores_dice)
stat_iou, p_value_iou = wilcoxon(modelA_scores_iou, modelB_scores_iou)

print("Wilcoxon signed-rank test results for dice:")
print(f"Statistic: {stat_dice:.4f}, p-value: {p_value_dice:.4f}")

print("Wilcoxon signed-rank test results for iou:")
print(f"Statistic: {stat_iou:.4f}, p-value: {p_value_iou:.4f}")

Wilcoxon signed-rank test results for dice:
Statistic: 270.0000, p-value: 0.0002
Wilcoxon signed-rank test results for iou:
Statistic: 276.0000, p-value: 0.0003


### 2D Unet preprocessing vs no preprocessing

In [9]:
modelA_scores_dice = np.array(dice_2d_no_preprocess)
modelB_scores_dice = np.array(dice_2d_preprocess)

modelA_scores_iou = np.array(iou_2d_no_preprocess)
modelB_scores_iou = np.array(iou_2d_preprocess)


# Perform a Wilcoxon signed-rank test (non-parametric)
stat_dice, p_value_dice = wilcoxon(modelA_scores_dice, modelB_scores_dice)
stat_iou, p_value_iou = wilcoxon(modelA_scores_iou, modelB_scores_iou)

print("Wilcoxon signed-rank test results for dice:")
print(f"Statistic: {stat_dice:.4f}, p-value: {p_value_dice:.4f}")

print("Wilcoxon signed-rank test results for iou:")
print(f"Statistic: {stat_iou:.4f}, p-value: {p_value_iou:.4f}")

Wilcoxon signed-rank test results for dice:
Statistic: 128.0000, p-value: 0.0000
Wilcoxon signed-rank test results for iou:
Statistic: 114.0000, p-value: 0.0000


### 2D vs 3D preprocess

In [10]:
modelA_scores_dice = np.array(dice_3d_preprocess)
modelB_scores_dice = np.array(dice_2d_preprocess)

modelA_scores_iou = np.array(iou_3d_preprocess)
modelB_scores_iou = np.array(iou_2d_preprocess)


# Perform a Wilcoxon signed-rank test (non-parametric)
stat_dice, p_value_dice = wilcoxon(modelA_scores_dice, modelB_scores_dice)
stat_iou, p_value_iou = wilcoxon(modelA_scores_iou, modelB_scores_iou)

print("Wilcoxon signed-rank test results for dice:")
print(f"Statistic: {stat_dice:.4f}, p-value: {p_value_dice:.10f}")

print("Wilcoxon signed-rank test results for iou:")
print(f"Statistic: {stat_iou:.4f}, p-value: {p_value_iou:.10f}")

Wilcoxon signed-rank test results for dice:
Statistic: 270.0000, p-value: 0.0002298074
Wilcoxon signed-rank test results for iou:
Statistic: 250.0000, p-value: 0.0001082823


## Evaluate different channels

In [9]:
test_loss_1_channel, dice_1_channel, iou_1_channel, acc_1_channel, sens_1_channel, spec_1_channel = test_model(one_channel_unet, get_dataloader_test(channels=[0]), nn.BCEWithLogitsLoss(), torch.device('cpu'), preprocessing=False, test_3d=False)

test_loss_2_channel, dice_2_channel, iou_2_channel, acc_2_channel, sens_2_channel, spec_2_channel = test_model(two_channel_unet, get_dataloader_test(channels=[0, 1]), nn.BCEWithLogitsLoss(), torch.device('cpu'), preprocessing=False, test_3d=False)

test_loss_3_channel, dice_3_channel, iou_3_channel, acc_3_channel, sens_3_channel, spec_3_channel = test_model(three_channel_unet, get_dataloader_test(channels=[0, 1, 2]), nn.BCEWithLogitsLoss(), torch.device('cpu'), preprocessing=False, test_3d=False)

test_loss_4_channel, dice_4_channel, iou_4_channel, acc_4_channel, sens_4_channel, spec_4_channel = test_model(four_channel_unet, get_dataloader_test(channels=[0, 1, 2, 3]), nn.BCEWithLogitsLoss(), torch.device('cpu'), preprocessing=False, test_3d=False)

test_loss_5_channel, dice_5_channel, iou_5_channel, acc_5_channel, sens_5_channel, spec_5_channel = test_model(five_channel_unet, get_dataloader_test(channels=[0, 1, 2, 3, 4]), nn.BCEWithLogitsLoss(), torch.device('cpu'), preprocessing=False, test_3d=False)

test_loss_6_channel, dice_6_channel, iou_6_channel, acc_6_channel, sens_6_channel, spec_6_channel = test_model(six_channel_unet, get_dataloader_test(channels=[0, 1, 2, 3, 4, 5]), nn.BCEWithLogitsLoss(), torch.device('cpu'), preprocessing=False, test_3d=False)

test_loss_7_channel, dice_7_channel, iou_7_channel, acc_7_channel, sens_7_channel, spec_7_channel = test_model(seven_channel_unet, get_dataloader_test(channels=[0, 1, 2, 3, 4, 5, 6]), nn.BCEWithLogitsLoss(), torch.device('cpu'), preprocessing=False, test_3d=False)

test_loss_8_channel, dice_8_channel, iou_8_channel, acc_8_channel, sens_8_channel, spec_8_channel = test_model(eight_channel_unet, get_dataloader_test(channels=[0, 1, 2, 3, 4, 5, 6, 7]), nn.BCEWithLogitsLoss(), torch.device('cpu'), preprocessing=False, test_3d=False)

test_loss_9_channel, dice_9_channel, iou_9_channel, acc_9_channel, sens_9_channel, spec_9_channel = test_model(nine_channel_unet, get_dataloader_test(channels=[0, 1, 2, 3, 4, 5, 6, 7, 8]), nn.BCEWithLogitsLoss(), torch.device('cpu'), preprocessing=False, test_3d=False)

test_loss_10_channel, dice_10_channel, iou_10_channel, acc_10_channel, sens_10_channel, spec_10_channel = test_model(ten_channel_unet, get_dataloader_test(channels=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), nn.BCEWithLogitsLoss(), torch.device('cpu'), preprocessing=False, test_3d=False)

test_loss_11_channel, dice_11_channel, iou_11_channel, acc_11_channel, sens_11_channel, spec_11_channel = test_model(eleven_channel_unet, get_dataloader_test(channels=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), nn.BCEWithLogitsLoss(), torch.device('cpu'), preprocessing=False, test_3d=False)

print('1 channel')
print('Dice overlap: ', np.mean(dice_1_channel))
print('IoU: ', np.mean(iou_1_channel))

print('2 channel')
print('Dice overlap: ', np.mean(dice_2_channel))
print('IoU: ', np.mean(iou_2_channel))

print('3 channel')
print('Dice overlap: ', np.mean(dice_3_channel))
print('IoU: ', np.mean(iou_3_channel))

print('4 channel')
print('Dice overlap: ', np.mean(dice_4_channel))
print('IoU: ', np.mean(iou_4_channel))

print('5 channel')
print('Dice overlap: ', np.mean(dice_5_channel))
print('IoU: ', np.mean(iou_5_channel))

print('6 channel')
print('Dice overlap: ', np.mean(dice_6_channel))
print('IoU: ', np.mean(iou_6_channel))

print('7 channel')
print('Dice overlap: ', np.mean(dice_7_channel))
print('IoU: ', np.mean(iou_7_channel))

print('8 channel')
print('Dice overlap: ', np.mean(dice_8_channel))
print('IoU: ', np.mean(iou_8_channel))

print('9 channel')
print('Dice overlap: ', np.mean(dice_9_channel))
print('IoU: ', np.mean(iou_9_channel))

print('10 channel')
print('Dice overlap: ', np.mean(dice_10_channel))
print('IoU: ', np.mean(iou_10_channel))

print('11 channel')
print('Dice overlap: ', np.mean(dice_11_channel))
print('IoU: ', np.mean(iou_11_channel))




1 channel
Dice overlap:  0.61972314
IoU:  0.45893863
2 channel
Dice overlap:  0.49061936
IoU:  0.3397203
3 channel
Dice overlap:  0.5790177
IoU:  0.4189411
4 channel
Dice overlap:  0.54791534
IoU:  0.38958633
5 channel
Dice overlap:  0.5881931
IoU:  0.42858973
6 channel
Dice overlap:  0.4247808
IoU:  0.27752945
7 channel
Dice overlap:  0.5807479
IoU:  0.42537078
8 channel
Dice overlap:  0.7251035
IoU:  0.5784673
9 channel
Dice overlap:  0.61174107
IoU:  0.44883892
10 channel
Dice overlap:  0.5632103
IoU:  0.40399128
11 channel
Dice overlap:  0.6468386
IoU:  0.49341366


## How to train the model (not recommended if cuda is not available. Our experiments were conducted on the HPC)
- All the experiments were logged using Weights&Biases. This is commented out in the below code

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
torchvision.disable_beta_transforms_warning()
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import models
#from torchsummary import summary
import torch.optim as optim
from torchvision.transforms import v2
from time import time
import wandb
from model import BaseUnet, BaseUnet3D
from dataset import BrightfieldMicroscopyDataset
from early_stopping import EarlyStopping
from arguments import parse_args, parse_args_3dunet
from evaluation_metrics import dice_overlap, intersection_over_union, accuracy, sensitivity, specificity
from data_processing_tools import remove_repeating_pattern

torch.manual_seed(276)

def get_dataloader(sample_size, batch_size):
 
    image_root = '/zhome/70/5/14854/nobackup/deeplearningf24/forcebiology/data/brightfield'
    mask_root = '/zhome/70/5/14854/nobackup/deeplearningf24/forcebiology/data/masks'

    # image_root = 'data/brightfield'
    # mask_root = 'data/masks'

    transform_train = v2.Compose([
        v2.Resize((sample_size, sample_size)),
        v2.RandomRotation(degrees=15),
        v2.RandomHorizontalFlip(p=0.3),
        v2.ToDtype(torch.float32, scale=True),
        v2.ToTensor(),
    ])

    transform_val = v2.Compose([
        v2.Resize((sample_size, sample_size)),
        v2.ToDtype(torch.float32, scale=True),
        v2.ToTensor(),
    ])

    brightfield_train_datatset = BrightfieldMicroscopyDataset(root_dir_images=image_root, root_dir_labels=mask_root, train=True, transform=transform_train)
    brightfield_val_datatset = BrightfieldMicroscopyDataset(root_dir_images=image_root, root_dir_labels=mask_root, train=False, validation=True, transform=transform_val)
    brightfield_test_datatset = BrightfieldMicroscopyDataset(root_dir_images=image_root, root_dir_labels=mask_root, train=False, validation=False, transform=transform_val)

    brightfield_loader_train = DataLoader(brightfield_train_datatset,  batch_size=batch_size, shuffle=True)
    brightfield_loader_val = DataLoader(brightfield_val_datatset,  batch_size=batch_size, shuffle=True)
    brightfield_loader_test = DataLoader(brightfield_test_datatset,  batch_size=1, shuffle=False)

    return brightfield_loader_train, brightfield_loader_val, brightfield_loader_test

def checkpoint_model(model, optimiser, epoch, path='model.pth'):
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimiser_state_dict': optimiser.state_dict(),
            }, path)

def save_model(model, path='model.pth'):
    torch.save(model.state_dict(), path)

def train_model(model, train_loader, val_loader, test_loader, optimiser, lr_scheduler, criterion, device, args, early_stopping, num_epochs=10):
    # Initialize W&B run
    # wandb.init(
    #     project=args.project_name,         
    #     entity="hndrkjs-danmarks-tekniske-universitet-dtu",           
    #     config={
    #         "epochs": num_epochs,
    #         "learning_rate": optimiser.param_groups[0]['lr'],
    #         "batch_size": args.batch_size,
    #         "model_name": args.model_name,
    #     }
    # )
    
    model.to(device)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            images, labels = data

            # remove repeating pattern
            for i in range(images.shape[0]):
                images[i] = torch.tensor(remove_repeating_pattern(images[i].numpy()))

            images, labels = images.to(device), labels.to(device)

            if args.train_3d:
                images = images.unsqueeze(2)
                labels = labels.unsqueeze(1)
         
            optimiser.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels.float())
            loss.backward()
            optimiser.step()
            running_loss += loss.item()

        # Calculate average training loss for the epoch
        avg_train_loss = running_loss / len(train_loader)
        
        # Validation step
        model.eval()
        val_loss = 0.0
        dice = 0
        iou = 0
        acc = 0
        sens = 0
        spec = 0
        with torch.no_grad():
            for data in val_loader:
                images, labels = data
                # remove repeating pattern
                for i in range(images.shape[0]):
                    images[i] = torch.tensor(remove_repeating_pattern(images[i].numpy()))
                
                images, labels = images.to(device), labels.to(device)

                if args.train_3d:
                    images = images.unsqueeze(2)
                    labels = labels.unsqueeze(1)

                outputs = model(images)
                loss = criterion(outputs, labels.float())
                val_loss += loss.item()

                Y_pred = (outputs > 0.45).float()
                dice += dice_overlap(Y_pred, labels)
                iou += intersection_over_union(Y_pred, labels)
                acc += accuracy(Y_pred, labels)
                sens += sensitivity(Y_pred, labels)
                spec += specificity(Y_pred, labels)

                # concatenate y_batch and y_pred to log
                image_array = np.concatenate([labels[0].cpu().numpy(), Y_pred[0].detach().cpu().numpy()], axis=1)

                # images = wandb.Image(image_array, caption="Top: Input, Bottom: Output")
            
            dice /= len(val_loader)
            iou /= len(val_loader)
            acc /= len(val_loader)
            sens /= len(val_loader)
            spec /= len(val_loader)

        avg_val_loss = val_loss / len(val_loader)

        # Adjust learning rate after each epoch
        if args.lr_scheduler:
            lr_scheduler.step()

        # Log metrics to W&B
        # wandb.log({
        #     "epoch": epoch,
        #     "train_loss": avg_train_loss,
        #     "val_loss": avg_val_loss,
        #     "learning_rate": optimiser.param_groups[0]['lr'],
        #     "Dice": dice,
        #     "IoU": iou,
        #     "Accuracy": acc,
        #     "Sensitivity": sens,
        #     "Specificity": spec
        # })
        # wandb.log({"Predicted segmentation": images})

        # Save model checkpoint to W&B
        # if epoch % 10 == 0:
        #     #wandb.save('checkpoint_model_{}_{}.pth'.format(args.project_name, epoch))
        #     checkpoint_model(model, optimiser, epoch, path='checkpoint_model_{}_{}.pth'.format(args.project_name, epoch))

        early_stopping(avg_val_loss, model)

        if early_stopping.early_stop:
            print("Early stopping")
            break
    
    # Save the final model and upload to W&B
    final_model_path = 'final_model_classification{}.pth'.format(args.model_name)
    save_model(model, path=final_model_path)
    
    # test the model
    model.eval()
    test_loss = 0
    dice = 0
    iou = 0
    acc = 0
    sens = 0
    spec = 0

    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            # remove repeating pattern
            for i in range(images.shape[0]):
                images[i] = torch.tensor(remove_repeating_pattern(images[i].numpy()))
            
            images, labels = images.to(device), labels.to(device)

            if args.train_3d:
                images = images.unsqueeze(2)
                labels = labels.unsqueeze(1)

            outputs = model(images)
            loss = criterion(outputs, labels.float())
            test_loss += loss.item()

            Y_pred = (outputs > 0.45).float()
            dice += dice_overlap(Y_pred, labels)
            iou += intersection_over_union(Y_pred, labels)
            acc += accuracy(Y_pred, labels)
            sens += sensitivity(Y_pred, labels)
            spec += specificity(Y_pred, labels)
        
        dice /= len(test_loader)
        iou /= len(test_loader)
        acc /= len(test_loader)
        sens /= len(test_loader)
        spec /= len(test_loader)
    
    # wandb.log({
    #     "test_loss": test_loss / len(test_loader),
    #     "Dice Test": dice,
    #     "IoU Test": iou,
    #     "Accuracy Test": acc,
    #     "Sensitivity Test": sens,
    #     "Specificity Test": spec
    # })

    # # Finish the W&B run
    # wandb.finish()

## Cell to run the training
- This cell calls the training function for the 2D Unet with 11 channels as input. 
- It is generally preferrred to do the training by calling the train.py file.
- Arguments for the model can be specified in the arguments.py file

In [None]:
args = parse_args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

torch.backends.cudnn.deterministic = True

model = BaseUnet(num_inputs=11)

train_loader, val_loader, test_loader = get_dataloader(args.sample_size, args.batch_size)

criterion = nn.BCEWithLogitsLoss()

# Initialize optimiser and learning rate scheduler
if args.optimiser == 'adam':
    optimiser = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
elif args.optimiser == 'sgd':
    optimiser = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
else:
    raise ValueError('optimiser should be either adam or sgd')
    
lr_scheduler = optim.lr_scheduler.StepLR(optimiser, step_size=args.step_size, gamma=args.gamma)

early_stopping = EarlyStopping(patience=args.patience, delta=args.delta, verbose=False, path='early_stopping_model{}.pth'.format(args.model_name))

# Train the model
train_model(model, train_loader, val_loader, test_loader,
                num_epochs=args.num_epochs, 
                optimiser=optimiser, lr_scheduler=lr_scheduler, 
                criterion=criterion, device=device, args=args, 
                early_stopping=early_stopping)