## dynunet pipeline with Medicaldecatelon
* medicaldecathlon 을 이용한 4D multi classes segmentation. -> 동작함.
* 아래 파이프라인을 바탕으로 현재 NEUROI ROI 데이터셋을 태워서 뭐가 문제가 있는지 확인해 보자.

In [None]:
import logging
import os
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser

import torch
import torch.distributed as dist
from monai.config import print_config
from monai.handlers import (
    CheckpointSaver,
    LrScheduleHandler,
    MeanDice,
    StatsHandler,
    ValidationHandler,
    from_engine,
)
from monai.inferers import SimpleInferer, SlidingWindowInferer
from monai.losses import DiceCELoss
from monai.utils import set_determinism
from torch.nn.parallel import DistributedDataParallel

from create_dataset import get_data_ke
from create_network import get_network
from evaluator import DynUNetEvaluator
from task_params import data_loader_params, patch_size
from trainer import DynUNetTrainer

from monai.utils import first
import numpy as np


In [None]:

# 	train.py -fold $fold -train_num_workers 4 -interval 10 -num_samples 2 \
# 	-learning_rate $lr -max_epochs 3000 -task_id 01 -pos_sample_num 1 \
# 	-expr_name baseline -tta_val True -multi_gpu True

In [None]:
task_id = "01"
fold = 0
root_dir = "/data/kehyeong/project/MONAI_examples/data/brats/"
datalist_path ="config/"


train_num_workers = 4
val_num_workers = 1
interval = 10
eval_overlap = 0.5
sw_batch_size = 4
window_mode = "gaussian"
num_samples = 2
pos_sample_num = 1
neg_sample_num = 1
cache_rate = 1.0
learning_rate = 1e-1
max_epochs = 3000
mode = "train"
checkpoint = None
amp = False
lr_decay = False
tta_val = True
batch_dice = False
determinism_flag = False
determinism_seed = 0
expr_name = "baseline"
local_rank = 0

multi_gpu = False  # True

In [None]:
task_id = task_id
fold = fold
val_output_dir = "./runs_{}_fold{}_{}/".format(task_id, fold, expr_name)
log_filename = "nnunet_task{}_fold{}.log".format(task_id, fold)
log_filename = os.path.join(val_output_dir, log_filename)
interval = interval
learning_rate = learning_rate
max_epochs = max_epochs
multi_gpu_flag = multi_gpu
amp_flag = amp
lr_decay_flag = lr_decay
sw_batch_size = sw_batch_size
tta_val = tta_val
batch_dice = batch_dice
window_mode = window_mode
eval_overlap = eval_overlap
local_rank = local_rank
determinism_flag = determinism_flag
determinism_seed = determinism_seed
if determinism_flag:
    set_determinism(seed=determinism_seed)
    if local_rank == 0:
        print("Using deterministic training.")

# transforms
train_batch_size = data_loader_params[task_id]["batch_size"]
if multi_gpu_flag:
    dist.init_process_group(backend="nccl", init_method="env://")

    device = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(device)
else:
    device = torch.device("cuda:3")

print(device)

properties, val_loader = get_data_ke(fold, task_id, root_dir, datalist_path, pos_sample_num,
                                     neg_sample_num, num_samples, multi_gpu, val_num_workers,
                                     cache_rate, train_num_workers, mode="validation")
_, train_loader = get_data_ke(fold, task_id, root_dir, datalist_path, pos_sample_num,
                              neg_sample_num, num_samples, multi_gpu, val_num_workers,
                              cache_rate, train_num_workers, 
                              batch_size=train_batch_size, mode="train")

## ke
test_data = first(train_loader)
print(test_data.keys())
print('image, label shape')
print(test_data['image'].shape)
print(test_data['label'].shape)
print(test_data['image'].dtype)
print(test_data['label'].dtype)
print(np.unique(test_data['label']))

In [None]:
properties

In [None]:
for each in test_data['label']:
    print(each.shape)
    total_labels = np.unique(each)
    print(f'class 수 {len(total_labels)}')

In [None]:
# produce the network
checkpoint = checkpoint
net = get_network(properties, task_id, val_output_dir, checkpoint)
net = net.to(device)

if multi_gpu_flag:
    net = DistributedDataParallel(module=net, device_ids=[device])

optimizer = torch.optim.SGD(
    net.parameters(),
    lr=learning_rate,
    momentum=0.99,
    weight_decay=3e-5,
    nesterov=True,
)

scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer, lr_lambda=lambda epoch: (1 - epoch / max_epochs) ** 0.9
)
# produce evaluator
val_handlers = [
    StatsHandler(output_transform=lambda x: None),
    CheckpointSaver(
        save_dir=val_output_dir, save_dict={"net": net}, save_key_metric=True
    ),
]

evaluator = DynUNetEvaluator(
    device=device,
    val_data_loader=val_loader,
    network=net,
    num_classes=len(properties["labels"]),
    inferer=SlidingWindowInferer(
        roi_size=patch_size[task_id],
        sw_batch_size=sw_batch_size,
        overlap=eval_overlap,
        mode=window_mode,
    ),
    postprocessing=None,
    key_val_metric={
        "val_mean_dice": MeanDice(
            include_background=False,
            output_transform=from_engine(["pred", "label"]),
        )
    },
    val_handlers=val_handlers,
    amp=amp_flag,
    tta_val=tta_val,
)

# produce trainer
loss = DiceCELoss(to_onehot_y=True, softmax=True, batch=batch_dice)
train_handlers = []
if lr_decay_flag:
    train_handlers += [LrScheduleHandler(lr_scheduler=scheduler, print_lr=True)]

train_handlers += [
    ValidationHandler(validator=evaluator, interval=interval, epoch_level=True),
    StatsHandler(
        tag_name="train_loss", output_transform=from_engine(["loss"], first=True)
    ),
]

trainer = DynUNetTrainer(
    device=device,
    max_epochs=max_epochs,
    train_data_loader=train_loader,
    network=net,
    optimizer=optimizer,
    loss_function=loss,
    inferer=SimpleInferer(),
    postprocessing=None,
    key_train_metric=None,
    train_handlers=train_handlers,
    amp=amp_flag,
)

if local_rank > 0:
    evaluator.logger.setLevel(logging.WARNING)
    trainer.logger.setLevel(logging.WARNING)

logger = logging.getLogger()

formatter = logging.Formatter(
    "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)

# Setup file handler
fhandler = logging.FileHandler(log_filename)
fhandler.setLevel(logging.INFO)
fhandler.setFormatter(formatter)

logger.addHandler(fhandler)

chandler = logging.StreamHandler()
chandler.setLevel(logging.INFO)
chandler.setFormatter(formatter)
logger.addHandler(chandler)

logger.setLevel(logging.INFO)

trainer.run()

In [None]:
test_data = first(train_loader)
print(test_data.keys())
print('image, label shape')
print(test_data['image'].shape)
print(test_data['label'].shape)
print(test_data['image'].dtype)
print(test_data['label'].dtype)
print(np.unique(test_data['label']))

In [None]:
import matplotlib.pyplot as plt

In [None]:
H=92

test_data = first(train_loader)
image, label = (test_data["image"][0][0], test_data["label"][0][0])
print(test_data['image'].shape, test_data['label'].shape)
print(f"image shape: {image.shape}, label shape: {label.shape}")
print(f"image dtype: {image.dtype}, label dtype: {label.dtype}")
# plot the slice [:, :, 80]
plt.figure("check", (15, 10))
plt.subplot(1, 2, 1)
plt.title("image")
plt.imshow(image[:, :, H], cmap="gray")
plt.subplot(1, 2, 2)
plt.title("label")
plt.imshow(label[:, :, H])
# plt.subplot(1, 4, 3)
# plt.title("brain")
# plt.imshow(brain[:, :, H], cmap="gray")
# plt.subplot(1, 4, 4)
# plt.title("mask")
# plt.imshow(mask[:, :, H])
plt.show()

print(np.unique(label))

In [None]:
label.dtype

In [None]:
image

In [None]:
pixel_min, pixel_max = image.min().item(), image.max().item()
print(pixel_min, pixel_max)
histogram, bin_edges = np.histogram(image, bins=256, range=(pixel_min, pixel_max))
plt.figure()
plt.title("Grayscale Histogram")
plt.xlabel("grayscale value")
plt.ylabel("pixel count")
plt.xlim([pixel_min, pixel_max])  # <- named arguments do not work here

plt.plot(bin_edges[0:-1], histogram)  # <- or here
plt.show()

In [None]:
new_img = np.where(image < 0.001, 9, image)
new_img.shape

In [None]:
plt.figure("check", (15, 10))
plt.title("image")
plt.subplot(1, 2, 1)
plt.imshow(new_img[:, :, H], cmap="gray")
plt.subplot(1, 2, 2)
plt.imshow(image[:, :, H], cmap="gray")

In [None]:
new_img