In [None]:
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader

from mnist_style.models import Encoder, Decoder
from mnist_style.persistence import load_models, save_models

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from functools import partial

sns.set_theme()

In [None]:
latent_dim = 8
encoder = Encoder(latent_dim)
decoder = Decoder(latent_dim)

load_models({"encoder": encoder, "decoder": decoder}, "./pt-aae")

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

test_dataset = MNIST(root='./data', train=False, download=False, transform=transform)

In [None]:
test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=True)

encoder.eval()
for batch, labels in test_dataloader:
    features = encoder(batch).detach().numpy()
    columns = ['f' + chr(i + ord('a')) for i in range(features.shape[1])]
    df = pd.DataFrame(features, columns=columns)
    df = df.assign(digit=labels)
    g = sns.pairplot(df, hue="digit", palette="tab10")  # hls
    break

In [None]:
test_dataloader = DataLoader(test_dataset, batch_size=10, shuffle=True)

encoder.eval()
decoder.eval()
fig, axs = plt.subplots(2, 10, figsize=(16, 3))
for batch, _ in test_dataloader:
    decoded_batch = decoder(encoder(batch)).detach().numpy()
    for i, (image, decoded) in enumerate(zip(batch, decoded_batch)):
        axs[0, i].set_axis_off()
        axs[1, i].set_axis_off()
        axs[0, i].imshow(image[0], cmap="viridis")
        axs[1, i].imshow(decoded[0], cmap="viridis")
    fig.tight_layout(pad=0, h_pad=1)
    break

In [None]:
rng = np.random.default_rng()

In [None]:
cdf = partial(stats.norm.cdf, loc=0, scale=2)
sample = stats.norm.rvs(size=1000, scale=2, random_state=rng)
print(stats.kstest(sample, cdf))
print(-np.log10(stats.kstest(sample, cdf).pvalue))

In [None]:
sns.histplot(sample)

In [None]:
test_dataloader = DataLoader(test_dataset, batch_size=400, shuffle=True)

def neg_log_fit_goodness(samples, norm_scale=2):
    cdf = partial(stats.norm.cdf, loc=0, scale=norm_scale)
    return -np.log10(stats.ks_1samp(samples, cdf).pvalue)

feat_names = ['feat ' + chr(ord('a') + i) for i in range(features.shape[1])]
batch_dfs = []

encoder.eval()
for batch, labels in test_dataloader:
    features = encoder(batch).detach().numpy()
    df = pd.DataFrame(features, columns=feat_names)
    df['digit'] = labels
    batch_dfs.append(df)

df = pd.concat(batch_dfs)
print(len(df.index))
all_goodness = neg_log_fit_goodness(df[feat_names].values.ravel())
feat_wise_logps = [[all_goodness] + [neg_log_fit_goodness(df[feat]) for feat in feat_names]]
for digit in range(10):
    df_dig = df.query(f'digit == {digit}')
    all_feat_goodness = neg_log_fit_goodness(df_dig[feat_names].values.ravel())
    feat_wise_logps.append([all_feat_goodness] + [neg_log_fit_goodness(df_dig[feat]) for feat in feat_names])
pd.DataFrame(feat_wise_logps, columns=['all features'] + feat_names, index=['all digits'] + [f'digit {i}' for i in range(10)])