In [None]:
import os
import torch
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.widgets import Slider
import seaborn as sns
import plotly.graph_objects as go
import plotly.express as px

from utils.dataLoaders import get_rotated_mnist_dataloader
from utils.checkpoints import load_gen_disc_from_checkpoint, load_checkpoint, print_checkpoint

from singleRegression import get_trace_single_regression, get_start_position

%matplotlib notebook
#matplotlib.use("nbagg")

In [None]:
def visualize_regression(loss, images_np, tar_np):
    fig= plt.figure(figsize=(9, 5))
    n_iterations = len(loss)
    
    a0 = plt.subplot(1, 2, 1)
    a1 = plt.subplot(3, 4, 3)
    a2 = plt.subplot(3, 4, 4)
    
    ax = [a0, a1, a2]
    
    log_loss = np.log(loss)
    slider_ax = fig.add_axes([0.25, 0.1, 0.65, 0.03])
    iter_slider = Slider(ax=slider_ax, label='Iteration', orientation='horizontal', 
                    valinit=0, valmin=0, valmax=n_iterations, valstep=1, closedmax=False)
    sns.lineplot(log_loss, ax=ax[0])
    ax[0].plot(0, log_loss[0], 'ko')
    ax[0].annotate('hoi', xy=(0, 0))
    
    ax[1].imshow(images_np[0], cmap='gray')
    ax[2].imshow(tar_np, cmap='gray')
    ax[0].set_title('log(MSE) over iterations')
    ax[0].set_xlabel('Iteration')
    ax[0].set_ylabel('log(MSE)')
    ax[1].set_title('Approximation')
    ax[2].set_title('Target')
    ax[1].grid(False)
    ax[2].grid(False)
    fig.subplots_adjust(bottom=0.25)
    
    def iter_update(val):
        idx = iter_slider.val
        ax[0].cla()
        sns.lineplot(log_loss, ax=ax[0])
        ax[0].plot(idx, log_loss[idx], 'ko', )
        ax[0].annotate(f'   {log_loss[idx]:.2f}\n', xy=(idx, log_loss[idx]), xytext=(idx, log_loss[idx]))
        ax[0].set_title('log(MSE) over iterations')
        ax[0].set_xlabel('Iteration')
        ax[0].set_ylabel('log(MSE)')
        ax[1].imshow(images_np[idx], cmap='gray')
        
    iter_slider.on_changed(iter_update)
    
    plt.show()

In [None]:
def visualize_three_regression(losses, images_np, tar_np):
    fig= plt.figure(figsize=(9, 5))
    n_iterations = len(losses[0])
    archs = ['p4', 'z2', 'vanilla']
    
    a0 = plt.subplot(1, 2, 1)
    a1 = plt.subplot(3, 4, 3)
    a2 = plt.subplot(3, 4, 4)
    a3 = plt.subplot(3, 4, 7)
    a4 = plt.subplot(3, 4, 8)
    a5 = plt.subplot(3, 4, 11)
    a6 = plt.subplot(3, 4, 12)
    
    ax = [a0, a1, a2, a3, a4, a5, a6]
    

    log_losses = [np.log(l) for l in losses]
    
    slider_ax = fig.add_axes([0.25, 0.1, 0.65, 0.03])
    iter_slider = Slider(ax=slider_ax, label='Iteration', orientation='horizontal', 
                    valinit=0, valmin=0, valmax=n_iterations, valstep=1, closedmax=False)
    
    for i, l in enumerate(log_losses):
        sns.lineplot(l, ax=ax[0], label=archs[i], legend='brief')
        ax[0].plot(0, l[0], 'ko')
    
    ax[1].imshow(images_np[0][0], cmap='gray')
    ax[2].imshow(tar_np, cmap='gray')
    ax[3].imshow(images_np[1][0], cmap='gray')
    ax[4].imshow(tar_np, cmap='gray')
    ax[5].imshow(images_np[2][0], cmap='gray')
    ax[6].imshow(tar_np, cmap='gray')
    
    ax[0].set_title('log(MSE) over iterations')
    ax[0].set_xlabel('Iteration')
    ax[0].set_ylabel('log(MSE)')
    ax[1].set_title('Approximation')
    ax[1].set_ylabel('p4')
    ax[3].set_ylabel('z2')
    ax[5].set_ylabel('vanilla')
    ax[2].set_title('Target')
    for i in range(1, 7):
        ax[i].grid(False)
        # Hide X and Y axes tick marks
        ax[i].set_xticks([])
        ax[i].set_yticks([])
    
    fig.subplots_adjust(bottom=0.25)
    
    def iter_update(val):
        idx = iter_slider.val
        ax[0].cla()
        for i, l in enumerate(log_losses):
            sns.lineplot(l, ax=ax[0], label=archs[i], legend='brief')
            ax[0].plot(idx, l[idx], 'ko', )
            ax[0].annotate(f'   {l[idx]:.2f}\n', xy=(idx, l[idx]), xytext=(idx, l[idx]))
        ax[0].set_title('log(MSE) over iterations')
        ax[0].set_xlabel('Iteration')
        ax[0].set_ylabel('log(MSE)')
        ax[1].imshow(images_np[0][idx], cmap='gray')
        ax[3].imshow(images_np[1][idx], cmap='gray')
        ax[5].imshow(images_np[2][idx], cmap='gray')
        
    iter_slider.on_changed(iter_update)
    
    plt.show()

Load trained models and complete test dataset

In [None]:
device = 'cpu'

IMG_SIZE = 28

paths = [
    '../trained_models/p4_rot_mnist/2023-10-31_14:16:50/checkpoint_20000',
    '../trained_models/z2_rot_mnist/2023-10-31_12:34:55/checkpoint_20000',
    '../trained_models/vanilla_small/2023-10-31_17:13:13/checkpoint_20000'
]
archs = ['p4', 'z2', 'vanilla']
generators = []

for p in paths:
    gen, _ = load_gen_disc_from_checkpoint(p, device, print_to_console=True)
    generators.append(gen.eval())
    checkpoint = load_checkpoint(p)
    print_checkpoint(checkpoint)

LATENT_DIM = checkpoint['latent_dim']

project_root = os.getcwd()
test_dataset, loader = get_rotated_mnist_dataloader(root='..',
                                                    batch_size=10000,
                                                    shuffle=False,
                                                    one_hot_encode=False,
                                                    num_examples=10000,
                                                    num_rotations=0,
                                                    train=False,
                                                    single_class=None)
all_targets, labels = next(iter(loader))

Chose target by specifying target index. Then perform regressions for all loaded models.

In [None]:
target_idx = 402
target = all_targets[target_idx]
label = labels[target_idx].item()

n_iterations = 1000
n_star_pos = 128
lr = 0.1
scheduler_step_size = 100
noise_amplitude = 0.0

noisy_target = target + noise_amplitude * torch.randn(target.size())

losses, reg_images = [], []

for i, gen in enumerate(generators):
    print(f'gen_arch: {archs[i]}')
    start_pos, _ = get_start_position(gen, LATENT_DIM, noisy_target, label, n_start_pos=n_star_pos)
    '''
    trace_results: [x_coord, y_coord, losses], (x_coords, y_coords) are the first two coords from a (possibly) higher dimensional input vector
    images: [image], list of images as np.arrays. Every image is an approximation at the corresponding regression step
    '''
    trace_results, images = get_trace_single_regression(gen, 
                                                        start_pos, 
                                                        noisy_target, 
                                                        label, 
                                                        n_iter=n_iterations, 
                                                        lr=lr, 
                                                        scheduler_step_size=scheduler_step_size, 
                                                        ret_intermediate_images=True)
    losses.append(trace_results[-1])
    reg_images.append(images)

In [None]:
plt.close()
visualize_three_regression(losses, reg_images, noisy_target.squeeze().numpy())