In [None]:
import os
import sys
import time
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Set the GPUs to use
sys.path.append(os.path.join(os.getcwd(), ".."))

import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix
import torch
import torch.nn.functional as F
from torch.nn import DataParallel
import torch.optim as optim
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")

from src.mnist.datasets import get_mnist_dataloaders, convert_flattened_to_image, get_merged_labels
from src.mnist.models import InfoBottleneckClassifier, ProxyRep2Label
from src.mnist.losses import InfoBottleneck_Loss


In [None]:
BATCH_SIZE = 500
# BETA = 5e-4  # Scaled by batch size for stability
BETA = 0.001  # Scaled by batch size for stability
REDUCTION = 'mean'
EPOCHS = 50
LR = 5e-4  # Learning rate
IBGROUP_PATH = "../checkpoints/mnist/group_similar/ib_groupifier.pth"

mnist_train, mnist_val, mnist_test = get_mnist_dataloaders("../data", one_hot=False)

merge_group = [
    [0, 6],
    [1],
    [4, 7, 9],
    [2, 3, 5, 8]
]

ib_groupifier = InfoBottleneckClassifier(
    input_dim=28 * 28,
    encoder_layer_sizes=[128, 128],
    latent_dim=64,
    mlp_layer_sizes=[256, 512, 256],
    nb_labels=len(merge_group)
).to(device)

if os.path.exists(IBGROUP_PATH):
    print(f"Loading InfoBottleneck model from {IBGROUP_PATH}")
    ib_groupifier.load_state_dict(torch.load(IBGROUP_PATH))
else:
    print(f"InfoBottleneck model not found at {IBGROUP_PATH},\nstarting training from scratch.")
    optimizer = optim.Adam(ib_groupifier.parameters(), lr=LR)
    loss_fn = InfoBottleneck_Loss(beta=BETA, reduction=REDUCTION)

    losses_train = []
    losses_val = []
    accuracies_val = []

    for epoch in range(EPOCHS):
        ib_groupifier.train()

        for i, (x, y) in enumerate(mnist_train):
            x = x.to(device)
            y = get_merged_labels(y, merge_group)
            y = y.float().to(device)

            optimizer.zero_grad()
            y_pred, mu, logvar = ib_groupifier(x)
            loss = loss_fn(y, y_pred, mu, logvar)
            loss.backward()
            optimizer.step()

            losses_train.append(loss.item())
            print(f"Epoch [{epoch + 1}/{EPOCHS}], Step [{i + 1}/{len(mnist_train)}], Loss: {loss.item():.4f}", end='\r')
        
        with torch.no_grad():
            ib_groupifier.eval()
            val_loss = 0.0
            correct = 0
            total = 0

            for x, y in mnist_val:
                x = x.to(device)
                y = get_merged_labels(y, merge_group)
                y = y.float().to(device)

                y_pred, mu, logvar = ib_groupifier(x)
                loss = loss_fn(y, y_pred, mu, logvar)
                val_loss += loss.item()
                group_pred = torch.argmax(y_pred, dim=-1)
                y = torch.argmax(y, dim=-1)
                correct += (group_pred == y).sum().item()
                total += y.size(0)
            val_loss /= len(mnist_val)
            accuracy = correct / total
            losses_val.append(val_loss)
            accuracies_val.append(accuracy)
            print(f"\nEpoch [{epoch + 1}/{EPOCHS}], Validation Loss: {val_loss:.4f}, Accuracy: {accuracy:.4f}")
    torch.save(ib_groupifier.state_dict(), IBGROUP_PATH)

    nb_minibatches = len(mnist_train)
    fig, axes = plt.subplots(2, 1, figsize=(10, 10))
    axes[0].plot(losses_train, label='Training Loss', color='blue')
    axes[0].plot(range(nb_minibatches, nb_minibatches * EPOCHS + 1, nb_minibatches),
                losses_val, label='Validation Loss', color='orange')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('VAE Proxy to Group Cross Entropy')
    axes[0].legend()
    axes[0].set_xscale('log')
    axes[0].set_xticks([])
    xlim = axes[0].get_xlim()
    axes[1].plot(range(nb_minibatches, nb_minibatches * EPOCHS + 1, nb_minibatches),
                accuracies_val, label='Validation Accuracy', color='green')
    axes[1].set_xlabel('Batch Iterations')
    axes[1].set_ylabel('Accuracy')
    axes[1].set_title('VAE Proxy to Group Accuracy')
    axes[1].legend()
    axes[1].set_xscale('log')
    axes[1].set_xlim(xlim)
    plt.tight_layout()
    plt.show()

In [None]:
nb_minibatches = len(mnist_train)
fig, axes = plt.subplots(2, 1, figsize=(10, 10))
axes[0].plot(losses_train, label='Training Loss', color='blue')
axes[0].plot(range(nb_minibatches, nb_minibatches * EPOCHS + 1, nb_minibatches),
            losses_val, label='Validation Loss', color='orange')
axes[0].set_ylabel('Loss')
axes[0].set_title('VAE Proxy to Group Cross Entropy')
axes[0].legend()
axes[0].set_xscale('log')
axes[0].set_xticks([])
xlim = axes[0].get_xlim()
axes[1].plot(range(nb_minibatches, nb_minibatches * EPOCHS + 1, nb_minibatches),
            accuracies_val, label='Validation Accuracy', color='green')
axes[1].set_xlabel('Batch Iterations')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('VAE Proxy to Group Accuracy')
axes[1].legend()
axes[1].set_xscale('log')
axes[1].set_xlim(xlim)
plt.tight_layout()
plt.show()

In [None]:
ib_groupifier.eval()
group_true_all = []
group_pred_all = []
with torch.no_grad():
    for x, y in mnist_test:
        x = x.to(device)
        y = get_merged_labels(y, merge_group).float().to(device)
        y_pred = ib_groupifier(x)[0]
        group_pred = y_pred.argmax(dim=-1)
        group_true = y.argmax(dim=-1)
        group_true_all.append(group_true.cpu().numpy())
        group_pred_all.append(group_pred.cpu().numpy())
group_true_all = np.concatenate(group_true_all)
group_pred_all = np.concatenate(group_pred_all)

cm = confusion_matrix(group_true_all, group_pred_all)
plt.figure(figsize=(8, 6))
plt.imshow(cm, cmap='Blues', interpolation='nearest')
plt.xlabel('Predicted Group')
plt.ylabel('True Group')
plt.title('Confusion Matrix for Group Predictions')
plt.xticks(range(len(merge_group)), [str(g) for g in merge_group], rotation=45)
plt.yticks(range(len(merge_group)), [str(g) for g in merge_group])
plt.colorbar()
plt.tight_layout()
plt.show()

In [None]:
IBCLASS_PATH = "../checkpoints/mnist/group_similar/ib_classifier.pth"

mnist_train, mnist_val, mnist_test = get_mnist_dataloaders("../data", one_hot=True)

z2lab = ProxyRep2Label(autoencoder=ib_groupifier, reparameterize=True, nb_labels=10)
z2lab = z2lab.to(device)

if os.path.exists(IBCLASS_PATH):
    print(f"Loading ProxyRep2Label model from {IBCLASS_PATH}")
    z2lab.load_state_dict(torch.load(IBCLASS_PATH))
else:
    print(f"ProxyRep2Label model not found at {IBCLASS_PATH},\nstarting training from scratch.")
    optimizer = optim.Adam(z2lab.parameters(), lr=LR)
    loss_fn = torch.nn.CrossEntropyLoss()

    losses_train = []
    losses_val = []
    accuracies_val = []

    for epoch in range(EPOCHS):
        z2lab.train()

        for i, (x, y) in enumerate(mnist_train):
            x = x.to(device)
            y = y.float().to(device)

            optimizer.zero_grad()
            y_pred = z2lab(x)
            loss = loss_fn(y_pred, y)
            loss.backward()
            optimizer.step()

            losses_train.append(loss.item())
            print(f"Epoch [{epoch + 1}/{EPOCHS}], Step [{i + 1}/{len(mnist_train)}], Loss: {loss.item():.4f}", end='\r')
        
        with torch.no_grad():
            z2lab.eval()
            val_loss = 0.0
            correct = 0
            total = 0

            for x, y in mnist_val:
                x = x.to(device)
                y = y.float().to(device)

                y_pred = z2lab(x)
                loss = loss_fn(y_pred, y)
                val_loss += loss.item()
                y_pred = y_pred.argmax(dim=-1)
                y = y.argmax(dim=-1)
                correct += (y_pred == y).sum().item()
                total += y.size(0)
            val_loss /= len(mnist_val)
            accuracy = correct / total
            losses_val.append(val_loss)
            accuracies_val.append(accuracy)
            print(f"\nEpoch [{epoch + 1}/{EPOCHS}], Validation Loss: {val_loss:.4f}, Accuracy: {accuracy:.4f}")
    torch.save(z2lab.state_dict(), IBCLASS_PATH)

    minibatches = len(mnist_train)
    
    fig, axes = plt.subplots(2, 1, figsize=(10, 10))
    axes[0].plot(losses_train, label='Training Loss', color='blue')
    axes[0].plot(range(nb_minibatches, nb_minibatches * EPOCHS + 1, nb_minibatches),
                losses_val, label='Validation Loss', color='orange')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('VAE Proxy to Group Cross Entropy')
    axes[0].legend()
    axes[0].set_xscale('log')
    axes[0].set_xticks([])
    xlim = axes[0].get_xlim()
    axes[1].plot(range(nb_minibatches, nb_minibatches * EPOCHS + 1, nb_minibatches),
                accuracies_val, label='Validation Accuracy', color='green')
    axes[1].set_xlabel('Batch Iterations')
    axes[1].set_ylabel('Accuracy')
    axes[1].set_title('VAE Proxy to Group Accuracy')
    axes[1].legend()
    axes[1].set_xscale('log')
    axes[1].set_xlim(xlim)
    plt.tight_layout()
    plt.show()

In [None]:
z2lab.eval()

lab_true_all = []
lab_pred_all = []

with torch.no_grad():
    for x, y in mnist_test:
        x = x.to(device)
        y = y.float().to(device)
        y_pred = z2lab(x)
        lab_pred = y_pred.argmax(dim=-1)
        lab_true = y.argmax(dim=-1)
        lab_true_all.append(lab_true.cpu().numpy())
        lab_pred_all.append(lab_pred.cpu().numpy())
lab_true_all = np.concatenate(lab_true_all)
lab_pred_all = np.concatenate(lab_pred_all)
accuracy = (lab_true_all == lab_pred_all).mean()
cm = confusion_matrix(lab_true_all, lab_pred_all)

plt.figure(figsize=(8, 6))
plt.imshow(cm, cmap='Blues', interpolation='nearest')
plt.colorbar()
plt.title(f"Confusion Matrix for Digit Predictions (acc.: {accuracy:.3f})")
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.xticks(ticks=range(10), labels=range(10))
plt.yticks(ticks=range(10), labels=range(10))
plt.grid(False)
plt.show()