In [None]:
"""SynthSeg testing in MONAI
"""
import math
import os
import pickle
from pathlib import Path
import numpy as np
import torch
from torch.nn.parallel import DistributedDataParallel
import torch.distributed as dist
import matplotlib.pyplot as plt
from monai.inferers import sliding_window_inference
from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    Orientationd,
    ScaleIntensityd,
    ResizeWithPadOrCropd,
    RandAffined,
    Invertd,
    RandZoomd,
)
from monai.data import CacheDataset, DataLoader, decollate_batch, TestTimeAugmentation
from monai.metrics import DiceMetric
from monai.utils import set_determinism
from monai.utils import first

import transforms_synthseg as transforms
import utils_synthseg as utils

seed = 0
set_determinism(seed=seed)
torch.backends.cudnn.benchmark = True

In [2]:
testing = True
if testing:
    spatial_size = (96,) * 3
    patch_size = None
    dout = "./results-testing"
else:    
    spatial_size = (256,) * 3
    patch_size = (160,) * 3
    dout = "./results"
dir_data = "./data/neurite_10"
model_pth = os.path.join(dout, 'model_best.pth')
model_pth = None if not os.path.exists(model_pth) else model_pth
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
images = [str(f) for f in sorted(Path(dir_data).rglob('orig.nii.gz'))]
labels = [str(f) for f in sorted(Path(dir_data).rglob('seg35.nii.gz'))]
val_files = [{"image": image_name, "label": label_name} for image_name, label_name in zip(images, labels)] 
val_files = [val_files[0]]
print(val_files)

In [4]:
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS", lazy=True),
        transforms.Resize(spatial_size, testing),        
        ScaleIntensityd(keys=["image"]),
    ]
)

In [None]:
val_loader = DataLoader(
    CacheDataset(
        data=val_files,
        transform=val_transforms,
        cache_rate=1.0,
    ),
    batch_size=1,
    shuffle=False,
)

In [None]:
# Label info
target_labels = list(transforms.MapLabelsNeurite.label_mapping().values())
n_labels = len(target_labels)
print("n_labels =", n_labels)

In [None]:
out_channels = n_labels + 1
model = utils.get_model(out_channels)
model = model.to(device)

checkpoint = torch.load(model_pth)
checkpoint = {k.replace("module.", ""): v for k, v in checkpoint.items()}
model.load_state_dict(checkpoint)

model.eval()

In [8]:
dice_metric = DiceMetric(include_background=True, reduction="mean")
dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch")

post_pred = Compose(
    [
        Invertd(
            keys="pred",
            transform=val_transforms,
            orig_keys="image",
            meta_keys="pred_meta_dict",
            orig_meta_keys="image_meta_dict",
            meta_key_postfix="meta_dict",
            nearest_interp=False,
            to_tensor=True,
            device="cpu",
        ),
    ]
)

In [9]:
# check_data = first(val_loader)
# image, label = (check_data["image"][0][0], check_data["label"][0][0])
# print(f"image shape: {image.shape}, label shape: {label.shape}")
# # plot the slice [:, :, 80]
# plt.figure("check", (12, 6))
# plt.subplot(1, 2, 1)
# plt.title("image")
# plt.imshow(image[:, :, 128], cmap="gray")
# plt.subplot(1, 2, 2)
# plt.title("label")
# plt.imshow(label[:, :, 128])
# plt.show()

In [None]:
check_data = first(val_loader)
with torch.no_grad():
    inputs, labels = (
        check_data["image"].to(device),
        check_data["label"].to(device),
    )

    check_data["pred"] = sliding_window_inference(
                inputs=inputs,
                roi_size=patch_size,
                sw_batch_size=1,
                predictor=model,
                overlap=0.5
    )
    post_outputs = [post_pred(i) for i in decollate_batch(check_data)]
    pred_orig = post_outputs[0]["pred"].softmax(0).argmax(0)

plt.subplot(1, 1, 1)
plt.title("pred_orig")
plt.imshow(pred_orig[:, :, 128].cpu())
plt.show()

In [None]:
keys = ["image", "label"]
tta_transforms = Compose(
    [
        LoadImaged(keys=keys),
        EnsureChannelFirstd(keys=keys),
        Orientationd(keys=keys, axcodes="RAS", lazy=True),
        ResizeWithPadOrCropd(
            keys=keys,
            spatial_size=spatial_size,
            lazy=True,
        ),
        RandZoomd(
            keys=keys,
            prob=1.0,
            mode=("bilinear", "nearest"),
            min_zoom=0.75,
            max_zoom=1.25, 
            lazy=True,
        ),
        RandAffined(
            keys=keys,
            prob=1.0,
            rotate_range=((25 * np.pi/180,) * 3),
            translate_range=(25),
            mode=("bilinear", "nearest"),
            lazy=True,
        ),
        ScaleIntensityd(keys=["image"]),
    ]
)

tta_loader = DataLoader(
    CacheDataset(
        data=val_files,
        transform=tta_transforms,
        cache_rate=1.0,
    ),
    batch_size=1,
    shuffle=False,
)

In [None]:
check_data = first(tta_loader)
image, label = (check_data["image"][0][0], check_data["label"][0][0])
plt.figure("check", (12, 6))
plt.subplot(1, 2, 1)
plt.title("image")
plt.imshow(image[:, :, 128], cmap="gray")
plt.subplot(1, 2, 2)
plt.title("label")
plt.imshow(label[:, :, 128])
plt.show()

In [None]:
def forward(inputs, patch_size, model, post_pred=False):
    outputs = sliding_window_inference(
                inputs=inputs,
                roi_size=patch_size,
                sw_batch_size=1,
                predictor=model,
                overlap=0.5
    )
    if post_pred:
        outputs = [post_pred(i) for i in decollate_batch(outputs)]
        outputs = outputs[0]
    return outputs

tt_aug = TestTimeAugmentation(
    tta_transforms,
    batch_size=1,
    num_workers=0,
    inferrer_fn=lambda x: forward(x, patch_size, model),
    device=device,
    return_full_data=True,
)

with torch.no_grad():
    for file in np.random.choice(val_files, size=1, replace=False):
        pred_tta_all = tt_aug(file, num_examples=16)
        pred_tta_all = pred_tta_all.softmax(1)

In [None]:
pred_tta_mean = pred_tta_all.mean(0).argmax(0)
pred_tta_median = pred_tta_all.median(0).values.argmax(0)
pred_tta_mode = pred_tta_all.mode(0).values.argmax(0)

plt.figure("check", (12, 12))
plt.subplot(2, 2, 1)
plt.title("pred_orig")
plt.imshow(pred_orig[:, :, 128].cpu())
plt.subplot(2, 2, 2)
plt.title("pred_tta_mean")
plt.imshow(pred_tta_mean[:, :, 128].cpu())
plt.subplot(2, 2, 3)
plt.title("pred_tta_median")
plt.imshow(pred_tta_median[:, :, 128].cpu())
plt.subplot(2, 2, 4)
plt.title("pred_tta_mode")
plt.imshow(pred_tta_mode[:, :, 128].cpu())
plt.show()

In [17]:
# with torch.no_grad():

#     for ix, batch_data in enumerate(val_loader):

#         inputs, labels = (
#             batch_data["image"].to(device),
#             batch_data["label"].to(device),
#         )

#         outputs = inference(inputs, model, patch_size=patch_size)
#         post_outputs = [post_pred(i) for i in decollate_batch(outputs)]

#         dice_metric(y_pred=post_outputs, y=labels)
#         dice_metric_batch(y_pred=post_outputs, y=labels)
#         break

#     # aggregate the final mean dice result
#     metric = dice_metric.aggregate().item()
#     metric_batch = dice_metric_batch.aggregate()        
#     # reset the status for next validation round
#     dice_metric.reset()
#     dice_metric_batch.reset()                        

#     # print metric to terminal
#     print(f"METRIC={metric:.4f}")
#     # for i in range(0, len(metric_batch), 10): 
#     #     print( " " * (13 + len(str(max_epochs))) + "|______" + ", "
#     #             .join([f"{k + i:3.0f}={v:0.3f}".format(k, v) for k, v in enumerate(metric_batch[i:i + 10])]))

In [18]:
# print(metric_batch)