In [1]:
%cd "/home/albin/skolarbete/DML_LAsegmentation"

/home/albin/skolarbete/DML_LAsegmentation


In [2]:
from scripts.train import patched_forward
from models.unet import UNet3D, NormalizationType
import pandas as pd
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
from monai.networks.nets import UNETR, SwinUNETR
from scripts.dataset import TestDataset
import SimpleITK as sitk

sns.set_theme(style="whitegrid")
custom_palette = ["#D32F2F", "#1976D2", "#4CAF50"]  # Red, Blue, Light Green
sns.set_palette(custom_palette)

In [3]:
val_image_path = 'data/Task02_Heart/imagesVl'
val_label_path = 'data/Task02_Heart/labelsVl'

patch_size = (64, 128, 128)
val_dataset = TestDataset(val_image_path, val_label_path, scale_intensity=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)

In [4]:
# Load UNet3D model results
file_name_unet3d = 'unet_model_results.pkl'
with open(file_name_unet3d, 'rb') as f:
    data_unet3d = pickle.load(f)
    best_model_unet3d = UNet3D(in_channels=1, out_channels=1, features=[32, 64, 128, 256], normalization=NormalizationType.GROUP_NORM)
    best_model_unet3d.load_state_dict(data_unet3d['model_state'])
    results_unet3d = (best_model_unet3d, data_unet3d['train_metrics'], data_unet3d['val_metrics'])

best_model_unet3d, results_train_unet3d, results_val_unet3d = results_unet3d

# Load UNETR model results
file_name_unetr = 'unetR_model_results.pkl'
with open(file_name_unetr, 'rb') as f:
    data_unetr = pickle.load(f)
    best_model_unetr = UNETR(in_channels=1, out_channels=1, img_size=(64, 128, 128))
    best_model_unetr.load_state_dict(data_unetr['model_state'])
    results_unetr = (best_model_unetr, data_unetr['train_metrics'], data_unetr['val_metrics'])

best_model_unetr, results_train_unetr, results_val_unetr = results_unetr

# Load SwinUNETR model results
file_name_swinunetr = 'swinUnetR_model_results.pkl'
with open(file_name_swinunetr, 'rb') as f:
    data_swinunetr = pickle.load(f)
    best_model_swinunetr = SwinUNETR(in_channels=1, out_channels=1, img_size=(64, 128, 128))
    best_model_swinunetr.load_state_dict(data_swinunetr['model_state'])
    results_swinunetr = (best_model_swinunetr, data_swinunetr['train_metrics'], data_swinunetr['val_metrics'])

best_model_swinunetr, results_train_swinunetr, results_val_swinunetr = results_swinunetr



In [5]:
test_image_path = 'data/Task02_Heart/imagesVl'
test_label_path = 'data/Task02_Heart/labelsVl'

device = 'cuda'
test_dataset = TestDataset(test_image_path, test_label_path, scale_intensity=True)
image_idx = test_dataset.data_paths.index('data/Task02_Heart/imagesVl/la_030.nii')
image, label = test_dataset[image_idx]
image = image.unsqueeze(0)

In [6]:
image = image.to(device)
best_model_unet3d = best_model_unet3d.to(device)
best_model_unetr = best_model_unetr.to(device)
best_model_swinunetr = best_model_swinunetr.to(device)

with torch.no_grad():
    pred_unet = patched_forward(best_model_unet3d, image, patch_size, overlap=0.5, device=device).cpu()
    pred_unetr = patched_forward(best_model_unetr, image, patch_size, overlap=0.5, device=device).cpu()
    pred_swinunetr = patched_forward(best_model_swinunetr, image, patch_size, overlap=0.5, device=device).cpu()

  ret = func(*args, **kwargs)


In [7]:
pred_unet_sigmoid = torch.sigmoid(pred_unet).squeeze().numpy()
pred_unetr_sigmoid = torch.sigmoid(pred_unetr).squeeze().numpy()
pred_swinunetr_sigmoid = torch.sigmoid(pred_swinunetr).squeeze().numpy()

In [20]:
# Apply threshold to make the arrays binary
pred_unet_sigmoid_binary = (pred_unet_sigmoid > 0.5).astype(int)
pred_unetr_sigmoid_binary = (pred_unetr_sigmoid > 0.5).astype(int)
pred_swinunetr_sigmoid_binary = (pred_swinunetr_sigmoid > 0.5).astype(int)


In [22]:
# Function to save tensor as .nii.gz with metadata from a reference file
def save_tensor_as_nii_with_metadata(array, filename, reference_file):
    # Load the reference image
    reference_image = sitk.ReadImage(reference_file)
    
    # Create a new image with the provided array
    new_image = sitk.GetImageFromArray(array)
    
    # Copy metadata from the reference image
    new_image.CopyInformation(reference_image)
    
    # Save the new image
    sitk.WriteImage(new_image, filename)

# Save the predictions with metadata from the reference file
reference_file = 'data/Task02_Heart/imagesVl/la_030.nii'
save_tensor_as_nii_with_metadata(pred_unet_sigmoid_binary, r'./output/pred_unet_030.nii', reference_file)
save_tensor_as_nii_with_metadata(pred_unetr_sigmoid_binary, r'./output/pred_unetr_030.nii', reference_file)
save_tensor_as_nii_with_metadata(pred_swinunetr_sigmoid_binary, r'./output/pred_swinunetr_030.nii', reference_file)