## Inference 4D dynunet pipeline with NeuroI ROI dataset
* `inference.py` 를 바탕으로 테스트
* Brain image + mask image -> 4D modalities
* in_channels: 4, out_channels:1

In [10]:
import os
# from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import numpy as np

import torch
import torch.distributed as dist
from monai.inferers import SlidingWindowInferer
from torch.nn.parallel import DistributedDataParallel

# from create_dataset import get_data
from create_network import get_network_ke
from inferrer import DynUNetInferrer
# from task_params import patch_size, task_name
from monai.utils import first
from config import get_config
from dataset_roi_4d import get_test_loader

In [2]:
config = "/data/kehyeong/project/MONAI_examples/dynunet_pipeline/config_roi_earlystop_toy_220209.yaml"
checkpoint = "/data/train/running/l/model_roi_try1_220217/models/net_key_metric=0.3331.pt"
test_dataset = "/work/NeuroI-models/ke-monai/data/roi/dataset_test_roi_toy2.csv"

infer_output_dir = "./runs_inference"


# fold = 0
# root_dir = "/workspace/data/medical/"
# expr_name = "expr"  # test folder suffix
# datalist_path = "config/"
train_num_workers = 4   # the num_workers parameter of training dataloader.
val_num_workers = 1
interval = 5   # the validation interval under epoch level.
eval_overlap = 0.5     #  the overlap parameter of SlidingWindowInferer.

spacing = [1.0, 1.0, 1.0]
deep_supr_num = 3
window_mode = "gaussian"     # the mode parameter for SlidingWindowInferer.
num_samples = 3.          # the num_samples parameter of RandCropByPosNegLabeld.
pos_sample_num = 1
neg_sample_num = 1
cache_rate = 1.0
amp = False
tta_val = False
multi_gpu = False
local_rank = 0



# task_id = args.task_id
# checkpoint = args.checkpoint
# val_output_dir = "./runs_{}_fold{}_{}/".format(
#     args.task_id, args.fold, args.expr_name
# )
# sw_batch_size = args.sw_batch_size
# infer_output_dir = os.path.join(val_output_dir, task_name[task_id])
# window_mode = args.window_mode
# eval_overlap = args.eval_overlap
# amp = args.amp
# tta_val = args.tta_val
# multi_gpu_flag = args.multi_gpu
# local_rank = args.local_rank


# task_id = args.task_id
# checkpoint = args.checkpoint
# val_output_dir = "./runs_{}_fold{}_{}/".format(
#     args.task_id, args.fold, args.expr_name
# )
# sw_batch_size = args.sw_batch_size
# infer_output_dir = os.path.join(val_output_dir, task_name[task_id])

# window_mode = args.window_mode
# eval_overlap = args.eval_overlap
# amp = args.amp
# tta_val = args.tta_val
multi_gpu_flag = multi_gpu
# local_rank = args.local_rank

In [3]:


config = get_config(config)
image_file_path = config["image_file_path"]
mask_file_path = config["mask_file_path"]
# val_batch_size = config["val"]["batch_size"]
val_num_workers = config["val"]["num_workers"]
data_dir = config["data_dir"]
num_classes = config["num_classes"]
patch_size = config["patch_size"]
sw_batch_size = config["val"]["batch_size"]      # the sw_batch_size parameter of SlidingWindowInferer.  # 학습 시켰던 파라미터로 고정

# infer_output_dir = data_dir
if not os.path.exists(data_dir):
    os.makedirs(data_dir)

if not os.path.exists(infer_output_dir):
    os.makedirs(infer_output_dir)

if multi_gpu_flag:
    dist.init_process_group(backend="nccl", init_method="env://")
    device = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(device)
else:
#     device = torch.device("cuda")
    device = torch.device("cpu")


test_loader = get_test_loader(
    data_dir=data_dir,
    id_file=test_dataset,
    image_file_pattern=image_file_path,
    mask_file_pattern=mask_file_path,
    batch_size=1,
    num_workers=val_num_workers,
    multi_gpu_flag=multi_gpu_flag
)

{dir}/s_{id}_{aid}_b.nii.gz /data/train/running/l/input_augmented SU0454 00_0
{dir}/s_{id}_{aid}_b.nii.gz /data/train/running/l/input_augmented SU0454 01_d
{dir}/s_{id}_{aid}_b.nii.gz /data/train/running/l/input_augmented SU0303 00_0
{dir}/s_{id}_{aid}_b.nii.gz /data/train/running/l/input_augmented SU0303 01_d
{dir}/s_{id}_{aid}_b.nii.gz /data/train/running/l/input_augmented SU0197 00_0
{dir}/s_{id}_{aid}_b.nii.gz /data/train/running/l/input_augmented SU0197 01_d


Loading dataset: 100%|███████████████████████████████████████████████████████████████████████████| 6/6 [00:29<00:00,  4.96s/it]


In [4]:
test_data = first(test_loader)
print(test_data.keys())
print("image shape:", test_data['image'].shape)
print("image dtype:", test_data['image'].dtype)

dict_keys(['image', 'mask', 'image_meta_dict', 'mask_meta_dict', 'image_transforms', 'mask_transforms', 'original_shape', 'bbox', 'crop_shape', 'resample_flag', 'anisotrophy_flag'])
image shape: torch.Size([1, 2, 134, 152, 132])
image dtype: torch.float32


In [7]:
properties = {
    'modality': [0,1],
    'labels': np.arange(num_classes)
}
n_class = len(properties["labels"])
in_channels = len(properties["modality"])
in_channels, n_class

(2, 109)

In [11]:
# produce the network
val_output_dir = "./runs_{}/".format("inference")
net = get_network_ke(properties, patch_size, spacing, deep_supr_num, 
                     val_output_dir, checkpoint)    # val_output_dir은 삭제 필요. 안씀
print('Loading nnUNET Done!!!')
net = net.to(device)
print('Loading nnUNET to GPU devices Done!!!')
print(net)

pretrained checkpoint: /data/train/running/l/model_roi_try1_220217/models/net_key_metric=0.3331.pt loaded
Loading nnUNET Done!!!
Loading nnUNET to GPU devices Done!!!
DynUNet(
  (input_block): UnetBasicBlock(
    (conv1): Convolution(
      (conv): Conv3d(2, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    )
    (conv2): Convolution(
      (conv): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    )
    (lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
    (norm1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (norm2): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  )
  (downsamples): ModuleList(
    (0): UnetBasicBlock(
      (conv1): Convolution(
        (conv): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
      )
      (conv2): Convolution(
        (conv): Conv3d(64, 64, kernel_size=(3, 

In [12]:
if multi_gpu_flag:
    net = DistributedDataParallel(
        module=net, device_ids=[device], find_unused_parameters=True
    )

In [13]:
net.eval()

inferrer = DynUNetInferrer(
    device=device,
    val_data_loader=test_loader,
    network=net,
    output_dir=infer_output_dir,
    num_classes=len(properties["labels"]),
    inferer=SlidingWindowInferer(
        roi_size=patch_size,
        sw_batch_size=sw_batch_size,
        overlap=eval_overlap,
        mode=window_mode,
    ),
    amp=amp,
    tta_val=tta_val,
)

# inferrer.run()

In [14]:
inferrer.run()

save s_SU0454_00_0_b.nii.gz with shape: (186, 230, 230), mean values: 3.6415143199788607
save s_SU0454_01_d_b.nii.gz with shape: (186, 230, 230), mean values: 4.753923918125089
save s_SU0303_00_0_b.nii.gz with shape: (186, 230, 230), mean values: 5.029365713356506
save s_SU0303_01_d_b.nii.gz with shape: (186, 230, 230), mean values: 4.6382436937211615
save s_SU0197_00_0_b.nii.gz with shape: (186, 230, 230), mean values: 3.925215053763441
save s_SU0197_01_d_b.nii.gz with shape: (186, 230, 230), mean values: 3.6964137040876475


In [None]:
# task_id = "01"
# fold = 0
# root_dir = "/data/kehyeong/project/MONAI_examples/data/brats/"
# datalist_path ="config/"

config = "/data/kehyeong/project/MONAI_examples/dynunet_pipeline/config_roi_earlystop_toy_220209.yaml"
train_dataset = "/work/NeuroI-models/ke-monai/data/roi/dataset_train_roi_toy2.csv"
val_dataset = "/work/NeuroI-models/ke-monai/data/roi/dataset_val_roi_toy2.csv"
log_file = "/data/train/running/l/model_roi_toy_220210/train.log"
checkpoint = None

## [아래 config에서 불러 오는 params]
# max_epochs = 1500
# num_samples = 4
# train_num_workers = 4
# val_num_workers = 2

patch_size = [64, 64, 64] # [128, 128, 128]    #[96, 96, 96]
learning_rate = 1.0e-3    # working: 1e-1
interval = 2
multi_gpu = False  # True
local_rank = 0

## [Dafault setting] - 변경 필요 없음
window_mode = "gaussian"  # "constant", "gaussian"
eval_overlap = 0.5
tta_val = True
batch_dice = False
lr_decay_flag = False
spacing = [1.0, 1.0, 1.0]
deep_supr_num = 3
expr_name = "baseline"
####################################


# pos_sample_num = 1
# neg_sample_num = 1
# cache_rate = 1.0


# mode = "train"
# checkpoint = None
# amp = False
# lr_decay = False
# tta_val = True
# batch_dice = False
# determinism_flag = False
# determinism_seed = 0
# expr_name = "baseline"

In [None]:
from config import get_config
from dataset_roi_4d import get_train_loader, get_val_loader   # 4D modality test
import monai

local_rank = local_rank
log_file = log_file
train_dataset = train_dataset
val_dataset = val_dataset
checkpoint = checkpoint

multi_gpu_flag = multi_gpu
config = get_config(config)
data_dir = config["data_dir"]
image_file_path = config["image_file_path"]
label_file_path = config["label_file_path"]
# brain_file_path = config["brain_file_path"]
mask_file_path = config["mask_file_path"]
random_seed = config["random_seed"]
max_epochs = config["train"]["max_epoches"]
num_classes = config["num_classes"]

# patch_size = tuple(config["patch_size"])
lr = config["train"]["lr"]
train_batch_size = config["train"]["batch_size"]
train_num_samples = config["train"]["num_samples"]
train_num_workers = config["train"]["num_workers"]
# val_interval = config["train"]["val_interval"]
val_batch_size = config["val"]["batch_size"]
sw_batch_size = val_batch_size    # for evaluator
val_num_workers = config["val"]["num_workers"]
log_dir = config["log_dir"]
model_dir = config["model_dir"]
mlflow_dir = os.path.join(log_dir, "mlruns")

amp_flag = (True if monai.utils.get_torch_version_tuple() >= (1, 6) else False,)

monai.utils.set_determinism(random_seed)
np.random.seed(random_seed)

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="5,6"

if multi_gpu_flag:
    dist.init_process_group(backend="nccl", init_method="env://")
    device = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(device)
else:
#     device = torch.device("cpu")
    device = torch.device("cuda")


#
# data loader
#
train_loader = get_train_loader(
    data_dir=data_dir,
    id_file=train_dataset,
    image_file_pattern=image_file_path,
    label_file_pattern=label_file_path,
    # brain_file_pattern=brain_file_path,
    mask_file_pattern=mask_file_path,
    batch_size=train_batch_size,
    patch_size=patch_size,
    num_samples=train_num_samples,
    num_workers=train_num_workers,
    multi_gpu_flag=multi_gpu_flag
)
val_loader = get_val_loader(
    data_dir=data_dir,
    id_file=val_dataset,
    image_file_pattern=image_file_path,
    label_file_pattern=label_file_path,
    # brain_file_pattern=brain_file_path,
    mask_file_pattern=mask_file_path,
    batch_size=val_batch_size,
    num_workers=val_num_workers,
    multi_gpu_flag=multi_gpu_flag
)

In [None]:
num_classes

In [None]:
test_data = first(train_loader)
print(f"배치 수: num_sample({train_num_samples}) x batch_num({train_batch_size}) -> {train_num_samples*train_batch_size}")
print(test_data.keys())
print("image shape:", test_data['image'].shape)
print("label shape:", test_data['label'].shape)
print("image dtype:", test_data['image'].dtype)
print("label dtype:", test_data['label'].dtype)
print("1번 배치의 유니크한 라벨 리스트:", np.unique(test_data['label']))
total_labels = np.unique(test_data['label'])
print(f'1번 배치의 유니크한 라벨 class 수: {len(total_labels)}')

In [None]:
# 배치 1번의 전체 라벨의 shape과 유니크한 라벨 수 조사.
for each in test_data['label']:
    print(each.shape)
    total_labels = np.unique(each)
    print(f'class 수 {len(total_labels)}')

In [None]:
properties = {
    'modality': [0,1],
    'labels': np.arange(num_classes)
}
n_class = len(properties["labels"])
in_channels = len(properties["modality"])
in_channels, n_class

In [None]:
# produce the network
val_output_dir = "./runs_fold{}_{}/".format(1, expr_name)
checkpoint = checkpoint
net = get_network_ke(properties, patch_size, spacing, deep_supr_num, 
                     val_output_dir, checkpoint)
net = net.to(device)
print(net)

if multi_gpu_flag:
    net = DistributedDataParallel(module=net, device_ids=[device])
# net = DistributedDataParallel(module=net, device_ids=[device])

optimizer = torch.optim.SGD(
    net.parameters(),
    lr=learning_rate,
    momentum=0.99,
    weight_decay=3e-5,
    nesterov=True,
)

scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer, lr_lambda=lambda epoch: (1 - epoch / max_epochs) ** 0.9
)
# produce evaluator
val_handlers = [
    StatsHandler(output_transform=lambda x: None),
    CheckpointSaver(
        save_dir=val_output_dir, save_dict={"net": net}, save_key_metric=True
    ),
]

evaluator = DynUNetEvaluator(
    device=device,
    val_data_loader=val_loader,
    network=net,
    num_classes=len(properties["labels"]),
    inferer=SlidingWindowInferer(
        roi_size=patch_size,
        sw_batch_size=sw_batch_size,
        overlap=eval_overlap,
        mode=window_mode,
    ),
    postprocessing=None,
    key_val_metric={
        "val_mean_dice": MeanDice(
            include_background=False,
            output_transform=from_engine(["pred", "label"]),
        )
    },
    val_handlers=val_handlers,
    amp=amp_flag,
    tta_val=tta_val,
)

# produce trainer
loss = DiceCELoss(to_onehot_y=True, softmax=True, batch=batch_dice)
train_handlers = []
if lr_decay_flag:
    train_handlers += [LrScheduleHandler(lr_scheduler=scheduler, print_lr=True)]

train_handlers += [
    ValidationHandler(validator=evaluator, interval=interval, epoch_level=True),
    StatsHandler(
#         tag_name="train_loss", output_transform=from_engine(["loss"], first=True)
        tag_name="train_loss", output_transform=lambda x: x["loss"]
    ),
]

trainer = DynUNetTrainer(
    device=device,
    max_epochs=max_epochs,
    train_data_loader=train_loader,
    network=net,
    optimizer=optimizer,
    loss_function=loss,
    inferer=SimpleInferer(),
    postprocessing=None,
    key_train_metric=None,
    train_handlers=train_handlers,
    amp=amp_flag,
)

if local_rank > 0:
    evaluator.logger.setLevel(logging.WARNING)
    trainer.logger.setLevel(logging.WARNING)

logger = logging.getLogger()

formatter = logging.Formatter(
    "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)

# Setup file handler
fhandler = logging.FileHandler(log_file)
fhandler.setLevel(logging.INFO)
fhandler.setFormatter(formatter)

logger.addHandler(fhandler)

chandler = logging.StreamHandler()
chandler.setLevel(logging.INFO)
chandler.setFormatter(formatter)
logger.addHandler(chandler)

logger.setLevel(logging.INFO)

trainer.run()

In [None]:
test_data = first(train_loader)
print(test_data.keys())
print('image, label shape')
print(test_data['image'].shape)
print(test_data['label'].shape)
print(test_data['image'].dtype)
print(test_data['label'].dtype)
print(np.unique(test_data['label']))
print(len(np.unique(test_data['label'])))

In [None]:
test_data['label'].numel(), 24*96*96*96

In [None]:
import matplotlib.pyplot as plt

In [None]:
H=60

test_data = first(train_loader)
image, label = (test_data["image"][0][0], test_data["label"][0][0])
print(test_data['image'].shape, test_data['label'].shape)
print(f"image shape: {image.shape}, label shape: {label.shape}")
print(f"image dtype: {image.dtype}, label dtype: {label.dtype}")
# plot the slice [:, :, 80]
plt.figure("check", (15, 10))
plt.subplot(1, 2, 1)
plt.title("image")
plt.imshow(image[:, :, H], cmap="gray")
plt.subplot(1, 2, 2)
plt.title("label")
plt.imshow(label[:, :, H])
# plt.subplot(1, 4, 3)
# plt.title("brain")
# plt.imshow(brain[:, :, H], cmap="gray")
# plt.subplot(1, 4, 4)
# plt.title("mask")
# plt.imshow(mask[:, :, H])
plt.show()

print(np.unique(label))

In [None]:
np.unique(label), label

In [None]:
image

In [None]:
pixel_min, pixel_max = image.min().item(), image.max().item()
print(pixel_min, pixel_max)
histogram, bin_edges = np.histogram(image, bins=256, range=(pixel_min, pixel_max))
plt.figure()
plt.title("Grayscale Histogram")
plt.xlabel("grayscale value")
plt.ylabel("pixel count")
plt.xlim([pixel_min, pixel_max])  # <- named arguments do not work here

plt.plot(bin_edges[0:-1], histogram)  # <- or here
plt.show()

In [None]:
new_img = np.where(image < 0, 2, image)
new_img.shape
plt.figure("check", (15, 10))
plt.title("image")
plt.subplot(1, 2, 1)
plt.imshow(new_img[:, :, H], cmap="gray")
# plt.imshow(new_img[:, :, H])
plt.subplot(1, 2, 2)
plt.imshow(image[:, :, H], cmap="gray")

In [None]:
new_img[:, :, H]

In [None]:
test_data['image'].shape

In [None]:
label[:, :, H]

In [None]:
image[:, :, H]

In [None]:
test_data.keys()

In [None]:
test_data['image_meta_dict']