In [None]:
import os
import sys
import tempfile
import shutil
from glob import glob
import logging
import nibabel as nib
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import monai
from monai.data import NiftiDataset, create_test_image_3d
from monai.inferers import sliding_window_inference
from monai.metrics import compute_meandice
from monai.visualize.img2tensorboard import plot_2d_or_3d_image
from monai.transforms import \
    Compose, AddChannel, LoadNifti, \
    ScaleIntensity, RandSpatialCrop, \
    ToTensor, CastToType, SpatialPad

monai.config.print_config()


# Main function
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # Supervised learning data for training and validation
    data_dir = '/home/marafath/scratch/iran_organized_data/test'

    test_images = []
    test_labels = []
    test_dir = []

    for patient in os.listdir(data_dir):
        for series in os.listdir(os.path.join(data_dir,patient)):
            test_images.append(os.path.join(data_dir,patient,series,'image.nii.gz'))
            test_labels.append(os.path.join(data_dir,patient,series,'segmentation.nii.gz'))
            test_dir.append(os.path.join(data_dir,patient,series))
                    
    # Defining Transform
    test_imtrans = Compose([
        ScaleIntensity(),
        AddChannel(),
        CastToType(),
        SpatialPad((96, 96, 96), mode='constant'),
        ToTensor()
    ])
    test_segtrans = Compose([
        AddChannel(),
        CastToType(),
        SpatialPad((96, 96, 96), mode='constant'),
        ToTensor()
    ])

    # create a validation data loader
    test_ds = NiftiDataset(test_images, test_labels, transform=test_imtrans, seg_transform=test_segtrans)
    test_loader = DataLoader(test_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available())

    # Defining model and hyperparameters
    device = torch.device('cuda:0')
    model = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=6,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    
    model.load_state_dict(torch.load('/home/marafath/scratch/saved_models/eu_model_f0.pth'))

    model.eval()
    with torch.no_grad():
        metric_sum = 0.
        metric_count = 0
        val_images_ = None
        val_labels_ = None
        val_outputs = None
        i = 0
        for val_data in test_loader:
            val_images_, val_labels_ = val_data[0].to(device), val_data[1].to(device)
            roi_size = (160, 160, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images_, roi_size, sw_batch_size, model)
            predicted_mask = torch.argmax(val_outputs, dim=1).detach().cpu().numpy() 
            predicted_mask = np.squeeze(predicted_mask)
            
            predicted_mask = nib.Nifti1Image(predicted_mask, np.eye(4))
            nib.save(predicted_mask, os.path.join(test_dir[i],'segmentation_lobes.nii.gz')) 
            print('Done '+str(i))
            
            i += 1

if __name__ == '__main__':
    main()