## Evaluation

In [48]:
import torch
import numpy as np
import random
from models.architectures import srgan, srflow
import PIL
import os
import torchvision
from torchvision import transforms
import matplotlib as mpl
from mpl_toolkits.axes_grid1 import make_axes_locatable
import pandas as pd

# Dataset loading
from data import dataloading
from data.era5_temp_dataset import InverseMinMaxScaler

from os.path import exists, join
import matplotlib.pyplot as plt
from matplotlib import transforms
import timeit
import pdb
import argparse
import seaborn as sns

from utils.metrics import *

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [49]:
print('GPUs avail:', torch.cuda.device_count())

# Parse Settings
parser = argparse.ArgumentParser()

# train configs
parser.add_argument("--model", type=str, default="srflow",
                    help="Model you want to train.")
parser.add_argument("--modeltype", type=str, default="srflow",
                    help="Specify modeltype you would like to train [srflow, cdiff, srgan].")
parser.add_argument("--model_path", type=str, default="runs/",
                    help="Directory where models are saved.")
parser.add_argument("--modelname", type=str, default=None,
                    help="Sepcify modelname to be tested.")
parser.add_argument("--epochs", type=int, default=10000,
                    help="number of epochs")
parser.add_argument("--max_steps", type=int, default=2000000,
                    help="For training on a large dataset.")
parser.add_argument("--log_interval", type=int, default=100,
                    help="Interval in which results should be logged.")

# runtime configs
parser.add_argument("--visual", action="store_true",
                    help="Visualizing the samples at test time.")
parser.add_argument("--noscaletest", action="store_true",
                    help="Disable scale in coupling layers only at test time.")
parser.add_argument("--noscale", action="store_true",
                    help="Disable scale in coupling layers.")
parser.add_argument("--testmode", action="store_true",
                    help="Model run on test set.")
parser.add_argument("--train", action="store_true",
                    help="If model should be trained.")
parser.add_argument("--resume_training", action="store_true",
                    help="If training should be resumed.")
parser.add_argument("--constraint", type=str, default='scaddDS',
                    help="Physical Constraint to be applied during training.")                   

# hyperparameters
parser.add_argument("--nbits", type=int, default=8,
                    help="Images converted to n-bit representations.")
parser.add_argument("--s", type=int, default=16, help="Upscaling factor.")
parser.add_argument("--bsz", type=int, default=16, help="batch size")
parser.add_argument("--lr", type=float, default=0.0002, help="learning rate")
parser.add_argument("--filter_size", type=int, default=512, help="filter size NN in Affine Coupling Layer")
parser.add_argument("--L", type=int, default=3, help="# of levels")
parser.add_argument("--K", type=int, default=2,
                    help="# of flow steps, i.e. model depth")
parser.add_argument("--nb", type=int, default=16,
                    help="# of residual-in-residual blocks LR network.")
parser.add_argument("--condch", type=int, default=128//8,
                    help="# of residual-in-residual blocks in LR network.")

# data
parser.add_argument("--datadir", type=str, default="/home/christina/Documents/clim-var-ds-cnf-own/data/",
                    help="Dataset to train the model on.")
parser.add_argument("--trainset", type=str, default="era5-TCW", help='[era5-TCW, era5-T2M]')
parser.add_argument("--testset", type=str, default="era5-TCW", help="Specify test dataset")

args = parser.parse_args('')
config = vars(args)

GPUs avail: 1


In [50]:
if torch.cuda.is_available():
    args.device = torch.device("cuda")
    args.num_gpus = torch.cuda.device_count()
    args.parallel = False

else:
    args.device = "cpu"

In [51]:
# use min-max due to non-gaussian distribution of data and outlier handling
def min_max_scaler(x, min_val=0, max_val=124):
    if min_val is None:
        min_val = x.min()
    if max_val is None:
        max_val = x.max()
    scaled_x = (x - min_val) / (max_val - min_val)
    return scaled_x, min_val, max_val

def inv_min_max_scaler(scaled_x, min_val=0, max_val=124):
    x = scaled_x * (max_val - min_val) + min_val
    return x

### Metric Evaluation
Evaluations: 
- CNF unconstrained vs. constrained
- CNF vs. GAN vs. CDiff
- Residual Error Plots for ???
- Metrics for 2x,4x,8x,16x with and without constraints

In [59]:
import os
import json
import torch
import torchvision.utils as vutils
import sys
sys.path.append("../../")
from utils import metrics

def save_snapshot(tensor, path):
    # Function to save tensor as image in Viridis color scheme
    plt.figure()
    plt.imshow(tensor[0,...].cpu().numpy().squeeze(), cmap='viridis')
    plt.axis('off')
    plt.savefig(path, bbox_inches='tight', pad_inches=0)
    plt.close()

def save_grid_images(tensor_batch, save_path, nrow=3, title=None):
    
    # Create a grid of images
    grid_img = vutils.make_grid(tensor_batch[0:9], nrow=nrow, normalize=True, scale_each=True)

    # Convert the grid to numpy and plot
    np_grid = grid_img.cpu().numpy().transpose((1, 2, 0))
    plt.figure(figsize=(10, 10))
    plt.imshow(np_grid[:, :, 0], cmap='viridis')  # Use Viridis colormap for grayscale images
    plt.axis('off')
    plt.title(title)
    plt.tight_layout(pad=0)

    # Save the grid image
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    
def metric_eval(dataloader, model, exp_name, args):
    metric_dict = {
        'rmse0': [], 'mse0': [], 'mae0': [],
        'rmse05': [], 'mse05': [], 'mae05': [],
        'rmse08': [], 'mse08': [], 'mae08': [],
        'rmse1': [], 'mse1': [], 'mae1': [],
        'crps_mean': [], 'crps_std': []
    }

    # Create directories
    results_dir = os.path.join(exp_name, 'experiment_results')
    os.makedirs(results_dir, exist_ok=True)
    
    snapshots_dir = os.path.join(results_dir, 'snapshots')
    os.makedirs(snapshots_dir, exist_ok=True)

    model.eval()
    with torch.no_grad():
        for batch_idx, item in enumerate(dataloader):
            y = item[0].to(args.device)
            x = item[1].to(args.device)

            y_unorm = item[2].to(args.device)
            x_unorm = item[3].to(args.device)

            z, _ = model.forward(x_hr=y, xlr=x)

            # Evaluate for different temperatures
            mu0, _, _ = model(xlr=x, reverse=True, eps=0)
            mu05, _, _ = model(xlr=x, reverse=True, eps=0.5)
            mu08, _, _ = model(xlr=x, reverse=True, eps=0.8)
            mu1, _, _ = model(xlr=x, reverse=True, eps=1.0)

            # Compute and store MSE for each temperature
            mse0 = metrics.MSE(inv_min_max_scaler(mu0), y_unorm).detach().cpu().numpy()
            mse05 = metrics.MSE(inv_min_max_scaler(mu05), y_unorm).detach().cpu().numpy()
            mse08 = metrics.MSE(inv_min_max_scaler(mu08), y_unorm).detach().cpu().numpy()
            mse1 = metrics.MSE(inv_min_max_scaler(mu1), y_unorm).detach().cpu().numpy()
            
            metric_dict['mse0'].append(mse0)
            metric_dict['mse05'].append(mse05)
            metric_dict['mse08'].append(mse08)
            metric_dict['mse1'].append(mse1)

            # Compute and store MAE for each temperature
            mae0 = metrics.MAE(inv_min_max_scaler(mu0), y_unorm).detach().cpu().numpy()
            mae05 = metrics.MAE(inv_min_max_scaler(mu05), y_unorm).detach().cpu().numpy()
            mae08 = metrics.MAE(inv_min_max_scaler(mu08), y_unorm).detach().cpu().numpy()
            mae1 = metrics.MAE(inv_min_max_scaler(mu1), y_unorm).detach().cpu().numpy()

            metric_dict['mae0'].append(mae0)
            metric_dict['mae05'].append(mae05)
            metric_dict['mae08'].append(mae08)
            metric_dict['mae1'].append(mae1)

            # Compute and store RMSE for each temperature
            rmse0 = metrics.RMSE(inv_min_max_scaler(mu0), y_unorm).detach().cpu().numpy()
            rmse05 = metrics.RMSE(inv_min_max_scaler(mu05), y_unorm).detach().cpu().numpy()
            rmse08 = metrics.RMSE(inv_min_max_scaler(mu08), y_unorm).detach().cpu().numpy()
            rmse1 = metrics.RMSE(inv_min_max_scaler(mu1), y_unorm).detach().cpu().numpy()

            metric_dict['rmse0'].append(rmse0)
            metric_dict['rmse05'].append(rmse05)
            metric_dict['rmse08'].append(rmse08)
            metric_dict['rmse1'].append(rmse1)

            # Save grid of images for visualization (adjust as needed)
            if batch_idx == 0:  # Save only for the first batch for simplicity
                # Save snapshots
                save_grid_images(y_unorm, os.path.join(snapshots_dir, f'ground_truth_{batch_idx}.png'))
                save_grid_images(mu0, os.path.join(snapshots_dir, f'prediction_mu0_{batch_idx}.png'))
                save_grid_images(mu05, os.path.join(snapshots_dir, f'prediction_mu05_{batch_idx}.png'))
                save_grid_images(mu08, os.path.join(snapshots_dir, f'prediction_mu08_{batch_idx}.png'))
                save_grid_images(mu1, os.path.join(snapshots_dir, f'prediction_mu1_{batch_idx}.png'))
                save_grid_images(x_unorm, os.path.join(snapshots_dir, f'low_res_{batch_idx}.png'))

            # Calculate CRPS for ensemble
            crps_stack = torch.stack([mu0, mu05, mu08, mu1], dim=1)
            print(crps_ensemble(y_unorm, crps_stack))
            crps_sum, crps_mean = crps_ensemble(y_unorm, crps_stack)
            crps_std = np.std(crps_sum, axis=0, ddof=1)
            metric_dict['crps_mean'].append(crps_mean)
            metric_dict['crps_std'].append(crps_std)
            
            print(f'Current RMSE - mu0: {rmse0}, mu05: {rmse05}, mu08: {rmse08}, mu1: {rmse1}')
            print(f'Current MAE - mu0: {mae0}, mu05: {mae05}, mu08: {mae08}, mu1: {mae1}')
            print(f'Current CRPS:{ crps_mean}')

            # if batch_idx == 200:
            #     break

    # Create a string representation of the metric_dict
    mean_dict = {}
    for key, value in metric_dict.items():
        if len(value) > 0:
            mean_dict[key] = np.mean(value)   
            
    metric_str = "\n".join([f"{key}: {value}" for key, value in mean_dict.items()])

    # Save metric_dict to a text file
    with open(os.path.join(results_dir, 'metrics.txt'), 'w') as f:
        json.dump(metric_str, f, indent=4)
        
    print(metric_dict)
    return metric_dict

# Experiments: Comparing across upsampling factors


In [60]:
# load model
print('Num of avail GPUs:', torch.cuda.device_count())

Num of avail GPUs: 1


In [61]:
import os
import torch

def load_model(exp_dir, model, mpath):
    """
    Load a model from the specified directory and checkpoint file.
    
    Parameters:
    exp_dir (str): The experiment directory path.
    model (torch.nn.Module): The model instance to load the state dictionary into.
    mpath (str): The path to the model checkpoint file.
    
    Returns:
    torch.nn.Module: The loaded model with the state dictionary.
    
    Raises:
    FileNotFoundError: If the model file does not exist at the specified path.
    """
    try:
        # Check if the file exists
        if not os.path.exists(mpath):
            raise FileNotFoundError(f"Model file not found at {mpath}")
        
        # Load the checkpoint
        ckpt = torch.load(mpath)
        
        # Load the state dictionary into the model
        model.load_state_dict(ckpt['model_state_dict'])
        
        print("Model loaded successfully.")
        return model

    except FileNotFoundError as e:
        print(e)
        return None

In [62]:
import gc
gc.collect()

9911

In [63]:
# args.s = 2
# train_loader, val_loader, test_loader, args = dataloading.load_data(args)
# in_channels = next(iter(test_loader))[0].shape[1]
# height, width = next(iter(test_loader))[0].shape[2], next(iter(test_loader))[0].shape[3]
# cnf2x = srflow.SRFlow((in_channels, height, width), args.filter_size, args.L, args.K, args.bsz, 2, args.constraint, args.nb, args.condch, args.noscale, args.noscaletest)
# exp_dir = 'runs/srflow_era5-TCW_None_2024_06_22_09_31_04_2x/'
# cnf_path = '/home/christina/Documents/clim-var-ds-cnf-own/runs/srflow_era5-TCW_None_2024_06_22_09_31_04_2x/model_checkpoints/model_epoch_0_step_2000.tar'
# # Load the model
# cnf2x = load_model(exp_dir, cnf2x, cnf_path).to(args.device)
# metric_eval(test_loader, cnf2x, exp_dir, args)

In [64]:
gc.collect()

44

In [65]:
args.s = 4
train_loader, val_loader, test_loader, args = dataloading.load_data(args)
in_channels = next(iter(test_loader))[0].shape[1]
height, width = next(iter(test_loader))[0].shape[2], next(iter(test_loader))[0].shape[3]
cnf4x = srflow.SRFlow((in_channels, height, width), args.filter_size, args.L, args.K, args.bsz, 4, args.constraint, args.nb, args.condch, args.noscale, args.noscaletest)
exp_dir = 'runs/srflow_era5-TCW_None_2024_06_22_09_00_21_4x/'
cnf_path = '/home/christina/Documents/clim-var-ds-cnf-own/runs/srflow_era5-TCW_None_2024_06_22_09_00_21_4x/model_checkpoints/model_epoch_0_step_2000.tar'
# Load the model
cnf4x = load_model(exp_dir, cnf4x, cnf_path).to(args.device)
metric_eval(test_loader, cnf4x, exp_dir, args)

Loading ERA5 TCW ...
Model loaded successfully.
[0.03641408 0.00161226 0.00587853 0.00709524 0.03521129 0.02694691
 0.02166406 0.01085579 0.0178852  0.00764915 0.04622082 0.01562121
 0.02291531 0.04516992 0.02306044 0.00509276]


ValueError: too many values to unpack (expected 2)

In [None]:
gc.collect()

In [13]:
args.s = 8
args.device = 'cuda'
train_loader, val_loader, test_loader, args = dataloading.load_data(args)
in_channels = next(iter(test_loader))[0].shape[1]
height, width = next(iter(test_loader))[0].shape[2], next(iter(test_loader))[0].shape[3]
cnf8x = srflow.SRFlow((in_channels, height, width), args.filter_size, args.L, args.K, args.bsz, 8, args.constraint, args.nb, args.condch, args.noscale, args.noscaletest)
exp_dir = 'runs/srflow_era5-TCW_None_2024_06_22_09_00_20_8x/'
cnf_path = '/home/christina/Documents/clim-var-ds-cnf-own/runs/srflow_era5-TCW_None_2024_06_22_09_00_20_8x/model_checkpoints/model_epoch_2_step_5750.tar'
# Load the model
cnf8x = load_model(exp_dir, cnf8x, cnf_path).to(args.device)
metric_eval(test_loader, cnf8x, exp_dir, args)

Loading ERA5 TCW ...
Model loaded successfully.
Current RMSE - mu0: 1.3318904638290405, mu05: 1.4634594917297363, mu08: 1.6496049165725708, mu1: 1.798222541809082
Current MAE - mu0: 0.9427525997161865, mu05: 1.039659023284912, mu08: 1.1646144390106201, mu1: 1.2741026878356934
Current CRPS:0.020576827228069305
Current RMSE - mu0: 1.6682360172271729, mu05: 1.8778321743011475, mu08: 2.146379232406616, mu1: 2.398204803466797
Current MAE - mu0: 1.1627733707427979, mu05: 1.3058044910430908, mu08: 1.4846324920654297, mu1: 1.6501340866088867
Current CRPS:0.023372437804937363
Current RMSE - mu0: 1.3846279382705688, mu05: 1.577899694442749, mu08: 1.8343181610107422, mu1: 2.004870891571045
Current MAE - mu0: 0.9700607061386108, mu05: 1.1018891334533691, mu08: 1.2743875980377197, mu1: 1.400654911994934
Current CRPS:0.022555585950613022
Current RMSE - mu0: 1.7158615589141846, mu05: 1.9204659461975098, mu08: 2.219501256942749, mu1: 2.430448055267334
Current MAE - mu0: 1.1724963188171387, mu05: 1.314

{'rmse0': [array(1.3318905, dtype=float32),
  array(1.668236, dtype=float32),
  array(1.3846279, dtype=float32),
  array(1.7158616, dtype=float32),
  array(1.4444795, dtype=float32),
  array(1.3700814, dtype=float32),
  array(1.4878238, dtype=float32),
  array(1.4114443, dtype=float32),
  array(1.6176355, dtype=float32),
  array(1.423278, dtype=float32),
  array(1.248455, dtype=float32),
  array(1.6778626, dtype=float32),
  array(1.4797704, dtype=float32),
  array(1.6305771, dtype=float32),
  array(1.2824069, dtype=float32),
  array(1.4159994, dtype=float32),
  array(1.6363208, dtype=float32),
  array(1.4604397, dtype=float32),
  array(1.3084271, dtype=float32),
  array(1.4166197, dtype=float32),
  array(1.3754351, dtype=float32),
  array(1.1534939, dtype=float32),
  array(1.3029748, dtype=float32),
  array(1.34077, dtype=float32),
  array(1.4120996, dtype=float32),
  array(1.355625, dtype=float32),
  array(1.6338646, dtype=float32),
  array(1.4115021, dtype=float32),
  array(1.7309506

In [14]:
gc.collect()

9859

In [16]:
args.s = 16
args.device = 'cuda'
train_loader, val_loader, test_loader, args = dataloading.load_data(args)
in_channels = next(iter(test_loader))[0].shape[1]
height, width = next(iter(test_loader))[0].shape[2], next(iter(test_loader))[0].shape[3]
cnf16x = srflow.SRFlow((in_channels, height, width), args.filter_size, args.L, args.K, args.bsz, 16, args.constraint, args.nb, args.condch, args.noscale, args.noscaletest)
exp_dir = 'runs/srflow_era5-TCW_None_2024_06_22_09_00_13_16x/'
cnf_path = '/home/christina/Documents/clim-var-ds-cnf-own/runs/srflow_era5-TCW_None_2024_06_22_09_00_13_16x/model_checkpoints/model_epoch_2_step_5750.tar'
# Load the model
cnf16x = load_model(exp_dir, cnf16x, cnf_path).to(args.device)
metric_eval(test_loader, cnf16x, exp_dir, args)

Loading ERA5 TCW ...
Model loaded successfully.
Current RMSE - mu0: 2.7445549964904785, mu05: 2.9587297439575195, mu08: 3.3401589393615723, mu1: 3.689572811126709
Current MAE - mu0: 2.0021347999572754, mu05: 2.1640753746032715, mu08: 2.4407784938812256, mu1: 2.708587646484375
Current CRPS:0.020582452416419983
Current RMSE - mu0: 3.402545928955078, mu05: 3.72341251373291, mu08: 4.442249298095703, mu1: 5.0418500900268555
Current MAE - mu0: 2.4763975143432617, mu05: 2.7151293754577637, mu08: 3.2648348808288574, mu1: 3.6934452056884766
Current CRPS:0.023375775665044785
Current RMSE - mu0: 2.7831525802612305, mu05: 3.074458122253418, mu08: 3.569326400756836, mu1: 4.135376453399658
Current MAE - mu0: 2.026686668395996, mu05: 2.2538907527923584, mu08: 2.607243537902832, mu1: 3.010141134262085
Current CRPS:0.022560130804777145
Current RMSE - mu0: 3.433812379837036, mu05: 3.6976215839385986, mu08: 4.333578109741211, mu1: 4.780289649963379
Current MAE - mu0: 2.47659969329834, mu05: 2.68243360519

{'rmse0': [array(2.744555, dtype=float32),
  array(3.402546, dtype=float32),
  array(2.7831526, dtype=float32),
  array(3.4338124, dtype=float32),
  array(3.0747707, dtype=float32),
  array(2.9587274, dtype=float32),
  array(3.1042962, dtype=float32),
  array(2.8602123, dtype=float32),
  array(3.271292, dtype=float32),
  array(2.7811575, dtype=float32),
  array(2.7513556, dtype=float32),
  array(3.2106988, dtype=float32),
  array(3.0894556, dtype=float32),
  array(3.196508, dtype=float32),
  array(2.7010746, dtype=float32),
  array(3.0330849, dtype=float32),
  array(3.250852, dtype=float32),
  array(2.586711, dtype=float32),
  array(2.6933346, dtype=float32),
  array(2.887167, dtype=float32),
  array(2.9204872, dtype=float32),
  array(2.5403838, dtype=float32),
  array(2.8249826, dtype=float32),
  array(2.901198, dtype=float32),
  array(2.907194, dtype=float32),
  array(2.674665, dtype=float32),
  array(3.3994582, dtype=float32),
  array(3.1073718, dtype=float32),
  array(3.296148, dty

In [None]:
gc.collect()

In [None]:
def plot_std(model):
    """
    For this experiment we visualize the super-resolution space for a single
    low-resolution image and its possible HR target predictions. We visualize
    the standard deviation of these predictions from the mean of the model.
    """
    color = 'plasma'
    savedir_viz = "experiments/{}_{}_{}/snapshots/population_std/".format(exp_name, modelname, args.trainset)
    os.makedirs(savedir_viz, exist_ok=True)
    model.eval()
    cmap = 'viridis' if args.trainset == 'era5-TCW' else 'inferno'
    with torch.no_grad():
        for batch_idx, item in enumerate(test_loader):

            y = item[0].to(args.device)
            x = item[1].to(args.device)

            y_unorm = item[2].to(args.device)
            x_unnorm = item[3].to(args.device)

            mu0 = model(x)

            samples = []
            n = 20
            sq_diff = torch.zeros_like(mu0)
            for n in range(n):
                mu1 = model(x)
                samples.append(mu0)
                sq_diff += (mu1 - mu0)**2

            # compute population standard deviation
            sigma = torch.sqrt(sq_diff / n)

            # create plot
            plt.figure()
            plt.imshow(sigma[0,...].permute(1,2,0).cpu().numpy(), cmap=color)
            plt.axis('off')
            # plt.show()
            plt.savefig(savedir_viz + '/sigma_{}.png'.format(batch_idx), dpi=300, bbox_inches='tight')
            plt.close()

            plt.figure()
            plt.imshow(mu0[0,...].permute(1,2,0).cpu().numpy(), cmap=cmap)
            plt.axis('off')
            # plt.show()
            plt.savefig(savedir_viz + '/mu0_{}.png'.format(batch_idx), dpi=300, bbox_inches='tight')
            plt.close()

            fig, (ax1, ax3, ax4, ax5, ax6, ax7) = plt.subplots(1,6)
            # fig.suptitle('Y, Y_hat, mu, sigma')
            ax1.imshow(y[0,...].permute(1,2,0).cpu().numpy(), cmap=cmap)
            divider = make_axes_locatable(ax1)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            cax.set_axis_off()
            ax1.set_title('Ground Truth', fontsize=5)
            ax1.axis('off')
            # ax2.imshow(mu0[0,...].permute(1,2,0).cpu().numpy(), cmap=cmap)
            # divider = make_axes_locatable(ax2)
            # cax = divider.append_axes("right", size="5%", pad=0.05)
            # cax.set_axis_off()
            # ax2.set_title('Mean', fontsize=5)
            # ax2.axis('off')
            ax3.imshow(samples[1][0,...].permute(1,2,0).cpu().numpy(), cmap=cmap)
            divider = make_axes_locatable(ax3)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            cax.set_axis_off()
            ax3.set_title('Sample 1', fontsize=5)
            ax3.axis('off')
            ax4.imshow(samples[2][0,...].permute(1,2,0).cpu().numpy(), cmap=cmap)
            divider = make_axes_locatable(ax4)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            cax.set_axis_off()
            ax4.set_title('Sample 2', fontsize=5)
            ax4.axis('off')
            ax5.imshow(samples[2][0,...].permute(1,2,0).cpu().numpy(), cmap=cmap)
            divider = make_axes_locatable(ax5)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            cax.set_axis_off()
            ax5.set_title('Sample 3', fontsize=5)
            ax5.axis('off')
            ax6.imshow(samples[2][0,...].permute(1,2,0).cpu().numpy(), cmap=cmap)
            divider = make_axes_locatable(ax6)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            cax.set_axis_off()
            ax6.set_title('Sample 4', fontsize=5)
            ax6.axis('off')
            divider = make_axes_locatable(ax7)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            im7 = ax7.imshow(sigma[0,...].permute(1,2,0).cpu().numpy(), cmap='magma')
            cbar = fig.colorbar(im7,cmap='magma', cax=cax)
            cbar.ax.tick_params(labelsize=5)
            ax7.set_title('Std. Dev.', fontsize=5)
            ax7.axis('off')
            plt.tight_layout()
            plt.savefig(savedir_viz + '/std_multiplot_{}.png'.format(batch_idx), dpi=300, bbox_inches='tight')
            # plt.show()
            plt.close()

    return None

In [None]:
# especially focus on analyzing extreme value predictions

### Scatter Plot (Predicted vs Target)
![image.png](attachment:image.png)

### Cumulative Distribution of  Residual Errors
![image.png](attachment:image.png)

### Experiments: Constraint placement
Comparing model with no constraint vs. multiplicative constraint at output layer

### Load CNF with constraints 4x

In [None]:
cnf_4x_add = srflow.SRFlow((in_channels, height, width), args.filter_size, args.L, args.K, args.bsz, 4, 'add', args.nb, args.condch, args.noscale, args.noscaletest)
cnf_4x_mult = srflow.SRFlow((in_channels, height, width), args.filter_size, args.L, args.K, args.bsz, 4, 'mult', args.nb, args.condch, args.noscale, args.noscaletest)
cnf_4x_scadd = srflow.SRFlow((in_channels, height, width), args.filter_size, args.L, args.K, args.bsz, 4, 'scadd', args.nb, args.condch, args.noscale, args.noscaletest)
cnf_4x_softmax = srflow.SRFlow((in_channels, height, width), args.filter_size, args.L, args.K, args.bsz, 4, 'softmax', args.nb, args.condch, args.noscale, args.noscaletest)

In [None]:
# load vanilla conditional flow 4x
exp_dir_add_4x = 'runs/srflow_era5-TCW_None_2024_06_22_09_00_21_4x/'
exp_dir_mult_4x = 'runs/srflow_era5-TCW_None_2024_06_22_09_00_21_4x/'
exp_dir_scadd_4x = 'runs/srflow_era5-TCW_None_2024_06_22_09_00_21_4x/'
exp_dir_softmax_4x = 'runs/srflow_era5-TCW_None_2024_06_22_09_00_21_4x/'

cnf_4x_add_path = '/home/christina/Documents/clim-var-ds-cnf-own/runs/srflow_era5-TCW_None_2024_06_22_09_00_21_4x/model_checkpoints/model_epoch_0_step_2000.tar'
cnf_4x_mult_path = '/home/christina/Documents/clim-var-ds-cnf-own/runs/srflow_era5-TCW_None_2024_06_22_09_00_21_4x/model_checkpoints/model_epoch_0_step_2000.tar'
cnf_4x_scadd_path = '/home/christina/Documents/clim-var-ds-cnf-own/runs/srflow_era5-TCW_None_2024_06_22_09_00_21_4x/model_checkpoints/model_epoch_0_step_2000.tar'
cnf_4x_softmax_path = '/home/christina/Documents/clim-var-ds-cnf-own/runs/srflow_era5-TCW_None_2024_06_22_09_00_21_4x/model_checkpoints/model_epoch_0_step_2000.tar'

In [None]:
# 2x watercontent mul with constraint at the end 
modelname = 'model_epoch_0_step_1000'
modelpath = '/home/mila/c/christina.winkler/clim-var-ds-cnf/runs/srflow_era5-TCW_mul_ constr_in_end__2024_06_03_17_35_33_2x/model_checkpoints/{}.tar'.format(modelname)

model = srflow.SRFlow((in_channels, args.height, args.width), args.filter_size, args.L, args.K,
                        args.bsz, args.s, 'mul', args.nb, args.condch, args.noscale, args.noscaletest)

params = sum(x.numel() for x in model.parameters() if x.requires_grad)
print('Nr of Trainable Params {}:  '.format(args.device), params)
model = model.to(args.device)
exp_name = "flow-{}-level-{}-k--constraint-{}".format(args.L, args.K, 'mul')
print(exp_name)

In [None]:
ckpt = torch.load(modelpath, map_location='cuda:0')
model.load_state_dict(ckpt['model_state_dict'])