In [None]:
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

import sys

sys.path.append("../")

from data.image_datasets import CIFARDataModule, plot_batch
from cifar_vae.vae import VariationalAutoencoder

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

pl.seed_everything(42)

In [None]:
dm = CIFARDataModule(
    data_dir="../raw-data/",
    batch_size=2048,
    num_train=45000,
    num_val=5000,
    num_test=10000,
    pad_for_rotation=False,
    variant="10",
    seed=42,
    random_labels=False,
)
dm.setup(None)

test_loader = dm.test_dataloader()

In [None]:
model = VariationalAutoencoder.load_from_checkpoint("trained_vae_state_dict.pt")
model.eval()

In [None]:
original_ims = next(iter(test_loader))[0][:10].to(model.device)
print("originals")
plot_batch(
    original_ims,
    grid_side_len=(2, 5),
    save_fig=False,
)

original_ims_decodings = model.reconstruct(original_ims)
print("original decodings")
plot_batch(
    original_ims_decodings,
    grid_side_len=(2, 5),
    save_fig=False,
)

best_decodings = model.decoder(model.best_encodings)[0].cpu().clamp(-1, 1)

print("best decodings")
plot_batch(
    best_decodings,
    grid_side_len=(2, 5),
    save_fig=False,
)

In [None]:
from IPython.display import HTML

BATCH_SIZE = 100

model = model.cpu()
mean_z = model.best_encodings.cpu().detach().numpy()

## Set 2 digits to interpolate beteen
m = 8
n = 2

a = mean_z[m]
b = mean_z[n]

# a = original_ims[2].cpu().detach().numpy()
# b = original_ims[8].cpu().detach().numpy()
t = np.linspace(0, 1, NUM_INTERPOLATIONS := BATCH_SIZE)
# reshape t for broadcasting
num_dims = len(a.shape)
t = t.reshape(-1, *([1] * num_dims))
print(t.shape)
print(a.shape)
z_interp = a * (1 - t) + b * t

x_interp = model.decoder(torch.from_numpy(z_interp).float().to(device))[0]
print(x_interp.shape)
x_interp = (x_interp.view(-1, 3, 32, 32).permute(0, 2, 3, 1) / 2 + 0.5).cpu().detach()

# plot the interpolation as an animation
fig, ax = plt.subplots()
# im = ax.imshow(x_interp[0], cmap="gray", vmin=0, vmax=1)
print(x_interp[0].min(), x_interp[0].max())
im = ax.imshow(x_interp[0], cmap="gray")


def update_frame(i):
    im.set_data(x_interp[i])
    # ax.set_title(f"{m} -> {n}; t = {t[i]:.4f}")


ani = animation.FuncAnimation(
    fig,
    update_frame,
    frames=range(NUM_INTERPOLATIONS),
    interval=40,  # Adjust the interval here (default: 100 milliseconds)
    repeat=True,
)

# Convert the animation to HTML format
html_animation = ani.to_jshtml()

# Display the animation in the notebook
HTML(html_animation)

In [None]:
## TODO: for the mean embeddings, calculate either a travelling salesman path or a minimum spanning tree

mean_z = mean_z.reshape(mean_z.shape[0], -1)
dists = np.sqrt(((mean_z[:, None] - mean_z[None, :]) ** 2).sum(axis=-1))

from itertools import permutations
from tqdm import tqdm
import math

# Number of nodes
n = dists.shape[0]

# All possible tours
tours = permutations(range(n))

# Initial shortest distance is infinity
shortest_dist = float("inf")
shortest_tour = None

# save all tour distances
tour_dists = []

# Check all possible tours
for tour in tqdm(tours, total=math.factorial(n)):
    # Calculate total distance of the tour
    tour = tour + (tour[0],)  # add the first node to the end of the tour
    tour_dist = sum(dists[tour[i - 1], tour[i]] for i in range(n + 1))
    tour_dists.append(tour_dist)

    # If this tour is shorter, update shortest_dist and shortest_tour
    if tour_dist < shortest_dist:
        shortest_dist = tour_dist
        shortest_tour = tour

print(f"Shortest tour: {shortest_tour}")
print(f"Shortest distance: {shortest_dist:.3f}")

# tour dist stats
mean_tour_dist = np.mean(tour_dists)
std_tour_dist = np.std(tour_dists)

dist_5_percentile = np.percentile(tour_dists, 5)
dist_95_percentile = np.percentile(tour_dists, 95)

# confidence interval for tour dist
print(
    f"90% distance confidence interval [{dist_5_percentile:.3f}, {dist_95_percentile:.3f}]"
)

In [None]:
# plot the travelling salesman path
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
scatter = ax.scatter(
    mean_z[:, 0], mean_z[:, 1], mean_z[:, 2], c=list(range(10)), cmap="tab10"
)
plt.colorbar(scatter)
for i in range(10):
    ax.plot(
        [mean_z[shortest_tour[i], 0], mean_z[shortest_tour[i + 1], 0]],
        [mean_z[shortest_tour[i], 1], mean_z[shortest_tour[i + 1], 1]],
        [mean_z[shortest_tour[i], 2], mean_z[shortest_tour[i + 1], 2]],
        color="black",
    )
plt.show()

In [None]:
# print every distance between two consecutive nodes
total_tour_dist = 0
for i in range(10):
    # print the nodes
    connection_distance = dists[shortest_tour[i], shortest_tour[i + 1]]
    print(f"{shortest_tour[i]}, {shortest_tour[i + 1]} -> {connection_distance:.2f}")
    total_tour_dist += connection_distance

print(f"total distance: {total_tour_dist:.2f}")
print("tour tuple:")
print(shortest_tour)

In [None]:
# Calculate the minimum spanning tree
from scipy.sparse.csgraph import minimum_spanning_tree

# Calculate the minimum spanning tree
mst = minimum_spanning_tree(dists)

# Convert to dense array
mst_array = mst.toarray().astype(float)

# Set diagonal to 0
np.fill_diagonal(mst_array, 0)

mst_edges = np.argwhere(mst_array > 0)

# plot the minimum spanning tree
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
scatter = ax.scatter(
    mean_z[:, 0], mean_z[:, 1], mean_z[:, 2], c=list(range(10)), cmap="tab10"
)
plt.colorbar(scatter)
for i in range(10):
    for j in range(i + 1, 10):
        if mst_array[i, j] > 0:
            ax.plot(
                [mean_z[i, 0], mean_z[j, 0]],
                [mean_z[i, 1], mean_z[j, 1]],
                [mean_z[i, 2], mean_z[j, 2]],
                color="black",
            )
plt.show()

In [None]:
# print minimum spanning tree distances
for edge in mst_edges:
    print(f"{edge[0]}, {edge[1]} -> {mst_array[edge[0], edge[1]]:.2f}")

print(f"Minimum spanning tree total distance: {mst_array.sum():.2f}")

In [None]:
from itertools import combinations

test_list = list(range(10))

print(list(combinations(test_list, 5)))

In [None]:
import itertools
from itertools import combinations


def generate_pairings(n):
    # Generate a list of 2n elements
    elements = list(range(2 * n))

    # Generate all possible pairings
    pairings = list(combinations(combinations(elements, 2), n))

    # Filter out pairings where elements are repeated
    valid_pairings = [
        pairing
        for pairing in pairings
        if len(set(itertools.chain.from_iterable(pairing))) == 2 * n
    ]

    return valid_pairings


n = 5  # Example for 2n = 6 elements
all_pairings = generate_pairings(n)

print(len(all_pairings))

# Print first 5 pairings
for pairing in all_pairings[:5]:
    print(pairing)

In [None]:
# Find the pairing that minimizes the sum of distances
min_sum = float("inf")
min_pairing = None

pair_dists = []
for pairing in all_pairings:
    # Calculate sum of distances
    sum_dist = sum(dists[i, j] for i, j in pairing)
    pair_dists.append(sum_dist)

    # If this sum is smaller, update min_sum and min_pairing
    if sum_dist < min_sum:
        min_sum = sum_dist
        min_pairing = pairing

print(f"Minimum sum of distances: {min_sum:.2f}")
print(f"Minimum sum pairing: {min_pairing}")

# pair dist stats
mean_pair_dist = np.mean(pair_dists)
std_pair_dist = np.std(pair_dists)

dist_5_percentile = np.percentile(pair_dists, 5)
dist_95_percentile = np.percentile(pair_dists, 95)

# confidence interval for tour dist
print(
    f"90% distance confidence interval [{dist_5_percentile:.3f}, {dist_95_percentile:.3f}]"
)

# print the distances between individual pairs
for pairing in min_pairing:
    print(f"{pairing[0]}, {pairing[1]} -> {dists[pairing[0], pairing[1]]:.2f}")