# Generate Latent Space Map using trained encoder + classifier

## 0. Imports

In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
from keras import models, layers, applications, backend as K
from tensorflow.keras.losses import MeanSquaredError, KLDivergence
from plotly import express as px

## 1. Load SOTA classifier

In [None]:
classifier = keras.models.load_model('best_classifier.h5', compile=False)
classifier.compile(optimizer = 'adam', loss = 'binary_crossentropy', metrics = ['accuracy'])

## 2. Load SOTA decoder (Latent Space Z -> Img)

In [None]:
decoder = keras.models.load_model('best_decoder.h5')

## 3. Generate grid of points

In [None]:
from itertools import product
x = np.linspace(0,1,500)
y = np.linspace(0,1,500)
grid = np.array([(a,b) for a,b in product(x,y)])

## 4. Decode the grid to generated images using the decoder

In [None]:
decoder_preds = decoder.predict(grid)

# Save to reuse later
np.save('decoder_preds.npy', decoder_preds)

print(decoder_preds.shape)

## 5. Classify the prediction of decoder into classes using classifier

In [None]:
classifier_preds = classifier.predict(decoder_preds)

print(classifier_preds.shape)

classifier_preds = np.argmax(classifier_preds, axis=1)

# Save to reuse later (this is actually a useful artifact for mapSpace function)
np.save('classifier_preds.npy', classifier_preds)

## 6. Build and display map

In [None]:
# Define a discrete color map with 10 distinct colors for each class
color_map = {
    0: 'red',
    1: 'blue',
    2: 'green',
    3: 'purple',
    4: 'orange',
    5: 'cyan',
    6: 'magenta',
    7: 'yellow',
    8: 'lime',
    9: 'brown',
}

# Create a mapping of class labels to text labels
class_labels = {
    0: 'T-shirt/top',
    1: 'Trouser',
    2: 'Pullover',
    3: 'Dress',
    4: 'Coat',
    5: 'Sandal',
    6: 'Shirt',
    7: 'Sneaker',
    8: 'Bag',
    9: 'Ankle boot',
}


In [None]:
point = [0.66, 0.232]
fig = px.imshow(decoder.predict([point])[0, :, :, 0])
fig.show()
print(class_labels[np.argmax(classifier.predict(decoder.predict([point])))])

In [None]:
# Map the labels to their corresponding colors
colors = [color_map[label] for label in classifier_preds]

# Create the scatter plot
plt.scatter(grid[:, 0], grid[:, 1], c=colors, s=5)

# Set the aspect ratio to be equal for x and y axes
plt.axis('equal')

# Set axis labels and title
plt.xlabel('X Coordinate')
plt.ylabel('Y Coordinate')
plt.title('Discrete Colored Map by Labels')

# Show the plot
plt.show()