## Evaluation

In [22]:
import torch
import numpy as np
import random
from models.architectures import srgan, srflow
import PIL
import os
import torchvision
from torchvision import transforms
import matplotlib as mpl
from mpl_toolkits.axes_grid1 import make_axes_locatable
import pandas as pd
import numpy as np

# Set a seed for reproducibility
np.random.seed(42)

# Dataset loading
from data import dataloading
from data.era5_temp_dataset import InverseMinMaxScaler

from os.path import exists, join
import matplotlib.pyplot as plt
from matplotlib import transforms
import timeit
import pdb
import argparse
import seaborn as sns

from utils.metrics import *

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
print('GPUs avail:', torch.cuda.device_count())

# Parse Settings
parser = argparse.ArgumentParser()

# train configs
parser.add_argument("--model", type=str, default="srflow",
                    help="Model you want to train.")
parser.add_argument("--modeltype", type=str, default="srflow",
                    help="Specify modeltype you would like to train [srflow, cdiff, srgan].")
parser.add_argument("--model_path", type=str, default="runs/",
                    help="Directory where models are saved.")
parser.add_argument("--modelname", type=str, default=None,
                    help="Sepcify modelname to be tested.")
parser.add_argument("--epochs", type=int, default=10000,
                    help="number of epochs")
parser.add_argument("--max_steps", type=int, default=2000000,
                    help="For training on a large dataset.")
parser.add_argument("--log_interval", type=int, default=100,
                    help="Interval in which results should be logged.")

# runtime configs
parser.add_argument("--visual", action="store_true",
                    help="Visualizing the samples at test time.")
parser.add_argument("--noscaletest", action="store_true",
                    help="Disable scale in coupling layers only at test time.")
parser.add_argument("--noscale", action="store_true",
                    help="Disable scale in coupling layers.")
parser.add_argument("--testmode", action="store_true",
                    help="Model run on test set.")
parser.add_argument("--train", action="store_true",
                    help="If model should be trained.")
parser.add_argument("--resume_training", action="store_true",
                    help="If training should be resumed.")
parser.add_argument("--constraint", type=str, default='scaddDS',
                    help="Physical Constraint to be applied during training.")                   

# hyperparameters
parser.add_argument("--nbits", type=int, default=8,
                    help="Images converted to n-bit representations.")
parser.add_argument("--s", type=int, default=16, help="Upscaling factor.")
parser.add_argument("--bsz", type=int, default=16, help="batch size")
parser.add_argument("--lr", type=float, default=0.0002, help="learning rate")
parser.add_argument("--filter_size", type=int, default=512, help="filter size NN in Affine Coupling Layer")
parser.add_argument("--L", type=int, default=3, help="# of levels")
parser.add_argument("--K", type=int, default=2,
                    help="# of flow steps, i.e. model depth")
parser.add_argument("--nb", type=int, default=16,
                    help="# of residual-in-residual blocks LR network.")
parser.add_argument("--condch", type=int, default=128//8,
                    help="# of residual-in-residual blocks in LR network.")

# data
parser.add_argument("--datadir", type=str, default="/home/christina/Documents/clim-var-ds-cnf-own/data/",
                    help="Dataset to train the model on.")
parser.add_argument("--trainset", type=str, default="era5-TCW", help='[era5-TCW, era5-T2M]')
parser.add_argument("--testset", type=str, default="era5-TCW", help="Specify test dataset")

args = parser.parse_args('')
config = vars(args)

GPUs avail: 1


In [3]:
if torch.cuda.is_available():
    args.device = torch.device("cuda")
    args.num_gpus = torch.cuda.device_count()
    args.parallel = False

else:
    args.device = "cpu"

In [4]:
# use min-max due to non-gaussian distribution of data and outlier handling
def min_max_scaler(x, min_val=0, max_val=124):
    if min_val is None:
        min_val = x.min()
    if max_val is None:
        max_val = x.max()
    scaled_x = (x - min_val) / (max_val - min_val)
    return scaled_x, min_val, max_val

def inv_min_max_scaler(scaled_x, min_val=0, max_val=124):
    x = scaled_x * (max_val - min_val) + min_val
    return x

In [5]:
import os
import torch

def load_model(exp_dir, model, mpath):
    """
    Load a model from the specified directory and checkpoint file.
    
    Parameters:
    exp_dir (str): The experiment directory path.
    model (torch.nn.Module): The model instance to load the state dictionary into.
    mpath (str): The path to the model checkpoint file.
    
    Returns:
    torch.nn.Module: The loaded model with the state dictionary.
    
    Raises:
    FileNotFoundError: If the model file does not exist at the specified path.
    """
    try:
        # Check if the file exists
        if not os.path.exists(mpath):
            raise FileNotFoundError(f"Model file not found at {mpath}")
        
        # Load the checkpoint
        ckpt = torch.load(mpath)
        
        # Load the state dictionary into the model
        model.load_state_dict(ckpt['model_state_dict'])
        
        print("Model loaded successfully.")
        return model

    except FileNotFoundError as e:
        print(e)
        return None

### Metric Evaluation
Evaluations: 
- CNF unconstrained vs. constrained
- CNF vs. GAN vs. CDiff
- Residual Error Plots for ???
- Metrics for 2x,4x,8x,16x with and without constraints

In [18]:
import os
import json
import torch
import torchvision.utils as vutils
import sys
sys.path.append("../../")
from utils import metrics

def save_snapshot(tensor, path):
    # Function to save tensor as image in Viridis color scheme
    plt.figure()
    plt.imshow(tensor[0,...].cpu().numpy().squeeze(), cmap='viridis')
    plt.axis('off')
    plt.savefig(path, bbox_inches='tight', pad_inches=0)
    plt.close()

def save_grid_images(tensor_batch, save_path, nrow=3, title=None):
    
    # Create a grid of images
    grid_img = vutils.make_grid(tensor_batch[0:9], nrow=nrow, normalize=True, scale_each=True)

    # Convert the grid to numpy and plot
    np_grid = grid_img.cpu().numpy().transpose((1, 2, 0))
    plt.figure(figsize=(10, 10))
    plt.imshow(np_grid[:, :, 0], cmap='viridis')  # Use Viridis colormap for grayscale images
    plt.axis('off')
    plt.title(title)
    plt.tight_layout(pad=0)

    # Save the grid image
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    
def metric_eval(dataloader, model, exp_name, args):
    metric_dict = {
        'rmse0': [], 'mse0': [], 'mae0': [],
        'rmse05': [], 'mse05': [], 'mae05': [],
        'rmse08': [], 'mse08': [], 'mae08': [],
        'rmse1': [], 'mse1': [], 'mae1': [],
        'crps0': [], 'crps05': [],'crps08': [],'crps1': [],
    }

    # Create directories
    results_dir = os.path.join(exp_name, 'experiment_results')
    os.makedirs(results_dir, exist_ok=True)
    
    snapshots_dir = os.path.join(results_dir, 'snapshots')
    os.makedirs(snapshots_dir, exist_ok=True)

    model.eval()
    with torch.no_grad():
        for batch_idx, item in enumerate(dataloader):
            y = item[0].to(args.device)
            x = item[1].to(args.device)

            y_unorm = item[2].to(args.device)
            x_unorm = item[3].to(args.device)

            z, _ = model.forward(x_hr=y, xlr=x)

            # Evaluate for different temperatures
            mu0, _, _ = model(xlr=x, reverse=True, eps=0)
            mu05, _, _ = model(xlr=x, reverse=True, eps=0.5)
            mu08, _, _ = model(xlr=x, reverse=True, eps=0.8)
            mu1, _, _ = model(xlr=x, reverse=True, eps=1.0)

            # Compute and store MSE for each temperature
            mse0 = metrics.MSE(inv_min_max_scaler(mu0), y_unorm).detach().cpu().numpy()
            mse05 = metrics.MSE(inv_min_max_scaler(mu05), y_unorm).detach().cpu().numpy()
            mse08 = metrics.MSE(inv_min_max_scaler(mu08), y_unorm).detach().cpu().numpy()
            mse1 = metrics.MSE(inv_min_max_scaler(mu1), y_unorm).detach().cpu().numpy()
            
            metric_dict['mse0'].append(mse0)
            metric_dict['mse05'].append(mse05)
            metric_dict['mse08'].append(mse08)
            metric_dict['mse1'].append(mse1)

            # Compute and store MAE for each temperature
            mae0 = metrics.MAE(inv_min_max_scaler(mu0), y_unorm).detach().cpu().numpy()
            mae05 = metrics.MAE(inv_min_max_scaler(mu05), y_unorm).detach().cpu().numpy()
            mae08 = metrics.MAE(inv_min_max_scaler(mu08), y_unorm).detach().cpu().numpy()
            mae1 = metrics.MAE(inv_min_max_scaler(mu1), y_unorm).detach().cpu().numpy()

            metric_dict['mae0'].append(mae0)
            metric_dict['mae05'].append(mae05)
            metric_dict['mae08'].append(mae08)
            metric_dict['mae1'].append(mae1)

            # Compute and store RMSE for each temperature
            rmse0 = metrics.RMSE(inv_min_max_scaler(mu0), y_unorm).detach().cpu().numpy()
            rmse05 = metrics.RMSE(inv_min_max_scaler(mu05), y_unorm).detach().cpu().numpy()
            rmse08 = metrics.RMSE(inv_min_max_scaler(mu08), y_unorm).detach().cpu().numpy()
            rmse1 = metrics.RMSE(inv_min_max_scaler(mu1), y_unorm).detach().cpu().numpy()

            metric_dict['rmse0'].append(rmse0)
            metric_dict['rmse05'].append(rmse05)
            metric_dict['rmse08'].append(rmse08)
            metric_dict['rmse1'].append(rmse1)

             # Calculate CRPS for ensemble
            metric_dict['crps0'].append(crps_ensemble(y_unorm, inv_min_max_scaler(mu0)))
            metric_dict['crps05'].append(crps_ensemble(y_unorm, inv_min_max_scaler(mu05)))
            metric_dict['crps08'].append(crps_ensemble(y_unorm, inv_min_max_scaler(mu08)))
            metric_dict['crps1'].append(crps_ensemble(y_unorm, inv_min_max_scaler(mu1)))
            
            # Save grid of images for visualization (adjust as needed)
            if batch_idx == 0:  # Save only for the first batch for simplicity
                # Save snapshots
                save_grid_images(y_unorm, os.path.join(snapshots_dir, f'ground_truth_{batch_idx}.png'))
                save_grid_images(mu0, os.path.join(snapshots_dir, f'prediction_mu0_{batch_idx}.png'))
                save_grid_images(mu05, os.path.join(snapshots_dir, f'prediction_mu05_{batch_idx}.png'))
                save_grid_images(mu08, os.path.join(snapshots_dir, f'prediction_mu08_{batch_idx}.png'))
                save_grid_images(mu1, os.path.join(snapshots_dir, f'prediction_mu1_{batch_idx}.png'))
                save_grid_images(x_unorm, os.path.join(snapshots_dir, f'low_res_{batch_idx}.png'))
                
            print(f'Current RMSE - mu0: {rmse0}, mu05: {rmse05}, mu08: {rmse08}, mu1: {rmse1}')
            print(f'Current MAE - mu0: {mae0}, mu05: {mae05}, mu08: {mae08}, mu1: {mae1}')
            # print(f'Current CRPS - mu0:{crps0}, mu05: {crps05}, mu08: {crps08}, m1: {crps1}')

            # if batch_idx == 200:
            #     break

    # Create a string representation of the metric_dict
    mean_dict = {}
    for key, value in metric_dict.items():
        if len(value) > 0:
            mean_dict[key] = np.mean(value)   
            
    metric_str = "\n".join([f"{key}: {value}" for key, value in mean_dict.items()])

    # Save metric_dict to a text file
    with open(os.path.join(results_dir, 'metrics.txt'), 'w') as f:
        json.dump(metric_str, f, indent=4)
        
    print(mean_dict)
    return mean_dict

# Experiments: Comparing across upsampling factors


In [7]:
# load model
print('Num of avail GPUs:', torch.cuda.device_count())

Num of avail GPUs: 1


In [8]:
import gc
gc.collect()

44

In [9]:
args.s = 2
train_loader, val_loader, test_loader, args = dataloading.load_data(args)
in_channels = next(iter(test_loader))[0].shape[1]
height, width = next(iter(test_loader))[0].shape[2], next(iter(test_loader))[0].shape[3]
cnf2x = srflow.SRFlow((in_channels, height, width), args.filter_size, args.L, args.K, args.bsz, 2, args.constraint, args.nb, args.condch, args.noscale, args.noscaletest)
exp_dir = 'runs/srflow_era5-TCW_None_2024_06_29_13_26_09_2x/'
cnf_path = '/home/christina/Documents/clim-var-ds-cnf-own/runs/srflow_era5-TCW_None_2024_06_29_13_26_09_2x/model_checkpoints/model_epoch_6_step_15750.tar'
# Load the model
cnf2x = load_model(exp_dir, cnf2x, cnf_path).to(args.device)
metric_eval(test_loader, cnf2x, exp_dir, args)

Loading ERA5 TCW ...
Model loaded successfully.
Current RMSE - mu0: 0.19845393300056458, mu05: 0.23489730060100555, mu08: 0.2743164598941803, mu1: 0.3107762336730957
Current MAE - mu0: 0.1074303388595581, mu05: 0.13323159515857697, mu08: 0.1619873046875, mu1: 0.18586765229701996
Current RMSE - mu0: 0.264850378036499, mu05: 0.30392539501190186, mu08: 0.35547173023223877, mu1: 0.4030383825302124
Current MAE - mu0: 0.14183129370212555, mu05: 0.1705910861492157, mu08: 0.20671749114990234, mu1: 0.23533710837364197
Current RMSE - mu0: 0.21360093355178833, mu05: 0.25308507680892944, mu08: 0.3063626289367676, mu1: 0.35042810440063477
Current MAE - mu0: 0.1193913072347641, mu05: 0.14680622518062592, mu08: 0.18184220790863037, mu1: 0.20845787227153778
Current RMSE - mu0: 0.29184892773628235, mu05: 0.34031933546066284, mu08: 0.41634663939476013, mu1: 0.4647655785083771
Current MAE - mu0: 0.1666335165500641, mu05: 0.2002718448638916, mu08: 0.2452668845653534, mu1: 0.27708011865615845
Current RMSE 

{'rmse0': [array(0.19845393, dtype=float32),
  array(0.26485038, dtype=float32),
  array(0.21360093, dtype=float32),
  array(0.29184893, dtype=float32),
  array(0.21131109, dtype=float32),
  array(0.22658798, dtype=float32),
  array(0.22285461, dtype=float32),
  array(0.22350156, dtype=float32),
  array(0.2587837, dtype=float32),
  array(0.25235236, dtype=float32),
  array(0.17166302, dtype=float32),
  array(0.29554474, dtype=float32),
  array(0.23239002, dtype=float32),
  array(0.2578716, dtype=float32),
  array(0.17107825, dtype=float32),
  array(0.20037082, dtype=float32),
  array(0.25941825, dtype=float32),
  array(0.23065871, dtype=float32),
  array(0.19629909, dtype=float32),
  array(0.2141766, dtype=float32),
  array(0.20737872, dtype=float32),
  array(0.14502652, dtype=float32),
  array(0.20748231, dtype=float32),
  array(0.18101114, dtype=float32),
  array(0.21279448, dtype=float32),
  array(0.21451163, dtype=float32),
  array(0.272429, dtype=float32),
  array(0.19924569, dtyp

In [10]:
gc.collect()

9859

In [13]:
args.s = 4
train_loader, val_loader, test_loader, args = dataloading.load_data(args)
in_channels = next(iter(test_loader))[0].shape[1]
height, width = next(iter(test_loader))[0].shape[2], next(iter(test_loader))[0].shape[3]
cnf4x = srflow.SRFlow((in_channels, height, width), args.filter_size, args.L, args.K, args.bsz, 4, args.constraint, args.nb, args.condch, args.noscale, args.noscaletest)
exp_dir = 'runs/srflow_era5-TCW_None_2024_06_29_13_26_13_4x/'
cnf_path = '/home/christina/Documents/clim-var-ds-cnf-own/runs/srflow_era5-TCW_None_2024_06_29_13_26_13_4x/model_checkpoints/model_epoch_10_step_26250.tar'
# Load the model
cnf4x = load_model(exp_dir, cnf4x, cnf_path).to(args.device)
metric_eval(test_loader, cnf4x, exp_dir, args)

Loading ERA5 TCW ...
Model loaded successfully.
Current RMSE - mu0: 0.48445749282836914, mu05: 0.5497838258743286, mu08: 0.6331164836883545, mu1: 0.7063823938369751
Current MAE - mu0: 0.2833460867404938, mu05: 0.3329157829284668, mu08: 0.39429035782814026, mu1: 0.4425612688064575
Current RMSE - mu0: 0.660796046257019, mu05: 0.7459709644317627, mu08: 0.8538327217102051, mu1: 0.9509994983673096
Current MAE - mu0: 0.38327863812446594, mu05: 0.44709357619285583, mu08: 0.5285980105400085, mu1: 0.5908297300338745
Current RMSE - mu0: 0.5357289910316467, mu05: 0.602929949760437, mu08: 0.6921342015266418, mu1: 0.7740520238876343
Current MAE - mu0: 0.3102187514305115, mu05: 0.36552053689956665, mu08: 0.43306758999824524, mu1: 0.4888858497142792
Current RMSE - mu0: 0.7021902203559875, mu05: 0.799588680267334, mu08: 0.9147934317588806, mu1: 1.0608819723129272
Current MAE - mu0: 0.4194105565547943, mu05: 0.48660188913345337, mu08: 0.5703456401824951, mu1: 0.6496960520744324
Current RMSE - mu0: 0.53

{'rmse0': [array(0.4844575, dtype=float32),
  array(0.66079605, dtype=float32),
  array(0.535729, dtype=float32),
  array(0.7021902, dtype=float32),
  array(0.5345534, dtype=float32),
  array(0.53129745, dtype=float32),
  array(0.58081996, dtype=float32),
  array(0.54456306, dtype=float32),
  array(0.6393016, dtype=float32),
  array(0.5850318, dtype=float32),
  array(0.45357272, dtype=float32),
  array(0.6907858, dtype=float32),
  array(0.5726379, dtype=float32),
  array(0.6146871, dtype=float32),
  array(0.46337536, dtype=float32),
  array(0.51045215, dtype=float32),
  array(0.6622318, dtype=float32),
  array(0.5587851, dtype=float32),
  array(0.49564293, dtype=float32),
  array(0.5329642, dtype=float32),
  array(0.5232176, dtype=float32),
  array(0.37734157, dtype=float32),
  array(0.5012482, dtype=float32),
  array(0.45465568, dtype=float32),
  array(0.5130371, dtype=float32),
  array(0.49372, dtype=float32),
  array(0.6436701, dtype=float32),
  array(0.50554013, dtype=float32),
  a

In [14]:
gc.collect()

9867

In [23]:
args.s = 8
args.device = 'cuda'
train_loader, val_loader, test_loader, args = dataloading.load_data(args)
in_channels = next(iter(test_loader))[0].shape[1]
height, width = next(iter(test_loader))[0].shape[2], next(iter(test_loader))[0].shape[3]
cnf8x = srflow.SRFlow((in_channels, height, width), args.filter_size, args.L, args.K, args.bsz, 8, args.constraint, args.nb, args.condch, args.noscale, args.noscaletest)
exp_dir = 'runs/srflow_era5-TCW_None_2024_06_29_13_26_12_8x/'
cnf_path = '/home/christina/Documents/clim-var-ds-cnf-own/runs/srflow_era5-TCW_None_2024_06_29_13_26_12_8x/model_checkpoints/model_epoch_4_step_11500.tar'
# Load the model
cnf8x = load_model(exp_dir, cnf8x, cnf_path).to(args.device)
metric_eval(test_loader, cnf8x, exp_dir, args)

Loading ERA5 TCW ...
Model loaded successfully.
Current RMSE - mu0: 1.1586923599243164, mu05: 1.30537748336792, mu08: 1.483165979385376, mu1: 1.664674162864685
Current MAE - mu0: 0.7625858187675476, mu05: 0.8745710849761963, mu08: 1.0155210494995117, mu1: 1.1532011032104492
Current RMSE - mu0: 1.5189896821975708, mu05: 1.6796889305114746, mu08: 1.9570293426513672, mu1: 2.290147542953491
Current MAE - mu0: 1.0023727416992188, mu05: 1.1351418495178223, mu08: 1.3347281217575073, mu1: 1.5625380277633667
Current RMSE - mu0: 1.2048746347427368, mu05: 1.368869662284851, mu08: 1.6379868984222412, mu1: 1.8971459865570068
Current MAE - mu0: 0.8027985095977783, mu05: 0.9323151111602783, mu08: 1.1227161884307861, mu1: 1.2983425855636597
Current RMSE - mu0: 1.5562050342559814, mu05: 1.7503681182861328, mu08: 2.0842785835266113, mu1: 2.449570655822754
Current MAE - mu0: 1.0227371454238892, mu05: 1.180488109588623, mu08: 1.4070734977722168, mu1: 1.6582939624786377
Current RMSE - mu0: 1.27882850170135

{'rmse0': 1.2616284,
 'mse0': 2.1301696,
 'mae0': 0.8331672,
 'rmse05': 1.4058713,
 'mse05': 2.6702805,
 'mae05': 0.9482584,
 'rmse08': 362.5076,
 'mse08': 1302008400.0,
 'mae08': 8.969207,
 'rmse1': 10385172000000.0,
 'mse1': 1.0785181e+30,
 'mae1': 225755900000.0,
 'crps0': 0.51439875,
 'crps05': 0.6055622,
 'crps08': 4.1272182,
 'crps1': 98239996000.0}

In [16]:
gc.collect()

9859

In [25]:
args.s = 16
args.device = 'cuda'
train_loader, val_loader, test_loader, args = dataloading.load_data(args)
in_channels = next(iter(test_loader))[0].shape[1]
height, width = next(iter(test_loader))[0].shape[2], next(iter(test_loader))[0].shape[3]
cnf16x = srflow.SRFlow((in_channels, height, width), args.filter_size, args.L, args.K, args.bsz, 16, args.constraint, args.nb, args.condch, args.noscale, args.noscaletest)
exp_dir = 'runs/srflow_era5-TCW_None_2024_06_29_13_26_12_16x/'
cnf_path = '/home/christina/Documents/clim-var-ds-cnf-own/runs/srflow_era5-TCW_None_2024_06_29_13_26_12_16x/model_checkpoints/model_epoch_5_step_13750.tar'
# Load the model
cnf16x = load_model(exp_dir, cnf16x, cnf_path).to(args.device)
metric_eval(test_loader, cnf16x, exp_dir, args)

Loading ERA5 TCW ...
Model loaded successfully.
Current RMSE - mu0: 2.4515905380249023, mu05: 2.5347533226013184, mu08: 2.8433117866516113, mu1: 3.0719504356384277
Current MAE - mu0: 1.746960163116455, mu05: 1.8089313507080078, mu08: 2.084209442138672, mu1: 2.2718653678894043
Current RMSE - mu0: 3.29221248626709, mu05: 3.4046177864074707, mu08: 3.8066582679748535, mu1: 4.165844917297363
Current MAE - mu0: 2.3271420001983643, mu05: 2.4309678077697754, mu08: 2.789163827896118, mu1: 3.046720504760742
Current RMSE - mu0: 2.564744472503662, mu05: 2.663745403289795, mu08: 3.030823230743408, mu1: 3.393923282623291
Current MAE - mu0: 1.824601411819458, mu05: 1.9104586839675903, mu08: 2.2209253311157227, mu1: 2.4984545707702637
Current RMSE - mu0: 3.23589825630188, mu05: 3.407365560531616, mu08: 3.6915693283081055, mu1: 4.160919189453125
Current MAE - mu0: 2.294903516769409, mu05: 2.4319913387298584, mu08: 2.696916103363037, mu1: 3.06524658203125
Current RMSE - mu0: 2.8463778495788574, mu05: 2.

{'rmse0': 2.751227,
 'mse0': 9.48738,
 'mae0': 1.9537641,
 'rmse05': 2.8596761,
 'mse05': 10.173089,
 'mae05': 2.0535772,
 'rmse08': 3.1954575,
 'mse08': 12.727322,
 'mae08': 2.334102,
 'rmse1': 3.551482,
 'mse1': 15.932216,
 'mae1': 2.6122048,
 'crps0': 0.8439635,
 'crps05': 1.0339391,
 'crps08': 1.2165021,
 'crps1': 1.3688538}

In [None]:
gc.collect()

In [None]:
def plot_std(model):
    """
    For this experiment we visualize the super-resolution space for a single
    low-resolution image and its possible HR target predictions. We visualize
    the standard deviation of these predictions from the mean of the model.
    """
    color = 'plasma'
    savedir_viz = "experiments/{}_{}_{}/snapshots/population_std/".format(exp_name, modelname, args.trainset)
    os.makedirs(savedir_viz, exist_ok=True)
    model.eval()
    cmap = 'viridis' if args.trainset == 'era5-TCW' else 'inferno'
    with torch.no_grad():
        for batch_idx, item in enumerate(test_loader):

            y = item[0].to(args.device)
            x = item[1].to(args.device)

            y_unorm = item[2].to(args.device)
            x_unnorm = item[3].to(args.device)

            mu0 = model(x)

            samples = []
            n = 20
            sq_diff = torch.zeros_like(mu0)
            for n in range(n):
                mu1 = model(x)
                samples.append(mu0)
                sq_diff += (mu1 - mu0)**2

            # compute population standard deviation
            sigma = torch.sqrt(sq_diff / n)

            # create plot
            plt.figure()
            plt.imshow(sigma[0,...].permute(1,2,0).cpu().numpy(), cmap=color)
            plt.axis('off')
            # plt.show()
            plt.savefig(savedir_viz + '/sigma_{}.png'.format(batch_idx), dpi=300, bbox_inches='tight')
            plt.close()

            plt.figure()
            plt.imshow(mu0[0,...].permute(1,2,0).cpu().numpy(), cmap=cmap)
            plt.axis('off')
            # plt.show()
            plt.savefig(savedir_viz + '/mu0_{}.png'.format(batch_idx), dpi=300, bbox_inches='tight')
            plt.close()

            fig, (ax1, ax3, ax4, ax5, ax6, ax7) = plt.subplots(1,6)
            # fig.suptitle('Y, Y_hat, mu, sigma')
            ax1.imshow(y[0,...].permute(1,2,0).cpu().numpy(), cmap=cmap)
            divider = make_axes_locatable(ax1)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            cax.set_axis_off()
            ax1.set_title('Ground Truth', fontsize=5)
            ax1.axis('off')
            # ax2.imshow(mu0[0,...].permute(1,2,0).cpu().numpy(), cmap=cmap)
            # divider = make_axes_locatable(ax2)
            # cax = divider.append_axes("right", size="5%", pad=0.05)
            # cax.set_axis_off()
            # ax2.set_title('Mean', fontsize=5)
            # ax2.axis('off')
            ax3.imshow(samples[1][0,...].permute(1,2,0).cpu().numpy(), cmap=cmap)
            divider = make_axes_locatable(ax3)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            cax.set_axis_off()
            ax3.set_title('Sample 1', fontsize=5)
            ax3.axis('off')
            ax4.imshow(samples[2][0,...].permute(1,2,0).cpu().numpy(), cmap=cmap)
            divider = make_axes_locatable(ax4)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            cax.set_axis_off()
            ax4.set_title('Sample 2', fontsize=5)
            ax4.axis('off')
            ax5.imshow(samples[2][0,...].permute(1,2,0).cpu().numpy(), cmap=cmap)
            divider = make_axes_locatable(ax5)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            cax.set_axis_off()
            ax5.set_title('Sample 3', fontsize=5)
            ax5.axis('off')
            ax6.imshow(samples[2][0,...].permute(1,2,0).cpu().numpy(), cmap=cmap)
            divider = make_axes_locatable(ax6)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            cax.set_axis_off()
            ax6.set_title('Sample 4', fontsize=5)
            ax6.axis('off')
            divider = make_axes_locatable(ax7)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            im7 = ax7.imshow(sigma[0,...].permute(1,2,0).cpu().numpy(), cmap='magma')
            cbar = fig.colorbar(im7,cmap='magma', cax=cax)
            cbar.ax.tick_params(labelsize=5)
            ax7.set_title('Std. Dev.', fontsize=5)
            ax7.axis('off')
            plt.tight_layout()
            plt.savefig(savedir_viz + '/std_multiplot_{}.png'.format(batch_idx), dpi=300, bbox_inches='tight')
            # plt.show()
            plt.close()

    return None

In [None]:
# especially focus on analyzing extreme value predictions

### Scatter Plot (Predicted vs Target)
![image.png](attachment:image.png)

### Cumulative Distribution of  Residual Errors
![image.png](attachment:image.png)

### Experiments: Constraint placement
Comparing model with no constraint vs. multiplicative constraint at output layer

### Load CNF with constraints 4x

In [None]:
cnf_4x_add = srflow.SRFlow((in_channels, height, width), args.filter_size, args.L, args.K, args.bsz, 4, 'add', args.nb, args.condch, args.noscale, args.noscaletest)
cnf_4x_mult = srflow.SRFlow((in_channels, height, width), args.filter_size, args.L, args.K, args.bsz, 4, 'mult', args.nb, args.condch, args.noscale, args.noscaletest)
cnf_4x_scadd = srflow.SRFlow((in_channels, height, width), args.filter_size, args.L, args.K, args.bsz, 4, 'scadd', args.nb, args.condch, args.noscale, args.noscaletest)
cnf_4x_softmax = srflow.SRFlow((in_channels, height, width), args.filter_size, args.L, args.K, args.bsz, 4, 'softmax', args.nb, args.condch, args.noscale, args.noscaletest)

In [None]:
# load vanilla conditional flow 4x
exp_dir_add_4x = 'runs/srflow_era5-TCW_None_2024_06_22_09_00_21_4x/'
exp_dir_mult_4x = 'runs/srflow_era5-TCW_None_2024_06_22_09_00_21_4x/'
exp_dir_scadd_4x = 'runs/srflow_era5-TCW_None_2024_06_22_09_00_21_4x/'
exp_dir_softmax_4x = 'runs/srflow_era5-TCW_None_2024_06_22_09_00_21_4x/'

cnf_4x_add_path = '/home/christina/Documents/clim-var-ds-cnf-own/runs/srflow_era5-TCW_None_2024_06_22_09_00_21_4x/model_checkpoints/model_epoch_0_step_2000.tar'
cnf_4x_mult_path = '/home/christina/Documents/clim-var-ds-cnf-own/runs/srflow_era5-TCW_None_2024_06_22_09_00_21_4x/model_checkpoints/model_epoch_0_step_2000.tar'
cnf_4x_scadd_path = '/home/christina/Documents/clim-var-ds-cnf-own/runs/srflow_era5-TCW_None_2024_06_22_09_00_21_4x/model_checkpoints/model_epoch_0_step_2000.tar'
cnf_4x_softmax_path = '/home/christina/Documents/clim-var-ds-cnf-own/runs/srflow_era5-TCW_None_2024_06_22_09_00_21_4x/model_checkpoints/model_epoch_0_step_2000.tar'

In [None]:
# 2x watercontent mul with constraint at the end 
modelname = 'model_epoch_0_step_1000'
modelpath = '/home/mila/c/christina.winkler/clim-var-ds-cnf/runs/srflow_era5-TCW_mul_ constr_in_end__2024_06_03_17_35_33_2x/model_checkpoints/{}.tar'.format(modelname)

model = srflow.SRFlow((in_channels, args.height, args.width), args.filter_size, args.L, args.K,
                        args.bsz, args.s, 'mul', args.nb, args.condch, args.noscale, args.noscaletest)

params = sum(x.numel() for x in model.parameters() if x.requires_grad)
print('Nr of Trainable Params {}:  '.format(args.device), params)
model = model.to(args.device)
exp_name = "flow-{}-level-{}-k--constraint-{}".format(args.L, args.K, 'mul')
print(exp_name)

In [None]:
ckpt = torch.load(modelpath, map_location='cuda:0')
model.load_state_dict(ckpt['model_state_dict'])

## Evaluate Power Spectrum

In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
from skimage import io

# Directory containing the test set images
image_dir = 'path_to_your_test_set_directory'
output_dir = 'path_to_save_power_spectra'  # Directory to save power spectra images

# Ensure the output directory exists
os.makedirs(output_dir, exist_ok=True)

def compute_power_spectrum(image):
    # Apply FFT
    f_transform = np.fft.fft2(image)
    f_transform_shifted = np.fft.fftshift(f_transform)

    # Compute the power spectrum
    power_spectrum = np.abs(f_transform_shifted) ** 2

    # Use logarithmic scale for visualization
    log_power_spectrum = np.log(1 + power_spectrum)
    
    return log_power_spectrum

def process_images(image_dir, output_dir):
    for filename in os.listdir(image_dir):
        if filename.endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
            # Load the image
            image_path = os.path.join(image_dir, filename)
            image = io.imread(image_path)
            
            # Ensure the image is single-channel
            if len(image.shape) == 3:
                raise ValueError("Image is not single-channel.")
            
            # Compute the power spectrum
            power_spectrum = compute_power_spectrum(image)

            # Save the power spectrum image
            output_path = os.path.join(output_dir, f'power_spectrum_{filename}')
            plt.imsave(output_path, power_spectrum, cmap='gray')
            
            print(f'Processed and saved: {filename}')

# Process all images in the directory
process_images(image_dir, output_dir)


FileNotFoundError: [Errno 2] No such file or directory: 'path_to_your_test_set_directory'