In [None]:
%matplotlib widget
from matplotlib import pyplot as plt
import numpy as np
import torch

# import dependencies for PCA
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

from mlc.data.mnist.dataset import MNIST
from mlc.util.model import load_checkpoint

In [None]:
model_version = "20250523114243213422"  # good
model = load_checkpoint("mlp_gan", model_version, "latest")
print(model.name())

In [None]:
# get validation data
dataset = MNIST({"noise": 0, "normalize": True})
validation_data = dataset.get_fold("train")

In [None]:
fig, ax = plt.subplots(1, 10, figsize=(10, 2), sharey=True)
# show some data
for i in range(10):
    data, _ = validation_data[i]
    with torch.no_grad():
        data = data.unsqueeze(0)
    ax[i].imshow(data.squeeze().cpu().numpy(), cmap="gray")
    ax[i].axis("off")

In [None]:
G = model.generator.to("cuda")
G = G.eval()

In [None]:
Z = torch.randn(100, model.latent_dimension()).to("cuda") * 0.1
# normalize Z
Z = (Z - torch.tensor([0.5]).to("cuda")) / torch.tensor([0.5]).to("cuda")
X = G(Z).cpu().detach().numpy()
# show generated data
fig, ax = plt.subplots(10, 10, figsize=(10, 10), sharey=True)
for i in range(10):
    for j in range(10):
        ax[i, j].imshow(X[i * 10 + j].reshape(28, 28), cmap="gray")
        ax[i, j].axis("off")
plt.show()

In [None]:
z00, z01 = Z[27], Z[28]
# interpolate between two latent vectors
Z_interp = []
for alpha in np.linspace(0, 1, 10):
    z = (1 - alpha) * z00 + alpha * z01
    Z_interp.append(z)

# show interpolated data
X_interp = G(torch.stack(Z_interp)).cpu().detach().numpy()
fig, ax = plt.subplots(1, 10, figsize=(10, 2), sharey=True)
for i in range(10):
    ax[i].imshow(X_interp[i].reshape(28, 28), cmap="gray")
    ax[i].axis("off")
plt.show()

In [None]:
z00, z01 = Z[27], Z[28]
z10, z11 = Z[70], Z[56]

# bilinear interpolation between two latent vectors
Z_interp = []
for alpha in np.linspace(0, 1, 10):
    for beta in np.linspace(0, 1, 10):
        z = (1 - alpha) * (1 - beta) * z00 + (1 - alpha) * beta * z01 + alpha * (1 - beta) * z10 + alpha * beta * z11
        Z_interp.append(z)
# show interpolated data
X_interp = G(torch.stack(Z_interp)).cpu().detach().numpy()
fig, ax = plt.subplots(10, 10, figsize=(10, 10), sharey=True)
for i in range(10):
    for j in range(10):
        ax[i, j].imshow(X_interp[10 * i + j].reshape(28, 28), cmap="gray")
        ax[i, j].axis("off")
plt.show()

In [None]:
# 7 - 1 + 4
z1 = Z[19]
z2 = Z[49]
z3 = Z[98]

# apply model to the latent vectors
X1 = G(z1.unsqueeze(0)).cpu().detach().numpy()
X2 = G(z2.unsqueeze(0)).cpu().detach().numpy()
X3 = G(z3.unsqueeze(0)).cpu().detach().numpy()
# show the generated data
fig, ax = plt.subplots(1, 3, figsize=(6, 2), sharey=True)
ax[0].imshow(X1.reshape(28, 28), cmap="gray")
ax[0].axis("off")
ax[1].imshow(X2.reshape(28, 28), cmap="gray")
ax[1].axis("off")
ax[2].imshow(X3.reshape(28, 28), cmap="gray")
ax[2].axis("off")
plt.show()

z_r = z1 - z2 + z3
# apply model to the resulting latent vector
X_r = G(z_r.unsqueeze(0)).cpu().detach().numpy()
# show the generated data
fig, ax = plt.subplots(1, 1, figsize=(2, 2), sharey=True)
ax.imshow(X_r.reshape(28, 28), cmap="gray")
ax.axis("off")
plt.show()