In [8]:
#Inspect if the outputs from the 10 random runs deviate significantly from each other. 
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from netCDF4 import Dataset
from tqdm import tqdm
import torch.nn.functional as F
from sklearn.metrics import r2_score as R2
from sklearn.model_selection import KFold
from copy import deepcopy
import utils
from unet import UNet_nobatchnorm
from scipy.stats import pearsonr
from pathlib import Path
import numpy.fft as fft
from matplotlib.colors import TwoSlopeNorm
import matplotlib.patches as patches
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import helper_functions as hf
from scipy.signal import convolve2d, convolve


In [9]:
root_dir = '/work/uo0780/u241359/project_tide_synergy/data/'
nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)
Ntrain = np.sum([nc.dimensions['time_counter'].size for nc in nctrains], axis = 0)
Ntest = np.sum([nc.dimensions['time_counter'].size for nc in nctest], axis = 0)

model_folder = '/work/uo0780/u241359/project_tide_synergy/trainedmodels_forpaper/'

vel_cmap  = 'BrBG' #'viridis'
vort_cmap = 'PRGn'
ssh_cmap  = 'bwr'
sst_cmap = 'inferno'

bottom_slice = slice(0,256)
mid_slice = slice(232, 488)
top_slice = slice(464, 720)

def corr(data, mod):
    return pearsonr(data.flatten(), mod.flatten())[0]
def L2_R(data,mod):
    return R2(data.flatten(), mod.flatten())

nensemble = 10 #How many re-trained U-Net models there are

In [10]:
Nbase = 16
def totorch(x):
    return torch.tensor(x, dtype = torch.float).cpu()
    
def preload_data(nctrains, total_records):
    #total_records = Ntrain#sum(nc.dimensions['time_counter'].size for nc in nctrains)
    #dimensions of data of the nc file.
    max_height = 722
    max_width = 258
    all_input_data = np.zeros((total_records, N_inp, max_height, max_width))*np.nan
    all_output_data = np.zeros((total_records, N_out, max_height, max_width))*np.nan
    current_index = 0
    for ncindex, ncdata in enumerate(nctrains):
        num_recs = ncdata.dimensions['time_counter'].size
        rec_slice = slice(current_index, current_index + num_recs)
        
        for ind, var_name in enumerate(var_input_names):
            data_slice = np.squeeze(ncdata.variables[var_name])
            # print('data_slice shape:')
            # print(data_slice.shape)        
            #all_input_data[rec_slice, ind, :, :] = data_slice
            #For some variables, the dimensions in (x, y) may be smaller than (max_height, max_width). Changing the code so that it adapts them.
            # Get the actual dimensions of data_slice
            slice_height, slice_width = data_slice.shape[-2], data_slice.shape[-1]
            # Place data_slice into the corresponding slice of all_input_data
            all_input_data[rec_slice, ind, :slice_height, :slice_width] = data_slice
    

        for ind, var_name in enumerate(var_output_names):
            data_slice = np.squeeze(ncdata.variables[var_name])
            #all_output_data[rec_slice, ind, :, :] = data_slice
            # Get the actual dimensions of data_slice
            slice_height, slice_width = data_slice.shape[-2], data_slice.shape[-1]
            # Place data_slice into the corresponding slice of all_input_data
            all_output_data[rec_slice, ind, :slice_height, :slice_width] = data_slice

        current_index += num_recs
        
    return all_input_data, all_output_data

# # Modify the loadtrain function to pull data from preloaded memory
# def loaddata_preloaded_train(index, batch_size, all_input_data, all_output_data):
#     rec_slice = slice(index, index + batch_size)
#     lim = 720
#     width = 256
#     yslice = slice(0, lim)
#     xslice = slice(0, width)
#     # print('rec_slice is:')
#     # print(rec_slice)
#     # print('mean of squared values of loaded input data:')
#     # print("{0:0.32f}".format(np.nanmean(all_input_data[rec_slice, :, yslice, xslice]**2)))
#     return (all_input_data[rec_slice, :, yslice, xslice], 
#             all_output_data[rec_slice, :, yslice, xslice])
#Load test data as one single batch
def loaddata_preloaded_test(all_input_data, all_output_data):
    #rec_slice = slice(index, index + batch_size)
    lim = 720
    width = 256
    yslice = slice(0, lim)
    xslice = slice(0, width)
    # print('rec_slice is:')
    # print(rec_slice)
    # print('mean of squared values of loaded input data:')
    # print("{0:0.32f}".format(np.nanmean(all_input_data[rec_slice, :, yslice, xslice]**2)))
    return (all_input_data[:, :, yslice, xslice], 
            all_output_data[:, :, yslice, xslice])


def load_variable(ncdata, ncindex, variable, rec_slice, yslice, xslice):
    data_squeezed = np.squeeze(ncdata[ncindex].variables[variable])
    return data_squeezed[rec_slice, yslice, xslice]

def hwvorticity(u, v, dgrid = 4000):
    return (np.gradient(v, axis =2) - np.gradient(u, axis =1))/dgrid

def hwdivergence(u, v, dgrid = 4000):
    return (np.gradient(u, axis =2) + np.gradient(v, axis =1))/dgrid

def preload_data_vortdiv(nctrains, total_records):
    #total_records = Ntrain#sum(nc.dimensions['time_counter'].size for nc in nctrains)
    #dimensions of data of the nc file.
    max_height = 722
    max_width = 258
    all_input_data = np.zeros((total_records, N_inp, max_height, max_width))*np.nan
    all_output_data = np.zeros((total_records, N_out, max_height, max_width))*np.nan
    current_index = 0
    for ncindex, ncdata in enumerate(nctrains):
        num_recs = ncdata.dimensions['time_counter'].size #how many time stamps are there in each .nc file (i.e., at each turbulence level)
        rec_slice = slice(current_index, current_index + num_recs)
        for ind, var_name in enumerate(var_input_names):
            if var_name == 'vort':
                u = np.squeeze(ncdata.variables['u_xy_ins'])
                v = np.squeeze(ncdata.variables['v_xy_ins'])
                #u.shape: (150, 722, 257); v.shape: (150, 721, 258)
                #as u and v have different number of grid points in x and y, we truncate them so that their shapes agree, enabling the simple way to compute vorticities based on finite diff.
                data_slice = hwvorticity(u[:,:-1,:], v[:,:,:-1])
            elif var_name == 'div':
                u = np.squeeze(ncdata.variables['u_xy_ins'])
                v = np.squeeze(ncdata.variables['v_xy_ins'])
                data_slice = hwdivergence(u[:,:-1,:], v[:,:,:-1])
            else:           
                data_slice = np.squeeze(ncdata.variables[var_name])
            # print('data_slice shape:')
            # print(data_slice.shape)        
            #all_input_data[rec_slice, ind, :, :] = data_slice
            
            #For some variables, the dimensions in (x, y) may be smaller than (max_height, max_width). Changing the code so that it adapts them.
            # Get the actual dimensions of data_slice
            slice_height, slice_width = data_slice.shape[-2], data_slice.shape[-1]
            # Place data_slice into the corresponding slice of all_input_data
            all_input_data[rec_slice, ind, :slice_height, :slice_width] = data_slice
    

        for ind, var_name in enumerate(var_output_names):
            data_slice = np.squeeze(ncdata.variables[var_name])
            #all_output_data[rec_slice, ind, :, :] = data_slice
            # Get the actual dimensions of data_slice
            slice_height, slice_width = data_slice.shape[-2], data_slice.shape[-1]
            # Place data_slice into the corresponding slice of all_input_data
            all_output_data[rec_slice, ind, :slice_height, :slice_width] = data_slice

        current_index += num_recs
        
    return all_input_data, all_output_data
    


In [11]:
#Functions for low-pass filtering
def gaussian_kernel(decaylength): 
    """Generates a Gaussian kernel."""
    #decaylength is in the unit of grid resolution (4km in Aurelien's data.) So in physical units, the decay lenght would be decaylength*(4 km).
    size=int(2*decaylength)
    sigma=decaylength/(2*np.sqrt(2*np.log(2))) #Interpretting decaylength as the FWHM of Gaussian
    kernel = np.fromfunction(
        lambda x, y: (1 / np.sqrt(2 * np.pi * sigma ** 2)) * 
                      np.exp(-((x- size/2)**2 + (y-size/2)**2) / (2 * sigma ** 2)),
        (size, size)  
    ) #Creating a kernel with 
    return kernel / np.sum(kernel)  # Normalize the kernel
    
def degrade_space_gaussian(field, decaylength):
    nt, nx, ny = np.shape(field)
    kernel = gaussian_kernel(decaylength)
    filtered_field = np.empty([nt, nx, ny])

    for i in range(nt):
        filtered_field[i, : ,:] = convolve2d(field[i, : ,:], kernel, mode = 'same', boundary='symm')#,  fillvalue = np.average(field[i, : ,:]))
    return filtered_field

# Load all data into memory; no normalization is done here yet.
# Apply a spatial lowpass filter to the temperature field 'T_xy_ins'
# decayunits is how many units of grid spacing is the decay length scale. A grid spacing is 4km in Aurelien's data. So in physical units, the decay lenght would be decayunits*(4 km).
def preload_data_filterT(nctrains, total_records,decayunits=25,plot=True):
    #total_records = Ntrain#sum(nc.dimensions['time_counter'].size for nc in nctrains)
    #dimensions of data of the nc file.
    max_height = 722
    max_width = 258
    all_input_data = np.zeros((total_records, N_inp, max_height, max_width))*np.nan
    all_output_data = np.zeros((total_records, N_out, max_height, max_width))*np.nan
    current_index = 0
    for ncindex, ncdata in enumerate(nctrains):
        num_recs = ncdata.dimensions['time_counter'].size
        rec_slice = slice(current_index, current_index + num_recs)
        
        for ind, var_name in enumerate(var_input_names):
            data_slice = np.squeeze(ncdata.variables[var_name])
            # print('data_slice shape:')
            # print(data_slice.shape)        
            #Turns out to be (time, height, width)
            # print('var_name:')
            # print(var_name)
            # Apply lowpass filter when the field is 'T_xy_ins'
            if var_name == 'T_xy_ins':
                if plot == True:
                    #Plot an image before the filter
                    itime=20        
                    cmapmax=np.max(data_slice[itime,:,:])
                    cmapmin=np.min(data_slice[itime,:,:])
                    figT, axT = plt.subplots(1, 2, figsize=(5, 5))
                    figT.set_dpi(256)   
                    im0=axT[0].pcolor(data_slice[itime,:,:],vmin=cmapmin,vmax=cmapmax)
                    axT[0].set_aspect(1)
                #Lowpass filter
                data_slice=degrade_space_gaussian(data_slice,decayunits)
                if plot == True:
                    axT[1].pcolor(data_slice[itime,:,:],vmin=cmapmin,vmax=cmapmax)
                    axT[1].set_aspect(1)
                    cbar0=plt.colorbar(im0, ax=axT, fraction=0.046, pad=0.04)
            
            #For some variables, the dimensions in (x, y) may be smaller than (max_height, max_width). Changing the code so that it adapts them.
            # Get the actual dimensions of data_slice
            slice_height, slice_width = data_slice.shape[-2], data_slice.shape[-1]
            # Place data_slice into the corresponding slice of all_input_data
            all_input_data[rec_slice, ind, :slice_height, :slice_width] = data_slice
    

        for ind, var_name in enumerate(var_output_names):
            data_slice = np.squeeze(ncdata.variables[var_name])
            if var_name == 'T_xy_ins':
                if plot == True:
                    #Plot an image before the filter
                    itime=20        
                    cmapmax=np.max(data_slice[itime,:,:])
                    cmapmin=np.min(data_slice[itime,:,:])
                    figT, axT = plt.subplots(1, 2, figsize=(5, 5))
                    figT.set_dpi(256)   
                    im0=axT[0].pcolor(data_slice[itime,:,:],vmin=cmapmin,vmax=cmapmax)
                    axT[0].set_aspect(1)
                #Lowpass filter
                
                data_slice=degrade_space_gaussian(data_slice,decayunits)
                if plot == True:
                    axT[1].pcolor(data_slice[itime,:,:],vmin=cmapmin,vmax=cmapmax)
                    axT[1].set_aspect(1)
                    cbar0=plt.colorbar(im0, ax=axT, fraction=0.046, pad=0.04)
            #all_output_data[rec_slice, ind, :, :] = data_slice
            # Get the actual dimensions of data_slice
            slice_height, slice_width = data_slice.shape[-2], data_slice.shape[-1]
            # Place data_slice into the corresponding slice of all_input_data
            all_output_data[rec_slice, ind, :slice_height, :slice_width] = data_slice

        current_index += num_recs
        
    return all_input_data, all_output_data


In [14]:
# ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓ Change below for each Configuration ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓
decayunits=25 #The best SST satellite has a resolution of 9 km. It may be safe to set decayunits to be 20 km to be already resolvable by satellites

vi1 = 'ssh_ins'
vi2 = 'T_xy_ins'

vo1 = 'ssh_cos'
vo2 = 'ssh_sin'

batch_size = 80 #maximizing it so that the GPU memory maxes out. Needs to be divisible by Ntrain. Otherwise there will be size mismatch issues.
lr0 = 0.005*10/batch_size #Roughly should scale inversely to batch_size

var_input_names = [vi1, vi2]
var_output_names = [vo1, vo2]
N_inp = len(var_input_names)
N_out = len(var_output_names)

save_fn_prefix  = 'any_{}{}_{}{}_nobatchnorm_degradeT_du_{}'.format(vi1, vi2, vo1, vo2, decayunits)
# ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑
nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)

all_train_input, all_train_output = preload_data_filterT(nctrains, Ntrain, decayunits=decayunits,plot=False)
all_test_input, all_test_output = preload_data_filterT(nctest, Ntest, decayunits=decayunits,plot=False)

#Normalize data
#Compute mean and variance for normalization
mean_input=np.nanmean(np.concatenate((all_train_input, all_test_input), axis=0),axis=(0, 2, 3))
mean_output=np.nanmean(np.concatenate((all_train_output, all_test_output), axis=0),axis=(0, 2, 3))
#Subtract the data with their means
all_train_input=all_train_input-mean_input[None, :, None, None]
all_train_output=all_train_output-mean_output[None, :, None, None]
all_test_input=all_test_input-mean_input[None, :, None, None]
all_test_output=all_test_output-mean_output[None, :, None, None]
#Compute the variances
var_input=np.nanmean((np.concatenate((all_train_input, all_test_input), axis=0))**2,axis=(0, 2, 3))
var_output=np.nanmean((np.concatenate((all_train_output, all_test_output), axis=0))**2,axis=(0, 2, 3))
#Scale the data so that they have variance of 1
all_train_input=all_train_input/np.sqrt(var_input[None, :, None, None])
all_train_output=all_train_output/np.sqrt(var_output[None, :, None, None])
all_test_input=all_test_input/np.sqrt(var_input[None, :, None, None])
all_test_output=all_test_output/np.sqrt(var_output[None, :, None, None])

inp_test, out_test = loaddata_preloaded_test(all_test_input, all_test_output)
out_test = out_test*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

truth_bot = out_test[:, :, bottom_slice, :]
truth_mid = out_test[:, :, mid_slice, :]
truth_top = out_test[:, :, top_slice, :]

combined_names = ''.join(var_input_names)

#Array to record performance metrics
corr_ensemble_full = np.zeros(nensemble)
R2_ensemble_full = np.zeros(nensemble)
corr_ensemble_top = np.zeros(nensemble)
R2_ensemble_top = np.zeros(nensemble)
corr_ensemble_mid = np.zeros(nensemble)
R2_ensemble_mid = np.zeros(nensemble)
corr_ensemble_bot = np.zeros(nensemble)
R2_ensemble_bot = np.zeros(nensemble)
for iensemble in np.arange(nensemble):
    fstr = f'{save_fn_prefix}_rp_{iensemble}' 
    model_filename = f'/{fstr}.pth'
    model_path = model_folder+ model_filename
    state_dict = torch.load(model_path)
    # Create a new instance of the model
    model = UNet_nobatchnorm(N_inp, N_out, bilinear = True, Nbase = Nbase)
    # Load the state_dict into the model
    model.load_state_dict(state_dict)
    # Set the model to evaluation mode
    model.eval()
    
    with torch.no_grad():
        out_mod = model(totorch(inp_test)).detach().cpu().numpy()
    #Renormalize
    out_mod = out_mod*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

    mod_bot = out_mod[:, :, bottom_slice, :]
    mod_mid = out_mod[:, :, mid_slice, :]
    mod_top = out_mod[:, :, top_slice, :]

    corr_ensemble_full[iensemble] = corr(out_test, out_mod)
    R2_ensemble_full[iensemble] = L2_R(out_test, out_mod)    
    corr_ensemble_top[iensemble] = corr(truth_top, mod_top)
    R2_ensemble_top[iensemble] = L2_R(truth_top, mod_top)  
    corr_ensemble_mid[iensemble] = corr(truth_mid, mod_mid)
    R2_ensemble_mid[iensemble] = L2_R(truth_mid, mod_mid)
    corr_ensemble_bot[iensemble] = corr(truth_bot, mod_bot)
    R2_ensemble_bot[iensemble] = L2_R(truth_bot, mod_bot)

print('full panel, correlation:')
mean_val = np.mean(corr_ensemble_full)
max_val = np.max(corr_ensemble_full)
min_val = np.min(corr_ensemble_full)
std_val = np.std(corr_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('full panel, R2:')
mean_val = np.mean(R2_ensemble_full)
max_val = np.max(R2_ensemble_full)
min_val = np.min(R2_ensemble_full)
std_val = np.std(R2_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, correlation:')
mean_val = np.mean(corr_ensemble_mid)
max_val = np.max(corr_ensemble_mid)
min_val = np.min(corr_ensemble_mid)
std_val = np.std(corr_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, R2:')
mean_val = np.mean(R2_ensemble_mid)
max_val = np.max(R2_ensemble_mid)
min_val = np.min(R2_ensemble_mid)
std_val = np.std(R2_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)

  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)


full panel, correlation:
, full & 0.8394 & 0.8432 & 0.8341 & 0.0028 \\
full panel, R2:
, full & 0.7036 & 0.7093 & 0.6957 & 0.0043 \\
midjet panel, correlation:
, mid & 0.7616 & 0.7676 & 0.7521 & 0.0053 \\
midjet panel, R2:
, mid & 0.5797 & 0.5888 & 0.5654 & 0.0080 \\


In [15]:
# ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓ Change below for each Configuration ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓
decayunits=5 #The best SST satellite has a resolution of 9 km. It may be safe to set decayunits to be 20 km to be already resolvable by satellites

vi1 = 'ssh_ins'
vi2 = 'T_xy_ins'

vo1 = 'ssh_cos'
vo2 = 'ssh_sin'

batch_size = 80 #maximizing it so that the GPU memory maxes out. Needs to be divisible by Ntrain. Otherwise there will be size mismatch issues.
lr0 = 0.005*10/batch_size #Roughly should scale inversely to batch_size

var_input_names = [vi1, vi2]
var_output_names = [vo1, vo2]
N_inp = len(var_input_names)
N_out = len(var_output_names)

save_fn_prefix  = 'any_{}{}_{}{}_nobatchnorm_degradeT_du_{}'.format(vi1, vi2, vo1, vo2, decayunits)
# ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑
nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)

all_train_input, all_train_output = preload_data_filterT(nctrains, Ntrain, decayunits=decayunits,plot=False)
all_test_input, all_test_output = preload_data_filterT(nctest, Ntest, decayunits=decayunits,plot=False)

#Normalize data
#Compute mean and variance for normalization
mean_input=np.nanmean(np.concatenate((all_train_input, all_test_input), axis=0),axis=(0, 2, 3))
mean_output=np.nanmean(np.concatenate((all_train_output, all_test_output), axis=0),axis=(0, 2, 3))
#Subtract the data with their means
all_train_input=all_train_input-mean_input[None, :, None, None]
all_train_output=all_train_output-mean_output[None, :, None, None]
all_test_input=all_test_input-mean_input[None, :, None, None]
all_test_output=all_test_output-mean_output[None, :, None, None]
#Compute the variances
var_input=np.nanmean((np.concatenate((all_train_input, all_test_input), axis=0))**2,axis=(0, 2, 3))
var_output=np.nanmean((np.concatenate((all_train_output, all_test_output), axis=0))**2,axis=(0, 2, 3))
#Scale the data so that they have variance of 1
all_train_input=all_train_input/np.sqrt(var_input[None, :, None, None])
all_train_output=all_train_output/np.sqrt(var_output[None, :, None, None])
all_test_input=all_test_input/np.sqrt(var_input[None, :, None, None])
all_test_output=all_test_output/np.sqrt(var_output[None, :, None, None])

inp_test, out_test = loaddata_preloaded_test(all_test_input, all_test_output)
out_test = out_test*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

truth_bot = out_test[:, :, bottom_slice, :]
truth_mid = out_test[:, :, mid_slice, :]
truth_top = out_test[:, :, top_slice, :]

combined_names = ''.join(var_input_names)

#Array to record performance metrics
corr_ensemble_full = np.zeros(nensemble)
R2_ensemble_full = np.zeros(nensemble)
corr_ensemble_top = np.zeros(nensemble)
R2_ensemble_top = np.zeros(nensemble)
corr_ensemble_mid = np.zeros(nensemble)
R2_ensemble_mid = np.zeros(nensemble)
corr_ensemble_bot = np.zeros(nensemble)
R2_ensemble_bot = np.zeros(nensemble)
for iensemble in np.arange(nensemble):
    fstr = f'{save_fn_prefix}_rp_{iensemble}' 
    model_filename = f'/{fstr}.pth'
    model_path = model_folder+ model_filename
    state_dict = torch.load(model_path)
    # Create a new instance of the model
    model = UNet_nobatchnorm(N_inp, N_out, bilinear = True, Nbase = Nbase)
    # Load the state_dict into the model
    model.load_state_dict(state_dict)
    # Set the model to evaluation mode
    model.eval()
    
    with torch.no_grad():
        out_mod = model(totorch(inp_test)).detach().cpu().numpy()
    #Renormalize
    out_mod = out_mod*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

    mod_bot = out_mod[:, :, bottom_slice, :]
    mod_mid = out_mod[:, :, mid_slice, :]
    mod_top = out_mod[:, :, top_slice, :]

    corr_ensemble_full[iensemble] = corr(out_test, out_mod)
    R2_ensemble_full[iensemble] = L2_R(out_test, out_mod)    
    corr_ensemble_top[iensemble] = corr(truth_top, mod_top)
    R2_ensemble_top[iensemble] = L2_R(truth_top, mod_top)  
    corr_ensemble_mid[iensemble] = corr(truth_mid, mod_mid)
    R2_ensemble_mid[iensemble] = L2_R(truth_mid, mod_mid)
    corr_ensemble_bot[iensemble] = corr(truth_bot, mod_bot)
    R2_ensemble_bot[iensemble] = L2_R(truth_bot, mod_bot)

print('full panel, correlation:')
mean_val = np.mean(corr_ensemble_full)
max_val = np.max(corr_ensemble_full)
min_val = np.min(corr_ensemble_full)
std_val = np.std(corr_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('full panel, R2:')
mean_val = np.mean(R2_ensemble_full)
max_val = np.max(R2_ensemble_full)
min_val = np.min(R2_ensemble_full)
std_val = np.std(R2_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, correlation:')
mean_val = np.mean(corr_ensemble_mid)
max_val = np.max(corr_ensemble_mid)
min_val = np.min(corr_ensemble_mid)
std_val = np.std(corr_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, R2:')
mean_val = np.mean(R2_ensemble_mid)
max_val = np.max(R2_ensemble_mid)
min_val = np.min(R2_ensemble_mid)
std_val = np.std(R2_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)

  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)


full panel, correlation:
, full & 0.8556 & 0.8628 & 0.8486 & 0.0042 \\
full panel, R2:
, full & 0.7309 & 0.7434 & 0.7199 & 0.0073 \\
midjet panel, correlation:
, mid & 0.7897 & 0.8012 & 0.7788 & 0.0068 \\
midjet panel, R2:
, mid & 0.6231 & 0.6416 & 0.6064 & 0.0106 \\


In [16]:
# ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓ Change below for each Configuration ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓
decayunits=25 #The best SST satellite has a resolution of 9 km. It may be safe to set decayunits to be 20 km to be already resolvable by satellites

vi1 = 'T_xy_ins'
vi2 = 'u_xy_ins'
vi3 = 'v_xy_ins'

vo1 = 'ssh_cos'
vo2 = 'ssh_sin'


batch_size = 60 #maximizing it so that the GPU memory maxes out. Needs to be divisible by Ntrain. Otherwise there will be size mismatch issues.
lr0 = 0.005*10/batch_size #Roughly should scale inversely to batch_size


var_input_names = [vi1, vi2, vi3]
var_output_names = [vo1, vo2]
N_inp = len(var_input_names)
N_out = len(var_output_names)

save_fn_prefix  = 'any_{}{}{}_{}{}_nobatchnorm_degradeT_du_{}'.format(vi1, vi2, vi3, vo1, vo2, decayunits)
# ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑
nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)

all_train_input, all_train_output = preload_data_filterT(nctrains, Ntrain, decayunits=decayunits,plot=False)
all_test_input, all_test_output = preload_data_filterT(nctest, Ntest, decayunits=decayunits,plot=False)

#Normalize data
#Compute mean and variance for normalization
mean_input=np.nanmean(np.concatenate((all_train_input, all_test_input), axis=0),axis=(0, 2, 3))
mean_output=np.nanmean(np.concatenate((all_train_output, all_test_output), axis=0),axis=(0, 2, 3))
#Subtract the data with their means
all_train_input=all_train_input-mean_input[None, :, None, None]
all_train_output=all_train_output-mean_output[None, :, None, None]
all_test_input=all_test_input-mean_input[None, :, None, None]
all_test_output=all_test_output-mean_output[None, :, None, None]
#Compute the variances
var_input=np.nanmean((np.concatenate((all_train_input, all_test_input), axis=0))**2,axis=(0, 2, 3))
var_output=np.nanmean((np.concatenate((all_train_output, all_test_output), axis=0))**2,axis=(0, 2, 3))
#Scale the data so that they have variance of 1
all_train_input=all_train_input/np.sqrt(var_input[None, :, None, None])
all_train_output=all_train_output/np.sqrt(var_output[None, :, None, None])
all_test_input=all_test_input/np.sqrt(var_input[None, :, None, None])
all_test_output=all_test_output/np.sqrt(var_output[None, :, None, None])

inp_test, out_test = loaddata_preloaded_test(all_test_input, all_test_output)
out_test = out_test*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

truth_bot = out_test[:, :, bottom_slice, :]
truth_mid = out_test[:, :, mid_slice, :]
truth_top = out_test[:, :, top_slice, :]

combined_names = ''.join(var_input_names)

#Array to record performance metrics
corr_ensemble_full = np.zeros(nensemble)
R2_ensemble_full = np.zeros(nensemble)
corr_ensemble_top = np.zeros(nensemble)
R2_ensemble_top = np.zeros(nensemble)
corr_ensemble_mid = np.zeros(nensemble)
R2_ensemble_mid = np.zeros(nensemble)
corr_ensemble_bot = np.zeros(nensemble)
R2_ensemble_bot = np.zeros(nensemble)
for iensemble in np.arange(nensemble):
    fstr = f'{save_fn_prefix}_rp_{iensemble}' 
    model_filename = f'/{fstr}.pth'
    model_path = model_folder+ model_filename
    state_dict = torch.load(model_path)
    # Create a new instance of the model
    model = UNet_nobatchnorm(N_inp, N_out, bilinear = True, Nbase = Nbase)
    # Load the state_dict into the model
    model.load_state_dict(state_dict)
    # Set the model to evaluation mode
    model.eval()
    
    with torch.no_grad():
        out_mod = model(totorch(inp_test)).detach().cpu().numpy()
    #Renormalize
    out_mod = out_mod*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

    mod_bot = out_mod[:, :, bottom_slice, :]
    mod_mid = out_mod[:, :, mid_slice, :]
    mod_top = out_mod[:, :, top_slice, :]

    corr_ensemble_full[iensemble] = corr(out_test, out_mod)
    R2_ensemble_full[iensemble] = L2_R(out_test, out_mod)    
    corr_ensemble_top[iensemble] = corr(truth_top, mod_top)
    R2_ensemble_top[iensemble] = L2_R(truth_top, mod_top)  
    corr_ensemble_mid[iensemble] = corr(truth_mid, mod_mid)
    R2_ensemble_mid[iensemble] = L2_R(truth_mid, mod_mid)
    corr_ensemble_bot[iensemble] = corr(truth_bot, mod_bot)
    R2_ensemble_bot[iensemble] = L2_R(truth_bot, mod_bot)

print('full panel, correlation:')
mean_val = np.mean(corr_ensemble_full)
max_val = np.max(corr_ensemble_full)
min_val = np.min(corr_ensemble_full)
std_val = np.std(corr_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('full panel, R2:')
mean_val = np.mean(R2_ensemble_full)
max_val = np.max(R2_ensemble_full)
min_val = np.min(R2_ensemble_full)
std_val = np.std(R2_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, correlation:')
mean_val = np.mean(corr_ensemble_mid)
max_val = np.max(corr_ensemble_mid)
min_val = np.min(corr_ensemble_mid)
std_val = np.std(corr_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, R2:')
mean_val = np.mean(R2_ensemble_mid)
max_val = np.max(R2_ensemble_mid)
min_val = np.min(R2_ensemble_mid)
std_val = np.std(R2_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)

  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)


full panel, correlation:
, full & 0.9456 & 0.9464 & 0.9442 & 0.0006 \\
full panel, R2:
, full & 0.8941 & 0.8956 & 0.8914 & 0.0012 \\
midjet panel, correlation:
, mid & 0.9083 & 0.9093 & 0.9057 & 0.0010 \\
midjet panel, R2:
, mid & 0.8249 & 0.8267 & 0.8201 & 0.0018 \\


In [17]:
# ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓ Change below for each Configuration ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓
decayunits=5 #The best SST satellite has a resolution of 9 km. It may be safe to set decayunits to be 20 km to be already resolvable by satellites

vi1 = 'T_xy_ins'
vi2 = 'u_xy_ins'
vi3 = 'v_xy_ins'

vo1 = 'ssh_cos'
vo2 = 'ssh_sin'


batch_size = 60 #maximizing it so that the GPU memory maxes out. Needs to be divisible by Ntrain. Otherwise there will be size mismatch issues.
lr0 = 0.005*10/batch_size #Roughly should scale inversely to batch_size


var_input_names = [vi1, vi2, vi3]
var_output_names = [vo1, vo2]
N_inp = len(var_input_names)
N_out = len(var_output_names)

save_fn_prefix  = 'any_{}{}{}_{}{}_nobatchnorm_degradeT_du_{}'.format(vi1, vi2, vi3, vo1, vo2, decayunits)
# ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑
nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)

all_train_input, all_train_output = preload_data_filterT(nctrains, Ntrain, decayunits=decayunits,plot=False)
all_test_input, all_test_output = preload_data_filterT(nctest, Ntest, decayunits=decayunits,plot=False)

#Normalize data
#Compute mean and variance for normalization
mean_input=np.nanmean(np.concatenate((all_train_input, all_test_input), axis=0),axis=(0, 2, 3))
mean_output=np.nanmean(np.concatenate((all_train_output, all_test_output), axis=0),axis=(0, 2, 3))
#Subtract the data with their means
all_train_input=all_train_input-mean_input[None, :, None, None]
all_train_output=all_train_output-mean_output[None, :, None, None]
all_test_input=all_test_input-mean_input[None, :, None, None]
all_test_output=all_test_output-mean_output[None, :, None, None]
#Compute the variances
var_input=np.nanmean((np.concatenate((all_train_input, all_test_input), axis=0))**2,axis=(0, 2, 3))
var_output=np.nanmean((np.concatenate((all_train_output, all_test_output), axis=0))**2,axis=(0, 2, 3))
#Scale the data so that they have variance of 1
all_train_input=all_train_input/np.sqrt(var_input[None, :, None, None])
all_train_output=all_train_output/np.sqrt(var_output[None, :, None, None])
all_test_input=all_test_input/np.sqrt(var_input[None, :, None, None])
all_test_output=all_test_output/np.sqrt(var_output[None, :, None, None])

inp_test, out_test = loaddata_preloaded_test(all_test_input, all_test_output)
out_test = out_test*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

truth_bot = out_test[:, :, bottom_slice, :]
truth_mid = out_test[:, :, mid_slice, :]
truth_top = out_test[:, :, top_slice, :]

combined_names = ''.join(var_input_names)

#Array to record performance metrics
corr_ensemble_full = np.zeros(nensemble)
R2_ensemble_full = np.zeros(nensemble)
corr_ensemble_top = np.zeros(nensemble)
R2_ensemble_top = np.zeros(nensemble)
corr_ensemble_mid = np.zeros(nensemble)
R2_ensemble_mid = np.zeros(nensemble)
corr_ensemble_bot = np.zeros(nensemble)
R2_ensemble_bot = np.zeros(nensemble)
for iensemble in np.arange(nensemble):
    fstr = f'{save_fn_prefix}_rp_{iensemble}' 
    model_filename = f'/{fstr}.pth'
    model_path = model_folder+ model_filename
    state_dict = torch.load(model_path)
    # Create a new instance of the model
    model = UNet_nobatchnorm(N_inp, N_out, bilinear = True, Nbase = Nbase)
    # Load the state_dict into the model
    model.load_state_dict(state_dict)
    # Set the model to evaluation mode
    model.eval()
    
    with torch.no_grad():
        out_mod = model(totorch(inp_test)).detach().cpu().numpy()
    #Renormalize
    out_mod = out_mod*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

    mod_bot = out_mod[:, :, bottom_slice, :]
    mod_mid = out_mod[:, :, mid_slice, :]
    mod_top = out_mod[:, :, top_slice, :]

    corr_ensemble_full[iensemble] = corr(out_test, out_mod)
    R2_ensemble_full[iensemble] = L2_R(out_test, out_mod)    
    corr_ensemble_top[iensemble] = corr(truth_top, mod_top)
    R2_ensemble_top[iensemble] = L2_R(truth_top, mod_top)  
    corr_ensemble_mid[iensemble] = corr(truth_mid, mod_mid)
    R2_ensemble_mid[iensemble] = L2_R(truth_mid, mod_mid)
    corr_ensemble_bot[iensemble] = corr(truth_bot, mod_bot)
    R2_ensemble_bot[iensemble] = L2_R(truth_bot, mod_bot)

print('full panel, correlation:')
mean_val = np.mean(corr_ensemble_full)
max_val = np.max(corr_ensemble_full)
min_val = np.min(corr_ensemble_full)
std_val = np.std(corr_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('full panel, R2:')
mean_val = np.mean(R2_ensemble_full)
max_val = np.max(R2_ensemble_full)
min_val = np.min(R2_ensemble_full)
std_val = np.std(R2_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, correlation:')
mean_val = np.mean(corr_ensemble_mid)
max_val = np.max(corr_ensemble_mid)
min_val = np.min(corr_ensemble_mid)
std_val = np.std(corr_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, R2:')
mean_val = np.mean(R2_ensemble_mid)
max_val = np.max(R2_ensemble_mid)
min_val = np.min(R2_ensemble_mid)
std_val = np.std(R2_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)

  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)


full panel, correlation:
, full & 0.9514 & 0.9528 & 0.9492 & 0.0010 \\
full panel, R2:
, full & 0.9050 & 0.9078 & 0.9008 & 0.0020 \\
midjet panel, correlation:
, mid & 0.9183 & 0.9208 & 0.9151 & 0.0017 \\
midjet panel, R2:
, mid & 0.8432 & 0.8477 & 0.8373 & 0.0031 \\


In [18]:
# ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓ Change below for each Configuration ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓
decayunits=25 #The best SST satellite has a resolution of 9 km. It may be safe to set decayunits to be 20 km to be already resolvable by satellites

vi1 = 'ssh_ins'
vi2 = 'T_xy_ins'
vi3 = 'u_xy_ins'
vi4 = 'v_xy_ins'

vo1 = 'ssh_cos'
vo2 = 'ssh_sin'

batch_size = 50 #maximizing it so that the GPU memory maxes out. Needs to be divisible by Ntrain. Otherwise there will be size mismatch issues.
lr0 = 0.005*10/batch_size #Roughly should scale inversely to batch_size

var_input_names = [vi1, vi2, vi3, vi4]
var_output_names = [vo1, vo2]
N_inp = len(var_input_names)
N_out = len(var_output_names)

save_fn_prefix  = 'any_{}{}{}{}_{}{}_nobatchnorm_degradeT_du_{}'.format(vi1, vi2, vi3, vi4, vo1, vo2, decayunits)
nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)
# ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑
nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)

all_train_input, all_train_output = preload_data_filterT(nctrains, Ntrain, decayunits=decayunits,plot=False)
all_test_input, all_test_output = preload_data_filterT(nctest, Ntest, decayunits=decayunits,plot=False)

#Normalize data
#Compute mean and variance for normalization
mean_input=np.nanmean(np.concatenate((all_train_input, all_test_input), axis=0),axis=(0, 2, 3))
mean_output=np.nanmean(np.concatenate((all_train_output, all_test_output), axis=0),axis=(0, 2, 3))
#Subtract the data with their means
all_train_input=all_train_input-mean_input[None, :, None, None]
all_train_output=all_train_output-mean_output[None, :, None, None]
all_test_input=all_test_input-mean_input[None, :, None, None]
all_test_output=all_test_output-mean_output[None, :, None, None]
#Compute the variances
var_input=np.nanmean((np.concatenate((all_train_input, all_test_input), axis=0))**2,axis=(0, 2, 3))
var_output=np.nanmean((np.concatenate((all_train_output, all_test_output), axis=0))**2,axis=(0, 2, 3))
#Scale the data so that they have variance of 1
all_train_input=all_train_input/np.sqrt(var_input[None, :, None, None])
all_train_output=all_train_output/np.sqrt(var_output[None, :, None, None])
all_test_input=all_test_input/np.sqrt(var_input[None, :, None, None])
all_test_output=all_test_output/np.sqrt(var_output[None, :, None, None])

inp_test, out_test = loaddata_preloaded_test(all_test_input, all_test_output)
out_test = out_test*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

truth_bot = out_test[:, :, bottom_slice, :]
truth_mid = out_test[:, :, mid_slice, :]
truth_top = out_test[:, :, top_slice, :]

combined_names = ''.join(var_input_names)

#Array to record performance metrics
corr_ensemble_full = np.zeros(nensemble)
R2_ensemble_full = np.zeros(nensemble)
corr_ensemble_top = np.zeros(nensemble)
R2_ensemble_top = np.zeros(nensemble)
corr_ensemble_mid = np.zeros(nensemble)
R2_ensemble_mid = np.zeros(nensemble)
corr_ensemble_bot = np.zeros(nensemble)
R2_ensemble_bot = np.zeros(nensemble)
for iensemble in np.arange(nensemble):
    fstr = f'{save_fn_prefix}_rp_{iensemble}' 
    model_filename = f'/{fstr}.pth'
    model_path = model_folder+ model_filename
    state_dict = torch.load(model_path)
    # Create a new instance of the model
    model = UNet_nobatchnorm(N_inp, N_out, bilinear = True, Nbase = Nbase)
    # Load the state_dict into the model
    model.load_state_dict(state_dict)
    # Set the model to evaluation mode
    model.eval()
    
    with torch.no_grad():
        out_mod = model(totorch(inp_test)).detach().cpu().numpy()
    #Renormalize
    out_mod = out_mod*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

    mod_bot = out_mod[:, :, bottom_slice, :]
    mod_mid = out_mod[:, :, mid_slice, :]
    mod_top = out_mod[:, :, top_slice, :]

    corr_ensemble_full[iensemble] = corr(out_test, out_mod)
    R2_ensemble_full[iensemble] = L2_R(out_test, out_mod)    
    corr_ensemble_top[iensemble] = corr(truth_top, mod_top)
    R2_ensemble_top[iensemble] = L2_R(truth_top, mod_top)  
    corr_ensemble_mid[iensemble] = corr(truth_mid, mod_mid)
    R2_ensemble_mid[iensemble] = L2_R(truth_mid, mod_mid)
    corr_ensemble_bot[iensemble] = corr(truth_bot, mod_bot)
    R2_ensemble_bot[iensemble] = L2_R(truth_bot, mod_bot)

print('full panel, correlation:')
mean_val = np.mean(corr_ensemble_full)
max_val = np.max(corr_ensemble_full)
min_val = np.min(corr_ensemble_full)
std_val = np.std(corr_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('full panel, R2:')
mean_val = np.mean(R2_ensemble_full)
max_val = np.max(R2_ensemble_full)
min_val = np.min(R2_ensemble_full)
std_val = np.std(R2_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, correlation:')
mean_val = np.mean(corr_ensemble_mid)
max_val = np.max(corr_ensemble_mid)
min_val = np.min(corr_ensemble_mid)
std_val = np.std(corr_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, R2:')
mean_val = np.mean(R2_ensemble_mid)
max_val = np.max(R2_ensemble_mid)
min_val = np.min(R2_ensemble_mid)
std_val = np.std(R2_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)

  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)


full panel, correlation:
, full & 0.9653 & 0.9738 & 0.9510 & 0.0081 \\
full panel, R2:
, full & 0.9315 & 0.9482 & 0.9038 & 0.0155 \\
midjet panel, correlation:
, mid & 0.9438 & 0.9589 & 0.9194 & 0.0138 \\
midjet panel, R2:
, mid & 0.8903 & 0.9192 & 0.8451 & 0.0257 \\


In [19]:
# ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓ Change below for each Configuration ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓
decayunits=5 #The best SST satellite has a resolution of 9 km. It may be safe to set decayunits to be 20 km to be already resolvable by satellites

vi1 = 'ssh_ins'
vi2 = 'T_xy_ins'
vi3 = 'u_xy_ins'
vi4 = 'v_xy_ins'

vo1 = 'ssh_cos'
vo2 = 'ssh_sin'

batch_size = 50 #maximizing it so that the GPU memory maxes out. Needs to be divisible by Ntrain. Otherwise there will be size mismatch issues.
lr0 = 0.005*10/batch_size #Roughly should scale inversely to batch_size

var_input_names = [vi1, vi2, vi3, vi4]
var_output_names = [vo1, vo2]
N_inp = len(var_input_names)
N_out = len(var_output_names)

save_fn_prefix  = 'any_{}{}{}{}_{}{}_nobatchnorm_degradeT_du_{}'.format(vi1, vi2, vi3, vi4, vo1, vo2, decayunits)
nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)
# ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑
nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)

all_train_input, all_train_output = preload_data_filterT(nctrains, Ntrain, decayunits=decayunits,plot=False)
all_test_input, all_test_output = preload_data_filterT(nctest, Ntest, decayunits=decayunits,plot=False)

#Normalize data
#Compute mean and variance for normalization
mean_input=np.nanmean(np.concatenate((all_train_input, all_test_input), axis=0),axis=(0, 2, 3))
mean_output=np.nanmean(np.concatenate((all_train_output, all_test_output), axis=0),axis=(0, 2, 3))
#Subtract the data with their means
all_train_input=all_train_input-mean_input[None, :, None, None]
all_train_output=all_train_output-mean_output[None, :, None, None]
all_test_input=all_test_input-mean_input[None, :, None, None]
all_test_output=all_test_output-mean_output[None, :, None, None]
#Compute the variances
var_input=np.nanmean((np.concatenate((all_train_input, all_test_input), axis=0))**2,axis=(0, 2, 3))
var_output=np.nanmean((np.concatenate((all_train_output, all_test_output), axis=0))**2,axis=(0, 2, 3))
#Scale the data so that they have variance of 1
all_train_input=all_train_input/np.sqrt(var_input[None, :, None, None])
all_train_output=all_train_output/np.sqrt(var_output[None, :, None, None])
all_test_input=all_test_input/np.sqrt(var_input[None, :, None, None])
all_test_output=all_test_output/np.sqrt(var_output[None, :, None, None])

inp_test, out_test = loaddata_preloaded_test(all_test_input, all_test_output)
out_test = out_test*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

truth_bot = out_test[:, :, bottom_slice, :]
truth_mid = out_test[:, :, mid_slice, :]
truth_top = out_test[:, :, top_slice, :]

combined_names = ''.join(var_input_names)

#Array to record performance metrics
corr_ensemble_full = np.zeros(nensemble)
R2_ensemble_full = np.zeros(nensemble)
corr_ensemble_top = np.zeros(nensemble)
R2_ensemble_top = np.zeros(nensemble)
corr_ensemble_mid = np.zeros(nensemble)
R2_ensemble_mid = np.zeros(nensemble)
corr_ensemble_bot = np.zeros(nensemble)
R2_ensemble_bot = np.zeros(nensemble)
for iensemble in np.arange(nensemble):
    fstr = f'{save_fn_prefix}_rp_{iensemble}' 
    model_filename = f'/{fstr}.pth'
    model_path = model_folder+ model_filename
    state_dict = torch.load(model_path)
    # Create a new instance of the model
    model = UNet_nobatchnorm(N_inp, N_out, bilinear = True, Nbase = Nbase)
    # Load the state_dict into the model
    model.load_state_dict(state_dict)
    # Set the model to evaluation mode
    model.eval()
    
    with torch.no_grad():
        out_mod = model(totorch(inp_test)).detach().cpu().numpy()
    #Renormalize
    out_mod = out_mod*np.sqrt(var_output[None, :, None, None])+mean_output[None, :, None, None]

    mod_bot = out_mod[:, :, bottom_slice, :]
    mod_mid = out_mod[:, :, mid_slice, :]
    mod_top = out_mod[:, :, top_slice, :]

    corr_ensemble_full[iensemble] = corr(out_test, out_mod)
    R2_ensemble_full[iensemble] = L2_R(out_test, out_mod)    
    corr_ensemble_top[iensemble] = corr(truth_top, mod_top)
    R2_ensemble_top[iensemble] = L2_R(truth_top, mod_top)  
    corr_ensemble_mid[iensemble] = corr(truth_mid, mod_mid)
    R2_ensemble_mid[iensemble] = L2_R(truth_mid, mod_mid)
    corr_ensemble_bot[iensemble] = corr(truth_bot, mod_bot)
    R2_ensemble_bot[iensemble] = L2_R(truth_bot, mod_bot)

print('full panel, correlation:')
mean_val = np.mean(corr_ensemble_full)
max_val = np.max(corr_ensemble_full)
min_val = np.min(corr_ensemble_full)
std_val = np.std(corr_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('full panel, R2:')
mean_val = np.mean(R2_ensemble_full)
max_val = np.max(R2_ensemble_full)
min_val = np.min(R2_ensemble_full)
std_val = np.std(R2_ensemble_full)
latex_table_row=f', full & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, correlation:')
mean_val = np.mean(corr_ensemble_mid)
max_val = np.max(corr_ensemble_mid)
min_val = np.min(corr_ensemble_mid)
std_val = np.std(corr_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)
print('midjet panel, R2:')
mean_val = np.mean(R2_ensemble_mid)
max_val = np.max(R2_ensemble_mid)
min_val = np.min(R2_ensemble_mid)
std_val = np.std(R2_ensemble_mid)
latex_table_row=f', mid & {mean_val:.4f} & {max_val:.4f} & {min_val:.4f} & {std_val:.4f} \\\\'
print(latex_table_row)

  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)
  state_dict = torch.load(model_path)


full panel, correlation:
, full & 0.9671 & 0.9738 & 0.9604 & 0.0042 \\
full panel, R2:
, full & 0.9350 & 0.9484 & 0.9220 & 0.0082 \\
midjet panel, correlation:
, mid & 0.9471 & 0.9587 & 0.9349 & 0.0075 \\
midjet panel, R2:
, mid & 0.8964 & 0.9189 & 0.8736 & 0.0141 \\
