In [None]:
import datetime
import os

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torchmetrics.functional.image import peak_signal_noise_ratio

from tqdm import tqdm

from utils.dataLoaders import get_rotated_mnist_dataloader
from utils.checkpoints import load_gen_disc_from_checkpoint, load_checkpoint, print_checkpoint

In [None]:
def load_regression_results(result_path):
    res = torch.load(result_path, map_location='cpu')
    path_to_model = res['path_to_model']
    generator, _ = load_gen_disc_from_checkpoint(f'../{path_to_model}')
    latent_noise = res['latent_noise']
    snrs = res['snrs']
    losses = res['loss_per_batch']
    class_to_search = ['class_to_search']
    
    print('REGRESSION DETAILS:')
    for key, value in res.items():
        if not (isinstance(value, torch.Tensor) or isinstance(value, list)):
            key = key + ': ' + '.' * (28 - len(key) - 2)
            print(f'{key : <28} {value}')
    print('\n')
    return generator, latent_noise, snrs, losses, class_to_search


def plot_comparison(tar: torch.Tensor, approx: torch.Tensor, title: str):
    fig, ax = plt.subplots(1, 2)
    ax[0].imshow(tar.detach().squeeze().cpu().numpy(), cmap='gray')
    ax[0].set_title('Target')
    ax[1].imshow(approx.detach().squeeze().cpu().numpy(), cmap='gray')
    ax[1].set_title('Approximation')
    plt.suptitle(title)
    plt.tight_layout()
    plt.show()
    

def single_class_plot(hist_data, tar_best, approx_best, snr_best, tar_worst, approx_worst, snr_worst, digit):
    """
    Plot histogram as well as best and worst approximation for a digit
    :param hist_data: numpy array containing PSNR values
    :param tar_best: target image of best approx
    :param approx_best: image of best approx
    :param snr_best: snr of best approx
    :param tar_worst: target image of worst approx
    :param approx_worst: image of worst approx
    :param snr_worst: snr of worst approx
    :param digit: digit to plot
    :return: -
    """
    plt.figure(figsize=(12, 6))
    ax1 = plt.subplot(1, 2, 1)
    ax2 = plt.subplot(2, 4, 3)
    ax3 = plt.subplot(2, 4, 4)
    ax4 = plt.subplot(2, 4, 7)
    ax5 = plt.subplot(2, 4, 8)
    axes = [ax1, ax2, ax3, ax4, ax5]
    sns.histplot(ax=axes[0], data=hist_data, kde=True)
    axes[0].set_xlabel('PSNR')
    axes[0].set_ylabel('Count')
    axes[0].set_title(f'Histogram of PSNR\nNumber of examplse: {len(hist_data)}\nMean: {np.mean(hist_data):.2f}')
    axes[1].imshow(tar_best.detach().squeeze().cpu().numpy(), cmap='gray')
    axes[1].set_title('Target best')
    axes[2].imshow(approx_best.detach().squeeze().cpu().numpy(), cmap='gray')
    axes[2].set_title(f'Best approx, PSNR: {snr_best:.2f}')
    axes[3].imshow(tar_worst.detach().squeeze().cpu().numpy(), cmap='gray')
    axes[3].set_title('Target worst')
    axes[4].imshow(approx_worst.detach().squeeze().cpu().numpy(), cmap='gray')
    axes[4].set_title(f'Worst approx, PSNR: {snr_worst:.2f}')
    
    for i in range(1, 5):
        axes[i].grid(False)
    
    plt.suptitle(f'DIGIT: {digit}', fontsize='xx-large')
    plt.show()

In [None]:
res_path = '../regressions/2023-11-11_13:17:20_p4_rot_mnist'
#res_path = '../regressions/2023-11-11_12:12:48_vanilla_small'
gen, latent_noise, snrs, losses, class_to_search = load_regression_results(res_path)

res_path2 = '../regressions/2023-11-10_12:35:35_p4_rot_mnist'
#res_path2 = '../regressions/2023-11-10_21:01:43_vanilla_small'
gen2, latent_noise2, snrs2, losses2, class_to_search2 = load_regression_results(res_path2)

_, new_loader = get_rotated_mnist_dataloader(root='..',
                                             batch_size=latent_noise.shape[0],
                                             shuffle=False,
                                             one_hot_encode=True,
                                             num_examples=10000,
                                             num_rotations=0,
                                             train=False)


all_targets, all_labels = next(iter(new_loader))

In [None]:
snrs_numpy = snrs.numpy()
snrs_numpy2 = snrs2.numpy()
mean_overall_snr = np.mean(snrs_numpy)
data = {'With 128 start pos': snrs_numpy, 'Start at 0': snrs_numpy2}
sns.displot(data=data, kde=True)
plt.title(f'Histogram of PSNRs\nAll digits\nTotal number of examples: {len(snrs_numpy)}\nMean: {mean_overall_snr:.2f}')
plt.xlabel('PSNR values')
plt.show()

In [None]:
idx_worst_snr = torch.argmin(snrs).item()
idx_best_snr = torch.argmax(snrs).item()

worst_approx = gen(latent_noise[idx_worst_snr].unsqueeze(0), all_labels[idx_worst_snr].unsqueeze(0))
best_approx = gen(latent_noise[idx_best_snr].unsqueeze(0), all_labels[idx_best_snr].unsqueeze(0))

single_class_plot(hist_data=snrs_numpy,
                  tar_best=all_targets[idx_best_snr],
                  approx_best=best_approx,
                  snr_best=snrs_numpy[idx_best_snr],
                  tar_worst=all_targets[idx_worst_snr],
                  approx_worst=worst_approx,
                  snr_worst=snrs_numpy[idx_worst_snr],
                  digit='ALL')

In [None]:
'''
PREPARE DATA FOR EACH DIGIT
'''
labels_decoded = torch.argmax(all_labels, dim=1)
total_num_examples = 0
snrs_per_digit = []
worst_approximations = [] # tuples (target, approx, snr)
best_approximations = [] # tuples (target, approx, snr)
means = []

# loop though digits
for c in range(10):
    # extract indices of current digit
    class_indices = torch.where(labels_decoded == c)[0].numpy()
    
    # extract snrs
    single_class_snrs = snrs_numpy[class_indices]
    num_examples = len(single_class_snrs)
    mean_curr_class = np.mean(single_class_snrs)
    
    means.append(mean_curr_class)
    snrs_per_digit.append(single_class_snrs)
    
    # extract targets, latent noise and labels for current digit
    class_targets = all_targets[class_indices]
    class_inputs = latent_noise[class_indices]
    class_labels = all_labels[class_indices]
    
    # find and generate best and worst approximation
    i_best = np.argmax(single_class_snrs)
    i_worst = np.argmin(single_class_snrs)
    best_approx = gen(class_inputs[i_best].unsqueeze(0), class_labels[i_best].unsqueeze(0))
    worst_approx = gen(class_inputs[i_worst].unsqueeze(0), class_labels[i_worst].unsqueeze(0))
    
    # append tuples (target, approx, snr)
    worst_approximations.append((class_targets[i_worst, 0], worst_approx[0, 0], single_class_snrs[i_worst]))
    best_approximations.append((class_targets[i_best, 0], best_approx[0, 0], single_class_snrs[i_best]))
    
    total_num_examples += num_examples
assert (total_num_examples == all_targets.shape[0])
print(f'total number of examples: {total_num_examples}')

In [None]:
#sns.set(rc={'figure.figsize':(11.7,16.27)})
sns.violinplot(data=snrs_per_digit, orient='h', split=True)
plt.title('PSNRs per Digit')
plt.xlabel('PSNR value')
plt.ylabel('Digit')
plt.show()

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 8))
sns.violinplot(data=snrs_per_digit, orient='h', split=True, ax=axes[0])
axes[0].set_xlabel('PSNR')
axes[0].set_ylabel('Digit')
sns.kdeplot(data=snrs_per_digit, ax=axes[1])
axes[1].set_xlabel('PSNR')
plt.ylabel('')
plt.show()

In [None]:
'''
PLOT HISTOGRAM, BEST AND WORST APPROXIMATION FOR EACH DIGIT
'''
for c in range(10):
    tar_worst, approx_worst, snr_worst = worst_approximations[c]
    tar_best, approx_best, snr_best = best_approximations[c]
    single_class_plot(hist_data=snrs_per_digit[c],
                      tar_best=tar_best,
                      approx_best=approx_best,
                      snr_best=snr_best,
                      tar_worst=tar_worst,
                      approx_worst=approx_worst,
                      snr_worst=snr_worst,
                      digit=c)    
