In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
!pip install SimpleITK
!pip install medpy

In [0]:
"""
import os
import numpy as np
import SimpleITK as sitk
from scipy import ndimage
#pre-calculate of weight of masks

def compute_distance_weight_matrix(mask, alpha=1, beta=8, omega=6):
    mask = np.asarray(mask)
    distance_to_border = ndimage.distance_transform_edt(mask > 0) + ndimage.distance_transform_edt(mask == 0)    
    weights = alpha + beta*np.exp(-(distance_to_border**2/omega**2))
    return np.asarray(weights, dtype='float32')

mask_path =  './drive/My Drive/isotropic_dataset/test/seg'
weight_path = './drive/My Drive/isotropic_dataset/test/weight'

for f in [f for f in os.listdir(mask_path) if f.endswith('.mhd')]:
  seg_mask = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(mask_path,f)))
  weight = compute_distance_weight_matrix(seg_mask)
  sitk.WriteImage(sitk.GetImageFromArray(weight), os.path.join(weight_path, f.split('_')[0]+'_weight.nrrd'), True)
  print(f)
"""

In [0]:
"""
Created on Tue Sep 17 08:55:35 2019

@author: Gabriel Hsu

ref:https://www.kaggle.com/ori226/data-augmentation-with-elastic-deformations

"""
from __future__ import print_function, division
import os 
from random import randint


import pandas as pd
import numpy as np
from scipy import ndimage

import torch
from torch.utils.data import Dataset, DataLoader

import SimpleITK as sitk

from data_augmentation import elastic_transform, gaussian_blur, gaussian_noise, crop_z

"""
The dataset of MICCAI 2014 Spine Challenge

"""

#%% Build the dataset 
class CSI_Dataset(Dataset):
    """xVertSeg Dataset"""
    
    def __init__(self, dataset_path, subset='train', linear_att=1.0, offset=1000.0):
        """
        Args:
            path_dataset(string): Root path to the whole dataset
            subset(string): 'train' or 'test' depend on which subset
        """
        self.idx = 1
        
        self.dataset_path = dataset_path
        self.subset = subset
        self.linear_att = linear_att
        self.offset = offset
        
        
        self.img_path = os.path.join(dataset_path, subset, 'img')
        self.mask_path = os.path.join(dataset_path, subset, 'seg')
        self.weight_path = os.path.join(dataset_path, subset, 'weight')
        
        self.img_names =  [f for f in os.listdir(self.img_path) if f.endswith('.mhd')]
#        self.mask_names = [f for f in os.listdir(self.mask_path) if f.endswith('.mhd')]

     
    def __len__(self):
        return len(self.img_names)
    
    
    def __getitem__(self, idx):
    
        img_name =  self.img_names[idx]
        mask_name = self.img_names[idx].split('.')[0]+'_label.mhd'
        weight_name = self.img_names[idx].split('.')[0]+'_weight.nrrd'
        
        img_file = os.path.join(self.img_path,  img_name)
        mask_file = os.path.join(self.mask_path, mask_name)
        weight_file = os.path.join(self.weight_path, weight_name)
        
        img = sitk.GetArrayFromImage(sitk.ReadImage(img_file))
        mask = sitk.GetArrayFromImage(sitk.ReadImage(mask_file))
        weight = sitk.GetArrayFromImage(sitk.ReadImage(weight_file))
        
        #z, y, x
        #print(img.shape)
        #print(mask.shape)
        #linear transformation from 12bit reconstruction img to HU unit
        #depend on the original data (CSI data value is from 0 ~ 4095)
        #img = img * self.linear_att - self.offset
        
        #image standardize
        m = np.mean(img)
        s = np.std(img)
        img-=m
        img/=s
        
        
        img_patch, ins_patch, gt_patch, weight_patch, c_label = extract_random_patch(img, 
                                                              mask, weight, self.idx)
        #print(np.mean(img_patch), np.std(img_patch))

        self.idx+=1
            
        return img_patch, ins_patch, gt_patch, weight_patch, c_label
        
#%% Compute weight distance for loss function
def compute_distance_weight_matrix(mask, alpha=1, beta=8, omega=6):
    mask = np.asarray(mask)
    distance_to_border = ndimage.distance_transform_edt(mask > 0) + ndimage.distance_transform_edt(mask == 0)    
    weights = alpha + beta*np.exp(-(distance_to_border**2/omega**2))
    return np.asarray(weights, dtype='float32')
    
    
#%% Extract the 128*128*128 patch
def extract_random_patch(img, mask, weight, i, patch_size=128):
    
    
    #list available vertebrae
    verts = np.unique(mask)
#    print('mask values:', verts)
    chosen_vert = verts[randint(1, len(verts)-1)]
#    print('chosen_vert:', chosen_vert)
    
    #create corresponde instance memory and ground truth
    ins_memory = np.copy(mask)
    ins_memory[ins_memory <= chosen_vert] = 0
    ins_memory[ins_memory > 0] = 1
#    print(np.unique(ins_memory))
    
    gt = np.copy(mask)
    gt[gt != chosen_vert] = 0
    gt[gt > 0] = 1
#    print(np.unique(gt))

    flag_empty = False
    
    if not i%6:
        #print(i, ' empty mask')
        patch_center = [np.random.randint(0, s) for s in img.shape]
        lower = [0, 0, 0]
        
        upper = [img.shape[0], img.shape[1], img.shape[2]]
        x = patch_center[2]
        y = patch_center[1]
        z = patch_center[0]
        
        #for ins
        gt = np.copy(mask)
        
        flag_empty = True
        
    else:
        print(i, ' normal sample')
        indices = np.nonzero(mask == chosen_vert)
        lower = [np.min(i) for i in indices]
        upper = [np.max(i) for i in indices]
        #random center of patch
        x = randint(lower[2], upper[2])
        y = randint(lower[1], upper[1])
        z = randint(lower[0], upper[0])
    
    #extract the patch and padding
    x_low = int(max(x-patch_size/2, 0))
    x_up = int(min(x+patch_size/2,  img.shape[2]))
    
    y_low = int(max(y-patch_size/2, 0))
    y_up = int(min(y+patch_size/2,  img.shape[1]))
    
    z_low = int(max(z-patch_size/2, 0))
    z_up = int(min(z+patch_size/2,  img.shape[0]))
    
    x_pad, y_pad, z_pad = np.zeros(2), np.zeros(2), np.zeros(2)
    
    img_patch = img[z_low:z_up, y_low:y_up, x_low:x_up]
    ins_patch = ins_memory[z_low:z_up, y_low:y_up, x_low:x_up]
    gt_patch = gt[z_low:z_up, y_low:y_up, x_low:x_up]
    
    weight_patch = weight[z_low:z_up, y_low:y_up, x_low:x_up]
    

    #paddding the patch to 128*128*128
    if x_low == 0:
      x_pad[0] = int(patch_size - img_patch.shape[2]) 
    elif x_up == img.shape[2]:
      x_pad[1] = int(patch_size - img_patch.shape[2]) 
      
    if y_low == 0:
      y_pad[0] = int(patch_size - img_patch.shape[1]) 
    elif y_up == img.shape[1]:
      y_pad[1] = int(patch_size - img_patch.shape[1]) 
      
    if z_low == 0:
      z_pad[0] = int(patch_size - img_patch.shape[0]) 
    elif z_up == img.shape[0]:
      z_pad[1] = int(patch_size - img_patch.shape[0]) 
   
    x_pad = x_pad.astype(int)
    y_pad = y_pad.astype(int)
    z_pad = z_pad.astype(int)
    
    img_patch = np.pad(img_patch, ((z_pad[0], z_pad[1]), (y_pad[0], y_pad[1]), (x_pad[0], x_pad[1])), 'constant', constant_value=img.min())
    ins_patch = np.pad(ins_patch, ((z_pad[0], z_pad[1]), (y_pad[0], y_pad[1]), (x_pad[0], x_pad[1])), 'constant', constant_values=ins_memory.min())
    gt_patch = np.pad(gt_patch,   ((z_pad[0], z_pad[1]), (y_pad[0], y_pad[1]), (x_pad[0], x_pad[1])), 'constant', constant_values=gt.min())
    weight_patch = np.pad(weight_patch, ((z_pad[0], z_pad[1]), (y_pad[0], y_pad[1]), (x_pad[0], x_pad[1])), 'constant', constant_values=weight.min())
    
    
    if flag_empty:
      #print('1/6 patches produced')
      ins_patch = gt_patch
      gt_patch = np.zeros_like(ins_patch)
      weight = np.ones_like(ins_patch)
    
    
    #give the label of completeness(partial or complete)
    vol = np.count_nonzero(gt == 1)
    sample_vol = np.count_nonzero(gt_patch == 1 )
    
    #print('visible volume:{:.6f}'.format(float(sample_vol/(vol+0.0001))))
    c_label = 0 if float(sample_vol/(vol+0.0001)) < 0.98 else 1
    
    #Randomly Data Augmentation
    # 50% chance elastic deformation
    aug = randint(0,3)
    
    if aug==0 and not flag_empty:
        #print('elastic deform')
        img_patch, gt_patch, ins_patch, weight_patch = elastic_transform(img_patch, gt_patch, ins_patch, weight_patch, alpha=300, sigma=8)
    # 50% chance gaussian blur
    if aug==1 and not flag_empty:
        #print('gaussian blur')
        img_patch = gaussian_blur(img_patch)
    # 50% chance gaussian noise
    if aug==2 and not flag_empty:
        #print('gaussian noise')
        img_patch = gaussian_noise(img_patch)
    """
    # 20% chance random crop 
    if np.random.rand() <= 0.5:
#        print('random crop')
        k = randint(0, 128)
        img_patch, ins_patch, gt_patch = crop_z(img_patch, ins_patch, gt_patch, k)
    """    
    
    
    img_patch = np.expand_dims(img_patch, axis=0)
    ins_patch = np.expand_dims(ins_patch, axis=0)
    gt_patch = np.expand_dims(gt_patch, axis=0)
    weight_patch = np.expand_dims(weight_patch, axis=0)
    c_label = np.expand_dims(c_label, axis=0)
    
    
    return img_patch, ins_patch, gt_patch, weight_patch, c_label


#%%% Test purpose
data_root = './drive/My Drive/isotropic_dataset'
train_set = CSI_Dataset(data_root, subset='train')
test_set = CSI_Dataset(data_root, subset='test')

dataloader_train = DataLoader(train_set, batch_size=1)
dataloader_test = DataLoader(test_set, batch_size=1)

"""
#produce training patches
training_sample = 5000
c_train = []

for i in range(training_sample):
  img_patch, ins_patch, gt_patch, weight, c_label = next(iter(dataloader_train))

 
  print(i, ' samples generated...')
  
  img_patch = torch.squeeze(img_patch)
  ins_patch = torch.squeeze(ins_patch)
  gt_patch = torch.squeeze(gt_patch)
  weight = torch.squeeze(weight)
  c_train.append(c_label.item())


  sitk.WriteImage(sitk.GetImageFromArray(img_patch.numpy()), './drive/My Drive/patches/train/img/img_'+ str(i)+'.nrrd', True)
  sitk.WriteImage(sitk.GetImageFromArray(gt_patch.numpy()), './drive/My Drive/patches/train/gt/gt_'+ str(i)+'.nrrd', True)
  sitk.WriteImage(sitk.GetImageFromArray(ins_patch.numpy()), './drive/My Drive/patches/train/ins/ins_'+ str(i)+'.nrrd', True)
  sitk.WriteImage(sitk.GetImageFromArray(weight.numpy()), './drive/My Drive/patches/train/weight/weight_'+ str(i)+'.nrrd', True)


pd.Series(c_train).to_excel('./drive/My Drive/patches/train/train_label.xlsx')




#produce test patches
test_sample = 1000
c_test = []
for i in range(test_sample):
  img_patch, ins_patch, gt_patch, weight, c_label = next(iter(dataloader_test))

  if not i % 1000:
    print(i, ' samples generated...')


  img_patch = torch.squeeze(img_patch)
  ins_patch = torch.squeeze(ins_patch)
  gt_patch = torch.squeeze(gt_patch)
  weight = torch.squeeze(weight)
  c_test.append(c_label.item())


  sitk.WriteImage(sitk.GetImageFromArray(img_patch.numpy()), './drive/My Drive/patches/test/img/img_'+ str(i)+'.nrrd', True)
  sitk.WriteImage(sitk.GetImageFromArray(gt_patch.numpy()), './drive/My Drive/patches/test/gt/gt_'+ str(i)+'.nrrd', True)
  sitk.WriteImage(sitk.GetImageFromArray(ins_patch.numpy()), './drive/My Drive/patches/test/ins/ins_'+ str(i)+'.nrrd', True)
  sitk.WriteImage(sitk.GetImageFromArray(weight.numpy()), './drive/My Drive/patches/test/weight/weight_'+ str(i)+'.nrrd', True)


pd.Series(c_test).to_excel('./drive/My Drive/patches/test/test_label.xlsx')
"""

In [0]:
"""
Training and Evaluation
"""


# -*- coding: utf-8 -*-
"""
Created on Thu Sep 19 11:21:22 2019

@author: Gabriel Hsu
"""
from __future__ import print_function, division
import os
import argparse
import time


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

from model import iterativeFCN
#from dataset_simple import CSI_Dataset_Patched
from metrics import DiceCoeff, ASSD

import SimpleITK as sitk

def seg_loss(pred, target, weight):
    size = pred.shape[0]
    FP = torch.sum(weight*(1-target)*pred)
    FN = torch.sum(weight*(1-pred)*target)
    return FP/size, FN/size
    
#%%
def train_single(args, model, device, img_patch, ins_patch, gt_patch, weight, c_label, optimizer):

    torch.cuda.empty_cache()
  
    model.train()
    correct = 0
    
    img_patch = img_patch.float()
    ins_patch = ins_patch.float()
    gt_patch = gt_patch.float()
    weight = weight.float()
    c_label = c_label.float()
    
    
    #pick a random scan
    optimizer.zero_grad()
    
    #concatenate the img_patch and ins_patch
    input_patch = torch.cat((img_patch, ins_patch), dim=1)
    input_patch, gt_patch, weight, c_label = input_patch.to(device), gt_patch.to(device), weight.to(device), c_label.to(device)
    S, C = model(input_patch.float())        
    
    
    #Calculate DiceCoeff
    pred = torch.round(S).detach()
    train_dice_coef =  DiceCoeff(pred, gt_patch.detach())
    

    #compute the loss
    lamda = 0.1
    
    #segloss 
    FP, FN = seg_loss(S, gt_patch, weight) 
    
    s_loss = lamda*FP + FN
    
    c_loss = -1*c_label*torch.log(C)-(1-c_label)*torch.log(1-C)

    print(s_loss.item(), c_loss.item())
    
    train_loss = s_loss + c_loss
    
    
    
    if C.round() == c_label:
        correct = 1
    
    #optimize the parameters
    train_loss.backward()
    optimizer.step()

    return train_loss.item(), correct, train_dice_coef

def test_single(args, model, device, img_patch, ins_patch, gt_patch, weight, c_label):
    
    torch.cuda.empty_cache()
    
    model.eval()
    correct = 0
    
    img_patch = img_patch.float()
    ins_patch = ins_patch.float()
    gt_patch = gt_patch.float()
    weight = weight.float()
    c_label = c_label.float()
    
    input_patch = torch.cat((img_patch, ins_patch), dim=1)
    input_patch, gt_patch, weight, c_label = input_patch.to(device), gt_patch.to(device), weight.to(device), c_label.to(device)
    
    with torch.no_grad():
        S, C = model(input_patch.float())
        
    """
    pred = torch.squeeze(S.to('cpu'))
    sitk.WriteImage(sitk.GetImageFromArray(pred.numpy()), './pred.nrrd', True)
    
    gtt = torch.squeeze(gt_patch.to('cpu'))
    sitk.WriteImage(sitk.GetImageFromArray(gtt.numpy()), './gt.nrrd', True)
    """
    
    #Calculate DiceCoeff
    pred = torch.round(S).detach()
    test_dice_coef =  DiceCoeff(pred, gt_patch.detach())  
    
    #compute the loss
    lamda = 0.1
    
    #segloss 
    FP, FN = seg_loss(S, gt_patch, weight) 
    
    s_loss = lamda*FP + FN
    
    c_loss = -1*c_label*torch.log(C)-(1-c_label)*torch.log(1-C)
    
    print(s_loss.item(), c_loss.item())
    
    if C.round() == c_label:
        correct = 1

    test_loss = s_loss + c_loss
    
        
    return test_loss.item(), correct, test_dice_coef
    
#%%Main
if  __name__ == "__main__" :   
    # Version of Pytorch
    print("Pytorch Version:", torch.__version__)
    
    # Training args
    parser = argparse.ArgumentParser(description='Fully Convolutional Network')
    parser.add_argument('--batch-size', type=int, default=1, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=100, metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum', type=float, default=0.99, metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda', action='store_true', default=True,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=1000, metavar='N',
                        help='how many batches to wait before logging training status')
    
    parser.add_argument('--save-model', action='store_true', default=True,
                        help='For Saving the current Model')
    
    args = parser.parse_known_args()[0]
    #args = parser.parse_args()

    # Use GPU if it is available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #data_root = './drive/My Drive/patches'
    
    
    # Create FCN
    model = iterativeFCN().to('cuda')
    #model.load_state_dict(torch.load('./drive/My Drive/IterativeFCN_best_norm.pth'))
     

    """
    batch_size = args.batch_size
    batch_size_valid = batch_size
    
    train_clabel = list(pd.read_excel(os.path.join(data_root, 'train/train_label.xlsx'))[0])   
    test_clabel = list(pd.read_excel(os.path.join(data_root, 'test/test_label.xlsx'))[0])   

    
    train_set = CSI_Dataset_Patched(data_root, train_clabel, subset='train')
    test_set = CSI_Dataset_Patched(data_root, test_clabel, subset='test')
    
    """
    train_loader = dataloader_train
    test_loader = dataloader_test
    
#%%    
    #optimizer
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    
    train_loss = []
    test_loss = []
    train_acc = []
    test_acc = []
    train_dice = []
    test_dice = []
    best_test_dice = 0
    
    total_iteration = 20000
    train_interval = 50
    eval_interval =  10
    
    # Start Training
    for epoch in range(int(total_iteration/train_interval)):
        
        start_time = time.time()
        epoch_train_dice = []
        epoch_test_dice = []
        epoch_train_loss = []
        epoch_test_loss = []
        epoch_train_accuracy = 0.
        epoch_test_accuracy = 0.
        correct_train_count = 0
        correct_test_count = 0
        
        #training process
        for i in range(train_interval):
            img_patch, ins_patch, gt_patch, weight, c_label = next(iter(train_loader))
            t_loss, t_c, t_dice = train_single(args, model, device, img_patch, ins_patch, gt_patch, weight, c_label, optimizer)
            epoch_train_loss.append(t_loss)
            epoch_train_dice.append(t_dice)
            correct_train_count+=t_c
            
        epoch_train_accuracy = correct_train_count/train_interval
        avg_train_loss = sum(epoch_train_loss) / len(epoch_train_loss)
        avg_train_dice = sum(epoch_train_dice) / len(epoch_train_dice)
        
        print('Train Epoch: {} \t Loss: {:.6f}\t acc: {:.6f}%\t dice: {:.6f}%'.format(epoch
              , avg_train_loss
              , epoch_train_accuracy*100
              , avg_train_dice*100))

        
        #validation process
        for i in range(eval_interval):
            img_patch, ins_patch, gt_patch, weight, c_label = next(iter(test_loader))
            v_loss, v_c, v_dice = test_single(args, model, device, img_patch, ins_patch, gt_patch, weight, c_label)
            epoch_test_loss.append(v_loss)
            epoch_test_dice.append(v_dice)
            correct_test_count+=v_c
            
        epoch_test_accuracy = correct_test_count/eval_interval
        avg_test_loss = sum(epoch_test_loss) / len(epoch_test_loss)
        avg_test_dice = sum(epoch_test_dice) / len(epoch_test_dice)
        
        
        print('Validation Epoch: {} \t Loss: {:.6f}\t acc: {:.6f}%\t dice: {:.6f}%'.format(epoch
              , avg_test_loss
              , epoch_test_accuracy*100
              , avg_test_dice*100))
        
        if avg_test_dice > best_test_dice:
            best_test_dice = avg_test_dice
            print('--- Saving model at Avg Test Dice:{:.2f}%  ---'.format(avg_test_dice))
            torch.save(model.state_dict(),'./drive/My Drive/IterativeFCN_best_norm.pth')
        
        print('-------------------------------------------------------')
        
        train_loss.append(epoch_train_loss)
        test_loss.append(epoch_test_loss)
        train_acc.append(epoch_train_accuracy)
        test_acc.append(epoch_test_accuracy)
        

        print("--- %s seconds ---" % (time.time() - start_time))

        
    

In [0]:
print("training:", len(train_loss))
print("validation:", len(test_loss))
x = list(range(1, len(train_loss)))
#plot train/validation loss versus epoch
plt.figure()
plt.title("Train/Validation Loss")
plt.xlabel("Epochs")
plt.ylabel("Total Loss")
plt.plot(x, train_loss,label="train loss")
plt.plot(x, test_loss, color='red', label="validation loss")
plt.legend(loc='upper right')
plt.grid(True)
plt.show()

#plot train/validation loss versus epoch
plt.figure()
plt.title("Train/Validation Accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.plot(x, train_acc,label="train acc")
plt.plot(x, test_acc, color='red', label="validation acc")
plt.legend(loc='upper right')
plt.grid(True)
plt.show()