# Align directory

In [None]:
import os
import sys

# 1. repo_dir used later
repo_dir = os.path.abspath(os.path.join(os.path.abspath(""), ".."))

# 2. sys.path must be appended for importing modules
sys.path.append(repo_dir)

# 3. fix current working directory
os.chdir(os.path.abspath(os.path.join(os.path.abspath(""))))

# Import libaries + assign device (CPU or GPU)

In [None]:
# assign discriminative or generative model boolean
test_discrim = True
model_type = 'discrim' if test_discrim else 'gen'

if test_discrim:
    import models.cm_discrim as cm
    from models.lo_discrim import bins_lo, fast_bins_lo
else:
    import models.cm_gen as cm
    from models.lo_gen import bins_lo, fast_bins_lo

from torch.utils.data import DataLoader
from torchvision import datasets
from utils.reproducibility import seed_everything

import glob
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision

# assign device (cpu or gpu, if present)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load test data

In [None]:
# choose True for MNIST and False for binary MNIST
use_mnist = False

# create data directory (if not done already)
data_dir = "../data"
if not os.path.exists(data_dir):
    os.makedirs(data_dir)

# download MNIST into data directory (if not done already)
mnist_test = datasets.MNIST(root="../data", train=False, download=True)
labels_mnist_test = mnist_test.targets

# convert dataset to tensor
mnist_test = mnist_test.data.view(10_000, 784).float()

# define test set and binarise if use_mnist is False
X_test = mnist_test if use_mnist else (mnist_test / 255 >= 0.5).float()
y_test = labels_mnist_test

# load val and test sets into dataloaders
batch_size = 128
test_loader = DataLoader(X_test, batch_size=batch_size)

# Load model

In [None]:
# model hyperparams
dataset = 'mnist' if use_mnist else 'bmnist'
latent_dim = 16
num_bins_trained = 2 ** 14
version_num = 0

# load model
decoder_arch = 'tConv'
model_path = glob.glob(repo_dir+f'/logs/{dataset}/{model_type}/{decoder_arch}/latent_dim_{latent_dim}/num_bins_{num_bins_trained}/version_{version_num}/checkpoints/*.ckpt')[0] # for bmnist
# model_path = glob.glob(repo_dir+f'/logs/{dataset}/{model_type}/latent_dim_{latent_dim}/num_bins_{num_bins_trained}/version_{version_num}/checkpoints/*.ckpt')[0] # for mnist
print(model_path)
model = cm.ContinuousMixture.load_from_checkpoint(model_path).to(device)
model.n_chunks = 32
model.missing = False # this was True before, check if difference (hope not)
model.eval(); # semi-colon to prevent printing model architecture

## Evaluate classification accuracy (on full and missing data) and log-likelihood of test data

In [None]:
def compute_accuracies(
    model,
    X_test,
    y_test,
    lower_power_bound,
    upper_power_bound,
    latent_opt,
    use_mnist=True,
    missing=False,
    missing_rate=0.6,
    batch_size_full=512,
    batch_size_missing=256,
    seed=42,
):
    accuracies = []
    bins_list = [2 ** k for k in range(lower_power_bound, upper_power_bound)]
    seed_everything(seed)
    test_lls = []
    for n_bins in bins_list:
        model.sampler.n_bins = n_bins
        if latent_opt:
            z, log_w = bins_lo(model, n_bins, train_loader, valid_loader, max_epochs=20, lr=1e-3, patience=5, device=device)
        else:
            z, log_w = model.sampler(seed=seed)
        all_ll = torch.zeros(len(X_test), 10)
        if not missing:
            test_lls.append([n_bins, model.eval_loader(test_loader, z, log_w, device=device).mean().item()])
        for digit in range(10):
            Xd = X_test.clone()
            # if missing mode, randomly mask entries
            if missing:
                model.missing = True
                mask = torch.rand_like(Xd) < missing_rate
                Xd[mask] = float('nan')
            # overwrite label channels
            if use_mnist:
                Xd[:, -1] = digit
            else:
                bits = torch.tensor([
                    int(b) for b in bin(digit)[2:].zfill(4)
                ], dtype=torch.float)
                Xd[:, -4:] = bits

            loader = DataLoader(
                Xd,
                batch_size=(batch_size_missing if missing else batch_size_full),
                shuffle=False
            )
            ll_chunks = []
            with torch.no_grad():
                for xb in loader:
                    xb = xb.to(device)
                    llb = model.forward(xb, z, log_w, k=None, seed=seed)
                    ll_chunks.append(llb.cpu())
            all_ll[:, digit] = torch.cat(ll_chunks, dim=0)

        # classification by max log-likelihood
        preds = all_ll.argmax(dim=1)
        acc = (preds == y_test.squeeze()).float().mean().item()
        accuracies.append(acc)
        line_class = f"Accuracy for {n_bins:5d} bins: {acc:.4f} : "
        print(line_class)

    if not missing:
        print()
        for (n_bins, test_ll) in test_lls:
            print(f"Test LL for {n_bins:5d} bins: {test_ll:.4f} : ")
    print()

    if latent_opt:
        return bins_list, accuracies, z
    
    return bins_list, accuracies

# Test on full data

In [None]:
bins_full, acc_full = compute_accuracies(
    model=model,
    X_test=X_test,
    y_test=y_test,
    lower_power_bound=15,
    upper_power_bound=17,
    latent_opt=False,
    use_mnist=use_mnist,
    missing=False,
)

# (TEMP): automate benchmarking

In [None]:
# model hyperparams
dataset = 'mnist' if use_mnist else 'bmnist'
decoder_arch = 'tConv'
latent_dim = 16
version_num = 0

bins_list = [2 ** k for k in range(8, 14)]
for num_bins_trained in bins_list:
    model_path = glob.glob(repo_dir+f'/logs/{dataset}/{model_type}/{decoder_arch}/latent_dim_{latent_dim}/num_bins_{num_bins_trained}/version_{version_num}/checkpoints/*.ckpt')[0] # for bmnist
    # model_path = glob.glob(repo_dir+f'/logs/{dataset}/{model_type}/latent_dim_{latent_dim}/num_bins_{num_bins_trained}/version_{version_num}/checkpoints/*.ckpt')[0] # for mnist
    model = cm.ContinuousMixture.load_from_checkpoint(model_path).to(device)
    model.n_chunks = 32
    model.missing = False # this was True before, check if difference (hope not)
    model.eval(); # semi-colon to prevent printing model architecture

    bins_full, acc_full = compute_accuracies(
        model=model,
        X_test=X_test,
        y_test=y_test,
        lower_power_bound=8,
        upper_power_bound=12,
        latent_opt=False,
        use_mnist=use_mnist,
        missing=False,
    )

# Test on missing data

In [None]:
bins_missing, acc_missing = compute_accuracies(
    model=model,
    X_test=X_test,
    y_test=y_test,
    lower_power_bound=8,
    upper_power_bound=12,
    latent_opt=False,
    use_mnist=use_mnist,
    missing=True,
    missing_rate=0.2
)

# Plot both

In [None]:
# Create ordinal positions
pos = np.arange(len(bins_full))

plt.figure(figsize=(8, 5))

# Plot at equal‐spaced positions
plt.plot(pos, [100 * a for a in acc_full], marker='o', linestyle='-', label='Full Data')
plt.plot(pos, [100 * a for a in acc_missing], marker='s', linestyle='--', label='Missing Data')

# Label each tick by its 2^k value
xtick_labels = [f"$2^{{{int(np.log2(b))}}}$" for b in bins_full]
plt.xticks(pos, xtick_labels)

plt.yticks(np.arange(10, 101, 10))
plt.xlabel("Number of Bins")
plt.ylabel("Accuracy (%)")
plt.title("Accuracy vs. Number of Bins")
plt.legend()
plt.grid(True)
plt.show()

## Evaluate sample quality

In [None]:
def reconstruct_image(grayscale_vector):
    image_data = grayscale_vector.reshape((28, 28))
    plt.imshow(image_data, cmap='gray')
    plt.axis('off')
    plt.show()

from random import randint

n_bins = 32
model.sampler.n_bins = n_bins
z, log_w = model.sampler(seed=42)
logits_tensor = model.decoder.net(z.to(device))
logits_sample = logits_tensor[randint(0, n_bins - 1)]

if decoder_arch == 'tConv':
    logits_permute = logits_sample.permute(1, 2, 0)
    logits_flat = logits_permute.reshape(-1, logits_permute.shape[-1])
    sampled_pixel_vals = torch.distributions.Categorical(logits=logits_flat).sample()
    reconstruct_image(sampled_pixel_vals.view(28, 28).detach().cpu())
else:
    probs = torch.sigmoid(logits_sample)
    sample = torch.bernoulli(probs)
    reconstruct_image(sample.detach().cpu())

# Latent optimisation

# Load training and validation (for latent opt.)

In [None]:
# download MNIST into data directory (if needed)
mnist_train_and_val = datasets.MNIST(root="../data", train=True, download=True)

# assign labels
labels_mnist_train_and_val = mnist_train_and_val.targets

# convert datasets to tensors
mnist_train_and_val = mnist_train_and_val.data.view(60_000, 784).float()

# embed class label in final pixel(s) of training samples
for idx in range(mnist_train_and_val.shape[0]):
    label = labels_mnist_train_and_val[idx]
    if use_mnist:
        mnist_train_and_val[idx][-1] = label
        # bin_label = torch.tensor([int(d) for d in bin(label)[2:].zfill(4)]).float()
        # mnist_train_and_val[idx][-4:] = bin_label
    else:
        binary_label = 255 * torch.tensor([int(d) for d in bin(label)[2:].zfill(4)]).float()
        mnist_train_and_val[idx][-4:] = binary_label

# define train and validation
if use_mnist:
    X_train = mnist_train_and_val[0:50_000]
    X_val   = mnist_train_and_val[50_000::]
else: # if use_mnist is False then binarise
    X_train = (mnist_train_and_val[0:50_000] / 255 >= 0.5).float()
    X_val   = (mnist_train_and_val[50_000::] / 255 >= 0.5).float()

y_train = labels_mnist_train_and_val[0:50_000]
y_val   = labels_mnist_train_and_val[50_000::]

# load data into data loaders
batch_size = 128
train_loader = DataLoader(X_train, batch_size=batch_size, shuffle=True, drop_last=True)
valid_loader = DataLoader(X_val  , batch_size=batch_size)

# Latent optimisation: test on full data

In [None]:
print(model_path)
# latent optimisation should really be done here to avoid doing it again for missing
bins_full, acc_full, z_opt = compute_accuracies(
    model=model,
    X_test=X_test,
    y_test=y_test,
    lower_power_bound=8,
    upper_power_bound=9,
    latent_opt=True,
    use_mnist=use_mnist,
    missing=False
)

# Latent optimisation: test on missing data data

In [None]:
bins_missing, acc_missing, z_opt = compute_accuracies(
    model=model,
    X_test=X_test,
    y_test=y_test,
    lower_power_bound=8,
    upper_power_bound=12,
    latent_opt=True,
    use_mnist=use_mnist,
    missing=True,
    missing_rate=0.2
)

# Evaluate sample quality

In [None]:
def reconstruct_image(grayscale_vector):
    image_data = grayscale_vector.reshape((28, 28))
    plt.imshow(image_data, cmap='gray')
    plt.axis('off')
    plt.show()

from random import randint

n_bins = 32
model.sampler.n_bins = n_bins
# latent z_opt used here!
logits_tensor = model.decoder.net(z_opt.to(device))
logits_sample = logits_tensor[randint(0, n_bins - 1)]

if decoder_arch == 'tConv':
    logits_permute = logits_sample.permute(1, 2, 0)
    logits_flat = logits_permute.reshape(-1, logits_permute.shape[-1])
    sampled_pixel_vals = torch.distributions.Categorical(logits=logits_flat).sample()
    reconstruct_image(sampled_pixel_vals.view(28, 28).detach().cpu())
else:
    probs = torch.sigmoid(logits_sample)
    sample = torch.bernoulli(probs)
    reconstruct_image(sample.detach().cpu())