In [46]:
import sys
sys.path.append("../../../0_CNN_total_Pytorch_new")
import os
from glob import glob
from tqdm import tqdm
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from glob import glob
import ast
import cv2
from matplotlib import pyplot as plt
from src.data_set.segmentation import SegDataset
from src.data_set.utils import read_json_as_dict
import random
import pandas as pd
import SimpleITK as sitk
from torch import nn
from natsort import natsorted
from utils import read_dicom_series_to_array, get_dicom_series_shape, read_nii_to_array, get_nii_shape
from utils import resize_dicom_series, write_series_to_path
from utils import get_parent_dir_name
from volumentations import Compose, ElasticTransform, RandomGamma, GaussianNoise, Transpose, Flip, RandomRotate90, GlassBlur, RandomCrop, GridDropout

from itertools import chain
import deepspeed
import math
from copy import deepcopy
from src.loss.seg_loss import get_dice_score, get_loss_fn, get_bce_loss_class, accuracy_metric
from src.model.inception_resnet_v2.multi_task.multi_task_3d import InceptionResNetV2MultiTask3D
from src.util.deepspeed import get_deepspeed_config_dict, average_across_gpus, toggle_grad, load_deepspeed_model_to_torch_model
from src.util.common import set_dropout_probability
import torch.nn.functional as F
from src.model.inception_resnet_v2.multi_task.multi_task_3d import InceptionResNetV2MultiTask3D
from src.model.train_util.logger import CSVLogger
import torch.distributed as dist
import pandas as pd
import csv
import nibabel as nib
import SimpleITK as sitk

In [51]:
def save_array_as_nifti(array, output_path, reverse=True):
    
    if reverse:
        array = array[::-1]
    # SimpleITK 이미지로 변환
    image = sitk.GetImageFromArray(array)
    
    # 이미지 정보 설정 (spacing, origin, direction 등)
    image.SetSpacing((1.0, 1.0, 1.0))  # 임의의 값으로 설정
    image.SetOrigin((0.0, 0.0, 0.0))   # 임의의 값으로 설정
    image.SetDirection((1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0))  # 단위 행렬로 설정

    # SimpleITK 이미지를 NIfTI 형식으로 저장
    sitk.WriteImage(image, output_path)

def full_data_collate_fn(batch):
    # shape 이 달라 list로 따로 들어있다
    image_tensor_list, mask_tensor_list = zip(*batch)
    return image_tensor_list, mask_tensor_list

def apply_augmentation(transform, image_array, mask_array):
    transform_dict = transform(image=image_array, mask=mask_array)
    
    return transform_dict["image"], transform_dict["mask"]

def get_augmentation():
    return Compose([
#         GridDropout(0.5, fill_value=0, mask_fill_value=0, p=1.0),
#         ElasticTransform(deformation_limits=(0, 0.15), p=1.0),
#         GlassBlur(sigma=0.05, max_delta=2, iterations=2, always_apply=False, mode='fast', p=0.5),
#         RandomGamma(gamma_limit=(80, 120), p=0.35),
        GaussianNoise(var_limit=(0, 5), p=0.5),
#         Transpose(p=0.5),
        Flip(0, p=0.5),
        Flip(1, p=0.5),
        Flip(2, p=0.5),
        RandomRotate90((1, 2), p=0.5),
    ], p=1.0)

def read_image(image_path):
    image_array = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    return image_array

def image_preprocess(image_array):
    min_value, max_value = 0, 255
    image_array = image_array.clip(min_value, max_value)
    image_array = (image_array - min_value) / (max_value - min_value)
    image_array = image_array / 255
    return image_array.astype("float32")[None]

def mask_preprocess(mask_array):
    mask_array = mask_array / 255
    return mask_array.astype("float32")

def get_partial_3d_model(target_z_dim, num_classes, get_class, get_recon):
    return InceptionResNetV2MultiTask3D(input_shape=(1, target_z_dim, 512, 512),
                                        class_channel=num_classes, seg_channels=num_classes, validity_shape=(1, 8, 8, 8),
                                        inject_class_channel=None,
                                        block_size=6, decode_init_channel=None,
                                        skip_connect=True, dropout_proba=0.05, norm="instace", act="relu6",
                                        class_act="softmax", seg_act="softmax", validity_act="sigmoid",
                                        get_seg=True, get_class=get_class, get_recon=get_recon, get_validity=False,
                                        use_class_head_simple=True, use_decode_pixelshuffle_only=False, use_decode_simpleoutput=False
                                        )
def get_log_path(test_model, loss_select, batch_size, target_z_dim, current_fold):
    log_path = f"./result/3d_{test_model}_{loss_select}_{batch_size}"
    log_path = f"{log_path}_{target_z_dim}"
    if get_class:
        log_path = f"{log_path}_class"
    if get_recon:
        log_path = f"{log_path}_recon"
    log_path = f"{log_path}_fold_{current_fold}"
    os.makedirs(f"{log_path}/weights", exist_ok=True)
    return log_path

def get_best_dice_epoch(csv_path, select_mode="dice_score"):

    select_mode_list = ["loss", "dice_score", "dice_score_diff"]
    assert select_mode in select_mode_list, f"check your select_mode: {select_mode} in {select_mode_list}"

    with open(csv_path) as csv_file:
        reader = csv.DictReader(csv_file)
        dict_from_csv = {field_name:[] for field_name in reader.fieldnames}    
        for row in reader:
            for filedname in reader.fieldnames:
                dict_from_csv[filedname].append(float(row[filedname]))
    if select_mode == "loss":
        loss_min_epoch = np.argmin(dict_from_csv['val_loss']) + 1
        loss_min_loss = np.min(dict_from_csv['val_loss'])
        loss_min_score = dict_from_csv['val_dice_score'][loss_min_epoch - 1]
        return loss_min_epoch, loss_min_loss, loss_min_score
    elif select_mode == "dice_score":
        score_max_epoch = np.argmax(dict_from_csv['val_dice_score']) + 1
        score_max_loss = dict_from_csv['val_loss'][score_max_epoch - 1]
        score_max_score = np.max(dict_from_csv['val_dice_score'])
        return score_max_epoch, score_max_loss, score_max_score

    elif select_mode == "dice_score_diff":
        min_epoch = 5
        val_score = dict_from_csv['val_dice_score'][min_epoch:]
        score_diff = np.array(dict_from_csv['dice_score'] - np.array(dict_from_csv['val_dice_score']))[min_epoch:]
        score_diff = np.maximum(score_diff, 0)

        loss_score_diff_min_epoch = np.argmax(val_score - score_diff) + 1 + min_epoch
        loss_score_diff_min_loss = dict_from_csv['val_loss'][loss_score_diff_min_epoch - 1]
        loss_score_diff_min_score = dict_from_csv['val_dice_score'][loss_score_diff_min_epoch - 1]
        return loss_score_diff_min_epoch, loss_score_diff_min_loss, loss_score_diff_min_score
    
class SegDataset(Dataset):
    def __init__(self, data_folder_list, z_dim_list, target_z_dim, use_full):
        self.data_folder_list = data_folder_list
        self.z_dim_list = z_dim_list
        
        self.target_z_dim = target_z_dim
        self.use_full = use_full
        self.transform = get_augmentation()
        
    def __len__(self):
        return len(self.data_folder_list)
    
    def __getitem__(self, idx):
        data_folder, z_dim = self.data_folder_list[idx], self.z_dim_list[idx]
        image_path = f"{data_folder}/nii.gz/image.nii.gz"
        mask_path = f"{data_folder}/nii.gz/mask.nii.gz"
        image_array, mask_array = self.get_array_from_folder(image_path, mask_path, z_dim)
        image_array, mask_array = apply_augmentation(self.transform, image_array, mask_array)
        image_array = image_preprocess(image_array)
        mask_array = mask_preprocess(mask_array)
        
        return torch.tensor(image_array), torch.tensor(mask_array)
    
    def get_array_from_folder(self, image_path, mask_path, z_dim):
        target_z_dim = self.target_z_dim
        
        if target_z_dim > z_dim:
            z_idx = 0
            padding_num = target_z_dim - z_dim
            z_idx_range = range(z_idx, min(z_idx + target_z_dim, z_dim))
        else:
            if self.use_full:
                z_idx = 0
                padding_num = 0
                z_idx_range = range(z_idx, z_dim)
            else:
                cand_z_idx_list = [idx for idx in range(0, z_dim - target_z_dim)]
                z_idx = random.choice(cand_z_idx_list)
                padding_num = z_dim - z_idx if z_dim - z_idx < target_z_dim else 0
                z_idx_range = range(z_idx, min(z_idx + target_z_dim, z_dim))
        top_pad_num = padding_num // 2
        bottom_pad_num = padding_num - top_pad_num
        
        image_array = self.read_nii_gz(image_path)
        mask_array = self.read_nii_gz(mask_path)

        image_array = image_array[list(z_idx_range)]
        mask_array = mask_array[list(z_idx_range)]
        image_array = np.pad(image_array, [(top_pad_num, bottom_pad_num), (0, 0), (0, 0)],
                             mode="constant", constant_values=0)
        mask_array = np.pad(mask_array, [(top_pad_num, bottom_pad_num), (0, 0), (0, 0)],
                             mode="constant", constant_values=0)
        return image_array, mask_array
    
    def read_nii_gz(self, nii_path):
        image_obj = nib.load(nii_path)
        image_array = image_obj.get_fdata()
        image_array = image_array.transpose(2, 0, 1)
        return image_array

class SegDataset(Dataset):
    def __init__(self, data_folder_list, z_dim_list, target_z_dim, use_full):
        self.data_folder_list = data_folder_list
        self.z_dim_list = z_dim_list
        
        self.target_z_dim = target_z_dim
        self.use_full = use_full
        self.transform = get_augmentation()
        
    def __len__(self):
        return len(self.data_folder_list)
    
    def __getitem__(self, idx):
        data_folder, z_dim = self.data_folder_list[idx], self.z_dim_list[idx]
        image_path = f"{data_folder}/nii.gz/image.nii.gz"
        mask_path = f"{data_folder}/nii.gz/mask.nii.gz"
        image_array, mask_array = self.get_array_from_folder(image_path, mask_path, z_dim)
        image_array, mask_array = apply_augmentation(self.transform, image_array, mask_array)
        image_array = image_preprocess(image_array)
        mask_array = mask_preprocess(mask_array)
        
        return torch.tensor(image_array), torch.tensor(mask_array)
    
    def get_array_from_folder(self, image_path, mask_path, z_dim):
        target_z_dim = self.target_z_dim
        
        if target_z_dim > z_dim:
            z_idx = 0
            padding_num = target_z_dim - z_dim
            z_idx_range = range(z_idx, min(z_idx + target_z_dim, z_dim))
        else:
            if self.use_full:
                z_idx = 0
                padding_num = 0
                z_idx_range = range(z_idx, z_dim)
            else:
                cand_z_idx_list = [idx for idx in range(0, z_dim - target_z_dim)]
                z_idx = random.choice(cand_z_idx_list)
                padding_num = z_dim - z_idx if z_dim - z_idx < target_z_dim else 0
                z_idx_range = range(z_idx, min(z_idx + target_z_dim, z_dim))
        top_pad_num = padding_num // 2
        bottom_pad_num = padding_num - top_pad_num
        
        image_array = self.read_nii_gz(image_path)
        mask_array = self.read_nii_gz(mask_path)

        image_array = image_array[list(z_idx_range)]
        mask_array = mask_array[list(z_idx_range)]
        image_array = np.pad(image_array, [(top_pad_num, bottom_pad_num), (0, 0), (0, 0)],
                             mode="constant", constant_values=0)
        mask_array = np.pad(mask_array, [(top_pad_num, bottom_pad_num), (0, 0), (0, 0)],
                             mode="constant", constant_values=0)
        return image_array, mask_array
    
    def read_nii_gz(self, nii_path):
        image_obj = nib.load(nii_path)
        image_array = image_obj.get_fdata()
        image_array = image_array.transpose(2, 0, 1)
        return image_array
    
get_l1_loss = nn.L1Loss()
get_l2_loss = nn.MSELoss()
# y_recon_pred.shape = [B, C, H, W]
def get_recon_loss_follow_seg(y_recon_pred, y_recon_gt, y_seg_pred):
    img_dim = y_recon_pred.dim() - 2
    repeat_tuple = (1 for _ in range(img_dim))
    recon_image_channel = y_recon_pred.size(1)
    y_seg_pred_weight = 2 * torch.sigmoid(25 * y_seg_pred[:, 1]) - 1
    y_seg_pred_weight = y_seg_pred_weight.unsqueeze(1).repeat(1, recon_image_channel, *repeat_tuple)
    recon_loss = torch.abs(y_recon_pred - y_recon_gt) * y_seg_pred_weight
    return torch.mean(recon_loss)

def compute_loss_metric_full(model, x_list, y_list, target_z_dim, process_batch_size,
                             get_class, get_recon, use_class_in_predict, device, dtype):
    
    y_pred_list, dice_score_list = model_predict_full(model, x_list, y_list, target_z_dim, process_batch_size,
                                                      get_class, get_recon, use_class_in_predict, device, dtype)
    return y_pred_list, dice_score_list

def model_predict_full(model, x_list, y_list, target_z_dim, process_batch_size,
                  get_class, get_recon, use_class_in_predict, device, dtype):
    with torch.no_grad():
        y_pred_list = []
        dice_score_list_per_gpu = []
        for batch_x, batch_y in zip(x_list, y_list):
            batch_x, batch_y = batch_x[None], batch_y[None]
            stride = target_z_dim // 4
            batch_y_pred = torch.zeros_like(batch_y)
            batch_info_list = []
            z_dim = batch_x.shape[2]
            z_idx_range = range(0, z_dim - target_z_dim + stride, stride)
            for idx, z_idx in enumerate(z_idx_range):
                x_slice = batch_x[:, :, z_idx:z_idx+target_z_dim]

                pad_num = target_z_dim - x_slice.shape[2]
                pad_half = pad_num // 2
                x_slice = F.pad(x_slice, (0, 0, 0, 0, pad_half, pad_num - pad_half), "constant", 0)
                batch_info_list.append([z_idx, pad_num, x_slice])

                if len(batch_info_list) == process_batch_size or idx < len(z_idx_range):
                    slice_batch = [batch_info[-1] for batch_info in batch_info_list]
                    slice_batch = torch.cat(slice_batch, dim=0).to(device=device, dtype=dtype)

                    if get_class and get_recon:
                        batch_predict, batch_label_predict, batch_recon_predict = model(slice_batch)
                    elif get_class:
                        batch_predict, batch_label_predict = model(slice_batch)
                    elif get_recon:
                        batch_predict, batch_recon_predict = model(slice_batch)
                    else:
                        batch_predict = model(slice_batch)

                    if get_class and use_class_in_predict:
                        # batch_indices.shape = [B, C]
                        # batch_predict.shape = [B, C, D, H, W]
                        batch_predict = batch_predict * batch_label_predict[:, :, None, None, None]

                    for idx, (slice_z_idx, pad_num) in enumerate([batch_info[:2] for batch_info in batch_info_list]):
                        pad_half = pad_num // 2
                        y_slice_pred = batch_predict[idx][None]
                        start_idx = slice_z_idx
                        end_idx = min(slice_z_idx + target_z_dim, z_dim)
                        z_slice = slice(start_idx, end_idx)
                        # total - pad_num => total - pad_num + pad_half
                        part_z_slice = slice(pad_half, end_idx - slice_z_idx + pad_half)
                        previous_slice = batch_y_pred[:, z_slice]
                        current_slice = y_slice_pred[:, :, part_z_slice].argmax(1).cpu()
                        batch_y_pred[:, z_slice] = torch.maximum(previous_slice, current_slice)
                    batch_info_list = []
            y_pred_list.append(batch_y_pred)

        for y_pred, y in zip(y_pred_list, y_list):
            epsilon = 1e-7
            y_pred, y = y_pred[0].numpy(), y[0].numpy()
            tp = np.sum(y_pred * y)
            fp = np.sum(y_pred) - tp
            fn = np.sum(y) - tp
            dice_score = (2 * tp + epsilon) / (2 * tp + fp + fn + epsilon)
            dice_score = torch.tensor(dice_score).to(device=device, dtype=dtype)
            dice_score_list_per_gpu.append(dice_score)
    
    return y_pred_list, dice_score_list_per_gpu

In [12]:
n_fold = 10

target_z_dim = 32
loss_list = ["dice_bce", "dice_bce_focal", "tversky_bce", "propotional_bce"]

total_epoch = 10
stage_coef_list = [2, 5]
decay_epoch = total_epoch - sum(stage_coef_list)
decay_dropout_ratio = 0.25 ** (1  / (total_epoch - sum(stage_coef_list)))
lr_setting_list = [4e-5, 2e-4, 0.25]

test_model = "unet_custom"
loss_select = "propotional_bce"
in_channels = 1
num_classes = 2
get_class = True
get_recon = True
use_seg_in_recon = True

batch_size = 2
num_gpu = torch.cuda.device_count()

In [31]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
meta_df_path = f"./phase.csv"
check_result_folder = "./check_output"
for current_fold in range(n_fold):
    test_fold = current_fold
    check_result_fold_folder = f"{check_result_folder}/fold_{current_fold}"
    os.makedirs(check_result_fold_folder)
    
    meta_df_base, meta_df_ext = os.path.splitext(meta_df_path)
    meta_fold_df_path = f"{meta_df_base}_fold{meta_df_ext}"
    assert os.path.exists(meta_fold_df_path), "check meta_fold_df_path existence"
    meta_fold_df = pd.read_csv(meta_fold_df_path)

    test_folder_list = list(meta_fold_df[meta_fold_df["Fold"] == test_fold]["data_folder"])
    test_z_dim_list = list(meta_fold_df[meta_fold_df["Fold"] == test_fold]["Depth"])
    test_dataset = SegDataset(test_folder_list, test_z_dim_list, target_z_dim, use_full=True)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=4, pin_memory=True, collate_fn=full_data_collate_fn, shuffle=False)
    print(f"test: {len(test_dataset)}")
    model = get_partial_3d_model(target_z_dim, num_classes, get_class, get_recon)
    model_param_num = sum(p.numel() for p in model.parameters())
    print(f"model_param_num = {model_param_num}")
    log_path = get_log_path(test_model, loss_select, batch_size, target_z_dim, current_fold)
    weight_path = f"{log_path}/weights/{total_epoch:03d}.ckpt"
    load_deepspeed_model_to_torch_model(model, weight_path)
    model = model.to(device=device).eval()
    with torch.no_grad():
        test_pbar = tqdm(enumerate(test_dataloader), total=len(test_dataloader))
        for batch_idx, (x_list, y_list) in test_pbar:
            y_pred_list, dice_score_list = compute_loss_metric_full(model, x_list, y_list, target_z_dim, process_batch_size=batch_size, 
                                                                   get_class=get_class, get_recon=get_recon, use_class_in_predict=False, 
                                                                   device=device, dtype=dtype)
            batch_test_folder_list = test_folder_list[batch_size * batch_idx:batch_size * (batch_idx + 1)]
            for test_folder, x, y, y_pred, dice_score in zip(batch_test_folder_list, x_list, y_list, y_pred_list, dice_score_list):
                data_basename = os.path.basename(test_folder)
                output_folder = f"{check_result_fold_folder}/{dice_score:.3f}_{data_basename}"
                os.makedirs(output_folder, exists_ok=True)
                image_path = f"{output_folder}/image.nii.gz"
                mask_path = f"{output_folder}/mask.nii.gz"
                pred_path = f"{output_folder}/pred.nii.gz"
                
                x, y, y_pred = x.numpy(), y.numpy(), y_pred.numpy()
                save_array_as_nifti(x, image_path)
                save_array_as_nifti(y, mask_path)
                save_array_as_nifti(y_pred, pred_path)

test: 92
model_param_num = 24001133
