In [None]:
import os
import torch
import torchvision as tv
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.widgets import Slider
import seaborn as sns
import plotly.graph_objects as go

from utils.data_loaders import get_rotated_mnist_dataloader
from utils.checkpoints import load_gen_disc_from_checkpoint, load_checkpoint, print_checkpoint, load_glow_from_checkpoint

from glow_regression import glow_regression

import warnings
warnings.simplefilter("ignore", UserWarning)

%matplotlib notebook
#matplotlib.use("nbagg")

In [None]:
device = 'cpu'

IMG_SIZE = 16
num_classes = 10

model_path = 'trained_models/glow/2023-11-23_12:25:04/checkpoint_93750'
model = load_glow_from_checkpoint(f'../{model_path}')

In [None]:
'''
DISPLAY MEAN IMAGE FOR EACH CLASS
'''
means = []
classes = torch.arange(num_classes)
num_samples = 10000

# compute latent variables for class means
z = []
for q in model.q0:
    means = []
    for i in range(num_classes):
        label = torch.tensor([i]).repeat(num_samples)
        samples = q(num_samples, label)[0]
        mean = torch.mean(samples, dim=0).unsqueeze(0).detach()
        means.append(mean)
    means = torch.cat(means)
    z.append(means)
    
# generate images at mean positions
with torch.no_grad():
    x, _ = model.forward_and_log_det(z)

# display means
fig, ax = plt.subplots(2, num_classes // 2)
for i in range(num_classes):
    a = ax[i // (num_classes//2), i % (num_classes // 2) ]
    a.imshow(x[i].squeeze(), cmap='gray')
    a.grid(False)
    a.get_xaxis().set_ticks([])
    a.get_yaxis().set_ticks([])
plt.suptitle(f'GLOW\nMean per class\nEstimated on {num_samples} samples')
    
    
plt.show()

In [None]:
'''
SHOW RANDOM SAMPLES FOR EACH CLASS
'''
n_examples = 10
with torch.no_grad():
    y = torch.arange(num_classes).repeat(n_examples).to(device)
    x, _ = model.sample(y=y)
    x_ = torch.clamp(x, 0, 1)
    plt.figure(figsize=(10, 10))
    plt.imshow(np.transpose(tv.utils.make_grid(x_, nrow=num_classes).cpu().numpy(), (1, 2, 0)))
    plt.grid(False)
    plt.yticks([])
    plt.xticks([])
    plt.suptitle(f'Random examples for each class')
    plt.show()

In [None]:
'''
PLOT TRAINING LOSS
'''
loss_hist = torch.load(f'../{model_path}', map_location='cpu')['loss_hist']

fig, ax = plt.subplots()
ax.plot(loss_hist)
ax.set_title('Training Loss over Iterations')
ax.set_xlabel('Iterations')
ax.set_ylabel('Training Loss')
plt.show()