In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Reshape, Conv2DTranspose, Conv2D, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import numpy as np
from tensorflow.keras.preprocessing.image import img_to_array, load_img
import os

In [None]:
grids = np.load('grids.npy')

def load_images(directory, target_size=(224, 224)):
    images = []
    filenames = os.listdir(directory)
    for filename in filenames:
        if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
            img_path = os.path.join(directory, filename)
            img = load_img(img_path, target_size=target_size, color_mode='grayscale')  # Use 'rgb' for color images
            img_array = img_to_array(img)
            img_array = img_array / 255.0  # Normalize to [0, 1]
            images.append(img_array)
    return np.array(images)

images = load_images('./dataset')  

assert grids.shape[0] == images.shape[0], "The number of grids must match the number of images"

In [None]:
input_grid = Input(shape=(49,))  # 7*7 grids flattened
x = Dense(128, activation='relu')(input_grid)
x = Dense(256, activation='relu')(x)
x = Dense(128 * 14 * 14, activation='relu')(x) 
x = Reshape((14, 14, 128))(x)
x = Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same', activation='relu')(x)  # Upscales to 28x28
x = Conv2DTranspose(32, (3, 3), strides=(2, 2), padding='same', activation='relu')(x)  # Upscales to 56x56
x = Conv2DTranspose(16, (3, 3), strides=(2, 2), padding='same', activation='relu')(x)  # Upscales to 112x112
decoded = Conv2DTranspose(1, (3, 3), strides=(2, 2), padding='same', activation='sigmoid')(x)  # Upscales to 224x224

autoencoder = Model(input_grid, decoded)
autoencoder.compile(optimizer=Adam(learning_rate=0.001), loss='mse')

autoencoder.summary()

In [None]:
autoencoder.fit(grids, images, epochs=200, batch_size=16, validation_split=0.2)

In [None]:
autoencoder.save('autoencoder.h5')

import matplotlib.pyplot as plt

def display_images(orig, decoded):
    n = 10  # How many digits we will display
    plt.figure(figsize=(20, 4))
    for i in range(n):
        # Display original
        ax = plt.subplot(2, n, i + 1)
        plt.imshow(orig[i].reshape(224, 224))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        # Display reconstruction
        ax = plt.subplot(2, n, i + 1 + n)
        plt.imshow(decoded[i].reshape(224, 224))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.show()

decoded_imgs = autoencoder.predict(grids[:10])
display_images(images[:10], decoded_imgs)