# Generate Encoded Map using a simple trained CNN

# 1. Build Classifier

In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from keras import models, layers, applications
from plotly import express as px
from plotly import graph_objects as go
import matplotlib.pyplot as plt

# Load dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)

num_classes = len(set(y_train))

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

In [None]:
# Encoder (directly inspired from Autoencoder)

input_img = layers.Input(shape=(28, 28, 1))

x = layers.Conv2D(8, (3, 3), activation='relu', padding='same')(input_img)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling2D((2, 2), padding='same')(x)

x = layers.Conv2D(16, (3, 3), activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling2D((2, 2), padding='same')(x)

x = layers.Conv2D(24, (3, 3), activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling2D((2, 2), padding='same')(x)

x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling2D((2, 2), padding='same')(x)

x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)

# x = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
# x = layers.BatchNormalization()(x)

# x = layers.Conv2D(156, (3, 3), activation='relu', padding='same')(x)
# x = layers.BatchNormalization()(x)

x = layers.Flatten()(x)

output = layers.Dense(num_classes, activation='softmax')(x)

# Build the classifier
classifier = models.Model(input_img, output, name='classifier')
classifier.summary()

In [None]:
opt = tf.keras.optimizers.Adam(learning_rate=0.001)

classifier.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

classifier.fit(x_train, y_train, batch_size=32, epochs=15, validation_split=0.1)

In [None]:
score = classifier.evaluate(x_test, y_test, verbose=0)
print("Test loss:", score[0])
print("Test accuracy:", score[1])

In [None]:
classifier.save('classifier.h5')

In [None]:
classifier = keras.models.load_model('classifier.h5')

# 2. Generate Map from Decoder's outputs

In [None]:
decoder = keras.models.load_model('decoder.h5')
decoder

In [None]:
px.imshow(decoder.predict([[0.5, 0.5]])[0, :, :, 0])

In [None]:
from itertools import product
x = np.linspace(0,1,500)
y = np.linspace(0,1,500)
grid = np.array([(a,b) for a,b in product(x,y)])

In [None]:
grid

## Encode the grid into images

In [None]:
decoder_preds = decoder.predict(grid)

In [None]:
decoder_preds.shape

In [None]:
np.save('decoder_preds.npy', decoder_preds)

In [None]:
decoder_preds = np.load('decoder_preds.npy')

## Classify the grid into classes

In [None]:
classifier_preds = classifier.predict(decoder_preds)

In [None]:
classifier_preds.shape

In [None]:
classifier_preds = np.argmax(classifier_preds, axis=1)
classifier_preds.shape

In [None]:
np.save('classifier_preds.npy', classifier_preds)

In [None]:
classifier_preds = np.load('classifier_preds.npy')

## Build map from classifications

In [None]:
# Define a discrete color map with 10 distinct colors for each class
color_map = {
    0: 'red',
    1: 'blue',
    2: 'green',
    3: 'purple',
    4: 'orange',
    5: 'cyan',
    6: 'magenta',
    7: 'yellow',
    8: 'lime',
    9: 'brown',
}

# Map the labels to their corresponding colors
colors = [color_map[label] for label in classifier_preds]

# Create the scatter plot
plt.scatter(grid[:, 0], grid[:, 1], c=colors, s=5)

# Set the aspect ratio to be equal for x and y axes
plt.axis('equal')

# Set axis labels and title
plt.xlabel('X Coordinate')
plt.ylabel('Y Coordinate')
plt.title('Discrete Colored Map by Labels')

# Show the plot
plt.show()