In [None]:
import datetime
import os

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torchmetrics.functional.image import peak_signal_noise_ratio

from tqdm import tqdm

from utils.dataLoaders import get_rotated_mnist_dataloader
from utils.checkpoints import load_gen_disc_from_checkpoint, load_checkpoint, print_checkpoint

In [None]:
def load_regression_results(result_path):
    res = torch.load(result_path, map_location='cpu')
    path_to_model = res['path_to_model']
    generator, _ = load_gen_disc_from_checkpoint(f'../{path_to_model}')
    latent_noise = res['latent_noise']
    snrs = res['snrs']
    losses = res['loss_per_batch']
    class_to_search = ['class_to_search']
    
    print('REGRESSION DETAILS:')
    for key, value in res.items():
        if not (isinstance(value, torch.Tensor) or isinstance(value, list)):
            key = key + ': ' + '.' * (28 - len(key) - 2)
            print(f'{key : <28} {value}')
    print('\n')
    return generator, latent_noise, snrs, losses, class_to_search


def plot_comparison(tar: torch.Tensor, approx: torch.Tensor, title: str):
    fig, ax = plt.subplots(1, 2)
    ax[0].imshow(tar.detach().cpu().numpy(), cmap='gray')
    ax[0].set_title('Target')
    ax[1].imshow(approx.detach().cpu().numpy(), cmap='gray')
    ax[1].set_title('Approximation')
    plt.suptitle(title)
    plt.tight_layout()
    plt.show()

In [None]:
res_path = '../regressions/2023-11-10_21:01:43_vanilla_small'
gen, latent_noise, snrs, losses, class_to_search = load_regression_results(res_path)

_, new_loader = get_rotated_mnist_dataloader(root='..',
                                             batch_size=latent_noise.shape[0],
                                             shuffle=False,
                                             one_hot_encode=True,
                                             num_examples=10000,
                                             num_rotations=0,
                                             train=False)


all_targets, all_labels = next(iter(new_loader))

In [None]:
snrs_numpy = snrs.numpy()
sns.histplot(snrs_numpy)
plt.title(f'Histogram of PSNRs\nAll digits\nTotal number of examples: {len(snrs_numpy)}')
plt.xlabel('PSNR values')
plt.show()

In [None]:
idx_worst_snr = torch.argmin(snrs).item()
idx_best_snr = torch.argmax(snrs).item()

worst_approx = gen(latent_noise[idx_worst_snr].unsqueeze(0), all_labels[idx_worst_snr].unsqueeze(0))
best_approx = gen(latent_noise[idx_best_snr].unsqueeze(0), all_labels[idx_best_snr].unsqueeze(0))

plot_comparison(all_targets[idx_worst_snr, 0], worst_approx[0, 0],
                title=f'Worst approximation. SNR: {snrs[idx_worst_snr].item():.2f}')
plot_comparison(all_targets[idx_best_snr, 0], best_approx[0, 0],
                title=f'Best approximation. SNR: {snrs[idx_best_snr].item():.2f}')

In [None]:
labels_decoded = torch.argmax(all_labels, dim=1)
total_num_examples = 0
for c in range(10):
    class_indices = torch.where(labels_decoded == c)[0].numpy()
    single_class_snrs = snrs_numpy[class_indices]
    num_examples = len(single_class_snrs)
    
    sns.histplot(single_class_snrs)
    plt.title(f'Histogram of PSNRs\nDigit: {c}\nTotal number of examples: {num_examples}')
    plt.xlabel('PSNR values')
    plt.show()
    
    total_num_examples += num_examples
    
    class_targets = all_targets[class_indices]
    class_inputs = latent_noise[class_indices]
    class_labels = all_labels[class_indices]
    
    
    i_best = np.argmax(single_class_snrs)
    i_worst = np.argmin(single_class_snrs)
    best_approx = gen(class_inputs[i_best].unsqueeze(0), class_labels[i_best].unsqueeze(0))
    worst_approx = gen(class_inputs[i_worst].unsqueeze(0), class_labels[i_worst].unsqueeze(0))
    
    plot_comparison(tar=class_targets[i_worst, 0], approx=worst_approx[0, 0],
                title=f'Worst approximation. SNR: {single_class_snrs[i_worst]:.2f}')
    plot_comparison(tar=class_targets[i_best, 0], approx=best_approx[0, 0],
                title=f'Best approximation. SNR: {single_class_snrs[i_best]:.2f}')
    
    print(f'best snr: {single_class_snrs[i_best]}')

assert (total_num_examples == all_targets.shape[0])
print(f'total number of examples: {total_num_examples}')
    

In [None]:
class_inputs