## Setup imports

In [None]:
import sys
import os
import logging
import numpy as np
import torch
import glob
import monai.networks.nets as nets
from monai.transforms import (
    Compose,
    LoadImaged,
    AddChanneld,
    CropForegroundd,
    ToTensord,
    RandFlipd,
    RandAffined,
    SpatialPadd,
    Activationsd,
    Activations,
    Resized,
    AsDiscreted,
    AsDiscrete,
    GaussianSmoothd,
    SpatialCropd,
)
from transforms import (
    CTWindowd,
    CTSegmentation,
    RelativeCropZd,
    RandGaussianNoised,
)
from monai.data import DataLoader, Dataset, PersistentDataset, CacheDataset
from torchsampler import ImbalancedDatasetSampler
from monai.transforms.croppad.batch import PadListDataCollate
from monai.utils import NumpyPadMode, set_determinism
from monai.utils.enums import Method
from monai.config import print_config
from sklearn.model_selection import train_test_split
from trainer import Trainer
from validator import Validator
from tester import Tester
from utils import (
    multi_slice_viewer,
    setup_directories,
    get_data_from_info,
    large_image_splitter,
    calculate_class_imbalance,
    create_device,
    balance_training_data,
    balance_training_data2,
    transform_and_copy,
    convert_labels,
    load_mixed_images,
    replace_suffix,
)
from test_data_loader import TestDataset
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
print_config()

## Setup directories

In [None]:
dirs = setup_directories()

## Setup torch device

In [None]:
# pass "cuda" to use the GPU
device, using_gpu = create_device("cuda")

## Load and randomize images

In [None]:
# HACKATON image and segmentation data
hackathon_dir = os.path.join(dirs["data"], 'HACKATHON')
image_dir = os.path.join(hackathon_dir, 'images', 'test')
seg_dir = os.path.join(hackathon_dir, 'segmentations', 'test')
test_image_files = glob.glob(os.path.join(image_dir, '*'))
with open(os.path.join(hackathon_dir, "train.txt"), 'r') as fp:
    train_info_hackathon = [(os.path.basename(entry), None) for entry in test_image_files]

test_data_hackathon = get_data_from_info(image_dir, seg_dir, train_info_hackathon)

## Setup transforms

In [None]:
# Crop foreground
crop_foreground = CropForegroundd(
    keys=["image"],
    source_key="image",
    margin=(5, 5, 0),
    select_fn = lambda x: x != 0
)
# Crop Z
crop_z = RelativeCropZd(keys=["image"], relative_z_roi=(0.05, 0.15))
# Window width and level (window center)
WW, WL = 1500, -600
ct_window = CTWindowd(keys=["image"], width=WW, level=WL)
# Pad image to have hight at least 30
spatial_pad = SpatialPadd(keys=["image"], spatial_size=(-1, -1, 30))
# Resize image x and y
resize_fator = 0.50
xy_size = int(512*resize_fator)
#resize = Resized(keys=["image"], spatial_size=(int(512*resize_fator), int(512*resize_fator), -1), mode="trilinear")
resize1 = Resized(keys=["image"], spatial_size=(-1, -1, 40), mode="area")
resize2 = Resized(keys=["image"], spatial_size=(xy_size, xy_size, -1), mode="area")
# spatioal crop
crop = SpatialCropd(keys=["image"], roi_start=(0, 0, 4), roi_end=(xy_size, xy_size, 36))
# gaussian smooth
gaussian_noise_smooth = GaussianSmoothd(keys=["image"], sigma=(0.2, 0.2, 0.0))

#### Create transforms

In [None]:
common_transform = Compose([
    LoadImaged(keys=["image"]),
    ct_window,
    CTSegmentation(keys=["image"]),
    AddChanneld(keys=["image"]),
    crop_foreground,
    #crop_z,
    gaussian_noise_smooth,
    resize1,
    resize2,
    crop,
])
hackathon_test_transfrom = Compose([
    common_transform,
    ToTensord(keys=["image"]),
]).flatten()

## Setup data

In [None]:
#set_determinism(seed=100)
test_dataset = PersistentDataset(data=test_data_hackathon[:], transform=hackathon_test_transfrom, cache_dir=dirs["persistent"])
test_loader = DataLoader(
    test_dataset,
    batch_size=2,
    #shuffle=True,
    pin_memory=using_gpu,
    num_workers=2,
    collate_fn=PadListDataCollate(Method.SYMMETRIC, NumpyPadMode.CONSTANT)
)

## Setup network

In [None]:
out_channels = 1
#network = nets.EfficientNetBN("efficientnet-b4", spatial_dims=3, in_channels=1, num_classes=out_channels).to(device)
network = nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=out_channels).to(device)

## Load network

In [None]:
model_dir = "DenseNet121_1_2021-04-17_15-35-18"
load_dir = os.path.join(dirs['out'], 'training', model_dir)
network_path = glob.glob(os.path.join(load_dir, 'network_key_metric*'))[0]
print(network_path)
network.load_state_dict(torch.load(network_path))

## Run tester

In [None]:
act = Activations(sigmoid=True) # One channel
#act = Activations(softmax=True)  # Two channel
d = AsDiscrete(threshold_values=True)

test_outputs_global = []

network.eval()
with torch.no_grad():
    for test_data in test_loader:
        test_images = test_data["image"].to(device)
        test_image_names = test_data["image_meta_dict"]["filename_or_obj"]
        test_outputs = act(network(test_images))
        
        _test_outputs = test_outputs.detach().cpu().numpy().ravel()
        _test_image_names = [os.path.basename(f) for f in test_image_names]
        out = np.array((_test_image_names,_test_outputs))
        out = out.T.tolist()
        test_outputs_global.extend(out)

if not os.path.exists(dirs['results']):
    os.mkdir(dirs['results'])
results_file = os.path.join(dirs['results'], f'{model_dir}.txt')
np.savetxt(results_file, test_outputs_global, delimiter=",", fmt='%s %.8s')
print(test_outputs_global)