In [1]:
from Annau2022.models import generator, critic
from Annau2022.SRModel import SRModelData, SuperResolver
from Annau2022.RAPSD import compute_rapsd
from Annau2022.progress_bar import progress_bar

import math
import torch
import torch.utils.data as data_utils
import numpy as np
import pytorch_msssim

import matplotlib.pyplot as plt
import matplotlib

matplotlib.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Computer Modern Roman"],
    "font.size": 14,
    'mathtext.default': 'regular',
    'text.latex.preamble': r'\usepackage{mathrsfs}'
    })


In [2]:
pfs_path = "/workspace/Annau2022/data/pfs_data/"

model_list = [
        # CNNs
        SRModelData(region="southeast", sr_model_name="CNN", exp_id="e25c6b40324643c3afc1cf42981b11b5"),
        SRModelData(region="central", sr_model_name="CNN", exp_id="fbe44b0423204805bc6af4d7d6ac562e"),
        SRModelData(region="west", sr_model_name="CNN", exp_id="f76c0170818244629de4544805f93a59"),

        # NFS GANs
        SRModelData(region="southeast", sr_model_name="NFS", exp_id="feda42500d2b45549be96f1bf62b0b03"),
        SRModelData(region="central", sr_model_name="NFS", exp_id="0c5ee480663f4f9eb7200f8879aa1244"),
        SRModelData(region="west", sr_model_name="NFS", exp_id="db9f0fae83c949eaad5d1176a43dae47"),

        # FS GANs
        SRModelData(region="southeast", sr_model_name="L_5", exp_id="1824682ae27c48669665cf042052d584"),
        SRModelData(region="southeast", sr_model_name="L_9", exp_id="3f48868c52404eb0a833897aa4642871"),
        SRModelData(region="southeast", sr_model_name="L_13", exp_id="e1d15a0615ca489aa6a17ec60247d0af"),
        SRModelData(region="central", sr_model_name="L_5", exp_id="202ea9f8a73b401fa22e62c24d9ab2d0"),
        SRModelData(region="central", sr_model_name="L_9", exp_id="079a94c41ad3482996cc2b9f95adba8d"),
        SRModelData(region="central", sr_model_name="L_13", exp_id="bcf7e7cfa8ab4c4196ad6a2bb18e8601"),
        SRModelData(region="west", sr_model_name="L_5", exp_id="70f5be887eff42e8a216780752644b2f"),
        SRModelData(region="west", sr_model_name="L_9", exp_id="6abe7a9940c04b47819689070100e5e6"),
        SRModelData(region="west", sr_model_name="L_13", exp_id="c4ec13e65fe74b399fc9e325a9966fef"),
        # PFS GANs alpha 500
        SRModelData(region="southeast", sr_model_name="PFS (alpha = 500) L_13", exp_id="3d2ea1e5f805454ea485a3a7c783fd5a", data_path=pfs_path),
        SRModelData(region="southeast", sr_model_name="PFS (alpha = 500) L_9", exp_id="caf7f501306848f8bc746605c4994e31", data_path=pfs_path),
        SRModelData(region="southeast", sr_model_name="PFS (alpha = 500) L_5", exp_id="90375b9266eb442cb15073895e14d691", data_path=pfs_path),
        SRModelData(region="central", sr_model_name="PFS (alpha = 500) L_13", exp_id="c5154f8f03c74cba924d789357e5ca84", data_path=pfs_path),
        SRModelData(region="central", sr_model_name="PFS (alpha = 500) L_9", exp_id="e54c953370974e2db09a37e9c0c7cdb5", data_path=pfs_path),
        SRModelData(region="central", sr_model_name="PFS (alpha = 500) L_5", exp_id="1570ac86f8e94e83b85447618ca576f5", data_path=pfs_path),
        SRModelData(region="west", sr_model_name="PFS (alpha = 500) L_13", exp_id="2e78fba6814545f0be62896cd14b031f", data_path=pfs_path),
        SRModelData(region="west", sr_model_name="PFS (alpha = 500) L_9", exp_id="ad5772150e7547ee8d14aa7bac192f54", data_path=pfs_path),
        SRModelData(region="west", sr_model_name="PFS (alpha = 500) L_5", exp_id="c5c5e0e8aad5411783329f31db91ff78", data_path=pfs_path),

        # PFS GANs alpha 50
        SRModelData(region="southeast", sr_model_name="PFS (alpha = 50) L_13", exp_id="5c9745ff961e46f9af206d36b6567fae", data_path=pfs_path),
        SRModelData(region="southeast", sr_model_name="PFS (alpha = 50) L_9", exp_id="3858c673c9344e7caf24144335981752", data_path=pfs_path),
        SRModelData(region="southeast", sr_model_name="PFS (alpha = 50) L_5", exp_id="328e5221158147a9ba9b41ab2ac385c7", data_path=pfs_path),
        SRModelData(region="central", sr_model_name="PFS (alpha = 50) L_13", exp_id="eedf0cd864204866b98e5de5e710f9c3", data_path=pfs_path),
        SRModelData(region="central", sr_model_name="PFS (alpha = 50) L_9", exp_id="1d568d304d7546f78c57e98ff1366b9d", data_path=pfs_path),
        SRModelData(region="central", sr_model_name="PFS (alpha = 50) L_5", exp_id="9400ee7db2004aa3b03e91ff710061eb", data_path=pfs_path),
        SRModelData(region="west", sr_model_name="PFS (alpha = 50) L_13", exp_id="faa34028b516487185c994f48621050a", data_path=pfs_path),
        SRModelData(region="west", sr_model_name="PFS (alpha = 50) L_9", exp_id="2faf762448b54ae2b96234d6c77c38b3", data_path=pfs_path),
        SRModelData(region="west", sr_model_name="PFS (alpha = 50) L_5", exp_id="4f0574ec4f7147f1b0555cafeb1cc98f", data_path=pfs_path),

]

In [3]:
# Store results
# Model name -> metric -> value
results = {f"{model.region}_{model.sr_model_name}" : {} for model in model_list}

### 1. Pixel-wise MAE
$|G(\textbf{x}) - \textbf{y}|_1$

In [4]:
from functools import partial

def compute_mae(x, y):
    return torch.mean(torch.abs(x - y))

def mean_stats(x, y, f):
    sx = torch.hypot(x[:, 0, ...], x[:, 1, ...])
    sy = torch.hypot(y[:, 0, ...], y[:, 1, ...])

    sx = f(sx)
    sy = f(sy)

    return torch.mean(sx - sy)

def median_symmetric_accuracy(x, y):
    return 100*(np.exp(np.median(np.abs(np.log(y/x))))-1)

def compute_ms_ssim(x, y):
    x[:, 0, ...] = (x[:, 0, ...] - torch.min(x[:, 0, ...])) / (torch.max(x[:, 0, ...]) - torch.min(x[:, 0, ...]))
    y[:, 0, ...] = (y[:, 0, ...] - torch.min(y[:, 0, ...])) / (torch.max(y[:, 0, ...]) - torch.min(y[:, 0, ...]))

    x[:, 1, ...] = (x[:, 1, ...] - torch.min(x[:, 1, ...])) / (torch.max(x[:, 1, ...]) - torch.min(x[:, 1, ...]))
    y[:, 1, ...] = (y[:, 1, ...] - torch.min(y[:, 1, ...])) / (torch.max(y[:, 1, ...]) - torch.min(y[:, 1, ...]))

    return pytorch_msssim.msssim(x, y, window_size=7)

for model in model_list:
    progress_bar(model_list.index(model), len(model_list))

    key_str = f"{model.region}_{model.sr_model_name}"
    
    lr, hr = model.load_test()
    sr = SuperResolver(model=model.load_generator(), lr=lr, hr=hr, region=model.region)
    hr = sr.super_resolve()
    y = sr.ground_truth()

    mae = np.empty(lr.shape[0])
    bias_spatial_mean = np.empty(lr.shape[0])
    bias_spatial_std = np.empty(lr.shape[0])
    bias_spatial_q90 = np.empty(lr.shape[0])
    ms_ssim_score = np.empty(lr.shape[0])

    hr_fields = torch.cat([x.cpu().detach() for x in hr], dim = 0)
    y_fields = torch.cat([x.cpu().detach() for x in y], dim = 0)

    hr_spectra = compute_rapsd(hr_fields.unsqueeze(1))
    y_spectra = compute_rapsd(y_fields.unsqueeze(1))

    model_results =  dict(
        mae = compute_mae(hr_fields.flatten(), y_fields.flatten()).item(),
        ms_ssim_score = compute_ms_ssim(hr_fields, y_fields).item(),
        bias_spatial_mean = mean_stats(hr_fields, y_fields, partial(torch.mean, dim=0)).item(),
        bias_spatial_std = mean_stats(hr_fields, y_fields, partial(torch.std, dim=0)).item(),
        bias_spatial_q90 = mean_stats(hr_fields, y_fields, partial(torch.quantile, dim=0, q=0.90)).item(),
        median_symmetric_accuracy_u10 = median_symmetric_accuracy(hr_spectra["u10"], y_spectra["u10"]),
        median_symmetric_accuracy_v10 = median_symmetric_accuracy(hr_spectra["v10"], y_spectra["v10"]),
    )
    results[key_str] = model_results


Progress: [------------------> ] 96%

In [8]:
# results

In [9]:
def print_table(results, model_list):
    alphabet = {"southeast": "(a)", "central": "(b)", "west": "(c)"}
    latex_nameref = {
        "PFS (alpha = 500) L_13": "PFS ($\\alpha = 500$) $\\mathscr{{L}}_{13}$",
        "PFS (alpha = 500) L_9": "PFS ($\\alpha = 500$) $\\mathscr{{L}}_9$",
        "PFS (alpha = 500) L_5": "PFS ($\\alpha = 500$) $\\mathscr{{L}}_5$",
        "PFS (alpha = 50) L_13": "PFS ($\\alpha = 50$) $\\mathscr{{L}}_{13}$",
        "PFS (alpha = 50) L_9": "PFS ($\\alpha = 50$) $\\mathscr{{L}}_9$",
        "PFS (alpha = 50) L_5": "PFS ($\\alpha = 50$) $\\mathscr{{L}}_5$",
        "NFS": "NFS GAN",
        "L_13": "FS $\\mathscr{{L}}_{13}$",
        "L_9": "FS $\\mathscr{{L}}_9$",
        "L_5": "FS $\\mathscr{{L}}_5$",
        "CNN": "CNN",
        }
    lines = []
    for region, value in alphabet.items():
        print(
            f"{value} {region.capitalize()} & MAE [m/s] & MS-SSIM & $\mu$ [m/s] & $\sigma$ [m/s] & $Q^{{90}}$ [m/s] & $\\xi_{{u10}}$ [\\%] & $\\xi_{{v10}}$ [\\%] &\\\\ \hline"
        )

        metric_dict = {}
        for metric in results["west_NFS"].keys():
            metric_dict[metric] = []
            for model in model_list:
                if model.region == region:
                    metric_dict[metric].append(results[f"{model.region}_{model.sr_model_name}"][metric])

            if metric == "ms_ssim_score":
                metric_dict[metric] = round(np.array(metric_dict[metric]).max(), 3)
            else:
                metric_dict[metric] = round(min(metric_dict[metric], key=abs), 3)

        for model in model_list:
            if model.region == region:
                print(f"{latex_nameref[model.sr_model_name]}", end=" &")
                for metric in results[f"{model.region}_{model.sr_model_name}"].keys():
                    result_value_comp = round(results[f"{model.region}_{model.sr_model_name}"][metric], 3)
                    result_value = results[f"{model.region}_{model.sr_model_name}"][metric]
                    if math.isclose(
                        metric_dict[metric],
                        result_value_comp
                    ):
                        if math.isclose(result_value_comp, 0.00):
                            result_value = 0.00

                        print(f"\\textbf{{{result_value:.3f}}}", end=" & ")
                    else:
                        print(f"{result_value:.3f}", end=" & ")
                print("\\\\")

        print(f"\\hline")

In [10]:
print_table(results, model_list)

(a) Southeast & MAE [m/s] & MS-SSIM & $\mu$ [m/s] & $\sigma$ [m/s] & $Q^{90}$ [m/s] & $\xi_{u10}$ [\%] & $\xi_{v10}$ [\%] &\\ \hline
CNN &\textbf{1.225} & 0.850 & -0.045 & 0.026 & -0.011 & 221.010 & 286.392 & \\
NFS GAN &1.437 & 0.843 & -0.105 & -0.010 & -0.117 & 6.569 & \textbf{3.977} & \\
FS $\mathscr{{L}}_5$ &1.323 & \textbf{0.866} & -0.042 & -0.005 & -0.047 & 7.451 & 5.044 & \\
FS $\mathscr{{L}}_9$ &1.345 & 0.854 & -0.096 & -0.010 & -0.108 & 6.647 & 8.559 & \\
FS $\mathscr{{L}}_{13}$ &1.376 & 0.853 & -0.021 & -0.011 & -0.035 & \textbf{2.955} & 5.130 & \\
PFS ($\alpha = 500$) $\mathscr{{L}}_{13}$ &1.597 & 0.819 & -0.043 & \textbf{-0.004} & -0.047 & 9.625 & 9.465 & \\
PFS ($\alpha = 500$) $\mathscr{{L}}_9$ &1.624 & 0.820 & \textbf{0.013} & -0.010 & \textbf{0.000} & 14.451 & 14.726 & \\
PFS ($\alpha = 500$) $\mathscr{{L}}_5$ &1.587 & 0.827 & -0.034 & -0.008 & -0.044 & 11.803 & 8.789 & \\
PFS ($\alpha = 50$) $\mathscr{{L}}_{13}$ &1.545 & 0.828 & -0.045 & -0.029 & -0.082 & 23.690 & 20.1

In [25]:
results['west_PFS (alpha = 500) L_9']

{'mae': 1.3770686388015747,
 'ms_ssim_score': 0.8542086482048035,
 'bias_spatial_mean': -0.026783594861626625,
 'bias_spatial_std': 0.0016041324706748128,
 'bias_spatial_q90': -0.023856569081544876,
 'median_symmetric_accuracy_u10': 16.145759766187773,
 'median_symmetric_accuracy_v10': 20.906680367535692}

In [9]:
results['west_NFS']

{'median_symmetric_accuracy_u10': 14.346124081244715,
 'median_symmetric_accuracy_v10': 22.505034102917556,
 'MAE': 1.2555190324783325,
 'bias_spatial_mean': -0.2304215133190155,
 'bias_spatial_std': -0.10931239277124405,
 'bias_spatial_q90': -0.36636579036712646,
 'ms_ssim_score': 0.8646016120910645}

In [None]:
# model = model_list['PFS GAN']['(\\alpha = 500)'][0]
for model in model_list['PFS GAN']:
    lr, hr = model.load_test()
    sr = SuperResolver(model=model.load_generator(), lr=lr, hr=hr, region=model.region)
    hr = sr.super_resolve()
    y = sr.ground_truth()
    MAE = []
    for hr_field, y_field in zip(hr, y):
        hr_field = hr_field.flatten().cpu().detach()
        y_field = y_field.flatten().cpu().detach()

        MAE.append(
            torch.mean(
                torch.abs(hr_field - y_field)
            )
        )
    print(f"{model.region} {model.sr_model_name} MAE: {np.mean(MAE)}")



southeast PFS (lpha = 500) L_13 MAE: 1.3976166248321533
southeast PFS (lpha = 500) L_9 MAE: 1.4027231931686401
southeast PFS (lpha = 500) L_5 MAE: 1.3592710494995117
central PFS (lpha = 500) L_13 MAE: 1.1967296600341797
central PFS (lpha = 500) L_9 MAE: 1.1960387229919434
central PFS (lpha = 500) L_5 MAE: 1.1584988832473755
west PFS (lpha = 500) L_13 MAE: 1.3590545654296875
west PFS (lpha = 500) L_9 MAE: 1.3454322814941406
west PFS (lpha = 500) L_5 MAE: 1.3030283451080322
southeast PFS (lpha = 50) L_13 MAE: 1.4911034107208252
southeast PFS (lpha = 50) L_9 MAE: 1.4329015016555786
southeast PFS (lpha = 50) L_5 MAE: 1.4113613367080688
central PFS (lpha = 50) L_13 MAE: 1.2291878461837769
central PFS (lpha = 50) L_9 MAE: 1.227510690689087
central PFS (lpha = 50) L_5 MAE: 1.2410467863082886
west PFS (lpha = 50) L_13 MAE: 1.4131944179534912
west PFS (lpha = 50) L_9 MAE: 1.410981297492981
west PFS (lpha = 50) L_5 MAE: 1.4086753129959106
