In [None]:
import torch
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader

from mnist_style.models import ClassifyingAutoEncoder
from mnist_style.persistence import load_models

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from functools import partial

sns.set_theme()

In [None]:
n_classes, style_dim = 10, 4
autoencoder = ClassifyingAutoEncoder(n_classes, style_dim)

load_models({"encoder": autoencoder.encoder, "decoder": autoencoder.decoder}, "./pt-aae")

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

test_dataset = MNIST(root='./data', train=False, download=False, transform=transform)

### Style Vector Distribution Visualize

In [None]:
test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=True)

autoencoder.eval()
for batch, labels in test_dataloader:
    _, style_feats = autoencoder.forward_encoder(batch)
    features = style_feats.detach().numpy()
    columns = ['f' + chr(i + ord('a')) for i in range(features.shape[1])]
    df = pd.DataFrame(features, columns=columns)
    df = df.assign(digit=labels)
    g = sns.pairplot(df, hue="digit", palette="tab10")  # hls
    break

### Compare Random Dataset Images and Corresponding AutoEncoder Results

In [None]:
test_dataloader = DataLoader(test_dataset, batch_size=10, shuffle=True)

autoencoder.eval()
fig, axs = plt.subplots(2, 10, figsize=(16, 3))
for batch, _ in test_dataloader:
    class_logits, style_feats, decoded_batch = autoencoder(batch)

    vals, idxs = F.softmax(class_logits.detach(), dim=1).max(dim=1)
    print(list(zip(idxs.tolist(), vals.numpy().round(3))))

    for i, (image, decoded) in enumerate(zip(batch, decoded_batch.detach())):
        axs[0, i].set_axis_off()
        axs[1, i].set_axis_off()
        axs[0, i].imshow(image[0], cmap="viridis")
        axs[1, i].imshow(decoded[0], cmap="viridis")
    break
fig.tight_layout(pad=0, h_pad=1)

### Classifier Accuracy and Style Vector Distribution Fitness

In [None]:
test_dataloader = DataLoader(test_dataset, batch_size=400, shuffle=True)

feat_names = ['feat ' + chr(ord('a') + i) for i in range(style_dim)]
batch_dfs = []

autoencoder.eval()
for batch, labels in test_dataloader:
    class_logits, style_feats = autoencoder.forward_encoder(batch)
    features = style_feats.detach().numpy()
    df = pd.DataFrame(features, columns=feat_names)
    df['digit'] = labels.numpy()
    predicted = np.argmax(class_logits.detach().numpy(), axis=1)
    df['accuracy'] = labels.numpy() == predicted
    batch_dfs.append(df)

df = pd.concat(batch_dfs)
print('mean accuracy:', np.mean(df['accuracy']))
df[['digit', 'accuracy']].groupby('digit').mean().T

In [None]:
def goodness_of_fit_metric(samples, norm_scale=2):
    cdf = partial(stats.norm.cdf, loc=0, scale=norm_scale)
    ks_test = stats.ks_1samp(samples, cdf)
    return -np.log10(ks_test.pvalue)  # ks_test.statistic

all_goodness = goodness_of_fit_metric(df[feat_names].values.ravel())
feat_wise_logps = [[all_goodness] + [goodness_of_fit_metric(df[feat]) for feat in feat_names]]
for digit in range(10):
    df_dig = df.query(f'digit == {digit}')
    all_feat_goodness = goodness_of_fit_metric(df_dig[feat_names].values.ravel())
    feat_wise_logps.append([all_feat_goodness] + [goodness_of_fit_metric(df_dig[feat]) for feat in feat_names])
pd.DataFrame(feat_wise_logps, columns=['all features'] + feat_names, index=['all digits'] + [f'digit {i}' for i in range(10)])

### Generate New Images for Random Style Vectors (fixed per row)

In [None]:
autoencoder.eval()
fig, axs = plt.subplots(4, 10, figsize=(16, 6.35))
for i in range(4):
    classes_batch = np.zeros((10, n_classes), dtype=np.float32)
    classes_batch[range(10), range(10)] = 1.
    style_batch = np.random.normal(loc=0, scale=1, size=(10, 4)).astype(np.float32)
    encoded_batch = np.concatenate((classes_batch, style_batch), axis=1)
    decoded_batch = autoencoder.decoder(torch.tensor(encoded_batch))

    for j, decoded in enumerate(decoded_batch.detach()):
        axs[i, j].set_axis_off()
        axs[i, j].imshow(decoded[0], cmap="viridis")
fig.tight_layout(pad=0, h_pad=1)