In [None]:
import os
import glob
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt


import tqdm
from torch.utils.data import Dataset
#from torchvision import transforms, utils
import nibabel as nib
from sklearn.model_selection import KFold, train_test_split, StratifiedKFold

import torchio as tio
from torchio import IntensityTransform
from torchio.transforms.augmentation import RandomTransform
from torchio.data import Subject

from torchio.transforms import (
    RandomFlip,
    RandomAffine,
    RandomElasticDeformation, 
    RandomNoise,
    RandomMotion,
    RandomBiasField,
    RescaleIntensity,
    Resample,
    ToCanonical,
    ZNormalization,
    CropOrPad,
    HistogramStandardization,
    OneOf,
    Compose,
)


def pad_nifti(v,m,d=32):
    #vw = np.shape(v)[0]
    #vh = np.shape(v)[1]
    
    vd = np.shape(v)[2]
    
    if vd > 32:
        return v,m

    p2 = d - vd
    if p2 % 2 == 0:
        p2 = p2/2
        p2 = int(p2) # 소수점 버림
        nv = np.pad(v,((0,0),(0,0),(p2,p2)),'constant',constant_values=0)
        nm = np.pad(m,((0,0),(0,0),(p2,p2)),'constant',constant_values=0)
        
    else:
        p2 = p2/2
        p2 = int(p2) # 소수점 버림
        nv = np.pad(v,((0,0),(0,0),(p2,p2+1)),'constant',constant_values=0)
        nm = np.pad(m,((0,0),(0,0),(p2,p2+1)),'constant',constant_values=0)
    


    return nv,nm

# z score normalization
def z_score(data, lth = 0.02, uth = 0.98):

    temp = np.sort(data[data>0])
    lth_num = np.int(temp.shape[0]*0.02)
    uth_num = np.int(temp.shape[0]*0.98)
    data_mean = np.mean(temp[lth_num:uth_num])
    data_std = np.std(temp[lth_num:uth_num])
    data = (data - data_mean)/data_std

    return data


        
train_transform = tio.Compose([
    tio.ToCanonical(),
    tio.OneOf({
        tio.Lambda(lambda x:x, types_to_apply=None):0.34,
        tio.RandomAffine(scales=(0.95,1.25), degrees=0, image_interpolation='linear',isotropic=True, center='image'):0.33,
        tio.RandomElasticDeformation(num_control_points=5, image_interpolation='linear'):0.33,      
              
              
    }),
    tio.OneOf({
        tio.Lambda(lambda x:x, types_to_apply=None):0.34,
        tio.RandomBiasField():0.33, 
        tio.RandomGamma(log_gamma=(-0.05,0.05)):0.33,
              
    }),

])

valtest_transform = tio.Compose([
    tio.ToCanonical()
    
])

# custom dataloader
class BrainSegmentationDataset(Dataset):     
    # split dataset and read the specified partition
    def __init__(self, images_dir, transform=None, mode="train", K=0, num_folds=10, random_state=11, val_split=0.1):
        assert mode in ["train", "validation","test"]
        self.transform = transform
        self.mode = mode
        volumes = []
        masks = []
        files = []
        
        v_list = sorted(glob.glob(images_dir))
        m_list = sorted(glob.glob(images_dir))      
        
        # subject weights (by the number of foreground voxels)
        subj_weights = np.zeros((len(m_list),))
        cnt = 0
        for i in range(len(m_list)):
                tmpmask = nib.load(m_list[i]).get_fdata()
                tmpidx = np.where(tmpmask > 0.5)[-1]
                subj_weights[cnt] = len(tmpidx)
                cnt = cnt + 1
        
        # split the subjects by median value
        subj_weights = np.where(subj_weights > np.median(subj_weights),1,0)
        
        cnt_idx = np.arange(len(v_list)) # 0, 1, 2, ..., num_subj-1
        kf = StratifiedKFold(n_splits=num_folds,shuffle=True,random_state=random_state)
        train_indices,validation_indices,test_indices = {},{},{}
        cnt = 0
        for train_validation_idx, test_idx in kf.split(cnt_idx,subj_weights): 
            splits = train_test_split(train_validation_idx, 
                                      shuffle=True,
                                      random_state=random_state,
                                      test_size=val_split,
                                     stratify=subj_weights[train_validation_idx])
            train_indices[str(cnt)], validation_indices[str(cnt)] = splits[0], splits[1]
            test_indices[str(cnt)] = test_idx
            cnt += 1
            
        if mode == 'train':
            sel = train_indices[str(K)]
        elif mode == 'validation':
            sel = validation_indices[str(K)]
        elif mode == 'test':
            sel = test_indices[str(K)]
            
        # read images 
        print("reading {} images...".format(mode))
        print("The length of {} set is: {}".format(mode, len(sel)))

        for i in range(len(v_list)):
            if np.any(sel == i):
                v = nib.load(v_list[i]).get_fdata()
                m = nib.load(m_list[i]).get_fdata()
                v,m = pad_nifti(v,m)
                v = z_score(v)
                print("{} || {}".format(v_list[i],m_list[i]))

                volumes.append(v)
                masks.append(m)
                
        print("Number of volumes {} set is: {}".format(mode, len(volumes)))
        self.volumes = volumes
        self.masks = masks
        
        
    def __len__(self):
        return len(self.volumes)
    
    
    def __getitem__(self, idx):
        print('idx:'+idx)
        image = self.volumes[idx]
        mask = self.masks[idx]   

        
        image = torch.from_numpy(image)
        mask = torch.from_numpy(mask)
        image = torch.unsqueeze(image,0)
        mask = torch.unsqueeze(mask,0)
        
        
        
        subject_dict = {
            'image': tio.ScalarImage(tensor=image),
            'mask': tio.LabelMap(tensor=mask),
        }

        subject = tio.Subject(subject_dict)
        
        if self.transform is not None:
            subject = self.transform(subject)
        
        image = subject['image'].tensor
        mask = subject['mask'].tensor

        
        return image, mask

In [None]:
torch.cuda.is_available()
GPU_NUM = 4
device = torch.device(f'cuda:{GPU_NUM}' if torch.cuda.is_available() else 'cpu')
torch.cuda.set_device(device) 
print ('Current cuda device Number =', torch.cuda.current_device())

if device.type == 'cuda':
    print('Current cuda device Name =', torch.cuda.get_device_name(GPU_NUM))
    print('Current cuda device Memory =', (torch.cuda.get_device_properties(GPU_NUM).total_memory))

In [None]:
def plot_img(inputs, label, predicted):
    #_, predicted = torch.max(outputs, 1)
        
    inputs = (inputs.cpu()).numpy()  
    label = (label.cpu()).numpy()
    predicted = (predicted.cpu()).numpy()

    inputs = inputs.astype(np.float)[0]    
    label = label.astype(np.uint8)[0]
    predicted = predicted.astype(np.uint8)[0]

    inputs = np.transpose(inputs,(1,2,0))
    label = np.transpose(label,(1,2,0))
    predicted = np.transpose(predicted,(1,2,0))
    #print(inputs.shape)
    #print(label.shape)
    #print(predicted.shape)
    
    fig = plt.figure(figsize=(20,20))
    ax1 = fig.add_subplot(1,3,1)
    ax1.imshow(inputs, cmap="gray")
    ax1.set_xticklabels([])
    ax1.set_yticklabels([])
    ax2 = fig.add_subplot(1,3,2)
    ax2.imshow(label, cmap='gray', vmin=0, vmax=1)
    ax2.set_xticklabels([])
    ax2.set_yticklabels([])
    ax3 = fig.add_subplot(1,3,3)
    ax3.imshow(predicted, cmap='gray',vmin=0, vmax=1)
    ax3.set_xticklabels([])
    ax3.set_yticklabels([])      

    plt.show()



def flatten(tensor):
    """Flattens a given tensor such that the channel axis is first.
    The shapes are transformed as follows:
       (N, C, D, H, W) -> (C, N * D * H * W)
    """
    # number of channels
    C = tensor.size(1)
    # new axis order
    axis_order = (1, 0) + tuple(range(2, tensor.dim()))
    # Transpose: (N, C, D, H, W) -> (C, N, D, H, W)
    transposed = tensor.permute(axis_order)
    # Flatten: (C, N, D, H, W) -> (C, N * D * H * W)
    return transposed.contiguous().view(C, -1)


def compute_per_channel_dice(input, target, epsilon=1e-6, weight=None):
    """
    Computes DiceCoefficient as defined in https://arxiv.org/abs/1606.04797 given  a multi channel input and target.
    Assumes the input is a normalized probability, e.g. a result of Sigmoid or Softmax function.
    Args:
         input (torch.Tensor): NxCxSpatial input tensor
         target (torch.Tensor): NxCxSpatial target tensor
         epsilon (float): prevents division by zero
         weight (torch.Tensor): Cx1 tensor of weight per channel/class
    """

    # input and target shapes must match
    assert input.size() == target.size(), "'input' and 'target' must have the same shape"
    
    input = flatten(input)
    target = flatten(target)
    target = target.float()

    # compute per channel Dice Coefficient
    intersect = (input * target).sum(-1)
    if weight is not None:
        intersect = weight * intersect

    # here we can use standard dice (input + target).sum(-1) or extension (see V-Net) (input^2 + target^2).sum(-1)
    denominator = (input * input).sum(-1) + (target * target).sum(-1)
    return 2 * (intersect / denominator.clamp(min=epsilon))

     
class DiceCoefficient:
    """Computes Dice Coefficient.
    Generalized to multiple channels by computing per-channel Dice Score
    (as described in https://arxiv.org/pdf/1707.03237.pdf) and theTn simply taking the average.
    Input is expected to be probabilities instead of logits.
    This metric is mostly useful when channels contain the same semantic class (e.g. affinities computed with different offsets).
    DO NOT USE this metric when training with DiceLoss, otherwise the results will be biased towards the loss.
    """

    def __init__(self, epsilon=1e-6, **kwargs):
        self.epsilon = epsilon

    def __call__(self, input, target):
        #target = torch.unsqueeze(target,1)
        # Average across channels in order to get the final score
        return torch.mean(compute_per_channel_dice(input, target, epsilon=self.epsilon))

In [None]:
from monai.networks.nets import BasicUNet

unet = BasicUNet(spatial_dims=3, out_channels=3,dropout=0.3, features=(32, 32, 64, 128, 256, 32))

unet.to(device)

In [None]:
def datasets(images, K=0, num_folds=10,random_state=10,val_split=0.1):
    train = BrainSegmentationDataset(
        images_dir = images,
        mode="train",
        transform=train_transform,
        K=K,num_folds=num_folds,random_state=random_state,val_split=val_split,
    )
    valid = BrainSegmentationDataset(
        images_dir = images,
        mode="validation",
        transform=valtest_transform,
        K=K,num_folds=num_folds,random_state=random_state,val_split=val_split,
    )
    test = BrainSegmentationDataset(
        images_dir = images,
        mode="test",
        transform=valtest_transform,
        K=K,num_folds=num_folds,random_state=random_state,val_split=val_split,
    )
    return train, valid, test


def data_loaders(image_path, batch_size, K=0):
    dataset_train, dataset_valid, dataset_test = datasets(image_path, K=K)

    
    loader_train = DataLoader(
        dataset_train,
        batch_size=1,
        shuffle=True,
        drop_last=True
 
    )
    loader_valid = DataLoader(
        dataset_valid,
        batch_size=1,
        shuffle=False,
        drop_last=True

    )
    loader_test = DataLoader(
        dataset_test,
        batch_size=1,
        shuffle=False,
        drop_last=True

    )

    return loader_train, loader_valid, loader_test

In [None]:
from monai.losses import DiceLoss


epochs = 100
LR = 1e-4
weights = './'
loss1 = DiceLoss(to_onehot_y = True)
optimizer = optim.AdamW(unet.parameters(), lr=LR)

In [None]:
batch_size = 1
loader_train, loader_valid, loader_test = data_loaders('./', batch_size=1, K=0)
loaders = {"train": loader_train, "valid": loader_valid, "test" : loader_test}

In [None]:
best_validation_dsc = 0.0

loss_train=[]
loss_valid=[]

allloss_train = []
alldsc_train = []
allloss_val = []
alldsc_val = []

dsc = DiceCoefficient()

# actual processing...
step = 0
print("1 Fold")
#print(num_folds)



for epoch in range(epochs):
    for phase in ["train", "valid"]:
        if phase == "train":
            unet.train()
        else:
            unet.eval()
            
        train_dsc_list = []
        validation_dsc_list = []
        
        
        for i, data in enumerate(loaders[phase]):
            if phase == "train": step += 1
            x, y_true = data 
            x, y_true = x.to(device).float(),y_true.to(device).float()
            #print(x.shape)
            #for i in range(176):
                #plot_img(x[:,:,:,:,i], y_true[:,:,:,:,i], y_true[:,:,:,:,i])

            optimizer.zero_grad()
            #print(y_true.shape)
            
            
            with torch.set_grad_enabled(phase == "train"): 
                
                y_pred = unet(x) # forwarding
                y_pred = F.softmax(y_pred,dim=1)
                
                loss_1 = loss1(y_pred, y_true) 
                
                loss = loss_1 
                
                
                if phase == "train":
                    
                    loss_train.append(loss.item())
                    
                    
                    y_pred = torch.argmax(y_pred,dim=1,keepdim=True)
                    f1 = dsc(y_pred, y_true)
                    train_dsc_list.append(f1.item())
                    
                    loss.backward()
                    optimizer.step() 
                  
                    
                if phase == "valid":
                    loss_valid.append(loss.item()) # gathering the loss
                    y_pred = torch.argmax(y_pred,dim=1,keepdim=True)
                    f1 = dsc(y_pred, y_true)
                    validation_dsc_list.append(f1.item())
                    
        if phase == "train": # reporting
            print("epoch {:04d}     | {}: {}".format(epoch + 1, "Train loss", np.mean(loss_train)))
            print("               | {}: {}".format("Train Dice", np.mean(train_dsc_list)))
            allloss_train.append(np.mean(loss_train))
            alldsc_train.append(np.mean(train_dsc_list))
            loss_train = []
            
        if phase == "valid": #reporting
            print("               | {}: {}".format("Validation loss", np.mean(loss_valid)))
            print("               | {}: {}".format("Validation Dice", np.mean(validation_dsc_list)))
            
                       
                # save best model
            if np.mean(validation_dsc_list) > best_validation_dsc:
                best_validation_dsc = np.mean(validation_dsc_list)
                torch.save(unet.state_dict(), os.path.join(weights, 'basic_UNet_1fold.pth'))
                print("\n save dice : {:4f}\n".format(best_validation_dsc))  
                  
            
            alldsc_val.append(np.mean(validation_dsc_list))
            allloss_val.append(np.mean(loss_valid))
            loss_valid = []
            

            

        
                    
            
        
print("\nBest validation mean DSC: {:4f}\n".format(best_validation_dsc))