In [1]:
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers, models
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, LSTM, TimeDistributed

In [2]:
batch_size = 32
img_height = 32
img_width = 32

In [None]:
train_ds = tf.keras.utils.image_dataset_from_directory(
  directory='train',
  image_size=(img_height, img_width),
  batch_size=batch_size)

  
val_ds = tf.keras.utils.image_dataset_from_directory(
  directory='test',
  image_size=(img_height, img_width),
  batch_size=batch_size)

In [None]:
class_names = train_ds.class_names
print(class_names)

In [None]:
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

In [None]:
normalization_layer = layers.Rescaling(1./255)
normalized_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
image_batch, labels_batch = next(iter(normalized_ds))
first_image = image_batch[0]

In [None]:
num_classes = len(class_names)

model = models.Sequential()

model.add(layers.Rescaling(1./255, input_shape=(img_height, img_width, 3)))
model.add(layers.Conv2D(16, (3, 3), padding='same', activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(32, (3, 3), padding='same', activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
#model.add(layers.Flatten())
# encode rows of matrix
model.add(TimeDistributed(LSTM(256)))
model.add(Dropout(0.2))

# encode columns
model.add(LSTM(256))

model.add(layers.Dense(128))
model.add(layers.Dropout(0.25))
model.add(layers.Dense(num_classes, activation='softmax'))

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

epochs=100
history = model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=epochs
)

model.save('cnn-lstm-mlp.h5')