In [None]:
import datetime
import os

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torchmetrics.functional.image import peak_signal_noise_ratio

from tqdm import tqdm

from utils.dataLoaders import get_rotated_mnist_dataloader
from utils.checkpoints import load_gen_disc_from_checkpoint, load_checkpoint, print_checkpoint

In [None]:
def load_regression_results(result_path):
    res = torch.load(result_path, map_location='cpu')
    path_to_model = res['path_to_model']
    generator, _ = load_gen_disc_from_checkpoint(f'../{path_to_model}')
    latent_noise = res['latent_noise']
    snrs = res['snrs']
    losses = res['loss_per_batch']
    class_to_search = ['class_to_search']
    
    print('REGRESSION DETAILS:')
    for key, value in res.items():
        if not (isinstance(value, torch.Tensor) or isinstance(value, list)):
            key = key + ': ' + '.' * (28 - len(key) - 2)
            print(f'{key : <28} {value}')
    print('\n')
    return generator, latent_noise, snrs, losses, class_to_search


def plot_comparison(tar: torch.Tensor, approx: torch.Tensor, title: str):
    fig, ax = plt.subplots(1, 2)
    ax[0].imshow(tar.detach().squeeze().cpu().numpy(), cmap='gray')
    ax[0].set_title('Target')
    ax[1].imshow(approx.detach().squeeze().cpu().numpy(), cmap='gray')
    ax[1].set_title('Approximation')
    plt.suptitle(title)
    plt.tight_layout()
    plt.show()
    

def single_class_plot(hist_data, tar_best, approx_best, snr_best, tar_worst, approx_worst, snr_worst, digit, title=''):
    """
    Plot histogram as well as best and worst approximation for a digit
    :param hist_data: numpy array containing PSNR values
    :param tar_best: target image of best approx
    :param approx_best: image of best approx
    :param snr_best: snr of best approx
    :param tar_worst: target image of worst approx
    :param approx_worst: image of worst approx
    :param snr_worst: snr of worst approx
    :param digit: digit to plot
    :param title: title for plot
    :return: -
    """
    plt.figure(figsize=(12, 6))
    ax1 = plt.subplot(1, 2, 1)
    ax2 = plt.subplot(2, 4, 3)
    ax3 = plt.subplot(2, 4, 4)
    ax4 = plt.subplot(2, 4, 7)
    ax5 = plt.subplot(2, 4, 8)
    axes = [ax1, ax2, ax3, ax4, ax5]
    sns.histplot(ax=axes[0], data=hist_data, kde=True)
    axes[0].set_xlabel('PSNR')
    axes[0].set_ylabel('Count')
    axes[0].set_title(f'Histogram of PSNR\nNumber of examplse: {len(hist_data)}\nMean: {np.mean(hist_data):.2f}')
    axes[1].imshow(tar_best.detach().squeeze().cpu().numpy(), cmap='gray')
    axes[1].set_title('Target best')
    axes[2].imshow(approx_best.detach().squeeze().cpu().numpy(), cmap='gray')
    axes[2].set_title(f'Best approx, PSNR: {snr_best:.2f}')
    axes[3].imshow(tar_worst.detach().squeeze().cpu().numpy(), cmap='gray')
    axes[3].set_title('Target worst')
    axes[4].imshow(approx_worst.detach().squeeze().cpu().numpy(), cmap='gray')
    axes[4].set_title(f'Worst approx, PSNR: {snr_worst:.2f}')
    
    for i in range(1, 5):
        axes[i].grid(False)
    
    plt.suptitle(f'{title} DIGIT: {digit}', fontsize='x-large', ha='left')
    plt.show()

In [None]:
#p4_multi_start = '../regressions/2023-11-11_13:17:20_p4_rot_mnist'
#z2_multi_start = '../regressions/2023-11-11_15:27:29_z2_rot_mnist'
#vanilla_multi_start = '../regressions/2023-11-11_12:12:48_vanilla_small'

p4_multi_start = '../regressions/2023-11-12_11:25:44_p4_rot_mnist'
z2_multi_start = '../regressions/2023-11-12_12:48:22_z2_rot_mnist'
vanilla_multi_start = '../regressions/2023-11-12_12:08:04_vanilla_small'


#p4_single_start = '../regressions/2023-11-10_12:35:35_p4_rot_mnist'
#z2_single_start = '../regressions/2023-11-10_20:56:21_z2_rot_mnist'
#vanilla_single_start = '../regressions/2023-11-10_21:01:43_vanilla_small'

p4_single_start = '../regressions/2023-11-11_13:17:20_p4_rot_mnist'
z2_single_start = '../regressions/2023-11-11_15:27:29_z2_rot_mnist'
vanilla_single_start = '../regressions/2023-11-11_12:12:48_vanilla_small'


p4_gen, p4_inputs, p4_snrs_multi, _, _ = load_regression_results(p4_multi_start)
_, _, p4_snrs_single, _, _ = load_regression_results(p4_single_start)

z2_gen, z2_inputs, z2_snrs_multi, _, _ = load_regression_results(z2_multi_start)
_, _, z2_snrs_single, _, _ = load_regression_results(z2_single_start)

van_gen, van_inputs, van_snrs_multi, _, _ = load_regression_results(vanilla_multi_start)
_, _, van_snrs_single, _, _ = load_regression_results(vanilla_single_start)

_, new_loader = get_rotated_mnist_dataloader(root='..',
                                             batch_size=p4_inputs.shape[0],
                                             shuffle=False,
                                             one_hot_encode=True,
                                             num_examples=10000,
                                             num_rotations=0,
                                             train=False)


all_targets, all_labels = next(iter(new_loader))

In [None]:
def plot_single_vs_multi_start(multi_start_snrs, single_start_snrs, gen_arch):
    if isinstance(multi_start_snrs, torch.Tensor):
        multi_start_snrs = multi_start_snrs.numpy()
    if isinstance(single_start_snrs, torch.Tensor):
        single_start_snrs = single_start_snrs.numpy()
    single_mean = np.mean(single_start_snrs)
    multi_mean = np.mean(multi_start_snrs)
    data = {'With 128 start pos': multi_start_snrs, 'Start at 0': single_start_snrs}
    sns.displot(data=data, kde=True)
    plt.title(f'{gen_arch}\nAll digits\nTotal number of examples: {len(multi_start_snrs)}\nMeans: {single_mean:.2f} and {multi_mean:.2f}')
    plt.xlabel('PSNR values')
    plt.show()

In [None]:
plot_single_vs_multi_start(
    multi_start_snrs=p4_snrs_multi, single_start_snrs=p4_snrs_single, gen_arch='p4_rot_mnist'
)
plot_single_vs_multi_start(
    multi_start_snrs=z2_snrs_multi, single_start_snrs=z2_snrs_single, gen_arch='z2_rot_mnist'
)
plot_single_vs_multi_start(
    multi_start_snrs=van_snrs_multi, single_start_snrs=van_snrs_single, gen_arch='vanilla'
)

In [None]:
def plot_hist_three_archs(p4_snrs, z2_snrs, van_snrs, title=''):
    if isinstance(p4_snrs, torch.Tensor):
        p4_snrs = p4_snrs.numpy()
    if isinstance(z2_snrs, torch.Tensor):
        z2_snrs = z2_snrs.numpy()
    if isinstance(van_snrs, torch.Tensor):
        van_snrs = van_snrs.numpy()


    data = {'p4_rot_mnist': p4_snrs, 'z2_rot_mnist': z2_snrs, 'vanilla': van_snrs}
    fig, ax = plt.subplots(1, 2, figsize=(12, 6))
    sns.histplot(ax=ax[0], data=data, kde=True)
    sns.kdeplot(ax=ax[1], data=data)
    plt.suptitle(f'{title}')
    ax[0].set_xlabel('PSNR values')
    ax[1].set_xlabel('PSNR values')
    plt.show()
plot_hist_three_archs(p4_snrs_multi, z2_snrs_multi, van_snrs_multi, 'All digits (10000 examples)\n128 start positions')

In [None]:
def plot_hist_best_worst_approx(snrs, generator, input_noise, labels, gen_arch=''):
    if isinstance(snrs, np.ndarray):
        snrs = torch.from_numpy(snrs)
    idx_worst_snr = torch.argmin(snrs).item()
    idx_best_snr = torch.argmax(snrs).item()
    
    worst_approx = generator(input_noise[idx_worst_snr].unsqueeze(0), labels[idx_worst_snr].unsqueeze(0))
    best_approx = generator(input_noise[idx_best_snr].unsqueeze(0), labels[idx_best_snr].unsqueeze(0))
    
    single_class_plot(hist_data=snrs.numpy(),
                      tar_best=all_targets[idx_best_snr],
                      approx_best=best_approx,
                      snr_best=snrs.numpy()[idx_best_snr],
                      tar_worst=all_targets[idx_worst_snr],
                      approx_worst=worst_approx,
                      snr_worst=snrs.numpy()[idx_worst_snr],
                      digit='ALL',
                      title=f'ARCH: {gen_arch},')
plot_hist_best_worst_approx(p4_snrs_multi, p4_gen, p4_inputs, all_labels, gen_arch='p4_rot_mnist')
plot_hist_best_worst_approx(z2_snrs_multi, z2_gen, z2_inputs, all_labels, gen_arch='z2_rot_mnist')
plot_hist_best_worst_approx(van_snrs_multi, van_gen, van_inputs, all_labels, gen_arch='vanilla')

In [None]:
def prepare_data_for_each_digit(snrs, generator, input_noise, labels):
    '''
    PREPARE DATA FOR EACH DIGIT
    '''
    if isinstance(snrs, torch.Tensor):
        snrs = snrs.numpy()
        
    labels_decoded = torch.argmax(labels, dim=1)
    total_num_examples = 0
    snrs_per_digit = []
    worst_approximations = [] # tuples (target, approx, snr)
    best_approximations = [] # tuples (target, approx, snr)
    means = []
    
    # loop though digits
    for c in range(10):
        # extract indices of current digit
        class_indices = torch.where(labels_decoded == c)[0].numpy()
        
        # extract snrs
        single_class_snrs = snrs[class_indices]
        num_examples = len(single_class_snrs)
        mean_curr_class = np.mean(single_class_snrs)
        
        means.append(mean_curr_class)
        snrs_per_digit.append(single_class_snrs)
        
        # extract targets, latent noise and labels for current digit
        class_targets = all_targets[class_indices]
        class_inputs = input_noise[class_indices]
        class_labels = all_labels[class_indices]
        
        # find and generate best and worst approximation
        i_best = np.argmax(single_class_snrs)
        i_worst = np.argmin(single_class_snrs)
        best_approx = generator(class_inputs[i_best].unsqueeze(0), class_labels[i_best].unsqueeze(0))
        worst_approx = generator(class_inputs[i_worst].unsqueeze(0), class_labels[i_worst].unsqueeze(0))
        
        # append tuples (target, approx, snr)
        worst_approximations.append((class_targets[i_worst, 0], worst_approx[0, 0], single_class_snrs[i_worst]))
        best_approximations.append((class_targets[i_best, 0], best_approx[0, 0], single_class_snrs[i_best]))
        
        total_num_examples += num_examples
    assert (total_num_examples == all_targets.shape[0])
    return snrs_per_digit, best_approximations, worst_approximations

p4_per_digit = prepare_data_for_each_digit(p4_snrs_multi, p4_gen, p4_inputs, all_labels)
z2_per_digit = prepare_data_for_each_digit(z2_snrs_multi, z2_gen, z2_inputs, all_labels)
van_per_digit = prepare_data_for_each_digit(van_snrs_multi, van_gen, van_inputs, all_labels)

In [None]:
def plot_dist_each_digit(snrs_list, gen_arch='?'):
    fig, axes = plt.subplots(1, 2, figsize=(10, 8))
    sns.violinplot(data=snrs_list, orient='h', split=True, ax=axes[0])
    axes[0].set_xlabel('PSNR')
    axes[0].set_ylabel('Digit')
    sns.kdeplot(data=snrs_list, ax=axes[1])
    axes[1].set_xlabel('PSNR')
    plt.ylabel('')
    plt.suptitle(f'gen_arch: {gen_arch}')
    plt.show()

plot_dist_each_digit(p4_per_digit[0], 'p4_rot_mnist')
plot_dist_each_digit(z2_per_digit[0], 'z2_rot_mnist')
plot_dist_each_digit(van_per_digit[0], 'vanilla')

In [None]:
def plot_histogram_best_worst_each_digit(snr_list, b_approximations, w_approximations, gen_arch='?'):
    '''
    PLOT HISTOGRAM, BEST AND WORST APPROXIMATION FOR EACH DIGIT
    '''
    for c in range(10):
        tar_worst, approx_worst, snr_worst = w_approximations[c]
        tar_best, approx_best, snr_best = b_approximations[c]
        single_class_plot(hist_data=snr_list[c],
                          tar_best=tar_best,
                          approx_best=approx_best,
                          snr_best=snr_best,
                          tar_worst=tar_worst,
                          approx_worst=approx_worst,
                          snr_worst=snr_worst,
                          digit=c, 
                          title=f'ARCH: {gen_arch},')    
#plot_histogram_best_worst_each_digit(p4_per_digit[0], p4_per_digit[1], p4_per_digit[2], 'p4_rot_mnist')
#plot_histogram_best_worst_each_digit(z2_per_digit[0], z2_per_digit[1], z2_per_digit[2], 'z2_rot_mnist')
plot_histogram_best_worst_each_digit(van_per_digit[0], van_per_digit[1], van_per_digit[2], 'vanilla')

In [None]:
thresh_l = 20
print(f'p4 #SNRS < {thresh_l}: {torch.sum(p4_snrs_multi < thresh_l)}')
print(f'z2 #SNRS < {thresh_l}: {torch.sum(z2_snrs_multi < thresh_l)}')
print(f'va #SNRS < {thresh_l}: {torch.sum(van_snrs_multi < thresh_l)}')

In [None]:
thresh_u = 30
print(f'p4 #SNRS > {thresh_u}: {torch.sum(p4_snrs_multi > thresh_u)}')
print(f'z2 #SNRS > {thresh_u}: {torch.sum(z2_snrs_multi > thresh_u)}')
print(f'va #SNRS > {thresh_u}: {torch.sum(van_snrs_multi > thresh_u)}')