In [None]:
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from vae import VariationalAutoencoder

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

pl.seed_everything(42)

In [None]:
BATCH_SIZE = 100

In [None]:
def plot(model, img):
    # Plot the samples and their reconstructions side by side
    model.to(device)
    reconstruction = model.reconstruct(img.to(device))

    original = img.view(-1, 28, 28).cpu().detach().numpy()
    reconstruction = reconstruction.view(-1, 28, 28).cpu().detach().numpy()
    fig, axs = plt.subplots(NUM_ROWS := 5, 4, figsize=(5, 10))
    for i in range(NUM_ROWS):
        for j in range(2):
            axs[i, 0 + j * 2].imshow(
                original[i + NUM_ROWS * j], cmap="gray", vmin=0, vmax=1
            )
            axs[i, 1 + j * 2].imshow(
                reconstruction[i + NUM_ROWS * j], cmap="gray", vmin=0, vmax=1
            )

    for ax in axs.flat:
        # remove x, y ticks
        ax.axis("off")

    axs[0, 0].set_title("Original")
    axs[0, 1].set_title("Reconstruction")
    axs[0, 2].set_title("Original")
    axs[0, 3].set_title("Reconstruction")

    plt.tight_layout()
    plt.show()

In [None]:
test_data = datasets.MNIST(
    root="../raw-data/", train=False, download=True, transform=transforms.ToTensor()
)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE)

In [None]:
# load trained model
model = VariationalAutoencoder()
model.load_state_dict(torch.load("trained_vae_state_dict.pt"))
plot(model, next(iter(test_loader))[0])

In [None]:
# compute the average embedding of every class in MNIST
model.to(device).eval()
z = []
y = []
with torch.no_grad():
    for x, label in test_loader:
        mu, _ = model.encode(x.to(device))
        z.append(mu)
        y.append(label)

z = torch.cat(z, dim=0).cpu().numpy()
y = torch.cat(y, dim=0).cpu().numpy()


# calculate the mean embeddings of every class
mean_z = []
for i in range(10):
    mean_z.append(z[y == i].mean(axis=0))
mean_z = np.stack(mean_z, axis=0)

with torch.no_grad():
    x_hat = (
        model.decode(torch.from_numpy(mean_z).float().to(device))
        .view(-1, 28, 28)
        .cpu()
        .numpy()
    )

In [None]:
# plot the reconstruction of the mean embeddings
fig, axs = plt.subplots(2, 5, figsize=(5, 2))
for i in range(2):
    for j in range(5):
        # set title of ax
        axs[i, j].set_title(f"{i * 5 + j}")
        axs[i, j].imshow(x_hat[i * 5 + j], cmap="gray", vmin=0, vmax=1)
        axs[i, j].axis("off")
plt.tight_layout()
plt.show()

In [None]:
from IPython.display import HTML

## Set 2 digits to interpolate beteen
m = 9
n = 4

a = mean_z[m]
b = mean_z[n]
t = np.linspace(0, 1, NUM_INTERPOLATIONS := BATCH_SIZE)
z_interp = a + t[:, None] * (b - a)
x_interp = model.decode(torch.from_numpy(z_interp).float().to(device))
x_interp = x_interp.view(-1, 28, 28).cpu().detach().numpy()

# plot the interpolation as an animation
fig, ax = plt.subplots()
im = ax.imshow(x_interp[0], cmap="gray", vmin=0, vmax=1)


def update_frame(i):
    im.set_data(x_interp[i])
    ax.set_title(f"{m} -> {n}; t = {t[i]:.4f}")


ani = animation.FuncAnimation(
    fig,
    update_frame,
    frames=range(NUM_INTERPOLATIONS),
    interval=40,  # Adjust the interval here (default: 100 milliseconds)
    repeat=True,
)

# Convert the animation to HTML format
html_animation = ani.to_jshtml()

# Display the animation in the notebook
HTML(html_animation)

In [None]:
# plot the embeddings in 2d
plt.figure(figsize=(10, 10))
plt.scatter(z[:, 0], z[:, 1], c=y, cmap="tab10")
plt.colorbar()
plt.grid()
plt.show()

# plot the mean embeddings in 2d
plt.figure(figsize=(10, 10))
plt.scatter(mean_z[:, 0], mean_z[:, 1], c=list(range(10)), cmap="tab10")
plt.colorbar()
plt.grid()
plt.show()

In [None]:
# plot the embeddings in 3d
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
scatter = ax.scatter(z[:, 0], z[:, 1], z[:, 2], c=y, cmap="tab10")
plt.colorbar(scatter)
plt.show()

# plot the mean embeddings in 3d
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
scatter = ax.scatter(
    mean_z[:, 0], mean_z[:, 1], mean_z[:, 2], c=list(range(10)), cmap="tab10"
)
plt.colorbar(scatter)
plt.show()

In [None]:
## TODO: for the mean embeddings, calculate either a travelling salesman path or a minimum spanning tree

dists = np.sqrt(((mean_z[:, None] - mean_z[None, :]) ** 2).sum(axis=-1))

from itertools import permutations
from tqdm import tqdm
import math

# Number of nodes
n = dists.shape[0]

# All possible tours
tours = permutations(range(n))

# Initial shortest distance is infinity
shortest_dist = float("inf")
shortest_tour = None

# save all tour distances
tour_dists = []

# Check all possible tours
for tour in tqdm(tours, total=math.factorial(n)):
    # Calculate total distance of the tour
    tour = tour + (tour[0],)  # add the first node to the end of the tour
    tour_dist = sum(dists[tour[i - 1], tour[i]] for i in range(n + 1))
    tour_dists.append(tour_dist)

    # If this tour is shorter, update shortest_dist and shortest_tour
    if tour_dist < shortest_dist:
        shortest_dist = tour_dist
        shortest_tour = tour

print(f"Shortest tour: {shortest_tour}")
print(f"Shortest distance: {shortest_dist:.3f}")

# tour dist stats
mean_tour_dist = np.mean(tour_dists)
std_tour_dist = np.std(tour_dists)

dist_5_percentile = np.percentile(tour_dists, 5)
dist_95_percentile = np.percentile(tour_dists, 95)

# confidence interval for tour dist
print(
    f"90% distance confidence interval [{dist_5_percentile:.3f}, {dist_95_percentile:.3f}]"
)

In [None]:
# plot the travelling salesman path
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
scatter = ax.scatter(
    mean_z[:, 0], mean_z[:, 1], mean_z[:, 2], c=list(range(10)), cmap="tab10"
)
plt.colorbar(scatter)
for i in range(9):
    ax.plot(
        [mean_z[shortest_tour[i], 0], mean_z[shortest_tour[i + 1], 0]],
        [mean_z[shortest_tour[i], 1], mean_z[shortest_tour[i + 1], 1]],
        [mean_z[shortest_tour[i], 2], mean_z[shortest_tour[i + 1], 2]],
        color="black",
    )
plt.show()

In [None]:
# print every distance between two consecutive nodes
for i in range(10):
    # print the nodes
    print(
        f"{shortest_tour[i]}, {shortest_tour[i + 1]} -> {dists[shortest_tour[i], shortest_tour[i + 1]]:.2f}"
    )

In [None]:
# Calculate the minimum spanning tree
from scipy.sparse.csgraph import minimum_spanning_tree

# Calculate the minimum spanning tree
mst = minimum_spanning_tree(dists)

# Convert to dense array
mst_array = mst.toarray().astype(float)

# Set diagonal to 0
np.fill_diagonal(mst_array, 0)

mst_edges = np.argwhere(mst_array > 0)

# plot the minimum spanning tree
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
scatter = ax.scatter(
    mean_z[:, 0], mean_z[:, 1], mean_z[:, 2], c=list(range(10)), cmap="tab10"
)
plt.colorbar(scatter)
for i in range(10):
    for j in range(i + 1, 10):
        if mst_array[i, j] > 0:
            ax.plot(
                [mean_z[i, 0], mean_z[j, 0]],
                [mean_z[i, 1], mean_z[j, 1]],
                [mean_z[i, 2], mean_z[j, 2]],
                color="black",
            )
plt.show()

In [None]:
# print minimum spanning tree distances
for edge in mst_edges:
    print(f"{edge[0]}, {edge[1]} -> {mst_array[edge[0], edge[1]]:.2f}")

print(f"Minimum spanning tree total distance: {mst_array.sum():.2f}")

In [None]:
from itertools import combinations

test_list = list(range(10))

print(list(combinations(test_list, 5)))

In [None]:
import itertools
from itertools import combinations


def generate_pairings(n):
    # Generate a list of 2n elements
    elements = list(range(2 * n))

    # Generate all possible pairings
    pairings = list(combinations(combinations(elements, 2), n))

    # Filter out pairings where elements are repeated
    valid_pairings = [
        pairing
        for pairing in pairings
        if len(set(itertools.chain.from_iterable(pairing))) == 2 * n
    ]

    return valid_pairings


n = 5  # Example for 2n = 6 elements
all_pairings = generate_pairings(n)

print(len(all_pairings))

# Print first 5 pairings
for pairing in all_pairings[:5]:
    print(pairing)

In [None]:
# Find the pairing that minimizes the sum of distances
min_sum = float("inf")
min_pairing = None

pair_dists = []
for pairing in all_pairings:
    # Calculate sum of distances
    sum_dist = sum(dists[i, j] for i, j in pairing)
    pair_dists.append(sum_dist)

    # If this sum is smaller, update min_sum and min_pairing
    if sum_dist < min_sum:
        min_sum = sum_dist
        min_pairing = pairing

print(f"Minimum sum of distances: {min_sum:.2f}")
print(f"Minimum sum pairing: {min_pairing}")

# pair dist stats
mean_pair_dist = np.mean(pair_dists)
std_pair_dist = np.std(pair_dists)

dist_5_percentile = np.percentile(pair_dists, 5)
dist_95_percentile = np.percentile(pair_dists, 95)

# confidence interval for tour dist
print(
    f"90% distance confidence interval [{dist_5_percentile:.3f}, {dist_95_percentile:.3f}]"
)

# print the distances between individual pairs
for pairing in min_pairing:
    print(f"{pairing[0]}, {pairing[1]} -> {dists[pairing[0], pairing[1]]:.2f}")