## Note: 
- This juypter notebook is to perform example for  SETD Model Inference


In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
import shutil
import tempfile
import pandas as pd
import torch
from torch.nn import MSELoss
from monai.apps import download_url, download_and_extract
from monai.config import print_config
from monai.data import DataLoader, Dataset, CacheDataset
from monai.losses import BendingEnergyLoss, MultiScaleLoss, DiceLoss
from monai.metrics import DiceMetric
from monai.networks.blocks import Warp
from monai.networks.nets import LocalNet
from monai.transforms import (
    Compose,
    LoadImaged,
    RandAffined,
    Resized,
    ScaleIntensityRanged,
    CropForegroundd,
    RandRotated,
)
from monai.utils import set_determinism, first
import glob
import torch
from resnet50 import generate_model
import torch.nn as nn
device = torch.device("cuda:0")

In [None]:
class ResNet50_GRU_Model(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=2, dropout=0.2, seq_len=128, pretrain_path='./path/to/resnet_50_23dataset.pth'):
        super(ResNet50_GRU_Model, self).__init__()
        self.resnet50model = generate_model(input_W=96, input_H=96, input_D=14, pretrain_path=pretrain_path, pretrained=True)
        for name, param in self.resnet50model.named_parameters():
            if 'conv_seg' not in name and 'reduce_channels' not in name and 'reduce_bn' not in name and 'reduce_relu' not in name:
                param.requires_grad = False
        self.hidden_size = hidden_size
        self.gru = nn.GRU(input_size, hidden_size, num_layers=num_layers, dropout=dropout, batch_first=True)
        self.seq_len = seq_len
        self.bn = nn.BatchNorm1d(hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)
    def forward(self, image2, image3):
        ddf2 = self.resnet50model(image2)
        ddf3 = self.resnet50model(image3)
        encoder_concatenated = torch.cat([ddf2.squeeze(-3).squeeze(-2).squeeze(-1), ddf3.squeeze(-3).squeeze(-2).squeeze(-1)], dim=1)
        output, _ = self.gru(encoder_concatenated.unsqueeze(1).repeat(1, self.seq_len, 1))
        output = output[:, -1, :]
        output = self.bn(output)
        output = self.fc(output)
        return output

In [None]:
import SimpleITK as sitk
def maskcroppingbox(data, mask, use2D=False):
    mask_2 = np.argwhere(mask)
    (zstart, ystart, xstart), (zstop, ystop, xstop) = mask_2.min(axis=0)-1, mask_2.max(axis=0) + 1
    zstart = max(0, zstart)
    ystart =max(0, ystart)
    xstart = max(0, xstart)
    zstop = min(data.shape[0], zstop)
    ystop = min(data.shape[1], ystop)
    xstop = min(data.shape[2], xstop)
    roi_image = data[zstart:zstop, ystart:ystop, xstart:xstop]
    roi_mask = mask[zstart:zstop, ystart:ystop, xstart:xstop]
    roi_image[roi_mask < 1] = 0
    return roi_image

artirial_imgpath = './path/to/3T_artirialorg.nii.gz'
nephrogenic_imgpath = './path/to/3T_nephrogenicorg.nii.gz'
artirial_maskpath = './path/to/3T_artirialseg.nii.gz'
nephrogenic_maskpath = './path/to/3T_nephrogenicseg.nii.gz'
artirial_img = sitk.ReadImage(artirial_imgpath)
artirial_img = sitk.GetArrayFromImage(artirial_img)
artirial_mask = sitk.ReadImage(artirial_maskpath)
artirial_mask = sitk.GetArrayFromImage(artirial_mask)
nephrogenic_img = sitk.ReadImage(nephrogenic_imgpath)
nephrogenic_img = sitk.GetArrayFromImage(nephrogenic_img)
nephrogenic_mask = sitk.ReadImage(nephrogenic_maskpath)
nephrogenic_mask = sitk.GetArrayFromImage(nephrogenic_mask)
artirial_croppedimg = maskcroppingbox(artirial_img, artirial_mask)
nephrogenic_croppedimg = maskcroppingbox(nephrogenic_img, nephrogenic_mask)

In [None]:
np.save("./path/to/artirial_croppedimg.npy", artirial_croppedimg)
np.save("./path/to/nephrogenic_croppedimg.npy", nephrogenic_croppedimg)

In [None]:
artirial_croppedimg_path = './path/to/unenhanced_croppedimg.npy'  # Provide the file path or PathLike object for the artirial cropped image
nephrogenic_croppedimg_path = './path/to/unenhanced_croppedimg.npy'  # Provide the file path or PathLike object for the nephrogenic cropped image

In [None]:
val_external_transforms = Compose(
    [
        LoadImaged(keys=["image2", "image3"], ensure_channel_first=True),
        ScaleIntensityRanged(
            keys=["image2", "image3"],
            a_min=0,
            a_max=250,
            b_min=0,
            b_max=250,
            clip=True,
        ),
        Resized(
            keys=["image2", "image3"],
            mode="trilinear",
            align_corners=True,
            spatial_size=(96, 96, 14),
        ),
    ]
)
input_data = {
    "image2": artirial_croppedimg_path,
    "image3": nephrogenic_croppedimg_path,
}
input_data = val_external_transforms(input_data)
image2 = torch.from_numpy(np.array(input_data["image2"])).to(device)
image3 = torch.from_numpy(np.array(input_data["image3"])).to(device)
image2 = image2.unsqueeze(0)
image3 = image3.unsqueeze(0)

In [None]:
import pandas as pd
device = torch.device("cuda:0")
model = ResNet50_GRU_Model(input_size=256, hidden_size=128, output_size=2).to(device)
model_path = "./path/to/best_metric_model_classification3d_array.pth"
model.load_state_dict(torch.load(model_path))
model.to(device)
model.eval()
with torch.no_grad():
    output = model(image2, image3)
    output_list = output.argmax(dim=1).cpu().numpy().tolist()
    proba = torch.nn.functional.softmax(output, dim=1).detach().cpu().numpy().tolist()
print(output_list)
print(proba)