In [3]:
%pylab inline
import dicom
import os
import time
import pickle
import tensorflow as tf
from tensorflow.python.framework.ops import reset_default_graph
from scipy.misc import imresize, imsave, imread
from skimage import exposure 
from skimage.util import random_noise
from sklearn.preprocessing import scale
from sklearn.feature_extraction import image
from skimage.measure import structural_similarity as ssim
from PIL import Image
import png
import cv2
import keras
from keras.preprocessing.image import ImageDataGenerator
import imreg_dft as ird

Populating the interactive namespace from numpy and matplotlib


Using TensorFlow backend.


### Data Preprocessing Functions and Loss Functions

In [None]:
def _tf_fspecial_gauss(size, sigma):
    """Function to mimic the 'fspecial' gaussian MATLAB function
    """
    x_data, y_data = np.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1]

    x_data = np.expand_dims(x_data, axis=-1)
    x_data = np.expand_dims(x_data, axis=-1)

    y_data = np.expand_dims(y_data, axis=-1)
    y_data = np.expand_dims(y_data, axis=-1)

    x = tf.constant(x_data, dtype=tf.float32)
    y = tf.constant(y_data, dtype=tf.float32)

    g = tf.exp(-((x**2 + y**2)/(2.0*sigma**2)))
    return g / tf.reduce_sum(g)

def tf_ssim(img1, img2, cs_map=False, mean_metric=True, size=11, sigma=1.5):
    window = _tf_fspecial_gauss(size, sigma) # window shape [size, size]
    K1 = 0.01
    K2 = 0.03
    L = 1  # depth of image (255 in case the image has a differnt scale)
    C1 = (K1*L)**2
    C2 = (K2*L)**2
    mu1 = tf.nn.conv2d(img1, window, strides=[1,1,1,1], padding='VALID')
    mu2 = tf.nn.conv2d(img2, window, strides=[1,1,1,1],padding='VALID')
    mu1_sq = mu1*mu1
    mu2_sq = mu2*mu2
    mu1_mu2 = mu1*mu2
    sigma1_sq = tf.nn.conv2d(img1*img1, window, strides=[1,1,1,1],padding='VALID') - mu1_sq
    sigma2_sq = tf.nn.conv2d(img2*img2, window, strides=[1,1,1,1],padding='VALID') - mu2_sq
    sigma12 = tf.nn.conv2d(img1*img2, window, strides=[1,1,1,1],padding='VALID') - mu1_mu2
    if cs_map:
        value = (((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*
                    (sigma1_sq + sigma2_sq + C2)),
                (2.0*sigma12 + C2)/(sigma1_sq + sigma2_sq + C2))
    else:
        value = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*
                    (sigma1_sq + sigma2_sq + C2))

    if mean_metric:
        value = tf.reduce_mean(value)
    return value


def tf_ms_ssim(img1, img2, mean_metric=True, level=5):
    weight = tf.constant([0.0448, 0.2856, 0.3001, 0.2363, 0.1333], dtype=tf.float32)
    mssim = []
    mcs = []
    for l in range(level):
        ssim_map, cs_map = tf_ssim(img1, img2, cs_map=True, mean_metric=False)
        mssim.append(tf.reduce_mean(ssim_map))
        mcs.append(tf.reduce_mean(cs_map))
        filtered_im1 = tf.nn.avg_pool(img1, [1,2,2,1], [1,2,2,1], padding='SAME')
        filtered_im2 = tf.nn.avg_pool(img2, [1,2,2,1], [1,2,2,1], padding='SAME')
        img1 = filtered_im1
        img2 = filtered_im2

    # list to tensor of dim D+1
    mssim = tf.pack(mssim, axis=0)
    mcs = tf.pack(mcs, axis=0)

    value = (tf.reduce_prod(mcs[0:level-1]**weight[0:level-1])*
                            (mssim[level-1]**weight[level-1]))

    if mean_metric:
        value = tf.reduce_mean(value)
    return value

def sobel_conv(images, dim=5):
    sobel_x = tf.constant([
            [1, 0, -2, 0, 1],
            [4, 0, -8, 0, 4],
            [6, 0, -12, 0, 6],
            [4, 0, -8, 0, 4],
            [1, 0, -2, 0, 1]
        ], tf.float32)
    sobel_x_filter = tf.reshape(sobel_x, [dim, dim, 1, 1])
    sobel_y_filter = tf.transpose(sobel_x_filter, [1, 0, 2, 3])
    
    filtered_x = tf.nn.conv2d(images, sobel_x_filter, strides=[1, 1, 1, 1], padding='SAME')
    filtered_y = tf.nn.conv2d(images, sobel_y_filter, strides=[1, 1, 1, 1], padding='SAME')
    filtered = tf.sqrt(tf.pow(filtered_x, 2) + tf.pow(filtered_y, 2))
    return filtered

def extract_dicom(files):
    images = []
    # loop through all the DICOM files
    for i, filenameDCM in enumerate(files):
        print("step: " + filenameDCM + " ", i)
        
        # read the dcm file
#         ds = dicom.read_file(filenameDCM)
#         # store the raw image data
#         images += [ds.pixel_array]

        # read the jpg file
        ds = cv2.imread(filenameDCM)
        ds = cv2.cvtColor(ds, cv2.COLOR_BGR2GRAY)
        images += [ds]
    return images

def extract_data(num = -1, extension="jpg"):
    PathDicom = [
        "/Users/user/Downloads/to_test/to_test"
    ]
    lstFilesDCM = []  # create an empty list
    
    for path in PathDicom:
        for dirName, subdirList, fileList in os.walk(path):
            for filename in fileList:
                if "." + extension in filename.lower():
                    lstFilesDCM.append(os.path.join(dirName,filename))
    
    num = min(len(lstFilesDCM), num)
    if num == -1:
        num = len(lstFilesDCM)
    
    images = extract_dicom(sorted(lstFilesDCM)[:num])
    return images

def crop_to_square(image, upsampling):
    if image.shape[0] == image.shape[1]:
        return image
    if upsampling:
        img = Image.fromarray(image)
        target_side = max(img.size)
        horizontal_padding = (target_side - img.size[0]) / 2
        vertical_padding = (target_side - img.size[1]) / 2
        start = [-horizontal_padding, -vertical_padding]
        width = img.size[0] + horizontal_padding
        height = img.size[1] + vertical_padding
    else:
#         img = Image.fromarray(image)
        target_side = min(image.shape)
        horizontal_padding = int((image.shape[0] - target_side) / 2)
        vertical_padding = int((image.shape[1] - target_side) / 2)
        start = [horizontal_padding, vertical_padding]
        width = image.shape[0] - horizontal_padding
        height = image.shape[1] - vertical_padding
#         print(img.size, start, width, height)
        return image[start[0]:width, start[1]:height]
        
    img = img.crop((start[0], start[1], width, height))
    return np.array(img)

def preprocess(images, upsampling=False):
#     images = [im / (-1.) for im in images]
#     images = [scale(im.astype(float), axis=0).astype('float32') for im in images]
    images = [(im + abs(im.min())) / (im.max() + abs(im.min()))  for im in images]
    return images

def resize(images, size):
    return [imresize(i, (size,size), "lanczos") for i in images]

def crop(images, upsampling=False):
    return [crop_to_square(im, upsampling=upsampling) for im in images]

### Separate Readings of Original and Ground Truth Data

In [None]:
images = extract_data()

### Preprocessing

In [None]:
size = 1024
train = crop(images, upsampling='True')
train = resize(train, size)
train = preprocess(train)

### Image Preview

In [None]:
plt.figure(figsize=(7,7))
print('Min: ' + min(train[0]))
print('Max: ' + max(train[0]))
print('Mean: ' + mean(train[0]))
print('Std: ' + std(train[0]))
imshow(train[0].reshape((size,size)), cmap='gray')

In [None]:
trainX = np.reshape(train, (len(train), size, size, 1))

In [None]:
trainY = np.reshape(train, (len(train), size, size, 1))

### Registration

In [None]:
trainX = images

In [None]:
trainY = images

In [None]:
path = '/Users/user/Documents/bone-suppression/database/bs/processed/registered_orig_35'
for i in range(len(trainX)):
    # the template
    im0 = trainX[i]
    # the image to be transformed
    im1 = imresize(trainY[i], im0.shape, 'lanczos')
    result = ird.similarity(im0, im1, numiter=3)
    imsave(path + '/x/bs_' + str(i) + '.jpg', trainX[i])
    imsave(path + '/y/bs_' + str(i) + '.jpg', result['timg'])

### Augmentation

In [None]:
data_gen_args = dict(
                    featurewise_center=False,
                    samplewise_center=False,
                    featurewise_std_normalization=False,
                    samplewise_std_normalization=False,
                    zca_whitening=False,
                    rotation_range=5.,
                    width_shift_range=0.08,
                    height_shift_range=0.08,
                    shear_range=0.06,
                    zoom_range=0.08,
                    channel_shift_range=0.2,
                    fill_mode='constant',
                    cval=0.,
                    horizontal_flip=True,
                    vertical_flip=False,
                    rescale=None)
image_datagen = ImageDataGenerator(**data_gen_args)

seed = 1
image_datagen.fit(trainX, augment=True, seed=seed)

In [None]:
batch_size = len(trainX)
for seed in range(115):
    x = image_datagen.flow(trainX, shuffle=True, seed=seed, 
        save_to_dir='/Users/user/Documents/bone-suppression/database/bs/processed/augmented_1024_/x/', 
        batch_size=batch_size)
    y = image_datagen.flow(trainY, shuffle=True, seed=seed, 
        save_to_dir='/Users/user/Documents/bone-suppression/database/bs/processed/augmented_1024_/y/', 
        batch_size=batch_size)
    _ = x.next()
    _ = y.next()

### GPU Stuff

In [None]:
x_path_pattern = "/home/theskyabove/resources/augmented_1024_4005/x/*.jpg"
y_path_pattern = "/home/theskyabove/resources/augmented_1024_4005/y/*.jpg"
queue_capacity = 8192

### Loading Previously Serialized Validation Data

In [None]:
with open('/home/theskyabove/resources/augmented_1024_4005/testX_bs_1024_py3.pckl', 'rb') as f:
    testX = pickle.load(f) / 255
with open('/home/theskyabove/resources/augmented_1024_4005/testY_bs_1024_py3.pckl', 'rb') as f:
    testY = pickle.load(f) / 255

### AE-like Model with Pooling as a Size-changing Factor

In [None]:
reset_default_graph()

X = tf.placeholder(tf.float32, [None, size, size, 1])
Y_clear = tf.placeholder(tf.float32, [None, size, size, 1])
X_tensor = tf.reshape(X, [-1, size, size, 1])

n_filters = [16, 32, 64]
filter_sizes = [5, 5, 5]

current_input = X_tensor
n_input = 1

Ws = []
shapes = []

for layer_i, n_output in enumerate(n_filters):
    with tf.variable_scope("encoder/layer/{}".format(layer_i)):
        shapes.append(current_input.get_shape().as_list())
        W = tf.get_variable(
            name='W',
            shape=[
                filter_sizes[layer_i],
                filter_sizes[layer_i],
                n_input,
                n_output],
            initializer=tf.random_normal_initializer(mean=0.0, stddev=0.02))
        h = tf.nn.conv2d(current_input, W,
            strides=[1, 1, 1, 1], padding='SAME')
        conv = tf.nn.relu(h)
        current_input = tf.nn.max_pool(conv, [1,2,2,1], [1,2,2,1], padding='SAME')
        Ws.append(W)
        n_input = n_output

print(n_filters, filter_sizes, shapes, current_input.get_shape().as_list())
Ws.reverse()
shapes.reverse()
n_filters.reverse()
n_filters = n_filters[1:] + [1]
print(n_filters, filter_sizes, shapes)

for layer_i, shape in enumerate(shapes):
    with tf.variable_scope("decoder/layer/{}".format(layer_i)):
        W = Ws[layer_i]
        h = tf.nn.conv2d_transpose(current_input, W,
            tf.pack([tf.shape(X)[0], shape[1], shape[2], shape[3]]),
            strides=[1, 2, 2, 1], padding='SAME')
        current_input = tf.nn.relu(h)
        
Y = current_input
cost_2 = tf.reduce_mean(tf.reduce_mean(tf.squared_difference(Y_clear, Y), 1))
cost = 1 - tf_ms_ssim(Y_clear, Y)

learning_rate = tf.Variable(initial_value=1e-2, trainable=False, dtype=tf.float32)
alpha = 0.99
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(alpha*cost + (1 - alpha)*cost_2)

arch_info = 'crs_ms-ssim_mse_mp_1024_16-32-64_5-5-5_0.99'

### GPU Pipeline for Image Processing

In [None]:
def read_jpg(filename_queue):
    reader = tf.WholeFileReader()
    key, record_string = reader.read(filename_queue)
    path, image = reader.read(filename_queue)
    image = tf.image.decode_jpeg(image, channels=1)
    image = image / 255
    return image

def input_pipeline(x_filenames, y_filenames, batch_size):
    seed = np.random.random()    
    x_filename_queue = tf.train.string_input_producer(x_filenames, seed=seed, capacity=queue_capacity)
    y_filename_queue = tf.train.string_input_producer(y_filenames, seed=seed, capacity=queue_capacity)
    x_image = read_jpg(x_filename_queue)
    y_image = read_jpg(y_filename_queue)
    
    min_after_dequeue = 4000
    num_threads = 2
    capacity = min_after_dequeue + (num_threads + 2) * 256
    batch = tf.train.shuffle_batch(
        [x_image, y_image], 
        batch_size=batch_size, 
        capacity=capacity,
        min_after_dequeue=min_after_dequeue, 
        num_threads=num_threads,
        shapes=((size, size, 1), (size, size, 1)))
    return batch


x_filenames = tf.train.match_filenames_once(x_path_pattern)
y_filenames = tf.train.match_filenames_once(y_path_pattern)

batch_size = tf.Variable(initial_value=64, trainable=False, dtype=tf.int32)
batch = input_pipeline(x_filenames, y_filenames, batch_size)

### Initialization of the Session and the Variables

In [None]:
sess = tf.Session()
sess.run(tf.initialize_all_variables())
saver = tf.train.Saver()

### Training Phase with Full Journalling

In [None]:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

In [None]:
%%time
past_epochs = 0
n_epochs = 150
l_rate = 1e-3
b_size = 5
verbose = True
sess_output_path = '{}{}_{}_{}-{}.txt'.format('/home/theskyabove/sessions/epoch_info/output_', arch_info, 
                                     l_rate, past_epochs, past_epochs + n_epochs)
sess_checkpoint_path = '{}{}_{}_{}'.format('/home/theskyabove/sessions/session_', arch_info, 
                                     l_rate, past_epochs + n_epochs)
sess_images_path = '{}_{}_{}_'.format(arch_info, l_rate, past_epochs + n_epochs)

for epoch_i in range(n_epochs):
    epoch_time = time.time()
    for i in range(800):
        [x_batch, y_batch] = sess.run(batch, feed_dict={batch_size: b_size})
        sess.run(optimizer, feed_dict={
                X: x_batch,
                Y_clear: y_batch,
                learning_rate: l_rate,
                batch_size: b_size
            })
    
    loss = []
    loss_2 = []
    for i in range(1, 6, 1):
        [a, b] = sess.run([cost, cost_2], feed_dict={X: testX[(i-1):i], Y_clear: testY[(i-1):i]})
        loss += [a]
        loss_2 += [b]
    epoch_info = '{} {} {}\n'.format(epoch_i + past_epochs, 
                                     [mean(loss), mean(loss_2)],
                                     time.time() - epoch_time)
    if verbose:
        with open(sess_output_path, "a") as sess_output_file:
            sess_output_file.write(epoch_info)
    print(epoch_info)

past_epochs += n_epochs 
if verbose:
    saver.save(sess, sess_checkpoint_path)

In [None]:
coord.request_stop()
coord.join(threads)

### Restore Previously Saved Session

In [None]:
saver.restore(sess, '/home/theskyabove/sessions/session_crs_ms-ssim_mse_mp_1024_16-32-64_5-5-5_0.99_0.0001_900')

### Test Phase

In [None]:
%%time
recon = []
loss = []
loss_2 = []
for i in range(1, testX.shape[0] + 1, 1):
    [c, a, b] = sess.run([Y, cost, cost_2], feed_dict={X: testX[(i-1):i], Y_clear: testY[(i-1):i]})
    loss += [a]
    loss_2 += [b]
    recon += [c.reshape((size, size))]

loss = mean(loss)
loss_2 = mean(loss_2)
orig = testX.reshape((-1, size, size))
supp = testY.reshape((-1, size, size))

### Visualization of Results

In [None]:
n_cols = 5  # how many images we will display
n_rows = 3
plt.figure(figsize=(n_cols * n_rows, n_rows ** 2))
for i in range(n_cols):
    # display original
    ax = plt.subplot(n_rows, n_cols, i + 1)
    plt.imshow(orig[i].reshape(size, size))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # display reconstruction
    ax = plt.subplot(n_rows, n_cols, i + 1 + n_cols)
    plt.imshow((recon[i]).reshape(size, size)) #jrecon
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    if n_rows == 3:
        # display bone-suppressed ground truth image
        ax = plt.subplot(n_rows, n_cols, i + 1 + n_cols * 2)
        plt.imshow((supp[i]).reshape(size, size))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        
print("loss: %f" %loss)
plt.savefig('{}{}_{}.png'.format(sess_images_path, str(loss), str(loss_2)), 
            bbox_inches='tight', transparent=True, dpi=1000)
plt.show()

### Saving the Results

In [None]:
orig = reshape(testX, (-1, size, size))
path = 'tmp/to_test/'
for i in range(len(recon)):
    imsave(path + str(i) + 'orig.jpg', orig[i], cmap='gray')
    imsave(path + str(i) + 'recon.jpg', recon[i], cmap='gray')