In [None]:
%matplotlib widget
from matplotlib import pyplot as plt
import numpy as np
import torch

# import dependencies for PCA
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

from mlc.data.mnist.dataset import MNIST
from mlc.util.model import load_checkpoint

In [None]:
# model_version, args = "20250411150333312014", {"neck_dim": 8, "init_dim": 32}
# model_version, args = "20250411144310508707", {"neck_dim": 16, "init_dim": 32}
model_version, args = "20250411153623227400", {"neck_dim": 16, "init_dim": 64}
model = load_checkpoint("mlp_autoencoder", args, model_version, "latest")
print(model.name())

In [None]:
# get model encoder!
encoder = model.encoder

# get validation data
dataset = MNIST({})
validation_data = dataset.get_fold("validation")

In [None]:
fig, ax = plt.subplots(2, 10, figsize=(10, 5), sharey=True)
# show some data
for i in range(10):
    data, _ = validation_data[i]
    with torch.no_grad():
        data = data.unsqueeze(0)
        data_rec = model(data)
    ax[0, i].imshow(data.squeeze().cpu().numpy(), cmap="gray")
    ax[0, i].axis("off")
    # apply model
    data_rec = data_rec.view((28, 28))
    ax[1, i].imshow(data_rec.cpu().numpy(), cmap="gray")
    ax[1, i].axis("off")

In [None]:
# apply encoder to data
encoded_data = []
label = []
for i in range(len(validation_data)):
    data, _ = validation_data[i]
    data = data.unsqueeze(0)
    with torch.no_grad():
        encoded = encoder(data)
        encoded_data.append(encoded.cpu().numpy())
        label.append(validation_data.get_label(i))
encoded_data = np.concatenate(encoded_data, axis=0)
label = np.array(label)
print(encoded_data.shape)
print(label.shape)

In [None]:
# Do PCA on encoded data
pca = PCA(n_components=2)
pca.fit(encoded_data)
encoded_data_2d = pca.transform(encoded_data)
print(encoded_data_2d.shape)

In [None]:
# scatter plot
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
# scatter plot
for i in range(10):
    ax.scatter(encoded_data_2d[label == i, 0], encoded_data_2d[label == i, 1], alpha=0.5, label=f"Class {i}")
ax.legend()
ax.set_xlabel("PCA 1")
ax.set_ylabel("PCA 2")
plt.show()

In [None]:
# Do PCA on encoded data
pca = PCA(n_components=3)
pca.fit(encoded_data)
encoded_data_3d = pca.transform(encoded_data)
print(encoded_data_3d.shape)

In [None]:
# Create a 3D scatter plot
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
# scatter plot
for i in range(10):
    ax.scatter(
        encoded_data_3d[label == i, 0],
        encoded_data_3d[label == i, 1],
        encoded_data_3d[label == i, 2],
        alpha=0.5,
        label=f"Class {i}",
    )
ax.legend()
ax.set_xlabel("PCA 1")
ax.set_ylabel("PCA 2")
ax.set_zlabel("PCA 3")
plt.show()

In [None]:
# Use t-SNE to reduce the dimensionality of the data to 2D
# Initialize t-SNE
tsne = TSNE(n_components=2, random_state=42)
# Fit and transform the data
encoded_data_tsne = tsne.fit_transform(encoded_data)
print(encoded_data_tsne.shape)

In [None]:
# Create a scatter plot
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
# scatter plot
for i in range(10):
    ax.scatter(encoded_data_tsne[label == i, 0], encoded_data_tsne[label == i, 1], alpha=0.5, label=f"Class {i}")
ax.legend()
ax.set_xlabel("t-SNE 1")
ax.set_ylabel("t-SNE 2")
plt.show()

In [None]:
# Use t-SNE to reduce the dimensionality of the data to 3D
# Initialize t-SNE
tsne = TSNE(n_components=3, random_state=42)
# Fit and transform the data
encoded_data_tsne_3d = tsne.fit_transform(encoded_data)
print(encoded_data_tsne_3d.shape)
# Create a 3D scatter plot
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
# scatter plot
for i in range(10):
    ax.scatter(
        encoded_data_tsne_3d[label == i, 0],
        encoded_data_tsne_3d[label == i, 1],
        encoded_data_tsne_3d[label == i, 2],
        alpha=0.5,
        label=f"Class {i}",
    )
ax.legend()
ax.set_xlabel("t-SNE 1")
ax.set_ylabel("t-SNE 2")
ax.set_zlabel("t-SNE 3")
plt.show()