In [1]:
#We try another shallower UNet with a receptive field comparable to the IT wavelength. 
#As the grid resolution is 4 km and the mode-1 tidal wavelength is about 230 km maximum, we want to make the receptive field to be about 60 grid points to accommodate .   
#The code as is requires at least 80 GB GPU and CPU memory. The CPU memory requirement may not be necessary if I do the evaluation steps on a GPU. 
#Print out the mean performance in the midjet in the test set right after training.

import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from netCDF4 import Dataset
from tqdm import tqdm
import torch.nn.functional as F
from sklearn.metrics import r2_score as R2
from sklearn.model_selection import KFold
from sklearn.metrics import r2_score as r2
from copy import deepcopy
import utils
from unet import UNet_nobatchnorm
from ShallowUNet_nobatchnorm import TwolayerUNet
from scipy.stats import pearsonr
#JU's addtion to automate inputs and outputs
import helper_functions as hf
import os
def save_fn(var_input_list, var_output_list):
    var_input_join  = '_and_'.join(var_input_list)
    var_output_join = '_and_'.join(var_output_list)
    return '{}_to_{}'.format(var_input_join, var_output_join)

torch.cuda.set_device(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print ('Running on ', device)


Running on  cuda:0


In [2]:
print(torch.__version__)
print(torch.version.cuda)

2.5.1
12.6


In [3]:
#how many parameters are there in the original UNet used elsewhere in this project
N_inp = 1
N_out = 2

model = UNet_nobatchnorm(N_inp, N_out, bilinear = True, Nbase = 16).cuda()
input = torch.randn(1,N_inp,256,720).to(device) 
output = model(input)
print('number of paramters in the 4-layer UNet with Nbase=16:', utils.nparams(model)/1e6, ' million params')


model = TwolayerUNet(N_inp, N_out, bilinear = True, Nbase = 50).cuda()
output = model(input)
print('number of paramters in the 2-layer UNet with an increased Nbase:', utils.nparams(model)/1e6, ' million params')

number of paramters in the 4-layer UNet with Nbase=16: 1.124418  million params
number of paramters in the 2-layer UNet with an increased Nbase: 1.175202  million params


In [4]:
maxEpochs =  300#small number is taken for debugging. Default is 300
nensemble = 5 #How many training sessions are run for each configuration 
Nbase = 50 #experimented in the previous block

In [1]:
!nvidia-smi #GPU usage should be maxed out during training; tune batch_size according to that

Mon Sep 15 15:52:34 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.05              Driver Version: 560.35.05      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          On  |   00000000:84:00.0 Off |                    0 |
| N/A   43C    P0             68W /  500W |       1MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [6]:
root_dir = '/work/uo0780/u241359/project_tide_synergy/data/'
nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print ('Running on ', device)

Ntrain = np.sum([nc.dimensions['time_counter'].size for nc in nctrains], axis = 0)
Ntest = np.sum([nc.dimensions['time_counter'].size for nc in nctest], axis = 0)

print('number of training records:', Ntrain)
print('number of testing records:', Ntest)

numTrainFiles = len(nctrains)
numRecsFile = nctrains[0].dimensions['time_counter'].size #How many snapshots in time in each data set there is
print (numRecsFile)


def preload_data(nctrains, total_records):
    #total_records = Ntrain#sum(nc.dimensions['time_counter'].size for nc in nctrains)
    #dimensions of data of the nc file.
    max_height = 722
    max_width = 258
    all_input_data = np.zeros((total_records, N_inp, max_height, max_width))*np.nan
    all_output_data = np.zeros((total_records, N_out, max_height, max_width))*np.nan
    current_index = 0
    for ncindex, ncdata in enumerate(nctrains):
        num_recs = ncdata.dimensions['time_counter'].size
        rec_slice = slice(current_index, current_index + num_recs)
        
        for ind, var_name in enumerate(var_input_names):
            data_slice = np.squeeze(ncdata.variables[var_name])
            # print('data_slice shape:')
            # print(data_slice.shape)        
            #all_input_data[rec_slice, ind, :, :] = data_slice
            #For some variables, the dimensions in (x, y) may be smaller than (max_height, max_width). Changing the code so that it adapts them.
            # Get the actual dimensions of data_slice
            slice_height, slice_width = data_slice.shape[-2], data_slice.shape[-1]
            # Place data_slice into the corresponding slice of all_input_data
            all_input_data[rec_slice, ind, :slice_height, :slice_width] = data_slice
    

        for ind, var_name in enumerate(var_output_names):
            data_slice = np.squeeze(ncdata.variables[var_name])
            #all_output_data[rec_slice, ind, :, :] = data_slice
            # Get the actual dimensions of data_slice
            slice_height, slice_width = data_slice.shape[-2], data_slice.shape[-1]
            # Place data_slice into the corresponding slice of all_input_data
            all_output_data[rec_slice, ind, :slice_height, :slice_width] = data_slice

        current_index += num_recs
        
    return all_input_data, all_output_data
    
# Modify the loadtrain function to pull data from preloaded memory
def loaddata_preloaded_train(index, batch_size, all_input_data, all_output_data):
    rec_slice = slice(index, index + batch_size)
    lim = 720
    width = 256
    yslice = slice(0, lim)
    xslice = slice(0, width)
    # print('rec_slice is:')
    # print(rec_slice)
    # print('mean of squared values of loaded input data:')
    # print("{0:0.32f}".format(np.nanmean(all_input_data[rec_slice, :, yslice, xslice]**2)))
    return (all_input_data[rec_slice, :, yslice, xslice], 
            all_output_data[rec_slice, :, yslice, xslice])
#Load test data as one single batch
def loaddata_preloaded_test(all_input_data, all_output_data):
    #rec_slice = slice(index, index + batch_size)
    lim = 720
    width = 256
    yslice = slice(0, lim)
    xslice = slice(0, width)
    # print('rec_slice is:')
    # print(rec_slice)
    # print('mean of squared values of loaded input data:')
    # print("{0:0.32f}".format(np.nanmean(all_input_data[rec_slice, :, yslice, xslice]**2)))
    return (all_input_data[:, :, yslice, xslice], 
            all_output_data[:, :, yslice, xslice])


def load_variable(ncdata, ncindex, variable, rec_slice, yslice, xslice):
    data_squeezed = np.squeeze(ncdata[ncindex].variables[variable])
    return data_squeezed[rec_slice, yslice, xslice]


# def loadtest():
#     var_input = np.ones([150, N_inp, 720, 256])
#     var_output = np.ones([150, N_out, 720, 256])

#     for ind, var_name in enumerate(var_input_names):
#         data_squeezed = np.squeeze(nctest.variables[var_name])
#         var_input[:, ind, :, :] = data_squeezed[rectest_slice, ytest_slice, xtest_slice]
#     for ind, var_name in enumerate(var_output_names):
#         data_squeezed = np.squeeze(nctest.variables[var_name])
#         var_output[:, ind, :, :] = data_squeezed[rectest_slice, ytest_slice, xtest_slice]
#     return var_input, var_output


Running on  cuda:0
number of training records: 600
number of testing records: 150
150


In [7]:
def run_model(var_input_names, var_output_names, save_fn_prefix, N_inp, N_out, iensemble, R2_all, corr_all):
    ytest_slice = slice(0, 720)
    xtest_slice = slice(0, 256)
    rectest_slice = slice(0, 150)

    def totorch(x):
        return torch.tensor(x, dtype = torch.float).cuda()

    model = TwolayerUNet(N_inp, N_out, bilinear = True, Nbase = Nbase).cuda()
    #model = torch.compile(UNet(N_inp, N_out, bilinear = True, Nbase = Nbase).cuda())

    if iensemble == 0:
        input = torch.randn(1,N_inp,256,720).to(device) 
        output = model(input)
        print('Model has ', utils.nparams(model)/1e6, ' million params')

    # for index in range(0, Ntrain, batch_size):
    #     inp, out = loadtrain_preloaded(index, batch_size, all_train_input, all_train_output)
    #     print(inp.shape, out.shape)
#         print(np.nanmean(inp**2), np.max(inp**2), inp.shape)
#         print(np.nanmean(out**2), np.max(out**2), inp.shape)

    inp_test, out_test = loaddata_preloaded_test(all_test_input, all_test_output)
    #inp, out_test = loadtest()
    # print('shapes of input and output TEST data:')
    # print(inp_test.shape, out_test.shape)
    with torch.no_grad():
        inp_test = totorch(inp_test)

    Tcycle = 10
    criterion_train  = nn.L1Loss()
    optim = torch.optim.AdamW(model.parameters(), lr=lr0, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-5*100) #increase weight_decay ***

    r2_test = np.zeros(maxEpochs)
    epochmin = []
    maxr2l = []

    learn = np.zeros(maxEpochs)
    minloss = 1000
    maxR2 = -1000
    minlosscount = 0
    perm = False

    model_best = deepcopy(model)  # Initialize before the loop for safety

    #print('Starting training loop')
    for epoch in tqdm(range(maxEpochs)):
        lr = utils.cosineSGDR(optim, epoch, T0=Tcycle, eta_min=0, eta_max=lr0, scheme = 'constant')  #captioning this seems to make the printed corr larger??***
        model.train()
        index_perm = np.arange(0, Ntrain, batch_size)
        
        if perm:
            index_perm = np.random.permutation(index_perm)
        
        for index in index_perm:
            inp, out = loaddata_preloaded_train(index, batch_size, all_train_input, all_train_output)            
#           inp, out = loadtrain(index, batch_size, nctrains)
            inp, out = totorch(inp), totorch(out)
            #continue #do this to pause the later operations to check how long it takes for the steps up to this 
            out_mod = model(inp)
            loss = criterion_train(out.squeeze(), out_mod.squeeze())
            #Set gradient to zero
            optim.zero_grad()
            #Compute gradients       
            loss.backward()
            #Update parameters with new gradient
            optim.step()
            #Record train loss
            #scheduler.step()
          
        model.eval()
        with torch.no_grad():
            #model_cpu = model.to('cpu')
            #out_mod = model_cpu(inp_test.to('cpu'))
            out_mod=model(inp_test)
            
            r2 = R2(out_test.flatten(), (out_mod).cpu().numpy().flatten())
            r2_test[epoch] = r2
            #print('Debugging: R2 of current epoch:', r2)#Debugging
            #record current best model and best predictions
            if maxR2 <  r2:
                maxR2 = r2
                epochmin.append(epoch)
                maxr2l.append(maxR2)                
                model_best = deepcopy(model)
                corr, pval = pearsonr(out_test.flatten(), (out_mod).cpu().numpy().flatten())
                print('R2:', r2, ' corr: ', corr, ' pval: ', pval)
            #model = model_cpu.to(device)

    #_, out_test = loadtest()
    model_best.eval()
    with torch.no_grad():
    #     inp_test = totorch(inp)
        model_best.to('cpu') #added by HW 
        out_mod = model_best(inp_test.to('cpu')).detach().cpu().numpy()

    R2_all[iensemble]=R2(out_test.flatten(), out_mod.flatten())
    corr_all[iensemble]=pearsonr(out_test.flatten(), out_mod.flatten())[0]

    print('All regions, best model R2:', R2_all[iensemble])#pearsonr(out_test.flatten(), out_mod.flatten())[0])
    print('All regions, best model corr:', corr_all[iensemble])#R2(out_test.flatten(), out_mod.flatten()))

    #Added 2025.9.12: mid-jet statistics
    mid_slice = slice(232, 488)
    out_test_mid = out_test[:, :, mid_slice, :]
    out_mod_mid = out_mod[:, :, mid_slice, :]
    R2_all[iensemble]=R2(out_test_mid.flatten(), out_mod_mid.flatten())
    corr_all[iensemble]=pearsonr(out_test_mid.flatten(), out_mod_mid.flatten())[0]
    print('Mid-jet, best model R2:', R2_all[iensemble])#pearsonr(out_test.flatten(), out_mod.flatten())[0])
    print('Mid-jet, best model corr:', corr_all[iensemble])#R2(out_test.flatten(), out_mod.flatten()))
    
    # Nx, Ny = out_test.shape[2:]; Nx, Ny

    print(out_mod.shape, 'outout model shape')
    dr = '/work/uo0780/u241359/project_tide_synergy/trainedmodels' #'./models/to_vel'
    os.makedirs(dr, exist_ok=True) # exist_ok=True allows the function to do nothing (i.e., not raise an error) if the directory already exists.
    fstr = f'{save_fn_prefix}_rp_{iensemble}'
    PATH = dr + f'/{fstr}.pth'
    torch.save(model_best.state_dict(), PATH)


In [8]:
vi1 = 'ssh_ins'

vo1 = 'ssh_cos'
vo2 = 'ssh_sin'

save_fn_prefix  = 'any_{}_{}{}_twolayerUNet_'.format(vi1, vo1, vo2)
var_input_names = [vi1]
var_output_names = [vo1, vo2]

batch_size = 100 #maximizing it so that the GPU memory maxes out. Needs to be divisible by Ntrain. Otherwise there will be size mismatch issues.

N_inp = len(var_input_names)
N_out = len(var_output_names)

lr0 = 0.005*10/batch_size #Roughly should scale inversely to batch_size

#Recording performance metrics on test data after eaching training cycle
R2_all = np.zeros(nensemble)
corr_all = np.zeros(nensemble)

nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)

all_train_input, all_train_output = preload_data(nctrains, Ntrain)
all_test_input, all_test_output = preload_data(nctest, Ntest)

#Normalize data
#Compute mean and variance for normalization
mean_input=np.nanmean(np.concatenate((all_train_input, all_test_input), axis=0),axis=(0, 2, 3))
mean_output=np.nanmean(np.concatenate((all_train_output, all_test_output), axis=0),axis=(0, 2, 3))
#Subtract the data with their means
all_train_input=all_train_input-mean_input[None, :, None, None]
all_train_output=all_train_output-mean_output[None, :, None, None]
all_test_input=all_test_input-mean_input[None, :, None, None]
all_test_output=all_test_output-mean_output[None, :, None, None]
#Compute the variances
var_input=np.nanmean((np.concatenate((all_train_input, all_test_input), axis=0))**2,axis=(0, 2, 3))
var_output=np.nanmean((np.concatenate((all_train_output, all_test_output), axis=0))**2,axis=(0, 2, 3))
print("mean and variance of all input data:")
print(mean_input,var_input)
print("mean and variance of all output data:")
print(mean_output,var_output)
#Scale the data so that they have variance of 1
all_train_input=all_train_input/np.sqrt(var_input[None, :, None, None])
all_train_output=all_train_output/np.sqrt(var_output[None, :, None, None])
all_test_input=all_test_input/np.sqrt(var_input[None, :, None, None])
all_test_output=all_test_output/np.sqrt(var_output[None, :, None, None])
#Have checked that after these operations, the data is scaled to be zero mean and variance 1.

for iensemble in np.arange(nensemble):
    run_model(var_input_names, var_output_names, save_fn_prefix, N_inp, N_out, iensemble, R2_all, corr_all)  
print('R2 from the best models in each run are:')
print(R2_all)
print('corr from the best models in each run are:')
print(corr_all)

mean and variance of all input data:
[0.03307104] [0.3119807]
mean and variance of all output data:
[-5.16228102e-04 -9.83592627e-05] [9.36516511e-05 1.01456128e-04]
Model has  1.175202  million params


  0%|          | 1/300 [00:18<1:30:35, 18.18s/it]

R2: -0.058450953332987776  corr:  0.0006509216971262774  pval:  1.2961882302512438e-06


  2%|▏         | 5/300 [01:14<1:12:18, 14.71s/it]

R2: 0.14954278314921088  corr:  0.41364287563696656  pval:  0.0


  3%|▎         | 8/300 [01:57<1:10:47, 14.55s/it]

R2: 0.21566258707385444  corr:  0.47322764299743014  pval:  0.0


  5%|▍         | 14/300 [03:20<1:08:00, 14.27s/it]

R2: 0.2282118547206664  corr:  0.49993421781197733  pval:  0.0


  5%|▌         | 15/300 [03:36<1:09:39, 14.66s/it]

R2: 0.23704007920472292  corr:  0.5187051369469105  pval:  0.0


  6%|▌         | 17/300 [04:05<1:09:20, 14.70s/it]

R2: 0.278307204203994  corr:  0.547358717188046  pval:  0.0


  8%|▊         | 25/300 [05:55<1:05:15, 14.24s/it]

R2: 0.2857502295462163  corr:  0.5596840559916753  pval:  0.0


  9%|▊         | 26/300 [06:11<1:06:53, 14.65s/it]

R2: 0.2908071024517579  corr:  0.5625173824624714  pval:  0.0


  9%|▉         | 27/300 [06:27<1:08:12, 14.99s/it]

R2: 0.2966022502104714  corr:  0.581330265261634  pval:  0.0


 11%|█         | 33/300 [07:50<1:03:45, 14.33s/it]

R2: 0.32614052415221784  corr:  0.594815445046988  pval:  0.0


 12%|█▏        | 36/300 [08:33<1:03:25, 14.42s/it]

R2: 0.3346631767950552  corr:  0.600064568742629  pval:  0.0


 13%|█▎        | 38/300 [09:02<1:03:40, 14.58s/it]

R2: 0.3578431039241018  corr:  0.6113466313209456  pval:  0.0


 14%|█▍        | 42/300 [09:58<1:01:56, 14.40s/it]

R2: 0.3683641703276198  corr:  0.6165132368378484  pval:  0.0


 15%|█▌        | 46/300 [10:54<1:00:42, 14.34s/it]

R2: 0.39000555775944334  corr:  0.6374288086054753  pval:  0.0


 19%|█▊        | 56/300 [13:11<57:29, 14.14s/it]  

R2: 0.3983021927349695  corr:  0.6551278374467496  pval:  0.0


 19%|█▉        | 58/300 [13:40<58:09, 14.42s/it]

R2: 0.40736495771778825  corr:  0.6520901612439384  pval:  0.0


 21%|██        | 63/300 [14:50<56:20, 14.26s/it]

R2: 0.4167532756262631  corr:  0.653009563494479  pval:  0.0


 22%|██▏       | 65/300 [15:19<56:42, 14.48s/it]

R2: 0.4214252286978456  corr:  0.663994630019988  pval:  0.0


 22%|██▏       | 67/300 [15:48<56:37, 14.58s/it]

R2: 0.42751437868191877  corr:  0.6635100501206268  pval:  0.0


 25%|██▌       | 75/300 [17:38<53:08, 14.17s/it]

R2: 0.45029401945106207  corr:  0.6820010002394061  pval:  0.0


 34%|███▍      | 102/300 [23:44<46:37, 14.13s/it]

R2: 0.45927933277805144  corr:  0.6936446641138979  pval:  0.0


 38%|███▊      | 114/300 [26:29<43:49, 14.14s/it]

R2: 0.5099167790093346  corr:  0.7186125685091956  pval:  0.0


 65%|██████▌   | 195/300 [44:46<24:45, 14.15s/it]

R2: 0.5194793804882862  corr:  0.728908541286193  pval:  0.0


 79%|███████▉  | 238/300 [54:29<14:35, 14.11s/it]

R2: 0.5253001101013985  corr:  0.7348879616351545  pval:  0.0


 80%|████████  | 241/300 [55:12<14:04, 14.32s/it]

R2: 0.5337699554243771  corr:  0.7381038596877297  pval:  0.0


 86%|████████▌ | 257/300 [58:50<10:07, 14.12s/it]

R2: 0.5348211272827115  corr:  0.7398183570546446  pval:  0.0


 93%|█████████▎| 278/300 [1:03:36<05:10, 14.11s/it]

R2: 0.5369392123468855  corr:  0.741482703509794  pval:  0.0


 94%|█████████▍| 283/300 [1:04:45<04:01, 14.21s/it]

R2: 0.5437199398974075  corr:  0.7460021043198372  pval:  0.0


100%|██████████| 300/300 [1:08:35<00:00, 13.72s/it]


All regions, best model R2: 0.5436823854386865
All regions, best model corr: 0.7459846738685127
Mid-jet, best model R2: 0.3927480853703813
Mid-jet, best model corr: 0.6447849624579924
(150, 2, 720, 256) outout model shape


  0%|          | 1/300 [00:15<1:17:27, 15.54s/it]

R2: -0.021034029907300633  corr:  0.0048148888247305805  pval:  9.454762226253782e-281


  1%|          | 2/300 [00:31<1:17:12, 15.55s/it]

R2: -0.006762531686080653  corr:  0.0746081972016937  pval:  0.0


  1%|▏         | 4/300 [01:00<1:13:50, 14.97s/it]

R2: 0.08376553829257971  corr:  0.30684688552660694  pval:  0.0


  2%|▏         | 6/300 [01:29<1:12:32, 14.81s/it]

R2: 0.10316009369332468  corr:  0.43667846553670586  pval:  0.0


  2%|▏         | 7/300 [01:44<1:13:27, 15.04s/it]

R2: 0.18594224482345723  corr:  0.46602559323460907  pval:  0.0


  3%|▎         | 9/300 [02:13<1:12:02, 14.85s/it]

R2: 0.2072564727631473  corr:  0.4877571090511132  pval:  0.0


  3%|▎         | 10/300 [02:29<1:12:47, 15.06s/it]

R2: 0.2077669773018861  corr:  0.49005348613580946  pval:  0.0


  4%|▍         | 12/300 [02:58<1:11:23, 14.87s/it]

R2: 0.21859123616678355  corr:  0.5208967443040294  pval:  0.0


  5%|▌         | 15/300 [03:40<1:09:13, 14.57s/it]

R2: 0.23574890395414005  corr:  0.538420287812868  pval:  0.0


  6%|▌         | 17/300 [04:09<1:09:00, 14.63s/it]

R2: 0.2618345609834455  corr:  0.5472332308101591  pval:  0.0


  8%|▊         | 25/300 [05:59<1:04:56, 14.17s/it]

R2: 0.26863800531906457  corr:  0.5488442893242823  pval:  0.0


  9%|▊         | 26/300 [06:15<1:06:34, 14.58s/it]

R2: 0.28146500962407484  corr:  0.5634990635961313  pval:  0.0


  9%|▉         | 28/300 [06:44<1:06:22, 14.64s/it]

R2: 0.28361276712138883  corr:  0.5700579676436419  pval:  0.0


 11%|█▏        | 34/300 [08:07<1:03:09, 14.25s/it]

R2: 0.29421823224796817  corr:  0.5808362336405997  pval:  0.0


 12%|█▏        | 35/300 [08:23<1:04:38, 14.64s/it]

R2: 0.3242612261789828  corr:  0.5935137718267318  pval:  0.0


 12%|█▏        | 36/300 [08:38<1:05:36, 14.91s/it]

R2: 0.3352435561662118  corr:  0.6041508261373592  pval:  0.0


 12%|█▏        | 37/300 [08:54<1:06:11, 15.10s/it]

R2: 0.3617440723517338  corr:  0.6177768370248791  pval:  0.0


 16%|█▌        | 47/300 [11:11<59:40, 14.15s/it]  

R2: 0.3844205287483071  corr:  0.638120964228181  pval:  0.0


 16%|█▋        | 49/300 [11:40<1:00:19, 14.42s/it]

R2: 0.3896623457845151  corr:  0.6406566254712993  pval:  0.0


 17%|█▋        | 51/300 [12:09<1:00:22, 14.55s/it]

R2: 0.3995028935724839  corr:  0.6515028982404835  pval:  0.0


 17%|█▋        | 52/300 [12:24<1:01:20, 14.84s/it]

R2: 0.40874247246121775  corr:  0.6453891399718248  pval:  0.0


 18%|█▊        | 53/300 [12:40<1:01:54, 15.04s/it]

R2: 0.412339935186233  corr:  0.6548710295348797  pval:  0.0


 19%|█▉        | 58/300 [13:49<57:55, 14.36s/it]  

R2: 0.44729631592635144  corr:  0.6761120311183205  pval:  0.0


 22%|██▏       | 67/300 [15:53<54:51, 14.13s/it]

R2: 0.46130691001129975  corr:  0.687684811371642  pval:  0.0


 25%|██▌       | 75/300 [17:43<53:00, 14.13s/it]

R2: 0.47427915803311593  corr:  0.6964830792479413  pval:  0.0


 26%|██▌       | 77/300 [18:12<53:35, 14.42s/it]

R2: 0.4766913464195477  corr:  0.7024919210899441  pval:  0.0


 28%|██▊       | 84/300 [19:48<51:01, 14.17s/it]

R2: 0.4781547642412751  corr:  0.703129669849257  pval:  0.0


 30%|███       | 90/300 [21:11<49:37, 14.18s/it]

R2: 0.4815758133317807  corr:  0.70536278005476  pval:  0.0


 33%|███▎      | 100/300 [23:28<47:05, 14.13s/it]

R2: 0.4861114815242058  corr:  0.7088562030100187  pval:  0.0


 36%|███▌      | 107/300 [25:05<45:31, 14.15s/it]

R2: 0.49221440956576334  corr:  0.7126212233784678  pval:  0.0


 38%|███▊      | 113/300 [26:28<44:11, 14.18s/it]

R2: 0.5068366226296324  corr:  0.7205356515618033  pval:  0.0


 43%|████▎     | 129/300 [30:06<40:12, 14.11s/it]

R2: 0.5075829858470533  corr:  0.7218311969283745  pval:  0.0


 46%|████▌     | 137/300 [31:56<38:24, 14.14s/it]

R2: 0.5170919282874743  corr:  0.7267854002059692  pval:  0.0


 53%|█████▎    | 158/300 [36:42<33:23, 14.11s/it]

R2: 0.5265669057036403  corr:  0.7328653756443182  pval:  0.0


 56%|█████▌    | 168/300 [38:59<31:16, 14.22s/it]

R2: 0.5269922653972816  corr:  0.7348682277164648  pval:  0.0


 56%|█████▋    | 169/300 [39:14<31:54, 14.61s/it]

R2: 0.5289763751312326  corr:  0.7369743314445668  pval:  0.0


 65%|██████▍   | 194/300 [44:54<24:54, 14.10s/it]

R2: 0.5300930525904476  corr:  0.7348509334709565  pval:  0.0


 66%|██████▋   | 199/300 [46:03<23:54, 14.20s/it]

R2: 0.5363684686935475  corr:  0.7416807473979184  pval:  0.0


 67%|██████▋   | 200/300 [46:19<24:19, 14.60s/it]

R2: 0.5388468121080572  corr:  0.7430706655279813  pval:  0.0


 74%|███████▍  | 223/300 [51:32<18:06, 14.11s/it]

R2: 0.5410871930164053  corr:  0.744272907055024  pval:  0.0


 83%|████████▎ | 248/300 [57:11<12:12, 14.10s/it]

R2: 0.5414735313865602  corr:  0.7437655825899617  pval:  0.0


 83%|████████▎ | 249/300 [57:27<12:20, 14.52s/it]

R2: 0.5495839691989728  corr:  0.7493836937688994  pval:  0.0


100%|██████████| 300/300 [1:08:55<00:00, 13.79s/it]


All regions, best model R2: 0.5495826355366232
All regions, best model corr: 0.7493829301304032
Mid-jet, best model R2: 0.4058387001242434
Mid-jet, best model corr: 0.6527015535892591
(150, 2, 720, 256) outout model shape


  0%|          | 1/300 [00:15<1:17:25, 15.54s/it]

R2: -0.0018748417286658103  corr:  0.02724503298441642  pval:  0.0


  1%|          | 2/300 [00:31<1:17:07, 15.53s/it]

R2: 0.0041378860629022185  corr:  0.07770684457102776  pval:  0.0


  1%|          | 3/300 [00:46<1:16:49, 15.52s/it]

R2: 0.008360430041911449  corr:  0.19615324916604734  pval:  0.0


  1%|▏         | 4/300 [01:02<1:16:35, 15.53s/it]

R2: 0.13683217674304726  corr:  0.38875817948696506  pval:  0.0


  2%|▏         | 6/300 [01:31<1:13:43, 15.05s/it]

R2: 0.1579250890327466  corr:  0.4797197312224734  pval:  0.0


  2%|▏         | 7/300 [01:46<1:14:14, 15.20s/it]

R2: 0.25054401969513407  corr:  0.5321842680986724  pval:  0.0


  3%|▎         | 9/300 [02:15<1:12:26, 14.94s/it]

R2: 0.2551976866798439  corr:  0.5384686797859903  pval:  0.0


  5%|▍         | 14/300 [03:25<1:08:22, 14.35s/it]

R2: 0.25818581351556924  corr:  0.5439508034827842  pval:  0.0


  6%|▌         | 17/300 [04:07<1:07:54, 14.40s/it]

R2: 0.27467919150440834  corr:  0.5640120383739177  pval:  0.0


  6%|▌         | 18/300 [04:23<1:09:16, 14.74s/it]

R2: 0.2861799777642815  corr:  0.5666203730687788  pval:  0.0


  6%|▋         | 19/300 [04:38<1:10:09, 14.98s/it]

R2: 0.2944023672171472  corr:  0.5720081091432203  pval:  0.0


  8%|▊         | 25/300 [06:01<1:05:27, 14.28s/it]

R2: 0.31972731266213583  corr:  0.5800209228825465  pval:  0.0


  9%|▉         | 27/300 [06:30<1:05:58, 14.50s/it]

R2: 0.350766457409901  corr:  0.6137634222948977  pval:  0.0


 12%|█▏        | 35/300 [08:20<1:02:31, 14.16s/it]

R2: 0.3740552345046009  corr:  0.6241074597728842  pval:  0.0


 12%|█▏        | 37/300 [08:49<1:03:13, 14.42s/it]

R2: 0.390512390290752  corr:  0.6402752137221458  pval:  0.0


 15%|█▌        | 45/300 [10:39<1:00:10, 14.16s/it]

R2: 0.4209899924419428  corr:  0.655966973858398  pval:  0.0


 19%|█▊        | 56/300 [13:10<57:26, 14.13s/it]  

R2: 0.4411850373939973  corr:  0.6726191972601299  pval:  0.0


 22%|██▏       | 66/300 [15:27<55:04, 14.12s/it]

R2: 0.4416123797474577  corr:  0.6815090577375224  pval:  0.0


 25%|██▌       | 76/300 [17:44<52:49, 14.15s/it]

R2: 0.4423975063842953  corr:  0.6837472625025255  pval:  0.0


 26%|██▌       | 77/300 [18:00<54:08, 14.57s/it]

R2: 0.4663478997732162  corr:  0.694523651262975  pval:  0.0


 31%|███▏      | 94/300 [21:52<48:27, 14.11s/it]

R2: 0.4726359550991812  corr:  0.7005685836358074  pval:  0.0


 36%|███▌      | 107/300 [24:49<45:25, 14.12s/it]

R2: 0.4727253139109864  corr:  0.7032788151652326  pval:  0.0


 36%|███▌      | 108/300 [25:05<46:31, 14.54s/it]

R2: 0.47431078456144027  corr:  0.701678145913382  pval:  0.0


 38%|███▊      | 115/300 [26:41<43:46, 14.20s/it]

R2: 0.4853819193476169  corr:  0.7097966296554941  pval:  0.0


 39%|███▉      | 117/300 [27:10<44:04, 14.45s/it]

R2: 0.5004505473623486  corr:  0.7170869609488206  pval:  0.0


 48%|████▊     | 144/300 [33:17<36:42, 14.12s/it]

R2: 0.5135263859007528  corr:  0.7243892475126171  pval:  0.0


 57%|█████▋    | 171/300 [39:24<30:20, 14.11s/it]

R2: 0.5278546845298848  corr:  0.7360306457957082  pval:  0.0


 75%|███████▍  | 224/300 [51:22<17:53, 14.12s/it]

R2: 0.5366497948581533  corr:  0.7390657992867107  pval:  0.0


 82%|████████▏ | 245/300 [56:08<12:57, 14.14s/it]

R2: 0.5378950412902792  corr:  0.7409087514151149  pval:  0.0


 83%|████████▎ | 249/300 [57:04<12:08, 14.28s/it]

R2: 0.5387243929641684  corr:  0.7428153339664463  pval:  0.0


 86%|████████▋ | 259/300 [59:22<09:40, 14.15s/it]

R2: 0.5414796920192311  corr:  0.7450441074256365  pval:  0.0


 87%|████████▋ | 261/300 [59:51<09:23, 14.44s/it]

R2: 0.5492542712424555  corr:  0.7437084517317413  pval:  0.0


 88%|████████▊ | 263/300 [1:00:20<08:59, 14.59s/it]

R2: 0.5502928708352725  corr:  0.7528328026856097  pval:  0.0


100%|██████████| 300/300 [1:08:40<00:00, 13.74s/it]


All regions, best model R2: 0.5503119947470392
All regions, best model corr: 0.7528445477289804
Mid-jet, best model R2: 0.3749075696594881
Mid-jet, best model corr: 0.6400259328103018
(150, 2, 720, 256) outout model shape


  0%|          | 1/300 [00:15<1:17:32, 15.56s/it]

R2: -0.022375913960195337  corr:  0.005641635895311324  pval:  0.0


  1%|          | 2/300 [00:31<1:17:12, 15.55s/it]

R2: -0.0004139529450861712  corr:  0.10860709465517936  pval:  0.0


  1%|          | 3/300 [00:46<1:16:56, 15.54s/it]

R2: 0.10537600876303499  corr:  0.33508439638376286  pval:  0.0


  1%|▏         | 4/300 [01:02<1:16:41, 15.55s/it]

R2: 0.16789928737206417  corr:  0.4405797359633269  pval:  0.0


  2%|▏         | 7/300 [01:44<1:12:05, 14.76s/it]

R2: 0.24055787264939998  corr:  0.5148184945679796  pval:  0.0


  3%|▎         | 8/300 [02:00<1:13:05, 15.02s/it]

R2: 0.2540449676321894  corr:  0.5240188879194708  pval:  0.0


  3%|▎         | 10/300 [02:29<1:11:50, 14.86s/it]

R2: 0.2681791134686218  corr:  0.5394188736551205  pval:  0.0


  5%|▌         | 15/300 [03:39<1:08:09, 14.35s/it]

R2: 0.3026509454660308  corr:  0.5729045195480325  pval:  0.0


  6%|▌         | 18/300 [04:21<1:07:42, 14.41s/it]

R2: 0.36006502981657285  corr:  0.6101105270449947  pval:  0.0


  9%|▊         | 26/300 [06:11<1:04:41, 14.17s/it]

R2: 0.3631877173997373  corr:  0.6204781682396298  pval:  0.0


  9%|▉         | 27/300 [06:27<1:06:20, 14.58s/it]

R2: 0.36337657434192705  corr:  0.6191002687245724  pval:  0.0


  9%|▉         | 28/300 [06:42<1:07:26, 14.88s/it]

R2: 0.38282001072653105  corr:  0.635335061824255  pval:  0.0


 12%|█▏        | 35/300 [08:19<1:02:54, 14.24s/it]

R2: 0.44935259101975  corr:  0.6783308172936775  pval:  0.0


 15%|█▍        | 44/300 [10:23<1:00:21, 14.15s/it]

R2: 0.45029896933337077  corr:  0.6813745331628036  pval:  0.0


 16%|█▌        | 48/300 [11:19<59:53, 14.26s/it]  

R2: 0.45109661943948454  corr:  0.6821116689496134  pval:  0.0


 17%|█▋        | 52/300 [12:15<59:04, 14.29s/it]

R2: 0.4661257245688001  corr:  0.6894854150516593  pval:  0.0


 21%|██        | 63/300 [14:46<55:48, 14.13s/it]

R2: 0.46791652500465064  corr:  0.689909510315093  pval:  0.0


 21%|██▏       | 64/300 [15:01<57:12, 14.55s/it]

R2: 0.4740056471852372  corr:  0.6993278517366811  pval:  0.0


 23%|██▎       | 68/300 [15:57<55:32, 14.37s/it]

R2: 0.48150160502220696  corr:  0.7020838125454639  pval:  0.0


 25%|██▌       | 76/300 [17:47<52:56, 14.18s/it]

R2: 0.4896746238990658  corr:  0.7091464664601635  pval:  0.0


 30%|██▉       | 89/300 [20:45<49:42, 14.14s/it]

R2: 0.5000340596696191  corr:  0.7168068925627299  pval:  0.0


 31%|███▏      | 94/300 [21:55<48:51, 14.23s/it]

R2: 0.5094744470640542  corr:  0.7209316136176617  pval:  0.0


 38%|███▊      | 115/300 [26:41<43:32, 14.12s/it]

R2: 0.5205369008635254  corr:  0.729311873962967  pval:  0.0


 41%|████▏     | 124/300 [28:45<41:30, 14.15s/it]

R2: 0.525231036772398  corr:  0.73414561559865  pval:  0.0


 44%|████▍     | 132/300 [30:35<39:41, 14.18s/it]

R2: 0.5299362663274996  corr:  0.737389728835891  pval:  0.0


 60%|█████▉    | 179/300 [41:13<28:30, 14.13s/it]

R2: 0.5315822107461452  corr:  0.7397145170273174  pval:  0.0


 63%|██████▎   | 188/300 [43:17<26:26, 14.16s/it]

R2: 0.5436437492726764  corr:  0.7439709620842937  pval:  0.0


 76%|███████▋  | 229/300 [52:34<16:43, 14.13s/it]

R2: 0.5542267117085342  corr:  0.7528049837480215  pval:  0.0


 88%|████████▊ | 264/300 [1:00:29<08:28, 14.13s/it]

R2: 0.5554786228587079  corr:  0.7538420200021293  pval:  0.0


 89%|████████▊ | 266/300 [1:00:58<08:10, 14.43s/it]

R2: 0.559373106127971  corr:  0.7546386682364766  pval:  0.0


100%|██████████| 300/300 [1:08:39<00:00, 13.73s/it]


All regions, best model R2: 0.559437453530454
All regions, best model corr: 0.7546791626783784
Mid-jet, best model R2: 0.4237249232553443
Mid-jet, best model corr: 0.6635798728147231
(150, 2, 720, 256) outout model shape


  0%|          | 1/300 [00:15<1:17:30, 15.55s/it]

R2: -0.002911074709972983  corr:  0.02676736352909647  pval:  0.0


  1%|          | 2/300 [00:31<1:17:11, 15.54s/it]

R2: 0.03328730731907503  corr:  0.18688604552979707  pval:  0.0


  1%|          | 3/300 [00:46<1:16:56, 15.54s/it]

R2: 0.03472514598351362  corr:  0.24690788122906634  pval:  0.0


  1%|▏         | 4/300 [01:02<1:16:42, 15.55s/it]

R2: 0.055385608775110784  corr:  0.3211231585278496  pval:  0.0


  2%|▏         | 5/300 [01:17<1:16:25, 15.54s/it]

R2: 0.09200948161645095  corr:  0.41405320735028983  pval:  0.0


  2%|▏         | 6/300 [01:33<1:16:08, 15.54s/it]

R2: 0.13689521103098745  corr:  0.4357998143814722  pval:  0.0


  2%|▏         | 7/300 [01:48<1:15:51, 15.53s/it]

R2: 0.16886667819871537  corr:  0.46187540383685943  pval:  0.0


  3%|▎         | 9/300 [02:17<1:13:13, 15.10s/it]

R2: 0.18168258018825534  corr:  0.47533028475745637  pval:  0.0


  4%|▍         | 12/300 [03:00<1:10:22, 14.66s/it]

R2: 0.2249502195697226  corr:  0.5043358202363913  pval:  0.0


  5%|▌         | 15/300 [03:43<1:08:58, 14.52s/it]

R2: 0.22744139692707444  corr:  0.5114794902893792  pval:  0.0


  5%|▌         | 16/300 [03:58<1:10:12, 14.83s/it]

R2: 0.22937048970853358  corr:  0.5291094578379698  pval:  0.0


  6%|▌         | 17/300 [04:14<1:10:58, 15.05s/it]

R2: 0.2563237262289546  corr:  0.5460533573048447  pval:  0.0


  6%|▌         | 18/300 [04:29<1:11:27, 15.20s/it]

R2: 0.2565354936072448  corr:  0.5423307838649951  pval:  0.0


  6%|▋         | 19/300 [04:45<1:11:42, 15.31s/it]

R2: 0.2686170110049252  corr:  0.5535848424992909  pval:  0.0


  9%|▊         | 26/300 [06:22<1:05:13, 14.28s/it]

R2: 0.31821885632145075  corr:  0.5792972553895535  pval:  0.0


  9%|▉         | 27/300 [06:37<1:06:45, 14.67s/it]

R2: 0.3195097018641173  corr:  0.5909629813703182  pval:  0.0


  9%|▉         | 28/300 [06:53<1:07:46, 14.95s/it]

R2: 0.3289261250181088  corr:  0.5947309402761491  pval:  0.0


 10%|▉         | 29/300 [07:08<1:08:22, 15.14s/it]

R2: 0.3536169051108311  corr:  0.6141427189167582  pval:  0.0


 12%|█▏        | 36/300 [08:45<1:02:47, 14.27s/it]

R2: 0.37944691162418387  corr:  0.6285813933821996  pval:  0.0


 13%|█▎        | 40/300 [09:41<1:01:58, 14.30s/it]

R2: 0.3806310781404654  corr:  0.6347654237081352  pval:  0.0


 14%|█▍        | 43/300 [10:24<1:01:38, 14.39s/it]

R2: 0.42149079996677874  corr:  0.6581678052016645  pval:  0.0


 15%|█▌        | 45/300 [10:53<1:01:49, 14.55s/it]

R2: 0.4515238448464334  corr:  0.6810237397164232  pval:  0.0


 19%|█▉        | 57/300 [13:37<57:18, 14.15s/it]  

R2: 0.45325007540423456  corr:  0.681173625481167  pval:  0.0


 23%|██▎       | 68/300 [16:08<54:43, 14.15s/it]

R2: 0.47455520267118956  corr:  0.7002437731568004  pval:  0.0


 25%|██▌       | 76/300 [17:58<52:53, 14.17s/it]

R2: 0.476364851390486  corr:  0.7038405020063708  pval:  0.0


 28%|██▊       | 85/300 [20:02<50:44, 14.16s/it]

R2: 0.48450525624678065  corr:  0.7018914606543127  pval:  0.0


 33%|███▎      | 99/300 [23:14<47:21, 14.13s/it]

R2: 0.4852602345394318  corr:  0.7091500670586472  pval:  0.0


 33%|███▎      | 100/300 [23:29<48:30, 14.55s/it]

R2: 0.48615640450817676  corr:  0.7101032925182923  pval:  0.0


 35%|███▌      | 105/300 [24:39<46:28, 14.30s/it]

R2: 0.4895400946599078  corr:  0.7081780576304701  pval:  0.0


 39%|███▊      | 116/300 [27:10<43:24, 14.15s/it]

R2: 0.4984242879134759  corr:  0.7151900744365325  pval:  0.0


 39%|███▉      | 118/300 [27:39<43:48, 14.44s/it]

R2: 0.5137939854853288  corr:  0.7250340847013123  pval:  0.0


 50%|████▉     | 149/300 [34:40<35:32, 14.12s/it]

R2: 0.5175781032600699  corr:  0.7291754831665178  pval:  0.0


 53%|█████▎    | 160/300 [37:11<32:59, 14.14s/it]

R2: 0.5194763437830888  corr:  0.7300954882862833  pval:  0.0


 56%|█████▌    | 168/300 [39:01<31:08, 14.16s/it]

R2: 0.5394530956817296  corr:  0.7419790325010265  pval:  0.0


 73%|███████▎  | 219/300 [50:33<19:05, 14.14s/it]

R2: 0.5468019463958498  corr:  0.7480348868652343  pval:  0.0


 81%|████████  | 243/300 [55:59<13:25, 14.12s/it]

R2: 0.5488251242586311  corr:  0.7430347427058741  pval:  0.0


 83%|████████▎ | 250/300 [57:36<11:49, 14.19s/it]

R2: 0.5497657580822688  corr:  0.750839705749823  pval:  0.0


 86%|████████▌ | 258/300 [59:26<09:55, 14.17s/it]

R2: 0.5517503750035232  corr:  0.7514933996068712  pval:  0.0


 89%|████████▉ | 267/300 [1:01:30<07:47, 14.16s/it]

R2: 0.5597559762459349  corr:  0.7559891731941125  pval:  0.0


100%|██████████| 300/300 [1:08:57<00:00, 13.79s/it]


All regions, best model R2: 0.5598264745121957
All regions, best model corr: 0.7560272799651656
Mid-jet, best model R2: 0.4092451165610914
Mid-jet, best model corr: 0.6577666600306842
(150, 2, 720, 256) outout model shape
R2 from the best models in each run are:
[0.39274809 0.4058387  0.37490757 0.42372492 0.40924512]
corr from the best models in each run are:
[0.64478496 0.65270155 0.64002593 0.66357987 0.65776666]
