In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import torchvision.transforms.v2 as v2
from torchvision.transforms import transforms
import os 
import sys
import glob
from scipy import fftpack, ndimage
import re
from skimage.morphology import opening, disk
from skimage import img_as_float
from torch.utils.data import DataLoader
from model import BaseUnet, BaseUnet3D
from data_processing_tools import remove_repeating_pattern
from evaluation_metrics import dice_overlap, intersection_over_union, accuracy, sensitivity, specificity
import timeit
from dataset import BrightfieldMicroscopyDataset

## Loading all the models

In [None]:
# one_channel_unet = BaseUnet(num_inputs=1)
# one_channel_unet.load_state_dict(torch.load('early_stopping_modelablation_channel_base_unet_1.pth', map_location=torch.device('cpu')))

# two_channel_unet = BaseUnet(num_inputs=2)
# two_channel_unet.load_state_dict(torch.load('early_stopping_modelablation_channel_base_unet_2.pth', map_location=torch.device('cpu')))

# three_channel_unet = BaseUnet(num_inputs=3)
# three_channel_unet.load_state_dict(torch.load('early_stopping_modelablation_channel_base_unet_3.pth', map_location=torch.device('cpu')))

# four_channel_unet = BaseUnet(num_inputs=4)
# four_channel_unet.load_state_dict(torch.load('early_stopping_modelablation_channel_base_unet_4.pth', map_location=torch.device('cpu')))

# five_channel_unet = BaseUnet(num_inputs=5)
# five_channel_unet.load_state_dict(torch.load('early_stopping_modelablation_channel_base_unet_5.pth', map_location=torch.device('cpu')))

# six_channel_unet = BaseUnet(num_inputs=6)
# six_channel_unet.load_state_dict(torch.load('early_stopping_modelablation_channel_base_unet_6.pth', map_location=torch.device('cpu')))

# seven_channel_unet = BaseUnet(num_inputs=7)
# seven_channel_unet.load_state_dict(torch.load('early_stopping_modelablation_channel_base_unet_7.pth', map_location=torch.device('cpu')))

# eight_channel_unet = BaseUnet(num_inputs=8)
# eight_channel_unet.load_state_dict(torch.load('early_stopping_modelablation_channel_base_unet_8.pth', map_location=torch.device('cpu')))

# nine_channel_unet = BaseUnet(num_inputs=9)
# nine_channel_unet.load_state_dict(torch.load('early_stopping_modelablation_channel_base_unet_9.pth', map_location=torch.device('cpu')))

# ten_channel_unet = BaseUnet(num_inputs=10)
# ten_channel_unet.load_state_dict(torch.load('early_stopping_modelablation_channel_base_unet_10.pth', map_location=torch.device('cpu')))

# eleven_channel_unet = BaseUnet(num_inputs=11)
# eleven_channel_unet.load_state_dict(torch.load('early_stopping_modelablation_channel_base_unet_11.pth', map_location=torch.device('cpu')))

base_unet_no_preprocessing = BaseUnet()
base_unet_no_preprocessing.load_state_dict(torch.load('early_stopping_modelUnet_simple_11_nopreprocess.pth', map_location=torch.device('cpu')))

base_unet_no_preprocessing_3d = BaseUnet3D()
base_unet_no_preprocessing_3d.load_state_dict(torch.load('early_stopping_modelUnet3D_simple_11_no_preprocess.pth', map_location=torch.device('cpu')))

base_unet_preprocessing = BaseUnet()
base_unet_preprocessing.load_state_dict(torch.load('early_stopping_modelUnet_simple_11_preprocess.pth', map_location=torch.device('cpu')))

base_unet_preprocessing_3d = BaseUnet3D()
base_unet_preprocessing_3d.load_state_dict(torch.load('early_stopping_modelUnet3D_simple_11_preprocess.pth', map_location=torch.device('cpu')))

# one_channel_unet.eval()
# two_channel_unet.eval()
# three_channel_unet.eval()
# four_channel_unet.eval()
# five_channel_unet.eval()
# six_channel_unet.eval()
# seven_channel_unet.eval()
# eight_channel_unet.eval()
# nine_channel_unet.eval()
# ten_channel_unet.eval()
# eleven_channel_unet.eval()
base_unet_no_preprocessing.eval()
base_unet_no_preprocessing_3d.eval()
base_unet_preprocessing.eval()
base_unet_preprocessing_3d.eval()

## Dataloader and Testing function

In [6]:
def get_dataloader_test(sample_size=512, batch_size=1, channels=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]):

    # image_root = '/zhome/70/5/14854/nobackup/deeplearningf24/forcebiology/data/brightfield'
    # mask_root = '/zhome/70/5/14854/nobackup/deeplearningf24/forcebiology/data/masks'
    image_root = 'data/brightfield'
    mask_root = 'data/masks'

    transform_test = v2.Compose([
        v2.Resize((sample_size, sample_size)),
        v2.ToDtype(torch.float32, scale=True),
        v2.ToTensor(),
    ])

    brightfield_test_datatset = BrightfieldMicroscopyDataset(root_dir_images=image_root, root_dir_labels=mask_root, train=False, validation=False, transform=transform_test, channels_to_use=channels)

    brightfield_loader_test = DataLoader(brightfield_test_datatset,  batch_size=batch_size, shuffle=False)

    return brightfield_loader_test

In [8]:
def test_model(model, test_loader, criterion, device, preprocessing=False, test_3d=False):
    test_loss = []
    dice = []
    iou = []
    acc = []
    sens = []
    spec = []

    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            # remove repeating pattern
            if preprocessing:
                for i in range(images.shape[0]):
                    images[i] = torch.tensor(remove_repeating_pattern(images[i].numpy()))

            images, labels = images.to(device), labels.to(device)

            if test_3d:
                images = images.unsqueeze(2)
                labels = labels.unsqueeze(1)

            outputs = model(images)
            loss = criterion(outputs, labels.float())
            test_loss.append(loss.item())

            Y_pred = (outputs > 0.45).float()
            
            dice.append(dice_overlap(Y_pred, labels))
            iou.append(intersection_over_union(Y_pred, labels))
            acc.append(accuracy(Y_pred, labels))
            sens.append(sensitivity(Y_pred, labels))
            spec.append(specificity(Y_pred, labels))

    return test_loss, dice, iou, acc, sens, spec
        

## Testing the 2D and 3D Unets with and without preprocessing

In [None]:
test_loss_3d_no_preprocess, dice_3d_no_preprocess, iou_3d_no_preprocess, acc_3d_no_preprocess, sens_3d_no_preprocess, spec_3d_no_preprocess = test_model(base_unet_no_preprocessing_3d, get_dataloader_test(channels=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), nn.BCEWithLogitsLoss(), torch.device('cpu'), preprocessing=False, test_3d=True)