In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
%load_ext line_profiler

In [None]:
import json
import os
import sys
sys.path.insert(0, os.path.abspath('../src/'))

import matplotlib.pyplot as plt
%matplotlib inline

import torch

In [None]:
from predict import Model
from predict import load_data
from utils import imsetshow

## Load models

In [None]:
model_names = ['cPSNR_test', 'lpips1_alex_test', 'lpips1_vgg_test']

models = {}

config_file_path = "../config/config.json"
with open(config_file_path, "r") as read_file:
    config = json.load(read_file)

checkpoint_dir = config["paths"]["checkpoint_dir"]
checkpoint_filename = 'HRNet.pth'

for model_name in model_names:
    run_subfolder = model_name
    
    checkpoint_file = os.path.join('..', checkpoint_dir, run_subfolder, checkpoint_filename)
    assert os.path.isfile(checkpoint_file)

    model = Model(config)
    model.load_checkpoint(checkpoint_file=checkpoint_file)

    # Add model to dictionary
    models[model_name] = model

## Load data

In [None]:
train_dataset, val_dataset, test_dataset, baseline_cpsnrs = load_data(config_file_path, val_proportion=0.10, top_k=-1)

## Generate predictions

In [None]:
model_results = {}

for model_name in model_names:
    model = models[model_name]
    current_model_results = []

    for imset in test_dataset:
        sr, scPSNR = model(imset)
        current_model_results.append((imset['lr'][0], imset['hr'], sr, scPSNR))

    model_results[model_name] = current_model_results

## Generate LPIPS results

In [None]:
import lpips

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def compute_lpips_values(model_name, loss_fn):
    results = model_results[model_name]

    model_hrs = [r[1] for r in results]
    model_srs = [r[2] for r in results]

    model_hrs_tensor = torch.stack(model_hrs).to(device)
    model_srs_tensor = torch.tensor(model_srs).to(device)

    # Normalize
    model_hrs_normalized = (model_hrs_tensor - 0.5) * 2
    model_srs_normalized = (model_srs_tensor - 0.5) * 2

    # Convert to color
    model_hrs_normalized = model_hrs_normalized.unsqueeze(1).repeat(1, 3, 1, 1)
    model_srs_normalized = model_srs_normalized.unsqueeze(1).repeat(1, 3, 1, 1)

    # Batch size of 10
    batch_size = 10

    all_lpips_values = []

    for i in range(0, len(model_hrs_normalized), batch_size):
        lpips_values = loss_fn(model_hrs_normalized[i:i+batch_size], model_srs_normalized[i:i+batch_size]).cpu().detach().numpy()
        # add the values to the list
        all_lpips_values.extend(lpips_values)

    return all_lpips_values

In [None]:
import numpy
model_lpips_values = {}

# Free up CUDA memory
torch.cuda.empty_cache()
lpips_fn = lpips.LPIPS(net='alex').to(device)

for model_name in model_names:
    model_lpips_values[model_name] = compute_lpips_values(model_name, lpips_fn)

    # Free up CUDA memory
    torch.cuda.empty_cache()

In [None]:
import numpy
model_lpips_values_vgg = {}

# Free up CUDA memory
torch.cuda.empty_cache()
lpips_fn = lpips.LPIPS(net='vgg').to(device)

for model_name in model_names:
    model_lpips_values_vgg[model_name] = compute_lpips_values(model_name, lpips_fn)

    # Free up CUDA memory
    torch.cuda.empty_cache()

## Generate image comparison

In [None]:
import matplotlib.pyplot as plt
import os

def generate_images(indices, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    for idx, imset_index in enumerate(indices):
        plt.figure(figsize=(20, 4), facecolor='white', dpi=300)

        plt.subplot(1, 5, 1)
        plt.imshow(test_dataset[imset_index]['lr'][0])
        plt.title('LR\n\n\n')
        plt.axis('off')

        plt.subplot(1, 5, 2)
        plt.imshow(test_dataset[imset_index]['hr'])
        plt.title('HR\n\n\n')
        plt.axis('off')

        # Plot SR images from all models
        for model_num, model_name in enumerate(model_results.keys(), start=3):
            model_result = model_results[model_name][imset_index]

            sr_image = model_result[2]
            cPSNR = model_result[3]
            lpips = model_lpips_values[model_name][imset_index].item()
            lpips_vgg = model_lpips_values_vgg[model_name][imset_index].item()

            plt.subplot(1, 5, model_num)
            plt.imshow(sr_image)
            plt.title(f'{model_name}\ncPSNR: {cPSNR:.2f}\nLPIPS (alex): {lpips:.4f}\nLPIPS (vgg): {lpips_vgg:.4f}')
            plt.axis('off')

        # Save the figure
        plt.savefig(f'{output_dir}/comparison_set_{idx}.png', bbox_inches='tight', facecolor='white', dpi=300)
        plt.close()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os

def adjust_brightness(sr_image, target_brightness):
    current_brightness = np.mean(sr_image)
    if current_brightness == 0:  # Prevent division by zero
        return sr_image
    factor = target_brightness / current_brightness
    return np.clip(sr_image * factor, 0, 1)  # Ensure values remain within valid range

def generate_adjusted_images(indices, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    for idx, imset_index in enumerate(indices):
        plt.figure(figsize=(20, 4), facecolor='white', dpi=300)

        plt.subplot(1, 5, 1)
        lr_image = test_dataset[imset_index]['lr'][0]
        plt.imshow(lr_image)
        plt.title('LR\n\n\n')
        plt.axis('off')

        plt.subplot(1, 5, 2)
        hr_image = test_dataset[imset_index]['hr']
        plt.imshow(hr_image)
        plt.title('HR\n\n\n')
        plt.axis('off')

        hr_brightness = torch.mean(hr_image).item()

        # Plot SR images from all models
        for model_num, model_name in enumerate(model_results.keys(), start=3):
            model_result = model_results[model_name][imset_index]

            sr_image = model_result[2]
            cPSNR = model_result[3]
            lpips = model_lpips_values[model_name][imset_index].item()
            lpips_vgg = model_lpips_values_vgg[model_name][imset_index].item()

            adjusted_sr_image = adjust_brightness(sr_image, hr_brightness)
            #print(hr_brightness, np.mean(sr_image), np.mean(adjusted_sr_image))

            plt.subplot(1, 5, model_num)
            plt.imshow(adjusted_sr_image)
            plt.title(f'{model_name}\ncPSNR: {cPSNR:.2f}\nLPIPS (alex): {lpips:.4f}\nLPIPS (vgg): {lpips_vgg:.4f}')
            plt.axis('off')

        # Save the figure
        plt.savefig(f'{output_dir}/comparison_set_{idx}.png', bbox_inches='tight', facecolor='white', dpi=300)
        plt.close()


In [None]:
from random import sample

# Randomly select 10 image sets
random_indices = sample(range(len(test_dataset)), 10)
output_dir = '../images/random'
generate_adjusted_images(random_indices, output_dir)

In [None]:
# Find the 10 best and worst images for alexnet in respect to LPIPS
lpips_values = model_lpips_values['lpips1_alex_test']

# Sort the indices based on the LPIPS values - lower should be first
sorted_indices = sorted(range(len(lpips_values)), key=lambda i: lpips_values[i])

# Best 10
best_indices = sorted_indices[:10]
best_output_dir = '../images/alex_best'
generate_adjusted_images(best_indices, best_output_dir)

# Worst 10
worst_indices = sorted_indices[-10:]
worst_output_dir = '../images/alex_worst'
generate_adjusted_images(worst_indices, worst_output_dir)

In [None]:
# Find the 10 best and worst images for alexnet in respect to LPIPS
lpips_values = model_lpips_values_vgg['lpips1_vgg_test']

# Sort the indices based on the LPIPS values - lower should be first
sorted_indices = sorted(range(len(lpips_values)), key=lambda i: lpips_values[i])

# Best 10
best_indices = sorted_indices[:10]
best_output_dir = '../images/vgg_best'
generate_adjusted_images(best_indices, best_output_dir)

# Worst 10
worst_indices = sorted_indices[-10:]
worst_output_dir = '../images/vgg_worst'
generate_adjusted_images(worst_indices, worst_output_dir)

## Display average cPSNR and lpips for each model

In [None]:
from tabulate import tabulate
import numpy as np

table_data = []

for model_name in model_names:
    model_results_data = [r[3] for r in model_results[model_name]]
    average_cPSNR = np.mean(model_results_data)

    lpips_values = [v.item() for v in model_lpips_values[model_name]]
    average_lpips = np.mean(lpips_values)

    lpips_values_vgg = [v.item() for v in model_lpips_values_vgg[model_name]]
    average_lpips_vgg = np.mean(lpips_values_vgg)

    table_data.append([model_name, f'{average_cPSNR:.2f}', f'{average_lpips:.4f}', f'{average_lpips_vgg:.4f}'])

print(tabulate(table_data, headers=["Model Name", "Average cPSNR", "Average LPIPS (alexnet)", "Average LPIPS (vgg)"]))
