In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from monai.networks.nets import UNet
from typing import List, Tuple, Dict, Any, Union, Optional

import os, sys

#TODO: fix this, this is temporary
sys.path.append(os.path.abspath('M:/Users/netan/Desktop/Coding Folder/Openu-DS-Workshop/src'))

import src.config as cfg
from src.loaders import load_dicom_series

In [65]:
# This class should not be changed.
# The model is UNet, the class just converts UNet to a classifier
class UNetClassifier(nn.Module):
    def __init__(self, num_classes=25):
        super(UNetClassifier, self).__init__()
    
        # Normal UNet model
        self.unet = UNet(
            spatial_dims=3,  # 3D UNet
            in_channels=1,
            out_channels=128,
            channels=(16, 32, 64, 128, 128), # The 128 in the end is cut off
            strides=(2, 2, 2, 2),
            num_res_units=2,
        )
        
        # Change the path in the config to match your path
        # Loads the weights of 'Wholebrainseg large unest' model
        self.unet.load_state_dict(torch.load(cfg.MODEL_PATH), strict=False)
        
        # Remove the head of the sagmantation 
        self.unet.model = self.unet.model[:-1]
        
        # Global Average Pooling to convert feature maps to a vector
        self.global_avg_pool = nn.AdaptiveAvgPool3d(1)
        
        # Fully connected layer to classify into num_classes categories
        self.fc = nn.Linear(32, num_classes)
    
    def forward(self, x):
        # Pass through the UNet backbone
        x = self.unet(x)
        
        x = self.global_avg_pool(x)
        
        # Flatten the output and pass it through the classifier
        x = torch.flatten(x, 1)
        x = self.fc(x) 
        
        return x

In [66]:
# Uses trilinear interpolation to resize the image
def resize_image_stack(image_stack: np.ndarray, new_size: Tuple[int, int, int]) -> torch.Tensor:
    image_tensor = torch.Tensor(image_stack).unsqueeze(0).unsqueeze(0) # interpolate needs 5D tensor
    resized_tensor = F.interpolate(image_tensor, size=new_size, mode='trilinear', align_corners=False)
    
    return resized_tensor

In [None]:
model = UNetClassifier(num_classes=25)
model.eval()

# Load the DICOM series
image3d = load_dicom_series(cfg.DATA_PATH / cfg.EXAMPLE_STIR_ID)

# The model was trained on 96x96x96
resized_image_tensor = resize_image_stack(image3d, new_size=(96, 96, 96))

#TODO: this is just for testing the model. This needs to be replaced
with torch.no_grad():
    output = model(resized_image_tensor)

predicted_class = torch.argmax(output, dim=1)

predicted_class_np = predicted_class.cpu().numpy()

print("Predicted Class:", predicted_class_np)