# MNIST


In [None]:
import os
import sys
import time
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Set the GPUs to use
sys.path.append(os.path.join(os.getcwd(), ".."))

import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch.nn import DataParallel
import torch.optim as optim
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")

from src.mnist.datasets import get_mnist_dataloaders, convert_flattened_to_image, get_merged_labels
from src.mnist.models import CVAE, InvariantAutoEncoder, InvariantVariationalAutoEncoder, ProxyRep2InvarRep, ProxyRep2Label
from src.mnist.losses import CVAE_Loss


# Merged group

In [None]:
BATCH_SIZE = 500
BETA = 5e-4  # Scaled by batch size for stability
REDUCTION = 'mean'
EPOCHS = 50
LR = 5e-4  # Learning rate
CVAE_PATH = "../checkpoints/mnist/group_similar/cvae.pth"

mnist_train, mnist_val, mnist_test = get_mnist_dataloaders("../data", one_hot=False)

merge_group = [
    [0, 6],
    [1],
    [4, 7, 9],
    [2, 3, 5, 8]
]

cvae = CVAE(28 * 28, 128, 64, num_classes=4)
cvae = cvae.to("cuda")

if os.path.exists(CVAE_PATH):
    print("Loading pre-trained CVAE model...")
    cvae.load_state_dict(torch.load(CVAE_PATH, map_location=device))
else:
    print("No pre-trained model found. Starting from scratch.")
    optimizer = optim.Adam(cvae.parameters(), lr=LR)
    loss_fn = CVAE_Loss(beta=BETA, reduction=REDUCTION)

    losses_train = []
    losses_val = []
    for epoch in range(EPOCHS):
        print(f"Epoch {epoch + 1}/{EPOCHS}")
        cvae.train()
        total_loss = 0
        for i, (x, y) in enumerate(mnist_train):
            x = x.to(device)
            y = get_merged_labels(y, merge_group).to(device)
            optimizer.zero_grad()
            recon_x, mu, logvar = cvae(x, y)
            loss = loss_fn(recon_x, x, mu, logvar)
            loss.backward()
            losses_train.append(loss.item())
            optimizer.step()
            total_loss += loss.item()
            print(f"{i + 1:>3}/{len(mnist_train)}: {loss.item():.4f}", end="\r")
        with torch.no_grad():
            cvae.eval()
            total_val_loss = 0
            for x, y in mnist_val:
                x = x.to(device)
                y = get_merged_labels(y, merge_group).to(device)
                recon_x, mu, logvar = cvae(x, y)
                val_loss = loss_fn(recon_x, x, mu, logvar)
                total_val_loss += val_loss.item()
            avg_val_loss = total_val_loss / len(mnist_val)
            losses_val.append(avg_val_loss)
            print(f"\nEpoch {epoch + 1} Validation Loss: {avg_val_loss:.4f}")

    torch.save(cvae.state_dict(), CVAE_PATH)

    nb_minibatches = len(mnist_train)

    plt.figure(figsize=(10, 5))
    plt.plot(losses_train, label='Training Loss', color='blue')
    plt.plot(range(nb_minibatches, nb_minibatches * EPOCHS + 1, nb_minibatches),
             losses_val, label='Validation Loss', color='orange')
    plt.xlabel('Batch Iterations')
    plt.ylabel('Loss')
    plt.title('Training and Validation Losses')
    plt.legend()
    plt.show()


In [None]:
NB_SAMPLES = 8

# MERGE_GROUP = [
#     [0, 6],
#     [1],
#     [4, 7, 9],
#     [2, 3, 5, 8]
# ]

for x, y_real in mnist_test:
    # print(x.shape, y.shape)
    x = x.to(device)
    y = get_merged_labels(y_real)
    y = y.to(device)
    break

cvae.encoder(x)
y_shuffled = y[torch.randperm(y.size(0))]
x_shuffled_recon, mu, logvar = cvae(x, y_shuffled)
x_recon = cvae(x, y)[0]

fig, axs = plt.subplots(3, NB_SAMPLES, figsize=(NB_SAMPLES*2, 6))
for idx in range(NB_SAMPLES):
    axs[0, idx].imshow(x[idx].cpu().reshape(28, 28), cmap="gray")
    axs[0, idx].set_title(f"{y_real[idx].item()} in {merge_group[y[idx].argmax()]}")
    axs[0, idx].set_xticks([])
    axs[0, idx].set_yticks([])

    axs[1, idx].imshow(x_recon[idx].detach().cpu().reshape(28, 28), cmap="gray")
    axs[1, idx].set_title(f" ")
    axs[1, idx].set_xticks([])
    axs[1, idx].set_yticks([])

    axs[2, idx].imshow(x_shuffled_recon[idx].detach().cpu().reshape(28, 28), cmap="gray")
    axs[2, idx].set_title(f"$S$ = {merge_group[y_shuffled[idx].argmax()]}")
    axs[2, idx].set_xticks([])
    axs[2, idx].set_yticks([])
axs[0, 0].set_ylabel("Original")
axs[1, 0].set_ylabel("Reconstructed")
axs[2, 0].set_ylabel("Shuffled")
plt.tight_layout()
plt.show()

In [None]:
Z2_DIM = 4  # Dimension of the latent space for invariant VAE
INVAE_PATH = "../checkpoints/mnist/group_similar/invariant_vae.pth"
# make cvae not trainable
for param in cvae.parameters():
    param.requires_grad = False
# make invariant vae trainable
invae = InvariantVariationalAutoEncoder(28 * 28, 128, Z2_DIM, cvae=cvae)
invae = invae.to(device)

if os.path.exists(INVAE_PATH):
    print("Loading pre-trained Invariant AE model...")
    invae.load_state_dict(torch.load(INVAE_PATH, map_location=device))
else:
    print("No pre-trained model found. Starting from scratch.")

    optimizer = optim.Adam(invae.parameters(), lr=1e-3)
    loss_fn = CVAE_Loss(beta=BETA, reduction=REDUCTION)

    losses_train = []
    losses_val = []
    for epoch in range(EPOCHS):
        print(f"Epoch {epoch + 1}/{EPOCHS}")
        invae.train()
        for i, (x, y) in enumerate(tqdm(mnist_train)):
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            recon_x, mu, logvar = invae(x)
            loss = loss_fn(recon_x, x, mu, logvar)
            loss.backward()
            losses_train.append(loss.item())
            optimizer.step()
            # print(f"{i + 1:>3}/{len(mnist_train)}: {loss.item():.4f}", end="\r")
        with torch.no_grad():
            invae.eval()
            total_val_loss = 0
            for x, y in mnist_val:
                x = x.to(device)
                y = y.to(device)
                recon_x, mu, logvar = invae(x)
                val_loss = loss_fn(recon_x, x, mu, logvar)
                total_val_loss += val_loss.item()
            avg_val_loss = total_val_loss / len(mnist_val)
            losses_val.append(avg_val_loss)
            print(f"\nEpoch {epoch + 1} Validation Loss: {avg_val_loss:.4f}")
    torch.save(invae.state_dict(), INVAE_PATH)

    nb_minibatches = len(mnist_train)
    plt.figure(figsize=(10, 5))
    plt.plot(losses_train, label='Training Loss', color='blue')
    plt.plot(range(nb_minibatches, nb_minibatches * EPOCHS + 1, nb_minibatches),
             losses_val, label='Validation Loss', color='orange')
    plt.xlabel('Batch Iterations')
    plt.ylabel('Loss')
    plt.title('Training and Validation Losses')
    plt.legend()
    plt.show()

In [None]:
NB_SAMPLES = 8
NB_subplots = 8
nb_figures = NB_SAMPLES // NB_subplots

for x, y_real in mnist_test:
    # print(x.shape, y.shape)
    x = x.to(device)
    y = get_merged_labels(y_real)
    y = y.to(device)
    break

x_recon_cvae = cvae(x, y)[0]
y_shuffled = y[torch.randperm(y.size(0))]
x_shuffled_recon, mu, logvar = cvae(x, y_shuffled)
x_recon_invae = invae(x)[0]

for f in range(nb_figures):
    fig, axs = plt.subplots(4, NB_subplots, figsize=(NB_subplots*2, 8), tight_layout=True)
    for idx in range(min(NB_subplots, NB_SAMPLES - f * NB_subplots)):
        axs[0, idx].imshow(x[f * NB_subplots + idx].cpu().reshape(28, 28), cmap="gray")
        axs[0, idx].set_title(f"{y_real[idx].item()} in {merge_group[y[idx].argmax()]}")
        axs[0, idx].set_xticks([])
        axs[0, idx].set_yticks([])

        axs[1, idx].imshow(x_recon_cvae[f * NB_subplots + idx].detach().cpu().reshape(28, 28), cmap="gray")
        axs[1, idx].set_title(" ")
        axs[1, idx].set_xticks([])
        axs[1, idx].set_yticks([])

        axs[2, idx].imshow(x_shuffled_recon[idx].detach().cpu().reshape(28, 28), cmap="gray")
        axs[2, idx].set_title(f"$S$ = {merge_group[y_shuffled[idx].argmax()]}")
        axs[2, idx].set_xticks([])
        axs[2, idx].set_yticks([])

        axs[3, idx].imshow(x_recon_invae[f * NB_subplots + idx].detach().cpu().reshape(28, 28), cmap="gray")
        axs[3, idx].set_title(" ")
        axs[3, idx].set_xticks([])
        axs[3, idx].set_yticks([])
    axs[0, 0].set_ylabel("Original", fontsize=15)
    axs[1, 0].set_ylabel("CVAE Recon.", fontsize=15)
    axs[2, 0].set_ylabel("Shuffled Recon.", fontsize=15)
    axs[3, 0].set_ylabel("InvVAE Recon.", fontsize=15)
    plt.tight_layout()
    plt.show()
    print(" ")

In [None]:
LIMIT = 3
x = []
y = []
for x_tmp, y_tmp in mnist_test:
    x.append(x_tmp)
    y.append(y_tmp)
x = torch.cat(x).to(device)
y = torch.cat(y).to(device)
label = y.detach().cpu().flatten()
x_recon_inv_vae = invae(x)
z = invae.encoder(x)
mu = z[:, :invae.latent_dim]
logvar = z[:, invae.latent_dim:]
z = invae.reparameterize(mu, logvar)
z = z.detach().cpu()
z_centers = []
for i in range(10):
    center = z[label == i].mean(axis=0)
    z_centers.append(center)
z_centers = torch.stack(z_centers)

fig, axes = plt.subplots(Z2_DIM - 1, Z2_DIM - 1, figsize=(5 * (Z2_DIM - 1), 5 * (Z2_DIM - 1)), tight_layout=True)
for i in range(Z2_DIM - 1):
    for j in range(1, Z2_DIM):
        if i < j:
            im = axes[i, j-1].scatter(z_centers[:, i], z_centers[:, j],
                                    c=range(10), cmap="tab10", vmin=-.5, vmax=9.5)
            axes[i, j-1].scatter(z[:, i], z[:, j],
                            c=label, cmap="tab10", alpha=0.1)
            # axes[i, j-1].set_aspect('equal', adjustable='box')
            # axes[i, j-1].set_xlim(-LIMIT, LIMIT)
            # axes[i, j-1].set_ylim(-LIMIT, LIMIT)
            for d in range(10):
                axes[i, j-1].text(z_centers[d][i], z_centers[d][j], str(d), fontsize=15, ha='center', va='center')
        else:
            axes[i, j-1].axis('off')

for i in range(Z2_DIM - 1):
    for j in range(Z2_DIM - 1):
        if i == j :
            axes[i, j].set_ylabel(f"$z_2$[{i}]", fontsize=15)
            axes[i, j].set_xlabel(f"$z_2$[{j + 1}]", fontsize=15)
        elif i > j:
            axes[i, j].axis('off')
        # else:
        #     axes[i, j].set_xticks([])
        #     axes[i, j].set_yticks([])

plt.suptitle("Latent Space of Invariant VAE", fontsize=20)
cbar = plt.colorbar(im, ax=axes, orientation='horizontal',
                    fraction=0.02, pad=0.0)
cbar.set_label('Digit Label', fontsize=15)
cbar.ax.tick_params(labelsize=12)
cbar.set_ticks(range(10))
plt.show()

In [None]:
PROXY2INVAR_PATH = "../checkpoints/mnist/group_similar/vaeproxy2invar.pth"

for param in invae.parameters():
    param.requires_grad = False

proxy2invar = ProxyRep2InvarRep(autoencoder=invae, reparameterize=True)
proxy2invar = proxy2invar.to(device)

if os.path.exists(PROXY2INVAR_PATH):
    print("Loading pre-trained Proxy2Invar model...")
    proxy2invar.load_state_dict(torch.load(PROXY2INVAR_PATH, map_location=device))
else:
    print("No pre-trained model found. Starting from scratch.")
    optimizer = optim.Adam(proxy2invar.parameters(), lr=1e-3)
    loss_fn = torch.nn.MSELoss(reduction='mean')

    losses_train = []
    losses_val = []
    for epoch in range(EPOCHS):
        print(f"Epoch {epoch + 1}/{EPOCHS}")
        proxy2invar.train()
        for i, (x, y) in enumerate(tqdm(mnist_train)):
            x = x.to(device)
            
            # Get real invariant representation from CVAE
            with torch.no_grad():
                z_invar = invae.cvae.encoder(x)
                mu = z_invar[:, :invae.cvae.latent_dim]
                logvar = z_invar[:, invae.cvae.latent_dim:]
                z_invar = invae.cvae.reparameterize(mu, logvar)

            z_invar_pred = proxy2invar(x)
            optimizer.zero_grad()
            loss = loss_fn(z_invar_pred, z_invar)
            loss.backward()
            losses_train.append(loss.item())
            optimizer.step()

        proxy2invar.eval()
        with torch.no_grad():
            total_val_loss = 0
            for x, y in mnist_val:
                x = x.to(device)
                z_invar = invae.cvae.encoder(x)
                mu = z_invar[:, :invae.cvae.latent_dim]
                logvar = z_invar[:, invae.cvae.latent_dim:]
                z_invar = invae.cvae.reparameterize(mu, logvar)

                z_invar_pred = proxy2invar(x)
                val_loss = loss_fn(z_invar_pred, z_invar)
                total_val_loss += val_loss.item()
            avg_val_loss = total_val_loss / len(mnist_val)
            losses_val.append(avg_val_loss)
            print(f"\nEpoch {epoch + 1} Validation Loss: {avg_val_loss:.4f}")
    torch.save(proxy2invar.state_dict(), PROXY2INVAR_PATH)

    nb_minibatches = len(mnist_train)

    plt.figure(figsize=(10, 5))
    plt.plot(losses_train, label='Training Loss', color='blue')
    plt.plot(range(nb_minibatches, nb_minibatches * EPOCHS + 1, nb_minibatches),
                losses_val, label='Validation Loss', color='orange')
    plt.xlabel('Batch Iterations')
    plt.ylabel('Loss')
    plt.title('VAE Proxy to Invariant Representation MSE')
    plt.legend()
    plt.xscale('log')
    plt.ylim(0.94, 1.1)

    plt.show()

In [None]:
proxy2group = ProxyRep2Label(autoencoder=invae, reparameterize=True, nb_labels=len(merge_group))
proxy2group.mlp

In [None]:
PROXY2GROUP_PATH = "../checkpoints/mnist/group_similar/vaeproxy2group.pth"

proxy2group = ProxyRep2Label(autoencoder=invae, reparameterize=True, nb_labels=len(merge_group))
proxy2group = proxy2group.to(device)
if os.path.exists(PROXY2GROUP_PATH):
    print("Loading pre-trained Proxy2Group model...")
    proxy2group.load_state_dict(torch.load(PROXY2GROUP_PATH, map_location=device))
else:
    print("No pre-trained model found. Starting from scratch.")
    optimizer = optim.Adam(proxy2group.parameters(), lr=1e-3)
    loss_fn = torch.nn.CrossEntropyLoss(reduction='mean')

    losses_train = []
    losses_val = []
    accuracies_val = []
    for epoch in range(EPOCHS):
        print(f"Epoch {epoch + 1}/{EPOCHS}")
        proxy2group.train()
        for i, (x, y) in enumerate(tqdm(mnist_train)):
            x = x.to(device)
            y = get_merged_labels(y, merge_group).float().to(device)
            optimizer.zero_grad()
            y_pred = proxy2group(x)
            loss = loss_fn(y_pred, y)
            loss.backward()
            losses_train.append(loss.item())
            optimizer.step()
        with torch.no_grad():
            proxy2group.eval()
            total_val_loss = 0
            total_val_accu = 0
            for x, y in mnist_val:
                x = x.to(device)
                y = get_merged_labels(y, merge_group).float().to(device)
                y_pred = proxy2group(x)
                val_loss = loss_fn(y_pred, y)
                total_val_loss += val_loss.item()
                
                group_pred = y_pred.argmax(dim=1)
                group_true = y.argmax(dim=1)
                acc = (group_pred == group_true).float().mean().item()
                total_val_accu += acc

            avg_val_loss = total_val_loss / len(mnist_val)
            avg_val_accu = total_val_accu / len(mnist_val)
            accuracies_val.append(avg_val_accu)
            losses_val.append(avg_val_loss)
            print(f"\nEpoch {epoch + 1} Validation Loss: {avg_val_loss:.4f}, Accuracy: {avg_val_accu:.4f}")
    torch.save(proxy2group.state_dict(), PROXY2GROUP_PATH)

    nb_minibatches = len(mnist_train)

    fig, axes = plt.subplots(2, 1, figsize=(10, 10))
    axes[0].plot(losses_train, label='Training Loss', color='blue')
    axes[0].plot(range(nb_minibatches, nb_minibatches * EPOCHS + 1, nb_minibatches),
                losses_val, label='Validation Loss', color='orange')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('VAE Proxy to Label Cross Entropy')
    axes[0].legend()
    axes[0].set_xscale('log')
    axes[0].set_xticks([])
    xlim = axes[0].get_xlim()

    axes[1].plot(range(nb_minibatches, nb_minibatches * EPOCHS + 1, nb_minibatches),
                accuracies_val, label='Validation Accuracy', color='green')
    axes[1].set_xlabel('Batch Iterations')
    axes[1].set_ylabel('Accuracy')
    axes[1].set_title('VAE Proxy to Label Accuracy')
    axes[1].legend()
    axes[1].set_xscale('log')
    axes[1].set_xlim(xlim)
    plt.tight_layout()
    plt.show()

In [None]:
proxy2group.eval()

group_true_all = []
group_pred_all = []

with torch.no_grad():
    for x, y in mnist_test:
        x = x.to(device)
        y = get_merged_labels(y, merge_group).float().to(device)
        y_pred = proxy2group(x)
        group_pred = y_pred.argmax(dim=1)
        group_true = y.argmax(dim=1)

        group_true_all.append(group_true.cpu())
        group_pred_all.append(group_pred.cpu())
group_true_all = torch.cat(group_true_all)
group_pred_all = torch.cat(group_pred_all)
from sklearn.metrics import confusion_matrix

cm = confusion_matrix(group_true_all, group_pred_all)
plt.figure(figsize=(8, 6))
plt.imshow(cm, cmap='Blues', interpolation='nearest')
plt.xlabel('Predicted Group')
plt.ylabel('True Group')
plt.title('Confusion Matrix for Group Predictions')
plt.xticks(range(len(merge_group)), [str(g) for g in merge_group], rotation=45)
plt.yticks(range(len(merge_group)), [str(g) for g in merge_group])
plt.colorbar()
plt.tight_layout()
plt.show()

In [None]:
PROXY2LAB_PATH = "../checkpoints/mnist/group_similar/vaeproxy2lab.pth"

proxy2lab = ProxyRep2Label(autoencoder=invae, reparameterize=True, nb_labels=10)
proxy2lab = proxy2lab.to(device)
mnist_train, mnist_val, mnist_test = get_mnist_dataloaders("../data", one_hot=True, batch_size=BATCH_SIZE)


if os.path.exists(PROXY2LAB_PATH):
    print("Loading pre-trained Proxy2Label model...")
    proxy2lab.load_state_dict(torch.load(PROXY2LAB_PATH, map_location=device))
else:
    print("No pre-trained model found. Starting from scratch.")
    optimizer = optim.Adam(proxy2lab.parameters(), lr=1e-3)
    loss_fn = torch.nn.CrossEntropyLoss(reduction='mean')

    losses_train = []
    losses_val = []
    accuracies_val = []
    for epoch in range(EPOCHS):
        print(f"Epoch {epoch + 1}/{EPOCHS}")
        proxy2lab.train()
        for i, (x, y) in enumerate(tqdm(mnist_train)):
            x = x.to(device)
            y = y.float().to(device)
            optimizer.zero_grad()
            y_pred = proxy2lab(x)
            loss = loss_fn(y_pred, y)
            loss.backward()
            losses_train.append(loss.item())
            optimizer.step()
        with torch.no_grad():
            proxy2lab.eval()
            total_val_loss = 0
            total_val_accu = 0

            for x, y in mnist_val:
                x = x.to(device)
                y = y.float().to(device)
                y_pred = proxy2lab(x)
                val_loss = loss_fn(y_pred, y)
                total_val_loss += val_loss.item()

                lab_pred = y_pred.argmax(axis=1)
                lab_true = y.argmax(axis=1)
                acc = (lab_pred == lab_true).float().mean().item()
                total_val_accu += acc

            avg_val_loss = total_val_loss / len(mnist_val)
            avg_val_accu = total_val_accu / len(mnist_val)
            accuracies_val.append(avg_val_accu)
            print(f"\nEpoch {epoch + 1} Validation Loss: {avg_val_loss:.4f}, Accuracy: {avg_val_accu:.4f}")
            losses_val.append(avg_val_loss)
            print(f"\nEpoch {epoch + 1} Validation Loss: {avg_val_loss:.4f}")


    torch.save(proxy2lab.state_dict(), PROXY2LAB_PATH)

    nb_minibatches = len(mnist_train)

    fig, axes = plt.subplots(2, 1, figsize=(10, 10))
    axes[0].plot(losses_train, label='Training Loss', color='blue')
    axes[0].plot(range(nb_minibatches, nb_minibatches * EPOCHS + 1, nb_minibatches),
                losses_val, label='Validation Loss', color='orange')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('VAE Proxy to Group Cross Entropy')
    axes[0].legend()
    axes[0].set_xscale('log')
    axes[0].set_xticks([])
    xlim = axes[0].get_xlim()

    axes[1].plot(range(nb_minibatches, nb_minibatches * EPOCHS + 1, nb_minibatches),
                accuracies_val, label='Validation Accuracy', color='green')
    axes[1].set_xlabel('Batch Iterations')
    axes[1].set_ylabel('Accuracy')
    axes[1].set_title('VAE Proxy to Group Accuracy')
    axes[1].legend()
    axes[1].set_xscale('log')
    axes[1].set_xlim(xlim)
    plt.tight_layout()
    plt.show()

In [None]:
proxy2lab.eval()

lab_true_all = []
lab_pred_all = []

for x, y in mnist_test:
    x = x.to(device)
    y = y.float().to(device)
    lab_true = y.argmax(axis=1)
    lab_true_all.append(lab_true)

    y_pred = proxy2lab(x)
    lab_pred = y_pred.argmax(axis=1)
    lab_pred_all.append(lab_pred)
lab_true_all = torch.cat(lab_true_all).detach().cpu()
lab_pred_all = torch.cat(lab_pred_all).detach().cpu()
confusion_matrix = torch.zeros(10, 10, dtype=torch.int64)
for true, pred in zip(lab_true_all, lab_pred_all):
    confusion_matrix[true, pred] += 1
accuracy = (lab_true_all == lab_pred_all).float().mean().item()

plt.figure(figsize=(8, 6))
plt.imshow(confusion_matrix, cmap='Blues', interpolation='nearest')
plt.colorbar()
plt.title(f"Confusion Matrix of AE Proxy to Label (acc.: {accuracy:.3f})")
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.xticks(ticks=range(10), labels=range(10))
plt.yticks(ticks=range(10), labels=range(10))
plt.grid(False)
plt.show()

In [None]:
Z2_DIM = 4  # Dimension of the latent space for invariant VAE
INAE_PATH = "../checkpoints/mnist/group_similar/invariant_ae.pth"
# make cvae not trainable
for param in cvae.parameters():
    param.requires_grad = False
# make invariant vae trainable
inae = InvariantAutoEncoder(28 * 28, 128, Z2_DIM, cvae=cvae)
inae = inae.to(device)

if os.path.exists(INAE_PATH):
    print("Loading pre-trained Invariant AE model...")
    inae.load_state_dict(torch.load(INAE_PATH, map_location=device))
else:
    print("No pre-trained model found. Starting from scratch.")

    optimizer = optim.Adam(inae.parameters(), lr=1e-3)
    loss_fn = torch.nn.MSELoss(reduction='mean')

    losses_train = []
    losses_val = []
    for epoch in range(EPOCHS):
        print(f"Epoch {epoch + 1}/{EPOCHS}")
        inae.train()
        for i, (x, y) in enumerate(tqdm(mnist_train)):
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            recon_x = inae(x)
            loss = loss_fn(recon_x, x)
            loss.backward()
            losses_train.append(loss.item())
            optimizer.step()
            # print(f"{i + 1:>3}/{len(mnist_train)}: {loss.item():.4f}", end="\r")
        with torch.no_grad():
            inae.eval()
            total_val_loss = 0
            for x, y in mnist_val:
                x = x.to(device)
                y = y.to(device)
                recon_x = inae(x)
                val_loss = loss_fn(recon_x, x)
                total_val_loss += val_loss.item()
            avg_val_loss = total_val_loss / len(mnist_val)
            losses_val.append(avg_val_loss)
            print(f"\nEpoch {epoch + 1} Validation Loss: {avg_val_loss:.4f}")
    torch.save(inae.state_dict(), INAE_PATH)

    nb_minibatches = len(mnist_train)

    plt.figure(figsize=(10, 5))
    plt.plot(losses_train, label='Training Loss', color='blue')
    plt.plot(range(nb_minibatches, nb_minibatches * EPOCHS + 1, nb_minibatches),
             losses_val, label='Validation Loss', color='orange')
    plt.xlabel('Batch Iterations')
    plt.ylabel('Loss')
    plt.title('Training and Validation Losses')
    plt.legend()

    plt.show()

In [None]:
NB_SAMPLES = 8
NB_subplots = 8
nb_figures = NB_SAMPLES // NB_subplots

mnist_train, mnist_val, mnist_test = get_mnist_dataloaders("../data", one_hot=False, batch_size=BATCH_SIZE)

for x, y_real in mnist_test:
    # print(x.shape, y.shape)
    x = x.to(device)
    y = get_merged_labels(y_real, merge_group)
    y = y.to(device)
    break

x_recon_cvae = cvae(x, y)[0]
y_shuffled = y[torch.randperm(y.size(0))]
x_shuffled_recon, mu, logvar = cvae(x, y_shuffled)
x_recon_inae = inae(x)

for f in range(nb_figures):
    fig, axs = plt.subplots(4, NB_subplots, figsize=(NB_subplots*2, 8), tight_layout=True)
    for idx in range(min(NB_subplots, NB_SAMPLES - f * NB_subplots)):
        axs[0, idx].imshow(x[f * NB_subplots + idx].cpu().reshape(28, 28), cmap="gray")
        axs[0, idx].set_title(f"{y_real[idx].item()} in {merge_group[y[idx].argmax()]}")
        axs[0, idx].set_xticks([])
        axs[0, idx].set_yticks([])

        axs[1, idx].imshow(x_recon_cvae[f * NB_subplots + idx].detach().cpu().reshape(28, 28), cmap="gray")
        axs[1, idx].set_title(" ")
        axs[1, idx].set_xticks([])
        axs[1, idx].set_yticks([])

        axs[2, idx].imshow(x_shuffled_recon[idx].detach().cpu().reshape(28, 28), cmap="gray")
        axs[2, idx].set_title(f"$S$ = {merge_group[y_shuffled[idx].argmax()]}")
        axs[2, idx].set_xticks([])
        axs[2, idx].set_yticks([])

        axs[3, idx].imshow(x_recon_inae[f * NB_subplots + idx].detach().cpu().reshape(28, 28), cmap="gray")
        axs[3, idx].set_title(" ")
        axs[3, idx].set_xticks([])
        axs[3, idx].set_yticks([])
    axs[0, 0].set_ylabel("Original", fontsize=15)
    axs[1, 0].set_ylabel("CVAE Recon.", fontsize=15)
    axs[2, 0].set_ylabel("Shuffled Recon.", fontsize=15)
    axs[3, 0].set_ylabel("Inv AE Recon.", fontsize=15)
    plt.tight_layout()
    plt.show()
    print(" ")

In [None]:
LIMIT = 3
x = []
y = []
for x_tmp, y_tmp in mnist_test:
    x.append(x_tmp)
    y.append(y_tmp)
x = torch.cat(x).to(device)
y = torch.cat(y).to(device)
label = y.detach().cpu().flatten()
x_recon_inv_vae = inae(x)
z = inae.encoder(x)
z = z.detach().cpu()
z_centers = []
for i in range(10):
    center = z[label == i].mean(axis=0)
    z_centers.append(center)
z_centers = torch.stack(z_centers)

fig, axes = plt.subplots(Z2_DIM - 1, Z2_DIM - 1, figsize=(5 * (Z2_DIM - 1), 5 * (Z2_DIM - 1)), tight_layout=True)
for i in range(Z2_DIM - 1):
    for j in range(1, Z2_DIM):
        if i < j:
            im = axes[i, j-1].scatter(z_centers[:, i], z_centers[:, j],
                                    c=range(10), cmap="tab10", vmin=-.5, vmax=9.5)
            axes[i, j-1].scatter(z[:, i], z[:, j],
                            c=label, cmap="tab10", alpha=0.1)
            # axes[i, j-1].set_aspect('equal', adjustable='box')
            # axes[i, j-1].set_xlim(-LIMIT, LIMIT)
            # axes[i, j-1].set_ylim(-LIMIT, LIMIT)
            for d in range(10):
                axes[i, j-1].text(z_centers[d][i], z_centers[d][j], str(d), fontsize=15, ha='center', va='center')
        else:
            axes[i, j-1].axis('off')

for i in range(Z2_DIM - 1):
    for j in range(Z2_DIM - 1):
        if i == j :
            axes[i, j].set_ylabel(f"$z_2$[{i}]", fontsize=15)
            axes[i, j].set_xlabel(f"$z_2$[{j + 1}]", fontsize=15)
        elif i > j:
            axes[i, j].axis('off')
        # else:
        #     axes[i, j].set_xticks([])
        #     axes[i, j].set_yticks([])

plt.suptitle("Latent Space of Invariant AE", fontsize=20)
cbar = plt.colorbar(im, ax=axes, orientation='horizontal',
                    fraction=0.02, pad=0.0)
cbar.set_label('Digit Label', fontsize=15)
cbar.ax.tick_params(labelsize=12)
cbar.set_ticks(range(10))
plt.show()

In [None]:
PROXY2INVAR_PATH = "../checkpoints/mnist/group_similar/aeproxy2invar.pth"

for param in cvae.parameters():
    param.requires_grad = False
for param in inae.parameters():
    param.requires_grad = False

proxy2invar = ProxyRep2InvarRep(autoencoder=inae, reparameterize=False)
proxy2invar = proxy2invar.to(device)

if os.path.exists(PROXY2INVAR_PATH):
    print("Loading pre-trained Proxy2Invar model...")
    proxy2invar.load_state_dict(torch.load(PROXY2INVAR_PATH, map_location=device))
else:
    print("No pre-trained model found. Starting from scratch.")
    optimizer = optim.Adam(proxy2invar.parameters(), lr=1e-3)
    loss_fn = torch.nn.MSELoss(reduction='mean')

    losses_train = []
    losses_val = []
    for epoch in range(EPOCHS):
        print(f"Epoch {epoch + 1}/{EPOCHS}")
        proxy2invar.train()
        for i, (x, y) in enumerate(tqdm(mnist_train)):
            x = x.to(device)
            
            # Get real invariant representation from CVAE
            with torch.no_grad():
                z_invar = inae.cvae.encoder(x)
                mu = z_invar[:, :inae.cvae.latent_dim]
                logvar = z_invar[:, inae.cvae.latent_dim:]
                z_invar = inae.cvae.reparameterize(mu, logvar)
            
            z_invar_pred = proxy2invar(x)
            optimizer.zero_grad()
            loss = loss_fn(z_invar_pred, z_invar)
            loss.backward()
            losses_train.append(loss.item())
            optimizer.step()

        proxy2invar.eval()
        with torch.no_grad():
            total_val_loss = 0
            for x, y in mnist_val:
                x = x.to(device)
                z_invar = inae.cvae.encoder(x)
                mu = z_invar[:, :inae.cvae.latent_dim]
                logvar = z_invar[:, inae.cvae.latent_dim:]
                z_invar = inae.cvae.reparameterize(mu, logvar)

                z_invar_pred = proxy2invar(x)
                val_loss = loss_fn(z_invar_pred, z_invar)
                total_val_loss += val_loss.item()
            avg_val_loss = total_val_loss / len(mnist_val)
            losses_val.append(avg_val_loss)
            print(f"\nEpoch {epoch + 1} Validation Loss: {avg_val_loss:.4f}")

    torch.save(proxy2invar.state_dict(), PROXY2INVAR_PATH)

    nb_minibatches = len(mnist_train)

    plt.figure(figsize=(10, 5))
    plt.plot(losses_train, label='Training Loss', color='blue')
    plt.plot(range(nb_minibatches, nb_minibatches * EPOCHS + 1, nb_minibatches),
                losses_val, label='Validation Loss', color='orange')
    plt.xlabel('Batch Iterations')
    plt.ylabel('Loss')
    plt.title('AE Proxy to Invariant Representation MSE')
    plt.legend()
    plt.xscale('log')

    plt.show()

In [None]:
PROXY2GROUP_PATH = "../checkpoints/mnist/group_similar/aeproxy2group.pth"
proxy2group = ProxyRep2Label(autoencoder=inae, reparameterize=False, nb_labels=len(merge_group))
proxy2group = proxy2group.to(device)

mnist_train, mnist_val, mnist_test = get_mnist_dataloaders("../data", one_hot=False, batch_size=BATCH_SIZE)
if os.path.exists(PROXY2GROUP_PATH):
    print("Loading pre-trained Proxy2Group model...")
    proxy2group.load_state_dict(torch.load(PROXY2GROUP_PATH, map_location=device))
else:
    print("No pre-trained model found. Starting from scratch.")
    optimizer = optim.Adam(proxy2group.parameters(), lr=1e-3)
    loss_fn = torch.nn.CrossEntropyLoss(reduction='mean')

    losses_train = []
    losses_val = []
    accuracies_val = []
    for epoch in range(EPOCHS):
        print(f"Epoch {epoch + 1}/{EPOCHS}")
        proxy2group.train()
        for i, (x, y) in enumerate(tqdm(mnist_train)):
            x = x.to(device)
            y = get_merged_labels(y, merge_group).float().to(device)
            optimizer.zero_grad()
            y_pred = proxy2group(x)
            loss = loss_fn(y_pred, y)
            loss.backward()
            losses_train.append(loss.item())
            optimizer.step()
        with torch.no_grad():
            proxy2group.eval()
            total_val_loss = 0
            total_val_accu = 0
            for x, y in mnist_val:
                x = x.to(device)
                y = get_merged_labels(y, merge_group).float().to(device)
                y_pred = proxy2group(x)
                val_loss = loss_fn(y_pred, y)
                total_val_loss += val_loss.item()
                
                group_pred = y_pred.argmax(dim=1)
                group_true = y.argmax(dim=1)
                acc = (group_pred == group_true).float().mean().item()
                total_val_accu += acc

            avg_val_loss = total_val_loss / len(mnist_val)
            avg_val_accu = total_val_accu / len(mnist_val)
            accuracies_val.append(avg_val_accu)
            losses_val.append(avg_val_loss)
            print(f"\nEpoch {epoch + 1} Validation Loss: {avg_val_loss:.4f}, Accuracy: {avg_val_accu:.4f}")
    torch.save(proxy2group.state_dict(), PROXY2GROUP_PATH)

    nb_minibatches = len(mnist_train)

    fig, axes = plt.subplots(2, 1, figsize=(10, 10))
    axes[0].plot(losses_train, label='Training Loss', color='blue')
    axes[0].plot(range(nb_minibatches, nb_minibatches * EPOCHS + 1, nb_minibatches),
                losses_val, label='Validation Loss', color='orange')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('VAE Proxy to Label Cross Entropy')
    axes[0].legend()
    axes[0].set_xscale('log')
    axes[0].set_xticks([])
    xlim = axes[0].get_xlim()

    axes[1].plot(range(nb_minibatches, nb_minibatches * EPOCHS + 1, nb_minibatches),
                accuracies_val, label='Validation Accuracy', color='green')
    axes[1].set_xlabel('Batch Iterations')
    axes[1].set_ylabel('Accuracy')
    axes[1].set_title('VAE Proxy to Label Accuracy')
    axes[1].legend()
    axes[1].set_xscale('log')
    axes[1].set_xlim(xlim)
    plt.tight_layout()
    plt.show()

In [None]:
proxy2group.eval()
group_true_all = []
group_pred_all = []
with torch.no_grad():
    for x, y in mnist_test:
        x = x.to(device)
        y = get_merged_labels(y, merge_group).float().to(device)
        y_pred = proxy2group(x)
        group_pred = y_pred.argmax(dim=1)
        group_true = y.argmax(dim=1)

        group_true_all.append(group_true.cpu())
        group_pred_all.append(group_pred.cpu())
group_true_all = torch.cat(group_true_all)
group_pred_all = torch.cat(group_pred_all)
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(group_true_all, group_pred_all)
plt.figure(figsize=(8, 6))
plt.imshow(cm, cmap='Blues', interpolation='nearest')
plt.xlabel('Predicted Group')
plt.ylabel('True Group')
plt.title('Confusion Matrix for Group Predictions')
plt.xticks(range(len(merge_group)), [str(g) for g in merge_group], rotation=45)
plt.yticks(range(len(merge_group)), [str(g) for g in merge_group])
plt.colorbar()
plt.tight_layout()
plt.show()

In [None]:
PROXY2LAB_PATH = "../checkpoints/mnist/group_similar/aeproxy2lab.pth"

proxy2lab = ProxyRep2Label(autoencoder=inae, reparameterize=False, nb_labels=10)
proxy2lab = proxy2lab.to(device)
mnist_train, mnist_val, mnist_test = get_mnist_dataloaders("../data", one_hot=True, batch_size=BATCH_SIZE)

if os.path.exists(PROXY2LAB_PATH):
    print("Loading pre-trained Proxy2Label model...")
    proxy2lab.load_state_dict(torch.load(PROXY2LAB_PATH, map_location=device))
else:
    print("No pre-trained model found. Starting from scratch.")
    optimizer = optim.Adam(proxy2lab.parameters(), lr=1e-3)
    loss_fn = torch.nn.CrossEntropyLoss(reduction='mean')

    losses_train = []
    losses_val = []
    accuracies_val = []
    for epoch in range(EPOCHS):
        print(f"Epoch {epoch + 1}/{EPOCHS}")
        proxy2lab.train()
        for i, (x, y) in enumerate(tqdm(mnist_train)):
            x = x.to(device)
            y = y.float().to(device)
            optimizer.zero_grad()
            y_pred = proxy2lab(x)
            loss = loss_fn(y_pred, y)
            loss.backward()
            losses_train.append(loss.item())
            optimizer.step()
        with torch.no_grad():
            proxy2lab.eval()
            total_val_loss = 0
            total_val_accu = 0

            for x, y in mnist_val:
                x = x.to(device)
                y = y.float().to(device)
                y_pred = proxy2lab(x)
                val_loss = loss_fn(y_pred, y)
                total_val_loss += val_loss.item()

                lab_pred = y_pred.argmax(axis=1)
                lab_true = y.argmax(axis=1)
                acc = (lab_pred == lab_true).float().mean().item()
                total_val_accu += acc

            avg_val_loss = total_val_loss / len(mnist_val)
            avg_val_accu = total_val_accu / len(mnist_val)
            accuracies_val.append(avg_val_accu)
            print(f"\nEpoch {epoch + 1} Validation Loss: {avg_val_loss:.4f}, Accuracy: {avg_val_accu:.4f}")
            losses_val.append(avg_val_loss)
            print(f"\nEpoch {epoch + 1} Validation Loss: {avg_val_loss:.4f}")


    torch.save(proxy2lab.state_dict(), PROXY2LAB_PATH)

    nb_minibatches = len(mnist_train)

    fig, axes = plt.subplots(2, 1, figsize=(10, 10))
    axes[0].plot(losses_train, label='Training Loss', color='blue')
    axes[0].plot(range(nb_minibatches, nb_minibatches * EPOCHS + 1, nb_minibatches),
                losses_val, label='Validation Loss', color='orange')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('AE Proxy to Label Cross Entropy')
    axes[0].legend()
    axes[0].set_xscale('log')
    axes[0].set_xticks([])
    xlim = axes[0].get_xlim()

    axes[1].plot(range(nb_minibatches, nb_minibatches * EPOCHS + 1, nb_minibatches),
                accuracies_val, label='Validation Accuracy', color='green')
    axes[1].set_xlabel('Batch Iterations')
    axes[1].set_ylabel('Accuracy')
    axes[1].set_title('AE Proxy to Label Accuracy')
    axes[1].legend()
    axes[1].set_xscale('log')
    axes[1].set_xlim(xlim)
    plt.tight_layout()
    plt.show()


In [None]:
PROXY2LAB_PATH = "../checkpoints/mnist/group_similar/aeproxy2lab.pth"

proxy2lab = ProxyRep2Label(autoencoder=inae, reparameterize=False, nb_labels=10)
proxy2lab = proxy2lab.to(device)

if os.path.exists(PROXY2LAB_PATH):
    print("Loading pre-trained Proxy2Label model...")
    proxy2lab.load_state_dict(torch.load(PROXY2LAB_PATH, map_location=device))
else:
    print("No pre-trained model found. Starting from scratch.")
    optimizer = optim.Adam(proxy2lab.parameters(), lr=1e-3)
    loss_fn = torch.nn.CrossEntropyLoss(reduction='mean')

    losses_train = []
    losses_val = []
    accuracies_val = []
    for epoch in range(EPOCHS):
        print(f"Epoch {epoch + 1}/{EPOCHS}")
        proxy2lab.train()
        for i, (x, y) in enumerate(tqdm(mnist_train)):
            x = x.to(device)
            y = y.float().to(device)
            optimizer.zero_grad()
            y_pred = proxy2lab(x)
            loss = loss_fn(y_pred, y)
            loss.backward()
            losses_train.append(loss.item())
            optimizer.step()
        with torch.no_grad():
            proxy2lab.eval()
            total_val_loss = 0
            total_val_accu = 0

            for x, y in mnist_val:
                x = x.to(device)
                y = y.float().to(device)
                y_pred = proxy2lab(x)
                val_loss = loss_fn(y_pred, y)
                total_val_loss += val_loss.item()

                lab_pred = y_pred.argmax(axis=1)
                lab_true = y.argmax(axis=1)
                acc = (lab_pred == lab_true).float().mean().item()
                total_val_accu += acc

            avg_val_loss = total_val_loss / len(mnist_val)
            avg_val_accu = total_val_accu / len(mnist_val)
            accuracies_val.append(avg_val_accu)
            print(f"\nEpoch {epoch + 1} Validation Loss: {avg_val_loss:.4f}, Accuracy: {avg_val_accu:.4f}")
            losses_val.append(avg_val_loss)
            print(f"\nEpoch {epoch + 1} Validation Loss: {avg_val_loss:.4f}")


    torch.save(proxy2lab.state_dict(), PROXY2LAB_PATH)

    nb_minibatches = len(mnist_train)

    fig, axes = plt.subplots(2, 1, figsize=(10, 10))
    axes[0].plot(losses_train, label='Training Loss', color='blue')
    axes[0].plot(range(nb_minibatches, nb_minibatches * EPOCHS + 1, nb_minibatches),
                losses_val, label='Validation Loss', color='orange')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('AE Proxy to Label Cross Entropy')
    axes[0].legend()
    axes[0].set_xscale('log')
    axes[0].set_xticks([])
    xlim = axes[0].get_xlim()

    axes[1].plot(range(nb_minibatches, nb_minibatches * EPOCHS + 1, nb_minibatches),
                accuracies_val, label='Validation Accuracy', color='green')
    axes[1].set_xlabel('Batch Iterations')
    axes[1].set_ylabel('Accuracy')
    axes[1].set_title('AE Proxy to Label Accuracy')
    axes[1].legend()
    axes[1].set_xscale('log')
    axes[1].set_xlim(xlim)
    plt.tight_layout()
    plt.show()


In [None]:
proxy2lab.eval()

lab_true_all = []
lab_pred_all = []

for x, y in mnist_test:
    x = x.to(device)
    y = y.float().to(device)
    lab_true = y.argmax(axis=1)
    lab_true_all.append(lab_true)

    y_pred = proxy2lab(x)
    lab_pred = y_pred.argmax(axis=1)
    lab_pred_all.append(lab_pred)
lab_true_all = torch.cat(lab_true_all).detach().cpu()
lab_pred_all = torch.cat(lab_pred_all).detach().cpu()
confusion_matrix = torch.zeros(10, 10, dtype=torch.int64)
for true, pred in zip(lab_true_all, lab_pred_all):
    confusion_matrix[true, pred] += 1
accuracy = (lab_true_all == lab_pred_all).float().mean().item()

plt.figure(figsize=(8, 6))
plt.imshow(confusion_matrix, cmap='Blues', interpolation='nearest')
plt.colorbar()
plt.title(f"Confusion Matrix of AE Proxy to Label (acc.: {accuracy:.3f})")
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.xticks(ticks=range(10), labels=range(10))
plt.yticks(ticks=range(10), labels=range(10))
plt.grid(False)
plt.show()