imports

In [None]:
import numpy as np
import os
import tensorflow as tf
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input, Dense, Reshape, Conv2DTranspose, Flatten, Conv2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import img_to_array, load_img
from tensorflow.keras.losses import MeanSquaredError

datset prep and grid loading

In [None]:
grids = np.load('grids.npy')

def load_images(directory, target_size=(224, 224)):
    images = []
    filenames = os.listdir(directory)
    for filename in filenames:
        if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
            img_path = os.path.join(directory, filename)
            img = load_img(img_path, target_size=target_size, color_mode='grayscale')  # Use 'rgb' for color images
            img_array = img_to_array(img)
            img_array = img_array / 255.0  # Normalize to [0, 1]
            images.append(img_array)
    return np.array(images)

images = load_images('./dataset') 

assert grids.shape[0] == images.shape[0], "The number of grids must match the number of images"

In [None]:
custom_objects = {
    'mse': MeanSquaredError() 
}

autoencoder = load_model('autoencoder.h5', custom_objects=custom_objects)

autoencoder.compile(optimizer=Adam(learning_rate=0.001), loss='mse')

do the training

In [None]:
autoencoder.fit(grids, images, epochs=200, batch_size=16, validation_split=0.2)

In [None]:
autoencoder.save('./models/autoencoder.h5')

#visualise graphs
import matplotlib.pyplot as plt

def display_images(orig, decoded):
    n = 10  # How many digits we will display
    plt.figure(figsize=(20, 4))
    for i in range(n):
        # Display original
        ax = plt.subplot(2, n, i + 1)
        plt.imshow(orig[i].reshape(224, 224))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        ax = plt.subplot(2, n, i + 1 + n)
        plt.imshow(decoded[i].reshape(224, 224))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.show()

# Predicting on the first 10 images
decoded_imgs = autoencoder.predict(grids[:10])
display_images(images[:10], decoded_imgs)