In [1]:
import SimpleITK as sitk
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt
from scipy import ndimage
import sys
sys.path.append("/home/xczhu/xzhu_clean/")
import MRI_tools.IO as IO

  from ._conv import register_converters as _register_converters


In [2]:
from PIL import Image
import os
def load_imgs(img_dir):
    Imgs = []
    img_lst = os.listdir(img_dir)
    img_lst.sort(reverse=True)
    for filename in img_lst:
        Imgs.append(np.array(Image.open(img_dir+filename)))
        
    return np.asarray(Imgs)

In [3]:
# sampling 
# 2.5 D sampling
import random 

def CHAOS_sampling(data_dir, data_id, key = 'CT', nslice = 5):
    full_data_dir = data_dir + '/' + key + '/' + str(data_id) 
    img = IO.read_dcm.read_dcms(full_data_dir + '/DICOM_anon/')
    label = IO.image_io.load_imgs(full_data_dir + '/Ground/')

    imgs, labels = sample_2p5(img, label, 100, nslice)
    
    return imgs, labels

def sample_2p5(img_in, img_out, N_sample, nslice = 5):
    N_sample = min([img_in.shape[0]-nslice, N_sample])
    N = img_out.shape[0]
    ind_list = [*range(nslice//2,N-nslice//2)]
    random.shuffle(ind_list)
    
    imgs_in = []
    imgs_out = []
    for i in range(N_sample):
        ind_t = ind_list[i]
        imgs_in.append(img_in[ind_t-nslice//2:ind_t+(1+nslice)//2,...])
        imgs_out.append(img_out[ind_t:ind_t+1,...])
        
    return imgs_in, imgs_out

In [None]:
from DL_torch import network
import torch.optim as optim
import torch
import torch.nn as nn

nslice = 5
unet_level = 4
net = network.UnetModule(nslice, 1, base_cn = 8, unet_level = unet_level)

train_dir = '/home/xczhu/tools/CHAOS/Train_Sets/'
train_id_lst = os.listdir('/home/xczhu/tools/CHAOS/Train_Sets/CT/')
train_id_lst = train_id_lst[:5]

test_dir = '/home/xczhu/tools/CHAOS/Test_Sets/'
test_id = os.listdir('/home/xczhu/tools/CHAOS/Test_Sets/CT/')

dir_checkpoint = './'

lr = 1e-3
epochs = 20
batch_size = 20
n_channel = nslice

gpu=torch.cuda.is_available()
sig = nn.Sigmoid()
save_cp=True
print('''
    Starting training:
    Epochs: {}
    Batch size: {}
    Learning rate: {}
    Checkpoints: {}
    CUDA: {}
'''.format(epochs, batch_size, lr, str(save_cp), str(gpu)))

cp_prefix = 'Seg_channel{}_unet_level{}'.format(n_channel,unet_level)
optimizer = optim.Adam(net.parameters(),lr=lr)
criterion = nn.MSELoss()

if gpu:
    device = torch.device("cpu")
    net = net.to(device)
    

for epoch in range(epochs):
    print('Starting epoch {}/{}.'.format(epoch + 1, epochs))
    net.train()
    i_imgs = []
    o_imgs = []
    for data_id in train_id_lst:
        i_img, o_img = CHAOS_sampling(train_dir, data_id)
        i_imgs = i_imgs + i_img
        o_imgs = o_imgs + o_img
    epoch_loss = 0
    N_train = len(i_imgs)

    for i in range(len(i_imgs)//batch_size):
        i_img = np.array([i_imgs[k] for k in range(i*batch_size,min([(i+1)*batch_size,len(i_imgs)]))]).astype(np.float32)
        o_img = np.array([o_imgs[k] for k in range(i*batch_size,min([(i+1)*batch_size,len(i_imgs)]))]).astype(np.float32)
        i_img = torch.from_numpy(i_img)
        o_img = torch.from_numpy(o_img)

        if gpu:
            i_img = i_img.to(device)
            o_img = o_img.to(device)
        optimizer.zero_grad()
        img_pred = net(i_img)
        img_pred = sig(img_pred)
        # img_mean = torch.mean(i_img.view(batch_size,-1,1,1,1),1,True)
        img_pred = img_pred.view(-1)
        o_img = o_img.view(-1)
        #img_pred = img_pred.view(-1)
        #o_img = o_img.view(-1)
        loss = criterion(img_pred, o_img)
        epoch_loss += loss

        # loss.item()
        # print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train, epoch_loss[0]))

        loss.backward()
        optimizer.step()

        print('Epoch finished ! Loss: {}'.format(epoch_loss / i))

    if save_cp and (epoch+1)%10==0:
        dir_checkpoint_cp = dir_checkpoint + cp_prefix +'CP{}.pth'.format(epoch + 1)
        if os.path.exists(dir_checkpoint) is False:
            os.mkdir(dir_checkpoint)

        torch.save({
            'epoch': epoch,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
        }, dir_checkpoint_cp)
        # torch.save(net.state_dict(),dir_checkpoint_cp)
        print('Checkpoint {} saved !'.format(epoch + 1))


    Starting training:
    Epochs: 20
    Batch size: 20
    Learning rate: 0.001
    Checkpoints: True
    CUDA: True

Starting epoch 1/20.
Reading Dicom directory: /home/xczhu/tools/CHAOS/Train_Sets//CT/5/DICOM_anon/
Reading Dicom directory: /home/xczhu/tools/CHAOS/Train_Sets//CT/14/DICOM_anon/
Reading Dicom directory: /home/xczhu/tools/CHAOS/Train_Sets//CT/1/DICOM_anon/
Reading Dicom directory: /home/xczhu/tools/CHAOS/Train_Sets//CT/2/DICOM_anon/
Reading Dicom directory: /home/xczhu/tools/CHAOS/Train_Sets//CT/8/DICOM_anon/
Epoch finished ! Loss: inf
Epoch finished ! Loss: 0.5712921023368835
Epoch finished ! Loss: 0.4374381899833679
Epoch finished ! Loss: 0.3878590166568756
Epoch finished ! Loss: 0.35584962368011475
Epoch finished ! Loss: 0.3377639651298523
Epoch finished ! Loss: 0.32386112213134766
Epoch finished ! Loss: 0.31387999653816223
Epoch finished ! Loss: 0.30651533603668213
Epoch finished ! Loss: 0.3015059232711792
Epoch finished ! Loss: 0.2966545820236206
Epoch finished !

Reading Dicom directory: /home/xczhu/tools/CHAOS/Train_Sets//CT/14/DICOM_anon/
Reading Dicom directory: /home/xczhu/tools/CHAOS/Train_Sets//CT/1/DICOM_anon/
Reading Dicom directory: /home/xczhu/tools/CHAOS/Train_Sets//CT/2/DICOM_anon/
Reading Dicom directory: /home/xczhu/tools/CHAOS/Train_Sets//CT/8/DICOM_anon/
Epoch finished ! Loss: inf
Epoch finished ! Loss: 0.500004768371582
Epoch finished ! Loss: 0.375003457069397
Epoch finished ! Loss: 0.33333590626716614
Epoch finished ! Loss: 0.3125022351741791
Epoch finished ! Loss: 0.30000191926956177
Epoch finished ! Loss: 0.29166850447654724
Epoch finished ! Loss: 0.2857159972190857
Epoch finished ! Loss: 0.2812517285346985
Epoch finished ! Loss: 0.27777937054634094
Epoch finished ! Loss: 0.2750026285648346
Epoch finished ! Loss: 0.2727504074573517
Epoch finished ! Loss: 0.2708546221256256
Epoch finished ! Loss: 0.26925045251846313
Epoch finished ! Loss: 0.26787978410720825
Epoch finished ! Loss: 0.26669755578041077
Epoch finished ! Loss: 0.

In [None]:
# network definition
import torch.optim as optim
import torch
import torch.nn as nn
from super_resnet import *
from util import *
import os

data_dir = '../tmp_data/'
train_id = [1,2,3,4,5,6,7,8,9,10,11,13,14,15,16]
test_id = [17,18,19,20]
train_data_dir = [data_dir+'case{}/'.format(i) for i in train_id]
test_data_dir = [data_dir+'case{}/'.format(i) for i in test_id]
dir_checkpoint = './'

def train(lr,epochs,batch_size,patch_size,n_channel,n_layers):
    # prep step 
    net = super_resnet(n_channel,n_layers)
    gpu=torch.cuda.is_available()
    save_cp=True
    patch_size = [patch_size,patch_size,patch_size]
    print('''
        Starting training:
        Epochs: {}
        Batch size: {}
        Learning rate: {}
        Checkpoints: {}
        CUDA: {}
    '''.format(epochs, batch_size, lr, str(save_cp), str(gpu)))
    
    cp_prefix = 'Sup_patch{}_layer{}_channel{}'.format(patch_size[0],n_layers,n_channel)
    optimizer = optim.Adam(net.parameters(),lr=lr)
    criterion = nn.MSELoss()

    if gpu:
        net = net.cuda()

    for epoch in range(epochs):
        print('Starting epoch {}/{}.'.format(epoch + 1, epochs))
        net.train()
        i_imgs, o_imgs = random_patch_sampler(train_data_dir, patch_size, 100)    
        epoch_loss = 0
        N_train = len(i_imgs)

        for i in range(len(i_imgs)//batch_size):
            i_img = np.array([[i_imgs[k]] for k in range(i*batch_size,min([(i+1)*batch_size,len(i_imgs)]))]).astype(np.float32)
            o_img = np.array([[o_imgs[k]] for k in range(i*batch_size,min([(i+1)*batch_size,len(i_imgs)]))]).astype(np.float32)
            i_img = torch.from_numpy(i_img)
            o_img = torch.from_numpy(o_img)

            if gpu:
                i_img = i_img.cuda()
                o_img = o_img.cuda()
            optimizer.zero_grad()
            img_pred = net(i_img)
            # img_mean = torch.mean(i_img.view(batch_size,-1,1,1,1),1,True)
            img_pred = img_pred.view(-1)
            o_img = o_img.view(-1)
            #img_pred = img_pred.view(-1)
            #o_img = o_img.view(-1)
            loss = criterion(img_pred, o_img)
            epoch_loss += loss

            # loss.item()
            # print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train, epoch_loss[0]))

            loss.backward()
            optimizer.step()

        print('Epoch finished ! Loss: {}'.format(epoch_loss / i))

        if save_cp and (epoch+1)%10==0:
            dir_checkpoint_cp = dir_checkpoint + cp_prefix +'CP{}.pth'.format(epoch + 1)
            if os.path.exists(dir_checkpoint) is False:
                os.mkdir(dir_checkpoint)

            torch.save({
                'epoch': epoch,
                'model_state_dict': net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss
            }, dir_checkpoint_cp)
            # torch.save(net.state_dict(),dir_checkpoint_cp)
            print('Checkpoint {} saved !'.format(epoch + 1))

def main():
    import argparse
    parser = argparse.ArgumentParser()
    # Training args
    parser.add_argument('--learning_rate', '-lr', type=float, default=1e-3)
    parser.add_argument('--epochs', '-n', type=int, default=50)
    parser.add_argument('--batch_size', '-b', type=int, default=40)
    # Neural network architecture args
    parser.add_argument('--patch_size', '-p', type=int, default=64)
    parser.add_argument('--n_layers', '-l', type=int, default=4)
    parser.add_argument('--n_channel', '-c', type=int, default=16)

    args = parser.parse_args()
    
    train(lr = args.learning_rate,
          epochs = args.epochs,
          batch_size = args.batch_size,
          patch_size = args.patch_size,
          n_channel = args.n_channel,
          n_layers = args.n_layers)


if __name__ == '__main__':
    main()

In [None]:
N = 40
plt.figure(1,figsize=(5,5))
plt.imshow(img[N],cmap='gray')
plt.imshow(label[N],cmap='jet', alpha=0.5)

