In [1]:
from Annau2022.models import generator, critic
from Annau2022.SRModel import SRModelData, SuperResolver
from Annau2022.RAPSD import compute_rapsd
from Annau2022.progress_bar import progress_bar
from Annau2022.metrics import compute_mae, compute_ms_ssim, median_symmetric_accuracy, mean_stats

import math
import torch
import torch.utils.data as data_utils
import numpy as np
from functools import partial


import matplotlib.pyplot as plt
import matplotlib

matplotlib.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Computer Modern Roman"],
    "font.size": 14,
    'mathtext.default': 'regular',
    'text.latex.preamble': r'\usepackage{mathrsfs}'
    })


In [2]:
fs_data_path = "/workspace/Annau2022/data/fs_data/"
pfs_nfs_path = "/workspace/Annau2022/data/nfs_pfs/"
model_list = [
        # CNNs
        SRModelData(region="southeast", sr_model_name="CNN", exp_id="e25c6b40324643c3afc1cf42981b11b5", data_path=fs_data_path),
        SRModelData(region="central", sr_model_name="CNN", exp_id="fbe44b0423204805bc6af4d7d6ac562e", data_path=fs_data_path),
        SRModelData(region="west", sr_model_name="CNN", exp_id="f76c0170818244629de4544805f93a59", data_path=fs_data_path),

        # NFS GANs
        SRModelData(region="southeast", sr_model_name="NFS", exp_id="feda42500d2b45549be96f1bf62b0b03", data_path=pfs_nfs_path),
        SRModelData(region="central", sr_model_name="NFS", exp_id="0c5ee480663f4f9eb7200f8879aa1244", data_path=pfs_nfs_path),
        SRModelData(region="west", sr_model_name="NFS", exp_id="db9f0fae83c949eaad5d1176a43dae47", data_path=pfs_nfs_path),

        # FS GANs
        SRModelData(region="southeast", sr_model_name="L_5", exp_id="1824682ae27c48669665cf042052d584", data_path=fs_data_path),
        SRModelData(region="southeast", sr_model_name="L_9", exp_id="3f48868c52404eb0a833897aa4642871", data_path=fs_data_path),
        SRModelData(region="southeast", sr_model_name="L_13", exp_id="e1d15a0615ca489aa6a17ec60247d0af", data_path=fs_data_path),
        SRModelData(region="central", sr_model_name="L_5", exp_id="202ea9f8a73b401fa22e62c24d9ab2d0", data_path=fs_data_path),
        SRModelData(region="central", sr_model_name="L_9", exp_id="079a94c41ad3482996cc2b9f95adba8d", data_path=fs_data_path),
        SRModelData(region="central", sr_model_name="L_13", exp_id="bcf7e7cfa8ab4c4196ad6a2bb18e8601", data_path=fs_data_path),
        SRModelData(region="west", sr_model_name="L_5", exp_id="70f5be887eff42e8a216780752644b2f", data_path=fs_data_path),
        SRModelData(region="west", sr_model_name="L_9", exp_id="6abe7a9940c04b47819689070100e5e6", data_path=fs_data_path),
        SRModelData(region="west", sr_model_name="L_13", exp_id="c4ec13e65fe74b399fc9e325a9966fef", data_path=fs_data_path),

        # PFS GANs alpha 500
        SRModelData(region="southeast", sr_model_name="PFS (alpha = 500) L_13", exp_id="3d2ea1e5f805454ea485a3a7c783fd5a", data_path=pfs_nfs_path),
        SRModelData(region="southeast", sr_model_name="PFS (alpha = 500) L_9", exp_id="caf7f501306848f8bc746605c4994e31", data_path=pfs_nfs_path),
        SRModelData(region="southeast", sr_model_name="PFS (alpha = 500) L_5", exp_id="90375b9266eb442cb15073895e14d691", data_path=pfs_nfs_path),
        SRModelData(region="central", sr_model_name="PFS (alpha = 500) L_13", exp_id="c5154f8f03c74cba924d789357e5ca84", data_path=pfs_nfs_path),
        SRModelData(region="central", sr_model_name="PFS (alpha = 500) L_9", exp_id="e54c953370974e2db09a37e9c0c7cdb5", data_path=pfs_nfs_path),
        SRModelData(region="central", sr_model_name="PFS (alpha = 500) L_5", exp_id="1570ac86f8e94e83b85447618ca576f5", data_path=pfs_nfs_path),
        SRModelData(region="west", sr_model_name="PFS (alpha = 500) L_13", exp_id="2e78fba6814545f0be62896cd14b031f", data_path=pfs_nfs_path),
        SRModelData(region="west", sr_model_name="PFS (alpha = 500) L_9", exp_id="ad5772150e7547ee8d14aa7bac192f54", data_path=pfs_nfs_path),
        SRModelData(region="west", sr_model_name="PFS (alpha = 500) L_5", exp_id="c5c5e0e8aad5411783329f31db91ff78", data_path=pfs_nfs_path),

        # PFS GANs alpha 50
        SRModelData(region="southeast", sr_model_name="PFS (alpha = 50) L_13", exp_id="5c9745ff961e46f9af206d36b6567fae", data_path=pfs_nfs_path),
        SRModelData(region="southeast", sr_model_name="PFS (alpha = 50) L_9", exp_id="3858c673c9344e7caf24144335981752", data_path=pfs_nfs_path),
        SRModelData(region="southeast", sr_model_name="PFS (alpha = 50) L_5", exp_id="328e5221158147a9ba9b41ab2ac385c7", data_path=pfs_nfs_path),
        SRModelData(region="central", sr_model_name="PFS (alpha = 50) L_13", exp_id="eedf0cd864204866b98e5de5e710f9c3", data_path=pfs_nfs_path),
        SRModelData(region="central", sr_model_name="PFS (alpha = 50) L_9", exp_id="1d568d304d7546f78c57e98ff1366b9d", data_path=pfs_nfs_path),
        SRModelData(region="central", sr_model_name="PFS (alpha = 50) L_5", exp_id="9400ee7db2004aa3b03e91ff710061eb", data_path=pfs_nfs_path),
        SRModelData(region="west", sr_model_name="PFS (alpha = 50) L_13", exp_id="faa34028b516487185c994f48621050a", data_path=pfs_nfs_path),
        SRModelData(region="west", sr_model_name="PFS (alpha = 50) L_9", exp_id="2faf762448b54ae2b96234d6c77c38b3", data_path=pfs_nfs_path),
        SRModelData(region="west", sr_model_name="PFS (alpha = 50) L_5", exp_id="4f0574ec4f7147f1b0555cafeb1cc98f", data_path=pfs_nfs_path),

]

In [3]:
# Store results
# Model name -> metric -> value
results = {f"{model.region}_{model.sr_model_name}" : {} for model in model_list}

### 1. Pixel-wise MAE
$|G(\textbf{x}) - \textbf{y}|_1$

In [8]:
for model in model_list:
    progress_bar(model_list.index(model), len(model_list))

    key_str = f"{model.region}_{model.sr_model_name}"
    
    lr, hr = model.load_test()
    sr = SuperResolver(G=model.load_generator(), lr=lr.float(), hr=hr.float(), region=model.region, batch_size=1024)
    hr = sr.super_resolve().detach().clone()
    y = sr.ground_truth().detach().clone()

    hr_spectra = compute_rapsd(hr)
    y_spectra = compute_rapsd(y)

    model_results =  dict(
        mae = compute_mae(hr.flatten(), y.flatten()).item(),
        ms_ssim_score = compute_ms_ssim(hr.cpu(), y.cpu()).item(),
        bias_spatial_mean = mean_stats(hr, y, partial(torch.mean, dim=0)).item(),
        bias_spatial_std = mean_stats(hr, y, partial(torch.std, dim=0)).item(),
        bias_spatial_q90 = mean_stats(hr, y, partial(torch.quantile, dim=0, q=0.90)).item(),
        median_symmetric_accuracy_u10 = median_symmetric_accuracy(hr_spectra["u10"], y_spectra["u10"]),
        median_symmetric_accuracy_v10 = median_symmetric_accuracy(hr_spectra["v10"], y_spectra["v10"]),
    )
    results[key_str] = model_results


Progress: [------------------> ] 96%

In [9]:
def print_table(results, model_list):
    alphabet = {"southeast": "(a)", "central": "(b)", "west": "(c)"}
    latex_nameref = {
        "PFS (alpha = 500) L_13": "PFS ($\\alpha = 500$) $\\mathscr{{L}}_{13}$",
        "PFS (alpha = 500) L_9": "PFS ($\\alpha = 500$) $\\mathscr{{L}}_9$",
        "PFS (alpha = 500) L_5": "PFS ($\\alpha = 500$) $\\mathscr{{L}}_5$",
        "PFS (alpha = 50) L_13": "PFS ($\\alpha = 50$) $\\mathscr{{L}}_{13}$",
        "PFS (alpha = 50) L_9": "PFS ($\\alpha = 50$) $\\mathscr{{L}}_9$",
        "PFS (alpha = 50) L_5": "PFS ($\\alpha = 50$) $\\mathscr{{L}}_5$",
        "NFS": "NFS GAN",
        "L_13": "FS $\\mathscr{{L}}_{13}$",
        "L_9": "FS $\\mathscr{{L}}_9$",
        "L_5": "FS $\\mathscr{{L}}_5$",
        "CNN": "CNN",
        }
    lines = []
    for region, value in alphabet.items():
        print(
            f"{value} {region.capitalize()} & MAE [m/s] & MS-SSIM & $\mu$ [m/s] & $\sigma$ [m/s] & $Q^{{90}}$ [m/s] & $\\xi_{{u10}}$ [\\%] & $\\xi_{{v10}}$ [\\%] &\\\\ \hline"
        )

        metric_dict = {}
        for metric in results["west_NFS"].keys():
            metric_dict[metric] = []
            for model in model_list:
                if model.region == region:
                    metric_dict[metric].append(results[f"{model.region}_{model.sr_model_name}"][metric])

            if metric == "ms_ssim_score":
                metric_dict[metric] = round(np.array(metric_dict[metric]).max(), 3)
            else:
                metric_dict[metric] = round(min(metric_dict[metric], key=abs), 3)

        for model in model_list:
            if model.region == region:
                print(f"{latex_nameref[model.sr_model_name]}", end=" &")
                for metric in results[f"{model.region}_{model.sr_model_name}"].keys():
                    result_value_comp = round(results[f"{model.region}_{model.sr_model_name}"][metric], 3)
                    result_value = results[f"{model.region}_{model.sr_model_name}"][metric]
                    if math.isclose(
                        metric_dict[metric],
                        result_value_comp
                    ):
                        if math.isclose(result_value_comp, 0.00):
                            result_value = 0.00

                        print(f"\\textbf{{{result_value:.3f}}}", end=" & ")
                    else:
                        print(f"{result_value:.3f}", end=" & ")
                print("\\\\")

        print(f"\\hline")

In [10]:
print_table(results, model_list)

(a) Southeast & MAE [m/s] & MS-SSIM & $\mu$ [m/s] & $\sigma$ [m/s] & $Q^{90}$ [m/s] & $\xi_{u10}$ [\%] & $\xi_{v10}$ [\%] &\\ \hline
CNN &\textbf{1.230} & \textbf{0.868} & -0.217 & -0.127 & -0.376 & 220.643 & 281.435 & \\
NFS GAN &1.433 & 0.848 & -0.043 & \textbf{-0.007} & \textbf{-0.072} & 3.382 & \textbf{2.426} & \\
FS $\mathscr{{L}}_5$ &1.303 & \textbf{0.868} & -0.228 & -0.080 & -0.339 & 6.932 & 6.557 & \\
FS $\mathscr{{L}}_9$ &1.325 & 0.851 & -0.196 & -0.054 & -0.296 & 7.664 & 10.094 & \\
FS $\mathscr{{L}}_{13}$ &1.360 & 0.843 & -0.158 & -0.029 & -0.202 & \textbf{2.912} & 5.254 & \\
PFS ($\alpha = 500$) $\mathscr{{L}}_{13}$ &1.413 & 0.849 & -0.107 & -0.071 & -0.230 & 14.705 & 12.563 & \\
PFS ($\alpha = 500$) $\mathscr{{L}}_9$ &1.414 & 0.847 & -0.198 & -0.111 & -0.372 & 16.499 & 16.384 & \\
PFS ($\alpha = 500$) $\mathscr{{L}}_5$ &1.382 & 0.856 & -0.120 & -0.094 & -0.264 & 18.101 & 14.727 & \\
PFS ($\alpha = 50$) $\mathscr{{L}}_{13}$ &1.458 & 0.765 & \textbf{-0.012} & 0.018 & -0.081 

In [None]:
print_table(results, model_list)

In [None]:
results['west_PFS (alpha = 500) L_9']

In [None]:
results['west_NFS']

In [None]:
# model = model_list['PFS GAN']['(\\alpha = 500)'][0]
for model in model_list['PFS GAN']:
    lr, hr = model.load_test()
    sr = SuperResolver(model=model.load_generator(), lr=lr, hr=hr, region=model.region)
    hr = sr.super_resolve()
    y = sr.ground_truth()
    MAE = []
    for hr_field, y_field in zip(hr, y):
        hr_field = hr_field.flatten().cpu().detach()
        y_field = y_field.flatten().cpu().detach()

        MAE.append(
            torch.mean(
                torch.abs(hr_field - y_field)
            )
        )
    print(f"{model.region} {model.sr_model_name} MAE: {np.mean(MAE)}")

