# Test post-processing transform to pad network output to original image size

In [None]:
import os
import sys
import tempfile
from glob import glob
import logging

import nibabel as nib
import numpy as np
import torch
from matplotlib import pyplot as plt
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator, _prepare_batch
from ignite.handlers import ModelCheckpoint
from torch.utils.data import DataLoader

import monai
from monai.data import NiftiDataset, list_data_collate
from monai.transforms import (
    Activationsd,
    AddChanneld,
    NormalizeIntensityd,
    AsDiscreted,
    Resized,
    Compose,
    KeepLargestConnectedComponentd,
    LoadNiftid,
    RandCropByPosNegLabeld,
    RandRotated,
    RandFlipd,
    ToTensord,
    MapTransform,
    CropForegroundd,
    SpatialCrop
)
from monai.utils import set_determinism

# from ipynb.fs.full.io_utils import create_data_list
sys.path.append("/mnt/data/mranzini/Desktop/GIFT-Surg/FBS_Monai/basic_unet_monai/src/")
from io_utils import create_data_list
from custom_transform import ConverToOneHotd, MinimumPadd, CropForegroundAnisotropicMargind, PadToOriginalSized

monai.config.print_config()
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

cuda_device=2
torch.cuda.set_device(cuda_device)
set_determinism(seed=0)

In [None]:
root_dir = "/mnt/data/mranzini/Desktop/GIFT-Surg/Data/NeuroImage_dataset"
val_files = [{'img': os.path.join(*[root_dir, "GroupA", "a01_02_Label.nii.gz"]),
              'seg': os.path.join(*[root_dir, "GroupA", "a01_02_Label.nii.gz"]),
              'mask': os.path.join(*[root_dir, "GroupA", "a01_02_Label.nii.gz"])}]
                                

# data preprocessing for inference:
# - convert data to right format [batch, channel, dim, dim, dim]
# - apply whitening
# - NOTE: resizing needs to be applied afterwards, otherwise it cannot be remapped back to original size
val_transforms = Compose([
    LoadNiftid(keys=['img', 'seg', 'mask']),
    AddChanneld(keys=['img', 'seg', 'mask']),
#     NormalizeIntensityd(keys=['img']),
    CropForegroundAnisotropicMargind(keys=['img'], source_key='mask', margin=[20, 20, 5]),
    ToTensord(keys=['img', 'seg', 'mask'])
])

val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
val_loader = monai.data.DataLoader(val_ds,
                                   batch_size=1,
                                   num_workers=1)

def prepare_batch(batchdata):
    assert isinstance(batchdata, dict), "prepare_batch expects dictionary input data."
    return (
        (batchdata['img'], batchdata['mask'])
        if 'mask' in batchdata
        else (batchdata['mask'], None)
    )

valid_data = monai.utils.misc.first(val_loader)
print("Validation data tensor shapes")
print(valid_data['img'].shape, valid_data['seg'].shape, valid_data['mask'].shape)

val_post_transform = PadToOriginalSized(keys=['img'], source_key='mask', margin=[20, 20, 5])
output_valid_data = val_post_transform(valid_data)

In [None]:
orig = valid_data['seg'].detach().cpu().numpy()
crop = valid_data['img'].detach().cpu().numpy()
post = output_valid_data['img'].detach().cpu().numpy()

print(orig.shape, crop.shape, post.shape)

print(np.sum(np.abs(orig-post)))

In [None]:
slice_orig = 18
slice_post = 18
plt.figure(figsize=(10, 5))
plt.subplot(131)
plt.imshow(orig[0, 0, :, :, slice_orig], interpolation="nearest")
plt.subplot(132)
plt.imshow(post[0, 0, :, :, slice_post], interpolation="nearest")
plt.subplot(133)
# plt.imshow(crop[0, 0, :, :, 11])
plt.imshow(orig[0, 0, :, :, slice_orig] - post[0, 0, :, :, slice_post], interpolation="nearest")