# Low-dose CT scan image de-scattering

## Intro
High-dose CT scan produces high-quality images, but might bring undesired radioactive side effect on patients' health. Low-dose CT scan is safer, but the image quality is worse. One way to deal with this trade-off is to de-scatter the low-dose CT image. The process is very much like denoising, except that the 'noise' we deal here is the scattering. Here I present a solution for de-scattering using 3D CNN autoencoder. I do not own the data, so I will present this solution without the data. The idea can be applied to other similar denoising problems.

## Data
Volumetric CT scan image from 63 patients. Each volumetric image is 256 x 256 x 95, float32, in '.dat' format. Each patient has a scattered image and a non-scattered image (as ground-truth). 
Training set consists images from 63 patients.
Testing set consists images from 3 patients.
Example image is shown below:  
<img style="float:left;" src='sample/sample_img.jpg'/>  

## CNN autoencoder solution
I built a 3D CNN autoencoder to capture the feature of the image, and then use deconvolutional layers to reconstruct the image. The reason for using 3D CNN rather than 2D CNN is that the scattering is not independent over layers and I wish to capture and utilize the information on z-axis. To minimize the workload, instead of mapping a scattered image to its ground-truth counterpart, I trained the autoencoder to map the scattered image to the scattering residual. The scattering residual is simply the difference between ground truth image and scattered image.
The architecture of the CNN is shown below. It's a fairly simple one, with only 3 convolutional layers of which kernel sizes are 9,3,5 respectively. Batch normalization is added after each convolutional layer and deconvolutional layer, and is not shown in the figure. 
<img style="float:left;" src='CNN arch.jpg'/>  

## Implementation on Keras

In [None]:
####################################################################
#################### custom image data generator ###################
####################################################################
import numpy as np
import keras
class DataGenerator(keras.utils.Sequence):
    def __init__(self, list_IDs, labels, batch_size=4, dim=(256,256,96), n_channels=1, n_classes=None, shuffle=True, sess='train'):
        self.dim = dim
        self.batch_size = batch_size
        self.labels = labels
        self.list_IDs = list_IDs
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.on_epoch_end()
        self.sess = sess
        self.n = len(list_IDs)
        ## test git ##
    
    def __len__(self):
        return int(np.floor(len(self.list_IDs) / self.batch_size))
    
    def __getitem__(self, index):
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        list_IDs_temp = [self.list_IDs[k] for k in indexes]
        X, y = self.__data_generation(list_IDs_temp)
        return X, y
    
    def on_epoch_end(self):
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)
            
    def __data_generation(self, list_IDs_temp):
        X = np.empty((self.batch_size, *self.dim, self.n_channels))
        y = np.empty((self.batch_size, *self.dim, self.n_channels))
        
        for i, ID in enumerate(list_IDs_temp):
            if self.sess == 'train':
                x_temp = np.reshape(np.fromfile('Training_dataset/' + ID + '_Sca.dat', dtype='float32'), (256,256,95), 'F')
                y_temp = np.reshape(np.fromfile('Training_dataset/' + ID + '.dat', dtype='float32'), (256,256,95), 'F')
            else:
                x_temp = np.reshape(np.fromfile('Testing_dataset/' + ID + '_Sca.dat', dtype='float32'), (256,256,95), 'F')
                y_temp = np.reshape(np.fromfile('Testing_dataset/' + ID + '.dat', dtype='float32'), (256,256,95), 'F')
            x_temp = np.concatenate((x_temp, np.zeros((256,256,1), dtype='float32')), axis=2)
            X[i,] = np.expand_dims(x_temp, axis=3)
            y_temp = np.concatenate((y_temp, np.zeros((256,256,1), dtype='float32')), axis=2)
            y[i,] =  np.expand_dims(y_temp, axis=3)# - np.expand_dims(x_temp, axis=3)
        return X, y

In [None]:
# Generate ID of training samples
# Scattered files ends with '_Sca'
# Ground truth files has no suffix in file name
import os
training_dir = 'Training_dataset/'
file_list = os.listdir(training_dir)
training_IDs = []
for i in file_list:
    if ('Sca' not in i):
        ID = i.split('.')[0]
        if(os.path.isfile(training_dir + '/' + ID + '_Sca'+'.dat')):
            training_IDs.append(ID)

# Generate ID of testing samples
testing_IDs = []
testing_list = 'Testing_dataset/'
file_list = os.listdir(testing_list)
for i in file_list:
    if 'Sca' not in i:
        ID = i.split('.')[0]
        testing_IDs.append(ID)

In [None]:
# Set up data generators
from keras.models import Sequential
params = {
    'dim': (256,256,96),
    'batch_size': 2,
    'n_classes': None,
    'n_channels': 1,
    'shuffle': True
}
training_generator = DataGenerator(training_IDs, None, sess='train', **params)
testing_generator = DataGenerator(testing_IDs, None, sess='test', **params)

In [None]:
####################################################################
#################### build 3D CNN autoencoder ######################
####################################################################
from keras.layers import Input, Dense, Conv3D, UpSampling3D, MaxPooling3D, Conv3DTranspose, BatchNormalization
from keras.models import Model
input_img = Input(shape=(256,256,96,1))
x = Conv3D(16, (9,9,3), activation='relu',padding='same')(input_img)
x = BatchNormalization(axis=-1)(x)
x = MaxPooling3D(pool_size=(2,2,2), padding='same')(x)
x = Conv3D(32, (3,3,3), activation='relu',padding='same')(x)
x = BatchNormalization(axis=-1)(x)
x = MaxPooling3D(pool_size=(2,2,2), padding='same')(x)
x = Conv3D(64, (5,5,3), activation='relu',padding='same')(x)
x = BatchNormalization(axis=-1)(x)
encoded = MaxPooling3D(pool_size=(2,2,2), padding='same')(x)
x = Conv3DTranspose(64, (5,5,3), activation='relu',padding='same')(encoded)
x = BatchNormalization(axis=-1)(x)
x = UpSampling3D((2,2,2))(x)
x = Conv3DTranspose(32, (3,3,3), activation='relu',padding='same')(x)
x = BatchNormalization(axis=-1)(x)
x = UpSampling3D((2,2,2))(x)
x = Conv3DTranspose(16, (9,9,3), activation='relu',padding='same')(x)
x = BatchNormalization(axis=-1)(x)
x = UpSampling3D((2,2,2))(x)
decoded = Conv3DTranspose(1, (3,3,3), activation='sigmoid',padding='same')(x)

denoiser = Model(input_img, decoded)
denoiser.compile(optimizer='adadelta', loss='mse')
denoiser.summary()

In [None]:
####################################################################
#################### train 3D CNN autoencoder ######################
####################################################################
from keras.callbacks import TensorBoard
from time import time
tensorboard = TensorBoard(log_dir='logs/{}'.format(time()))

history = denoiser.fit_generator(
    generator=training_generator,
    epochs=3,
    steps_per_epoch=training_generator.n/training_generator.batch_size,
    validation_data = testing_generator,
    validation_steps = testing_generator.n/testing_generator.batch_size,
    callbacks = [tensorboard]
)

## Result
<img style="float:left;" src='result1.jpg'/>
The training result after 25 epochs is shown above. MSE loss of training and testing set is plotted in log scale. Loss is still decreasing at this point. On right is the predicted scattering and predicted de-scattered image. This model learned pattern of scattering near the edge of the image. Root mean square error of this testing sample decreased from 0.0232 to 0.0102. More training epochs would make more improvement. 

## Discussion
This model is fairly simple, yet after 25 epochs it already captures some grand feature of the scatter, like the scatter near the edges. A more complicated model would almost certainly perform better.