In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
%load_ext line_profiler

In [None]:
import json
import os
import sys
sys.path.insert(0, os.path.abspath('../src/'))

import matplotlib.pyplot as plt
%matplotlib inline

import torch

In [None]:
from predict import Model
from predict import load_data
from utils import imsetshow

## 1.0 Configuration

In [None]:
model_name = 'lpips1_vgg'
lpips_mode = 'vgg'

config_file_path = "../config/config.json"
with open(config_file_path, "r") as read_file:
    config = json.load(read_file)
    
checkpoint_dir = config["paths"]["checkpoint_dir"]
run_subfolder = model_name
checkpoint_filename = 'HRNet.pth'
checkpoint_file = os.path.join('..', checkpoint_dir, run_subfolder, checkpoint_filename)
# print(checkpoint_file)
assert os.path.isfile(checkpoint_file)

## 1.1 Load model

In [None]:
model = Model(config)
model.load_checkpoint(checkpoint_file=checkpoint_file)

## 1.2 Load data

In [None]:
train_dataset, val_dataset, test_dataset, baseline_cpsnrs = load_data(config_file_path, val_proportion=0.10, top_k=-1)

## 1.3 Run evaluation

In [None]:
results = model.evaluate(train_dataset, val_dataset, test_dataset, baseline_cpsnrs)

## 1.4 Benchmark % ESA Baseline

In [None]:
results.describe().T

In [None]:
results.loc[results['part'] == 'train'].describe().loc['mean']

In [None]:
results.loc[results['part'] == 'val'].describe().loc['mean']

In [None]:
results.loc[results['part'] == 'test'].describe().loc['mean']

In [None]:
results.loc[results['part']=='train'].hist(column=['ESA', 'model'],
                                           sharex=True, sharey=True, bins=100, layout=(2,1), figsize=(10, 3));
results.loc[results['part']=='train'].hist(column=['score'],
                                           bins=100, figsize=(10, 1));

In [None]:
results.loc[results['part']=='val'].hist(column=['ESA', 'model'],
                                           sharex=True, sharey=True, bins=20, layout=(2,1), figsize=(10, 3));
results.loc[results['part']=='val'].hist(column=['score'],
                                           bins=20, figsize=(10, 1));

In [None]:
results.plot.scatter('mean_clr', 'score', s=100, alpha=.1);

In [None]:
results.plot.scatter('std_clr', 'score', s=100, alpha=.1);

## 1.5 Pretty pics

In [None]:
results[results['part'] == 'val'].sort_values('score')

In [None]:
output_dir = os.path.join('..', 'images', model_name)
os.makedirs(output_dir, exist_ok=True)

for i in range(0, 3):
    imset = val_dataset[i]
    sr, scPSNR = model(imset)

    #imsetshow(imset, k=5, figsize=(20,8), resample=False, show_histogram=True, show_map=True)

    plt.figure(figsize=(30, 10))
    plt.subplot(131);  plt.imshow(imset['lr'][0]);  plt.title('Low-Resolution-0 (300m / pixel)');
    plt.subplot(132);  plt.imshow(sr);  plt.title('Super-Resolution (100m / pixel)');
    plt.subplot(133);  plt.imshow(imset['hr']);  plt.title('Ground-truth high-resolution (100m / pixel)');

    output_file = os.path.join(output_dir, f'val_{i}.png')
    plt.savefig(output_file)

## 1.6 Calculate LPIPS

In [None]:
val_srs = []
val_hrs = []

for i in range(len(val_dataset)):
    imset = val_dataset[i]
    sr, scPSNR = model(imset)
    val_srs.append(sr)
    val_hrs.append(imset['hr'])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

val_hrs_tensor = torch.stack(val_hrs).to(device)
val_srs_tensor = torch.tensor(val_srs).to(device)

# Normalize
val_hrs_normalized = (val_hrs_tensor - 0.5) * 2
val_srs_normalized = (val_srs_tensor - 0.5) * 2

# Convert to color
val_hrs_normalized = val_hrs_normalized.unsqueeze(1).repeat(1, 3, 1, 1)
val_srs_normalized = val_srs_normalized.unsqueeze(1).repeat(1, 3, 1, 1)

# Compute LPIPS
import lpips

lpips_fn = lpips.LPIPS(net=lpips_mode).to(device)
all_lpips_values = []

# Batch size of 10
batch_size = 10

for i in range(0, len(val_hrs_normalized), batch_size):
    lpips_values = lpips_fn(val_hrs_normalized[i:i+batch_size], val_srs_normalized[i:i+batch_size]).cpu().detach().numpy()
    # add the values to the list
    all_lpips_values.extend(lpips_values)

average_lpips = sum(all_lpips_values) / len(all_lpips_values)

print(f'Average LPIPS for model {model_name}: {average_lpips}')

## 1.7 Generate submission file

In [None]:
model.generate_submission_file(imset_dataset=test_dataset, out='../submission')