## Inference 4D dynunet pipeline with NeuroI ROI dataset
* `inference.py` 를 바탕으로 테스트
* Brain image + mask image -> 4D modalities
* in_channels: 4, out_channels:1

In [1]:
import os
# from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import numpy as np

import torch
import torch.distributed as dist
from monai.inferers import SlidingWindowInferer
from torch.nn.parallel import DistributedDataParallel

# from create_dataset import get_data
from create_network import get_network_ke
from inferrer import DynUNetInferrer, DynUNetInferrer_PostMapping
# from task_params import patch_size, task_name
from monai.utils import first
from config import get_config
from dataset_roi_4d import get_test_loader
import time

In [2]:
!config = "/data/kehyeong/project/MONAI_examples/dynunet_pipeline/config_roi_earlystop_toy_220209.yaml"
!checkpoint = "/data/train/running/l/model_roi_try1_220217/models/net_key_metric=0.3331.pt"
!test_dataset = "/work/NeuroI-models/ke-monai/data/roi/dataset_test_roi_toy2.csv"

!infer_output_dir = "./runs_inference"


# fold = 0
# root_dir = "/workspace/data/medical/"
# expr_name = "expr"  # test folder suffix
# datalist_path = "config/"
!train_num_workers = 4   # the num_workers parameter of training dataloader.
!val_num_workers = 1
interval = 5   # the validation interval under epoch level.
!eval_overlap = 0.5     #  the overlap parameter of SlidingWindowInferer.

!spacing = [1.0, 1.0, 1.0]
!deep_supr_num = 3
!window_mode = "gaussian"     # the mode parameter for SlidingWindowInferer.
!num_samples = 3.          # the num_samples parameter of RandCropByPosNegLabeld.
!pos_sample_num = 1
!neg_sample_num = 1
cache_rate = 1.0
!amp = False
!tta_val = False
multi_gpu = False
!local_rank = 0



# task_id = args.task_id
# checkpoint = args.checkpoint
# val_output_dir = "./runs_{}_fold{}_{}/".format(
#     args.task_id, args.fold, args.expr_name
# )
# sw_batch_size = args.sw_batch_size
# infer_output_dir = os.path.join(val_output_dir, task_name[task_id])
# window_mode = args.window_mode
# eval_overlap = args.eval_overlap
# amp = args.amp
# tta_val = args.tta_val
# multi_gpu_flag = args.multi_gpu
# local_rank = args.local_rank


# task_id = args.task_id
# checkpoint = args.checkpoint
# val_output_dir = "./runs_{}_fold{}_{}/".format(
#     args.task_id, args.fold, args.expr_name
# )
# sw_batch_size = args.sw_batch_size
# infer_output_dir = os.path.join(val_output_dir, task_name[task_id])

# window_mode = args.window_mode
# eval_overlap = args.eval_overlap
# amp = args.amp
# tta_val = args.tta_val
multi_gpu_flag = multi_gpu
# local_rank = args.local_rank

In [3]:


!config = get_config(config)
!image_file_path = config["image_file_path"]
!mask_file_path = config["mask_file_path"]
# val_batch_size = config["val"]["batch_size"]
!val_num_workers = config["val"]["num_workers"]
!data_dir = config["data_dir"]
!num_classes = config["num_classes"]
!patch_size = config["patch_size"]
!sw_batch_size = config["val"]["batch_size"]      # the sw_batch_size parameter of SlidingWindowInferer.  # 학습 시켰던 파라미터로 고정

# infer_output_dir = data_dir
if not os.path.exists(data_dir):
    os.makedirs(data_dir)

if not os.path.exists(infer_output_dir):
    os.makedirs(infer_output_dir)

if multi_gpu_flag:
    dist.init_process_group(backend="nccl", init_method="env://")
    device = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(device)
else:
#     device = torch.device("cuda")
    device = torch.device("cpu")


test_loader = get_test_loader(
    data_dir=data_dir,
    id_file=test_dataset,
    image_file_pattern=image_file_path,
    mask_file_pattern=mask_file_path,
    batch_size=1,
    num_workers=val_num_workers,
    multi_gpu_flag=multi_gpu_flag
)

{dir}/s_{id}_{aid}_b.nii.gz /data/train/running/l/input_augmented SU0303 00_0
{dir}/s_{id}_{aid}_b.nii.gz /data/train/running/l/input_augmented SU0197 00_0


Loading dataset: 100%|███████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.85it/s]


In [4]:
test_data = first(test_loader)
print(test_data.keys())
print("image shape:", test_data['image'].shape)
print("image dtype:", test_data['image'].dtype)

dict_keys(['image', 'mask', 'image_meta_dict', 'mask_meta_dict', 'image_transforms', 'mask_transforms', 'original_shape', 'bbox', 'crop_shape', 'resample_flag', 'anisotrophy_flag'])
image shape: torch.Size([1, 2, 154, 154, 152])
image dtype: torch.float32


In [5]:
properties = {
    'modality': [0,1],
    'labels': np.arange(num_classes)
}
n_class = len(properties["labels"])
in_channels = len(properties["modality"])
in_channels, n_class

(2, 109)

In [6]:
# produce the network
val_output_dir = "./runs_{}/".format("inference")
net = get_network_ke(properties, patch_size, spacing, deep_supr_num, 
                     val_output_dir, checkpoint)    # val_output_dir은 삭제 필요. 안씀
print('Loading nnUNET Done!!!')
net = net.to(device)
print('Loading nnUNET to CPU devices Done!!!')
print(net)

pretrained checkpoint: /data/train/running/l/model_roi_try1_220217/models/net_key_metric=0.3331.pt loaded
Loading nnUNET Done!!!
Loading nnUNET to CPU devices Done!!!
DynUNet(
  (input_block): UnetBasicBlock(
    (conv1): Convolution(
      (conv): Conv3d(2, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    )
    (conv2): Convolution(
      (conv): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    )
    (lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
    (norm1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (norm2): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  )
  (downsamples): ModuleList(
    (0): UnetBasicBlock(
      (conv1): Convolution(
        (conv): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
      )
      (conv2): Convolution(
        (conv): Conv3d(64, 64, kernel_size=(3, 

In [7]:
if multi_gpu_flag:
    net = DistributedDataParallel(
        module=net, device_ids=[device], find_unused_parameters=True
    )

In [8]:
net.eval()

inferrer = DynUNetInferrer_PostMapping(
    device=device,
    val_data_loader=test_loader,
    network=net,
    output_dir=infer_output_dir,
    num_classes=len(properties["labels"]),
    inferer=SlidingWindowInferer(
        roi_size=patch_size,
        sw_batch_size=sw_batch_size,
        overlap=eval_overlap,
        mode=window_mode,
    ),
    amp=amp,
    tta_val=tta_val,
)

# inferrer.run()

In [9]:
start = time.time()
inferrer.run()
print("time :", time.time() - start)

변형전 dtype float64
[  0.   1.   3.   4.   5.   6.   7.   8.   9.  10.  11.  12.  13.  14.
  15.  16.  17.  18.  20.  21.  23.  24.  25.  26.  27.  28.  29.  30.
  31.  32.  33.  34.  36.  37.  39.  40.  41.  42.  43.  44.  45.  46.
  47.  48.  49.  50.  51.  52.  53.  54.  55.  56.  57.  58.  59.  60.
  61.  62.  63.  64.  65.  66.  67.  68.  69.  70.  71.  72.  73.  74.
  75.  76.  77.  78.  79.  80.  81.  82.  83.  84.  85.  86.  87.  88.
  89.  90.  91.  92.  93.  94.  95.  96.  97.  98.  99. 100. 101. 102.
 103. 104. 105. 106. 107. 108.]
변형후 dtype float32
[0.000e+00 2.000e+00 4.000e+00 5.000e+00 7.000e+00 8.000e+00 1.000e+01
 1.100e+01 1.200e+01 1.300e+01 1.400e+01 1.500e+01 1.600e+01 1.700e+01
 1.800e+01 2.400e+01 2.600e+01 2.800e+01 3.100e+01 4.100e+01 4.300e+01
 4.400e+01 4.600e+01 4.700e+01 4.900e+01 5.000e+01 5.100e+01 5.200e+01
 5.300e+01 5.400e+01 5.800e+01 6.000e+01 6.300e+01 7.700e+01 8.500e+01
 2.510e+02 2.520e+02 2.530e+02 2.540e+02 2.550e+02 1.000e+03 1.002e+03
 1.003e+0

### 잘 mapping되었는지 확인
테스트 이미지 로드하여 테스트 진행

In [14]:
from monai.transforms import (
    LoadImaged,
    AddChanneld,
    MapLabelValued,
    Compose
)
from monai.data import (
    CacheDataset,
    DataLoader,
    partition_dataset,
)

def get_test_transform():
    keys = ["pred"]
    transforms = [
        LoadImaged(keys=keys),
#         AddChanneld(keys=keys),
#         MapLabelValued(
#             keys=["pred"], 
#             orig_labels=target_label_classes, 
#             target_labels=orig_label_classes
#         ),
#         Orientationd(keys=keys, axcodes="RAS"),
#         ConcatItemsd(keys=["image", "mask"], name="image"),
#         PreprocessAnisotropic(
#             keys=["image"],
#             clip_values=clip_values,
#             pixdim=spacing,
#             normalize_values=normalize_values,
#             model_mode="test",
#         ),
#         CastToTyped(keys=["image"], dtype=(np.float32)),
#         EnsureTyped(keys=["image"]),
    ]
    return Compose(transforms)
transform = get_test_transform()

data = []
pred_file = os.path.join(infer_output_dir, "s_SU0197_00_0_b.nii.gz")
data.append({'pred': pred_file})

dataset = CacheDataset(
    data=data,
    transform=transform,
    num_workers=8,
    cache_rate=1.0,
)

data_loader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=True,
    num_workers=4,
    drop_last=True,
)

Loading dataset: 100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  3.93it/s]


In [15]:
test_data = first(data_loader)
print(test_data.keys())
print(test_data['pred'].shape)
print(test_data['pred'].dtype)

dict_keys(['pred', 'pred_meta_dict'])
torch.Size([1, 186, 230, 230])
torch.float32


In [16]:
len(np.unique(test_data['pred'])), np.unique(test_data['pred'])

(104,
 array([0.000e+00, 2.000e+00, 4.000e+00, 5.000e+00, 7.000e+00, 8.000e+00,
        1.000e+01, 1.100e+01, 1.200e+01, 1.300e+01, 1.400e+01, 1.500e+01,
        1.600e+01, 1.700e+01, 1.800e+01, 2.400e+01, 2.600e+01, 2.800e+01,
        3.100e+01, 4.100e+01, 4.300e+01, 4.400e+01, 4.600e+01, 4.700e+01,
        4.900e+01, 5.000e+01, 5.100e+01, 5.200e+01, 5.300e+01, 5.400e+01,
        5.800e+01, 6.000e+01, 6.300e+01, 7.700e+01, 8.500e+01, 2.510e+02,
        2.520e+02, 2.530e+02, 2.540e+02, 2.550e+02, 1.000e+03, 1.002e+03,
        1.003e+03, 1.005e+03, 1.006e+03, 1.007e+03, 1.008e+03, 1.009e+03,
        1.010e+03, 1.011e+03, 1.012e+03, 1.013e+03, 1.014e+03, 1.015e+03,
        1.016e+03, 1.017e+03, 1.018e+03, 1.019e+03, 1.020e+03, 1.021e+03,
        1.022e+03, 1.023e+03, 1.024e+03, 1.025e+03, 1.026e+03, 1.027e+03,
        1.028e+03, 1.029e+03, 1.030e+03, 1.031e+03, 1.034e+03, 1.035e+03,
        2.000e+03, 2.002e+03, 2.003e+03, 2.005e+03, 2.006e+03, 2.007e+03,
        2.008e+03, 2.009e+03, 2.

In [17]:
result = np.where(test_data['pred'] == 2018)   # 2018, 93
result

(array([0, 0, 0, ..., 0, 0, 0]),
 array([125, 125, 125, ..., 150, 150, 150]),
 array([132, 133, 133, ..., 134, 134, 134]),
 array([132, 131, 132, ..., 135, 136, 137]))