In [None]:
import os
import torch
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.widgets import Slider
import seaborn as sns
import plotly.graph_objects as go

from utils.dataLoaders import get_rotated_mnist_dataloader
from utils.checkpoints import load_gen_disc_from_checkpoint, load_checkpoint, print_checkpoint, load_glow_from_checkpoint

from glow_regression import glow_regression

import warnings
warnings.simplefilter("ignore", UserWarning)

%matplotlib notebook
#matplotlib.use("nbagg")

In [None]:
def visualize_regression(loss, images_np, tar_np):
    fig= plt.figure(figsize=(9, 5))
    n_iterations = len(loss)
    
    a0 = plt.subplot(1, 2, 1)
    a1 = plt.subplot(3, 4, 3)
    a2 = plt.subplot(3, 4, 4)
    
    ax = [a0, a1, a2]
    
    log_loss = np.log(loss)
    slider_ax = fig.add_axes([0.25, 0.1, 0.65, 0.03])
    iter_slider = Slider(ax=slider_ax, label='Iteration', orientation='horizontal', 
                    valinit=0, valmin=0, valmax=n_iterations, valstep=1, closedmax=False)
    sns.lineplot(log_loss, ax=ax[0])
    ax[0].plot(0, log_loss[0], 'ko')
    ax[0].annotate('hoi', xy=(0, 0))
    
    ax[1].imshow(images_np[0], cmap='gray')
    ax[2].imshow(tar_np, cmap='gray')
    ax[0].set_title('log(MSE) over iterations')
    ax[0].set_xlabel('Iteration')
    ax[0].set_ylabel('log(MSE)')
    ax[1].set_title('Approximation')
    ax[2].set_title('Target')
    ax[1].grid(False)
    ax[2].grid(False)
    fig.subplots_adjust(bottom=0.25)
    
    def iter_update(val):
        idx = iter_slider.val
        ax[0].cla()
        sns.lineplot(log_loss, ax=ax[0])
        ax[0].plot(idx, log_loss[idx], 'ko', )
        ax[0].annotate(f'   {log_loss[idx]:.2f}\n', xy=(idx, log_loss[idx]), xytext=(idx, log_loss[idx]))
        ax[0].set_title('log(MSE) over iterations')
        ax[0].set_xlabel('Iteration')
        ax[0].set_ylabel('log(MSE)')
        ax[1].imshow(images_np[idx], cmap='gray')
        
    iter_slider.on_changed(iter_update)
    
    plt.show()

In [None]:
def visualize_three_regression(losses, images_np, tar_np, tags):
    fig= plt.figure(figsize=(9, 5))
    n_iterations = len(losses[0])
    archs = ['p4', 'z2', 'vanilla']
    
    a0 = plt.subplot(1, 2, 1)
    a1 = plt.subplot(3, 4, 3)
    a2 = plt.subplot(3, 4, 4)
    a3 = plt.subplot(3, 4, 7)
    a4 = plt.subplot(3, 4, 8)
    a5 = plt.subplot(3, 4, 11)
    a6 = plt.subplot(3, 4, 12)
    
    ax = [a0, a1, a2, a3, a4, a5, a6]
    

    log_losses = [np.log(l) for l in losses]
    
    slider_ax = fig.add_axes([0.25, 0.1, 0.65, 0.03])
    iter_slider = Slider(ax=slider_ax, label='Iteration', orientation='horizontal', 
                    valinit=0, valmin=0, valmax=n_iterations, valstep=1, closedmax=False)
    
    for i, l in enumerate(log_losses):
        sns.lineplot(l, ax=ax[0], label=tags[i], legend='brief')
        ax[0].plot(0, l[0], 'ko')
    
    ax[1].imshow(images_np[0][0], cmap='gray')
    ax[2].imshow(tar_np, cmap='gray')
    ax[3].imshow(images_np[1][0], cmap='gray')
    ax[4].imshow(tar_np, cmap='gray')
    ax[5].imshow(images_np[2][0], cmap='gray')
    ax[6].imshow(tar_np, cmap='gray')
    
    ax[0].set_title('log(MSE) over iterations')
    ax[0].set_xlabel('Iteration')
    ax[0].set_ylabel('log(MSE)')
    ax[1].set_title('Approximation')
    ax[1].set_xlabel(tags[0])
    ax[3].set_xlabel(tags[1])
    ax[5].set_xlabel(tags[2])
    ax[2].set_title('Target')
    for i in range(1, 7):
        ax[i].grid(False)
        # Hide X and Y axes tick marks
        ax[i].set_xticks([])
        ax[i].set_yticks([])
    
    fig.subplots_adjust(bottom=0.25)
    
    def iter_update(val):
        idx = iter_slider.val
        ax[0].cla()
        for i, l in enumerate(log_losses):
            sns.lineplot(l, ax=ax[0], label=tags[i], legend='brief')
            ax[0].plot(idx, l[idx], 'ko', )
            ax[0].annotate(f'   {l[idx]:.2f}\n', xy=(idx, l[idx]), xytext=(idx, l[idx]))
        ax[0].set_title('log(MSE) over iterations')
        ax[0].set_xlabel('Iteration')
        ax[0].set_ylabel('log(MSE)')
        ax[1].imshow(images_np[0][idx], cmap='gray')
        ax[3].imshow(images_np[1][idx], cmap='gray')
        ax[5].imshow(images_np[2][idx], cmap='gray')
        
    iter_slider.on_changed(iter_update)
    
    plt.show()

In [None]:
device = 'cpu'

IMG_SIZE = 16

model_path = 'trained_models/glow/2023-11-23_12:25:04/checkpoint_60000'
model = load_glow_from_checkpoint(f'../{model_path}')

test_dataset, loader = get_rotated_mnist_dataloader(root='..',
                                                    batch_size=10000,
                                                    shuffle=False,
                                                    one_hot_encode=False,
                                                    num_examples=10000,
                                                    num_rotations=0,
                                                    img_size=IMG_SIZE,
                                                    train=False,
                                                    single_class=None,
                                                    glow=True)
all_targets, labels = next(iter(loader))

In [None]:
target_idx = 404
target = all_targets[target_idx]
label = labels[target_idx]

n_iter = 200
lr = 1e-1
wd = 0.005
scheduler_step_size = None

z_final, images_no_noise, losses_no_noise = glow_regression(model, target, label, n_iter, lr=lr, wd=wd, scheduler_step_size=scheduler_step_size)

In [None]:
visualize_regression(losses_no_noise, images_no_noise, target.squeeze())

In [None]:
'''
NOISY REGRESSION
'''
#target_idx = 402
target = all_targets[target_idx]
noisy_target = target + 0.2 * (torch.randn(target.shape) + 0.5)
label = labels[target_idx]

n_iter = 200
lr = 1e-1
wd = 0.005
scheduler_step_size = None

z_final_noisy, images_noisy, losses_noisy = glow_regression(model, noisy_target, label, n_iter, lr=lr, wd=wd, scheduler_step_size=scheduler_step_size)

In [None]:
visualize_regression(losses_noisy, images_noisy, noisy_target.squeeze())

In [None]:
loss_hist = torch.load(f'../{model_path}', map_location='cpu')['loss_hist']

fig, ax = plt.subplots()
ax.plot(loss_hist)
ax.set_title('training loss over iterations')
plt.show()

In [None]:
target_idx = 404
target = all_targets[target_idx]
noise_strength = 0.2
noisy_target = target + noise_strength * (torch.randn(target.shape) + 0.5)
label = labels[target_idx]

n_iter = 200
lrs = [1e-1, 1e-1, 1e-1]
wds = [0, 0, 0.005]
zero_starts = [True, False, False]
scheduler_step_size = None


loss_list = []
image_list = []
tags = []

for i in range(3):
    lr = lrs[i]
    wd = wds[i]
    zeros_start = zero_starts[i]
    z, im, l = glow_regression(model, noisy_target, label, n_iter, lr=lr, wd=wd, scheduler_step_size=scheduler_step_size, zero_start=zeros_start)
    tags.append(f'Zero start={zeros_start}, wd={wd}')
    loss_list.append(l)
    image_list.append(im)
    


In [None]:
visualize_three_regression(loss_list, image_list, noisy_target.squeeze(), tags)