In [1]:
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [2]:

# Data Augmentation for training set
train_datagen = ImageDataGenerator(
    rescale = 1./255,
    horizontal_flip = True,
)

# Load in the training dataset with augmentation applied
train_generator = train_datagen.flow_from_directory(
    '../binary-brain-classification/training',
    target_size = (224,224),
    batch_size = 32,
    class_mode = 'binary'
)

# Only rescale the test set
test_datagen = ImageDataGenerator(
    rescale = 1./255
)

# Load in the test dataset with only rescaling
test_generator = test_datagen.flow_from_directory(
    '../binary-brain-classification/testing',
    target_size = (224,224),
    batch_size = 32,
    class_mode = 'binary'
)

Found 3660 images belonging to 2 classes.
Found 392 images belonging to 2 classes.


In [3]:
model = keras.Sequential([

    # Define the input shape
    keras.layers.Input(shape=(224,224,3)),
    
    # CNN layer 1
    keras.layers.Conv2D(16, (3,3), activation=None),
    keras.layers.BatchNormalization(),
    keras.layers.Activation('relu'),
    keras.layers.MaxPooling2D((2,2)),
    
    # Flatten and classify
    keras.layers.Flatten(),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(1, activation = 'sigmoid')

])
model.summary()

2025-01-09 12:27:59.789721: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M2
2025-01-09 12:27:59.789755: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2025-01-09 12:27:59.789763: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.33 GB
2025-01-09 12:27:59.789781: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2025-01-09 12:27:59.789796: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [4]:
# Compile the Model
model.compile(
    optimizer = 'adam',    # Popular optimizer
    loss = 'binary_crossentropy',   # Categorical crossentropy for classification tasks
    metrics = (['accuracy'])    # Measure the accuracy
)

In [5]:
# Run with GPU for better performance
with tf.device('/GPU:0'):   
    # Test the model
    model.fit(
        train_generator,    # Traning dataset from our generator
        epochs = 10    # See all the images 10 times
    )

  self._warn_if_super_not_called()


Epoch 1/10


2025-01-09 12:28:00.193682: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.


[1m115/115[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 74ms/step - accuracy: 0.8045 - loss: 1.9083
Epoch 2/10
[1m115/115[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 73ms/step - accuracy: 0.8899 - loss: 0.8556
Epoch 3/10
[1m115/115[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 72ms/step - accuracy: 0.9369 - loss: 0.4543
Epoch 4/10
[1m115/115[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 72ms/step - accuracy: 0.9593 - loss: 0.2734
Epoch 5/10
[1m115/115[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 72ms/step - accuracy: 0.9732 - loss: 0.2106
Epoch 6/10
[1m115/115[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 72ms/step - accuracy: 0.9862 - loss: 0.0762
Epoch 7/10
[1m115/115[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 72ms/step - accuracy: 0.9796 - loss: 0.1048
Epoch 8/10
[1m115/115[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 72ms/step - accuracy: 0.9896 - loss: 0.0339
Epoch 9/10
[1m115/115[0m [32m━━━━━━━━━━━

In [6]:
model.evaluate(test_generator)

[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 32ms/step - accuracy: 0.8845 - loss: 1.9182


[1.4054574966430664, 0.8877550959587097]

In [7]:
model.save('./models/binary_model.keras')