Sources: 
 https://github.com/DLTK/DLTK/tree/master/examples/tutorials
 
 
 https://github.com/fitushar/3DCNNs_TF2Modelhub
 
 
 https://keras.io/examples/vision/3D_image_classification/

In [None]:
import os
import zipfile
import numpy as np

import tensorflow as tf
import keras 
from tensorflow import keras

from tensorflow.keras import optimizers,layers
from tensorflow.keras.optimizers import schedules

initializer = tf.keras.initializers.HeUniform()
tf.random.set_seed(42)
import os
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'


## Downloading and processing the CT (3D.niftii files) dataset
Here I will be using the same tutorial presented in https://keras.io/examples/vision/3D_image_classification/
to download the dataset

In [None]:
# Download url of normal CT scans.
url = "https://github.com/hasibzunair/3D-image-classification-tutorial/releases/download/v0.2/CT-0.zip"
filename = os.path.join(os.getcwd(), "CT-0.zip")
keras.utils.get_file(filename, url)

# Download url of abnormal CT scans.
url = "https://github.com/hasibzunair/3D-image-classification-tutorial/releases/download/v0.2/CT-23.zip"
filename = os.path.join(os.getcwd(), "CT-23.zip")
keras.utils.get_file(filename, url)

# Make a directory to store the data.
os.makedirs("MosMedData")

# Unzip data in the newly created directory.
with zipfile.ZipFile("CT-0.zip", "r") as z_fp:
    z_fp.extractall("./MosMedData/")

with zipfile.ZipFile("CT-23.zip", "r") as z_fp:
    z_fp.extractall("./MosMedData/")


In [None]:
import nibabel as nib

from scipy import ndimage


def read_nifti_file(filepath):
    """Read and load volume"""
    # Read file
    scan = nib.load(filepath)
    # Get raw data
    scan = scan.get_fdata()
    return scan


def normalise_zero_one(image):
    """Image normalisation. Normalises image to fit [0, 1] range."""
    image = image.astype(np.float32)
    minimum = np.min(image)
    maximum = np.max(image)
    if maximum > minimum:
        ret = (image - minimum) / (maximum - minimum)
    else:
        ret = image * 0.
    return ret

def whitening(image):
    """Whitening. Normalises image to zero mean and unit variance."""
    image = image.astype(np.float32)
    mean = np.mean(image)
    std = np.std(image)
    if std > 0:
        ret = (image - mean) / std
    else:
        ret = image * 0.
    return ret

def remove_slices(scan):
    scan = scan[:,:,35:135]
    return scan

def crop3D(scan): ## Bonus Function,
    #this could be useful for many MRI datasets where the beginning of the MRI series are usually useless
    start = (90,90)     
    end = (290,290)
    slices = tuple(map(slice, start, end))
    return scan[slices] 
def resize_volume(img):
    """Resize across z-axis"""
    # Set the desired depth
    desired_depth = 64
    desired_width = 128
    desired_height = 128
    # Get current depth
    current_depth = img.shape[-1]
    current_width = img.shape[0]
    current_height = img.shape[1]
    # Compute depth factor
    depth = current_depth / desired_depth
    width = current_width / desired_width
    height = current_height / desired_height
    depth_factor = 1 / depth
    width_factor = 1 / width
    height_factor = 1 / height
    # Rotate
    img = ndimage.rotate(img, 90, reshape=False)
    # Resize across z-axis
    img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=1)
    return img
def process_scan(path):
    """Read and resize volume"""
    # Read scan
    volume = read_nifti_file(path)
    # Normalize
    volume = whitening(volume)
    
    # Resize width, height and depth
    volume = resize_volume(volume)

    return volume


In [None]:
# Folder "CT-0" consist of CT scans having normal lung tissue,
# no CT-signs of viral pneumonia.
normal_scan_paths = [
    os.path.join(os.getcwd(), "MosMedData/CT-0", x)
    for x in os.listdir("MosMedData/CT-0")
]
# Folder "CT-23" consist of CT scans having several ground-glass opacifications,
# involvement of lung parenchyma.
abnormal_scan_paths = [
    os.path.join(os.getcwd(), "MosMedData/CT-23", x)
    for x in os.listdir("MosMedData/CT-23")
]

print("CT scans with normal lung tissue: " + str(len(normal_scan_paths)))
print("CT scans with abnormal lung tissue: " + str(len(abnormal_scan_paths)))


In [None]:
# Read and process the scans.
# Each scan is resized across height, width, and depth and rescaled.
abnormal_scans = np.array([process_scan(path) for path in abnormal_scan_paths])
normal_scans = np.array([process_scan(path) for path in normal_scan_paths])

# For the CT scans having presence of viral pneumonia
# assign 1, for the normal ones assign 0.
abnormal_labels = np.array([1 for _ in range(len(abnormal_scans))])
normal_labels = np.array([0 for _ in range(len(normal_scans))])


data= np.concatenate((abnormal_scans, normal_scans), axis=0)
label = np.concatenate((abnormal_labels, normal_labels), axis=0)
from sklearn.model_selection import train_test_split
x_train, x_val, y_train, y_val = train_test_split(data, label, test_size=0.2,random_state=42)


In [None]:
import random

from scipy import ndimage


@tf.function
def rotate(volume):
    """Rotate the volume by a few degrees"""

    def scipy_rotate(volume):
        # define some rotation angles
        angles = [-20, -10, -5, 5, 10, 20]
        # pick angles at random
        angle = random.choice(angles)
        # rotate volume
        volume = ndimage.rotate(volume, angle, reshape=False)
        volume[volume < 0] = 0
        volume[volume > 1] = 1
        return volume

    augmented_volume = tf.numpy_function(scipy_rotate, [volume], tf.float32)
    return augmented_volume


def train_preprocessing(volume, label):
    """Process training data by rotating and adding a channel."""
    # Rotate volume
    volume = rotate(volume)
    volume = tf.expand_dims(volume, axis=3)
    return volume, label


def validation_preprocessing(volume, label):
    """Process validation data by only adding a channel."""
    volume = tf.expand_dims(volume, axis=3)
    return volume, label


In [None]:
train_loader = tf.data.Dataset.from_tensor_slices((x_train, y_train))
validation_loader = tf.data.Dataset.from_tensor_slices((x_val, y_val))

batch_size = 2
# Augment the on the fly during training.
train_dataset = (
    train_loader.shuffle(len(x_train))
    .map(train_preprocessing)
    .batch(batch_size)
    .prefetch(2)
)
# Only rescale.
validation_dataset = (
    validation_loader.shuffle(len(x_val))
    .map(validation_preprocessing)
    .batch(batch_size)
    .prefetch(2)
)


# VGG16


In [None]:

def VGG3D(inputs):
    inputs = inputs
    x = inputs
    x = layers.Conv3D(32, (3, 3, 3), padding='same', activation='relu')(x)
    x = layers.Conv3D(32, (3, 3, 3), padding='same', activation='relu')(x)
    x = layers.MaxPooling3D(pool_size=(2, 2, 2))(x)

    if TRAIN_CLASSIFY_USE_BN:
        x = layers.BatchNormalization()(x)

    x = layers.Conv3D(64, (3, 3, 3), padding='same', activation='relu')(x)
    x = layers.Conv3D(64, (3, 3, 3), padding='same', activation='relu')(x)
    x = layers.MaxPooling3D(pool_size=(2, 2, 2))(x)

    if TRAIN_CLASSIFY_USE_BN:
        x = BatchNormalization()(x)

    x = layers.Conv3D(128, (3, 3, 3), padding='same', activation='relu')(x)
    x = layers.Conv3D(128, (3, 3, 3), padding='same', activation='relu')(x)
    x = layers.Conv3D(128, (3, 3, 3), padding='same', activation='relu')(x)
    x = layers.MaxPooling3D(pool_size=(2, 2, 2))(x)

    if TRAIN_CLASSIFY_USE_BN:
        x = layers.BatchNormalization()(x)

    x = layers.Conv3D(256, (3, 3, 3), padding='same', activation='relu')(x)
    x = layers.Conv3D(256, (3, 3, 3), padding='same', activation='relu')(x)
    x = layers.Conv3D(256, (3, 3, 3), padding='same', activation='relu')(x)
    x = layers.MaxPooling3D(pool_size=(2, 2, 2))(x)

    if TRAIN_CLASSIFY_USE_BN:
        x = layers.BatchNormalization()(x)

    x = layers.Conv3D(512, (3, 3, 3), padding='same', activation='relu')(x)
    x = layers.Conv3D(512, (3, 3, 3), padding='same', activation='relu')(x)
    x = layers.Conv3D(512, (3, 3, 3), padding='same', activation='relu')(x)
    x = layers.GlobalMaxPooling3D()(x)

    x = layers.Dense(32, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    x = layers.Dense(units=1,activation='sigmoid')(x)
    model = keras.Model(inputs=inputs, outputs=x)
    return model

model=VGG3D(layers.Input((128, 128, 64,1)))


## This section will be the same for all the classifiers. 
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
              loss=tf.keras.losses.BinaryCrossentropy(), 
              metrics=[tf.keras.metrics.BinaryAccuracy(),tf.keras.metrics.Precision(),tf.keras.metrics.AUC()])

checkpoint_cb = keras.callbacks.ModelCheckpoint(
    "3d_image_classification.h5", save_best_only=True
)
early_stopping_cb = keras.callbacks.EarlyStopping(monitor="val_acc", patience=15)

# Train the model, doing validation at the end of each epoch
model.fit(
    train_dataset,
    validation_data=validation_dataset,
    epochs=100,
    shuffle=True,
    verbose=2,
    callbacks=[checkpoint_cb, early_stopping_cb],
)


# Inception V3

In [None]:
INCEPTION_BLOCKS = 6
INCEPTION_REDUCTION_STEPS = 2
INCEPTION_KEEP_FILTERS = 128
INCEPTION_ENABLE_DEPTHWISE_SEPARABLE_CONV_SHRINKAGE = 0.333
INCEPTION_ENABLE_SPATIAL_SEPARABLE_CONV = True
INCEPTION_DROPOUT = 0.5

def conv_bn_relu(x, filters, kernel_size=(3, 3, 3), strides=(1, 1, 1), padding='same'):
    x = layers.Conv3D(filters, kernel_size=kernel_size, strides=strides, padding=padding)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    return x

def inception_base(x):
    x = conv_bn_relu(x, filters=32)
    x = conv_bn_relu(x, filters=32)
    x = conv_bn_relu(x, filters=64)

    b0 = layers.MaxPooling3D(pool_size=(2, 2, 2))(x)
    b1 = conv_bn_relu(x, 64, strides=(2, 2, 2))
    x =  layers.Concatenate(axis=4)([b0, b1])

    print('inception_base')
    print(b0.get_shape())
    print(b1.get_shape())
    print(x.get_shape())

    return x

def inception_block(x, filters=256):
    shrinkaged_filters = int(filters * INCEPTION_ENABLE_DEPTHWISE_SEPARABLE_CONV_SHRINKAGE)
    b0 = conv_bn_relu(x, filters=filters, kernel_size=(1, 1, 1))

    b1 = conv_bn_relu(x, filters=shrinkaged_filters, kernel_size=(1, 1, 1))
    b1 = conv_bn_relu(b1, filters=filters, kernel_size=(3, 3, 3))

    b2 = conv_bn_relu(x, filters=shrinkaged_filters, kernel_size=(1, 1, 1))
    b2 = conv_bn_relu(b2, filters=filters, kernel_size=(3, 3, 3))
    b2 = conv_bn_relu(b2, filters=filters, kernel_size=(3, 3, 3))

    b3 = layers.AveragePooling3D(pool_size=(3, 3, 3), strides=(1, 1, 1), padding='same')(x)
    b3 = conv_bn_relu(b3, filters=filters, kernel_size=(1, 1, 1))

    bs = [b0, b1, b2, b3]

    print('inception_block')
    print(b0.get_shape())
    print(b1.get_shape())
    print(b2.get_shape())
    print(b3.get_shape())

    if INCEPTION_ENABLE_SPATIAL_SEPARABLE_CONV:
        b4 = conv_bn_relu(x, filters=shrinkaged_filters, kernel_size=(1, 1, 1))
        b4 = conv_bn_relu(b4, filters=filters, kernel_size=(5, 1, 1))
        b4 = conv_bn_relu(b4, filters=filters, kernel_size=(1, 5, 1))
        b4 = conv_bn_relu(b4, filters=filters, kernel_size=(1, 1, 5))
        bs.append(b4)
        print(b4.get_shape())

    x = layers.Concatenate(axis=4)(bs)
    print(x.get_shape())

    return x

def reduction_block(x, filters=256):
    b0 = conv_bn_relu(x, filters=filters, kernel_size=(3, 3, 3), strides=(2, 2, 2), padding='same')

    b1 = conv_bn_relu(x, filters=filters, kernel_size=(1, 1, 1))
    b1 = conv_bn_relu(b1, filters=filters, kernel_size=(3, 3, 3))
    b1 = conv_bn_relu(b1, filters=filters, kernel_size=(3, 3, 3), strides=(2, 2, 2), padding='same')

    b2 = MaxPooling3D(pool_size=(3, 3, 3), strides=(2, 2, 2), padding='same')(x)
    b2 = conv_bn_relu(b2, filters=filters, kernel_size=(1, 1, 1))

    bs = [b0, b1, b2]

    print('reduction_block')
    print(b0.get_shape())
    print(b1.get_shape())
    print(b2.get_shape())

    if INCEPTION_ENABLE_SPATIAL_SEPARABLE_CONV:
        b3 = conv_bn_relu(x, filters=filters, kernel_size=(1, 1, 1))
        b3 = conv_bn_relu(b3, filters=filters, kernel_size=(5, 1, 1))
        b3 = conv_bn_relu(b3, filters=filters, kernel_size=(1, 5, 1))
        b3 = conv_bn_relu(b3, filters=filters, kernel_size=(1, 1, 5))
        b3 = conv_bn_relu(b3, filters=filters, kernel_size=(3, 3, 3), strides=(2, 2, 2), padding='same')
        bs.append(b3)
        print(b3.get_shape())

    x = Concatenate(axis=4)(bs)
    print(x.get_shape())

    return x

def Inception3D(inputs,num_classes):
    inputs = inputs
    # Make inception base
    x = inception_base(inputs)

    for i in range(INCEPTION_BLOCKS):
        x = inception_block(x, filters=INCEPTION_KEEP_FILTERS)

        if (i + 1) % INCEPTION_REDUCTION_STEPS == 0 and i != INCEPTION_BLOCKS - 1:
            x = reduction_block(x, filters=INCEPTION_KEEP_FILTERS // 2)

    print('top')
    x = layers.GlobalMaxPooling3D()(x)
    x = layers.Dropout(INCEPTION_DROPOUT)(x)
    x = layers.Dense(units=1,activation='sigmoid')(x)
    model = keras.Model(inputs=inputs, outputs=x)
    return model


model=Inception3D(layers.Input((128, 128, 64,1)))


## This section will be the same for all the classifiers. 
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
              loss=tf.keras.losses.BinaryCrossentropy(), 
              metrics=[tf.keras.metrics.BinaryAccuracy(),tf.keras.metrics.Precision(),tf.keras.metrics.AUC()])

checkpoint_cb = keras.callbacks.ModelCheckpoint(
    "3d_image_classification.h5", save_best_only=True
)
early_stopping_cb = keras.callbacks.EarlyStopping(monitor="val_acc", patience=15)

# Train the model, doing validation at the end of each epoch
model.fit(
    train_dataset,
    validation_data=validation_dataset,
    epochs=100,
    shuffle=True,
    verbose=2,
    callbacks=[checkpoint_cb, early_stopping_cb],
)


In [None]:
TRAIN_NUM_RES_UNIT=3
TRAIN_NUM_FILTERS=(16, 32, 64, 128)
TRAIN_STRIDES=((1, 1, 1), (2, 2, 2), (2, 2, 2), (2, 2, 2))
TRAIN_CLASSIFY_ACTICATION=tf.nn.relu6
TRAIN_KERNAL_INITIALIZER=tf.keras.initializers.VarianceScaling(distribution='uniform')

###Residual Block
def Residual_Block(inputs,
                 out_filters,
                 kernel_size=(3, 3, 3),
                 strides=(1, 1, 1),
                 use_bias=False,
                 activation=tf.nn.relu6,
                 kernel_initializer=tf.keras.initializers.VarianceScaling(distribution='uniform'),
                 bias_initializer=tf.zeros_initializer(),
                 kernel_regularizer=tf.keras.regularizers.l2(l=0.001),
                 bias_regularizer=None,
                 **kwargs):


    conv_params={'padding': 'same',
                   'use_bias': use_bias,
                   'kernel_initializer': kernel_initializer,
                   'bias_initializer': bias_initializer,
                   'kernel_regularizer': kernel_regularizer,
                   'bias_regularizer': bias_regularizer}

    in_filters = inputs.get_shape().as_list()[-1]
    x=inputs
    orig_x=x

    ##building
    # Adjust the strided conv kernel size to prevent losing information
    k = [s * 2 if s > 1 else k for k, s in zip(kernel_size, strides)]

    if np.prod(strides) != 1:
            orig_x = tf.keras.layers.MaxPool3D(pool_size=strides,strides=strides,padding='valid')(orig_x)

    ##sub-unit-0
    x=tf.keras.layers.BatchNormalization()(x)
    x=activation(x)
    x=tf.keras.layers.Conv3D(filters=out_filters,kernel_size=k,strides=strides,**conv_params)(x)

    ##sub-unit-1
    x=tf.keras.layers.BatchNormalization()(x)
    x=activation(x)
    x=tf.keras.layers.Conv3D(filters=out_filters,kernel_size=kernel_size,strides=(1,1,1),**conv_params)(x)

        # Handle differences in input and output filter sizes
    if in_filters < out_filters:
        orig_x = tf.pad(tensor=orig_x,paddings=[[0, 0]] * (len(x.get_shape().as_list()) - 1) + [[
                    int(np.floor((out_filters - in_filters) / 2.)),
                    int(np.ceil((out_filters - in_filters) / 2.))]])

    elif in_filters > out_filters:
        orig_x = tf.keras.layers.Conv3D(filters=out_filters,kernel_size=kernel_size,strides=(1,1,1),**conv_params)(orig_x)

    x += orig_x
    return x



## Resnet----3D
def Resnet3D(inputs,
              num_classes,
              num_res_units=TRAIN_NUM_RES_UNIT,
              filters=TRAIN_NUM_FILTERS,
              strides=TRAIN_STRIDES,
              use_bias=False,
              activation=TRAIN_CLASSIFY_ACTICATION,
              kernel_initializer=TRAIN_KERNAL_INITIALIZER,
              bias_initializer=tf.zeros_initializer(),
              kernel_regularizer=tf.keras.regularizers.l2(l=0.001),
              bias_regularizer=None,
              **kwargs):
    conv_params = {'padding': 'same',
                   'use_bias': use_bias,
                   'kernel_initializer': kernel_initializer,
                   'bias_initializer': bias_initializer,
                   'kernel_regularizer': kernel_regularizer,
                   'bias_regularizer': bias_regularizer}


    ##building
    k = [s * 2 if s > 1 else 3 for s in strides[0]]


    #Input
    x = inputs
    #1st-convo
    x=tf.keras.layers.Conv3D(filters[0], k, strides[0], **conv_params)(x)

    for res_scale in range(1, len(filters)):
        x = Residual_Block(
                inputs=x,
                out_filters=filters[res_scale],
                strides=strides[res_scale],
                activation=activation,
                name='unit_{}_0'.format(res_scale))
        for i in range(1, num_res_units):
            x = Residual_Block(
                    inputs=x,
                    out_filters=filters[res_scale],
                    strides=(1, 1, 1),
                    activation=activation,
                    name='unit_{}_{}'.format(res_scale, i))


    x=tf.keras.layers.BatchNormalization()(x)
    x=activation(x)
    x=tf.keras.layers.GlobalAveragePooling3D()(x)
    x =tf.keras.layers.Dropout(0.5)(x)
    classifier=tf.keras.layers.Dense(units=num_classes,activation='sigmoid')(x)
    model = tf.keras.Model(inputs=inputs, outputs=classifier)
    return model


model=Resnet3D(layers.Input((128, 128, 64,1)))


## This section will be the same for all the classifiers. 
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
              loss=tf.keras.losses.BinaryCrossentropy(), 
              metrics=[tf.keras.metrics.BinaryAccuracy(),tf.keras.metrics.Precision(),tf.keras.metrics.AUC()])

checkpoint_cb = keras.callbacks.ModelCheckpoint(
    "3d_image_classification.h5", save_best_only=True
)
early_stopping_cb = keras.callbacks.EarlyStopping(monitor="val_acc", patience=15)

# Train the model, doing validation at the end of each epoch
model.fit(
    train_dataset,
    validation_data=validation_dataset,
    epochs=100,
    shuffle=True,
    verbose=2,
    callbacks=[checkpoint_cb, early_stopping_cb],
)