In [None]:
import os
import torch
import torchvision as tv
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.widgets import Slider
import seaborn as sns
import plotly.graph_objects as go

from utils.data_loaders import get_rotated_mnist_dataloader
from utils.checkpoints import load_gen_disc_from_checkpoint, load_checkpoint, print_checkpoint, load_glow_from_checkpoint

from glow_regression import glow_regression

import warnings
warnings.simplefilter("ignore", UserWarning)

#%matplotlib notebook
#matplotlib.use("nbagg")

In [None]:
def plot_100_random_samples(model, title):
    n_examples = 100
    with torch.no_grad():
        x, _ = model.sample(num_samples=n_examples, y=None)
        x_ = torch.clamp(x, 0, 1)
        plt.figure(figsize=(10, 10))
        plt.imshow(np.transpose(tv.utils.make_grid(x_, nrow=10).cpu().numpy(), (1, 2, 0)))
        plt.grid(False)
        plt.yticks([])
        plt.xticks([])
        plt.suptitle(title)
        plt.tight_layout()
        plt.show()
        
def plot_100_random_samples_unnormalized(model, title):
    n_examples = 100
    with torch.no_grad():
        x, _ = model.sample(num_samples=n_examples, y=None)
        fig, ax = plt.subplots(10, 10, figsize=(10, 10))
        for i in range(10):
            for j in range(10):
                a = ax[i, j]
                a.imshow(x[10 * i + j].detach().squeeze().numpy(), cmap='gray')
                a.get_xaxis().set_ticks([])
                a.get_yaxis().set_ticks([])
        plt.grid(False)
        plt.suptitle(title)
        plt.tight_layout()
        plt.show()
        
def plot_patch_at_zero(model, description):
    z_zero = [torch.zeros(q.shape).unsqueeze(0) for q in model.q0]
    zero_patch, _ = model.forward_and_log_det(z_zero)
    
    plt.imshow(zero_patch.squeeze().detach(), cmap='gray')
    plt.title(f'Patch at z = zero\n{description}')
    plt.tight_layout()
    plt.show()

In [None]:
device = 'cpu'

IMG_SIZE = 8

model_path_128_train_img = 'trained_models/glow/2023-11-30_13:26:30/checkpoint_100000'
model_128 = load_glow_from_checkpoint(f'../{model_path_128_train_img}', arch='lodopab')

model_path_single_img = 'trained_models/glow/2023-12-01_09:01:05/checkpoint_12307'
model_single_img = load_glow_from_checkpoint(f'../{model_path_single_img}', arch='lodopab')

In [None]:
'''
SHOW RANDOM PATCHES
'''
plot_100_random_samples(model_128, f'Random patches\nGLOW trained on 128 images for 100000 iterations')
plot_100_random_samples(model_single_img, f'Random patches\nGLOW trained on single image for 12307 iterations')

In [None]:
'''
PLOT RANDOM PATCHES NORMALIZED PER BATCH
'''
plot_100_random_samples_unnormalized(model_128, f'Random patches\nGLOW trained on 128 images for 100000 iterations')
plot_100_random_samples_unnormalized(model_single_img, f'Random patches\nGLOW trained on single image for 12307 iterations')

In [None]:
plot_patch_at_zero(model_128, description='Trained on 128 images')
plot_patch_at_zero(model_single_img, description='Trained on single image')