In [None]:
import os, tqdm, time, random
import cv2 as cv
import numpy as np
import pandas as pd
import nibabel as nib
import tensorflow as tf
from matplotlib import pyplot as plt
from tensorflow.keras import models, layers, losses, optimizers

from loss import *
from utils import *
from network import *

random.seed(777)
tf.set_random_seed(777)
np.random.seed(777)

In [None]:
root_dir = '../data/for_section/'

fold_idx = 1

train_path = os.path.join(root_dir, 'for_fold_%d'%fold_idx)

val_path = os.path.join(root_dir, 'for_fold_%d_val'%fold_idx)

total = data_loader(train_path)

data = []
label = []
n_size = 32
n_slice = 16
cs_strides = 16
a_strides = 2

#print(test.shape)

for i in range(2):
    test = total[i]
    cor, sag, axi = test.shape
    means = [test[idx:idx+n_size, jdx:jdx+n_size, kdx:kdx+(n_slice*6)].mean() 
             for idx in range(0, cor-n_size, cs_strides) 
             for jdx in range(0, sag-n_size, cs_strides) 
             for kdx in range(0, axi-(n_slice*6), a_strides)]


    for idx in tqdm.tqdm_notebook(range(0, cor-n_size, cs_strides)):
        for jdx in range(0, sag-n_size, cs_strides):
            for kdx in range(0, axi-(n_slice*6), a_strides):
                tmp = test[idx:idx+n_size, jdx:jdx+n_size, kdx:kdx+(n_slice*6)]
                if tmp.mean() > np.mean(means)+10:
                    label.append(tmp)
                    tmp = np.array(np.dsplit(tmp, n_slice))
                    tmp = tmp.mean(axis=-1)
                    tmp = np.transpose(tmp, [1,2,0])
                    data.append(tmp)
                
label = np.array(label)[..., np.newaxis]
data = np.array(data)[..., np.newaxis]
print(data.shape)
print(label.shape)


# Prepare Validation
val_data = []
val_label = []

# Prepare Validation
scan_list = sorted(os.listdir(val_path))[1:3]
for scan in scan_list:
    dante_path = os.path.join(val_path, scan, 'T1SPACE09mmISOPOSTwDANTE')
    img_name = [i for i in os.listdir(dante_path) if '.nii' in i and '_rsl' not in i][0]
    #print(img_name)
    val = nib.load(os.path.join(dante_path, img_name))
    val = check_data(val.get_data())

cor, sag, axi = test.shape
#print(test.shape)
for idx in tqdm.tqdm_notebook(range(0, cor-n_size, cs_strides)):
    for jdx in range(0, sag-n_size, cs_strides):
        for kdx in range(0, axi-(n_slice*6), a_strides):
            tmp = val[idx:idx+n_size, jdx:jdx+n_size, kdx:kdx+(n_slice*6)]
#             means.append(tmp.mean())
            if tmp.mean() > 225:
                val_label.append(tmp)
                tmp = np.array(np.dsplit(tmp, n_slice))
                tmp = tmp.mean(axis=-1)
                tmp = np.transpose(tmp, [1,2,0])
                val_data.append(tmp)
                
val_label = np.array(val_label)[..., np.newaxis]
val_data = np.array(val_data)[..., np.newaxis]
print(val_data.shape)
print(val_label.shape)

In [None]:
net = SR3D()

In [None]:
date = time.ctime().split(' ')

ckpt_root = './checkpoint/%s_%02d_%s/DeepSR_msegrad_test'%(date[1], int(date[2]), date[-1])
result_root = './result/%s_%02d_%s/DeepSR_msegrad_test'%(date[1], int(date[2]), date[-1])

try:
    os.makedirs(ckpt_root)
    print("\nMake Save Directory!\n")
except:
    print("\nDirectory Already Exist!\n")

try:
    os.makedirs(result_root)
    print("\nMake Save Directory!\n")
except:
    print("\nDirectory Already Exist!\n")
    

model_json = net.to_json()
with open(os.path.join(ckpt_root, "model.json"), "w") as json_file:
    json_file.write(model_json)
print("\nModel Saved!\n")

In [None]:
net.compile(optimizer=optimizers.Adam(0.0001), loss=mse_grad_loss, metrics=['mse', gradient_3d_loss, mutual_information])

In [None]:
history = net.fit(data, label, batch_size=16, epochs=100, validation_data=[val_data, val_label])

net.save_weights(os.path.join(ckpt_root, 'weight.h5'))

df = pd.DataFrame(history.history)
df.to_csv(os.path.join(result_root, 'loss.csv'))

In [None]:
scan_list = sorted(os.listdir(val_path))[1:]
for scan in scan_list:
    dante_path = os.path.join(val_path, scan, 'T1SPACE09mmISOPOSTwDANTE')
    img_name = [i for i in os.listdir(dante_path) if '.nii' in i and '_rsl' not in i][0]
    #print(img_name)
    val = nib.load(os.path.join(dante_path, img_name))
    val = check_data(val.get_data())
    
val_in = []
cor, sag, axi = val.shape
tmp = np.array(np.dsplit(val, axi//6))
tmp = tmp.mean(axis=-1)
tmp = np.transpose(tmp, [1, 2, 0])

half_top = int(np.floor(tmp.shape[-1]/2))
#half_bot = int(np.ceil(tmp.shape[-1]/2))

In [None]:
recon = np.zeros(shape=[cor, sag, axi])

slice_dict = {
    1:[[0, 160], [0, 160], [0, 128], [0, 128]], 
    2:[[0, 160], [96, 256], [0, 128], [32, 256]], 
    3:[[96, 256], [0, 160], [32, 256], [0, 128]],
    4:[[96, 256], [96, 256], [32, 256], [32, 256]]
}
test = {}
sli = 12
i = 0
for row in range(2):
    row_start = row*128
    for col in range(2):
        
        col_start = col*128
        print(i, row_start, col_start)
        test[i] = net.predict(tmp[np.newaxis, 
                                   slice_dict[i+1][0][0]:slice_dict[i+1][0][1],
                                   slice_dict[i+1][1][0]:slice_dict[i+1][1][1],
                                   :, np.newaxis])
        recon[row_start:row_start+128, 
              col_start:col_start+128] = test[i][0,
                                                  slice_dict[i+1][2][0]:slice_dict[i+1][2][1],
                                                  slice_dict[i+1][3][0]:slice_dict[i+1][3][1], :, 0]
        i += 1
# test_3d_2 = net.predict(tmp[np.newaxis, ..., half_top:, np.newaxis])

In [None]:
aff = np.eye(4)
aff[2, 2]=6

nib.save(nib.Nifti1Image(tmp, aff), os.path.join(result_root, 'val_input.nii'))
nib.save(nib.Nifti1Image(recon, np.eye(4)), os.path.join(result_root, 'val_pred.nii'))
nib.save(nib.Nifti1Image(val, np.eye(4)), os.path.join(result_root, 'val_label.nii'))