In [None]:
import os
import torch
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.widgets import Slider
import seaborn as sns
import plotly.graph_objects as go
import plotly.express as px

from utils.dataLoaders import get_rotated_mnist_dataloader
from utils.checkpoints import load_gen_disc_from_checkpoint, load_checkpoint, print_checkpoint

from singleRegression import get_trace_single_regression, get_start_position

%matplotlib notebook
#matplotlib.use("nbagg")

In [None]:
def plot_comparison(tar: torch.Tensor, approx: torch.Tensor, title: str):
    with plt.ioff():
        fig, ax = plt.subplots(1, 2)
        ax[0].imshow(tar.detach().cpu().numpy(), cmap='gray')
        ax[0].set_title('Target')
        ax[1].imshow(approx.detach().cpu().numpy(), cmap='gray')
        ax[1].set_title('Approximation')
        plt.suptitle(title)
        plt.show()
    
    #matplotlib.pyplot.close()

In [None]:
device = 'cpu'

IMG_SIZE = 28

checkpoint_path = '../trained_models/p4_rot_mnist/2023-11-09_18:14:22/checkpoint_20000'
#checkpoint_path = '../trained_models/z2_rot_mnist/2023-11-10_08:29:53/checkpoint_20000'
#checkpoint_path = '../trained_models/vanilla_small/2023-11-10_09:14:25/checkpoint_20000'
gen, _ = load_gen_disc_from_checkpoint(checkpoint_path, device=device, print_to_console=True)
checkpoint = load_checkpoint(checkpoint_path)
print_checkpoint(checkpoint)
LATENT_DIM = checkpoint['latent_dim']
gen_arch = checkpoint['gen_arch']
gen.eval()

max_x_y = 3
grid_size = 1000
inc = 2 * max_x_y / np.sqrt(grid_size)

x = np.arange(-max_x_y, max_x_y, inc)
y = np.arange(-max_x_y, max_x_y, inc)
grid_size = len(x) * len(y)
grid = np.meshgrid(x, y)
inputs = torch.from_numpy(np.array(grid).T.reshape(-1, 2)).type(torch.float32)

project_root = os.getcwd()
test_dataset, loader = get_rotated_mnist_dataloader(root='..',
                                                    batch_size=1,
                                                    shuffle=False,
                                                    one_hot_encode=True,
                                                    num_examples=10000,
                                                    num_rotations=0,
                                                    train=False,
                                                    single_class=None)

loader_iterator = iter(loader)
target_images, input_labels = next(loader_iterator)
#target_images, input_labels = next(loader_iterator)
#target_images, input_labels = next(loader_iterator)
#target_images, input_labels = next(loader_iterator)
target_images = target_images.to(device)
input_labels = input_labels.repeat(grid_size, 1).type(torch.float32).to(device)

out = gen(inputs, input_labels)
with torch.no_grad():
    loss = torch.mean(torch.nn.functional.mse_loss(out, target_images.repeat(grid_size, 1, 1, 1), reduction='none'), dim=(1,2,3))
loss = loss.detach().cpu().numpy().reshape(len(x), len(y))

In [None]:
'''
VISUALIZE GENERATOR OUTPUT FOR X, Y INPUT COORDINATES
'''
plt.ion()
fig, ax = plt.subplots()
images = out.squeeze().detach().cpu().numpy().reshape(len(x), len(y), 28, 28)
fig.subplots_adjust(left=0.25, bottom=0.25)
y_ax = fig.add_axes([0.1, 0.25, 0.0225, 0.63])
x_ax = fig.add_axes([0.25, 0.1, 0.65, 0.03])
y_slider = Slider(ax=y_ax, label='y', orientation='vertical', valinit=0, valmin=-max_x_y, valmax=max_x_y, valstep=inc, closedmax=False)
x_slider = Slider(ax=x_ax, label='x', valinit=0, valmin=-max_x_y, valmax=max_x_y, valstep=inc, closedmax=False)
ax.imshow(images[len(x)//2, len(y)//2], cmap='gray')
ax.grid(False)

def update(val):
    x_idx = (x_slider.val / (2 * max_x_y) + 0.5) * len(x)
    y_idx = (y_slider.val / (2 * max_x_y) + 0.5) * len(x)
    ax.imshow(images[int(x_idx), int(y_idx)], cmap='gray')
    
x_slider.on_changed(update)
y_slider.on_changed(update)

plt.show()

In [None]:
'''
3D SURFACE PLOT OF LOSS LANDSCAPE
Show potential start positions and regression path
'''
n_iter = 100
label = int(torch.where(input_labels[0] == 1)[0])
start_vec, start_pos_data = get_start_position(gen, 2, target_images, label, n_start_pos=10)
#start_vec = torch.tensor([0,0])
trace_data, reg_images = get_trace_single_regression(gen, start_vec, target_images, label=label, 
                                                     lr=0.01, n_iter=n_iter, ret_intermediate_images=True)

fig = go.Figure()
fig.add_surface(x=grid[1], y=grid[0], z=loss)
fig.add_scatter3d(x=start_pos_data[0], y=start_pos_data[1], z=start_pos_data[2], mode='markers')
fig.add_scatter3d(x=trace_data[0], y=trace_data[1], z=trace_data[2])
fig.add_scatter3d(x=[trace_data[0][0]], y=[trace_data[1][0]], z=[trace_data[2][0]], marker=dict(symbol='diamond'))
title = f'Loss Landscape: {gen_arch}'

fig.update_layout(title=title, autosize=False,
                  width=800, height=600,
                  margin=dict(l=65, r=50, b=65, t=90),
                  )

fig.show()

In [None]:
fig= plt.figure(figsize=(9, 5))
n_iterations = len(trace_data[0])

a0 = plt.subplot(1, 2, 1)
a1 = plt.subplot(1, 4, 3)
a2 = plt.subplot(1, 4, 4)

ax = [a0, a1, a2]

log_loss = np.log(trace_data[-1])
slider_ax = fig.add_axes([0.25, 0.1, 0.65, 0.03])
iter_slider = Slider(ax=slider_ax, label='Iteration', orientation='horizontal', 
                valinit=0, valmin=0, valmax=n_iterations, valstep=1, closedmax=False)
sns.lineplot(log_loss, ax=ax[0])
ax[0].plot(0, log_loss[0], 'ko')
ax[0].annotate('hoi', xy=(0, 0))

ax[1].imshow(reg_images[0], cmap='gray')
ax[2].imshow(target_images[0, 0], cmap='gray')
ax[0].set_title('log(MSE) over iterations')
ax[0].set_xlabel('Iteration')
ax[0].set_ylabel('log(MSE)')
ax[1].set_title('Approximation')
ax[2].set_title('Target')
ax[1].grid(False)
ax[2].grid(False)
fig.subplots_adjust(bottom=0.25)

def iter_update(val):
    idx = iter_slider.val
    ax[0].cla()
    sns.lineplot(log_loss, ax=ax[0])
    ax[0].plot(idx, log_loss[idx], 'ko', )
    ax[0].annotate(f'   {log_loss[idx]:.2f}\n', xy=(idx, log_loss[idx]), xytext=(idx, log_loss[idx]))
    ax[0].set_title('log(MSE) over iterations')
    ax[0].set_xlabel('Iteration')
    ax[0].set_ylabel('log(MSE)')
    ax[1].imshow(reg_images[idx], cmap='gray')
    
iter_slider.on_changed(iter_update)

plt.show()

In [None]:
# classic matplotlib version:
#fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
#surf = ax.plot_surface(grid[1], grid[0], loss, cmap=cm.coolwarm, linewidth=0, antialiased=False)
#plt.show()
import plotly.graph_objects as go
fig = go.Figure(data=[go.Surface(x=grid[1], y=grid[0], z=loss)])

fig.update_layout(title='Loss Landscape', autosize=False,
                  width=800, height=600,
                  margin=dict(l=65, r=50, b=65, t=90))

fig.show()

In [None]:
'''
FIND AND PLOT BEST APPROXIMATION BASED ON LOSS LANDSCAPE
'''
min_index = np.unravel_index(np.argmin(loss, axis=None), loss.shape)
min_x = grid[1][min_index[0], 0]
min_y = grid[0][0, min_index[1]]
#print(min_x, min_y)
min_coords = torch.tensor([min_x, min_y]).type(torch.float32).unsqueeze(0)
best_approx = gen(min_coords, input_labels[0].unsqueeze(0))
min_loss = torch.nn.functional.mse_loss(best_approx, target_images[0].unsqueeze(0))
plot_comparison(target_images[0, 0], best_approx[0, 0], f'best approximation\nx: {min_x:.2f}, y: {min_y:.2f}\nloss: {min_loss:.2f}')

In [None]:
sanity_check_input = torch.tensor([-0.0, -0.0]).type(torch.float32).unsqueeze(0)
o = gen(sanity_check_input, input_labels[0].unsqueeze(0))
l = torch.nn.functional.mse_loss(o, target_images[0].unsqueeze(0))
print(l)
plot_comparison(target_images[0, 0], o[0, 0], f'x: {sanity_check_input[0, 0]:.2f}, y: {sanity_check_input[0, 1]:.2f}\nloss: {l:.3f}')