# Model Inference Creation

In [None]:
## Importing Libs

from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd,
    Invertd,
    DivisiblePadd,
    RandAffined,
    RandRotated,
    RandGaussianNoised,
    ToTensor,
    Resized,
    FillHolesd,
    RemoveSmallObjectsd
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss, DiceCELoss
from monai.inferers import sliding_window_inference, SimpleInferer
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
from torch.utils.data import ConcatDataset
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
from datetime import datetime
import nibabel as nib
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
## Getting the dataset

##IMPORTANT: CHANGE HERE TO THE DATA PATH##
# it is made with the same format as written in the cluster

test_images_dir = "/tsi/data_education/data_challenge/test/volume"

test_images = sorted(glob.glob(os.path.join(test_images_dir,"*.nii*")))
test_data = [{"image": image} for image in test_images]

In [None]:
## TRANSFORMS - PreProcessing & PostProcessing

##IMPORTANT: CHANGE HERE TO THE PLACE YOU WANT TO SAVE THE DATA##

testD_transforms = Compose(
    [
        LoadImaged(keys=["image"]),
        EnsureChannelFirstd(keys=["image"]),
        CropForegroundd(keys=["image"], source_key="image"),
        Orientationd(keys=["image"], axcodes="RAS"),
        DivisiblePadd(["image"], 16),
        Resized(keys=["image"], spatial_size=(192,192,192))
    ]
)

post_transforms = Compose(
    [
        Invertd(
            keys="pred",
            transform=testD_transforms,
            orig_keys="image",
            meta_keys="pred_meta_dict",
            orig_meta_keys="image_meta_dict",
            meta_key_postfix="meta_dict",
            nearest_interp=False,
            to_tensor=True,
        ),
        AsDiscreted(keys="pred", argmax=True),
        FillHolesd(keys="pred"),
        SaveImaged(keys="pred", meta_keys="pred_meta_dict", output_dir="./predictionsSWIN", output_postfix="", resample=False),
    ]
)

In [None]:
## Creating the dataset

test_org_ds = Dataset(data=test_data, transform=testD_transforms)
test_org_loader = DataLoader(test_org_ds, batch_size=1)

In [None]:
## Loading the model

##IMPORTANT: CHANGE HERE TO THE PLACE WHERE THE MODEL IS STORED#

model = torch.load("./modelSWIN.h5").to(device)

In [None]:
## Starting the inference

inferer = SimpleInferer()
model.eval()

results = []
counter = 0

with torch.no_grad():
    for test_data in test_org_loader:
        counter += 1
        test_inputs = test_data["image"].to(device)
        test_data["pred"] = inferer(test_inputs,  model)
        print(f"prediciting {counter}")

        test_data = [post_transforms(i) for i in decollate_batch(test_data)]