In [152]:
# Imports

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras import layers
import tensorflow_datasets as tfds
from imgaug import augmenters as iaa
import imgaug as ia

#----------------------------------------
!pip install simpleitk
import SimpleITK as sitk
import sys
import os
import matplotlib.pyplot as plt

Defaulting to user installation because normal site-packages is not writeable


In [153]:

#INITIAL VARS —————————————————————————————————————————————————————————————
tfds.disable_progress_bar()
tf.random.set_seed(42)
ia.seed(42)

x_train123 = []
y_train = []
x_test123 = []
y_test = []

"""
## Load the Brain Classification dataset
"""

#PATH INFORMATION —————————————————————————————————————————————————————————————
healthy_train = "RandAugmentTest/Data/Training/Healthy" #Relative Path to Healthy Train Data
diseased_train = "RandAugmentTest/Data/Training/Diseased/" #Relative Path to Diseased Train Data
healthy_test = "RandAugmentTest/Data/Testing/Healthy/" #Relative Path to Healthy Test Data
diseased_test = "RandAugmentTest/Data/Testing/Diseased/" #Relative Path to Diseased Test Data


#POPULATE TRAIN/TEST DATASETS —————————————————————————————————————————————————————————————
for file_name in os.listdir(healthy_train):
  if file_name.endswith('.mhd'):
    img = sitk.ReadImage(os.path.join(healthy_train, file_name))
    imgdata = sitk.GetArrayFromImage(img)
    x_train123.append(imgdata)
    y_train.append(0)

for file_name in os.listdir(diseased_train):
  if file_name.endswith('.mhd'):
    img = sitk.ReadImage(os.path.join(diseased_train, file_name))
    imgdata = sitk.GetArrayFromImage(img)
    x_train123.append(imgdata)
    y_train.append(1)
count = 0
for file_name in os.listdir(healthy_test):
  if file_name.endswith('.mhd'):
    img = sitk.ReadImage(os.path.join(healthy_test, file_name))
    imgdata = sitk.GetArrayFromImage(img)
    x_test123.append(imgdata)
    y_test.append(0) 

for file_name in os.listdir(diseased_test):
  if file_name.endswith('.mhd'):
    img = sitk.ReadImage(os.path.join(diseased_test, file_name))
    imgdata = sitk.GetArrayFromImage(img)
    x_test123.append(imgdata)
    y_test.append(1) 


x_train = np.empty(shape=[378,256,256, 3])
for i in range(len(x_train)):
    x_train[i] = np.stack((x_train123[i],)*3, axis=-1)
y_train = np.array(y_train)

x_test = np.empty(shape=[50,256,256, 3])
for i in range(len(x_test)):
    x_train[i] = np.stack((x_test123[i],)*3, axis=-1)
y_test = np.array(y_test)


# print(x_train.shape)
# print(y_train.shape)

# print(x_test.shape)
# print(y_test.shape)


print(f"Total training examples: {len(x_train)}")
print(f"Total test examples: {len(x_test)}")



Total training examples: 378
Total test examples: 50


In [161]:
# Define hyperparameters

AUTO = tf.data.AUTOTUNE
BATCH_SIZE = 256
EPOCHS = 4
IMAGE_SIZE = 256

rand_aug = iaa.RandAugment(n=3, m=4)


In [162]:
#Augment our datasets with rand_aug

def augment(images):
    images = tf.cast(images, tf.uint8)
    return rand_aug(images=images.numpy())



In [163]:
# Convert the numpy train/test datasets into tensorflow train/test datasets

train_ds_rand = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .shuffle(BATCH_SIZE * 100)
    .batch(BATCH_SIZE)
    .map(
        lambda x, y: (tf.image.resize(x, (IMAGE_SIZE, IMAGE_SIZE)), y),
        num_parallel_calls=AUTO,
    )
    .map(
        lambda x, y: (tf.py_function(augment, [x], [tf.float32])[0], y),
        num_parallel_calls=AUTO,
    )
    .prefetch(AUTO)
)

test_ds = (
    tf.data.Dataset.from_tensor_slices((x_test, y_test))
    .batch(BATCH_SIZE)
    .map(
        lambda x, y: (tf.image.resize(x, (IMAGE_SIZE, IMAGE_SIZE)), y),
        num_parallel_calls=AUTO,
    )
    .prefetch(AUTO)
)

In [164]:
#Define the CNN Model's Architecture


def get_training_model():
    resnet50_v2 = tf.keras.applications.ResNet50V2(
        weights=None,
        include_top=True,
        input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3),
        classes=2,
    )
    model = tf.keras.Sequential(
        [
            layers.Input((IMAGE_SIZE, IMAGE_SIZE, 3)),
            layers.Rescaling(scale=1.0 / 127.5, offset=-1),
            resnet50_v2,
        ]
    )
    return model


get_training_model().summary()


Model: "sequential_40"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 rescaling_40 (Rescaling)    (None, 256, 256, 3)       0         
                                                                 
 resnet50v2 (Functional)     (None, 2)                 23568898  
                                                                 
Total params: 23,568,898
Trainable params: 23,523,458
Non-trainable params: 45,440
_________________________________________________________________


In [165]:
#Get initial training parameters

initial_model = get_training_model()
initial_model.save_weights("initial_weights.h5")


In [166]:
#Train data model using RandAugment

rand_aug_model = get_training_model()
rand_aug_model.load_weights("initial_weights.h5")
rand_aug_model.compile(
    loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"]
)
rand_aug_model.fit(train_ds_rand, validation_data=test_ds, epochs=EPOCHS)
_, test_acc = rand_aug_model.evaluate(test_ds)
print("Test accuracy: {:.2f}%".format(test_acc * 100))


Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4
Test accuracy: 76.00%
