In [None]:
import torch
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader

from mnist_style.models import ClassifyingAutoEncoder
from mnist_style.persistence import load_models

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from functools import partial

sns.set_theme()

In [None]:
n_classes, style_dim = 10, 4
autoencoder = ClassifyingAutoEncoder(n_classes, style_dim)

load_models({"encoder": autoencoder.encoder, "decoder": autoencoder.decoder}, "./pt-aae")

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

test_dataset = MNIST(root='./data', train=False, download=False, transform=transform)

### Style Vector Distribution Visualize

In [None]:
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False)

feat_names = ['feat_' + chr(ord('a') + i) for i in range(style_dim)]
test_enc_dfs = []
autoencoder.eval()
with torch.inference_mode():
    for batch, labels in test_dataloader:
        class_logits, style_feats = autoencoder.forward_encoder(batch)
        features = style_feats.detach().numpy()
        predictions = np.argmax(class_logits.detach().numpy(), axis=1)
        df = pd.DataFrame(features, columns=feat_names)
        df['digit'] = labels.numpy()
        df['prediction'] = predictions
        test_enc_dfs.append(df)
encoder_df = pd.concat(test_enc_dfs, ignore_index=True)

In [None]:
g = sns.PairGrid(encoder_df, hue="digit", diag_sharey=False, height=3, palette="tab10")  # hue="hls"
g.map_diag(sns.histplot, multiple="stack", element="bars")
g.map_offdiag(sns.scatterplot)
g.add_legend()
for i, axs in enumerate(g.axes):
    for j, ax in enumerate(axs):
        ax.axvline(color='black', linewidth=0.5)
        if i != j:
            ax.axhline(color='black', linewidth=0.5)

### Compare Random Dataset Images and Corresponding AutoEncoder Results

In [None]:
imgs_per_row = 10
test_dataloader = DataLoader(test_dataset, batch_size=imgs_per_row, shuffle=True)

autoencoder.eval()
fig, axs = plt.subplots(2, imgs_per_row, figsize=(16, 3))
for batch, _ in test_dataloader:
    class_logits, style_feats, decoded_batch = autoencoder(batch)

    vals, idxs = F.softmax(class_logits.detach(), dim=1).max(dim=1)
    print(list(zip(idxs.tolist(), vals.numpy().round(3))))

    for i, (image, decoded) in enumerate(zip(batch, decoded_batch.detach())):
        axs[0, i].set_axis_off()
        axs[1, i].set_axis_off()
        axs[0, i].imshow(image[0], cmap="viridis")
        axs[1, i].imshow(decoded[0], cmap="viridis")
    break
fig.tight_layout(pad=0, h_pad=1)

### Classifier Accuracy and Style Vector Distribution Fitness

In [None]:
test_dataloader = DataLoader(test_dataset, batch_size=400, shuffle=True)

enc_acc_df = encoder_df.assign(accuracy=encoder_df['digit'] == encoder_df['prediction'])
print('mean accuracy:', np.mean(enc_acc_df['accuracy']))
enc_acc_df[['digit', 'accuracy']].groupby('digit').mean().T

In [None]:
def goodness_of_fit_metric(samples, norm_scale=2):
    cdf = partial(stats.norm.cdf, loc=0, scale=norm_scale)
    ks_test = stats.ks_1samp(samples, cdf)
    return -np.log10(ks_test.pvalue) if ks_test.pvalue > 0 else np.inf
    # return ks_test.statistic

all_goodness = goodness_of_fit_metric(encoder_df[feat_names].values.ravel())
feat_wise_logps = [[all_goodness] + [goodness_of_fit_metric(encoder_df[feat]) for feat in feat_names]]
for digit in range(10):
    df_dig = encoder_df.query(f'digit == {digit}')
    all_feat_goodness = goodness_of_fit_metric(df_dig[feat_names].values.ravel())
    feat_wise_logps.append([all_feat_goodness] + [goodness_of_fit_metric(df_dig[feat]) for feat in feat_names])
pd.DataFrame(feat_wise_logps, columns=['all features'] + feat_names, index=['all digits'] + [f'digit {i}' for i in range(10)])

### Generate New Images for Random Style Vectors (fixed per row)

In [None]:
num_rows = 3
norm_scale = 1
autoencoder.eval()
fig, axs = plt.subplots(num_rows, 10, figsize=(16, 4.75))
for i in range(num_rows):
    classes_onehot = F.one_hot(torch.arange(10), 10)
    style_feats = torch.randn((10, style_dim), dtype=torch.float32) * norm_scale
    encoded_batch = torch.concat((classes_onehot, style_feats), dim=1)
    decoded_batch = autoencoder.decoder(encoded_batch)

    for j, decoded in enumerate(decoded_batch.detach()):
        axs[i, j].set_axis_off()
        axs[i, j].imshow(decoded[0], cmap="viridis")
fig.tight_layout(pad=0, h_pad=1)