## Evaluation 4D dynunet pipeline with NeuroI ROI dataset
* `train.py`with mode to `val` 을 기반으로함
* `commands/val.sh` 을 참고
* Brain image + mask image -> 4D modalities
* in_channels: 4, out_channels:1

In [1]:
import os
import numpy as np
import time

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from monai.utils import first
from monai.handlers import (
    MeanDice,
    from_engine,
)
from monai.inferers import SimpleInferer, SlidingWindowInferer
# from monai.utils import set_determinism

from create_network import get_network_ke
# from evaluator import (
#     DynUNetEvaluator,
#     DynUNetEvaluator_gpu,
#     DynUNetEvaluator_SaveResult,
#     DynUNetEvaluator_GPU_SaveResult_PostMapping
# )
from config import get_config
from dataset_roi_4d import get_val_loader   # 4D modality test

In [2]:
config = "/data/kehyeong/project/MONAI_examples/dynunet_pipeline/config_roi_earlystop_toy_220209.yaml"
checkpoint = "/data/train/running/l/model_roi_try1_220217/models/net_key_metric=0.3331.pt"
val_dataset = "/work/NeuroI-models/ke-monai/data/roi/dataset_test_roi_toy2.csv"
val_output_dir = "./runs_eval2"

# tta_val = True    # ?!?!?!  whether to use test time augmentation.
tta_val = False

multi_gpu_flag = False
spacing = [1.0, 1.0, 1.0]
deep_supr_num = 3
window_mode = "gaussian"     # the mode parameter for SlidingWindowInferer.
eval_overlap = 0.5
amp = False
local_rank = 0

config = get_config(config)
data_dir = config["data_dir"]
image_file_path = config["image_file_path"]
label_file_path = config["label_file_path"]
mask_file_path = config["mask_file_path"]
val_batch_size = config["val"]["batch_size"] 
val_num_workers = config["val"]["num_workers"]
num_classes = config["num_classes"]
patch_size = config["patch_size"]

TTA는 말 그대로 Inference(Test) 과정에서 Augmentation 을 적용한 뒤 예측의 확률을 평균(또는 다른 방법)을 통해 도출하는 기법입니다. 모델 학습간에 다양한 Augmentation 을 적용하여 학습하였을시, Inference 과정에서도 유사한 Augmentation 을 적용하여 정확도를 높일 수 있습니다. 또는 이미지에 객체가 너무 작게 위치한 경우 원본 이미지, Crop 이미지를 넣는 등 다양하게 활용이 가능

In [3]:
if not os.path.exists(val_output_dir):
    os.makedirs(val_output_dir)
    
if multi_gpu_flag:
    dist.init_process_group(backend="nccl", init_method="env://")
    device = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(device)
else:
    device = torch.device("cuda:6")
#     device = torch.device("cpu")

val_loader = get_val_loader(
    data_dir=data_dir,
    id_file=val_dataset,
    image_file_pattern=image_file_path,
    label_file_pattern=label_file_path,
    mask_file_pattern=mask_file_path,
    batch_size=val_batch_size,
    num_workers=val_num_workers,
    multi_gpu_flag=multi_gpu_flag
)

{dir}/s_{id}_{aid}_b.nii.gz /data/train/running/l/input_augmented SU0303 00_0
{dir}/s_{id}_{aid}_b.nii.gz /data/train/running/l/input_augmented SU0197 00_0


Loading dataset: 100%|███████████████████████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.04s/it]


In [4]:
test_data = first(val_loader)
print(test_data.keys())
print("image shape:", test_data['image'].shape)
print("image dtype:", test_data['image'].dtype)
print("label shape:", test_data['label'].shape)
print("label dtype:", test_data['label'].dtype)
print("1번 배치의 유니크한 라벨 리스트:", np.unique(test_data['label']))
total_labels = np.unique(test_data['label'])
print(f'1번 배치의 유니크한 라벨 class 수: {len(total_labels)}')

dict_keys(['image', 'label', 'mask', 'image_meta_dict', 'label_meta_dict', 'mask_meta_dict', 'image_transforms', 'label_transforms', 'mask_transforms', 'original_shape', 'bbox', 'crop_shape', 'resample_flag', 'anisotrophy_flag'])
image shape: torch.Size([1, 2, 154, 154, 152])
image dtype: torch.float32
label shape: torch.Size([1, 1, 186, 230, 230])
label dtype: torch.uint8
1번 배치의 유니크한 라벨 리스트: [  0   1   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17  18
  20  21  23  24  25  26  27  28  29  30  31  32  33  34  36  37  39  40
  41  42  43  44  45  46  47  48  49  50  51  52  53  54  55  56  57  58
  59  60  61  62  63  64  65  66  67  68  69  70  71  72  73  74  75  76
  77  78  79  80  81  82  83  84  85  86  87  88  89  90  91  92  93  94
  95  96  97  98  99 100 101 102 103 104 105 106 107 108]
1번 배치의 유니크한 라벨 class 수: 104


--------------

post transform 테스트

In [None]:
from monai.transforms import AsDiscrete, Transform
post_label = AsDiscrete(to_onehot=num_classes)
post_pred = AsDiscrete(argmax=True, to_onehot=num_classes)

In [None]:
aaa = post_label(test_data['label'][0])
aaa.shape

In [None]:
aaa

In [None]:
aaa[:, 0, 0, 0]

In [None]:
test_data['label'].shape

In [None]:
test_data['label'][0, :, 80, 80, 80]

In [None]:
aaa[:, 80, 80, 80]

In [None]:
bbb = post_pred(aaa)
bbb.shape

In [None]:
bbb[:, 80, 80, 80]

------------------------------

In [5]:
properties = {
    'modality': [0,1],
    'labels': np.arange(num_classes)
}
n_class = len(properties["labels"])
in_channels = len(properties["modality"])
in_channels, n_class

(2, 109)

In [6]:
# produce the network
net = get_network_ke(properties, patch_size, spacing, deep_supr_num, 
                     val_output_dir, checkpoint)    # val_output_dir은 삭제 필요. 안씀
print('Loading nnUNET Done!!!')
net = net.to(device)
print('Loading nnUNET to GPU devices Done!!!')
print(net)

pretrained checkpoint: /data/train/running/l/model_roi_try1_220217/models/net_key_metric=0.3331.pt loaded
Loading nnUNET Done!!!
Loading nnUNET to GPU devices Done!!!
DynUNet(
  (input_block): UnetBasicBlock(
    (conv1): Convolution(
      (conv): Conv3d(2, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    )
    (conv2): Convolution(
      (conv): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    )
    (lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
    (norm1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (norm2): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  )
  (downsamples): ModuleList(
    (0): UnetBasicBlock(
      (conv1): Convolution(
        (conv): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
      )
      (conv2): Convolution(
        (conv): Conv3d(64, 64, kernel_size=(3, 

In [7]:
if multi_gpu_flag:
    net = DistributedDataParallel(module=net, device_ids=[device])

num_classes = len(properties["labels"])

In [8]:
device

device(type='cuda', index=6)

In [None]:
# ## 1. 기본형
# net.eval()
# evaluator = DynUNetEvaluator(
#     device=device,
#     val_data_loader=val_loader,
#     network=net,
#     num_classes=num_classes,
#     inferer=SlidingWindowInferer(
#         roi_size=patch_size,
#         sw_batch_size=val_batch_size,
#         overlap=eval_overlap,
#         mode=window_mode,
#     ),
#     postprocessing=None,      # 이걸 바꿔줘야할듯...
#     key_val_metric={
#         "val_mean_dice": MeanDice(
#             include_background=False,
#             output_transform=from_engine(["pred", "label"]),
#         )
#     },
#     additional_metrics=None,
#     amp=amp,
#     tta_val=tta_val,
# )

In [None]:
# ## 2. GPU inference 형
# net.eval()
# evaluator = DynUNetEvaluator_gpu(
#     device=device,
#     val_data_loader=val_loader,
#     network=net,
#     num_classes=num_classes,
#     inferer=SlidingWindowInferer(
#         roi_size=patch_size,
#         sw_batch_size=val_batch_size,
#         overlap=eval_overlap,
#         mode=window_mode,
#     ),
#     postprocessing=None,      # 이걸 바꿔줘야할듯...
#     key_val_metric={
#         "val_mean_dice": MeanDice(
#             include_background=False,
#             output_transform=from_engine(["pred", "label"]),
#         )
#     },
#     additional_metrics=None,
#     amp=amp,
#     tta_val=tta_val,
# )

In [None]:
# ## 3. GPU ver. + inference image 저장 + tta False

# net.eval()
# evaluator = DynUNetEvaluator_SaveResult(
#     device=device,
#     val_data_loader=val_loader,
#     network=net,
#     output_dir=val_output_dir,
#     num_classes=num_classes,
#     inferer=SlidingWindowInferer(
#         roi_size=patch_size,
#         sw_batch_size=val_batch_size,
#         overlap=eval_overlap,
#         mode=window_mode,
#     ),
#     postprocessing=None,      # 이걸 바꿔줘야할듯...
#     key_val_metric={
#         "val_mean_dice": MeanDice(
#             include_background=False,
#             output_transform=from_engine(["pred", "label"]),
#         )
#     },
#     additional_metrics=None,
#     amp=amp,
#     tta_val=tta_val,
# )

In [None]:
# ## 4. GPU ver. + inference image 저장형 + tta False + posttransform (mapping index)
# from monai.transforms import (
#     LoadImaged,
#     AddChanneld,
#     MapLabelValued,
#     Compose
# )

# orig_label_classes, target_label_classes = (
#     np.array([   0,    2,    3,    4,    5,    7,    8,   10,   11,   12,   13,
#          14,   15,   16,   17,   18,   24,   26,   28,   30,   31,   41,
#          42,   43,   44,   46,   47,   49,   50,   51,   52,   53,   54,
#          58,   60,   62,   63,   77,   80,   85,  251,  252,  253,  254,
#         255, 1000, 1002, 1003, 1005, 1006, 1007, 1008, 1009, 1010, 1011,
#        1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022,
#        1023, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1034, 1035,
#        2000, 2002, 2003, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012,
#        2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023,
#        2024, 2025, 2026, 2027, 2028, 2029, 2030, 2031, 2034, 2035]),
#     np.array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
#         13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
#         26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
#         39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
#         52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
#         65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
#         78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
#         91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
#        104, 105, 106, 107, 108])
# )

# post_trans = MapLabelValued(
#     keys=["pred"], 
#     orig_labels=target_label_classes, 
#     target_labels=orig_label_classes
# )


# net.eval()
# evaluator = DynUNetEvaluator_SaveResult(
#     device=device,
#     val_data_loader=val_loader,
#     network=net,
#     output_dir=val_output_dir,
#     num_classes=num_classes,
#     inferer=SlidingWindowInferer(
#         roi_size=patch_size,
#         sw_batch_size=val_batch_size,
#         overlap=eval_overlap,
#         mode=window_mode,
#     ),
#     postprocessing=post_trans,      # 이걸 바꿔줘야할듯...
#     key_val_metric={
#         "val_mean_dice": MeanDice(
#             include_background=False,
#             output_transform=from_engine(["pred", "label"]),
#         )
#     },
#     additional_metrics=None,
#     amp=amp,
#     tta_val=tta_val,
# )

In [80]:
import os
import numpy as np
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union

import torch
import torch.nn as nn
from ignite.engine import Engine
from ignite.metrics import Metric
from monai.data import decollate_batch
from monai.data.nifti_writer import write_nifti
from monai.engines import SupervisedEvaluator
from monai.engines.utils import CommonKeys as Keys
from monai.engines.utils import IterationEvents, default_prepare_batch
from monai.inferers import Inferer
from monai.networks.utils import eval_mode
from monai.transforms import AsDiscrete, Transform
from torch.utils.data import DataLoader

from transforms import recovery_prediction

from monai.transforms import Compose, AddChannel, MapLabelValue, CastToType

class DynUNetEvaluator_GPU_SaveResult_PostMapping(SupervisedEvaluator):
    """
    넣어준 모든 이미지들의 DiceMean 계산 뿐 아니라 inference 결과 이미지도 저장하도록 변경

    Args:
        device: an object representing the device on which to run.
        val_data_loader: Ignite engine use data_loader to run, must be
            torch.DataLoader.
        network: use the network to run model forward.
        num_classes: the number of classes (output channels) for the task.
        epoch_length: number of iterations for one epoch, default to
            `len(val_data_loader)`.
        non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
            with respect to the host. For other cases, this argument has no effect.
        prepare_batch: function to parse image and label for current iteration.
        iteration_update: the callable function for every iteration, expect to accept `engine`
            and `batchdata` as input parameters. if not provided, use `self._iteration()` instead.
        inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
        postprocessing: execute additional transformation for the model output data.
            Typically, several Tensor based transforms composed by `Compose`.
        key_val_metric: compute metric when every iteration completed, and save average value to
            engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the
            checkpoint into files.
        additional_metrics: more Ignite metrics that also attach to Ignite Engine.
        val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
            CheckpointHandler, StatsHandler, SegmentationSaver, etc.
        amp: whether to enable auto-mixed-precision evaluation, default is False.
        tta_val: whether to do the 8 flips (8 = 2 ** 3, where 3 represents the three dimensions)
            test time augmentation, default is False.

    """

    def __init__(
        self,
        device: torch.device,
        val_data_loader: DataLoader,
        network: torch.nn.Module,
        output_dir: str,                                      # infer specific
        num_classes: Union[str, int],
        epoch_length: Optional[int] = None,
        non_blocking: bool = False,
        prepare_batch: Callable = default_prepare_batch,
        iteration_update: Optional[Callable] = None,
        inferer: Optional[Inferer] = None,
        postprocessing: Optional[Transform] = None,
        key_val_metric: Optional[Dict[str, Metric]] = None,
        additional_metrics: Optional[Dict[str, Metric]] = None,
        val_handlers: Optional[Sequence] = None,
        amp: bool = False,
        tta_val: bool = False,
    ) -> None:
        super().__init__(
            device=device,
            val_data_loader=val_data_loader,
            network=network,
            epoch_length=epoch_length,
            non_blocking=non_blocking,
            prepare_batch=prepare_batch,
            iteration_update=iteration_update,
            inferer=inferer,
            postprocessing=postprocessing,
            key_val_metric=key_val_metric,
            additional_metrics=additional_metrics,
            val_handlers=val_handlers,
            amp=amp,
        )

        if not isinstance(num_classes, int):
            num_classes = int(num_classes)
        self.output_dir = output_dir               # infer specific
        self.num_classes = num_classes
        self.post_pred = AsDiscrete(argmax=True, to_onehot=num_classes)
        self.post_label = AsDiscrete(to_onehot=num_classes)              # eval specific
        self.tta_val = tta_val
        
        # orig_label_classes, target_label_classes = (
        #     np.array([   0,    2,    3,    4,    5,    7,    8,   10,   11,   12,   13,
        #         14,   15,   16,   17,   18,   24,   26,   28,   30,   31,   41,
        #         42,   43,   44,   46,   47,   49,   50,   51,   52,   53,   54,
        #         58,   60,   62,   63,   77,   80,   85,  251,  252,  253,  254,
        #         255, 1000, 1002, 1003, 1005, 1006, 1007, 1008, 1009, 1010, 1011,
        #     1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022,
        #     1023, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1034, 1035,
        #     2000, 2002, 2003, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012,
        #     2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023,
        #     2024, 2025, 2026, 2027, 2028, 2029, 2030, 2031, 2034, 2035], dtype=np.float64),
        #     np.array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        #         13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        #         26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        #         39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        #         52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        #         65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
        #         78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
        #         91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
        #     104, 105, 106, 107, 108], dtype=np.float64)
        # )
        # self.orig_label_classes = orig_label_classes
        # self.target_label_classes = target_label_classes
        # self.post_trans = Compose([
        #     AddChannel(),
        #     MapLabelValue(    # 0~108로 잡혀있는 label을 original index로 변경
        #                   orig_labels=self.target_label_classes, 
        #                   target_labels=self.orig_label_classes,
        #                   dtype=np.uint8
        #                   )
        #     ])
        # self.post_trans = MapLabelValue(    # 0~108로 잡혀있는 label을 original index로 변경
        #                                 orig_labels=self.target_label_classes, 
        #                                 target_labels=self.orig_label_classes
        #                   )

    def _iteration(
        self, engine: Engine, batchdata: Dict[str, Any]
    ) -> Dict[str, torch.Tensor]:
        """
        callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine.
        Return below items in a dictionary:
            - IMAGE: image Tensor data for model input, already moved to device.
            - LABEL: label Tensor data corresponding to the image, already moved to device.
            - PRED: prediction result of model.

        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
            batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.

        Raises:
            ValueError: When ``batchdata`` is None.

        """
        if batchdata is None:
            raise ValueError("Must provide batch data for current iteration.")
        batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking)
        if len(batch) == 2:
            inputs, targets = batch
            args: Tuple = ()
            kwargs: Dict = {}
        else:
            inputs, targets, args, kwargs = batch

        targets = targets.cpu()      # device로 바꿔보자
        # print('CPU 비활성화!!!!!!!!!! ')

        def _compute_pred():
            ct = 1.0
            # pred = self.inferer(inputs, self.network, *args, **kwargs).cpu()    # device로 바꿔보자 (지우면댐)
            pred = self.inferer(inputs, self.network, *args, **kwargs)    # device로 바꿔보자 (지우면댐)
            pred = nn.functional.softmax(pred, dim=1)
            if not self.tta_val:
                return pred
            else:
                for dims in [[2], [3], [4], (2, 3), (2, 4), (3, 4), (2, 3, 4)]:
                    flip_inputs = torch.flip(inputs, dims=dims)
                    flip_pred = torch.flip(
                        # self.inferer(flip_inputs, self.network).cpu(), dims=dims    # device로 바꿔보자
                        self.inferer(flip_inputs, self.network), dims=dims    # device로 바꿔보자
                    )
                    flip_pred = nn.functional.softmax(flip_pred, dim=1)
                    del flip_inputs
                    pred += flip_pred
                    del flip_pred
                    ct += 1
                return pred / ct

        # execute forward computation
        with eval_mode(self.network):
            if self.amp:
                with torch.cuda.amp.autocast():
                    predictions = _compute_pred()
            else:
                predictions = _compute_pred()

        inputs = inputs.cpu()    #  # device로 바꿔보자

        predictions = self.post_pred(decollate_batch(predictions)[0])
        targets = self.post_label(decollate_batch(targets)[0])                # eval specific

        affine = batchdata["image_meta_dict"]["affine"].numpy()[0]            # infer specific
        resample_flag = batchdata["resample_flag"]
        anisotrophy_flag = batchdata["anisotrophy_flag"]
        crop_shape = batchdata["crop_shape"][0].tolist()
        original_shape = batchdata["original_shape"][0].tolist()
        if resample_flag:
            # convert the prediction back to the original (after cropped) shape
            predictions = recovery_prediction(
                predictions.numpy(), [self.num_classes, *crop_shape], anisotrophy_flag
            )
            predictions = torch.tensor(predictions)

        ## 이미지 저장
        predictions_wirte = predictions.cpu()
        print(type(predictions_wirte))
        print(predictions_wirte.shape)
        predictions_wirte = np.argmax(predictions_wirte, axis=0)
        print(predictions_wirte.shape)
        predictions_wirte_org = np.zeros([*original_shape])
        
        # put iteration outputs into engine.state
        engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets.unsqueeze(0)}
        engine.state.output[Keys.PRED] = torch.zeros([1, self.num_classes, *original_shape])
        # pad the prediction back to the original shape
        box_start, box_end = batchdata["bbox"][0]
        h_start, w_start, d_start = box_start
        h_end, w_end, d_end = box_end

        engine.state.output[Keys.PRED][
            0, :, h_start:h_end, w_start:w_end, d_start:d_end
        ] = predictions
        del predictions

        
        ## 이미지 저장
        # predictions_wirte = self.post_trans(predictions_wirte)
        orig_label_classes, target_label_classes = (
            np.array([   0,    2,    3,    4,    5,    7,    8,   10,   11,   12,   13,
                14,   15,   16,   17,   18,   24,   26,   28,   30,   31,   41,
                42,   43,   44,   46,   47,   49,   50,   51,   52,   53,   54,
                58,   60,   62,   63,   77,   80,   85,  251,  252,  253,  254,
                255, 1000, 1002, 1003, 1005, 1006, 1007, 1008, 1009, 1010, 1011,
            1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022,
            1023, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1034, 1035,
            2000, 2002, 2003, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012,
            2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023,
            2024, 2025, 2026, 2027, 2028, 2029, 2030, 2031, 2034, 2035]),
            np.array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
                13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
                26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
                39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
                52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
                65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
                78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
                91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
            104, 105, 106, 107, 108])
        )
        post_trans = MapLabelValue(    # 0~108로 잡혀있는 label을 original index로 변경
            orig_labels=target_label_classes, 
            target_labels=orig_label_classes,
#             dtype=np.uint8
        )
#         post_trans = Compose([
#             MapLabelValue(    # 0~108로 잡혀있는 label을 original index로 변경
#             orig_labels=target_label_classes, 
#             target_labels=orig_label_classes,
# #             dtype=np.uint8
#             ),
#             CastToType(dtype=np.int64)
            
#         ])
               
        
        predictions_wirte_org[h_start:h_end, w_start:w_end, d_start:d_end] = predictions_wirte
        del predictions_wirte
        print('변형전 dtype', predictions_wirte_org.dtype)
        print(np.unique(predictions_wirte_org))
        # predictions_wirte_org = self.post_trans(predictions_wirte_org)   # 원래대로 pred index 복구
        # predictions_wirte_org = predictions_wirte_org.squeeze()
        predictions_wirte_org = post_trans(predictions_wirte_org)
        print('변형후 dtype', predictions_wirte_org.dtype)
        print(np.unique(predictions_wirte_org))
        
        filename = batchdata["image_meta_dict"]["filename_or_obj"][0].split("/")[-1]
        print(
            "save {} with shape: {}".format(
                filename, predictions_wirte_org.shape
            )
        )
        write_nifti(
            data=predictions_wirte_org,
            file_name=os.path.join(self.output_dir, filename),
            affine=affine,
            resample=False,
            output_dtype=np.uint32,
        )
        print(f'MRI img:{ filename } eval done .................')
        engine.fire_event(IterationEvents.FORWARD_COMPLETED)
        engine.fire_event(IterationEvents.MODEL_COMPLETED)

        return engine.state.output

In [96]:
import os
import numpy as np
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union

import torch
import torch.nn as nn
from ignite.engine import Engine
from ignite.metrics import Metric
from monai.data import decollate_batch
from monai.data.nifti_writer import write_nifti
from monai.engines import SupervisedEvaluator
from monai.engines.utils import CommonKeys as Keys
from monai.engines.utils import IterationEvents, default_prepare_batch
from monai.inferers import Inferer
from monai.networks.utils import eval_mode
from monai.transforms import AsDiscrete, Transform
from torch.utils.data import DataLoader

from transforms import recovery_prediction

from monai.transforms import Compose, AddChannel, MapLabelValue, CastToType

class DynUNetEvaluator_GPU_SaveResult_PostMapping2(SupervisedEvaluator):
    """
    넣어준 모든 이미지들의 DiceMean 계산 뿐 아니라 inference 결과 이미지도 저장하도록 변경

    Args:
        device: an object representing the device on which to run.
        val_data_loader: Ignite engine use data_loader to run, must be
            torch.DataLoader.
        network: use the network to run model forward.
        num_classes: the number of classes (output channels) for the task.
        epoch_length: number of iterations for one epoch, default to
            `len(val_data_loader)`.
        non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
            with respect to the host. For other cases, this argument has no effect.
        prepare_batch: function to parse image and label for current iteration.
        iteration_update: the callable function for every iteration, expect to accept `engine`
            and `batchdata` as input parameters. if not provided, use `self._iteration()` instead.
        inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
        postprocessing: execute additional transformation for the model output data.
            Typically, several Tensor based transforms composed by `Compose`.
        key_val_metric: compute metric when every iteration completed, and save average value to
            engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the
            checkpoint into files.
        additional_metrics: more Ignite metrics that also attach to Ignite Engine.
        val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
            CheckpointHandler, StatsHandler, SegmentationSaver, etc.
        amp: whether to enable auto-mixed-precision evaluation, default is False.
        tta_val: whether to do the 8 flips (8 = 2 ** 3, where 3 represents the three dimensions)
            test time augmentation, default is False.

    """

    def __init__(
        self,
        device: torch.device,
        val_data_loader: DataLoader,
        network: torch.nn.Module,
        output_dir: str,                                      # infer specific
        num_classes: Union[str, int],
        epoch_length: Optional[int] = None,
        non_blocking: bool = False,
        prepare_batch: Callable = default_prepare_batch,
        iteration_update: Optional[Callable] = None,
        inferer: Optional[Inferer] = None,
        postprocessing: Optional[Transform] = None,
        key_val_metric: Optional[Dict[str, Metric]] = None,
        additional_metrics: Optional[Dict[str, Metric]] = None,
        val_handlers: Optional[Sequence] = None,
        amp: bool = False,
        tta_val: bool = False,
    ) -> None:
        super().__init__(
            device=device,
            val_data_loader=val_data_loader,
            network=network,
            epoch_length=epoch_length,
            non_blocking=non_blocking,
            prepare_batch=prepare_batch,
            iteration_update=iteration_update,
            inferer=inferer,
            postprocessing=postprocessing,
            key_val_metric=key_val_metric,
            additional_metrics=additional_metrics,
            val_handlers=val_handlers,
            amp=amp,
        )

        if not isinstance(num_classes, int):
            num_classes = int(num_classes)
        self.output_dir = output_dir               # infer specific
        self.num_classes = num_classes
        self.post_pred = AsDiscrete(argmax=True, to_onehot=num_classes)
        self.post_label = AsDiscrete(to_onehot=num_classes)              # eval specific
        self.tta_val = tta_val
        
        orig_label_classes, target_label_classes = (
            np.array([   0,    2,    3,    4,    5,    7,    8,   10,   11,   12,   13,
                14,   15,   16,   17,   18,   24,   26,   28,   30,   31,   41,
                42,   43,   44,   46,   47,   49,   50,   51,   52,   53,   54,
                58,   60,   62,   63,   77,   80,   85,  251,  252,  253,  254,
                255, 1000, 1002, 1003, 1005, 1006, 1007, 1008, 1009, 1010, 1011,
            1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022,
            1023, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1034, 1035,
            2000, 2002, 2003, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012,
            2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023,
            2024, 2025, 2026, 2027, 2028, 2029, 2030, 2031, 2034, 2035], dtype=np.float64),
            np.array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
                13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
                26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
                39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
                52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
                65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
                78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
                91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
            104, 105, 106, 107, 108], dtype=np.float64)
        )
        self.orig_label_classes = orig_label_classes
        self.target_label_classes = target_label_classes
        # self.post_trans = Compose([
        #     AddChannel(),
        #     MapLabelValue(    # 0~108로 잡혀있는 label을 original index로 변경
        #                   orig_labels=self.target_label_classes, 
        #                   target_labels=self.orig_label_classes,
        #                   dtype=np.uint8
        #                   )
        #     ])
        self.post_trans = MapLabelValue(    # 0~108로 잡혀있는 label을 original index로 변경
                                        orig_labels=self.target_label_classes, 
                                        target_labels=self.orig_label_classes
                          )

    def _iteration(
        self, engine: Engine, batchdata: Dict[str, Any]
    ) -> Dict[str, torch.Tensor]:
        """
        callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine.
        Return below items in a dictionary:
            - IMAGE: image Tensor data for model input, already moved to device.
            - LABEL: label Tensor data corresponding to the image, already moved to device.
            - PRED: prediction result of model.

        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
            batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.

        Raises:
            ValueError: When ``batchdata`` is None.

        """
        if batchdata is None:
            raise ValueError("Must provide batch data for current iteration.")
        batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking)
        if len(batch) == 2:
            inputs, targets = batch
            args: Tuple = ()
            kwargs: Dict = {}
        else:
            inputs, targets, args, kwargs = batch

        targets = targets.cpu()      # device로 바꿔보자
        # print('CPU 비활성화!!!!!!!!!! ')

        def _compute_pred():
            ct = 1.0
            # pred = self.inferer(inputs, self.network, *args, **kwargs).cpu()    # device로 바꿔보자 (지우면댐)
            pred = self.inferer(inputs, self.network, *args, **kwargs)    # device로 바꿔보자 (지우면댐)
            pred = nn.functional.softmax(pred, dim=1)
            if not self.tta_val:
                return pred
            else:
                for dims in [[2], [3], [4], (2, 3), (2, 4), (3, 4), (2, 3, 4)]:
                    flip_inputs = torch.flip(inputs, dims=dims)
                    flip_pred = torch.flip(
                        # self.inferer(flip_inputs, self.network).cpu(), dims=dims    # device로 바꿔보자
                        self.inferer(flip_inputs, self.network), dims=dims    # device로 바꿔보자
                    )
                    flip_pred = nn.functional.softmax(flip_pred, dim=1)
                    del flip_inputs
                    pred += flip_pred
                    del flip_pred
                    ct += 1
                return pred / ct

        # execute forward computation
        with eval_mode(self.network):
            if self.amp:
                with torch.cuda.amp.autocast():
                    predictions = _compute_pred()
            else:
                predictions = _compute_pred()

        inputs = inputs.cpu()    #  # device로 바꿔보자

        predictions = self.post_pred(decollate_batch(predictions)[0])
        targets = self.post_label(decollate_batch(targets)[0])                # eval specific

        affine = batchdata["image_meta_dict"]["affine"].numpy()[0]            # infer specific
        resample_flag = batchdata["resample_flag"]
        anisotrophy_flag = batchdata["anisotrophy_flag"]
        crop_shape = batchdata["crop_shape"][0].tolist()
        original_shape = batchdata["original_shape"][0].tolist()
        if resample_flag:
            # convert the prediction back to the original (after cropped) shape
            predictions = recovery_prediction(
                predictions.numpy(), [self.num_classes, *crop_shape], anisotrophy_flag
            )
            predictions = torch.tensor(predictions)

        ## 이미지 저장
        predictions_wirte = predictions.cpu()
        print(type(predictions_wirte))
        print(predictions_wirte.shape)
        predictions_wirte = np.argmax(predictions_wirte, axis=0)
        print(predictions_wirte.shape)
        predictions_wirte_org = np.zeros([*original_shape])
        
        # put iteration outputs into engine.state
        engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets.unsqueeze(0)}
        engine.state.output[Keys.PRED] = torch.zeros([1, self.num_classes, *original_shape])
        # pad the prediction back to the original shape
        box_start, box_end = batchdata["bbox"][0]
        h_start, w_start, d_start = box_start
        h_end, w_end, d_end = box_end

        engine.state.output[Keys.PRED][
            0, :, h_start:h_end, w_start:w_end, d_start:d_end
        ] = predictions
        del predictions

        
        ## 이미지 저장
        # predictions_wirte = self.post_trans(predictions_wirte)
#         orig_label_classes, target_label_classes = (
#             np.array([   0,    2,    3,    4,    5,    7,    8,   10,   11,   12,   13,
#                 14,   15,   16,   17,   18,   24,   26,   28,   30,   31,   41,
#                 42,   43,   44,   46,   47,   49,   50,   51,   52,   53,   54,
#                 58,   60,   62,   63,   77,   80,   85,  251,  252,  253,  254,
#                 255, 1000, 1002, 1003, 1005, 1006, 1007, 1008, 1009, 1010, 1011,
#             1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022,
#             1023, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1034, 1035,
#             2000, 2002, 2003, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012,
#             2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023,
#             2024, 2025, 2026, 2027, 2028, 2029, 2030, 2031, 2034, 2035]),
#             np.array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
#                 13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
#                 26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
#                 39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
#                 52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
#                 65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
#                 78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
#                 91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
#             104, 105, 106, 107, 108])
#         )
#         post_trans = MapLabelValue(    # 0~108로 잡혀있는 label을 original index로 변경
#             orig_labels=target_label_classes, 
#             target_labels=orig_label_classes,
#         )
#         post_trans = Compose([
#             MapLabelValue(    # 0~108로 잡혀있는 label을 original index로 변경
#             orig_labels=target_label_classes, 
#             target_labels=orig_label_classes,
# #             dtype=np.uint8
#             ),
#             CastToType(dtype=np.int64)
            
#         ])
               
        
        predictions_wirte_org[h_start:h_end, w_start:w_end, d_start:d_end] = predictions_wirte
        del predictions_wirte
        print('변형전 dtype', predictions_wirte_org.dtype)
        print(np.unique(predictions_wirte_org))
        predictions_wirte_org = self.post_trans(predictions_wirte_org)   # 원래대로 pred index 복구
        # predictions_wirte_org = predictions_wirte_org.squeeze()
#         predictions_wirte_org = post_trans(predictions_wirte_org)
        print('변형후 dtype', predictions_wirte_org.dtype)
        print(np.unique(predictions_wirte_org))
        
        filename = batchdata["image_meta_dict"]["filename_or_obj"][0].split("/")[-1]
        print(
            "save {} with shape: {}".format(
                filename, predictions_wirte_org.shape
            )
        )
        write_nifti(
            data=predictions_wirte_org,
            file_name=os.path.join(self.output_dir, filename),
            affine=affine,
            resample=False,
            output_dtype=np.uint32,
        )
        print(f'MRI img:{ filename } eval done .................')
        engine.fire_event(IterationEvents.FORWARD_COMPLETED)
        engine.fire_event(IterationEvents.MODEL_COMPLETED)

        return engine.state.output

In [97]:
## 5. GPU ver. + inference image 저장형 + tta False + posttransform 2 (mapping index)
# evaluator class 내부에서 transform
net.eval()
evaluator = DynUNetEvaluator_GPU_SaveResult_PostMapping2(
    device=device,
    val_data_loader=val_loader,
    network=net,
    output_dir=val_output_dir,
    num_classes=num_classes,
    inferer=SlidingWindowInferer(
        roi_size=patch_size,
        sw_batch_size=val_batch_size,
        overlap=eval_overlap,
        mode=window_mode,
    ),
    postprocessing=None,
    key_val_metric={
        "val_mean_dice": MeanDice(
            include_background=False,
            output_transform=from_engine(["pred", "label"]),
        )
    },
    additional_metrics=None,
    amp=amp,
    tta_val=tta_val,
)

In [98]:
start = time.time()
evaluator.run()
print("time :", time.time() - start)

<class 'torch.Tensor'>
torch.Size([109, 154, 154, 152])
torch.Size([154, 154, 152])
변형전 dtype float64
[  0.   1.   3.   4.   5.   6.   7.   8.   9.  10.  11.  12.  13.  14.
  15.  16.  17.  18.  20.  21.  23.  24.  25.  26.  27.  28.  29.  30.
  31.  32.  33.  34.  36.  37.  39.  40.  41.  42.  43.  44.  45.  46.
  47.  48.  49.  50.  51.  52.  53.  54.  55.  56.  57.  58.  59.  60.
  61.  62.  63.  64.  65.  66.  67.  68.  69.  70.  71.  72.  73.  74.
  75.  76.  77.  78.  79.  80.  81.  82.  83.  84.  85.  86.  87.  88.
  89.  90.  91.  92.  93.  94.  95.  96.  97.  98.  99. 100. 101. 102.
 103. 104. 105. 106. 107. 108.]
변형후 dtype float32
[0.000e+00 2.000e+00 4.000e+00 5.000e+00 7.000e+00 8.000e+00 1.000e+01
 1.100e+01 1.200e+01 1.300e+01 1.400e+01 1.500e+01 1.600e+01 1.700e+01
 1.800e+01 2.400e+01 2.600e+01 2.800e+01 3.100e+01 4.100e+01 4.300e+01
 4.400e+01 4.600e+01 4.700e+01 4.900e+01 5.000e+01 5.100e+01 5.200e+01
 5.300e+01 5.400e+01 5.800e+01 6.000e+01 6.300e+01 7.700e+01 8.500e

In [None]:
a[93]

이미지를 로드하여 mapping transform 테스트를 진행해야 한다.

In [99]:
from monai.transforms import (
    LoadImaged,
    AddChanneld,
    MapLabelValued,
    Compose
)
from monai.data import (
    CacheDataset,
    DataLoader,
    partition_dataset,
)

orig_label_classes, target_label_classes = (
    np.array([   0,    2,    3,    4,    5,    7,    8,   10,   11,   12,   13,
         14,   15,   16,   17,   18,   24,   26,   28,   30,   31,   41,
         42,   43,   44,   46,   47,   49,   50,   51,   52,   53,   54,
         58,   60,   62,   63,   77,   80,   85,  251,  252,  253,  254,
        255, 1000, 1002, 1003, 1005, 1006, 1007, 1008, 1009, 1010, 1011,
       1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022,
       1023, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1034, 1035,
       2000, 2002, 2003, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012,
       2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023,
       2024, 2025, 2026, 2027, 2028, 2029, 2030, 2031, 2034, 2035]),
    np.array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
        78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
        91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
       104, 105, 106, 107, 108])
)

def get_test_transform():
    keys = ["label", "pred"]
    transforms = [
        LoadImaged(keys=keys),
#         AddChanneld(keys=keys),
#         MapLabelValued(
#             keys=["pred"], 
#             orig_labels=target_label_classes, 
#             target_labels=orig_label_classes
#         ),
#         Orientationd(keys=keys, axcodes="RAS"),
#         ConcatItemsd(keys=["image", "mask"], name="image"),
#         PreprocessAnisotropic(
#             keys=["image"],
#             clip_values=clip_values,
#             pixdim=spacing,
#             normalize_values=normalize_values,
#             model_mode="test",
#         ),
#         CastToTyped(keys=["image"], dtype=(np.float32)),
#         EnsureTyped(keys=["image"]),
    ]
    return Compose(transforms)
transform = get_test_transform()

In [100]:
# import shutil
# original_file = os.path.join(data_dir, "s_SU0197_00_0_b.nii.gz")
# shutil.copy(original_file, './')

In [101]:
data = []
label_file = os.path.join(data_dir, "s_SU0197_00_0_l.nii.gz")
pred_file = "runs_eval2/s_SU0197_00_0_b.nii.gz"
data.append({'label': label_file, 'pred': pred_file})

dataset = CacheDataset(
    data=data,
    transform=transform,
    num_workers=8,
    cache_rate=1.0,
)

data_loader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=True,
    num_workers=4,
    drop_last=True,
)

Loading dataset: 100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.78it/s]


In [102]:
test_data = first(data_loader)

In [103]:
test_data.keys()

dict_keys(['label', 'pred', 'label_meta_dict', 'pred_meta_dict'])

In [104]:
affine = test_data["pred_meta_dict"]["affine"].numpy()[0]
affine

array([[ 9.99280810e-01, -2.96830740e-02,  2.35963799e-02,
        -9.13983536e+01],
       [ 2.94676349e-02,  9.99521315e-01,  9.42613184e-03,
        -1.21794128e+02],
       [-2.38648821e-02, -8.72402266e-03,  9.99677122e-01,
        -9.46715775e+01],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00]])

In [105]:
test_data['label'].shape, test_data['pred'].shape

(torch.Size([1, 186, 230, 230]), torch.Size([1, 186, 230, 230]))

In [106]:
test_data['pred'].dtype, test_data['label'].dtype

(torch.float32, torch.float32)

In [107]:
len(np.unique(test_data['pred'])), np.unique(test_data['pred'])

(104,
 array([0.000e+00, 2.000e+00, 4.000e+00, 5.000e+00, 7.000e+00, 8.000e+00,
        1.000e+01, 1.100e+01, 1.200e+01, 1.300e+01, 1.400e+01, 1.500e+01,
        1.600e+01, 1.700e+01, 1.800e+01, 2.400e+01, 2.600e+01, 2.800e+01,
        3.100e+01, 4.100e+01, 4.300e+01, 4.400e+01, 4.600e+01, 4.700e+01,
        4.900e+01, 5.000e+01, 5.100e+01, 5.200e+01, 5.300e+01, 5.400e+01,
        5.800e+01, 6.000e+01, 6.300e+01, 7.700e+01, 8.500e+01, 2.510e+02,
        2.520e+02, 2.530e+02, 2.540e+02, 2.550e+02, 1.000e+03, 1.002e+03,
        1.003e+03, 1.005e+03, 1.006e+03, 1.007e+03, 1.008e+03, 1.009e+03,
        1.010e+03, 1.011e+03, 1.012e+03, 1.013e+03, 1.014e+03, 1.015e+03,
        1.016e+03, 1.017e+03, 1.018e+03, 1.019e+03, 1.020e+03, 1.021e+03,
        1.022e+03, 1.023e+03, 1.024e+03, 1.025e+03, 1.026e+03, 1.027e+03,
        1.028e+03, 1.029e+03, 1.030e+03, 1.031e+03, 1.034e+03, 1.035e+03,
        2.000e+03, 2.002e+03, 2.003e+03, 2.005e+03, 2.006e+03, 2.007e+03,
        2.008e+03, 2.009e+03, 2.

In [108]:
np.unique(test_data['label'])

array([0.000e+00, 2.000e+00, 4.000e+00, 5.000e+00, 7.000e+00, 8.000e+00,
       1.000e+01, 1.100e+01, 1.200e+01, 1.300e+01, 1.400e+01, 1.500e+01,
       1.600e+01, 1.700e+01, 1.800e+01, 2.400e+01, 2.600e+01, 2.800e+01,
       3.100e+01, 4.100e+01, 4.300e+01, 4.400e+01, 4.600e+01, 4.700e+01,
       4.900e+01, 5.000e+01, 5.100e+01, 5.200e+01, 5.300e+01, 5.400e+01,
       5.800e+01, 6.000e+01, 6.300e+01, 7.700e+01, 8.500e+01, 2.510e+02,
       2.520e+02, 2.530e+02, 2.540e+02, 2.550e+02, 1.000e+03, 1.002e+03,
       1.003e+03, 1.005e+03, 1.006e+03, 1.007e+03, 1.008e+03, 1.009e+03,
       1.010e+03, 1.011e+03, 1.012e+03, 1.013e+03, 1.014e+03, 1.015e+03,
       1.016e+03, 1.017e+03, 1.018e+03, 1.019e+03, 1.020e+03, 1.021e+03,
       1.022e+03, 1.023e+03, 1.024e+03, 1.025e+03, 1.026e+03, 1.027e+03,
       1.028e+03, 1.029e+03, 1.030e+03, 1.031e+03, 1.034e+03, 1.035e+03,
       2.000e+03, 2.002e+03, 2.003e+03, 2.005e+03, 2.006e+03, 2.007e+03,
       2.008e+03, 2.009e+03, 2.010e+03, 2.011e+03, 

In [109]:
result = np.where(test_data['pred'] == 2018)   # 2018, 93
result

(array([0, 0, 0, ..., 0, 0, 0]),
 array([125, 125, 125, ..., 150, 150, 150]),
 array([132, 133, 133, ..., 134, 134, 134]),
 array([132, 131, 132, ..., 135, 136, 137]))

In [110]:
idx = 12
test_data['pred'][result[0][idx], result[1][idx], result[2][idx], result[3][idx]]

tensor(2018.)

In [111]:
test_data['label'][result[0][idx], result[1][idx], result[2][idx], result[3][idx]]

tensor(2018.)

오케 맵핑이 잘되고 있음을 확인하였음.

------------------
저장된 이미지로부터 post-transform mapping 하여 결과 확인해본다.

In [None]:
from monai.transforms import MapLabelValue
orig_label_classes, target_label_classes = (
    np.array([   0,    2,    3,    4,    5,    7,    8,   10,   11,   12,   13,
        14,   15,   16,   17,   18,   24,   26,   28,   30,   31,   41,
        42,   43,   44,   46,   47,   49,   50,   51,   52,   53,   54,
        58,   60,   62,   63,   77,   80,   85,  251,  252,  253,  254,
        255, 1000, 1002, 1003, 1005, 1006, 1007, 1008, 1009, 1010, 1011,
    1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022,
    1023, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1034, 1035,
    2000, 2002, 2003, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012,
    2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023,
    2024, 2025, 2026, 2027, 2028, 2029, 2030, 2031, 2034, 2035]),
    np.array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
        78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
        91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
    104, 105, 106, 107, 108])
)
post_trans = MapLabelValue(    # 0~108로 잡혀있는 label을 original index로 변경
    orig_labels=target_label_classes, 
    target_labels=orig_label_classes
)

In [None]:
test_data['pred'].shape, test_data['label'].shape

In [None]:
aaa = post_trans(test_data['pred'])
aaa.shape

In [None]:
# aaa = test_data['pred']

In [None]:
result = np.where(aaa == 2018)   # 41
result

In [None]:
idx = 226
aaa[result[0][idx], result[1][idx], result[2][idx], result[3][idx]]

In [None]:
test_data['label'][result[0][idx], result[1][idx], result[2][idx], result[3][idx]]

post transform 한 결과를 이미지로 저장한다.

In [None]:
aaa.shape

In [None]:
aaa.squeeze().shape

In [None]:
from monai.data.nifti_writer import write_nifti

In [None]:
write_nifti(
        data=aaa.squeeze(),
        file_name='aa2.nii.gz',
        affine=affine,
        resample=False,
        output_dtype=np.uint8,
    )

In [None]:
aaa.shape

In [None]:
from monai.transforms import (
    LoadImaged,
    AddChanneld,
    MapLabelValued,
    Compose
)
from monai.data import (
    CacheDataset,
    DataLoader,
    partition_dataset,
)

orig_label_classes, target_label_classes = (
    np.array([   0,    2,    3,    4,    5,    7,    8,   10,   11,   12,   13,
         14,   15,   16,   17,   18,   24,   26,   28,   30,   31,   41,
         42,   43,   44,   46,   47,   49,   50,   51,   52,   53,   54,
         58,   60,   62,   63,   77,   80,   85,  251,  252,  253,  254,
        255, 1000, 1002, 1003, 1005, 1006, 1007, 1008, 1009, 1010, 1011,
       1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022,
       1023, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1034, 1035,
       2000, 2002, 2003, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012,
       2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023,
       2024, 2025, 2026, 2027, 2028, 2029, 2030, 2031, 2034, 2035]),
    np.array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
        78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
        91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
       104, 105, 106, 107, 108])
)

def get_test_transform():
    keys = ["label", "pred"]
    transforms = [
        LoadImaged(keys=keys),
#         AddChanneld(keys=keys),
#         MapLabelValued(
#             keys=["pred"], 
#             orig_labels=target_label_classes, 
#             target_labels=orig_label_classes
#         ),
#         Orientationd(keys=keys, axcodes="RAS"),
#         ConcatItemsd(keys=["image", "mask"], name="image"),
#         PreprocessAnisotropic(
#             keys=["image"],
#             clip_values=clip_values,
#             pixdim=spacing,
#             normalize_values=normalize_values,
#             model_mode="test",
#         ),
#         CastToTyped(keys=["image"], dtype=(np.float32)),
#         EnsureTyped(keys=["image"]),
    ]
    return Compose(transforms)
transform = get_test_transform()

In [None]:
data = []
label_file = os.path.join(data_dir, "s_SU0197_00_0_l.nii.gz")
pred_file = "aa2.nii.gz"
data.append({'label': label_file, 'pred': pred_file})

dataset = CacheDataset(
    data=data,
    transform=transform,
    num_workers=8,
    cache_rate=1.0,
)

data_loader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=True,
    num_workers=4,
    drop_last=True,
)

In [None]:
test_data = first(data_loader)

In [None]:
test_data.keys()

In [None]:
result = np.where(test_data['pred'] == 2018)
result

In [None]:
idx = 12
test_data['pred'][result[0][idx], result[1][idx], result[2][idx], result[3][idx]]

In [None]:
test_data['label'][result[0][idx], result[1][idx], result[2][idx], result[3][idx]]

## finding 정리
evaluator에 MeanDice가 저장되는 개념.
 * evaluator.state.metrics : 전체 ROI MeanDice 평균
 * evaluator.state.metric_details["val_mean_dice"] : ROI별 MeanDice 평균

In [None]:
if local_rank == 0:
    print(evaluator.state.metrics)
    results = evaluator.state.metric_details["val_mean_dice"]
    if num_classes > 2:
        for i in range(num_classes - 1):
            print(
                "mean dice for label {} is {}".format(i + 1, results[:, i].mean())
            )

In [None]:
dir(evaluator)

In [None]:
evaluator.state_dict()

In [None]:
dir(evaluator.state)

In [None]:
type(evaluator.state.metric_details), evaluator.state.metric_details.keys()

In [None]:
evaluator.state.metric_details['val_mean_dice']

In [None]:
evaluator.state.metric_details['val_mean_dice'].shape