In [None]:
import os
import tensorflow as tf
from tensorflow import keras
from keras import layers
from keras.preprocessing.image import ImageDataGenerator

print('Dependencies succesfully imported!')

"""Checking gpu device"""
physical_devices = tf.config.list_physical_devices('GPU')
print(physical_devices)

In [10]:
# getting filepath for each split
train_path = r"dataset/train"
val_path = r"dataset/val"
test_path = r"dataset/test"

seed = 42
shuffle = True

datagen = ImageDataGenerator(
    rescale = 1/255.,
)

def augment_data(path, augmenter, batch_size = 32, target_size = (256,256), class_mode = 'binary'):
    return augmenter.flow_from_directory(
        directory = path,
        batch_size = batch_size,
        target_size = target_size,
        class_mode = class_mode,
        seed = seed,
        shuffle = shuffle
    )

# import files using image_dataset_from_directory
train = augment_data(train_path, augmenter = datagen)
val = augment_data(val_path, augmenter = datagen)
test = augment_data(test_path, augmenter = datagen)

Found 5216 images belonging to 2 classes.
Found 16 images belonging to 2 classes.
Found 624 images belonging to 2 classes.


In [11]:
"""Creating CNNBlock"""
class CNNBlock(layers.Layer):
    def __init__(self, out_channels, kernel_size):
        super(CNNBlock, self).__init__()
        self.conv = layers.Conv2D(out_channels, kernel_size, padding = 'same', input_shape = (256,256))
        self.batch_norm = layers.BatchNormalization()

    """Create function to run forward pass"""
    """CNN -> BatchNormalization -> ReLu"""
    def call(self, input, training = False):
        x = self.conv(input)
        x = self.batch_norm(x, training = training)
        x = tf.nn.relu(x)
        return x

        

In [12]:
model = keras.Sequential([
    CNNBlock(32, 3),
    layers.Flatten(),
    layers.Dense(2)
])
model.compile(
    optimizer = keras.optimizers.Adam(),
    loss = keras.losses.BinaryCrossentropy(),
    metrics = ['accuracy']
)

In [13]:
history = model.fit(
    train,
    epochs = 10,
    validation_data = val
)

Epoch 1/10

KeyboardInterrupt: 