In [1]:
# Import PyTorch Data Loader Library
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchvision.models.segmentation import fcn_resnet50

# Other Library Imports
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import SimpleITK as sitk

# Custom Imports
from DRAC_Dataloading import DRAC_Loader
from DRAC_Models import ResNetBinary
from DRAC_Training import train_model
from DRAC_Testing import test_model
from DRAC_Criterion import DiceLoss

# Warnings
import warnings
warnings.filterwarnings("ignore")

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
print(torch.cuda.get_device_name(0))

cuda:0
NVIDIA RTX A5000


In [3]:
batch_size = 1

In [4]:
# Load in the custom models
name = os.listdir("./models/")
name

['fcn_resnet50_3_250_DiceLoss',
 'fcn_resnet50_2_100_CrossEntropyLoss',
 'fcn_resnet50_2_50_CrossEntropyLoss']

In [5]:
file = "./models/" + name[0] + "/"
models = os.listdir(file)
models = [file + model for model in models]

# If any item in the list is a directory, remove it
models = [model for model in models if not os.path.isdir(model)]

# Reorder the list so that the .pth files are in the order Intraretinal.pth, Neovascularization.pth, Nonperfusion.pth


models

['./models/fcn_resnet50_3_250_DiceLoss/Intraretinal.pth',
 './models/fcn_resnet50_3_250_DiceLoss/Nonperfusion.pth',
 './models/fcn_resnet50_3_250_DiceLoss/Neovascularization.pth']

In [6]:
# Within file create folder called "Predictions"
if not os.path.exists(file + "Predictions"):
    os.makedirs(file + "Predictions")

# Set another location
predictions = file + "Predictions/"

In [7]:
test_data_intra = DRAC_Loader(data_type = 'test', transform = None, mask = "intraretinal")
test_loader_intra = DataLoader(test_data_intra, batch_size = batch_size, shuffle = True)
test_data_neo = DRAC_Loader(data_type = 'test', transform = None, mask = "neovascular")
test_loader_neo = DataLoader(test_data_neo, batch_size = batch_size, shuffle = True)
test_data_nonper = DRAC_Loader(data_type = 'test', transform = None, mask = "nonperfusion")
test_loader_nonper = DataLoader(test_data_nonper, batch_size = batch_size, shuffle = True)

# Create the list of Data Loaders
test_loaders = [test_loader_intra, test_loader_neo, test_loader_nonper]

In [8]:
count = 0

for model1 in models:
    # Load in the model 
    model = fcn_resnet50(num_classes=2)
    model.load_state_dict(torch.load(model1))
    model.eval()
    model.to(device)
    
    # Test the model, run if statements...
    if "Intraretinal" in model1:
        loader = test_loaders[0]
        loader_name = "Intraretinal"
        nii_file = "1.nii.gz"
    elif "Neovascularization" in model1:
        loader = test_loaders[1]
        loader_name = "Neovascularization"
        nii_file = "3.nii.gz"
    elif "Nonperfusion" in model1:
        loader = test_loaders[2]
        loader_name = "Nonperfusion"
        nii_file = "2.nii.gz"
    else:
        print("Error")
        break
    
    # Create an empty array to store the predicted masks
    predicted_masks = np.zeros((len(loader), 1024, 1024))
    
    for i, (inputs, name) in enumerate(loader):
        # Permute
        inputs = inputs.permute(0, 3, 1, 2)
        
        inputs = inputs.to(device)
        outputs = model(inputs)
        _, preds = torch.max(outputs['out'], 1)
        preds = preds.cpu().numpy()
        
        # Store the predicted mask in the array
        predicted_masks[i] = preds[0]
    print(predicted_masks.shape)    
    predicted_masks_sitk = sitk.GetImageFromArray(predicted_masks)
    sitk.WriteImage(predicted_masks_sitk, f"{predictions}{nii_file}")
    
    saveimages = False
    if saveimages:
        for inputs, name in loader:
            # Permute
            inputs = inputs.permute(0, 3, 1, 2)
            
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs['out'], 1)
            preds = preds.cpu().numpy()
            
            if not os.path.exists(f"{predictions}{loader_name}"):
                os.makedirs(f"{predictions}{loader_name}")
            
            for i in range(len(preds)):
                # Set the name of the file
                name = f"{predictions}{loader_name}/{name[0]}.png"
                
                plt.imsave(f"{name}", preds[i], cmap = 'gray')
                count += 1

(65, 1024, 1024)
(65, 1024, 1024)
(65, 1024, 1024)


In [9]:
# Load "./models/fcn_resnet50_2_50_CrossEntropyLoss/Predictions/1.nii.gz"
image = sitk.ReadImage("./models/fcn_resnet50_3_250_DiceLoss/Predictions/1.nii.gz")

In [10]:
# Get the image dimensions
size = image.GetSize()
print("Image size:", size)

# Get the voxel spacing
spacing = image.GetSpacing()
print("Voxel spacing:", spacing)

# Get the image origin
origin = image.GetOrigin()
print("Image origin:", origin)

# Access the image data as a numpy array
image_array = sitk.GetArrayFromImage(image)
print("Image array shape:", image_array.shape)

Image size: (1024, 1024, 65)
Voxel spacing: (1.0, 1.0, 1.0)
Image origin: (0.0, 0.0, 0.0)
Image array shape: (65, 1024, 1024)
