In [1]:
!python -c "import monai; import nibabel; import tqdm" || pip install -q "monai-weekly[nibabel, tqdm]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline

21번과 비교실험 

In [2]:
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import shutil
import tempfile
import nibabel as nib
import numpy as np
import time
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import functools
from torch.nn import init
from monai.losses import DiceCELoss
from monai.losses.ssim_loss import SSIMLoss
from monai.inferers import sliding_window_inference
from monai.transforms import (
    AsDiscrete,
    Compose,
    Resized,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandSpatialCropSamplesd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
    EnsureTyped,
    RandGaussianNoise,
    RandGaussianSmooth,
    RandZoomd,
    RandFlip,
    RandRotate90,
    RandAdjustContrast,
    RandShiftIntensity,
    RandGibbsNoise,
    ScaleIntensity,
    RandSimulateLowResolutiond
)

from monai.config import print_config
from monai.metrics import MAEMetric
from monai.networks.nets import (SwinUNETR, UNETR, UNet, DynUNet, SegResNet)

from monai.data import (
    ThreadDataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
    set_track_meta,
)
import argparse
import torch
from torch import nn
from torch.optim.lr_scheduler import _LRScheduler

import math
import copy
from torch.nn import init
import functools
from torch.optim import lr_scheduler

import torch
import torch.nn as nn
import torch.nn.functional as F


class SingleConv(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, stride, padding):
        super(SingleConv, self).__init__()

        self.single_conv = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding, stride=stride, bias=True),
            nn.InstanceNorm3d(out_ch, affine=True),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.single_conv(x)


class DenseConvolve(nn.Module):
    def __init__(self, in_ch, growth_rate=16, stride=(1, 1, 1)):
        super(DenseConvolve, self).__init__()

        self.single_conv = nn.Sequential(
            nn.Conv3d(in_ch, growth_rate, kernel_size=(3, 3, 3), padding=1, stride=stride, bias=True),
            nn.InstanceNorm3d(growth_rate, affine=True),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return torch.cat((self.single_conv(x), x), dim=1)


class DenseDownsample(nn.Module):
    def __init__(self, in_ch, growth_rate=16, stride=(2, 2, 2)):
        super(DenseDownsample, self).__init__()

        self.single_conv = nn.Sequential(
            nn.Conv3d(in_ch, growth_rate, kernel_size=(3, 3, 3), padding=1, stride=stride, bias=True),
            nn.InstanceNorm3d(growth_rate, affine=True),
            nn.ReLU(inplace=True)
        )

        self.pooling = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))

    def forward(self, x):
        return torch.cat((self.single_conv(x), self.pooling(x)), dim=1)


class UNetUpsample(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(UNetUpsample, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, kernel_size=(3, 3, 3), padding=1, stride=(1, 1, 1), bias=True),
            nn.InstanceNorm3d(out_ch, affine=True),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = F.interpolate(x, scale_factor=2, mode='trilinear', align_corners=True)
        x = self.conv(x)
        return x


class Encoder(nn.Module):
    def __init__(self, in_ch, growth_rate=16):
        super(Encoder, self).__init__()
        self.encoder_1 = nn.Sequential(
            DenseConvolve(in_ch, growth_rate),
            DenseConvolve(in_ch + growth_rate, growth_rate),
        )
        self.encoder_2 = nn.Sequential(
            DenseDownsample(in_ch + 2 * growth_rate, growth_rate),
            DenseConvolve(in_ch + 3 * growth_rate, growth_rate),
            DenseConvolve(in_ch + 4 * growth_rate, growth_rate)
        )
        self.encoder_3 = nn.Sequential(
            DenseDownsample(in_ch + 5 * growth_rate, growth_rate),
            DenseConvolve(in_ch + 6 * growth_rate, growth_rate),
            DenseConvolve(in_ch + 7 * growth_rate, growth_rate)
        )
        self.encoder_4 = nn.Sequential(
            DenseDownsample(in_ch + 8 * growth_rate, growth_rate),
            DenseConvolve(in_ch + 9 * growth_rate, growth_rate),
            DenseConvolve(in_ch + 10 * growth_rate, growth_rate)
        )
        self.encoder_5 = nn.Sequential(
            DenseDownsample(in_ch + 11 * growth_rate, growth_rate),
            DenseConvolve(in_ch + 12 * growth_rate, growth_rate),
            DenseConvolve(in_ch + 13 * growth_rate, growth_rate),
            DenseConvolve(in_ch + 14 * growth_rate, growth_rate),
            DenseConvolve(in_ch + 15 * growth_rate, growth_rate)
        )

    def forward(self, x):
        out_encoder_1 = self.encoder_1(x)
        out_encoder_2 = self.encoder_2(out_encoder_1)
        out_encoder_3 = self.encoder_3(out_encoder_2)
        out_encoder_4 = self.encoder_4(out_encoder_3)
        out_encoder_5 = self.encoder_5(out_encoder_4)

        return [out_encoder_1, out_encoder_2, out_encoder_3, out_encoder_4, out_encoder_5]


class Decoder(nn.Module):
    def __init__(self, in_ch, growth_rate, upsample_chan):
        super(Decoder, self).__init__()

        self.upconv_4 = UNetUpsample(in_ch + 16 * growth_rate, upsample_chan)
        self.decoder_conv_4 = nn.Sequential(
            SingleConv(in_ch + 11 * growth_rate + upsample_chan, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1),
                       padding=1),
            SingleConv(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=1)
        )
        self.upconv_3 = UNetUpsample(256, upsample_chan)
        self.decoder_conv_3 = nn.Sequential(
            SingleConv(in_ch + 8 * growth_rate + upsample_chan, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1),
                       padding=1),
            SingleConv(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=1)
        )
        self.upconv_2 = UNetUpsample(128, upsample_chan)
        self.decoder_conv_2 = nn.Sequential(
            SingleConv(in_ch + 5 * growth_rate + upsample_chan, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=1),
            SingleConv(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=1)
        )
        self.upconv_1 = UNetUpsample(64, upsample_chan)
        self.decoder_conv_1 = nn.Sequential(
            SingleConv(in_ch + 2 * growth_rate + upsample_chan, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=1),
            SingleConv(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=1)
        )

        self.final_conv = nn.Conv3d(32, 1, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=True)

    def forward(self, out_encoder):
        out_encoder_1, out_encoder_2, out_encoder_3, out_encoder_4, out_encoder_5 = out_encoder

        out_decoder_4 = self.decoder_conv_4(
            torch.cat((self.upconv_4(out_encoder_5), out_encoder_4), dim=1)
        )
        out_decoder_3 = self.decoder_conv_3(
            torch.cat((self.upconv_3(out_decoder_4), out_encoder_3), dim=1)
        )
        out_decoder_2 = self.decoder_conv_2(
            torch.cat((self.upconv_2(out_decoder_3), out_encoder_2), dim=1)
        )
        out_decoder_1 = self.decoder_conv_1(
            torch.cat((self.upconv_1(out_decoder_2), out_encoder_1), dim=1)
        )

        final_output = self.final_conv(out_decoder_1)
        return final_output


class HD_UNet(nn.Module):
    def __init__(self, in_ch, growth_rate, upsample_chan):
        super(HD_UNet, self).__init__()
        self.encoder = Encoder(in_ch, growth_rate)
        self.decoder = Decoder(in_ch, growth_rate, upsample_chan)

        # init
        self.initialize()

    @staticmethod
    def init_conv_IN(modules):
        for m in modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.)
            elif isinstance(m, nn.InstanceNorm3d):
                nn.init.constant_(m.weight, 1.)
                nn.init.constant_(m.bias, 0.)

    def initialize(self):
        print('# random init encoder weight using nn.init.kaiming_uniform !')
        self.init_conv_IN(self.decoder.modules)
        print('# random init decoder weight using nn.init.kaiming_uniform !')
        self.init_conv_IN(self.encoder.modules)

    def forward(self, x):
        out_encoder = self.encoder(x)
        out_decoder = self.decoder(out_encoder)

        # Output is a list: [Output]
        return out_decoder


class Model(nn.Module):
    def __init__(self, in_ch, growth_rate, upsample_chan):
        super(Model, self).__init__()

        self.model = HD_UNet(in_ch, growth_rate, upsample_chan)

    def forward(self, x):
        return self.model(x)
    
class NLayerDiscriminator(nn.Module):
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm3d, use_sigmoid=False):
        super(NLayerDiscriminator, self).__init__()
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm3d
        else:
            use_bias = norm_layer == nn.InstanceNorm3d

        kw = 4
        padw = 1
        sequence = [
            nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
            nn.LeakyReLU(0.2, True)
        ]

        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2**n, 8)
            sequence += [
                nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult,
                          kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2**n_layers, 8)
        sequence += [
            nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult,
                      kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]

        if use_sigmoid:
            sequence += [nn.Sigmoid()]

        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        return self.model(input)
def get_kernels_strides(patch_size, spacing):
    """
    This function is only used for decathlon datasets with the provided patch sizes.
    When refering this method for other tasks, please ensure that the patch size for each spatial dimension should
    be divisible by the product of all strides in the corresponding dimension.
    In addition, the minimal spatial size should have at least one dimension that has twice the size of
    the product of all strides. For patch sizes that cannot find suitable strides, an error will be raised.

    """
    sizes, spacings = patch_size, spacing
    input_size = sizes
    strides, kernels = [], []
    while True:
        spacing_ratio = [sp / min(spacings) for sp in spacings]
        stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)]
        kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
        if all(s == 1 for s in stride):
            break
        for idx, (i, j) in enumerate(zip(sizes, stride)):
            if i % j != 0:
                raise ValueError(
                    f"Patch size is not supported, please try to modify the size {input_size[idx]} in the spatial dimension {idx}."
                )
        sizes = [i / j for i, j in zip(sizes, stride)]
        spacings = [i * j for i, j in zip(spacings, stride)]
        kernels.append(kernel)
        strides.append(stride)

    strides.insert(0, len(spacings) * [1])
    kernels.append(len(spacings) * [3])
    return kernels, strides
class PolyLRScheduler(_LRScheduler):
    def __init__(self, optimizer, initial_lr: float, max_steps: int, exponent: float = 0.9, current_step: int = None):
        self.optimizer = optimizer
        self.initial_lr = initial_lr
        self.max_steps = max_steps
        self.exponent = exponent
        self.ctr = 0
        super().__init__(optimizer, current_step if current_step is not None else -1, False)

    def step(self, current_step=None):
        if current_step is None or current_step == -1:
            current_step = self.ctr
            self.ctr += 1

        new_lr = self.initial_lr * (1 - current_step / self.max_steps) ** self.exponent
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = new_lr
class GANLoss(nn.Module):
    def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0):
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        if use_lsgan:
            self.loss = nn.MSELoss()
        else:
            self.loss = nn.BCELoss()

    def get_target_tensor(self, input, target_is_real):
        if target_is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        return target_tensor.expand_as(input)

    def __call__(self, input, target_is_real):
        target_tensor = self.get_target_tensor(input, target_is_real)
        return self.loss(input, target_tensor)
    
norm_layer = functools.partial(nn.InstanceNorm3d, affine=False, track_running_stats=True)

In [3]:
def oar_run_train():
    
    root_dir = r"C:\! Project\2024_unitydoseprediction\traindata"
    ckpt_dir = r"G:\! project\2024- UnityDosePrediction\ckpt"
    data_dir = root_dir + r"/"
    num_samples = 2
    model_savepath = os.path.join(ckpt_dir, "20240813model_contrast_hdunet")
    loss_savepath = os.path.join(ckpt_dir, "20240813loss_contrast_hdunet")
    os.makedirs(model_savepath, exist_ok=True)
    os.makedirs(loss_savepath, exist_ok=True)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    space_x, space_y, space_z = 1.1, 1.1, 2.0
    a_min, a_max, b_min, b_max = -175, 350, 0, 1
    spatial_size_xyz = (96, 96, 64)
    
    train_transforms = Compose(
           [
                LoadImaged(keys=["image", "label"], ensure_channel_first=True),
                # Orientationd(keys=["image", "label"], axcodes="LPS"),
                EnsureTyped(keys=["image", "label"], device=device, track_meta=False),
                RandRotate90d(
                     keys=["image", "label"],
                     prob=0.2,
                     max_k=3,
                     spatial_axes=(0, 1)
                ),
                
                # RandZoomd(
                #     keys=["image", "label"],
                #     min_zoom=0.7,
                #     max_zoom=1.4,
                #     mode=("trilinear", "nearest"),
                #     align_corners=(True, None),
                #     prob=0.20,
                # ),
                RandFlipd(["image", "label"], spatial_axis=[0], prob=0.3),
                RandFlipd(["image", "label"], spatial_axis=[1], prob=0.3),
                RandFlipd(["image", "label"], spatial_axis=[2], prob=0.3),

                # RandCropByPosNegLabeld(
                #     keys=["image", "label"],
                #     label_key="label",
                #     spatial_size=spatial_size_xyz,
                #     pos=1,
                #     neg=1,
                #     num_samples=num_samples,
                #     image_key="image",
                #     image_threshold=0,
                # ),
                RandSpatialCropSamplesd(
                    keys=["image", "label"],
                    roi_size=spatial_size_xyz,
                    num_samples=num_samples,
                    random_size=False
                ),
                RandSimulateLowResolutiond(keys=["image"], prob=0.2),
                # RandGaussianNoised(keys=["image"], std=0.01, prob=0.1),
                # RandGaussianSmoothd(
                #     keys=["image"],
                #     sigma_x=(0.5, 1.15),
                #     sigma_y=(0.5, 1.15),
                #     sigma_z=(0.5, 1.15),
                #     prob=0.2    
                # ),
                
                # RandRotate90d(
                #      keys=["image", "label"],
                #      prob=0.3,
                #      max_k=3,
                # )
                # RandShiftIntensityd(
                #     keys=["image"],
                #     offsets=0.10,
                #     prob=0.50,
                # ),
            ]
            ) 
    
        
    
    val_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"], ensure_channel_first=True),
            # ScaleIntensityRanged(keys=["image"], a_min=a_min, a_max=a_max, b_min=0.0, b_max=1.0, clip=True),
            # Orientationd(keys=["image", "label"], axcodes="LPS"),
            EnsureTyped(keys=["image", "label"], device=device, track_meta=True),
        ]
    )

    test_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"], ensure_channel_first=True),
            ScaleIntensityRanged(keys=["image"], a_min=a_min, a_max=a_max, b_min=0.0, b_max=1.0, clip=True),
            Orientationd(keys=["image", "label"], axcodes="LPS"),
            Spacingd(
                keys=["image", "label"],
                pixdim=(space_x, space_y, space_z),
                mode=("bilinear", "nearest"),
            ),
            EnsureTyped(keys=["image", "label"], device=device, track_meta=True),
        ]
    )

    split_json = "dataset_f0" + ".json"
    
    datasets = data_dir + split_json
    datalist = load_decathlon_datalist(datasets, True, "training")
    val_files = load_decathlon_datalist(datasets, True, "validation")
    train_ds = CacheDataset(
        data=datalist,
        transform=train_transforms,
        cache_num=45,
        cache_rate=1.0,
        num_workers=8,
    )
    train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=1, shuffle=True)
    val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_num=0, cache_rate=1.0, num_workers=4)
    val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1)


    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    OAR_nums_plus_one = 1
    patch_size = list(spatial_size_xyz)
    spacing = [3.0, 3.0, 3.0]
    # # ks, st = get_kernels_strides(patch_size, spacing)
    # # print(ks, st)
    # # uks = st[1:]
    #dynunet hyperparameter
   
    # ks = [[3, 3, 1], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
    # st = [[1, 1, 1], [2, 2, 1], [2, 2, 2], [2, 2, 1]]
    # uks = st[1:]

    # netG = ResnetGenerator(12, 1, 64, norm_layer, False, 6, "reflect").to(device)
    netG = Model(in_ch=12, growth_rate=16, upsample_chan=64).to(device)
    # netD = NLayerDiscriminator(13, 64, 3, norm_layer, False).to(device)
    
    # model = SegResNet(in_channels = 5, out_channels = OAR_nums_plus_one, dropout_prob = 0.3, act="LEAKYRELU").to(device)
    # model = DynUNet(
    #     spatial_dims=3,
    #     in_channels=12,
    #     out_channels=OAR_nums_plus_one,
    #     kernel_size=ks,
    #     strides=st,
    #     upsample_kernel_size=uks,
    #     dropout=0.1,
    #     act_name= "LEAKYRELU",
    #     deep_supervision=False
    # ).to(device)
    

    torch.backends.cudnn.benchmark = True
    #FOR SINGLE CHANNEL PREDICTION
    # loss_function = nn.L1Loss()
    loss_function = nn.MSELoss()
    # criterionGAN = GANLoss().to(device)
    initial_lr = 1e-4
    # optimizer = torch.optim.SGD(model.parameters(), lr=initial_lr, weight_decay=3e-5, momentum=0.99, nesterov=True)
    optimizer_G = torch.optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.999))
    # optimizer_D = torch.optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.999))
    # scheduler = CosineAnnealingNoWarmUpRestarts(optimizer, 5, 2, 1e-2, 0, 0, 0.8)
    T_0 = 20
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=T_0, T_mult=2, eta_min=0)
    scheduler_G = PolyLRScheduler(optimizer_G, initial_lr=initial_lr, max_steps=500)
    # scheduler_D = PolyLRScheduler(optimizer_D, initial_lr=initial_lr, max_steps=500)
    scaler = torch.cuda.amp.GradScaler()

    # ckpt = torch.load(r"G:\! project\2024- UnityDosePrediction\ckpt\20240813model_contrast_hdunet\doseprediction_modelG_70.pth")
    
    # netG.load_state_dict(ckpt["netG_state_dict"])
    # optimizer_G.load_state_dict(ckpt["optimizer_G_state_dict"])
    # scheduler_G.load_state_dict(ckpt["scheduler_G_state_dict"])
    
    import gc
    import random
    def validation(epoch_iterator_val):
        netG.eval()
        with torch.no_grad():
            for batch in epoch_iterator_val:
                gc.collect()
                torch.cuda.empty_cache()
                val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda())
                
                # tx =  torch.reshape(val_inputs[:, 0, :, :, :].clone().detach(), (val_inputs.shape[0], 1, val_inputs.shape[2], val_inputs.shape[3], val_inputs.shape[4]))
                
                # val_inputs = torch.cat((tx, val_inputs[:, 1:, :, :, :]), 1)

                with torch.cuda.amp.autocast():
                    val_outputs = sliding_window_inference(val_inputs, spatial_size_xyz, 4, netG)
                # val_labels_list = decollate_batch(val_labels)
                # val_labels_convert = [post_label(val_label_tensor) for val_label_tensor in val_labels_list]
                # val_outputs_list = decollate_batch(val_outputs)
                # val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list]
                MAE_metric(y_pred=val_outputs, y=val_labels)
                epoch_iterator_val.set_description("Validate (%d / %d Steps)" % (global_step, 10.0))
            mean_mae_val = MAE_metric.aggregate().item()
            MAE_metric.reset()
        return mean_mae_val


    def train(global_step, train_loader, mae_val_best, global_step_best, t_0):
        netG.train()
        # netD.train()
        
        randcontrast = RandAdjustContrast(prob=0.1, gamma=(0.7, 1.5))
        randintensity = RandShiftIntensity(prob=0.1, offsets=0.4)
        randgibbs = RandGibbsNoise(prob=0.2, alpha=(0.5, 0.6))
        rgsmooth = RandGaussianSmooth(sigma_x=(0.5, 1.15),
                    sigma_y=(0.5, 1.15),
                    sigma_z=(0.5, 1.15),
                    prob=0.1)
        rgnoise = RandGaussianNoise(prob=0.1, std=0.1)
        epoch_loss_D = 0
        epoch_loss_G = 0
        step = 0
        # rand_fliper_x = RandFlip(prob=0.5, spatial_axis=0)
        # rand_fliper_y = RandFlip(prob=0.1, spatial_axis=1)
        # rand_fliper_z = RandFlip(prob=0.1, spatial_axis=2)
        size_x = spatial_size_xyz[0]
        size_y = spatial_size_xyz[1]
        size_z = spatial_size_xyz[2]
        
        epoch_iterator = tqdm(train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True)
        scale_intensity = ScaleIntensity(minv=-1, maxv=1)
        # from PIL import Image
        for step, batch in enumerate(epoch_iterator):
            step += 1
            realA, realB = (batch["image"].cuda(), batch["label"].cuda())
            # now_x = realA.shape[2]
            # now_y = realA.shape[3]
            # now_z = realA.shape[4]
            # max_x = now_x - size_x
            # max_y = now_y - size_y 
            # max_z = now_z - size_z 
            # rnd_x = random.randint(0, max_x)
            # rnd_y = random.randint(0, max_y)
            # rnd_z = random.randint(0, max_z)
            # realA = realA[:, :, rnd_x:rnd_x+size_x, rnd_y:rnd_y+size_y, rnd_z:rnd_z+size_z].cuda()
            # realB = realB[:, :, rnd_x:rnd_x+size_x, rnd_y:rnd_y+size_y, rnd_z:rnd_z+size_z].cuda()
            
            tx =  torch.reshape(realA[:, 0, :, : ,:].clone().detach(), (realA.shape[0], 1, realA.shape[2], realA.shape[3], realA.shape[4]))
            one = torch.reshape(realA[:, 1, :, : ,:].clone().detach(), (realA.shape[0], 1, realA.shape[2], realA.shape[3], realA.shape[4]))
            ptv = torch.reshape(realA[:, 2, :, : ,:].clone().detach(), (realA.shape[0], 1, realA.shape[2], realA.shape[3], realA.shape[4]))
            tx = scale_intensity(tx)
            ptv = scale_intensity(ptv)
            realB = scale_intensity(realB)
            # mr_img = np.squeeze(tx.detach().clone().cpu().numpy()[0])
            # ptv_img = np.squeeze(ptv.detach().clone().cpu().numpy()[0])
            # realB_img = np.squeeze(realB.detach().clone().cpu().numpy()[0])
            
            # mr_img = np.uint8( 255 * (np.transpose(mr_img, (2, 1, 0)) - np.min(mr_img))/(np.max(mr_img) - np.min(mr_img)))
            # ptv_img = np.uint8( 255 * (np.transpose(ptv_img, (2, 1, 0)) - np.min(ptv_img))/(np.max(ptv_img) - np.min(ptv_img)))
            # realB_img = np.uint8( 255 * (np.transpose(realB_img, (2, 1, 0)) - np.min(realB_img))/(np.max(realB_img) - np.min(realB_img)))
            # img = np.concatenate((mr_img, ptv_img, realB_img), axis = 2)
            # for i in range(len(img)):
            #     IM = Image.fromarray(img[i])
            #     IM.save(os.path.join(r"G:\! project\2024- UnityDosePrediction\debug", "P%06d_%04d_%03d.png" %(global_step,step, i)))
            
            
            # tx = randgibbs(tx)
            # tx = randintensity(tx)
            # tx = randcontrast(tx)& my
            # tx = rgsmooth(tx)
            # tx = rgnoise(tx)

            realA = torch.cat((tx, one, ptv, realA[:, 3:, :, :, :]), 1)
            # realB = scale_intensity(realB)
            
            with torch.cuda.amp.autocast():
                fakeB = netG(realA)
                # print(len(fakeB))
                # realAB = torch.cat((realA, realB), dim = 1)
                # fakeAB = torch.cat((realA, fakeB), dim = 1)
                # pred_fake = netD(fakeAB.detach())
                # pred_real = netD(realAB.detach())
                # loss_D_fake = criterionGAN(pred_fake, False)
                # loss_D_real = criterionGAN(pred_real, True)
                # loss_D = (loss_D_fake + loss_D_real) * 0.5
                
                # pred_G_fake = netD(fakeAB)
                # loss_G_GAN = criterionGAN(pred_G_fake, True)
                L1_loss = loss_function(fakeB, realB)
               
                 

            # scaler.scale(loss_D).backward()
            # epoch_loss_D += loss_D.item()
            # scaler.unscale_(optimizer_D)
            # scaler.step(optimizer_D)
            # optimizer_D.zero_grad()
            
            scaler.scale(L1_loss).backward()
            epoch_loss_G += L1_loss.item()
            scaler.unscale_(optimizer_G)
            scaler.step(optimizer_G)
            optimizer_G.zero_grad()
            
            scaler.update()
            
            
            epoch_iterator.set_description(f"Training ({global_step} / {max_iterations} Steps) (L2_loss={L1_loss:2.5f})")
            if (global_step % eval_num == 0 and global_step != 0) or global_step == max_iterations:
                if (global_step // eval_num) % save_period == 0:
                    global_step_best = global_step
                    torch.save({
                        'global_step': global_step,
                        'netG_state_dict': netG.state_dict(),
                        'optimizer_G_state_dict' : optimizer_G.state_dict(),
                        'scheduler_G_state_dict': scheduler_G.state_dict(),
                    }, os.path.join(model_savepath, "doseprediction_modelG_%d.pth" %(global_step//eval_num)))
                    # torch.save({
                    # 'global_step': global_step,
                    # 'netD_state_dict': netD.state_dict(),
                    # 'optimizer_D_state_dict': optimizer_D.state_dict(),
                    # 'scheduler_D_state_dict': scheduler_D.state_dict()
                    # }, os.path.join(model_savepath, "doseprediction_modelD_%d.pth" %(global_step//eval_num)))
                    
                    # print(optimizer.param_groups[0]['lr'])
                    scheduler_G.step()
                    # scheduler_D.step()
            global_step += 1
        # print((global_step+1) // (len(train_ds)) % t_0)
        # if (global_step+1) // (len(train_ds)) % t_0 == 0:
        #     print("donedone")
        #     scheduler.base_lrs[0] = scheduler.base_lrs[0] * (0.7)
        #     t_0 *= 2
        
        return global_step, mae_val_best, global_step_best, t_0
     
    max_iterations = 62503
    eval_num = 125
    post_label = AsDiscrete(to_onehot=OAR_nums_plus_one)
    post_pred = AsDiscrete(argmax=True, to_onehot=OAR_nums_plus_one)
    MAE_metric = MAEMetric(reduction="mean", get_not_nans=False)
    global_step = 0
    mae_val_best = 1000.0
    global_step_best = 0
    epoch_loss_values = []
    metric_values = []
    save_period = 10
    t_0 = T_0
    # global_step = ckpt["global_step"]
    while global_step < max_iterations:
        global_step, mae_val_best, global_step_best, t_0 = train(global_step, train_loader, mae_val_best, global_step_best, t_0=t_0)
        epoch_loss_npy = np.array(epoch_loss_values)
        np.save(os.path.join(loss_savepath, "%d_loss.npy" %(int(global_step))), epoch_loss_npy)
    # total_case_num = 24
    # model.load_state_dict(torch.load(os.path.join(root_dir, name_oar.lower() + "_model_fold0_0.pth")))
    # model.eval()
    # original_nib_path = root_dir + r"/imcroppedval"
    # original_nib_path_list = os.listdir(original_nib_path)
    return None

    # with torch.no_grad():
    #     for case_num in range(total_case_num):
    #         start = time.time()
    #         template_nib = nib.load(os.path.join(original_nib_path, original_nib_path_list[case_num]))

    #         img_name = os.path.split(val_ds[case_num]["image"].meta["filename_or_obj"])[1]
    #         img = val_ds[case_num]["image"]
    #         label = val_ds[case_num]["label"]
    #         val_inputs = torch.unsqueeze(img, 0).cuda()
    #         val_labels = torch.unsqueeze(label, 0).cuda()
    #         bone_min = -1000
    #         bone_max = 2000
    #         soft_min = -160
    #         soft_max = 350
    #         brain_min = -5
    #         brain_max = 65
    #         stroke_min = 15
    #         stroke_max = 45

    #         # box_start, box_end = FG_cropper.compute_bounding_box(val_inputs)
    #         tx =  torch.reshape(val_inputs[:, 0, :, :, :].clone().detach(), (val_inputs.shape[0], 1, val_inputs.shape[2], val_inputs.shape[3], val_inputs.shape[4]))
    #         x1 =  torch.reshape(val_inputs[:, 0, :, :, :].clone().detach(), (val_inputs.shape[0], 1, val_inputs.shape[2], val_inputs.shape[3], val_inputs.shape[4]))
    #         x2 =  torch.reshape(val_inputs[:, 0, :, :, :].clone().detach(), (val_inputs.shape[0], 1, val_inputs.shape[2], val_inputs.shape[3], val_inputs.shape[4]))
    #         x3 =  torch.reshape(val_inputs[:, 0, :, :, :].clone().detach(), (val_inputs.shape[0], 1, val_inputs.shape[2], val_inputs.shape[3], val_inputs.shape[4]))
    #         x4 =  torch.reshape(val_inputs[:, 0, :, :, :].clone().detach(), (val_inputs.shape[0], 1, val_inputs.shape[2], val_inputs.shape[3], val_inputs.shape[4]))
    #         tx2 = torch.reshape(val_inputs[:, 1, :, :, :].clone().detach(), (val_inputs.shape[0], 1, val_inputs.shape[2], val_inputs.shape[3], val_inputs.shape[4]))
    #         x1[x1<bone_min] = bone_min
    #         x1[x1>bone_max] = bone_max
    #         x1 = (x1-bone_min)/(bone_max-bone_min)
    #         x2[x2<soft_min] = soft_min
    #         x2[x2>soft_max] = soft_max
    #         x2 = (x2-soft_min)/(soft_max-soft_min)
    #         x3[x3<brain_min] = brain_min
    #         x3[x3>brain_max] = brain_max 
    #         x3 = (x3-brain_min)/(brain_max-brain_min)
    #         x4[x4<stroke_min] = stroke_min
    #         x4[x4>stroke_max] = stroke_max
    #         x4 = (x4 - stroke_min)/(stroke_max - stroke_min)
    #         val_inputs = torch.cat((tx, x1, x2, x3, x4, tx2), 1)

    #         val_outputs = sliding_window_inference(val_inputs, spatial_size_xyz, 4, model, overlap=0.5)
    #         last_outputs = torch.argmax(val_outputs, dim=1).detach().cpu()[0].numpy()
    #         img_npy = img.cpu()[0].numpy()
    #         label_npy = label.cpu()[0].numpy()
    #         print("time taken : %f" %(time.time()- start))
    #         # nib.save(
    #                 # nib.Nifti1Image(img_npy.astype(np.uint8), np.ones((4, 4))), os.path.join(r"D:\!HaN_Challenge\HaN-Seg_NRRD\!output_dir\label_1to10_fold0", "val_ct_%02d" %(case_num+1))
    #             # )

    #         nib.save(
    #                 nib.Nifti1Image(label_npy.astype(np.uint8), template_nib.affine, template_nib.header), os.path.join(root_dir + "/result", "val_label_%02d" %(case_num+1))
    #             )
    #         nib.save(
    #                 nib.Nifti1Image(last_outputs.astype(np.uint8), template_nib.affine, template_nib.header), os.path.join(root_dir + "/result", "infered_label_%02d" %(case_num+1))
    #             )

In [4]:


oar_run_train()
    

Loading dataset: 100%|██████████| 45/45 [00:08<00:00,  5.13it/s]


# random init encoder weight using nn.init.kaiming_uniform !
# random init decoder weight using nn.init.kaiming_uniform !


Training (8794 / 62503 Steps) (L2_loss=0.00195): 100%|██████████| 45/45 [00:10<00:00,  4.46it/s]
Training (8839 / 62503 Steps) (L2_loss=0.00237): 100%|██████████| 45/45 [00:06<00:00,  6.86it/s]
Training (8884 / 62503 Steps) (L2_loss=0.00865): 100%|██████████| 45/45 [00:06<00:00,  6.85it/s]
Training (8929 / 62503 Steps) (L2_loss=0.00695): 100%|██████████| 45/45 [00:06<00:00,  6.83it/s]
Training (8974 / 62503 Steps) (L2_loss=0.00295): 100%|██████████| 45/45 [00:06<00:00,  6.71it/s]
Training (9019 / 62503 Steps) (L2_loss=0.00269): 100%|██████████| 45/45 [00:06<00:00,  6.65it/s]
Training (9064 / 62503 Steps) (L2_loss=0.00552): 100%|██████████| 45/45 [00:06<00:00,  6.59it/s]
Training (9109 / 62503 Steps) (L2_loss=0.00317): 100%|██████████| 45/45 [00:06<00:00,  6.65it/s]
Training (9154 / 62503 Steps) (L2_loss=0.00436): 100%|██████████| 45/45 [00:06<00:00,  6.66it/s]
Training (9199 / 62503 Steps) (L2_loss=0.00619): 100%|██████████| 45/45 [00:06<00:00,  6.58it/s]
Training (9244 / 62503 Steps) 

fold 0 
arytenoid 0.64386
a_carotid_l 0.83728 (아마이것보단 높을것)
a_carotid_r 0.88483
bone_mandible 0.964496
oralcavity 0.924
cochlea_l 0.85533
optnr_1 0.77

fold 1 
arytenoid 0.45... 가우시안 노이즈 + 가우시안 블러링