In [1]:
#Inspect if the outputs from the 10 random runs deviate significantly from each other. 
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from netCDF4 import Dataset
from tqdm import tqdm
import torch.nn.functional as F
from sklearn.metrics import r2_score as R2
from sklearn.model_selection import KFold
from copy import deepcopy
import utils
from unet import UNet_nobatchnorm
from scipy.stats import pearsonr
from pathlib import Path
import numpy.fft as fft
from matplotlib.colors import TwoSlopeNorm
import matplotlib.patches as patches
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import helper_functions as hf

In [2]:
root_dir = '/work/uo0780/u241359/project_tide_synergy/data/'
nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)
Ntrain = np.sum([nc.dimensions['time_counter'].size for nc in nctrains], axis = 0)
Ntest = np.sum([nc.dimensions['time_counter'].size for nc in nctest], axis = 0)

model_folder = '/work/uo0780/u241359/project_tide_synergy/trainedmodels_forpaper/'
#filesuffix1 + randomseednumber + filesuffix2 = file suffix. 
filesuffix1='_ssh_cosssh_sin_nobatchnorm_rp_'
filesuffix2='.pth'

vel_cmap  = 'BrBG' #'viridis'
vort_cmap = 'PRGn'
ssh_cmap  = 'bwr'
sst_cmap = 'inferno'

bottom_slice = slice(0,256)
mid_slice = slice(232, 488)
top_slice = slice(464, 720)

def corr(data, mod):
    return pearsonr(data.flatten(), mod.flatten())[0]
def L2_R(data,mod):
    return R2(data.flatten(), mod.flatten())

nensemble = 10 #How many re-trained U-Net models there are

In [3]:
Nbase = 16
def totorch(x):
    return torch.tensor(x, dtype = torch.float).cpu()
    
def preload_data(nctrains, total_records):
    #total_records = Ntrain#sum(nc.dimensions['time_counter'].size for nc in nctrains)
    #dimensions of data of the nc file.
    max_height = 722
    max_width = 258
    all_input_data = np.zeros((total_records, N_inp, max_height, max_width))*np.nan
    all_output_data = np.zeros((total_records, N_out, max_height, max_width))*np.nan
    current_index = 0
    for ncindex, ncdata in enumerate(nctrains):
        num_recs = ncdata.dimensions['time_counter'].size
        rec_slice = slice(current_index, current_index + num_recs)
        
        for ind, var_name in enumerate(var_input_names):
            data_slice = np.squeeze(ncdata.variables[var_name])
            # print('data_slice shape:')
            # print(data_slice.shape)        
            #all_input_data[rec_slice, ind, :, :] = data_slice
            #For some variables, the dimensions in (x, y) may be smaller than (max_height, max_width). Changing the code so that it adapts them.
            # Get the actual dimensions of data_slice
            slice_height, slice_width = data_slice.shape[-2], data_slice.shape[-1]
            # Place data_slice into the corresponding slice of all_input_data
            all_input_data[rec_slice, ind, :slice_height, :slice_width] = data_slice
    

        for ind, var_name in enumerate(var_output_names):
            data_slice = np.squeeze(ncdata.variables[var_name])
            #all_output_data[rec_slice, ind, :, :] = data_slice
            # Get the actual dimensions of data_slice
            slice_height, slice_width = data_slice.shape[-2], data_slice.shape[-1]
            # Place data_slice into the corresponding slice of all_input_data
            all_output_data[rec_slice, ind, :slice_height, :slice_width] = data_slice

        current_index += num_recs
        
    return all_input_data, all_output_data

# # Modify the loadtrain function to pull data from preloaded memory
# def loaddata_preloaded_train(index, batch_size, all_input_data, all_output_data):
#     rec_slice = slice(index, index + batch_size)
#     lim = 720
#     width = 256
#     yslice = slice(0, lim)
#     xslice = slice(0, width)
#     # print('rec_slice is:')
#     # print(rec_slice)
#     # print('mean of squared values of loaded input data:')
#     # print("{0:0.32f}".format(np.nanmean(all_input_data[rec_slice, :, yslice, xslice]**2)))
#     return (all_input_data[rec_slice, :, yslice, xslice], 
#             all_output_data[rec_slice, :, yslice, xslice])
#Load test data as one single batch
def loaddata_preloaded_test(all_input_data, all_output_data):
    #rec_slice = slice(index, index + batch_size)
    lim = 720
    width = 256
    yslice = slice(0, lim)
    xslice = slice(0, width)
    # print('rec_slice is:')
    # print(rec_slice)
    # print('mean of squared values of loaded input data:')
    # print("{0:0.32f}".format(np.nanmean(all_input_data[rec_slice, :, yslice, xslice]**2)))
    return (all_input_data[:, :, yslice, xslice], 
            all_output_data[:, :, yslice, xslice])


def load_variable(ncdata, ncindex, variable, rec_slice, yslice, xslice):
    data_squeezed = np.squeeze(ncdata[ncindex].variables[variable])
    return data_squeezed[rec_slice, yslice, xslice]

def hwvorticity(u, v, dgrid = 4000):
    return (np.gradient(v, axis =2) - np.gradient(u, axis =1))/dgrid

def hwdivergence(u, v, dgrid = 4000):
    return (np.gradient(u, axis =2) + np.gradient(v, axis =1))/dgrid

def preload_data_vortdiv(nctrains, total_records):
    #total_records = Ntrain#sum(nc.dimensions['time_counter'].size for nc in nctrains)
    #dimensions of data of the nc file.
    max_height = 722
    max_width = 258
    all_input_data = np.zeros((total_records, N_inp, max_height, max_width))*np.nan
    all_output_data = np.zeros((total_records, N_out, max_height, max_width))*np.nan
    current_index = 0
    for ncindex, ncdata in enumerate(nctrains):
        num_recs = ncdata.dimensions['time_counter'].size #how many time stamps are there in each .nc file (i.e., at each turbulence level)
        rec_slice = slice(current_index, current_index + num_recs)
        for ind, var_name in enumerate(var_input_names):
            if var_name == 'vort':
                u = np.squeeze(ncdata.variables['u_xy_ins'])
                v = np.squeeze(ncdata.variables['v_xy_ins'])
                #u.shape: (150, 722, 257); v.shape: (150, 721, 258)
                #as u and v have different number of grid points in x and y, we truncate them so that their shapes agree, enabling the simple way to compute vorticities based on finite diff.
                data_slice = hwvorticity(u[:,:-1,:], v[:,:,:-1])
            elif var_name == 'div':
                u = np.squeeze(ncdata.variables['u_xy_ins'])
                v = np.squeeze(ncdata.variables['v_xy_ins'])
                data_slice = hwdivergence(u[:,:-1,:], v[:,:,:-1])
            else:           
                data_slice = np.squeeze(ncdata.variables[var_name])
            # print('data_slice shape:')
            # print(data_slice.shape)        
            #all_input_data[rec_slice, ind, :, :] = data_slice
            
            #For some variables, the dimensions in (x, y) may be smaller than (max_height, max_width). Changing the code so that it adapts them.
            # Get the actual dimensions of data_slice
            slice_height, slice_width = data_slice.shape[-2], data_slice.shape[-1]
            # Place data_slice into the corresponding slice of all_input_data
            all_input_data[rec_slice, ind, :slice_height, :slice_width] = data_slice
    

        for ind, var_name in enumerate(var_output_names):
            data_slice = np.squeeze(ncdata.variables[var_name])
            #all_output_data[rec_slice, ind, :, :] = data_slice
            # Get the actual dimensions of data_slice
            slice_height, slice_width = data_slice.shape[-2], data_slice.shape[-1]
            # Place data_slice into the corresponding slice of all_input_data
            all_output_data[rec_slice, ind, :slice_height, :slice_width] = data_slice

        current_index += num_recs
        
    return all_input_data, all_output_data
    


In [4]:
# ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓ Change below for each Configuration ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓
vi1 = 'ssh_ins'

vo1 = 'ssh_cos'
vo2 = 'ssh_sin'

save_fn_prefix  = 'any_{}_{}{}_nobatchnorm'.format(vi1, vo1, vo2)
var_input_names = [vi1]
var_output_names = [vo1, vo2]
# ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑
N_inp = len(var_input_names)
N_out = len(var_output_names)

nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)

all_train_input, all_train_output = preload_data(nctrains, Ntrain)
all_test_input, all_test_output = preload_data(nctest, Ntest)

#Normalize data
#Compute mean and variance for normalization
mean_input=np.nanmean(np.concatenate((all_train_input, all_test_input), axis=0),axis=(0, 2, 3))
mean_output=np.nanmean(np.concatenate((all_train_output, all_test_output), axis=0),axis=(0, 2, 3))
#Subtract the data with their means
all_train_input=all_train_input-mean_input[None, :, None, None]
all_train_output=all_train_output-mean_output[None, :, None, None]
all_test_input=all_test_input-mean_input[None, :, None, None]
all_test_output=all_test_output-mean_output[None, :, None, None]
#Compute the variances
var_input=np.nanmean((np.concatenate((all_train_input, all_test_input), axis=0))**2,axis=(0, 2, 3))
var_output=np.nanmean((np.concatenate((all_train_output, all_test_output), axis=0))**2,axis=(0, 2, 3))
#Scale the data so that they have variance of 1
all_train_input=all_train_input/np.sqrt(var_input[None, :, None, None])
all_train_output=all_train_output/np.sqrt(var_output[None, :, None, None])
all_test_input=all_test_input/np.sqrt(var_input[None, :, None, None])
all_test_output=all_test_output/np.sqrt(var_output[None, :, None, None])

inp_test, out_test = loaddata_preloaded_test(all_test_input, all_test_output)
out_test = out_test*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

truth_bot = out_test[:, :, bottom_slice, :]
truth_mid = out_test[:, :, mid_slice, :]
truth_top = out_test[:, :, top_slice, :]

combined_names = ''.join(var_input_names)

#Array to record performance metrics
corr_ensemble_full = np.zeros(nensemble)
R2_ensemble_full = np.zeros(nensemble)
corr_ensemble_top = np.zeros(nensemble)
R2_ensemble_top = np.zeros(nensemble)
corr_ensemble_mid = np.zeros(nensemble)
R2_ensemble_mid = np.zeros(nensemble)
corr_ensemble_bot = np.zeros(nensemble)
R2_ensemble_bot = np.zeros(nensemble)
for iensemble in np.arange(nensemble):
    model_filename = f'any_{combined_names}{filesuffix1}{iensemble}{filesuffix2}'
    model_path = Path(model_folder, model_filename)
    state_dict = torch.load(model_path)
    # Create a new instance of the model
    model = UNet_nobatchnorm(N_inp, N_out, bilinear = True, Nbase = Nbase)
    # Load the state_dict into the model
    model.load_state_dict(state_dict)
    # Set the model to evaluation mode
    model.eval()
    
    with torch.no_grad():
        out_mod = model(totorch(inp_test)).detach().cpu().numpy()
    #Renormalize
    out_mod = out_mod*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

    mod_bot = out_mod[:, :, bottom_slice, :]
    mod_mid = out_mod[:, :, mid_slice, :]
    mod_top = out_mod[:, :, top_slice, :]

    corr_ensemble_full[iensemble] = corr(out_test, out_mod)
    R2_ensemble_full[iensemble] = L2_R(out_test, out_mod)    
    corr_ensemble_top[iensemble] = corr(truth_top, mod_top)
    R2_ensemble_top[iensemble] = L2_R(truth_top, mod_top)  
    corr_ensemble_mid[iensemble] = corr(truth_mid, mod_mid)
    R2_ensemble_mid[iensemble] = L2_R(truth_mid, mod_mid)
    corr_ensemble_bot[iensemble] = corr(truth_bot, mod_bot)
    R2_ensemble_bot[iensemble] = L2_R(truth_bot, mod_bot)

print('full panel, correlation:')
mean_val = np.mean(corr_ensemble_full)
max_val = np.max(corr_ensemble_full)
min_val = np.min(corr_ensemble_full)
std_val = np.std(corr_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('full panel, R2:')
mean_val = np.mean(R2_ensemble_full)
max_val = np.max(R2_ensemble_full)
min_val = np.min(R2_ensemble_full)
std_val = np.std(R2_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, correlation:')
mean_val = np.mean(corr_ensemble_mid)
max_val = np.max(corr_ensemble_mid)
min_val = np.min(corr_ensemble_mid)
std_val = np.std(corr_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, R2:')
mean_val = np.mean(R2_ensemble_mid)
max_val = np.max(R2_ensemble_mid)
min_val = np.min(R2_ensemble_mid)
std_val = np.std(R2_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)

full panel, correlation:
, full & 0.7908 & 0.7964 & 0.7868 & 0.0029 \\
full panel, R2:
, full & 0.6235 & 0.6338 & 0.6175 & 0.0046 \\
midjet panel, correlation:
, mid & 0.6781 & 0.6872 & 0.6689 & 0.0059 \\
midjet panel, R2:
, mid & 0.4544 & 0.4689 & 0.4431 & 0.0078 \\


In [5]:
# ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓ Change below for each Configuration ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓
vi1 = 'ssh_ins'
vi2 = 'T_xy_ins'
vi3 = 'u_xy_ins'
vi4 = 'v_xy_ins'

vo1 = 'ssh_cos'
vo2 = 'ssh_sin'

save_fn_prefix  = 'any_{}{}{}{}_{}{}_nobatchnorm'.format(vi1, vi2, vi3, vi4, vo1, vo2)
var_input_names = [vi1, vi2, vi3, vi4]
var_output_names = [vo1, vo2]
# ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑
N_inp = len(var_input_names)
N_out = len(var_output_names)

nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)

all_train_input, all_train_output = preload_data(nctrains, Ntrain)
all_test_input, all_test_output = preload_data(nctest, Ntest)

#Normalize data
#Compute mean and variance for normalization
mean_input=np.nanmean(np.concatenate((all_train_input, all_test_input), axis=0),axis=(0, 2, 3))
mean_output=np.nanmean(np.concatenate((all_train_output, all_test_output), axis=0),axis=(0, 2, 3))
#Subtract the data with their means
all_train_input=all_train_input-mean_input[None, :, None, None]
all_train_output=all_train_output-mean_output[None, :, None, None]
all_test_input=all_test_input-mean_input[None, :, None, None]
all_test_output=all_test_output-mean_output[None, :, None, None]
#Compute the variances
var_input=np.nanmean((np.concatenate((all_train_input, all_test_input), axis=0))**2,axis=(0, 2, 3))
var_output=np.nanmean((np.concatenate((all_train_output, all_test_output), axis=0))**2,axis=(0, 2, 3))
#Scale the data so that they have variance of 1
all_train_input=all_train_input/np.sqrt(var_input[None, :, None, None])
all_train_output=all_train_output/np.sqrt(var_output[None, :, None, None])
all_test_input=all_test_input/np.sqrt(var_input[None, :, None, None])
all_test_output=all_test_output/np.sqrt(var_output[None, :, None, None])

inp_test, out_test = loaddata_preloaded_test(all_test_input, all_test_output)
out_test = out_test*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

truth_bot = out_test[:, :, bottom_slice, :]
truth_mid = out_test[:, :, mid_slice, :]
truth_top = out_test[:, :, top_slice, :]

combined_names = ''.join(var_input_names)

#Array to record performance metrics
corr_ensemble_full = np.zeros(nensemble)
R2_ensemble_full = np.zeros(nensemble)
corr_ensemble_top = np.zeros(nensemble)
R2_ensemble_top = np.zeros(nensemble)
corr_ensemble_mid = np.zeros(nensemble)
R2_ensemble_mid = np.zeros(nensemble)
corr_ensemble_bot = np.zeros(nensemble)
R2_ensemble_bot = np.zeros(nensemble)
for iensemble in np.arange(nensemble):
    model_filename = f'any_{combined_names}{filesuffix1}{iensemble}{filesuffix2}'
    model_path = Path(model_folder, model_filename)
    state_dict = torch.load(model_path)
    # Create a new instance of the model
    model = UNet_nobatchnorm(N_inp, N_out, bilinear = True, Nbase = Nbase)
    # Load the state_dict into the model
    model.load_state_dict(state_dict)
    # Set the model to evaluation mode
    model.eval()
    
    with torch.no_grad():
        out_mod = model(totorch(inp_test)).detach().cpu().numpy()
    #Renormalize
    out_mod = out_mod*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

    mod_bot = out_mod[:, :, bottom_slice, :]
    mod_mid = out_mod[:, :, mid_slice, :]
    mod_top = out_mod[:, :, top_slice, :]

    corr_ensemble_full[iensemble] = corr(out_test, out_mod)
    R2_ensemble_full[iensemble] = L2_R(out_test, out_mod)    
    corr_ensemble_top[iensemble] = corr(truth_top, mod_top)
    R2_ensemble_top[iensemble] = L2_R(truth_top, mod_top)  
    corr_ensemble_mid[iensemble] = corr(truth_mid, mod_mid)
    R2_ensemble_mid[iensemble] = L2_R(truth_mid, mod_mid)
    corr_ensemble_bot[iensemble] = corr(truth_bot, mod_bot)
    R2_ensemble_bot[iensemble] = L2_R(truth_bot, mod_bot)

print('full panel, correlation:')
mean_val = np.mean(corr_ensemble_full)
max_val = np.max(corr_ensemble_full)
min_val = np.min(corr_ensemble_full)
std_val = np.std(corr_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('full panel, R2:')
mean_val = np.mean(R2_ensemble_full)
max_val = np.max(R2_ensemble_full)
min_val = np.min(R2_ensemble_full)
std_val = np.std(R2_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, correlation:')
mean_val = np.mean(corr_ensemble_mid)
max_val = np.max(corr_ensemble_mid)
min_val = np.min(corr_ensemble_mid)
std_val = np.std(corr_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, R2:')
mean_val = np.mean(R2_ensemble_mid)
max_val = np.max(R2_ensemble_mid)
min_val = np.min(R2_ensemble_mid)
std_val = np.std(R2_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)

full panel, correlation:
, full & 0.9708 & 0.9763 & 0.9515 & 0.0077 \\
full panel, R2:
, full & 0.9421 & 0.9528 & 0.9052 & 0.0148 \\
midjet panel, correlation:
, mid & 0.9531 & 0.9629 & 0.9202 & 0.0132 \\
midjet panel, R2:
, mid & 0.9078 & 0.9264 & 0.8467 & 0.0248 \\


In [6]:
# ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓ Change below for each Configuration ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓
vi1 = 'T_xy_ins'

vo1 = 'ssh_cos'
vo2 = 'ssh_sin'

save_fn_prefix  = 'any_{}_{}{}_nobatchnorm'.format(vi1, vo1, vo2)
var_input_names = [vi1]
var_output_names = [vo1, vo2]
# ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑
N_inp = len(var_input_names)
N_out = len(var_output_names)

nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)

all_train_input, all_train_output = preload_data(nctrains, Ntrain)
all_test_input, all_test_output = preload_data(nctest, Ntest)

#Normalize data
#Compute mean and variance for normalization
mean_input=np.nanmean(np.concatenate((all_train_input, all_test_input), axis=0),axis=(0, 2, 3))
mean_output=np.nanmean(np.concatenate((all_train_output, all_test_output), axis=0),axis=(0, 2, 3))
#Subtract the data with their means
all_train_input=all_train_input-mean_input[None, :, None, None]
all_train_output=all_train_output-mean_output[None, :, None, None]
all_test_input=all_test_input-mean_input[None, :, None, None]
all_test_output=all_test_output-mean_output[None, :, None, None]
#Compute the variances
var_input=np.nanmean((np.concatenate((all_train_input, all_test_input), axis=0))**2,axis=(0, 2, 3))
var_output=np.nanmean((np.concatenate((all_train_output, all_test_output), axis=0))**2,axis=(0, 2, 3))
#Scale the data so that they have variance of 1
all_train_input=all_train_input/np.sqrt(var_input[None, :, None, None])
all_train_output=all_train_output/np.sqrt(var_output[None, :, None, None])
all_test_input=all_test_input/np.sqrt(var_input[None, :, None, None])
all_test_output=all_test_output/np.sqrt(var_output[None, :, None, None])

inp_test, out_test = loaddata_preloaded_test(all_test_input, all_test_output)
out_test = out_test*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

truth_bot = out_test[:, :, bottom_slice, :]
truth_mid = out_test[:, :, mid_slice, :]
truth_top = out_test[:, :, top_slice, :]

combined_names = ''.join(var_input_names)

#Array to record performance metrics
corr_ensemble_full = np.zeros(nensemble)
R2_ensemble_full = np.zeros(nensemble)
corr_ensemble_top = np.zeros(nensemble)
R2_ensemble_top = np.zeros(nensemble)
corr_ensemble_mid = np.zeros(nensemble)
R2_ensemble_mid = np.zeros(nensemble)
corr_ensemble_bot = np.zeros(nensemble)
R2_ensemble_bot = np.zeros(nensemble)
for iensemble in np.arange(nensemble):
    model_filename = f'any_{combined_names}{filesuffix1}{iensemble}{filesuffix2}'
    model_path = Path(model_folder, model_filename)
    state_dict = torch.load(model_path)
    # Create a new instance of the model
    model = UNet_nobatchnorm(N_inp, N_out, bilinear = True, Nbase = Nbase)
    # Load the state_dict into the model
    model.load_state_dict(state_dict)
    # Set the model to evaluation mode
    model.eval()
    
    with torch.no_grad():
        out_mod = model(totorch(inp_test)).detach().cpu().numpy()
    #Renormalize
    out_mod = out_mod*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

    mod_bot = out_mod[:, :, bottom_slice, :]
    mod_mid = out_mod[:, :, mid_slice, :]
    mod_top = out_mod[:, :, top_slice, :]

    corr_ensemble_full[iensemble] = corr(out_test, out_mod)
    R2_ensemble_full[iensemble] = L2_R(out_test, out_mod)    
    corr_ensemble_top[iensemble] = corr(truth_top, mod_top)
    R2_ensemble_top[iensemble] = L2_R(truth_top, mod_top)  
    corr_ensemble_mid[iensemble] = corr(truth_mid, mod_mid)
    R2_ensemble_mid[iensemble] = L2_R(truth_mid, mod_mid)
    corr_ensemble_bot[iensemble] = corr(truth_bot, mod_bot)
    R2_ensemble_bot[iensemble] = L2_R(truth_bot, mod_bot)

print('full panel, correlation:')
mean_val = np.mean(corr_ensemble_full)
max_val = np.max(corr_ensemble_full)
min_val = np.min(corr_ensemble_full)
std_val = np.std(corr_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('full panel, R2:')
mean_val = np.mean(R2_ensemble_full)
max_val = np.max(R2_ensemble_full)
min_val = np.min(R2_ensemble_full)
std_val = np.std(R2_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, correlation:')
mean_val = np.mean(corr_ensemble_mid)
max_val = np.max(corr_ensemble_mid)
min_val = np.min(corr_ensemble_mid)
std_val = np.std(corr_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, R2:')
mean_val = np.mean(R2_ensemble_mid)
max_val = np.max(R2_ensemble_mid)
min_val = np.min(R2_ensemble_mid)
std_val = np.std(R2_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)

full panel, correlation:
, full & 0.3970 & 0.4168 & 0.3700 & 0.0143 \\
full panel, R2:
, full & 0.1453 & 0.1649 & 0.1227 & 0.0137 \\
midjet panel, correlation:
, mid & 0.0442 & 0.0697 & 0.0066 & 0.0222 \\
midjet panel, R2:
, mid & -0.0491 & -0.0156 & -0.0899 & 0.0239 \\


In [7]:
# ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓ Change below for each Configuration ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓
vi1 = 'u_xy_ins'
vi2 = 'v_xy_ins'

vo1 = 'ssh_cos'
vo2 = 'ssh_sin'

save_fn_prefix  = 'any_{}{}_{}{}_nobatchnorm'.format(vi1, vi2, vo1, vo2)
var_input_names = [vi1, vi2]
var_output_names = [vo1, vo2]
# ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑
N_inp = len(var_input_names)
N_out = len(var_output_names)

nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)

all_train_input, all_train_output = preload_data(nctrains, Ntrain)
all_test_input, all_test_output = preload_data(nctest, Ntest)

#Normalize data
#Compute mean and variance for normalization
mean_input=np.nanmean(np.concatenate((all_train_input, all_test_input), axis=0),axis=(0, 2, 3))
mean_output=np.nanmean(np.concatenate((all_train_output, all_test_output), axis=0),axis=(0, 2, 3))
#Subtract the data with their means
all_train_input=all_train_input-mean_input[None, :, None, None]
all_train_output=all_train_output-mean_output[None, :, None, None]
all_test_input=all_test_input-mean_input[None, :, None, None]
all_test_output=all_test_output-mean_output[None, :, None, None]
#Compute the variances
var_input=np.nanmean((np.concatenate((all_train_input, all_test_input), axis=0))**2,axis=(0, 2, 3))
var_output=np.nanmean((np.concatenate((all_train_output, all_test_output), axis=0))**2,axis=(0, 2, 3))
#Scale the data so that they have variance of 1
all_train_input=all_train_input/np.sqrt(var_input[None, :, None, None])
all_train_output=all_train_output/np.sqrt(var_output[None, :, None, None])
all_test_input=all_test_input/np.sqrt(var_input[None, :, None, None])
all_test_output=all_test_output/np.sqrt(var_output[None, :, None, None])

inp_test, out_test = loaddata_preloaded_test(all_test_input, all_test_output)
out_test = out_test*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

truth_bot = out_test[:, :, bottom_slice, :]
truth_mid = out_test[:, :, mid_slice, :]
truth_top = out_test[:, :, top_slice, :]

combined_names = ''.join(var_input_names)

#Array to record performance metrics
corr_ensemble_full = np.zeros(nensemble)
R2_ensemble_full = np.zeros(nensemble)
corr_ensemble_top = np.zeros(nensemble)
R2_ensemble_top = np.zeros(nensemble)
corr_ensemble_mid = np.zeros(nensemble)
R2_ensemble_mid = np.zeros(nensemble)
corr_ensemble_bot = np.zeros(nensemble)
R2_ensemble_bot = np.zeros(nensemble)
for iensemble in np.arange(nensemble):
    model_filename = f'any_{combined_names}{filesuffix1}{iensemble}{filesuffix2}'
    model_path = Path(model_folder, model_filename)
    state_dict = torch.load(model_path)
    # Create a new instance of the model
    model = UNet_nobatchnorm(N_inp, N_out, bilinear = True, Nbase = Nbase)
    # Load the state_dict into the model
    model.load_state_dict(state_dict)
    # Set the model to evaluation mode
    model.eval()
    
    with torch.no_grad():
        out_mod = model(totorch(inp_test)).detach().cpu().numpy()
    #Renormalize
    out_mod = out_mod*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

    mod_bot = out_mod[:, :, bottom_slice, :]
    mod_mid = out_mod[:, :, mid_slice, :]
    mod_top = out_mod[:, :, top_slice, :]

    corr_ensemble_full[iensemble] = corr(out_test, out_mod)
    R2_ensemble_full[iensemble] = L2_R(out_test, out_mod)    
    corr_ensemble_top[iensemble] = corr(truth_top, mod_top)
    R2_ensemble_top[iensemble] = L2_R(truth_top, mod_top)  
    corr_ensemble_mid[iensemble] = corr(truth_mid, mod_mid)
    R2_ensemble_mid[iensemble] = L2_R(truth_mid, mod_mid)
    corr_ensemble_bot[iensemble] = corr(truth_bot, mod_bot)
    R2_ensemble_bot[iensemble] = L2_R(truth_bot, mod_bot)

print('full panel, correlation:')
mean_val = np.mean(corr_ensemble_full)
max_val = np.max(corr_ensemble_full)
min_val = np.min(corr_ensemble_full)
std_val = np.std(corr_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('full panel, R2:')
mean_val = np.mean(R2_ensemble_full)
max_val = np.max(R2_ensemble_full)
min_val = np.min(R2_ensemble_full)
std_val = np.std(R2_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, correlation:')
mean_val = np.mean(corr_ensemble_mid)
max_val = np.max(corr_ensemble_mid)
min_val = np.min(corr_ensemble_mid)
std_val = np.std(corr_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, R2:')
mean_val = np.mean(R2_ensemble_mid)
max_val = np.max(R2_ensemble_mid)
min_val = np.min(R2_ensemble_mid)
std_val = np.std(R2_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)

full panel, correlation:
, full & 0.9407 & 0.9415 & 0.9400 & 0.0006 \\
full panel, R2:
, full & 0.8848 & 0.8863 & 0.8835 & 0.0012 \\
midjet panel, correlation:
, mid & 0.9021 & 0.9038 & 0.8996 & 0.0012 \\
midjet panel, R2:
, mid & 0.8136 & 0.8165 & 0.8090 & 0.0022 \\


In [8]:
# ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓ Change below for each Configuration ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓
vi1 = 'ssh_ins'
vi2 = 'T_xy_ins'

vo1 = 'ssh_cos'
vo2 = 'ssh_sin'

save_fn_prefix  = 'any_{}{}_{}{}_nobatchnorm'.format(vi1, vi2, vo1, vo2)
var_input_names = [vi1, vi2]
var_output_names = [vo1, vo2]
# ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑
N_inp = len(var_input_names)
N_out = len(var_output_names)

nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)

all_train_input, all_train_output = preload_data(nctrains, Ntrain)
all_test_input, all_test_output = preload_data(nctest, Ntest)

#Normalize data
#Compute mean and variance for normalization
mean_input=np.nanmean(np.concatenate((all_train_input, all_test_input), axis=0),axis=(0, 2, 3))
mean_output=np.nanmean(np.concatenate((all_train_output, all_test_output), axis=0),axis=(0, 2, 3))
#Subtract the data with their means
all_train_input=all_train_input-mean_input[None, :, None, None]
all_train_output=all_train_output-mean_output[None, :, None, None]
all_test_input=all_test_input-mean_input[None, :, None, None]
all_test_output=all_test_output-mean_output[None, :, None, None]
#Compute the variances
var_input=np.nanmean((np.concatenate((all_train_input, all_test_input), axis=0))**2,axis=(0, 2, 3))
var_output=np.nanmean((np.concatenate((all_train_output, all_test_output), axis=0))**2,axis=(0, 2, 3))
#Scale the data so that they have variance of 1
all_train_input=all_train_input/np.sqrt(var_input[None, :, None, None])
all_train_output=all_train_output/np.sqrt(var_output[None, :, None, None])
all_test_input=all_test_input/np.sqrt(var_input[None, :, None, None])
all_test_output=all_test_output/np.sqrt(var_output[None, :, None, None])

inp_test, out_test = loaddata_preloaded_test(all_test_input, all_test_output)
out_test = out_test*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

truth_bot = out_test[:, :, bottom_slice, :]
truth_mid = out_test[:, :, mid_slice, :]
truth_top = out_test[:, :, top_slice, :]

combined_names = ''.join(var_input_names)

#Array to record performance metrics
corr_ensemble_full = np.zeros(nensemble)
R2_ensemble_full = np.zeros(nensemble)
corr_ensemble_top = np.zeros(nensemble)
R2_ensemble_top = np.zeros(nensemble)
corr_ensemble_mid = np.zeros(nensemble)
R2_ensemble_mid = np.zeros(nensemble)
corr_ensemble_bot = np.zeros(nensemble)
R2_ensemble_bot = np.zeros(nensemble)
for iensemble in np.arange(nensemble):
    model_filename = f'any_{combined_names}{filesuffix1}{iensemble}{filesuffix2}'
    model_path = Path(model_folder, model_filename)
    state_dict = torch.load(model_path)
    # Create a new instance of the model
    model = UNet_nobatchnorm(N_inp, N_out, bilinear = True, Nbase = Nbase)
    # Load the state_dict into the model
    model.load_state_dict(state_dict)
    # Set the model to evaluation mode
    model.eval()
    
    with torch.no_grad():
        out_mod = model(totorch(inp_test)).detach().cpu().numpy()
    #Renormalize
    out_mod = out_mod*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

    mod_bot = out_mod[:, :, bottom_slice, :]
    mod_mid = out_mod[:, :, mid_slice, :]
    mod_top = out_mod[:, :, top_slice, :]

    corr_ensemble_full[iensemble] = corr(out_test, out_mod)
    R2_ensemble_full[iensemble] = L2_R(out_test, out_mod)    
    corr_ensemble_top[iensemble] = corr(truth_top, mod_top)
    R2_ensemble_top[iensemble] = L2_R(truth_top, mod_top)  
    corr_ensemble_mid[iensemble] = corr(truth_mid, mod_mid)
    R2_ensemble_mid[iensemble] = L2_R(truth_mid, mod_mid)
    corr_ensemble_bot[iensemble] = corr(truth_bot, mod_bot)
    R2_ensemble_bot[iensemble] = L2_R(truth_bot, mod_bot)

print('full panel, correlation:')
mean_val = np.mean(corr_ensemble_full)
max_val = np.max(corr_ensemble_full)
min_val = np.min(corr_ensemble_full)
std_val = np.std(corr_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('full panel, R2:')
mean_val = np.mean(R2_ensemble_full)
max_val = np.max(R2_ensemble_full)
min_val = np.min(R2_ensemble_full)
std_val = np.std(R2_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, correlation:')
mean_val = np.mean(corr_ensemble_mid)
max_val = np.max(corr_ensemble_mid)
min_val = np.min(corr_ensemble_mid)
std_val = np.std(corr_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, R2:')
mean_val = np.mean(R2_ensemble_mid)
max_val = np.max(R2_ensemble_mid)
min_val = np.min(R2_ensemble_mid)
std_val = np.std(R2_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)

full panel, correlation:
, full & 0.8568 & 0.8619 & 0.8488 & 0.0038 \\
full panel, R2:
, full & 0.7328 & 0.7415 & 0.7194 & 0.0064 \\
midjet panel, correlation:
, mid & 0.7902 & 0.8015 & 0.7783 & 0.0067 \\
midjet panel, R2:
, mid & 0.6235 & 0.6416 & 0.6052 & 0.0107 \\


In [9]:
# ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓ Change below for each Configuration ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓
vi1 = 'ssh_ins'
vi2 = 'u_xy_ins'
vi3 = 'v_xy_ins'

vo1 = 'ssh_cos'
vo2 = 'ssh_sin'

save_fn_prefix  = 'any_{}{}{}_{}{}_nobatchnorm'.format(vi1, vi2, vi3, vo1, vo2)
var_input_names = [vi1, vi2, vi3]
var_output_names = [vo1, vo2]
# ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑
N_inp = len(var_input_names)
N_out = len(var_output_names)

nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)

all_train_input, all_train_output = preload_data(nctrains, Ntrain)
all_test_input, all_test_output = preload_data(nctest, Ntest)

#Normalize data
#Compute mean and variance for normalization
mean_input=np.nanmean(np.concatenate((all_train_input, all_test_input), axis=0),axis=(0, 2, 3))
mean_output=np.nanmean(np.concatenate((all_train_output, all_test_output), axis=0),axis=(0, 2, 3))
#Subtract the data with their means
all_train_input=all_train_input-mean_input[None, :, None, None]
all_train_output=all_train_output-mean_output[None, :, None, None]
all_test_input=all_test_input-mean_input[None, :, None, None]
all_test_output=all_test_output-mean_output[None, :, None, None]
#Compute the variances
var_input=np.nanmean((np.concatenate((all_train_input, all_test_input), axis=0))**2,axis=(0, 2, 3))
var_output=np.nanmean((np.concatenate((all_train_output, all_test_output), axis=0))**2,axis=(0, 2, 3))
#Scale the data so that they have variance of 1
all_train_input=all_train_input/np.sqrt(var_input[None, :, None, None])
all_train_output=all_train_output/np.sqrt(var_output[None, :, None, None])
all_test_input=all_test_input/np.sqrt(var_input[None, :, None, None])
all_test_output=all_test_output/np.sqrt(var_output[None, :, None, None])

inp_test, out_test = loaddata_preloaded_test(all_test_input, all_test_output)
out_test = out_test*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

truth_bot = out_test[:, :, bottom_slice, :]
truth_mid = out_test[:, :, mid_slice, :]
truth_top = out_test[:, :, top_slice, :]

combined_names = ''.join(var_input_names)

#Array to record performance metrics
corr_ensemble_full = np.zeros(nensemble)
R2_ensemble_full = np.zeros(nensemble)
corr_ensemble_top = np.zeros(nensemble)
R2_ensemble_top = np.zeros(nensemble)
corr_ensemble_mid = np.zeros(nensemble)
R2_ensemble_mid = np.zeros(nensemble)
corr_ensemble_bot = np.zeros(nensemble)
R2_ensemble_bot = np.zeros(nensemble)
for iensemble in np.arange(nensemble):
    model_filename = f'any_{combined_names}{filesuffix1}{iensemble}{filesuffix2}'
    model_path = Path(model_folder, model_filename)
    state_dict = torch.load(model_path)
    # Create a new instance of the model
    model = UNet_nobatchnorm(N_inp, N_out, bilinear = True, Nbase = Nbase)
    # Load the state_dict into the model
    model.load_state_dict(state_dict)
    # Set the model to evaluation mode
    model.eval()
    
    with torch.no_grad():
        out_mod = model(totorch(inp_test)).detach().cpu().numpy()
    #Renormalize
    out_mod = out_mod*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

    mod_bot = out_mod[:, :, bottom_slice, :]
    mod_mid = out_mod[:, :, mid_slice, :]
    mod_top = out_mod[:, :, top_slice, :]

    corr_ensemble_full[iensemble] = corr(out_test, out_mod)
    R2_ensemble_full[iensemble] = L2_R(out_test, out_mod)    
    corr_ensemble_top[iensemble] = corr(truth_top, mod_top)
    R2_ensemble_top[iensemble] = L2_R(truth_top, mod_top)  
    corr_ensemble_mid[iensemble] = corr(truth_mid, mod_mid)
    R2_ensemble_mid[iensemble] = L2_R(truth_mid, mod_mid)
    corr_ensemble_bot[iensemble] = corr(truth_bot, mod_bot)
    R2_ensemble_bot[iensemble] = L2_R(truth_bot, mod_bot)

print('full panel, correlation:')
mean_val = np.mean(corr_ensemble_full)
max_val = np.max(corr_ensemble_full)
min_val = np.min(corr_ensemble_full)
std_val = np.std(corr_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('full panel, R2:')
mean_val = np.mean(R2_ensemble_full)
max_val = np.max(R2_ensemble_full)
min_val = np.min(R2_ensemble_full)
std_val = np.std(R2_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, correlation:')
mean_val = np.mean(corr_ensemble_mid)
max_val = np.max(corr_ensemble_mid)
min_val = np.min(corr_ensemble_mid)
std_val = np.std(corr_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, R2:')
mean_val = np.mean(R2_ensemble_mid)
max_val = np.max(R2_ensemble_mid)
min_val = np.min(R2_ensemble_mid)
std_val = np.std(R2_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)

full panel, correlation:
, full & 0.9627 & 0.9744 & 0.9529 & 0.0079 \\
full panel, R2:
, full & 0.9267 & 0.9493 & 0.9078 & 0.0152 \\
midjet panel, correlation:
, mid & 0.9384 & 0.9594 & 0.9218 & 0.0141 \\
midjet panel, R2:
, mid & 0.8804 & 0.9202 & 0.8493 & 0.0263 \\


In [10]:
# ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓ Change below for each Configuration ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓
vi1 = 'T_xy_ins'
vi2 = 'u_xy_ins'
vi3 = 'v_xy_ins'

vo1 = 'ssh_cos'
vo2 = 'ssh_sin'

save_fn_prefix  = 'any_{}{}{}_{}{}_nobatchnorm'.format(vi1, vi2, vi3, vo1, vo2)
var_input_names = [vi1, vi2, vi3]
var_output_names = [vo1, vo2]
# ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑
N_inp = len(var_input_names)
N_out = len(var_output_names)

nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)

all_train_input, all_train_output = preload_data(nctrains, Ntrain)
all_test_input, all_test_output = preload_data(nctest, Ntest)

#Normalize data
#Compute mean and variance for normalization
mean_input=np.nanmean(np.concatenate((all_train_input, all_test_input), axis=0),axis=(0, 2, 3))
mean_output=np.nanmean(np.concatenate((all_train_output, all_test_output), axis=0),axis=(0, 2, 3))
#Subtract the data with their means
all_train_input=all_train_input-mean_input[None, :, None, None]
all_train_output=all_train_output-mean_output[None, :, None, None]
all_test_input=all_test_input-mean_input[None, :, None, None]
all_test_output=all_test_output-mean_output[None, :, None, None]
#Compute the variances
var_input=np.nanmean((np.concatenate((all_train_input, all_test_input), axis=0))**2,axis=(0, 2, 3))
var_output=np.nanmean((np.concatenate((all_train_output, all_test_output), axis=0))**2,axis=(0, 2, 3))
#Scale the data so that they have variance of 1
all_train_input=all_train_input/np.sqrt(var_input[None, :, None, None])
all_train_output=all_train_output/np.sqrt(var_output[None, :, None, None])
all_test_input=all_test_input/np.sqrt(var_input[None, :, None, None])
all_test_output=all_test_output/np.sqrt(var_output[None, :, None, None])

inp_test, out_test = loaddata_preloaded_test(all_test_input, all_test_output)
out_test = out_test*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

truth_bot = out_test[:, :, bottom_slice, :]
truth_mid = out_test[:, :, mid_slice, :]
truth_top = out_test[:, :, top_slice, :]

combined_names = ''.join(var_input_names)

#Array to record performance metrics
corr_ensemble_full = np.zeros(nensemble)
R2_ensemble_full = np.zeros(nensemble)
corr_ensemble_top = np.zeros(nensemble)
R2_ensemble_top = np.zeros(nensemble)
corr_ensemble_mid = np.zeros(nensemble)
R2_ensemble_mid = np.zeros(nensemble)
corr_ensemble_bot = np.zeros(nensemble)
R2_ensemble_bot = np.zeros(nensemble)
for iensemble in np.arange(nensemble):
    model_filename = f'any_{combined_names}{filesuffix1}{iensemble}{filesuffix2}'
    model_path = Path(model_folder, model_filename)
    state_dict = torch.load(model_path)
    # Create a new instance of the model
    model = UNet_nobatchnorm(N_inp, N_out, bilinear = True, Nbase = Nbase)
    # Load the state_dict into the model
    model.load_state_dict(state_dict)
    # Set the model to evaluation mode
    model.eval()
    
    with torch.no_grad():
        out_mod = model(totorch(inp_test)).detach().cpu().numpy()
    #Renormalize
    out_mod = out_mod*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

    mod_bot = out_mod[:, :, bottom_slice, :]
    mod_mid = out_mod[:, :, mid_slice, :]
    mod_top = out_mod[:, :, top_slice, :]

    corr_ensemble_full[iensemble] = corr(out_test, out_mod)
    R2_ensemble_full[iensemble] = L2_R(out_test, out_mod)    
    corr_ensemble_top[iensemble] = corr(truth_top, mod_top)
    R2_ensemble_top[iensemble] = L2_R(truth_top, mod_top)  
    corr_ensemble_mid[iensemble] = corr(truth_mid, mod_mid)
    R2_ensemble_mid[iensemble] = L2_R(truth_mid, mod_mid)
    corr_ensemble_bot[iensemble] = corr(truth_bot, mod_bot)
    R2_ensemble_bot[iensemble] = L2_R(truth_bot, mod_bot)

print('full panel, correlation:')
mean_val = np.mean(corr_ensemble_full)
max_val = np.max(corr_ensemble_full)
min_val = np.min(corr_ensemble_full)
std_val = np.std(corr_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('full panel, R2:')
mean_val = np.mean(R2_ensemble_full)
max_val = np.max(R2_ensemble_full)
min_val = np.min(R2_ensemble_full)
std_val = np.std(R2_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, correlation:')
mean_val = np.mean(corr_ensemble_mid)
max_val = np.max(corr_ensemble_mid)
min_val = np.min(corr_ensemble_mid)
std_val = np.std(corr_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, R2:')
mean_val = np.mean(R2_ensemble_mid)
max_val = np.max(R2_ensemble_mid)
min_val = np.min(R2_ensemble_mid)
std_val = np.std(R2_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)

full panel, correlation:
, full & 0.9522 & 0.9534 & 0.9507 & 0.0008 \\
full panel, R2:
, full & 0.9065 & 0.9087 & 0.9034 & 0.0016 \\
midjet panel, correlation:
, mid & 0.9198 & 0.9217 & 0.9178 & 0.0014 \\
midjet panel, R2:
, mid & 0.8458 & 0.8495 & 0.8422 & 0.0025 \\


In [11]:
#Below are to check the runs with vorticity and divergence

In [12]:
# ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓ Change below for each Configuration ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓
vi1 = 'vort'

vo1 = 'ssh_cos'
vo2 = 'ssh_sin'

save_fn_prefix  = 'any_{}_{}{}_nobatchnorm'.format(vi1, vo1, vo2)
var_input_names = [vi1]
var_output_names = [vo1, vo2]
# ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑
N_inp = len(var_input_names)
N_out = len(var_output_names)

nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)

all_train_input, all_train_output = preload_data_vortdiv(nctrains, Ntrain)
all_test_input, all_test_output = preload_data_vortdiv(nctest, Ntest)

#Normalize data
#Compute mean and variance for normalization
mean_input=np.nanmean(np.concatenate((all_train_input, all_test_input), axis=0),axis=(0, 2, 3))
mean_output=np.nanmean(np.concatenate((all_train_output, all_test_output), axis=0),axis=(0, 2, 3))
#Subtract the data with their means
all_train_input=all_train_input-mean_input[None, :, None, None]
all_train_output=all_train_output-mean_output[None, :, None, None]
all_test_input=all_test_input-mean_input[None, :, None, None]
all_test_output=all_test_output-mean_output[None, :, None, None]
#Compute the variances
var_input=np.nanmean((np.concatenate((all_train_input, all_test_input), axis=0))**2,axis=(0, 2, 3))
var_output=np.nanmean((np.concatenate((all_train_output, all_test_output), axis=0))**2,axis=(0, 2, 3))
#Scale the data so that they have variance of 1
all_train_input=all_train_input/np.sqrt(var_input[None, :, None, None])
all_train_output=all_train_output/np.sqrt(var_output[None, :, None, None])
all_test_input=all_test_input/np.sqrt(var_input[None, :, None, None])
all_test_output=all_test_output/np.sqrt(var_output[None, :, None, None])

inp_test, out_test = loaddata_preloaded_test(all_test_input, all_test_output)
out_test = out_test*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

truth_bot = out_test[:, :, bottom_slice, :]
truth_mid = out_test[:, :, mid_slice, :]
truth_top = out_test[:, :, top_slice, :]

combined_names = ''.join(var_input_names)

#Array to record performance metrics
corr_ensemble_full = np.zeros(nensemble)
R2_ensemble_full = np.zeros(nensemble)
corr_ensemble_top = np.zeros(nensemble)
R2_ensemble_top = np.zeros(nensemble)
corr_ensemble_mid = np.zeros(nensemble)
R2_ensemble_mid = np.zeros(nensemble)
corr_ensemble_bot = np.zeros(nensemble)
R2_ensemble_bot = np.zeros(nensemble)
for iensemble in np.arange(nensemble):
    model_filename = f'any_{combined_names}{filesuffix1}{iensemble}{filesuffix2}'
    model_path = Path(model_folder, model_filename)
    state_dict = torch.load(model_path)
    # Create a new instance of the model
    model = UNet_nobatchnorm(N_inp, N_out, bilinear = True, Nbase = Nbase)
    # Load the state_dict into the model
    model.load_state_dict(state_dict)
    # Set the model to evaluation mode
    model.eval()
    
    with torch.no_grad():
        out_mod = model(totorch(inp_test)).detach().cpu().numpy()
    #Renormalize
    out_mod = out_mod*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

    mod_bot = out_mod[:, :, bottom_slice, :]
    mod_mid = out_mod[:, :, mid_slice, :]
    mod_top = out_mod[:, :, top_slice, :]

    corr_ensemble_full[iensemble] = corr(out_test, out_mod)
    R2_ensemble_full[iensemble] = L2_R(out_test, out_mod)    
    corr_ensemble_top[iensemble] = corr(truth_top, mod_top)
    R2_ensemble_top[iensemble] = L2_R(truth_top, mod_top)  
    corr_ensemble_mid[iensemble] = corr(truth_mid, mod_mid)
    R2_ensemble_mid[iensemble] = L2_R(truth_mid, mod_mid)
    corr_ensemble_bot[iensemble] = corr(truth_bot, mod_bot)
    R2_ensemble_bot[iensemble] = L2_R(truth_bot, mod_bot)

print('full panel, correlation:')
mean_val = np.mean(corr_ensemble_full)
max_val = np.max(corr_ensemble_full)
min_val = np.min(corr_ensemble_full)
std_val = np.std(corr_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('full panel, R2:')
mean_val = np.mean(R2_ensemble_full)
max_val = np.max(R2_ensemble_full)
min_val = np.min(R2_ensemble_full)
std_val = np.std(R2_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, correlation:')
mean_val = np.mean(corr_ensemble_mid)
max_val = np.max(corr_ensemble_mid)
min_val = np.min(corr_ensemble_mid)
std_val = np.std(corr_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, R2:')
mean_val = np.mean(R2_ensemble_mid)
max_val = np.max(R2_ensemble_mid)
min_val = np.min(R2_ensemble_mid)
std_val = np.std(R2_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)

full panel, correlation:
, full & 0.8201 & 0.8258 & 0.8151 & 0.0031 \\
full panel, R2:
, full & 0.6719 & 0.6815 & 0.6642 & 0.0051 \\
midjet panel, correlation:
, mid & 0.7125 & 0.7224 & 0.7049 & 0.0049 \\
midjet panel, R2:
, mid & 0.5072 & 0.5209 & 0.4969 & 0.0069 \\


In [13]:
# ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓ Change below for each Configuration ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓
vi1 = 'vort'
vi2 = 'T_xy_ins'

vo1 = 'ssh_cos'
vo2 = 'ssh_sin'

save_fn_prefix  = 'any_{}{}_{}{}_nobatchnorm'.format(vi1, vi2, vo1, vo2)
var_input_names = [vi1, vi2]
var_output_names = [vo1, vo2]
# ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑
N_inp = len(var_input_names)
N_out = len(var_output_names)

nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)

all_train_input, all_train_output = preload_data_vortdiv(nctrains, Ntrain)
all_test_input, all_test_output = preload_data_vortdiv(nctest, Ntest)

#Normalize data
#Compute mean and variance for normalization
mean_input=np.nanmean(np.concatenate((all_train_input, all_test_input), axis=0),axis=(0, 2, 3))
mean_output=np.nanmean(np.concatenate((all_train_output, all_test_output), axis=0),axis=(0, 2, 3))
#Subtract the data with their means
all_train_input=all_train_input-mean_input[None, :, None, None]
all_train_output=all_train_output-mean_output[None, :, None, None]
all_test_input=all_test_input-mean_input[None, :, None, None]
all_test_output=all_test_output-mean_output[None, :, None, None]
#Compute the variances
var_input=np.nanmean((np.concatenate((all_train_input, all_test_input), axis=0))**2,axis=(0, 2, 3))
var_output=np.nanmean((np.concatenate((all_train_output, all_test_output), axis=0))**2,axis=(0, 2, 3))
#Scale the data so that they have variance of 1
all_train_input=all_train_input/np.sqrt(var_input[None, :, None, None])
all_train_output=all_train_output/np.sqrt(var_output[None, :, None, None])
all_test_input=all_test_input/np.sqrt(var_input[None, :, None, None])
all_test_output=all_test_output/np.sqrt(var_output[None, :, None, None])

inp_test, out_test = loaddata_preloaded_test(all_test_input, all_test_output)
out_test = out_test*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

truth_bot = out_test[:, :, bottom_slice, :]
truth_mid = out_test[:, :, mid_slice, :]
truth_top = out_test[:, :, top_slice, :]

combined_names = ''.join(var_input_names)

#Array to record performance metrics
corr_ensemble_full = np.zeros(nensemble)
R2_ensemble_full = np.zeros(nensemble)
corr_ensemble_top = np.zeros(nensemble)
R2_ensemble_top = np.zeros(nensemble)
corr_ensemble_mid = np.zeros(nensemble)
R2_ensemble_mid = np.zeros(nensemble)
corr_ensemble_bot = np.zeros(nensemble)
R2_ensemble_bot = np.zeros(nensemble)
for iensemble in np.arange(nensemble):
    model_filename = f'any_{combined_names}{filesuffix1}{iensemble}{filesuffix2}'
    model_path = Path(model_folder, model_filename)
    state_dict = torch.load(model_path)
    # Create a new instance of the model
    model = UNet_nobatchnorm(N_inp, N_out, bilinear = True, Nbase = Nbase)
    # Load the state_dict into the model
    model.load_state_dict(state_dict)
    # Set the model to evaluation mode
    model.eval()
    
    with torch.no_grad():
        out_mod = model(totorch(inp_test)).detach().cpu().numpy()
    #Renormalize
    out_mod = out_mod*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

    mod_bot = out_mod[:, :, bottom_slice, :]
    mod_mid = out_mod[:, :, mid_slice, :]
    mod_top = out_mod[:, :, top_slice, :]

    corr_ensemble_full[iensemble] = corr(out_test, out_mod)
    R2_ensemble_full[iensemble] = L2_R(out_test, out_mod)    
    corr_ensemble_top[iensemble] = corr(truth_top, mod_top)
    R2_ensemble_top[iensemble] = L2_R(truth_top, mod_top)  
    corr_ensemble_mid[iensemble] = corr(truth_mid, mod_mid)
    R2_ensemble_mid[iensemble] = L2_R(truth_mid, mod_mid)
    corr_ensemble_bot[iensemble] = corr(truth_bot, mod_bot)
    R2_ensemble_bot[iensemble] = L2_R(truth_bot, mod_bot)

print('full panel, correlation:')
mean_val = np.mean(corr_ensemble_full)
max_val = np.max(corr_ensemble_full)
min_val = np.min(corr_ensemble_full)
std_val = np.std(corr_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('full panel, R2:')
mean_val = np.mean(R2_ensemble_full)
max_val = np.max(R2_ensemble_full)
min_val = np.min(R2_ensemble_full)
std_val = np.std(R2_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, correlation:')
mean_val = np.mean(corr_ensemble_mid)
max_val = np.max(corr_ensemble_mid)
min_val = np.min(corr_ensemble_mid)
std_val = np.std(corr_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, R2:')
mean_val = np.mean(R2_ensemble_mid)
max_val = np.max(R2_ensemble_mid)
min_val = np.min(R2_ensemble_mid)
std_val = np.std(R2_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)

full panel, correlation:
, full & 0.8470 & 0.8501 & 0.8437 & 0.0019 \\
full panel, R2:
, full & 0.7171 & 0.7221 & 0.7119 & 0.0031 \\
midjet panel, correlation:
, mid & 0.7549 & 0.7637 & 0.7470 & 0.0049 \\
midjet panel, R2:
, mid & 0.5696 & 0.5832 & 0.5577 & 0.0075 \\


In [14]:
# ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓ Change below for each Configuration ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓
vi1 = 'div'

vo1 = 'ssh_cos'
vo2 = 'ssh_sin'

save_fn_prefix  = 'any_{}_{}{}_nobatchnorm'.format(vi1, vo1, vo2)
var_input_names = [vi1]
var_output_names = [vo1, vo2]
# ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑
N_inp = len(var_input_names)
N_out = len(var_output_names)

nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)

all_train_input, all_train_output = preload_data_vortdiv(nctrains, Ntrain)
all_test_input, all_test_output = preload_data_vortdiv(nctest, Ntest)

#Normalize data
#Compute mean and variance for normalization
mean_input=np.nanmean(np.concatenate((all_train_input, all_test_input), axis=0),axis=(0, 2, 3))
mean_output=np.nanmean(np.concatenate((all_train_output, all_test_output), axis=0),axis=(0, 2, 3))
#Subtract the data with their means
all_train_input=all_train_input-mean_input[None, :, None, None]
all_train_output=all_train_output-mean_output[None, :, None, None]
all_test_input=all_test_input-mean_input[None, :, None, None]
all_test_output=all_test_output-mean_output[None, :, None, None]
#Compute the variances
var_input=np.nanmean((np.concatenate((all_train_input, all_test_input), axis=0))**2,axis=(0, 2, 3))
var_output=np.nanmean((np.concatenate((all_train_output, all_test_output), axis=0))**2,axis=(0, 2, 3))
#Scale the data so that they have variance of 1
all_train_input=all_train_input/np.sqrt(var_input[None, :, None, None])
all_train_output=all_train_output/np.sqrt(var_output[None, :, None, None])
all_test_input=all_test_input/np.sqrt(var_input[None, :, None, None])
all_test_output=all_test_output/np.sqrt(var_output[None, :, None, None])

inp_test, out_test = loaddata_preloaded_test(all_test_input, all_test_output)
out_test = out_test*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

truth_bot = out_test[:, :, bottom_slice, :]
truth_mid = out_test[:, :, mid_slice, :]
truth_top = out_test[:, :, top_slice, :]

combined_names = ''.join(var_input_names)

#Array to record performance metrics
corr_ensemble_full = np.zeros(nensemble)
R2_ensemble_full = np.zeros(nensemble)
corr_ensemble_top = np.zeros(nensemble)
R2_ensemble_top = np.zeros(nensemble)
corr_ensemble_mid = np.zeros(nensemble)
R2_ensemble_mid = np.zeros(nensemble)
corr_ensemble_bot = np.zeros(nensemble)
R2_ensemble_bot = np.zeros(nensemble)
for iensemble in np.arange(nensemble):
    model_filename = f'any_{combined_names}{filesuffix1}{iensemble}{filesuffix2}'
    model_path = Path(model_folder, model_filename)
    state_dict = torch.load(model_path)
    # Create a new instance of the model
    model = UNet_nobatchnorm(N_inp, N_out, bilinear = True, Nbase = Nbase)
    # Load the state_dict into the model
    model.load_state_dict(state_dict)
    # Set the model to evaluation mode
    model.eval()
    
    with torch.no_grad():
        out_mod = model(totorch(inp_test)).detach().cpu().numpy()
    #Renormalize
    out_mod = out_mod*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

    mod_bot = out_mod[:, :, bottom_slice, :]
    mod_mid = out_mod[:, :, mid_slice, :]
    mod_top = out_mod[:, :, top_slice, :]

    corr_ensemble_full[iensemble] = corr(out_test, out_mod)
    R2_ensemble_full[iensemble] = L2_R(out_test, out_mod)    
    corr_ensemble_top[iensemble] = corr(truth_top, mod_top)
    R2_ensemble_top[iensemble] = L2_R(truth_top, mod_top)  
    corr_ensemble_mid[iensemble] = corr(truth_mid, mod_mid)
    R2_ensemble_mid[iensemble] = L2_R(truth_mid, mod_mid)
    corr_ensemble_bot[iensemble] = corr(truth_bot, mod_bot)
    R2_ensemble_bot[iensemble] = L2_R(truth_bot, mod_bot)

print('full panel, correlation:')
mean_val = np.mean(corr_ensemble_full)
max_val = np.max(corr_ensemble_full)
min_val = np.min(corr_ensemble_full)
std_val = np.std(corr_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('full panel, R2:')
mean_val = np.mean(R2_ensemble_full)
max_val = np.max(R2_ensemble_full)
min_val = np.min(R2_ensemble_full)
std_val = np.std(R2_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, correlation:')
mean_val = np.mean(corr_ensemble_mid)
max_val = np.max(corr_ensemble_mid)
min_val = np.min(corr_ensemble_mid)
std_val = np.std(corr_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, R2:')
mean_val = np.mean(R2_ensemble_mid)
max_val = np.max(R2_ensemble_mid)
min_val = np.min(R2_ensemble_mid)
std_val = np.std(R2_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)

full panel, correlation:
, full & 0.8792 & 0.8822 & 0.8735 & 0.0022 \\
full panel, R2:
, full & 0.7713 & 0.7763 & 0.7611 & 0.0039 \\
midjet panel, correlation:
, mid & 0.8239 & 0.8294 & 0.8158 & 0.0035 \\
midjet panel, R2:
, mid & 0.6782 & 0.6870 & 0.6649 & 0.0058 \\


In [15]:
# ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓ Change below for each Configuration ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓
vi1 = 'vort'
vi2 = 'div'

vo1 = 'ssh_cos'
vo2 = 'ssh_sin'

save_fn_prefix  = 'any_{}{}_{}{}_nobatchnorm'.format(vi1, vi2, vo1, vo2)
var_input_names = [vi1, vi2]
var_output_names = [vo1, vo2]
# ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑
N_inp = len(var_input_names)
N_out = len(var_output_names)

nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)

all_train_input, all_train_output = preload_data_vortdiv(nctrains, Ntrain)
all_test_input, all_test_output = preload_data_vortdiv(nctest, Ntest)

#Normalize data
#Compute mean and variance for normalization
mean_input=np.nanmean(np.concatenate((all_train_input, all_test_input), axis=0),axis=(0, 2, 3))
mean_output=np.nanmean(np.concatenate((all_train_output, all_test_output), axis=0),axis=(0, 2, 3))
#Subtract the data with their means
all_train_input=all_train_input-mean_input[None, :, None, None]
all_train_output=all_train_output-mean_output[None, :, None, None]
all_test_input=all_test_input-mean_input[None, :, None, None]
all_test_output=all_test_output-mean_output[None, :, None, None]
#Compute the variances
var_input=np.nanmean((np.concatenate((all_train_input, all_test_input), axis=0))**2,axis=(0, 2, 3))
var_output=np.nanmean((np.concatenate((all_train_output, all_test_output), axis=0))**2,axis=(0, 2, 3))
#Scale the data so that they have variance of 1
all_train_input=all_train_input/np.sqrt(var_input[None, :, None, None])
all_train_output=all_train_output/np.sqrt(var_output[None, :, None, None])
all_test_input=all_test_input/np.sqrt(var_input[None, :, None, None])
all_test_output=all_test_output/np.sqrt(var_output[None, :, None, None])

inp_test, out_test = loaddata_preloaded_test(all_test_input, all_test_output)
out_test = out_test*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

truth_bot = out_test[:, :, bottom_slice, :]
truth_mid = out_test[:, :, mid_slice, :]
truth_top = out_test[:, :, top_slice, :]

combined_names = ''.join(var_input_names)

#Array to record performance metrics
corr_ensemble_full = np.zeros(nensemble)
R2_ensemble_full = np.zeros(nensemble)
corr_ensemble_top = np.zeros(nensemble)
R2_ensemble_top = np.zeros(nensemble)
corr_ensemble_mid = np.zeros(nensemble)
R2_ensemble_mid = np.zeros(nensemble)
corr_ensemble_bot = np.zeros(nensemble)
R2_ensemble_bot = np.zeros(nensemble)
for iensemble in np.arange(nensemble):
    model_filename = f'any_{combined_names}{filesuffix1}{iensemble}{filesuffix2}'
    model_path = Path(model_folder, model_filename)
    state_dict = torch.load(model_path)
    # Create a new instance of the model
    model = UNet_nobatchnorm(N_inp, N_out, bilinear = True, Nbase = Nbase)
    # Load the state_dict into the model
    model.load_state_dict(state_dict)
    # Set the model to evaluation mode
    model.eval()
    
    with torch.no_grad():
        out_mod = model(totorch(inp_test)).detach().cpu().numpy()
    #Renormalize
    out_mod = out_mod*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

    mod_bot = out_mod[:, :, bottom_slice, :]
    mod_mid = out_mod[:, :, mid_slice, :]
    mod_top = out_mod[:, :, top_slice, :]

    corr_ensemble_full[iensemble] = corr(out_test, out_mod)
    R2_ensemble_full[iensemble] = L2_R(out_test, out_mod)    
    corr_ensemble_top[iensemble] = corr(truth_top, mod_top)
    R2_ensemble_top[iensemble] = L2_R(truth_top, mod_top)  
    corr_ensemble_mid[iensemble] = corr(truth_mid, mod_mid)
    R2_ensemble_mid[iensemble] = L2_R(truth_mid, mod_mid)
    corr_ensemble_bot[iensemble] = corr(truth_bot, mod_bot)
    R2_ensemble_bot[iensemble] = L2_R(truth_bot, mod_bot)

print('full panel, correlation:')
mean_val = np.mean(corr_ensemble_full)
max_val = np.max(corr_ensemble_full)
min_val = np.min(corr_ensemble_full)
std_val = np.std(corr_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('full panel, R2:')
mean_val = np.mean(R2_ensemble_full)
max_val = np.max(R2_ensemble_full)
min_val = np.min(R2_ensemble_full)
std_val = np.std(R2_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, correlation:')
mean_val = np.mean(corr_ensemble_mid)
max_val = np.max(corr_ensemble_mid)
min_val = np.min(corr_ensemble_mid)
std_val = np.std(corr_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, R2:')
mean_val = np.mean(R2_ensemble_mid)
max_val = np.max(R2_ensemble_mid)
min_val = np.min(R2_ensemble_mid)
std_val = np.std(R2_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)

full panel, correlation:
, full & 0.9294 & 0.9300 & 0.9286 & 0.0005 \\
full panel, R2:
, full & 0.8633 & 0.8644 & 0.8611 & 0.0011 \\
midjet panel, correlation:
, mid & 0.8811 & 0.8826 & 0.8794 & 0.0008 \\
midjet panel, R2:
, mid & 0.7757 & 0.7787 & 0.7727 & 0.0016 \\


In [4]:
# ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓ Change below for each Configuration ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓
vi1 = 'u_xy_ins'

vo1 = 'ssh_cos'
vo2 = 'ssh_sin'

save_fn_prefix  = 'any_{}_{}{}_nobatchnorm'.format(vi1, vo1, vo2)
var_input_names = [vi1]
var_output_names = [vo1, vo2]
# ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑
N_inp = len(var_input_names)
N_out = len(var_output_names)

nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)

all_train_input, all_train_output = preload_data(nctrains, Ntrain)
all_test_input, all_test_output = preload_data(nctest, Ntest)

#Normalize data
#Compute mean and variance for normalization
mean_input=np.nanmean(np.concatenate((all_train_input, all_test_input), axis=0),axis=(0, 2, 3))
mean_output=np.nanmean(np.concatenate((all_train_output, all_test_output), axis=0),axis=(0, 2, 3))
#Subtract the data with their means
all_train_input=all_train_input-mean_input[None, :, None, None]
all_train_output=all_train_output-mean_output[None, :, None, None]
all_test_input=all_test_input-mean_input[None, :, None, None]
all_test_output=all_test_output-mean_output[None, :, None, None]
#Compute the variances
var_input=np.nanmean((np.concatenate((all_train_input, all_test_input), axis=0))**2,axis=(0, 2, 3))
var_output=np.nanmean((np.concatenate((all_train_output, all_test_output), axis=0))**2,axis=(0, 2, 3))
#Scale the data so that they have variance of 1
all_train_input=all_train_input/np.sqrt(var_input[None, :, None, None])
all_train_output=all_train_output/np.sqrt(var_output[None, :, None, None])
all_test_input=all_test_input/np.sqrt(var_input[None, :, None, None])
all_test_output=all_test_output/np.sqrt(var_output[None, :, None, None])

inp_test, out_test = loaddata_preloaded_test(all_test_input, all_test_output)
out_test = out_test*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

truth_bot = out_test[:, :, bottom_slice, :]
truth_mid = out_test[:, :, mid_slice, :]
truth_top = out_test[:, :, top_slice, :]

combined_names = ''.join(var_input_names)

#Array to record performance metrics
corr_ensemble_full = np.zeros(nensemble)
R2_ensemble_full = np.zeros(nensemble)
corr_ensemble_top = np.zeros(nensemble)
R2_ensemble_top = np.zeros(nensemble)
corr_ensemble_mid = np.zeros(nensemble)
R2_ensemble_mid = np.zeros(nensemble)
corr_ensemble_bot = np.zeros(nensemble)
R2_ensemble_bot = np.zeros(nensemble)
for iensemble in np.arange(nensemble):
    model_filename = f'any_{combined_names}{filesuffix1}{iensemble}{filesuffix2}'
    model_path = Path(model_folder, model_filename)
    state_dict = torch.load(model_path)
    # Create a new instance of the model
    model = UNet_nobatchnorm(N_inp, N_out, bilinear = True, Nbase = Nbase)
    # Load the state_dict into the model
    model.load_state_dict(state_dict)
    # Set the model to evaluation mode
    model.eval()
    
    with torch.no_grad():
        out_mod = model(totorch(inp_test)).detach().cpu().numpy()
    #Renormalize
    out_mod = out_mod*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

    mod_bot = out_mod[:, :, bottom_slice, :]
    mod_mid = out_mod[:, :, mid_slice, :]
    mod_top = out_mod[:, :, top_slice, :]

    corr_ensemble_full[iensemble] = corr(out_test, out_mod)
    R2_ensemble_full[iensemble] = L2_R(out_test, out_mod)    
    corr_ensemble_top[iensemble] = corr(truth_top, mod_top)
    R2_ensemble_top[iensemble] = L2_R(truth_top, mod_top)  
    corr_ensemble_mid[iensemble] = corr(truth_mid, mod_mid)
    R2_ensemble_mid[iensemble] = L2_R(truth_mid, mod_mid)
    corr_ensemble_bot[iensemble] = corr(truth_bot, mod_bot)
    R2_ensemble_bot[iensemble] = L2_R(truth_bot, mod_bot)

print('full panel, correlation:')
mean_val = np.mean(corr_ensemble_full)
max_val = np.max(corr_ensemble_full)
min_val = np.min(corr_ensemble_full)
std_val = np.std(corr_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('full panel, R2:')
mean_val = np.mean(R2_ensemble_full)
max_val = np.max(R2_ensemble_full)
min_val = np.min(R2_ensemble_full)
std_val = np.std(R2_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, correlation:')
mean_val = np.mean(corr_ensemble_mid)
max_val = np.max(corr_ensemble_mid)
min_val = np.min(corr_ensemble_mid)
std_val = np.std(corr_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, R2:')
mean_val = np.mean(R2_ensemble_mid)
max_val = np.max(R2_ensemble_mid)
min_val = np.min(R2_ensemble_mid)
std_val = np.std(R2_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)

  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)


full panel, correlation:
, full & 0.8151 & 0.8189 & 0.8090 & 0.0027 \\
full panel, R2:
, full & 0.6627 & 0.6696 & 0.6521 & 0.0047 \\
midjet panel, correlation:
, mid & 0.6946 & 0.7030 & 0.6847 & 0.0052 \\
midjet panel, R2:
, mid & 0.4800 & 0.4932 & 0.4646 & 0.0080 \\


In [5]:
# ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓ Change below for each Configuration ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓
vi1 = 'v_xy_ins'

vo1 = 'ssh_cos'
vo2 = 'ssh_sin'

save_fn_prefix  = 'any_{}_{}{}_nobatchnorm'.format(vi1, vo1, vo2)
var_input_names = [vi1]
var_output_names = [vo1, vo2]
# ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑
N_inp = len(var_input_names)
N_out = len(var_output_names)

nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)

all_train_input, all_train_output = preload_data(nctrains, Ntrain)
all_test_input, all_test_output = preload_data(nctest, Ntest)

#Normalize data
#Compute mean and variance for normalization
mean_input=np.nanmean(np.concatenate((all_train_input, all_test_input), axis=0),axis=(0, 2, 3))
mean_output=np.nanmean(np.concatenate((all_train_output, all_test_output), axis=0),axis=(0, 2, 3))
#Subtract the data with their means
all_train_input=all_train_input-mean_input[None, :, None, None]
all_train_output=all_train_output-mean_output[None, :, None, None]
all_test_input=all_test_input-mean_input[None, :, None, None]
all_test_output=all_test_output-mean_output[None, :, None, None]
#Compute the variances
var_input=np.nanmean((np.concatenate((all_train_input, all_test_input), axis=0))**2,axis=(0, 2, 3))
var_output=np.nanmean((np.concatenate((all_train_output, all_test_output), axis=0))**2,axis=(0, 2, 3))
#Scale the data so that they have variance of 1
all_train_input=all_train_input/np.sqrt(var_input[None, :, None, None])
all_train_output=all_train_output/np.sqrt(var_output[None, :, None, None])
all_test_input=all_test_input/np.sqrt(var_input[None, :, None, None])
all_test_output=all_test_output/np.sqrt(var_output[None, :, None, None])

inp_test, out_test = loaddata_preloaded_test(all_test_input, all_test_output)
out_test = out_test*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

truth_bot = out_test[:, :, bottom_slice, :]
truth_mid = out_test[:, :, mid_slice, :]
truth_top = out_test[:, :, top_slice, :]

combined_names = ''.join(var_input_names)

#Array to record performance metrics
corr_ensemble_full = np.zeros(nensemble)
R2_ensemble_full = np.zeros(nensemble)
corr_ensemble_top = np.zeros(nensemble)
R2_ensemble_top = np.zeros(nensemble)
corr_ensemble_mid = np.zeros(nensemble)
R2_ensemble_mid = np.zeros(nensemble)
corr_ensemble_bot = np.zeros(nensemble)
R2_ensemble_bot = np.zeros(nensemble)
for iensemble in np.arange(nensemble):
    model_filename = f'any_{combined_names}{filesuffix1}{iensemble}{filesuffix2}'
    model_path = Path(model_folder, model_filename)
    state_dict = torch.load(model_path)
    # Create a new instance of the model
    model = UNet_nobatchnorm(N_inp, N_out, bilinear = True, Nbase = Nbase)
    # Load the state_dict into the model
    model.load_state_dict(state_dict)
    # Set the model to evaluation mode
    model.eval()
    
    with torch.no_grad():
        out_mod = model(totorch(inp_test)).detach().cpu().numpy()
    #Renormalize
    out_mod = out_mod*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

    mod_bot = out_mod[:, :, bottom_slice, :]
    mod_mid = out_mod[:, :, mid_slice, :]
    mod_top = out_mod[:, :, top_slice, :]

    corr_ensemble_full[iensemble] = corr(out_test, out_mod)
    R2_ensemble_full[iensemble] = L2_R(out_test, out_mod)    
    corr_ensemble_top[iensemble] = corr(truth_top, mod_top)
    R2_ensemble_top[iensemble] = L2_R(truth_top, mod_top)  
    corr_ensemble_mid[iensemble] = corr(truth_mid, mod_mid)
    R2_ensemble_mid[iensemble] = L2_R(truth_mid, mod_mid)
    corr_ensemble_bot[iensemble] = corr(truth_bot, mod_bot)
    R2_ensemble_bot[iensemble] = L2_R(truth_bot, mod_bot)

print('full panel, correlation:')
mean_val = np.mean(corr_ensemble_full)
max_val = np.max(corr_ensemble_full)
min_val = np.min(corr_ensemble_full)
std_val = np.std(corr_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('full panel, R2:')
mean_val = np.mean(R2_ensemble_full)
max_val = np.max(R2_ensemble_full)
min_val = np.min(R2_ensemble_full)
std_val = np.std(R2_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, correlation:')
mean_val = np.mean(corr_ensemble_mid)
max_val = np.max(corr_ensemble_mid)
min_val = np.min(corr_ensemble_mid)
std_val = np.std(corr_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, R2:')
mean_val = np.mean(R2_ensemble_mid)
max_val = np.max(R2_ensemble_mid)
min_val = np.min(R2_ensemble_mid)
std_val = np.std(R2_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)

  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)


full panel, correlation:
, full & 0.8560 & 0.8579 & 0.8542 & 0.0011 \\
full panel, R2:
, full & 0.7320 & 0.7351 & 0.7290 & 0.0019 \\
midjet panel, correlation:
, mid & 0.7768 & 0.7798 & 0.7740 & 0.0019 \\
midjet panel, R2:
, mid & 0.6030 & 0.6074 & 0.5990 & 0.0028 \\
