## Evaluation 4D dynunet pipeline with NeuroI ROI dataset
* `train.py`with mode to `val` 을 기반으로함
* `commands/val.sh` 을 참고
* Brain image + mask image -> 4D modalities
* in_channels: 4, out_channels:1

In [1]:
import os
import numpy as np
import time

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from monai.utils import first
from monai.handlers import (
    MeanDice,
    from_engine,
)
from monai.inferers import SimpleInferer, SlidingWindowInferer
# from monai.utils import set_determinism

from create_network import get_network_ke
# from evaluator import (
#     DynUNetEvaluator,
#     DynUNetEvaluator_gpu,
#     DynUNetEvaluator_SaveResult,
#     DynUNetEvaluator_GPU_SaveResult_PostMapping
# )
from config import get_config
from dataset_roi_4d import get_val_loader   # 4D modality test

In [26]:
config = "/data/kehyeong/project/MONAI_examples/dynunet_pipeline/config_roi_earlystop_toy_220209.yaml"
checkpoint = "/data/train/running/l/model_roi_try1_220217/models/net_key_metric=0.3331.pt"
val_dataset = "/work/NeuroI-models/ke-monai/data/roi/dataset_test_roi_toy2.csv"
# val_output_dir = "./runs_eval2"
val_output_dir = "/home/kehyeong/tmp_result"

# tta_val = True    # ?!?!?!  whether to use test time augmentation.
tta_val = False

multi_gpu_flag = False
spacing = [1.0, 1.0, 1.0]
deep_supr_num = 3
window_mode = "gaussian"     # the mode parameter for SlidingWindowInferer.
eval_overlap = 0.5
amp = False
local_rank = 0

config = get_config(config)
data_dir = config["data_dir"]
image_file_path = config["image_file_path"]
label_file_path = config["label_file_path"]
mask_file_path = config["mask_file_path"]
val_batch_size = config["val"]["batch_size"] 
val_num_workers = config["val"]["num_workers"]
num_classes = config["num_classes"]
patch_size = config["patch_size"]

TTA는 말 그대로 Inference(Test) 과정에서 Augmentation 을 적용한 뒤 예측의 확률을 평균(또는 다른 방법)을 통해 도출하는 기법입니다. 모델 학습간에 다양한 Augmentation 을 적용하여 학습하였을시, Inference 과정에서도 유사한 Augmentation 을 적용하여 정확도를 높일 수 있습니다. 또는 이미지에 객체가 너무 작게 위치한 경우 원본 이미지, Crop 이미지를 넣는 등 다양하게 활용이 가능

In [3]:
if not os.path.exists(val_output_dir):
    os.makedirs(val_output_dir)
    
if multi_gpu_flag:
    dist.init_process_group(backend="nccl", init_method="env://")
    device = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(device)
else:
    device = torch.device("cuda:6")
#     device = torch.device("cpu")

val_loader = get_val_loader(
    data_dir=data_dir,
    id_file=val_dataset,
    image_file_pattern=image_file_path,
    label_file_pattern=label_file_path,
    mask_file_pattern=mask_file_path,
    batch_size=val_batch_size,
    num_workers=val_num_workers,
    multi_gpu_flag=multi_gpu_flag
)

{dir}/s_{id}_{aid}_b.nii.gz /data/train/running/l/input_augmented SU0303 00_0
{dir}/s_{id}_{aid}_b.nii.gz /data/train/running/l/input_augmented SU0197 00_0


Loading dataset: 100%|███████████████████████████████████████████████████████████████████████████| 2/2 [00:08<00:00,  4.01s/it]


In [None]:
# test_data = first(val_loader)
# print(test_data.keys())
# print("image shape:", test_data['image'].shape)
# print("image dtype:", test_data['image'].dtype)
# print("label shape:", test_data['label'].shape)
# print("label dtype:", test_data['label'].dtype)
# print("1번 배치의 유니크한 라벨 리스트:", np.unique(test_data['label']))
# total_labels = np.unique(test_data['label'])
# print(f'1번 배치의 유니크한 라벨 class 수: {len(total_labels)}')

------------------------------

In [4]:
properties = {
    'modality': [0,1],
    'labels': np.arange(num_classes)
}
n_class = len(properties["labels"])
in_channels = len(properties["modality"])
in_channels, n_class

(2, 109)

In [5]:
# produce the network
net = get_network_ke(properties, patch_size, spacing, deep_supr_num, 
                     val_output_dir, checkpoint)    # val_output_dir은 삭제 필요. 안씀
print('Loading nnUNET Done!!!')
net = net.to(device)
print('Loading nnUNET to GPU devices Done!!!')
print(net)

pretrained checkpoint: /data/train/running/l/model_roi_try1_220217/models/net_key_metric=0.3331.pt loaded
Loading nnUNET Done!!!
Loading nnUNET to GPU devices Done!!!
DynUNet(
  (input_block): UnetBasicBlock(
    (conv1): Convolution(
      (conv): Conv3d(2, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    )
    (conv2): Convolution(
      (conv): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    )
    (lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
    (norm1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (norm2): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  )
  (downsamples): ModuleList(
    (0): UnetBasicBlock(
      (conv1): Convolution(
        (conv): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
      )
      (conv2): Convolution(
        (conv): Conv3d(64, 64, kernel_size=(3, 

In [6]:
if multi_gpu_flag:
    net = DistributedDataParallel(module=net, device_ids=[device])

num_classes = len(properties["labels"])

In [7]:
device

device(type='cuda', index=6)

In [8]:
# ## 1. 기본형
# net.eval()
# evaluator = DynUNetEvaluator(
#     device=device,
#     val_data_loader=val_loader,
#     network=net,
#     num_classes=num_classes,
#     inferer=SlidingWindowInferer(
#         roi_size=patch_size,
#         sw_batch_size=val_batch_size,
#         overlap=eval_overlap,
#         mode=window_mode,
#     ),
#     postprocessing=None,      # 이걸 바꿔줘야할듯...
#     key_val_metric={
#         "val_mean_dice": MeanDice(
#             include_background=False,
#             output_transform=from_engine(["pred", "label"]),
#         )
#     },
#     additional_metrics=None,
#     amp=amp,
#     tta_val=tta_val,
# )

In [9]:
# ## 2. GPU inference 형
# net.eval()
# evaluator = DynUNetEvaluator_gpu(
#     device=device,
#     val_data_loader=val_loader,
#     network=net,
#     num_classes=num_classes,
#     inferer=SlidingWindowInferer(
#         roi_size=patch_size,
#         sw_batch_size=val_batch_size,
#         overlap=eval_overlap,
#         mode=window_mode,
#     ),
#     postprocessing=None,      # 이걸 바꿔줘야할듯...
#     key_val_metric={
#         "val_mean_dice": MeanDice(
#             include_background=False,
#             output_transform=from_engine(["pred", "label"]),
#         )
#     },
#     additional_metrics=None,
#     amp=amp,
#     tta_val=tta_val,
# )

In [10]:
# ## 3. GPU ver. + inference image 저장 + tta False

# net.eval()
# evaluator = DynUNetEvaluator_SaveResult(
#     device=device,
#     val_data_loader=val_loader,
#     network=net,
#     output_dir=val_output_dir,
#     num_classes=num_classes,
#     inferer=SlidingWindowInferer(
#         roi_size=patch_size,
#         sw_batch_size=val_batch_size,
#         overlap=eval_overlap,
#         mode=window_mode,
#     ),
#     postprocessing=None,      # 이걸 바꿔줘야할듯...
#     key_val_metric={
#         "val_mean_dice": MeanDice(
#             include_background=False,
#             output_transform=from_engine(["pred", "label"]),
#         )
#     },
#     additional_metrics=None,
#     amp=amp,
#     tta_val=tta_val,
# )

In [11]:
# ## 4. GPU ver. + inference image 저장형 + tta False + posttransform (mapping index)
# from monai.transforms import (
#     LoadImaged,
#     AddChanneld,
#     MapLabelValued,
#     Compose
# )

# orig_label_classes, target_label_classes = (
#     np.array([   0,    2,    3,    4,    5,    7,    8,   10,   11,   12,   13,
#          14,   15,   16,   17,   18,   24,   26,   28,   30,   31,   41,
#          42,   43,   44,   46,   47,   49,   50,   51,   52,   53,   54,
#          58,   60,   62,   63,   77,   80,   85,  251,  252,  253,  254,
#         255, 1000, 1002, 1003, 1005, 1006, 1007, 1008, 1009, 1010, 1011,
#        1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022,
#        1023, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1034, 1035,
#        2000, 2002, 2003, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012,
#        2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023,
#        2024, 2025, 2026, 2027, 2028, 2029, 2030, 2031, 2034, 2035]),
#     np.array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
#         13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
#         26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
#         39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
#         52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
#         65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
#         78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
#         91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
#        104, 105, 106, 107, 108])
# )

# post_trans = MapLabelValued(
#     keys=["pred"], 
#     orig_labels=target_label_classes, 
#     target_labels=orig_label_classes
# )


# net.eval()
# evaluator = DynUNetEvaluator_SaveResult(
#     device=device,
#     val_data_loader=val_loader,
#     network=net,
#     output_dir=val_output_dir,
#     num_classes=num_classes,
#     inferer=SlidingWindowInferer(
#         roi_size=patch_size,
#         sw_batch_size=val_batch_size,
#         overlap=eval_overlap,
#         mode=window_mode,
#     ),
#     postprocessing=post_trans,      # 이걸 바꿔줘야할듯...
#     key_val_metric={
#         "val_mean_dice": MeanDice(
#             include_background=False,
#             output_transform=from_engine(["pred", "label"]),
#         )
#     },
#     additional_metrics=None,
#     amp=amp,
#     tta_val=tta_val,
# )

In [None]:
import os
import numpy as np
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union

import torch
import torch.nn as nn
from ignite.engine import Engine
from ignite.metrics import Metric
from monai.data import decollate_batch
from monai.data.nifti_writer import write_nifti
from monai.engines import SupervisedEvaluator
from monai.engines.utils import CommonKeys as Keys
from monai.engines.utils import IterationEvents, default_prepare_batch
from monai.inferers import Inferer
from monai.networks.utils import eval_mode
from monai.transforms import AsDiscrete, Transform
from torch.utils.data import DataLoader

from transforms import recovery_prediction

from monai.transforms import Compose, AddChannel, MapLabelValue, CastToType

class DynUNetEvaluator_GPU_SaveResult_PostMapping(SupervisedEvaluator):
    """
    넣어준 모든 이미지들의 DiceMean 계산 뿐 아니라 inference 결과 이미지도 저장하도록 변경

    Args:
        device: an object representing the device on which to run.
        val_data_loader: Ignite engine use data_loader to run, must be
            torch.DataLoader.
        network: use the network to run model forward.
        num_classes: the number of classes (output channels) for the task.
        epoch_length: number of iterations for one epoch, default to
            `len(val_data_loader)`.
        non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
            with respect to the host. For other cases, this argument has no effect.
        prepare_batch: function to parse image and label for current iteration.
        iteration_update: the callable function for every iteration, expect to accept `engine`
            and `batchdata` as input parameters. if not provided, use `self._iteration()` instead.
        inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
        postprocessing: execute additional transformation for the model output data.
            Typically, several Tensor based transforms composed by `Compose`.
        key_val_metric: compute metric when every iteration completed, and save average value to
            engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the
            checkpoint into files.
        additional_metrics: more Ignite metrics that also attach to Ignite Engine.
        val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
            CheckpointHandler, StatsHandler, SegmentationSaver, etc.
        amp: whether to enable auto-mixed-precision evaluation, default is False.
        tta_val: whether to do the 8 flips (8 = 2 ** 3, where 3 represents the three dimensions)
            test time augmentation, default is False.

    """

    def __init__(
        self,
        device: torch.device,
        val_data_loader: DataLoader,
        network: torch.nn.Module,
        output_dir: str,                                      # infer specific
        num_classes: Union[str, int],
        epoch_length: Optional[int] = None,
        non_blocking: bool = False,
        prepare_batch: Callable = default_prepare_batch,
        iteration_update: Optional[Callable] = None,
        inferer: Optional[Inferer] = None,
        postprocessing: Optional[Transform] = None,
        key_val_metric: Optional[Dict[str, Metric]] = None,
        additional_metrics: Optional[Dict[str, Metric]] = None,
        val_handlers: Optional[Sequence] = None,
        amp: bool = False,
        tta_val: bool = False,
    ) -> None:
        super().__init__(
            device=device,
            val_data_loader=val_data_loader,
            network=network,
            epoch_length=epoch_length,
            non_blocking=non_blocking,
            prepare_batch=prepare_batch,
            iteration_update=iteration_update,
            inferer=inferer,
            postprocessing=postprocessing,
            key_val_metric=key_val_metric,
            additional_metrics=additional_metrics,
            val_handlers=val_handlers,
            amp=amp,
        )

        if not isinstance(num_classes, int):
            num_classes = int(num_classes)
        self.output_dir = output_dir               # infer specific
        self.num_classes = num_classes
        self.post_pred = AsDiscrete(argmax=True, to_onehot=num_classes)
        self.post_label = AsDiscrete(to_onehot=num_classes)              # eval specific
        self.tta_val = tta_val
        
        # orig_label_classes, target_label_classes = (
        #     np.array([   0,    2,    3,    4,    5,    7,    8,   10,   11,   12,   13,
        #         14,   15,   16,   17,   18,   24,   26,   28,   30,   31,   41,
        #         42,   43,   44,   46,   47,   49,   50,   51,   52,   53,   54,
        #         58,   60,   62,   63,   77,   80,   85,  251,  252,  253,  254,
        #         255, 1000, 1002, 1003, 1005, 1006, 1007, 1008, 1009, 1010, 1011,
        #     1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022,
        #     1023, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1034, 1035,
        #     2000, 2002, 2003, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012,
        #     2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023,
        #     2024, 2025, 2026, 2027, 2028, 2029, 2030, 2031, 2034, 2035], dtype=np.float64),
        #     np.array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        #         13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        #         26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        #         39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        #         52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        #         65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
        #         78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
        #         91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
        #     104, 105, 106, 107, 108], dtype=np.float64)
        # )
        # self.orig_label_classes = orig_label_classes
        # self.target_label_classes = target_label_classes
        # self.post_trans = Compose([
        #     AddChannel(),
        #     MapLabelValue(    # 0~108로 잡혀있는 label을 original index로 변경
        #                   orig_labels=self.target_label_classes, 
        #                   target_labels=self.orig_label_classes,
        #                   dtype=np.uint8
        #                   )
        #     ])
        # self.post_trans = MapLabelValue(    # 0~108로 잡혀있는 label을 original index로 변경
        #                                 orig_labels=self.target_label_classes, 
        #                                 target_labels=self.orig_label_classes
        #                   )

    def _iteration(
        self, engine: Engine, batchdata: Dict[str, Any]
    ) -> Dict[str, torch.Tensor]:
        """
        callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine.
        Return below items in a dictionary:
            - IMAGE: image Tensor data for model input, already moved to device.
            - LABEL: label Tensor data corresponding to the image, already moved to device.
            - PRED: prediction result of model.

        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
            batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.

        Raises:
            ValueError: When ``batchdata`` is None.

        """
        if batchdata is None:
            raise ValueError("Must provide batch data for current iteration.")
        batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking)
        if len(batch) == 2:
            inputs, targets = batch
            args: Tuple = ()
            kwargs: Dict = {}
        else:
            inputs, targets, args, kwargs = batch

        targets = targets.cpu()      # device로 바꿔보자
        # print('CPU 비활성화!!!!!!!!!! ')

        def _compute_pred():
            ct = 1.0
            # pred = self.inferer(inputs, self.network, *args, **kwargs).cpu()    # device로 바꿔보자 (지우면댐)
            pred = self.inferer(inputs, self.network, *args, **kwargs)    # device로 바꿔보자 (지우면댐)
            pred = nn.functional.softmax(pred, dim=1)
            if not self.tta_val:
                return pred
            else:
                for dims in [[2], [3], [4], (2, 3), (2, 4), (3, 4), (2, 3, 4)]:
                    flip_inputs = torch.flip(inputs, dims=dims)
                    flip_pred = torch.flip(
                        # self.inferer(flip_inputs, self.network).cpu(), dims=dims    # device로 바꿔보자
                        self.inferer(flip_inputs, self.network), dims=dims    # device로 바꿔보자
                    )
                    flip_pred = nn.functional.softmax(flip_pred, dim=1)
                    del flip_inputs
                    pred += flip_pred
                    del flip_pred
                    ct += 1
                return pred / ct

        # execute forward computation
        with eval_mode(self.network):
            if self.amp:
                with torch.cuda.amp.autocast():
                    predictions = _compute_pred()
            else:
                predictions = _compute_pred()

        inputs = inputs.cpu()    #  # device로 바꿔보자

        predictions = self.post_pred(decollate_batch(predictions)[0])
        targets = self.post_label(decollate_batch(targets)[0])                # eval specific

        affine = batchdata["image_meta_dict"]["affine"].numpy()[0]            # infer specific
        resample_flag = batchdata["resample_flag"]
        anisotrophy_flag = batchdata["anisotrophy_flag"]
        crop_shape = batchdata["crop_shape"][0].tolist()
        original_shape = batchdata["original_shape"][0].tolist()
        if resample_flag:
            # convert the prediction back to the original (after cropped) shape
            predictions = recovery_prediction(
                predictions.numpy(), [self.num_classes, *crop_shape], anisotrophy_flag
            )
            predictions = torch.tensor(predictions)

        ## 이미지 저장
        predictions_wirte = predictions.cpu()
        print(type(predictions_wirte))
        print(predictions_wirte.shape)
        predictions_wirte = np.argmax(predictions_wirte, axis=0)
        print(predictions_wirte.shape)
        predictions_wirte_org = np.zeros([*original_shape])
        
        # put iteration outputs into engine.state
        engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets.unsqueeze(0)}
        engine.state.output[Keys.PRED] = torch.zeros([1, self.num_classes, *original_shape])
        # pad the prediction back to the original shape
        box_start, box_end = batchdata["bbox"][0]
        h_start, w_start, d_start = box_start
        h_end, w_end, d_end = box_end

        engine.state.output[Keys.PRED][
            0, :, h_start:h_end, w_start:w_end, d_start:d_end
        ] = predictions
        del predictions

        
        ## 이미지 저장
        # predictions_wirte = self.post_trans(predictions_wirte)
        orig_label_classes, target_label_classes = (
            np.array([   0,    2,    3,    4,    5,    7,    8,   10,   11,   12,   13,
                14,   15,   16,   17,   18,   24,   26,   28,   30,   31,   41,
                42,   43,   44,   46,   47,   49,   50,   51,   52,   53,   54,
                58,   60,   62,   63,   77,   80,   85,  251,  252,  253,  254,
                255, 1000, 1002, 1003, 1005, 1006, 1007, 1008, 1009, 1010, 1011,
            1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022,
            1023, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1034, 1035,
            2000, 2002, 2003, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012,
            2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023,
            2024, 2025, 2026, 2027, 2028, 2029, 2030, 2031, 2034, 2035]),
            np.array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
                13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
                26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
                39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
                52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
                65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
                78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
                91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
            104, 105, 106, 107, 108])
        )
        post_trans = MapLabelValue(    # 0~108로 잡혀있는 label을 original index로 변경
            orig_labels=target_label_classes, 
            target_labels=orig_label_classes,
#             dtype=np.uint8
        )
#         post_trans = Compose([
#             MapLabelValue(    # 0~108로 잡혀있는 label을 original index로 변경
#             orig_labels=target_label_classes, 
#             target_labels=orig_label_classes,
# #             dtype=np.uint8
#             ),
#             CastToType(dtype=np.int64)
            
#         ])
               
        
        predictions_wirte_org[h_start:h_end, w_start:w_end, d_start:d_end] = predictions_wirte
        del predictions_wirte
        print('변형전 dtype', predictions_wirte_org.dtype)
        print(np.unique(predictions_wirte_org))
        # predictions_wirte_org = self.post_trans(predictions_wirte_org)   # 원래대로 pred index 복구
        # predictions_wirte_org = predictions_wirte_org.squeeze()
        predictions_wirte_org = post_trans(predictions_wirte_org)
        print('변형후 dtype', predictions_wirte_org.dtype)
        print(np.unique(predictions_wirte_org))
        
        filename = batchdata["image_meta_dict"]["filename_or_obj"][0].split("/")[-1]
        print(
            "save {} with shape: {}".format(
                filename, predictions_wirte_org.shape
            )
        )
        write_nifti(
            data=predictions_wirte_org,
            file_name=os.path.join(self.output_dir, filename),
            affine=affine,
            resample=False,
            output_dtype=np.uint32,
        )
        print(f'MRI img:{ filename } eval done .................')
        engine.fire_event(IterationEvents.FORWARD_COMPLETED)
        engine.fire_event(IterationEvents.MODEL_COMPLETED)

        return engine.state.output

In [None]:
import os
import numpy as np
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union

import torch
import torch.nn as nn
from ignite.engine import Engine
from ignite.metrics import Metric
from monai.data import decollate_batch
from monai.data.nifti_writer import write_nifti
from monai.engines import SupervisedEvaluator
from monai.engines.utils import CommonKeys as Keys
from monai.engines.utils import IterationEvents, default_prepare_batch
from monai.inferers import Inferer
from monai.networks.utils import eval_mode
from monai.transforms import AsDiscrete, Transform
from torch.utils.data import DataLoader

from transforms import recovery_prediction

from monai.transforms import Compose, AddChannel, MapLabelValue, CastToType

class DynUNetEvaluator_GPU_SaveResult_PostMapping2(SupervisedEvaluator):
    """
    넣어준 모든 이미지들의 DiceMean 계산 뿐 아니라 inference 결과 이미지도 저장하도록 변경

    Args:
        device: an object representing the device on which to run.
        val_data_loader: Ignite engine use data_loader to run, must be
            torch.DataLoader.
        network: use the network to run model forward.
        num_classes: the number of classes (output channels) for the task.
        epoch_length: number of iterations for one epoch, default to
            `len(val_data_loader)`.
        non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
            with respect to the host. For other cases, this argument has no effect.
        prepare_batch: function to parse image and label for current iteration.
        iteration_update: the callable function for every iteration, expect to accept `engine`
            and `batchdata` as input parameters. if not provided, use `self._iteration()` instead.
        inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
        postprocessing: execute additional transformation for the model output data.
            Typically, several Tensor based transforms composed by `Compose`.
        key_val_metric: compute metric when every iteration completed, and save average value to
            engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the
            checkpoint into files.
        additional_metrics: more Ignite metrics that also attach to Ignite Engine.
        val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
            CheckpointHandler, StatsHandler, SegmentationSaver, etc.
        amp: whether to enable auto-mixed-precision evaluation, default is False.
        tta_val: whether to do the 8 flips (8 = 2 ** 3, where 3 represents the three dimensions)
            test time augmentation, default is False.

    """

    def __init__(
        self,
        device: torch.device,
        val_data_loader: DataLoader,
        network: torch.nn.Module,
        output_dir: str,                                      # infer specific
        num_classes: Union[str, int],
        epoch_length: Optional[int] = None,
        non_blocking: bool = False,
        prepare_batch: Callable = default_prepare_batch,
        iteration_update: Optional[Callable] = None,
        inferer: Optional[Inferer] = None,
        postprocessing: Optional[Transform] = None,
        key_val_metric: Optional[Dict[str, Metric]] = None,
        additional_metrics: Optional[Dict[str, Metric]] = None,
        val_handlers: Optional[Sequence] = None,
        amp: bool = False,
        tta_val: bool = False,
    ) -> None:
        super().__init__(
            device=device,
            val_data_loader=val_data_loader,
            network=network,
            epoch_length=epoch_length,
            non_blocking=non_blocking,
            prepare_batch=prepare_batch,
            iteration_update=iteration_update,
            inferer=inferer,
            postprocessing=postprocessing,
            key_val_metric=key_val_metric,
            additional_metrics=additional_metrics,
            val_handlers=val_handlers,
            amp=amp,
        )

        if not isinstance(num_classes, int):
            num_classes = int(num_classes)
        self.output_dir = output_dir               # infer specific
        self.num_classes = num_classes
        self.post_pred = AsDiscrete(argmax=True, to_onehot=num_classes)
        self.post_label = AsDiscrete(to_onehot=num_classes)              # eval specific
        self.tta_val = tta_val
        
        orig_label_classes, target_label_classes = (
            np.array([   0,    2,    3,    4,    5,    7,    8,   10,   11,   12,   13,
                14,   15,   16,   17,   18,   24,   26,   28,   30,   31,   41,
                42,   43,   44,   46,   47,   49,   50,   51,   52,   53,   54,
                58,   60,   62,   63,   77,   80,   85,  251,  252,  253,  254,
                255, 1000, 1002, 1003, 1005, 1006, 1007, 1008, 1009, 1010, 1011,
            1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022,
            1023, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1034, 1035,
            2000, 2002, 2003, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012,
            2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023,
            2024, 2025, 2026, 2027, 2028, 2029, 2030, 2031, 2034, 2035], dtype=np.float64),
            np.array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
                13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
                26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
                39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
                52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
                65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
                78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
                91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
            104, 105, 106, 107, 108], dtype=np.float64)
        )
        self.orig_label_classes = orig_label_classes
        self.target_label_classes = target_label_classes
        # self.post_trans = Compose([
        #     AddChannel(),
        #     MapLabelValue(    # 0~108로 잡혀있는 label을 original index로 변경
        #                   orig_labels=self.target_label_classes, 
        #                   target_labels=self.orig_label_classes,
        #                   dtype=np.uint8
        #                   )
        #     ])
        self.post_trans = MapLabelValue(    # 0~108로 잡혀있는 label을 original index로 변경
                                        orig_labels=self.target_label_classes, 
                                        target_labels=self.orig_label_classes
                          )

    def _iteration(
        self, engine: Engine, batchdata: Dict[str, Any]
    ) -> Dict[str, torch.Tensor]:
        """
        callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine.
        Return below items in a dictionary:
            - IMAGE: image Tensor data for model input, already moved to device.
            - LABEL: label Tensor data corresponding to the image, already moved to device.
            - PRED: prediction result of model.

        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
            batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.

        Raises:
            ValueError: When ``batchdata`` is None.

        """
        if batchdata is None:
            raise ValueError("Must provide batch data for current iteration.")
        batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking)
        if len(batch) == 2:
            inputs, targets = batch
            args: Tuple = ()
            kwargs: Dict = {}
        else:
            inputs, targets, args, kwargs = batch

        targets = targets.cpu()      # device로 바꿔보자
        # print('CPU 비활성화!!!!!!!!!! ')

        def _compute_pred():
            ct = 1.0
            # pred = self.inferer(inputs, self.network, *args, **kwargs).cpu()    # device로 바꿔보자 (지우면댐)
            pred = self.inferer(inputs, self.network, *args, **kwargs)    # device로 바꿔보자 (지우면댐)
            pred = nn.functional.softmax(pred, dim=1)
            if not self.tta_val:
                return pred
            else:
                for dims in [[2], [3], [4], (2, 3), (2, 4), (3, 4), (2, 3, 4)]:
                    flip_inputs = torch.flip(inputs, dims=dims)
                    flip_pred = torch.flip(
                        # self.inferer(flip_inputs, self.network).cpu(), dims=dims    # device로 바꿔보자
                        self.inferer(flip_inputs, self.network), dims=dims    # device로 바꿔보자
                    )
                    flip_pred = nn.functional.softmax(flip_pred, dim=1)
                    del flip_inputs
                    pred += flip_pred
                    del flip_pred
                    ct += 1
                return pred / ct

        # execute forward computation
        with eval_mode(self.network):
            if self.amp:
                with torch.cuda.amp.autocast():
                    predictions = _compute_pred()
            else:
                predictions = _compute_pred()

        inputs = inputs.cpu()    #  # device로 바꿔보자

        predictions = self.post_pred(decollate_batch(predictions)[0])
        targets = self.post_label(decollate_batch(targets)[0])                # eval specific

        affine = batchdata["image_meta_dict"]["affine"].numpy()[0]            # infer specific
        resample_flag = batchdata["resample_flag"]
        anisotrophy_flag = batchdata["anisotrophy_flag"]
        crop_shape = batchdata["crop_shape"][0].tolist()
        original_shape = batchdata["original_shape"][0].tolist()
        if resample_flag:
            # convert the prediction back to the original (after cropped) shape
            predictions = recovery_prediction(
                predictions.numpy(), [self.num_classes, *crop_shape], anisotrophy_flag
            )
            predictions = torch.tensor(predictions)

        ## 이미지 저장
        predictions_wirte = predictions.cpu()
        print(type(predictions_wirte))
        print(predictions_wirte.shape)
        predictions_wirte = np.argmax(predictions_wirte, axis=0)
        print(predictions_wirte.shape)
        predictions_wirte_org = np.zeros([*original_shape])
        
        # put iteration outputs into engine.state
        engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets.unsqueeze(0)}
        engine.state.output[Keys.PRED] = torch.zeros([1, self.num_classes, *original_shape])
        # pad the prediction back to the original shape
        box_start, box_end = batchdata["bbox"][0]
        h_start, w_start, d_start = box_start
        h_end, w_end, d_end = box_end

        engine.state.output[Keys.PRED][
            0, :, h_start:h_end, w_start:w_end, d_start:d_end
        ] = predictions
        del predictions

        
        ## 이미지 저장
        # predictions_wirte = self.post_trans(predictions_wirte)
#         orig_label_classes, target_label_classes = (
#             np.array([   0,    2,    3,    4,    5,    7,    8,   10,   11,   12,   13,
#                 14,   15,   16,   17,   18,   24,   26,   28,   30,   31,   41,
#                 42,   43,   44,   46,   47,   49,   50,   51,   52,   53,   54,
#                 58,   60,   62,   63,   77,   80,   85,  251,  252,  253,  254,
#                 255, 1000, 1002, 1003, 1005, 1006, 1007, 1008, 1009, 1010, 1011,
#             1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022,
#             1023, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1034, 1035,
#             2000, 2002, 2003, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012,
#             2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023,
#             2024, 2025, 2026, 2027, 2028, 2029, 2030, 2031, 2034, 2035]),
#             np.array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
#                 13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
#                 26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
#                 39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
#                 52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
#                 65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
#                 78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
#                 91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
#             104, 105, 106, 107, 108])
#         )
#         post_trans = MapLabelValue(    # 0~108로 잡혀있는 label을 original index로 변경
#             orig_labels=target_label_classes, 
#             target_labels=orig_label_classes,
#         )
#         post_trans = Compose([
#             MapLabelValue(    # 0~108로 잡혀있는 label을 original index로 변경
#             orig_labels=target_label_classes, 
#             target_labels=orig_label_classes,
# #             dtype=np.uint8
#             ),
#             CastToType(dtype=np.int64)
            
#         ])
               
        
        predictions_wirte_org[h_start:h_end, w_start:w_end, d_start:d_end] = predictions_wirte
        del predictions_wirte
        print('변형전 dtype', predictions_wirte_org.dtype)
        print(np.unique(predictions_wirte_org))
        predictions_wirte_org = self.post_trans(predictions_wirte_org)   # 원래대로 pred index 복구
        # predictions_wirte_org = predictions_wirte_org.squeeze()
#         predictions_wirte_org = post_trans(predictions_wirte_org)
        print('변형후 dtype', predictions_wirte_org.dtype)
        print(np.unique(predictions_wirte_org))
        
        filename = batchdata["image_meta_dict"]["filename_or_obj"][0].split("/")[-1]
        print(
            "save {} with shape: {}".format(
                filename, predictions_wirte_org.shape
            )
        )
        write_nifti(
            data=predictions_wirte_org,
            file_name=os.path.join(self.output_dir, filename),
            affine=affine,
            resample=False,
            output_dtype=np.uint32,
        )
        print(f'MRI img:{ filename } eval done .................')
        engine.fire_event(IterationEvents.FORWARD_COMPLETED)
        engine.fire_event(IterationEvents.MODEL_COMPLETED)

        return engine.state.output

In [None]:
## 5. GPU ver. + inference image 저장형 + tta False + posttransform 2 (mapping index)
# evaluator class 내부에서 transform
net.eval()
evaluator = DynUNetEvaluator_GPU_SaveResult_PostMapping2(
    device=device,
    val_data_loader=val_loader,
    network=net,
    output_dir=val_output_dir,
    num_classes=num_classes,
    inferer=SlidingWindowInferer(
        roi_size=patch_size,
        sw_batch_size=val_batch_size,
        overlap=eval_overlap,
        mode=window_mode,
    ),
    postprocessing=None,
    key_val_metric={
        "val_mean_dice": MeanDice(
            include_background=False,
            output_transform=from_engine(["pred", "label"]),
        )
    },
    additional_metrics=None,
    amp=amp,
    tta_val=tta_val,
)

In [None]:
start = time.time()
evaluator.run()
print("time :", time.time() - start)

In [27]:
import os
import numpy as np
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union

import torch
import torch.nn as nn
from ignite.engine import Engine
from ignite.metrics import Metric
from monai.data import decollate_batch
from monai.data.nifti_writer import write_nifti
from monai.engines import SupervisedEvaluator
from monai.engines.utils import CommonKeys as Keys
from monai.engines.utils import IterationEvents, default_prepare_batch
from monai.inferers import Inferer
from monai.networks.utils import eval_mode
from monai.transforms import AsDiscrete, Transform
from torch.utils.data import DataLoader

from transforms import recovery_prediction

from monai.transforms import Compose, AddChannel, MapLabelValue, CastToType

class DynUNetEvaluator_GPU_test(SupervisedEvaluator):
    """
    넣어준 모든 이미지들의 DiceMean 계산 뿐 아니라 inference 결과 이미지도 저장하도록 변경
    GPU inference 테스트

    Args:
        device: an object representing the device on which to run.
        val_data_loader: Ignite engine use data_loader to run, must be
            torch.DataLoader.
        network: use the network to run model forward.
        num_classes: the number of classes (output channels) for the task.
        epoch_length: number of iterations for one epoch, default to
            `len(val_data_loader)`.
        non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
            with respect to the host. For other cases, this argument has no effect.
        prepare_batch: function to parse image and label for current iteration.
        iteration_update: the callable function for every iteration, expect to accept `engine`
            and `batchdata` as input parameters. if not provided, use `self._iteration()` instead.
        inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
        postprocessing: execute additional transformation for the model output data.
            Typically, several Tensor based transforms composed by `Compose`.
        key_val_metric: compute metric when every iteration completed, and save average value to
            engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the
            checkpoint into files.
        additional_metrics: more Ignite metrics that also attach to Ignite Engine.
        val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
            CheckpointHandler, StatsHandler, SegmentationSaver, etc.
        amp: whether to enable auto-mixed-precision evaluation, default is False.
        tta_val: whether to do the 8 flips (8 = 2 ** 3, where 3 represents the three dimensions)
            test time augmentation, default is False.

    """

    def __init__(
        self,
        device: torch.device,
        val_data_loader: DataLoader,
        network: torch.nn.Module,
        output_dir: str,                                      # infer specific
        num_classes: Union[str, int],
        epoch_length: Optional[int] = None,
        non_blocking: bool = False,
        prepare_batch: Callable = default_prepare_batch,
        iteration_update: Optional[Callable] = None,
        inferer: Optional[Inferer] = None,
        postprocessing: Optional[Transform] = None,
        key_val_metric: Optional[Dict[str, Metric]] = None,
        additional_metrics: Optional[Dict[str, Metric]] = None,
        val_handlers: Optional[Sequence] = None,
        amp: bool = False,
        tta_val: bool = False,
    ) -> None:
        super().__init__(
            device=device,
            val_data_loader=val_data_loader,
            network=network,
            epoch_length=epoch_length,
            non_blocking=non_blocking,
            prepare_batch=prepare_batch,
            iteration_update=iteration_update,
            inferer=inferer,
            postprocessing=postprocessing,
            key_val_metric=key_val_metric,
            additional_metrics=additional_metrics,
            val_handlers=val_handlers,
            amp=amp,
        )

        if not isinstance(num_classes, int):
            num_classes = int(num_classes)
        self.output_dir = output_dir               # infer specific
        self.num_classes = num_classes
        self.post_pred = AsDiscrete(argmax=True, to_onehot=num_classes)
        self.post_label = AsDiscrete(to_onehot=num_classes)              # eval specific
        self.tta_val = tta_val
        
        orig_label_classes, target_label_classes = (
            np.array([   0,    2,    3,    4,    5,    7,    8,   10,   11,   12,   13,
                14,   15,   16,   17,   18,   24,   26,   28,   30,   31,   41,
                42,   43,   44,   46,   47,   49,   50,   51,   52,   53,   54,
                58,   60,   62,   63,   77,   80,   85,  251,  252,  253,  254,
                255, 1000, 1002, 1003, 1005, 1006, 1007, 1008, 1009, 1010, 1011,
            1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022,
            1023, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1034, 1035,
            2000, 2002, 2003, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012,
            2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023,
            2024, 2025, 2026, 2027, 2028, 2029, 2030, 2031, 2034, 2035], dtype=np.float64),
            np.array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
                13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
                26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
                39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
                52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
                65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
                78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
                91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
            104, 105, 106, 107, 108], dtype=np.float64)
        )
        self.orig_label_classes = orig_label_classes
        self.target_label_classes = target_label_classes
        # self.post_trans = Compose([
        #     AddChannel(),
        #     MapLabelValue(    # 0~108로 잡혀있는 label을 original index로 변경
        #                   orig_labels=self.target_label_classes, 
        #                   target_labels=self.orig_label_classes,
        #                   dtype=np.uint8
        #                   )
        #     ])
        self.post_trans = MapLabelValue(    # 0~108로 잡혀있는 label을 original index로 변경
                                        orig_labels=self.target_label_classes, 
                                        target_labels=self.orig_label_classes
                          )

    def _iteration(
        self, engine: Engine, batchdata: Dict[str, Any]
    ) -> Dict[str, torch.Tensor]:
        """
        callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine.
        Return below items in a dictionary:
            - IMAGE: image Tensor data for model input, already moved to device.
            - LABEL: label Tensor data corresponding to the image, already moved to device.
            - PRED: prediction result of model.

        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
            batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.

        Raises:
            ValueError: When ``batchdata`` is None.

        """
        if batchdata is None:
            raise ValueError("Must provide batch data for current iteration.")
        batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking)
        if len(batch) == 2:
            inputs, targets = batch
            args: Tuple = ()
            kwargs: Dict = {}
        else:
            inputs, targets, args, kwargs = batch
        
        start = time.time()

        targets = targets.cpu()      # device로 바꿔보자
        # print('CPU 비활성화!!!!!!!!!! ')

        def _compute_pred():
            ct = 1.0
            # pred = self.inferer(inputs, self.network, *args, **kwargs).cpu()    # device로 바꿔보자 (지우면댐)
            pred = self.inferer(inputs, self.network, *args, **kwargs)    # device로 바꿔보자 (지우면댐)
            pred = nn.functional.softmax(pred, dim=1)
            if not self.tta_val:
                return pred
            else:
                for dims in [[2], [3], [4], (2, 3), (2, 4), (3, 4), (2, 3, 4)]:
                    flip_inputs = torch.flip(inputs, dims=dims)
                    flip_pred = torch.flip(
                        # self.inferer(flip_inputs, self.network).cpu(), dims=dims    # device로 바꿔보자
                        self.inferer(flip_inputs, self.network), dims=dims    # device로 바꿔보자
                    )
                    flip_pred = nn.functional.softmax(flip_pred, dim=1)
                    del flip_inputs
                    pred += flip_pred
                    del flip_pred
                    ct += 1
                return pred / ct

        # execute forward computation
        with eval_mode(self.network):
            if self.amp:
                with torch.cuda.amp.autocast():
                    predictions = _compute_pred()
            else:
                predictions = _compute_pred()

        inputs = inputs.cpu()    #  # device로 바꿔보자
        print("inference & GPU to CPU time :", time.time() - start)

        start = time.time()
        predictions = self.post_pred(decollate_batch(predictions)[0])
        targets = self.post_label(decollate_batch(targets)[0])                # eval specific

        affine = batchdata["image_meta_dict"]["affine"].numpy()[0]            # infer specific
        resample_flag = batchdata["resample_flag"]
        anisotrophy_flag = batchdata["anisotrophy_flag"]
        crop_shape = batchdata["crop_shape"][0].tolist()
        original_shape = batchdata["original_shape"][0].tolist()
        if resample_flag:
            # convert the prediction back to the original (after cropped) shape
            predictions = recovery_prediction(
                predictions.numpy(), [self.num_classes, *crop_shape], anisotrophy_flag
            )
            predictions = torch.tensor(predictions)
        print("post process time :", time.time() - start)

        start = time.time()
        ## 이미지 저장
        predictions_wirte = predictions.cpu()
#         print(type(predictions_wirte))
#         print(predictions_wirte.shape)
        predictions_wirte = np.argmax(predictions_wirte, axis=0)
#         print(predictions_wirte.shape)
        predictions_wirte_org = np.zeros([*original_shape])
        
        # put iteration outputs into engine.state
        engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets.unsqueeze(0)}
        engine.state.output[Keys.PRED] = torch.zeros([1, self.num_classes, *original_shape])
        # pad the prediction back to the original shape
        box_start, box_end = batchdata["bbox"][0]
        h_start, w_start, d_start = box_start
        h_end, w_end, d_end = box_end

        engine.state.output[Keys.PRED][
            0, :, h_start:h_end, w_start:w_end, d_start:d_end
        ] = predictions
        del predictions

        
       
        predictions_wirte_org[h_start:h_end, w_start:w_end, d_start:d_end] = predictions_wirte
        del predictions_wirte
#         print('변형전 dtype', predictions_wirte_org.dtype)
#         print(np.unique(predictions_wirte_org))
        predictions_wirte_org = self.post_trans(predictions_wirte_org)   # 원래대로 pred index 복구
#         print('변형후 dtype', predictions_wirte_org.dtype)
#         print(np.unique(predictions_wirte_org))
        
        filename = batchdata["image_meta_dict"]["filename_or_obj"][0].split("/")[-1]
        print(
            "save {} with shape: {}".format(
                filename, predictions_wirte_org.shape
            )
        )
#         write_nifti(
#             data=predictions_wirte_org,
#             file_name=os.path.join(self.output_dir, filename),
#             affine=affine,
#             resample=False,
#             output_dtype=np.uint32,
#         )
#         print(f'MRI img:{ filename } eval done .................')
        print("save img time :", time.time() - start)
    
        start = time.time()
        engine.fire_event(IterationEvents.FORWARD_COMPLETED)
        engine.fire_event(IterationEvents.MODEL_COMPLETED)
        print("fire event time :", time.time() - start)

        return engine.state.output

In [28]:
## 6. GPU inference 테스트 - 어디까지 GPU가능한건지 확인
net.eval()
evaluator = DynUNetEvaluator_GPU_test(
    device=device,
    val_data_loader=val_loader,
    network=net,
    output_dir=val_output_dir,
    num_classes=num_classes,
    inferer=SlidingWindowInferer(
        roi_size=patch_size,
        sw_batch_size=val_batch_size,
        overlap=eval_overlap,
        mode=window_mode,
    ),
    postprocessing=None,
    key_val_metric={
        "val_mean_dice": MeanDice(
            include_background=False,
            output_transform=from_engine(["pred", "label"]),
        )
    },
    additional_metrics=None,
    amp=amp,
    tta_val=tta_val,
)

In [93]:
class DynUNetEvaluator(SupervisedEvaluator):
    """
    This class inherits from SupervisedEvaluator in MONAI, and is used with DynUNet
    on Decathlon datasets.

    Args:
        device: an object representing the device on which to run.
        val_data_loader: Ignite engine use data_loader to run, must be
            torch.DataLoader.
        network: use the network to run model forward.
        num_classes: the number of classes (output channels) for the task.
        epoch_length: number of iterations for one epoch, default to
            `len(val_data_loader)`.
        non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
            with respect to the host. For other cases, this argument has no effect.
        prepare_batch: function to parse image and label for current iteration.
        iteration_update: the callable function for every iteration, expect to accept `engine`
            and `batchdata` as input parameters. if not provided, use `self._iteration()` instead.
        inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
        postprocessing: execute additional transformation for the model output data.
            Typically, several Tensor based transforms composed by `Compose`.
        key_val_metric: compute metric when every iteration completed, and save average value to
            engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the
            checkpoint into files.
        additional_metrics: more Ignite metrics that also attach to Ignite Engine.
        val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
            CheckpointHandler, StatsHandler, SegmentationSaver, etc.
        amp: whether to enable auto-mixed-precision evaluation, default is False.
        tta_val: whether to do the 8 flips (8 = 2 ** 3, where 3 represents the three dimensions)
            test time augmentation, default is False.

    """

    def __init__(
        self,
        device: torch.device,
        val_data_loader: DataLoader,
        network: torch.nn.Module,
        num_classes: Union[str, int],
        epoch_length: Optional[int] = None,
        non_blocking: bool = False,
        prepare_batch: Callable = default_prepare_batch,
        iteration_update: Optional[Callable] = None,
        inferer: Optional[Inferer] = None,
        postprocessing: Optional[Transform] = None,
        key_val_metric: Optional[Dict[str, Metric]] = None,
        additional_metrics: Optional[Dict[str, Metric]] = None,
        val_handlers: Optional[Sequence] = None,
        amp: bool = False,
        tta_val: bool = False,
    ) -> None:
        super().__init__(
            device=device,
            val_data_loader=val_data_loader,
            network=network,
            epoch_length=epoch_length,
            non_blocking=non_blocking,
            prepare_batch=prepare_batch,
            iteration_update=iteration_update,
            inferer=inferer,
            postprocessing=postprocessing,
            key_val_metric=key_val_metric,
            additional_metrics=additional_metrics,
            val_handlers=val_handlers,
            amp=amp,
        )

        if not isinstance(num_classes, int):
            num_classes = int(num_classes)
        self.num_classes = num_classes
        self.post_pred = AsDiscrete(argmax=True, to_onehot=num_classes)
        self.post_label = AsDiscrete(to_onehot=num_classes)              # eval specific
        self.tta_val = tta_val

    def _iteration(
        self, engine: Engine, batchdata: Dict[str, Any]
    ) -> Dict[str, torch.Tensor]:
        """
        callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine.
        Return below items in a dictionary:
            - IMAGE: image Tensor data for model input, already moved to device.
            - LABEL: label Tensor data corresponding to the image, already moved to device.
            - PRED: prediction result of model.

        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
            batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.

        Raises:
            ValueError: When ``batchdata`` is None.

        """
        start = time.time()
        
        if batchdata is None:
            raise ValueError("Must provide batch data for current iteration.")
        batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking)
        if len(batch) == 2:
            inputs, targets = batch
            args: Tuple = ()
            kwargs: Dict = {}
        else:
            inputs, targets, args, kwargs = batch

#         targets = targets.cpu()      # device로 바꿔보자
        print("batch 나누고 target cpu time :", time.time() - start)

        start = time.time()
        def _compute_pred():
            ct = 1.0
#             pred = self.inferer(inputs, self.network, *args, **kwargs).cpu()    # device로 바꿔보자 (지우면댐)
            pred = self.inferer(inputs, self.network, *args, **kwargs)    # device로 바꿔보자 (지우면댐)
            pred = nn.functional.softmax(pred, dim=1)
            if not self.tta_val:
                return pred
            else:
                for dims in [[2], [3], [4], (2, 3), (2, 4), (3, 4), (2, 3, 4)]:
                    flip_inputs = torch.flip(inputs, dims=dims)
                    flip_pred = torch.flip(
                        self.inferer(flip_inputs, self.network).cpu(), dims=dims    # device로 바꿔보자
                    )
                    flip_pred = nn.functional.softmax(flip_pred, dim=1)
                    del flip_inputs
                    pred += flip_pred
                    del flip_pred
                    ct += 1
                return pred / ct

        # execute forward computation
        with eval_mode(self.network):
            if self.amp:
                with torch.cuda.amp.autocast():
                    predictions = _compute_pred()
            else:
                predictions = _compute_pred()

        inputs = inputs.cpu()    #  # device로 바꿔보자
        print("inference & input cpu time :", time.time() - start)

        start = time.time()
        predictions = self.post_pred(decollate_batch(predictions)[0])
        targets = self.post_label(decollate_batch(targets)[0])                # eval specific
        print("post transform time :", time.time() - start)

        start = time.time()
        resample_flag = batchdata["resample_flag"]
        anisotrophy_flag = batchdata["anisotrophy_flag"]
        crop_shape = batchdata["crop_shape"][0].tolist()
        original_shape = batchdata["original_shape"][0].tolist()
        if resample_flag:
            # convert the prediction back to the original (after cropped) shape
            predictions = recovery_prediction(
                predictions.numpy(), [self.num_classes, *crop_shape], anisotrophy_flag
            )
            predictions = torch.tensor(predictions)
        print("post process time :", time.time() - start)

        start = time.time()
        targets = targets.cpu()
        # put iteration outputs into engine.state
        engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets.unsqueeze(0)}
        engine.state.output[Keys.PRED] = torch.zeros([1, self.num_classes, *original_shape])
        print("engine state output IMAGE, LABEL time :", time.time() - start)
        
        start = time.time()
        # pad the prediction back to the original shape
        box_start, box_end = batchdata["bbox"][0]
        h_start, w_start, d_start = box_start
        h_end, w_end, d_end = box_end
        print("padding step1 :", time.time() - start)

        
        start = time.time()
        engine.state.output[Keys.PRED][
            0, :, h_start:h_end, w_start:w_end, d_start:d_end
        ] = predictions
        del predictions
        print("engine state output PRED time :", time.time() - start)

        
        
        # filename = batchdata["image_meta_dict"]["filename_or_obj"][0].split("/")[-1]
        # print(f'MRI img:{ filename } eval done .................')
        
        engine.fire_event(IterationEvents.FORWARD_COMPLETED)
        engine.fire_event(IterationEvents.MODEL_COMPLETED)

        return engine.state.output



In [94]:
## 6. GPU inference 테스트 - 어디까지 GPU가능한건지 확인
net.eval()
evaluator = DynUNetEvaluator(
    device=device,
    val_data_loader=val_loader,
    network=net,
#     output_dir=val_output_dir,
    num_classes=num_classes,
    inferer=SlidingWindowInferer(
        roi_size=patch_size,
        sw_batch_size=val_batch_size,
        overlap=eval_overlap,
        mode=window_mode,
    ),
    postprocessing=None,
    key_val_metric={
        "val_mean_dice": MeanDice(
            include_background=False,
            output_transform=from_engine(["pred", "label"]),
        )
    },
    additional_metrics=None,
    amp=amp,
    tta_val=tta_val,
)

In [95]:
start = time.time()
evaluator.run()
print("time :", time.time() - start)

batch 나누고 target cpu time : 0.021812915802001953
inference & input cpu time : 0.9027938842773438
post transform time : 0.0031576156616210938
post process time : 0.0003266334533691406
engine state output IMAGE, LABEL time : 8.171430587768555
padding step1 : 0.0018572807312011719
engine state output PRED time : 2.464026689529419
batch 나누고 target cpu time : 0.023492097854614258
inference & input cpu time : 0.3713648319244385
post transform time : 0.001184225082397461
post process time : 0.00028824806213378906
engine state output IMAGE, LABEL time : 2.120267629623413
padding step1 : 0.0009603500366210938
engine state output PRED time : 0.61195969581604
time : 57.13516139984131


In [83]:
start = time.time()
evaluator.run()
print("time :", time.time() - start)

batch 나누고 target cpu time : 0.010008573532104492
inference & input cpu time : 0.6507704257965088
post transform time : 0.0016832351684570312
post process time : 0.0003173351287841797
engine state output IMAGE, LABEL time : 9.368163585662842
padding step1 : 0.0011758804321289062
engine state output PRED time : 1.7084453105926514
batch 나누고 target cpu time : 0.007800102233886719
inference & input cpu time : 0.22352385520935059
post transform time : 0.0010840892791748047
post process time : 8.344650268554688e-05
engine state output IMAGE, LABEL time : 2.3483664989471436
padding step1 : 0.00038814544677734375
engine state output PRED time : 0.6502490043640137
time : 46.154680490493774


In [82]:
start = time.time()
evaluator.run()
print("time :", time.time() - start)

batch 나누고 target cpu time : 0.00948786735534668
inference & input cpu time : 0.6287620067596436
post transform time : 0.0023174285888671875
post process time : 0.0005431175231933594
engine state output IMAGE, LABEL time : 5.265115022659302
padding step1 : 0.00231170654296875
engine state output PRED time : 2.101475238800049
batch 나누고 target cpu time : 0.009245872497558594
inference & input cpu time : 0.221177339553833
post transform time : 0.00106048583984375
post process time : 7.43865966796875e-05
engine state output IMAGE, LABEL time : 2.5195584297180176
padding step1 : 0.0003178119659423828
engine state output PRED time : 0.6775071620941162
time : 44.54240584373474


In [81]:
start = time.time()
evaluator.run()
print("time :", time.time() - start)

batch 나누고 target cpu time : 0.012660741806030273
inference & input cpu time : 0.6355206966400146
post transform time : 0.00115203857421875
post process time : 9.846687316894531e-05
engine state output IMAGE, LABEL time : 4.576287508010864
padding step1 : 0.0018472671508789062
engine state output PRED time : 2.8230185508728027
batch 나누고 target cpu time : 0.009508371353149414
inference & input cpu time : 0.2380521297454834
post transform time : 0.0009570121765136719
post process time : 5.91278076171875e-05
engine state output IMAGE, LABEL time : 2.501232147216797
padding step1 : 0.0008628368377685547
engine state output PRED time : 0.7110633850097656
time : 50.182448387145996


In [80]:
start = time.time()
evaluator.run()
print("time :", time.time() - start)

batch 나누고 target cpu time : 0.008741378784179688
inference & input cpu time : 0.643047571182251
post transform time : 0.0014870166778564453
post process time : 0.00011467933654785156
engine state output IMAGE, LABEL time : 7.032412052154541
padding step1 : 0.002337932586669922
engine state output PRED time : 1.9719617366790771
batch 나누고 target cpu time : 0.008804559707641602
inference & input cpu time : 0.22884535789489746
post transform time : 0.0010373592376708984
post process time : 0.0003533363342285156
engine state output IMAGE, LABEL time : 2.4040772914886475
padding step1 : 0.0015950202941894531
engine state output PRED time : 0.8416135311126709
time : 46.7656512260437


In [79]:
start = time.time()
evaluator.run()
print("time :", time.time() - start)

batch 나누고 target cpu time : 0.010185956954956055
inference & input cpu time : 0.6356122493743896
post transform time : 0.001585245132446289
post process time : 0.00017881393432617188
engine state output IMAGE, LABEL time : 33.64306569099426
padding step1 : 0.0016582012176513672
engine state output PRED time : 22.492120027542114
batch 나누고 target cpu time : 0.006925821304321289
inference & input cpu time : 0.22810983657836914
post transform time : 0.0013251304626464844
post process time : 0.0004665851593017578
engine state output IMAGE, LABEL time : 2.1834659576416016
padding step1 : 0.0003027915954589844
engine state output PRED time : 0.7112226486206055
time : 173.87773871421814


In [78]:
start = time.time()
evaluator.run()
print("time :", time.time() - start)

batch 나누고 target cpu time : 0.01003265380859375
inference & input cpu time : 0.6416950225830078
post transform time : 0.013015031814575195
post process time : 0.0005283355712890625
engine state output IMAGE, LABEL time : 4.64120078086853
padding step1 : 0.0021772384643554688
engine state output PRED time : 1.2226996421813965
batch 나누고 target cpu time : 0.007966995239257812
inference & input cpu time : 0.24000000953674316
post transform time : 0.0009665489196777344
post process time : 6.771087646484375e-05
engine state output IMAGE, LABEL time : 2.509558916091919
padding step1 : 0.0010721683502197266
engine state output PRED time : 0.8421065807342529
time : 98.78075790405273


In [72]:
start = time.time()
evaluator.run()
print("time :", time.time() - start)

batch 나누고 target cpu time : 0.13880324363708496
inference & input cpu time : 0.7224211692810059
post transform time : 8.644158363342285
post process time : 0.008751869201660156
engine state output IMAGE, LABEL time : 9.609595537185669
padding step1 : 0.002599000930786133
engine state output PRED time : 5.351591110229492
batch 나누고 target cpu time : 0.4457237720489502
inference & input cpu time : 0.33050537109375
post transform time : 0.9618697166442871
post process time : 0.0004565715789794922
engine state output IMAGE, LABEL time : 0.4404935836791992
padding step1 : 0.000682830810546875
engine state output PRED time : 0.7321929931640625
time : 313.5694110393524


In [71]:
start = time.time()
evaluator.run()
print("time :", time.time() - start)

batch 나누고 target cpu time : 0.03365278244018555
inference & input cpu time : 0.6299469470977783
post transform time : 1.790766716003418
post process time : 0.008248329162597656
engine state output IMAGE, LABEL time : 0.26795172691345215
padding step1 : 0.02272510528564453
engine state output PRED time : 0.9441828727722168
batch 나누고 target cpu time : 0.016907691955566406
inference & input cpu time : 0.22605657577514648
post transform time : 0.29419970512390137
post process time : 0.00028514862060546875
engine state output IMAGE, LABEL time : 0.2901937961578369
padding step1 : 0.009162664413452148
engine state output PRED time : 0.8744888305664062
time : 36.87901425361633


In [42]:
start = time.time()
evaluator.run()
print("time :", time.time() - start)

batch 나누고 target cpu time : 0.02635049819946289
inference & input cpu time : 2.158511161804199
post transform time : 1.7476403713226318
post process time : 0.007225751876831055
engine state output IMAGE, LABEL time : 0.902824878692627
padding step1 : 0.0026252269744873047
engine state output PRED time : 0.24624967575073242
batch 나누고 target cpu time : 0.013247489929199219
inference & input cpu time : 1.3362798690795898
post transform time : 1.1730046272277832
post process time : 0.0009338855743408203
engine state output IMAGE, LABEL time : 0.4753870964050293
padding step1 : 0.00940847396850586
engine state output PRED time : 0.2286827564239502
time : 27.524839639663696


In [40]:
start = time.time()
evaluator.run()
print("time :", time.time() - start)

batch 나누고 target cpu time : 0.03188371658325195
inference & input cpu time : 2.4179773330688477
post transform time : 3.0485117435455322
post process time : 0.005202770233154297
engine state output IMAGE, LABEL time : 1.761392593383789
padding step1 : 0.0015418529510498047
engine state output PRED time : 0.2728283405303955
batch 나누고 target cpu time : 0.014521121978759766
inference & input cpu time : 1.3466269969940186
post transform time : 1.2124178409576416
post process time : 0.0007166862487792969
engine state output IMAGE, LABEL time : 0.37946271896362305
padding step1 : 0.0004837512969970703
engine state output PRED time : 0.2589864730834961
time : 30.914069652557373


In [39]:
start = time.time()
evaluator.run()
print("time :", time.time() - start)

batch 나누고 target cpu time : 0.028201580047607422
inference & input cpu time : 2.6068193912506104
post transform time : 1.3740296363830566
post process time : 0.006607532501220703
engine state output IMAGE, LABEL time : 1.0256309509277344
padding step1 : 0.0025115013122558594
engine state output PRED time : 0.22584128379821777
batch 나누고 target cpu time : 0.014957666397094727
inference & input cpu time : 1.3947994709014893
post transform time : 1.2087392807006836
post process time : 0.0009834766387939453
engine state output IMAGE, LABEL time : 0.3890504837036133
padding step1 : 0.0004909038543701172
engine state output PRED time : 0.24766159057617188
time : 32.28419256210327


In [38]:
start = time.time()
evaluator.run()
print("time :", time.time() - start)

batch 나누고 target cpu time : 0.029281139373779297
inference & input cpu time : 2.3431434631347656
post transform time : 1.6407890319824219
post process time : 0.001245737075805664
engine state output IMAGE, LABEL time : 1.2049484252929688
padding step1 : 0.002322673797607422
engine state output PRED time : 0.31487178802490234
batch 나누고 target cpu time : 0.02144336700439453
inference & input cpu time : 1.507706880569458
post transform time : 1.1656773090362549
post process time : 0.0008268356323242188
engine state output IMAGE, LABEL time : 0.39177393913269043
padding step1 : 0.00037217140197753906
engine state output PRED time : 0.2459874153137207
time : 57.72733664512634


In [37]:
start = time.time()
evaluator.run()
print("time :", time.time() - start)

batch 나누고 target cpu time : 0.04864764213562012
inference & input cpu time : 5.725265026092529
post transform time : 36.36089324951172
post process time : 0.022902965545654297
engine state output IMAGE, LABEL time : 38.6579794883728
padding step1 : 0.001954317092895508
engine state output PRED time : 0.6220638751983643
batch 나누고 target cpu time : 0.023790597915649414
inference & input cpu time : 1.4547474384307861
post transform time : 1.4021012783050537
post process time : 0.0006260871887207031
engine state output IMAGE, LABEL time : 0.385875940322876
padding step1 : 0.0010504722595214844
engine state output PRED time : 0.27706384658813477
time : 124.70445346832275


In [25]:
start = time.time()
evaluator.run()
print("time :", time.time() - start)

inference & GPU to CPU time : 0.833308219909668
post process time : 20.509447813034058
<class 'torch.Tensor'>
torch.Size([109, 154, 154, 152])
torch.Size([154, 154, 152])
save s_SU0303_00_0_b.nii.gz with shape: (186, 230, 230)
save img time : 72.48131132125854
fire event time : 0.006493330001831055
inference & GPU to CPU time : 0.3933119773864746
post process time : 0.47392749786376953
<class 'torch.Tensor'>
torch.Size([109, 141, 146, 133])
torch.Size([141, 146, 133])
save s_SU0197_00_0_b.nii.gz with shape: (186, 230, 230)
save img time : 4.3263585567474365
fire event time : 0.006649017333984375
time : 311.10440850257874


In [None]:
a[93]

In [20]:
for idx, i in enumerate(val_loader):
    print(idx)
    print(i['image'].shape)

0
torch.Size([1, 2, 154, 154, 152])
1
torch.Size([1, 2, 141, 146, 133])


이미지를 로드하여 mapping transform 테스트를 진행해야 한다.

In [None]:
from monai.transforms import (
    LoadImaged,
    AddChanneld,
    MapLabelValued,
    Compose
)
from monai.data import (
    CacheDataset,
    DataLoader,
    partition_dataset,
)

orig_label_classes, target_label_classes = (
    np.array([   0,    2,    3,    4,    5,    7,    8,   10,   11,   12,   13,
         14,   15,   16,   17,   18,   24,   26,   28,   30,   31,   41,
         42,   43,   44,   46,   47,   49,   50,   51,   52,   53,   54,
         58,   60,   62,   63,   77,   80,   85,  251,  252,  253,  254,
        255, 1000, 1002, 1003, 1005, 1006, 1007, 1008, 1009, 1010, 1011,
       1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022,
       1023, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1034, 1035,
       2000, 2002, 2003, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012,
       2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023,
       2024, 2025, 2026, 2027, 2028, 2029, 2030, 2031, 2034, 2035]),
    np.array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
        78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
        91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
       104, 105, 106, 107, 108])
)

def get_test_transform():
    keys = ["label", "pred"]
    transforms = [
        LoadImaged(keys=keys),
#         AddChanneld(keys=keys),
#         MapLabelValued(
#             keys=["pred"], 
#             orig_labels=target_label_classes, 
#             target_labels=orig_label_classes
#         ),
#         Orientationd(keys=keys, axcodes="RAS"),
#         ConcatItemsd(keys=["image", "mask"], name="image"),
#         PreprocessAnisotropic(
#             keys=["image"],
#             clip_values=clip_values,
#             pixdim=spacing,
#             normalize_values=normalize_values,
#             model_mode="test",
#         ),
#         CastToTyped(keys=["image"], dtype=(np.float32)),
#         EnsureTyped(keys=["image"]),
    ]
    return Compose(transforms)
transform = get_test_transform()

In [None]:
# import shutil
# original_file = os.path.join(data_dir, "s_SU0197_00_0_b.nii.gz")
# shutil.copy(original_file, './')

In [None]:
data = []
label_file = os.path.join(data_dir, "s_SU0197_00_0_l.nii.gz")
pred_file = "runs_eval2/s_SU0197_00_0_b.nii.gz"
data.append({'label': label_file, 'pred': pred_file})

dataset = CacheDataset(
    data=data,
    transform=transform,
    num_workers=8,
    cache_rate=1.0,
)

data_loader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=True,
    num_workers=4,
    drop_last=True,
)

In [None]:
test_data = first(data_loader)

In [None]:
test_data.keys()

In [None]:
affine = test_data["pred_meta_dict"]["affine"].numpy()[0]
affine

In [None]:
test_data['label'].shape, test_data['pred'].shape

In [None]:
test_data['pred'].dtype, test_data['label'].dtype

In [None]:
len(np.unique(test_data['pred'])), np.unique(test_data['pred'])

In [None]:
np.unique(test_data['label'])

In [None]:
result = np.where(test_data['pred'] == 2018)   # 2018, 93
result

In [None]:
idx = 12
test_data['pred'][result[0][idx], result[1][idx], result[2][idx], result[3][idx]]

In [None]:
test_data['label'][result[0][idx], result[1][idx], result[2][idx], result[3][idx]]

오케 맵핑이 잘되고 있음을 확인하였음.

------------------
저장된 이미지로부터 post-transform mapping 하여 결과 확인해본다.

In [None]:
from monai.transforms import MapLabelValue
orig_label_classes, target_label_classes = (
    np.array([   0,    2,    3,    4,    5,    7,    8,   10,   11,   12,   13,
        14,   15,   16,   17,   18,   24,   26,   28,   30,   31,   41,
        42,   43,   44,   46,   47,   49,   50,   51,   52,   53,   54,
        58,   60,   62,   63,   77,   80,   85,  251,  252,  253,  254,
        255, 1000, 1002, 1003, 1005, 1006, 1007, 1008, 1009, 1010, 1011,
    1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022,
    1023, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1034, 1035,
    2000, 2002, 2003, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012,
    2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023,
    2024, 2025, 2026, 2027, 2028, 2029, 2030, 2031, 2034, 2035]),
    np.array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
        78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
        91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
    104, 105, 106, 107, 108])
)
post_trans = MapLabelValue(    # 0~108로 잡혀있는 label을 original index로 변경
    orig_labels=target_label_classes, 
    target_labels=orig_label_classes
)

In [None]:
test_data['pred'].shape, test_data['label'].shape

In [None]:
aaa = post_trans(test_data['pred'])
aaa.shape

In [None]:
# aaa = test_data['pred']

In [None]:
result = np.where(aaa == 2018)   # 41
result

In [None]:
idx = 226
aaa[result[0][idx], result[1][idx], result[2][idx], result[3][idx]]

In [None]:
test_data['label'][result[0][idx], result[1][idx], result[2][idx], result[3][idx]]

post transform 한 결과를 이미지로 저장한다.

In [None]:
aaa.shape

In [None]:
aaa.squeeze().shape

In [None]:
from monai.data.nifti_writer import write_nifti

In [None]:
write_nifti(
        data=aaa.squeeze(),
        file_name='aa2.nii.gz',
        affine=affine,
        resample=False,
        output_dtype=np.uint8,
    )

In [None]:
aaa.shape

In [None]:
from monai.transforms import (
    LoadImaged,
    AddChanneld,
    MapLabelValued,
    Compose
)
from monai.data import (
    CacheDataset,
    DataLoader,
    partition_dataset,
)

orig_label_classes, target_label_classes = (
    np.array([   0,    2,    3,    4,    5,    7,    8,   10,   11,   12,   13,
         14,   15,   16,   17,   18,   24,   26,   28,   30,   31,   41,
         42,   43,   44,   46,   47,   49,   50,   51,   52,   53,   54,
         58,   60,   62,   63,   77,   80,   85,  251,  252,  253,  254,
        255, 1000, 1002, 1003, 1005, 1006, 1007, 1008, 1009, 1010, 1011,
       1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022,
       1023, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1034, 1035,
       2000, 2002, 2003, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012,
       2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023,
       2024, 2025, 2026, 2027, 2028, 2029, 2030, 2031, 2034, 2035]),
    np.array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
        78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
        91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
       104, 105, 106, 107, 108])
)

def get_test_transform():
    keys = ["label", "pred"]
    transforms = [
        LoadImaged(keys=keys),
#         AddChanneld(keys=keys),
#         MapLabelValued(
#             keys=["pred"], 
#             orig_labels=target_label_classes, 
#             target_labels=orig_label_classes
#         ),
#         Orientationd(keys=keys, axcodes="RAS"),
#         ConcatItemsd(keys=["image", "mask"], name="image"),
#         PreprocessAnisotropic(
#             keys=["image"],
#             clip_values=clip_values,
#             pixdim=spacing,
#             normalize_values=normalize_values,
#             model_mode="test",
#         ),
#         CastToTyped(keys=["image"], dtype=(np.float32)),
#         EnsureTyped(keys=["image"]),
    ]
    return Compose(transforms)
transform = get_test_transform()

In [None]:
data = []
label_file = os.path.join(data_dir, "s_SU0197_00_0_l.nii.gz")
pred_file = "aa2.nii.gz"
data.append({'label': label_file, 'pred': pred_file})

dataset = CacheDataset(
    data=data,
    transform=transform,
    num_workers=8,
    cache_rate=1.0,
)

data_loader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=True,
    num_workers=4,
    drop_last=True,
)

In [None]:
test_data = first(data_loader)

In [None]:
test_data.keys()

In [None]:
result = np.where(test_data['pred'] == 2018)
result

In [None]:
idx = 12
test_data['pred'][result[0][idx], result[1][idx], result[2][idx], result[3][idx]]

In [None]:
test_data['label'][result[0][idx], result[1][idx], result[2][idx], result[3][idx]]

## finding 정리
evaluator에 MeanDice가 저장되는 개념.
 * evaluator.state.metrics : 전체 ROI MeanDice 평균
 * evaluator.state.metric_details["val_mean_dice"] : ROI별 MeanDice 평균

In [None]:
if local_rank == 0:
    print(evaluator.state.metrics)
    results = evaluator.state.metric_details["val_mean_dice"]
    if num_classes > 2:
        for i in range(num_classes - 1):
            print(
                "mean dice for label {} is {}".format(i + 1, results[:, i].mean())
            )

In [None]:
dir(evaluator)

In [None]:
evaluator.state_dict()

In [None]:
dir(evaluator.state)

In [None]:
type(evaluator.state.metric_details), evaluator.state.metric_details.keys()

In [None]:
evaluator.state.metric_details['val_mean_dice']

In [None]:
evaluator.state.metric_details['val_mean_dice'].shape