In [1]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from netCDF4 import Dataset
from tqdm import tqdm
import torch.nn.functional as F
from sklearn.metrics import r2_score as R2
from sklearn.model_selection import KFold
from sklearn.metrics import r2_score as r2
from copy import deepcopy
import utils
from unet import UNet_nobatchnorm
from scipy.stats import pearsonr
#JU's addtion to automate inputs and outputs
import helper_functions as hf
torch.cuda.set_device(0)
import os
def save_fn(var_input_list, var_output_list):
    var_input_join  = '_and_'.join(var_input_list)
    var_output_join = '_and_'.join(var_output_list)
    return '{}_to_{}'.format(var_input_join, var_output_join)

In [2]:
print(torch.__version__)
print(torch.version.cuda)

1.13.0
11.7


In [3]:
maxEpochs =  300#small number is taken for debugging
nensemble = 10 #How many training sessions are run for each configuration 
Nbase = 16

In [4]:
!nvidia-smi #GPU usage should be maxed out during training; tune batch_size according to that

Thu Dec 26 11:25:21 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.42.02              Driver Version: 555.42.02      CUDA Version: 12.5     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          Off |   00000000:03:00.0 Off |                    0 |
| N/A   48C    P0             69W /  500W |       5MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100-SXM4-80GB          Off |   00

In [5]:
root_dir = '/work/uo0780/u241359/project_tide_synergy/data/'
nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print ('Running on ', device)

Ntrain = np.sum([nc.dimensions['time_counter'].size for nc in nctrains], axis = 0)
Ntest = np.sum([nc.dimensions['time_counter'].size for nc in nctest], axis = 0)

print('number of training records:', Ntrain)
print('number of testing records:', Ntest)

numTrainFiles = len(nctrains)
numRecsFile = nctrains[0].dimensions['time_counter'].size #How many snapshots in time in each data set there is
print (numRecsFile)


def preload_data(nctrains, total_records):
    #total_records = Ntrain#sum(nc.dimensions['time_counter'].size for nc in nctrains)
    #dimensions of data of the nc file.
    max_height = 722
    max_width = 258
    all_input_data = np.zeros((total_records, N_inp, max_height, max_width))*np.nan
    all_output_data = np.zeros((total_records, N_out, max_height, max_width))*np.nan
    current_index = 0
    for ncindex, ncdata in enumerate(nctrains):
        num_recs = ncdata.dimensions['time_counter'].size
        rec_slice = slice(current_index, current_index + num_recs)
        
        for ind, var_name in enumerate(var_input_names):
            data_slice = np.squeeze(ncdata.variables[var_name])
            # print('data_slice shape:')
            # print(data_slice.shape)        
            #all_input_data[rec_slice, ind, :, :] = data_slice
            #For some variables, the dimensions in (x, y) may be smaller than (max_height, max_width). Changing the code so that it adapts them.
            # Get the actual dimensions of data_slice
            slice_height, slice_width = data_slice.shape[-2], data_slice.shape[-1]
            # Place data_slice into the corresponding slice of all_input_data
            all_input_data[rec_slice, ind, :slice_height, :slice_width] = data_slice
    

        for ind, var_name in enumerate(var_output_names):
            data_slice = np.squeeze(ncdata.variables[var_name])
            #all_output_data[rec_slice, ind, :, :] = data_slice
            # Get the actual dimensions of data_slice
            slice_height, slice_width = data_slice.shape[-2], data_slice.shape[-1]
            # Place data_slice into the corresponding slice of all_input_data
            all_output_data[rec_slice, ind, :slice_height, :slice_width] = data_slice

        current_index += num_recs
        
    return all_input_data, all_output_data
    
# Modify the loadtrain function to pull data from preloaded memory
def loaddata_preloaded_train(index, batch_size, all_input_data, all_output_data):
    rec_slice = slice(index, index + batch_size)
    lim = 720
    width = 256
    yslice = slice(0, lim)
    xslice = slice(0, width)
    # print('rec_slice is:')
    # print(rec_slice)
    # print('mean of squared values of loaded input data:')
    # print("{0:0.32f}".format(np.nanmean(all_input_data[rec_slice, :, yslice, xslice]**2)))
    return (all_input_data[rec_slice, :, yslice, xslice], 
            all_output_data[rec_slice, :, yslice, xslice])
#Load test data as one single batch
def loaddata_preloaded_test(all_input_data, all_output_data):
    #rec_slice = slice(index, index + batch_size)
    lim = 720
    width = 256
    yslice = slice(0, lim)
    xslice = slice(0, width)
    # print('rec_slice is:')
    # print(rec_slice)
    # print('mean of squared values of loaded input data:')
    # print("{0:0.32f}".format(np.nanmean(all_input_data[rec_slice, :, yslice, xslice]**2)))
    return (all_input_data[:, :, yslice, xslice], 
            all_output_data[:, :, yslice, xslice])


def load_variable(ncdata, ncindex, variable, rec_slice, yslice, xslice):
    data_squeezed = np.squeeze(ncdata[ncindex].variables[variable])
    return data_squeezed[rec_slice, yslice, xslice]


# def loadtest():
#     var_input = np.ones([150, N_inp, 720, 256])
#     var_output = np.ones([150, N_out, 720, 256])

#     for ind, var_name in enumerate(var_input_names):
#         data_squeezed = np.squeeze(nctest.variables[var_name])
#         var_input[:, ind, :, :] = data_squeezed[rectest_slice, ytest_slice, xtest_slice]
#     for ind, var_name in enumerate(var_output_names):
#         data_squeezed = np.squeeze(nctest.variables[var_name])
#         var_output[:, ind, :, :] = data_squeezed[rectest_slice, ytest_slice, xtest_slice]
#     return var_input, var_output


Running on  cuda:0
number of training records: 600
number of testing records: 150
150


In [6]:
def run_model(var_input_names, var_output_names, save_fn_prefix, N_inp, N_out, iensemble, R2_all, corr_all):
    ytest_slice = slice(0, 720)
    xtest_slice = slice(0, 256)
    rectest_slice = slice(0, 150)

    def totorch(x):
        return torch.tensor(x, dtype = torch.float).cuda()

    model = UNet_nobatchnorm(N_inp, N_out, bilinear = True, Nbase = Nbase).cuda()
    #model = torch.compile(UNet(N_inp, N_out, bilinear = True, Nbase = Nbase).cuda())

    if iensemble == 0:
        input = torch.randn(1,N_inp,256,720).to(device) 
        output = model(input)
        print('Model has ', utils.nparams(model)/1e6, ' million params')

    # for index in range(0, Ntrain, batch_size):
    #     inp, out = loadtrain_preloaded(index, batch_size, all_train_input, all_train_output)
    #     print(inp.shape, out.shape)
#         print(np.nanmean(inp**2), np.max(inp**2), inp.shape)
#         print(np.nanmean(out**2), np.max(out**2), inp.shape)

    inp_test, out_test = loaddata_preloaded_test(all_test_input, all_test_output)
    #inp, out_test = loadtest()
    # print('shapes of input and output TEST data:')
    # print(inp_test.shape, out_test.shape)
    with torch.no_grad():
        inp_test = totorch(inp_test)

    Tcycle = 10
    criterion_train  = nn.L1Loss()
    optim = torch.optim.AdamW(model.parameters(), lr=lr0, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-5*100) #increase weight_decay ***

    r2_test = np.zeros(maxEpochs)
    epochmin = []
    maxr2l = []

    learn = np.zeros(maxEpochs)
    minloss = 1000
    maxR2 = -1000
    minlosscount = 0
    perm = False

    model_best = deepcopy(model)  # Initialize before the loop for safety

    #print('Starting training loop')
    for epoch in tqdm(range(maxEpochs)):
        lr = utils.cosineSGDR(optim, epoch, T0=Tcycle, eta_min=0, eta_max=lr0, scheme = 'constant')  #captioning this seems to make the printed corr larger??***
        model.train()
        index_perm = np.arange(0, Ntrain, batch_size)
        
        if perm:
            index_perm = np.random.permutation(index_perm)
        
        for index in index_perm:
            inp, out = loaddata_preloaded_train(index, batch_size, all_train_input, all_train_output)            
#           inp, out = loadtrain(index, batch_size, nctrains)
            inp, out = totorch(inp), totorch(out)
            #continue #do this to pause the later operations to check how long it takes for the steps up to this 
            out_mod = model(inp)
            loss = criterion_train(out.squeeze(), out_mod.squeeze())
            #Set gradient to zero
            optim.zero_grad()
            #Compute gradients       
            loss.backward()
            #Update parameters with new gradient
            optim.step()
            #Record train loss
            #scheduler.step()
          
        model.eval()
        with torch.no_grad():
            #model_cpu = model.to('cpu')
            #out_mod = model_cpu(inp_test.to('cpu'))
            out_mod=model(inp_test)
            
            r2 = R2(out_test.flatten(), (out_mod).cpu().numpy().flatten())
            r2_test[epoch] = r2
            #print('Debugging: R2 of current epoch:', r2)#Debugging
            #record current best model and best predictions
            if maxR2 <  r2:
                maxR2 = r2
                epochmin.append(epoch)
                maxr2l.append(maxR2)                
                model_best = deepcopy(model)
                corr, pval = pearsonr(out_test.flatten(), (out_mod).cpu().numpy().flatten())
                print('R2:', r2, ' corr: ', corr, ' pval: ', pval)
            #model = model_cpu.to(device)

    #_, out_test = loadtest()
    model_best.eval()
    with torch.no_grad():
    #     inp_test = totorch(inp)
        model_best.to('cpu') #added by HW 
        out_mod = model_best(inp_test.to('cpu')).detach().cpu().numpy()

    R2_all[iensemble]=R2(out_test.flatten(), out_mod.flatten())
    corr_all[iensemble]=pearsonr(out_test.flatten(), out_mod.flatten())[0]
    # print('Best model R2:', R2_all[iensemble])#pearsonr(out_test.flatten(), out_mod.flatten())[0])
    # print('Best model corr:', corr_all[iensemble])#R2(out_test.flatten(), out_mod.flatten()))

    # Nx, Ny = out_test.shape[2:]; Nx, Ny

    # print(out_mod.shape, 'outout model shape')
    dr = '/work/uo0780/u241359/project_tide_synergy/trainedmodels' #'./models/to_vel'
    os.makedirs(dr, exist_ok=True) # exist_ok=True allows the function to do nothing (i.e., not raise an error) if the directory already exists.
    fstr = f'{save_fn_prefix}_rp_{iensemble}'
    PATH = dr + f'/{fstr}.pth'
    torch.save(model_best.state_dict(), PATH)

In [7]:
vi1 = 'ssh_ins'

vo1 = 'ssh_cos'
vo2 = 'ssh_sin'

save_fn_prefix  = 'any_{}_{}{}_nobatchnorm'.format(vi1, vo1, vo2)
var_input_names = [vi1]
var_output_names = [vo1, vo2]

batch_size = 100 #maximizing it so that the GPU memory maxes out. Needs to be divisible by Ntrain. Otherwise there will be size mismatch issues.

N_inp = len(var_input_names)
N_out = len(var_output_names)

lr0 = 0.005*10/batch_size #Roughly should scale inversely to batch_size

#Recording performance metrics on test data after eaching training cycle
R2_all = np.zeros(nensemble)
corr_all = np.zeros(nensemble)

nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)

all_train_input, all_train_output = preload_data(nctrains, Ntrain)
all_test_input, all_test_output = preload_data(nctest, Ntest)

#Normalize data
#Compute mean and variance for normalization
mean_input=np.nanmean(np.concatenate((all_train_input, all_test_input), axis=0),axis=(0, 2, 3))
mean_output=np.nanmean(np.concatenate((all_train_output, all_test_output), axis=0),axis=(0, 2, 3))
#Subtract the data with their means
all_train_input=all_train_input-mean_input[None, :, None, None]
all_train_output=all_train_output-mean_output[None, :, None, None]
all_test_input=all_test_input-mean_input[None, :, None, None]
all_test_output=all_test_output-mean_output[None, :, None, None]
#Compute the variances
var_input=np.nanmean((np.concatenate((all_train_input, all_test_input), axis=0))**2,axis=(0, 2, 3))
var_output=np.nanmean((np.concatenate((all_train_output, all_test_output), axis=0))**2,axis=(0, 2, 3))
print("mean and variance of all input data:")
print(mean_input,var_input)
print("mean and variance of all output data:")
print(mean_output,var_output)
#Scale the data so that they have variance of 1
all_train_input=all_train_input/np.sqrt(var_input[None, :, None, None])
all_train_output=all_train_output/np.sqrt(var_output[None, :, None, None])
all_test_input=all_test_input/np.sqrt(var_input[None, :, None, None])
all_test_output=all_test_output/np.sqrt(var_output[None, :, None, None])
#Have checked that after these operations, the data is scaled to be zero mean and variance 1.

for iensemble in np.arange(nensemble):
    run_model(var_input_names, var_output_names, save_fn_prefix, N_inp, N_out, iensemble, R2_all, corr_all)  
print('R2 from the best models in each run are:')
print(R2_all)
print('corr from the best models in each run are:')
print(corr_all)

mean and variance of all input data:
[0.03307104] [0.3119807]
mean and variance of all output data:
[-5.16228102e-04 -9.83592627e-05] [9.36516511e-05 1.01456128e-04]
Model has  1.124418  million params


  0%|          | 1/300 [00:06<30:50,  6.19s/it]

R2: -0.029354206544735417  corr:  -0.00015833893508326424  pval:  0.23902452625417986


  1%|          | 3/300 [00:16<27:20,  5.53s/it]

R2: -0.02548615220791306  corr:  0.0068413514470485855  pval:  0.0


  2%|▏         | 5/300 [00:27<26:38,  5.42s/it]

R2: -0.01908674050545245  corr:  0.03217524886297779  pval:  0.0


  2%|▏         | 6/300 [00:32<27:15,  5.56s/it]

R2: -0.015319691030644567  corr:  0.056989737340779445  pval:  0.0


  3%|▎         | 10/300 [00:52<25:07,  5.20s/it]

R2: -0.00427693130976059  corr:  0.140927971543464  pval:  0.0


  4%|▎         | 11/300 [00:58<25:59,  5.40s/it]

R2: 0.05292530383968275  corr:  0.3268503264145011  pval:  0.0


  4%|▍         | 13/300 [01:09<25:45,  5.38s/it]

R2: 0.27046446478741526  corr:  0.531661943211757  pval:  0.0


  5%|▍         | 14/300 [01:14<26:18,  5.52s/it]

R2: 0.27571699600528243  corr:  0.55323188878646  pval:  0.0


  5%|▌         | 15/300 [01:20<26:40,  5.62s/it]

R2: 0.32069059499755725  corr:  0.5896303320068393  pval:  0.0


  6%|▌         | 17/300 [01:31<26:08,  5.54s/it]

R2: 0.3653013419049621  corr:  0.6367492917284542  pval:  0.0


  6%|▌         | 18/300 [01:37<26:25,  5.62s/it]

R2: 0.3961490583872884  corr:  0.652874720524883  pval:  0.0


  6%|▋         | 19/300 [01:43<26:35,  5.68s/it]

R2: 0.4175883053546401  corr:  0.6719862145023792  pval:  0.0


  7%|▋         | 20/300 [01:48<26:43,  5.73s/it]

R2: 0.4188742220047973  corr:  0.6707803993664927  pval:  0.0


  8%|▊         | 23/300 [02:03<24:48,  5.37s/it]

R2: 0.4424669734277177  corr:  0.6747509177782312  pval:  0.0


  8%|▊         | 25/300 [02:14<24:29,  5.34s/it]

R2: 0.47041687509822316  corr:  0.6937337372322558  pval:  0.0


  9%|▊         | 26/300 [02:20<25:03,  5.49s/it]

R2: 0.48116178726764747  corr:  0.7087722757593589  pval:  0.0


  9%|▉         | 27/300 [02:26<25:26,  5.59s/it]

R2: 0.49003969834544114  corr:  0.7121773355740906  pval:  0.0


  9%|▉         | 28/300 [02:31<25:45,  5.68s/it]

R2: 0.5049952180368423  corr:  0.7171819107188605  pval:  0.0


 10%|▉         | 29/300 [02:37<25:56,  5.74s/it]

R2: 0.5117473759932276  corr:  0.7250514044723405  pval:  0.0


 11%|█         | 33/300 [02:57<23:26,  5.27s/it]

R2: 0.5360264933086032  corr:  0.7354293986912189  pval:  0.0


 12%|█▏        | 36/300 [03:12<22:58,  5.22s/it]

R2: 0.541132295870149  corr:  0.7384816290483924  pval:  0.0


 12%|█▏        | 37/300 [03:18<23:41,  5.41s/it]

R2: 0.5559028166252771  corr:  0.7470083754503097  pval:  0.0


 13%|█▎        | 38/300 [03:24<24:12,  5.54s/it]

R2: 0.5569992314086443  corr:  0.7537228353025901  pval:  0.0


 13%|█▎        | 39/300 [03:30<24:29,  5.63s/it]

R2: 0.5748445378005418  corr:  0.7623797463786001  pval:  0.0


 13%|█▎        | 40/300 [03:36<24:40,  5.69s/it]

R2: 0.5761927028724052  corr:  0.763085243180912  pval:  0.0


 14%|█▎        | 41/300 [03:41<24:49,  5.75s/it]

R2: 0.5779675648044087  corr:  0.7618454191744404  pval:  0.0


 16%|█▌        | 48/300 [04:16<22:49,  5.44s/it]

R2: 0.5793404268238873  corr:  0.7639932122561269  pval:  0.0


 16%|█▋        | 49/300 [04:23<23:50,  5.70s/it]

R2: 0.5915208853847176  corr:  0.7727605387751078  pval:  0.0


 17%|█▋        | 50/300 [04:29<24:57,  5.99s/it]

R2: 0.595443522651375  corr:  0.7747195294472993  pval:  0.0


 19%|█▉        | 58/300 [05:11<22:09,  5.49s/it]

R2: 0.6138591920476673  corr:  0.7842210868754791  pval:  0.0


 20%|██        | 60/300 [05:22<22:34,  5.64s/it]

R2: 0.615928574672998  corr:  0.7855300914346366  pval:  0.0


 23%|██▎       | 70/300 [06:14<20:46,  5.42s/it]

R2: 0.6168614520604905  corr:  0.7864857587252547  pval:  0.0


 30%|███       | 90/300 [07:55<18:45,  5.36s/it]

R2: 0.62603009842409  corr:  0.791939831057504  pval:  0.0


 33%|███▎      | 99/300 [08:41<18:19,  5.47s/it]

R2: 0.6265239166741159  corr:  0.7928462568046536  pval:  0.0


 33%|███▎      | 100/300 [08:48<19:17,  5.79s/it]

R2: 0.6267753915869922  corr:  0.7929042844343935  pval:  0.0


100%|██████████| 300/300 [25:44<00:00,  5.15s/it]
  0%|          | 1/300 [00:06<32:54,  6.60s/it]

R2: -0.0015437997857112862  corr:  0.015433761993059566  pval:  0.0


  2%|▏         | 5/300 [00:28<27:46,  5.65s/it]

R2: 0.054140378496096764  corr:  0.2536907882763861  pval:  0.0


  3%|▎         | 8/300 [00:44<27:48,  5.72s/it]

R2: 0.12727907913645398  corr:  0.43781204405217267  pval:  0.0


  3%|▎         | 9/300 [00:50<28:39,  5.91s/it]

R2: 0.21153590093182717  corr:  0.5077614819601091  pval:  0.0


  3%|▎         | 10/300 [00:57<29:37,  6.13s/it]

R2: 0.31163042776858585  corr:  0.5851223712102469  pval:  0.0


  4%|▎         | 11/300 [01:03<29:59,  6.23s/it]

R2: 0.3273904423468029  corr:  0.5817716872424362  pval:  0.0


  4%|▍         | 12/300 [01:10<30:16,  6.31s/it]

R2: 0.3524429695517389  corr:  0.6038484821016638  pval:  0.0


  4%|▍         | 13/300 [01:17<30:30,  6.38s/it]

R2: 0.3603517356764937  corr:  0.6281264778676472  pval:  0.0


  5%|▍         | 14/300 [01:23<30:27,  6.39s/it]

R2: 0.40250697811873315  corr:  0.6517768598485707  pval:  0.0


  5%|▌         | 15/300 [01:29<30:14,  6.37s/it]

R2: 0.4523490544884746  corr:  0.6789603410147186  pval:  0.0


  5%|▌         | 16/300 [01:36<30:44,  6.49s/it]

R2: 0.47366974229197456  corr:  0.6953659595438956  pval:  0.0


  6%|▌         | 17/300 [01:42<30:24,  6.45s/it]

R2: 0.4775200617605093  corr:  0.6954347013893194  pval:  0.0


  6%|▌         | 18/300 [01:49<30:07,  6.41s/it]

R2: 0.4952531067330521  corr:  0.7045846973798738  pval:  0.0


  6%|▋         | 19/300 [01:55<29:53,  6.38s/it]

R2: 0.5022596001041982  corr:  0.7105436171822119  pval:  0.0


  7%|▋         | 20/300 [02:01<29:39,  6.36s/it]

R2: 0.507388352981013  corr:  0.7139739361924753  pval:  0.0


  8%|▊         | 25/300 [02:28<26:19,  5.74s/it]

R2: 0.5222369334352299  corr:  0.7317120568041994  pval:  0.0


  9%|▊         | 26/300 [02:35<27:31,  6.03s/it]

R2: 0.5285716016290332  corr:  0.7314063741732876  pval:  0.0


  9%|▉         | 27/300 [02:41<27:59,  6.15s/it]

R2: 0.5410306130997906  corr:  0.7386514131227219  pval:  0.0


  9%|▉         | 28/300 [02:48<27:52,  6.15s/it]

R2: 0.5490226504440673  corr:  0.7435315890884846  pval:  0.0


 10%|▉         | 29/300 [02:54<28:42,  6.36s/it]

R2: 0.5586101650748441  corr:  0.7486356740326895  pval:  0.0


 10%|█         | 30/300 [03:01<28:41,  6.38s/it]

R2: 0.561270540884182  corr:  0.7498078923676486  pval:  0.0


 10%|█         | 31/300 [03:07<28:56,  6.46s/it]

R2: 0.5644126721640679  corr:  0.7526367169545998  pval:  0.0


 12%|█▏        | 36/300 [03:34<25:06,  5.71s/it]

R2: 0.5651813046759513  corr:  0.7561497235762936  pval:  0.0


 12%|█▏        | 37/300 [03:41<26:13,  5.98s/it]

R2: 0.5661336419218104  corr:  0.7545022142591359  pval:  0.0


 13%|█▎        | 38/300 [03:47<26:35,  6.09s/it]

R2: 0.5774355832742308  corr:  0.7605583515803909  pval:  0.0


 13%|█▎        | 39/300 [03:54<27:05,  6.23s/it]

R2: 0.5819990330972011  corr:  0.7641333476727517  pval:  0.0


 13%|█▎        | 40/300 [04:00<26:56,  6.22s/it]

R2: 0.5875245726522915  corr:  0.767246907996118  pval:  0.0


 16%|█▌        | 47/300 [04:37<23:37,  5.60s/it]

R2: 0.5911812270045432  corr:  0.7720984806083034  pval:  0.0


 16%|█▌        | 48/300 [04:43<24:32,  5.84s/it]

R2: 0.5998107880430761  corr:  0.7757004473264952  pval:  0.0


 16%|█▋        | 49/300 [04:50<25:08,  6.01s/it]

R2: 0.6055087199052924  corr:  0.7788460637506526  pval:  0.0


 20%|█▉        | 59/300 [05:44<24:00,  5.98s/it]

R2: 0.6106140106942608  corr:  0.7828277181226039  pval:  0.0


 20%|██        | 60/300 [05:50<24:43,  6.18s/it]

R2: 0.6128043767242366  corr:  0.7838670751551141  pval:  0.0


 23%|██▎       | 68/300 [06:33<22:02,  5.70s/it]

R2: 0.6138071186636151  corr:  0.7839118480915368  pval:  0.0


 23%|██▎       | 69/300 [06:39<22:49,  5.93s/it]

R2: 0.614473299241495  corr:  0.7845107046803144  pval:  0.0


 23%|██▎       | 70/300 [06:46<23:11,  6.05s/it]

R2: 0.6154836827322648  corr:  0.7855298644730316  pval:  0.0


 25%|██▌       | 76/300 [07:18<21:14,  5.69s/it]

R2: 0.6159174074915408  corr:  0.7863848512358603  pval:  0.0


 26%|██▌       | 78/300 [07:30<21:30,  5.81s/it]

R2: 0.6159897814746194  corr:  0.7862562211316659  pval:  0.0


 26%|██▋       | 79/300 [07:36<22:16,  6.05s/it]

R2: 0.619671166855716  corr:  0.7886212617804316  pval:  0.0


 27%|██▋       | 80/300 [07:43<22:24,  6.11s/it]

R2: 0.6218342892488216  corr:  0.7897377329291776  pval:  0.0


 30%|███       | 90/300 [08:35<19:25,  5.55s/it]

R2: 0.6220358857044279  corr:  0.7896741456893929  pval:  0.0


 36%|███▋      | 109/300 [10:13<17:50,  5.60s/it]

R2: 0.6240336190605127  corr:  0.790543636162392  pval:  0.0


100%|██████████| 300/300 [26:29<00:00,  5.30s/it]
  0%|          | 1/300 [00:06<33:40,  6.76s/it]

R2: -0.03316714712932778  corr:  0.00026478458457687134  pval:  0.048956306830173504


  2%|▏         | 5/300 [00:28<28:36,  5.82s/it]

R2: 0.017158029898281635  corr:  0.19160264773416358  pval:  0.0


  2%|▏         | 6/300 [00:35<29:59,  6.12s/it]

R2: 0.06601494100476213  corr:  0.2971917808393958  pval:  0.0


  2%|▏         | 7/300 [00:41<30:24,  6.23s/it]

R2: 0.12445799935720381  corr:  0.40555012071678975  pval:  0.0


  3%|▎         | 8/300 [00:48<30:14,  6.21s/it]

R2: 0.17107344022538196  corr:  0.4516404005797627  pval:  0.0


  3%|▎         | 9/300 [00:54<30:01,  6.19s/it]

R2: 0.1892471734107437  corr:  0.47022978454170616  pval:  0.0


  3%|▎         | 10/300 [01:00<30:19,  6.27s/it]

R2: 0.21646888063491632  corr:  0.5026264095047551  pval:  0.0


  4%|▍         | 12/300 [01:12<29:13,  6.09s/it]

R2: 0.32990448355496527  corr:  0.6074879474430622  pval:  0.0


  4%|▍         | 13/300 [01:18<29:50,  6.24s/it]

R2: 0.3787113654202945  corr:  0.6382372661977299  pval:  0.0


  5%|▍         | 14/300 [01:25<30:13,  6.34s/it]

R2: 0.4197871304459877  corr:  0.6554681523453335  pval:  0.0


  5%|▌         | 16/300 [01:37<29:20,  6.20s/it]

R2: 0.4359400991065421  corr:  0.6811896916414312  pval:  0.0


  6%|▌         | 17/300 [01:43<29:21,  6.22s/it]

R2: 0.45790375774126324  corr:  0.6868785152748417  pval:  0.0


  6%|▋         | 19/300 [01:55<28:11,  6.02s/it]

R2: 0.46656206053970806  corr:  0.6983889936370677  pval:  0.0


  8%|▊         | 23/300 [02:16<25:48,  5.59s/it]

R2: 0.48177399218467276  corr:  0.7054724644650922  pval:  0.0


  8%|▊         | 24/300 [02:22<27:29,  5.98s/it]

R2: 0.4904143356143196  corr:  0.7050949928781053  pval:  0.0


  9%|▊         | 26/300 [02:34<27:18,  5.98s/it]

R2: 0.49181253378604217  corr:  0.7037176911363677  pval:  0.0


  9%|▉         | 27/300 [02:41<27:50,  6.12s/it]

R2: 0.5006115061267931  corr:  0.7162748854224544  pval:  0.0


  9%|▉         | 28/300 [02:47<28:01,  6.18s/it]

R2: 0.527728371171437  corr:  0.7319771271678166  pval:  0.0


 11%|█▏        | 34/300 [03:19<25:44,  5.80s/it]

R2: 0.5402732585481673  corr:  0.739459739878305  pval:  0.0


 12%|█▏        | 35/300 [03:26<26:51,  6.08s/it]

R2: 0.5582706557715269  corr:  0.7487899170465994  pval:  0.0


 12%|█▏        | 36/300 [03:33<27:19,  6.21s/it]

R2: 0.5610474367964701  corr:  0.7523646711488564  pval:  0.0


 12%|█▏        | 37/300 [03:39<27:16,  6.22s/it]

R2: 0.573240214995822  corr:  0.7577599762149645  pval:  0.0


 13%|█▎        | 39/300 [03:51<26:50,  6.17s/it]

R2: 0.5783247458437056  corr:  0.7609620400004239  pval:  0.0


 16%|█▌        | 47/300 [04:33<23:41,  5.62s/it]

R2: 0.5890626123168721  corr:  0.7698049941856276  pval:  0.0


 16%|█▋        | 49/300 [04:44<23:58,  5.73s/it]

R2: 0.5931589967888203  corr:  0.7711439006154254  pval:  0.0


 19%|█▊        | 56/300 [05:21<22:52,  5.62s/it]

R2: 0.597347695501647  corr:  0.7762098622266265  pval:  0.0


 19%|█▉        | 58/300 [05:33<23:34,  5.84s/it]

R2: 0.5989996989021396  corr:  0.7749487430597766  pval:  0.0


 20%|█▉        | 59/300 [05:40<24:44,  6.16s/it]

R2: 0.6033273801367713  corr:  0.777420797489799  pval:  0.0


 20%|██        | 60/300 [05:46<25:03,  6.27s/it]

R2: 0.6057508468559365  corr:  0.7790132601292957  pval:  0.0


 23%|██▎       | 69/300 [06:36<21:51,  5.68s/it]

R2: 0.6092370606827795  corr:  0.7811090783867461  pval:  0.0


 23%|██▎       | 70/300 [06:42<22:42,  5.92s/it]

R2: 0.6109674953013615  corr:  0.7823116038329149  pval:  0.0


 26%|██▋       | 79/300 [07:30<20:22,  5.53s/it]

R2: 0.6127260132603567  corr:  0.7838227633539826  pval:  0.0


 27%|██▋       | 80/300 [07:37<20:57,  5.72s/it]

R2: 0.6129137530455857  corr:  0.7836993103650363  pval:  0.0


 32%|███▏      | 97/300 [09:05<18:48,  5.56s/it]

R2: 0.6197695073264466  corr:  0.7886248604137899  pval:  0.0


100%|██████████| 300/300 [26:19<00:00,  5.26s/it]
  0%|          | 1/300 [00:06<31:26,  6.31s/it]

R2: -0.0006282975095324161  corr:  0.0026424926584389095  pval:  5.787260503359797e-86


  2%|▏         | 5/300 [00:28<28:07,  5.72s/it]

R2: 0.021075705417660795  corr:  0.1461637473170458  pval:  0.0


  2%|▏         | 6/300 [00:35<30:05,  6.14s/it]

R2: 0.07466245391276582  corr:  0.27802272612575946  pval:  0.0


  2%|▏         | 7/300 [00:41<30:24,  6.23s/it]

R2: 0.17160447635864262  corr:  0.429006754627126  pval:  0.0


  3%|▎         | 8/300 [00:47<30:24,  6.25s/it]

R2: 0.21863420902352315  corr:  0.47232374268478056  pval:  0.0


  3%|▎         | 9/300 [00:54<30:31,  6.29s/it]

R2: 0.2849484684445247  corr:  0.5398141003206105  pval:  0.0


  3%|▎         | 10/300 [01:00<30:32,  6.32s/it]

R2: 0.36212559175019754  corr:  0.6080289484959691  pval:  0.0


  4%|▍         | 12/300 [01:12<30:11,  6.29s/it]

R2: 0.40960718381653594  corr:  0.64992451866408  pval:  0.0


  5%|▍         | 14/300 [01:24<29:12,  6.13s/it]

R2: 0.41345812920153946  corr:  0.6533377527914935  pval:  0.0


  5%|▌         | 15/300 [01:30<29:37,  6.24s/it]

R2: 0.47279474987633263  corr:  0.6964908762392664  pval:  0.0


  6%|▌         | 17/300 [01:42<28:44,  6.09s/it]

R2: 0.4816172099893643  corr:  0.7058349680594962  pval:  0.0


  6%|▌         | 18/300 [01:48<29:04,  6.19s/it]

R2: 0.506794235400245  corr:  0.7120915902519299  pval:  0.0


  6%|▋         | 19/300 [01:55<29:57,  6.40s/it]

R2: 0.5212007952642161  corr:  0.7242695379992574  pval:  0.0


  7%|▋         | 20/300 [02:02<30:04,  6.44s/it]

R2: 0.5243397611077788  corr:  0.7256766869263607  pval:  0.0


  9%|▉         | 27/300 [02:38<25:28,  5.60s/it]

R2: 0.5325925469822734  corr:  0.739117068183782  pval:  0.0


  9%|▉         | 28/300 [02:45<26:23,  5.82s/it]

R2: 0.5507344804129664  corr:  0.7445206716900759  pval:  0.0


 10%|▉         | 29/300 [02:51<27:07,  6.01s/it]

R2: 0.5599863638579505  corr:  0.7495711508562309  pval:  0.0


 10%|█         | 30/300 [02:58<28:02,  6.23s/it]

R2: 0.5673353168931713  corr:  0.7540534951785383  pval:  0.0


 12%|█▏        | 35/300 [03:24<24:41,  5.59s/it]

R2: 0.5840225683011717  corr:  0.7666093352745545  pval:  0.0


 12%|█▏        | 37/300 [03:36<24:56,  5.69s/it]

R2: 0.584277380461091  corr:  0.767687782997532  pval:  0.0


 13%|█▎        | 38/300 [03:42<26:03,  5.97s/it]

R2: 0.5924123450201199  corr:  0.7725825913383597  pval:  0.0


 13%|█▎        | 39/300 [03:48<26:18,  6.05s/it]

R2: 0.602581759809633  corr:  0.7771622385519175  pval:  0.0


 13%|█▎        | 40/300 [03:55<26:40,  6.16s/it]

R2: 0.6030695631260418  corr:  0.7773840210637939  pval:  0.0


 16%|█▋        | 49/300 [04:42<23:12,  5.55s/it]

R2: 0.6073919757501071  corr:  0.7805450343125985  pval:  0.0


 17%|█▋        | 50/300 [04:48<24:05,  5.78s/it]

R2: 0.6131053219260855  corr:  0.7843568293637257  pval:  0.0


 20%|█▉        | 59/300 [05:35<22:15,  5.54s/it]

R2: 0.6188747064383308  corr:  0.7879177333303278  pval:  0.0


 20%|██        | 60/300 [05:42<23:03,  5.76s/it]

R2: 0.6202192895419973  corr:  0.7886597401168431  pval:  0.0


 23%|██▎       | 70/300 [06:33<21:00,  5.48s/it]

R2: 0.620934528881445  corr:  0.7899899546434245  pval:  0.0


 26%|██▋       | 79/300 [07:22<22:03,  5.99s/it]

R2: 0.6217368034993609  corr:  0.7897502767188778  pval:  0.0


 27%|██▋       | 80/300 [07:29<22:44,  6.20s/it]

R2: 0.6240789568205893  corr:  0.7915214034535415  pval:  0.0


 30%|██▉       | 89/300 [08:16<19:47,  5.63s/it]

R2: 0.6256538666179954  corr:  0.7941150901057075  pval:  0.0


 30%|███       | 90/300 [08:23<21:00,  6.00s/it]

R2: 0.6264261278058786  corr:  0.7940102752374032  pval:  0.0


100%|██████████| 300/300 [26:13<00:00,  5.25s/it]
  0%|          | 1/300 [00:06<33:31,  6.73s/it]

R2: -0.018229460217154658  corr:  0.00556433940467576  pval:  0.0


  2%|▏         | 7/300 [00:38<27:20,  5.60s/it]

R2: 0.05751096944536804  corr:  0.28740069693210557  pval:  0.0


  3%|▎         | 8/300 [00:45<28:43,  5.90s/it]

R2: 0.16486853667226187  corr:  0.44415704604923406  pval:  0.0


  3%|▎         | 9/300 [00:51<29:22,  6.06s/it]

R2: 0.25669342682842167  corr:  0.5427629243890321  pval:  0.0


  3%|▎         | 10/300 [00:58<30:24,  6.29s/it]

R2: 0.346208814863293  corr:  0.6114339343845758  pval:  0.0


  4%|▍         | 13/300 [01:15<28:18,  5.92s/it]

R2: 0.40642429284596904  corr:  0.6566840099944266  pval:  0.0


  5%|▍         | 14/300 [01:21<29:05,  6.10s/it]

R2: 0.4492485658308216  corr:  0.6825976702938619  pval:  0.0


  5%|▌         | 15/300 [01:28<29:23,  6.19s/it]

R2: 0.4880535570753063  corr:  0.7033267872163506  pval:  0.0


  5%|▌         | 16/300 [01:34<29:43,  6.28s/it]

R2: 0.4926857234223082  corr:  0.718553475698194  pval:  0.0


  6%|▌         | 17/300 [01:41<29:49,  6.32s/it]

R2: 0.5155658959978272  corr:  0.7195878654598475  pval:  0.0


  6%|▌         | 18/300 [01:47<30:09,  6.42s/it]

R2: 0.526059914148702  corr:  0.7269355870807951  pval:  0.0


  7%|▋         | 20/300 [01:59<28:44,  6.16s/it]

R2: 0.5322066327359942  corr:  0.7303958964208621  pval:  0.0


  9%|▊         | 26/300 [02:31<26:07,  5.72s/it]

R2: 0.54544250539532  corr:  0.7438646911006821  pval:  0.0


  9%|▉         | 28/300 [02:42<26:22,  5.82s/it]

R2: 0.5679101083919891  corr:  0.7539841526577954  pval:  0.0


 12%|█▏        | 36/300 [03:23<23:49,  5.41s/it]

R2: 0.5786710866750011  corr:  0.7625090236179486  pval:  0.0


 12%|█▏        | 37/300 [03:30<24:53,  5.68s/it]

R2: 0.5803364967253681  corr:  0.7618856577845442  pval:  0.0


 13%|█▎        | 38/300 [03:36<25:48,  5.91s/it]

R2: 0.5887553305014943  corr:  0.7684704952363218  pval:  0.0


 13%|█▎        | 39/300 [03:43<26:17,  6.04s/it]

R2: 0.5904348805644098  corr:  0.7686322321848731  pval:  0.0


 13%|█▎        | 40/300 [03:49<27:03,  6.24s/it]

R2: 0.5945785529179832  corr:  0.7714572059930984  pval:  0.0


 16%|█▌        | 48/300 [04:31<23:08,  5.51s/it]

R2: 0.5980559096486104  corr:  0.7744976167981404  pval:  0.0


 16%|█▋        | 49/300 [04:37<24:24,  5.83s/it]

R2: 0.6086055009084821  corr:  0.7812597836531738  pval:  0.0


 17%|█▋        | 50/300 [04:44<25:21,  6.08s/it]

R2: 0.6115583368827363  corr:  0.7827419337693097  pval:  0.0


 20%|██        | 60/300 [05:36<21:58,  5.49s/it]

R2: 0.6131989641098979  corr:  0.7840404116928135  pval:  0.0


 23%|██▎       | 69/300 [06:23<20:53,  5.43s/it]

R2: 0.6167169135907209  corr:  0.7856360282420907  pval:  0.0


 23%|██▎       | 70/300 [06:29<22:00,  5.74s/it]

R2: 0.6183759831428652  corr:  0.7869659727734233  pval:  0.0


 26%|██▋       | 79/300 [07:17<20:47,  5.65s/it]

R2: 0.6197396085321369  corr:  0.7877339260176767  pval:  0.0


 27%|██▋       | 80/300 [07:24<21:35,  5.89s/it]

R2: 0.6217943507650505  corr:  0.7890262137386943  pval:  0.0


 29%|██▉       | 87/300 [08:01<19:37,  5.53s/it]

R2: 0.6231486471687127  corr:  0.7914339182754748  pval:  0.0


 32%|███▏      | 96/300 [08:48<18:28,  5.44s/it]

R2: 0.6233934342378226  corr:  0.7914692115432793  pval:  0.0


100%|██████████| 300/300 [26:12<00:00,  5.24s/it]
  0%|          | 1/300 [00:06<32:17,  6.48s/it]

R2: -0.0007789802671600565  corr:  0.0010498166034659897  pval:  5.876411530896112e-15


  2%|▏         | 5/300 [00:28<28:39,  5.83s/it]

R2: 0.0359552689968764  corr:  0.18994818520896273  pval:  0.0


  2%|▏         | 6/300 [00:34<29:20,  5.99s/it]

R2: 0.06815262998189209  corr:  0.2620691165403036  pval:  0.0


  2%|▏         | 7/300 [00:41<29:43,  6.09s/it]

R2: 0.13981643780819974  corr:  0.38861120565173357  pval:  0.0


  3%|▎         | 8/300 [00:47<30:39,  6.30s/it]

R2: 0.26374901142557416  corr:  0.5320139691426169  pval:  0.0


  3%|▎         | 9/300 [00:54<30:58,  6.38s/it]

R2: 0.33524335158597685  corr:  0.5856947727082902  pval:  0.0


  3%|▎         | 10/300 [01:01<31:02,  6.42s/it]

R2: 0.35922007706863335  corr:  0.6001333985808912  pval:  0.0


  5%|▍         | 14/300 [01:23<28:30,  5.98s/it]

R2: 0.3953623004051613  corr:  0.6318157629526158  pval:  0.0


  5%|▌         | 15/300 [01:30<29:26,  6.20s/it]

R2: 0.42193172028310133  corr:  0.6537097352486497  pval:  0.0


  5%|▌         | 16/300 [01:36<30:03,  6.35s/it]

R2: 0.44738681380140777  corr:  0.6717939298477538  pval:  0.0


  6%|▌         | 17/300 [01:43<30:29,  6.46s/it]

R2: 0.4684338325295986  corr:  0.685787045166421  pval:  0.0


  6%|▋         | 19/300 [01:55<29:19,  6.26s/it]

R2: 0.49178752102397794  corr:  0.7035271345744654  pval:  0.0


  7%|▋         | 20/300 [02:01<29:22,  6.29s/it]

R2: 0.49342545105507485  corr:  0.7031643972646906  pval:  0.0


  8%|▊         | 25/300 [02:28<25:52,  5.65s/it]

R2: 0.5083684616875751  corr:  0.7170974794793952  pval:  0.0


  9%|▉         | 27/300 [02:40<26:33,  5.84s/it]

R2: 0.5321721910736049  corr:  0.7321810310286805  pval:  0.0


  9%|▉         | 28/300 [02:46<27:01,  5.96s/it]

R2: 0.5382960986392074  corr:  0.737167838595496  pval:  0.0


 10%|▉         | 29/300 [02:52<27:28,  6.08s/it]

R2: 0.551343681466733  corr:  0.7440704004904314  pval:  0.0


 10%|█         | 30/300 [02:59<28:11,  6.27s/it]

R2: 0.5533797625056854  corr:  0.7453211819446206  pval:  0.0


 12%|█▏        | 37/300 [03:36<24:45,  5.65s/it]

R2: 0.5818793731911912  corr:  0.763495575506687  pval:  0.0


 13%|█▎        | 40/300 [03:53<24:53,  5.74s/it]

R2: 0.5840096093250249  corr:  0.7654105811114822  pval:  0.0


 16%|█▌        | 48/300 [04:35<23:10,  5.52s/it]

R2: 0.5871627346597227  corr:  0.7677400089588543  pval:  0.0


 16%|█▋        | 49/300 [04:42<24:10,  5.78s/it]

R2: 0.5908168115250296  corr:  0.7700137040578662  pval:  0.0


 17%|█▋        | 50/300 [04:48<24:48,  5.95s/it]

R2: 0.5938365372908213  corr:  0.7725108542548687  pval:  0.0


 19%|█▉        | 58/300 [05:29<21:40,  5.37s/it]

R2: 0.5954004049967399  corr:  0.7727432918841005  pval:  0.0


 23%|██▎       | 68/300 [06:21<21:18,  5.51s/it]

R2: 0.61007410107071  corr:  0.7821630023731243  pval:  0.0


 30%|██▉       | 89/300 [08:10<19:30,  5.55s/it]

R2: 0.6133731820258185  corr:  0.7849559696900829  pval:  0.0


 30%|███       | 90/300 [08:16<20:25,  5.84s/it]

R2: 0.6139456073811544  corr:  0.7855636021435151  pval:  0.0


 69%|██████▊   | 206/300 [18:09<08:17,  5.29s/it]


KeyboardInterrupt: 

In [None]:
vi1 = 'T_xy_ins'

vo1 = 'ssh_cos'
vo2 = 'ssh_sin'

save_fn_prefix  = 'any_{}_{}{}_nobatchnorm'.format(vi1, vo1, vo2)
var_input_names = [vi1]
var_output_names = [vo1, vo2]

batch_size = 100 #maximizing it so that the GPU memory maxes out. Needs to be divisible by Ntrain. Otherwise there will be size mismatch issues.

N_inp = len(var_input_names)
N_out = len(var_output_names)

lr0 = 0.005*10/batch_size #Roughly should scale inversely to batch_size

#Recording performance metrics on test data after eaching training cycle
R2_all = np.zeros(nensemble)
corr_all = np.zeros(nensemble)

nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)

all_train_input, all_train_output = preload_data(nctrains, Ntrain)
all_test_input, all_test_output = preload_data(nctest, Ntest)

#Normalize data
#Compute mean and variance for normalization
mean_input=np.nanmean(np.concatenate((all_train_input, all_test_input), axis=0),axis=(0, 2, 3))
mean_output=np.nanmean(np.concatenate((all_train_output, all_test_output), axis=0),axis=(0, 2, 3))
#Subtract the data with their means
all_train_input=all_train_input-mean_input[None, :, None, None]
all_train_output=all_train_output-mean_output[None, :, None, None]
all_test_input=all_test_input-mean_input[None, :, None, None]
all_test_output=all_test_output-mean_output[None, :, None, None]
#Compute the variances
var_input=np.nanmean((np.concatenate((all_train_input, all_test_input), axis=0))**2,axis=(0, 2, 3))
var_output=np.nanmean((np.concatenate((all_train_output, all_test_output), axis=0))**2,axis=(0, 2, 3))
print("mean and variance of all input data:")
print(mean_input,var_input)
print("mean and variance of all output data:")
print(mean_output,var_output)
#Scale the data so that they have variance of 1
all_train_input=all_train_input/np.sqrt(var_input[None, :, None, None])
all_train_output=all_train_output/np.sqrt(var_output[None, :, None, None])
all_test_input=all_test_input/np.sqrt(var_input[None, :, None, None])
all_test_output=all_test_output/np.sqrt(var_output[None, :, None, None])
#Have checked that after these operations, the data is scaled to be zero mean and variance 1.

for iensemble in np.arange(nensemble):
    run_model(var_input_names, var_output_names, save_fn_prefix, N_inp, N_out, iensemble, R2_all, corr_all)  
print('R2 from the best models in each run are:')
print(R2_all)
print('corr from the best models in each run are:')
print(corr_all)

In [None]:
vi1 = 'u_xy_ins'
vi2 = 'v_xy_ins'

vo1 = 'ssh_cos'
vo2 = 'ssh_sin'

save_fn_prefix  = 'any_{}{}_{}{}_nobatchnorm'.format(vi1, vi2, vo1, vo2)
var_input_names = [vi1, vi2]
var_output_names = [vo1, vo2]

batch_size = 80 #maximizing it so that the GPU memory maxes out. Needs to be divisible by Ntrain. Otherwise there will be size mismatch issues.

N_inp = len(var_input_names)
N_out = len(var_output_names)

lr0 = 0.005*10/batch_size #Roughly should scale inversely to batch_size

#Recording performance metrics on test data after eaching training cycle
R2_all = np.zeros(nensemble)
corr_all = np.zeros(nensemble)

nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)

all_train_input, all_train_output = preload_data(nctrains, Ntrain)
all_test_input, all_test_output = preload_data(nctest, Ntest)

#Normalize data
#Compute mean and variance for normalization
mean_input=np.nanmean(np.concatenate((all_train_input, all_test_input), axis=0),axis=(0, 2, 3))
mean_output=np.nanmean(np.concatenate((all_train_output, all_test_output), axis=0),axis=(0, 2, 3))
#Subtract the data with their means
all_train_input=all_train_input-mean_input[None, :, None, None]
all_train_output=all_train_output-mean_output[None, :, None, None]
all_test_input=all_test_input-mean_input[None, :, None, None]
all_test_output=all_test_output-mean_output[None, :, None, None]
#Compute the variances
var_input=np.nanmean((np.concatenate((all_train_input, all_test_input), axis=0))**2,axis=(0, 2, 3))
var_output=np.nanmean((np.concatenate((all_train_output, all_test_output), axis=0))**2,axis=(0, 2, 3))
print("mean and variance of all input data:")
print(mean_input,var_input)
print("mean and variance of all output data:")
print(mean_output,var_output)
#Scale the data so that they have variance of 1
all_train_input=all_train_input/np.sqrt(var_input[None, :, None, None])
all_train_output=all_train_output/np.sqrt(var_output[None, :, None, None])
all_test_input=all_test_input/np.sqrt(var_input[None, :, None, None])
all_test_output=all_test_output/np.sqrt(var_output[None, :, None, None])
#Have checked that after these operations, the data is scaled to be zero mean and variance 1.

for iensemble in np.arange(nensemble):
    run_model(var_input_names, var_output_names, save_fn_prefix, N_inp, N_out, iensemble, R2_all, corr_all)  
print('R2 from the best models in each run are:')
print(R2_all)
print('corr from the best models in each run are:')
print(corr_all)

In [None]:
vi1 = 'ssh_ins'
vi2 = 'T_xy_ins'

vo1 = 'ssh_cos'
vo2 = 'ssh_sin'

save_fn_prefix  = 'any_{}{}_{}{}_nobatchnorm'.format(vi1, vi2, vo1, vo2)
var_input_names = [vi1, vi2]
var_output_names = [vo1, vo2]

batch_size = 80 #maximizing it so that the GPU memory maxes out. Needs to be divisible by Ntrain. Otherwise there will be size mismatch issues.

N_inp = len(var_input_names)
N_out = len(var_output_names)

lr0 = 0.005*10/batch_size #Roughly should scale inversely to batch_size

#Recording performance metrics on test data after eaching training cycle
R2_all = np.zeros(nensemble)
corr_all = np.zeros(nensemble)

nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)

all_train_input, all_train_output = preload_data(nctrains, Ntrain)
all_test_input, all_test_output = preload_data(nctest, Ntest)

#Normalize data
#Compute mean and variance for normalization
mean_input=np.nanmean(np.concatenate((all_train_input, all_test_input), axis=0),axis=(0, 2, 3))
mean_output=np.nanmean(np.concatenate((all_train_output, all_test_output), axis=0),axis=(0, 2, 3))
#Subtract the data with their means
all_train_input=all_train_input-mean_input[None, :, None, None]
all_train_output=all_train_output-mean_output[None, :, None, None]
all_test_input=all_test_input-mean_input[None, :, None, None]
all_test_output=all_test_output-mean_output[None, :, None, None]
#Compute the variances
var_input=np.nanmean((np.concatenate((all_train_input, all_test_input), axis=0))**2,axis=(0, 2, 3))
var_output=np.nanmean((np.concatenate((all_train_output, all_test_output), axis=0))**2,axis=(0, 2, 3))
print("mean and variance of all input data:")
print(mean_input,var_input)
print("mean and variance of all output data:")
print(mean_output,var_output)
#Scale the data so that they have variance of 1
all_train_input=all_train_input/np.sqrt(var_input[None, :, None, None])
all_train_output=all_train_output/np.sqrt(var_output[None, :, None, None])
all_test_input=all_test_input/np.sqrt(var_input[None, :, None, None])
all_test_output=all_test_output/np.sqrt(var_output[None, :, None, None])
#Have checked that after these operations, the data is scaled to be zero mean and variance 1.

for iensemble in np.arange(nensemble):
    run_model(var_input_names, var_output_names, save_fn_prefix, N_inp, N_out, iensemble, R2_all, corr_all)  
print('R2 from the best models in each run are:')
print(R2_all)
print('corr from the best models in each run are:')
print(corr_all)

In [None]:
vi1 = 'ssh_ins'
vi2 = 'u_xy_ins'
vi3 = 'v_xy_ins'

vo1 = 'ssh_cos'
vo2 = 'ssh_sin'

save_fn_prefix  = 'any_{}{}{}_{}{}_nobatchnorm'.format(vi1, vi2, vi3, vo1, vo2)

batch_size = 60 #maximizing it so that the GPU memory maxes out. Needs to be divisible by Ntrain. Otherwise there will be size mismatch issues.

var_input_names = [vi1, vi2, vi3]
var_output_names = [vo1, vo2]

N_inp = len(var_input_names)
N_out = len(var_output_names)

lr0 = 0.005*10/batch_size #Roughly should scale inversely to batch_size
def preload_data(nctrains, total_records):
    #total_records = Ntrain#sum(nc.dimensions['time_counter'].size for nc in nctrains)
    #dimensions of data of the nc file.
    max_height = 722
    max_width = 258
    all_input_data = np.zeros((total_records, N_inp, max_height, max_width))
    all_output_data = np.zeros((total_records, N_out, max_height, max_width))
    current_index = 0
    for ncindex, ncdata in enumerate(nctrains):
        num_recs = ncdata.dimensions['time_counter'].size
        rec_slice = slice(current_index, current_index + num_recs)
        
        for ind, var_name in enumerate(var_input_names):
            data_slice = np.squeeze(ncdata.variables[var_name])
            # print('data_slice shape:')
            # print(data_slice.shape)        
            #all_input_data[rec_slice, ind, :, :] = data_slice
            #For some variables, the dimensions in (x, y) may be smaller than (max_height, max_width). Changing the code so that it adapts them.
            # Get the actual dimensions of data_slice
            slice_height, slice_width = data_slice.shape[-2], data_slice.shape[-1]
            # Place data_slice into the corresponding slice of all_input_data
            all_input_data[rec_slice, ind, :slice_height, :slice_width] = data_slice
    

        for ind, var_name in enumerate(var_output_names):
            data_slice = np.squeeze(ncdata.variables[var_name])
            #all_output_data[rec_slice, ind, :, :] = data_slice
            # Get the actual dimensions of data_slice
            slice_height, slice_width = data_slice.shape[-2], data_slice.shape[-1]
            # Place data_slice into the corresponding slice of all_input_data
            all_output_data[rec_slice, ind, :slice_height, :slice_width] = data_slice

        current_index += num_recs
        
    return all_input_data, all_output_data
#Recording performance metrics on test data after eaching training cycle
R2_all = np.zeros(nensemble)
corr_all = np.zeros(nensemble)

nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)

all_train_input, all_train_output = preload_data(nctrains, Ntrain)
all_test_input, all_test_output = preload_data(nctest, Ntest)


#Normalize data
#Compute mean and variance for normalization
mean_input=np.nanmean(np.concatenate((all_train_input, all_test_input), axis=0),axis=(0, 2, 3))
mean_output=np.nanmean(np.concatenate((all_train_output, all_test_output), axis=0),axis=(0, 2, 3))
#Subtract the data with their means
all_train_input=all_train_input-mean_input[None, :, None, None]
all_train_output=all_train_output-mean_output[None, :, None, None]
all_test_input=all_test_input-mean_input[None, :, None, None]
all_test_output=all_test_output-mean_output[None, :, None, None]
#Compute the variances
var_input=np.nanmean((np.concatenate((all_train_input, all_test_input), axis=0))**2,axis=(0, 2, 3))
var_output=np.nanmean((np.concatenate((all_train_output, all_test_output), axis=0))**2,axis=(0, 2, 3))
print("mean and variance of all input data:")
print(mean_input,var_input)
print("mean and variance of all output data:")
print(mean_output,var_output)
#Scale the data so that they have variance of 1
all_train_input=all_train_input/np.sqrt(var_input[None, :, None, None])
all_train_output=all_train_output/np.sqrt(var_output[None, :, None, None])
all_test_input=all_test_input/np.sqrt(var_input[None, :, None, None])
all_test_output=all_test_output/np.sqrt(var_output[None, :, None, None])
#Have checked that after these operations, the data is scaled to be zero mean and variance 1.

for iensemble in np.arange(nensemble):
    run_model(var_input_names, var_output_names, save_fn_prefix, N_inp, N_out, iensemble, R2_all, corr_all)  
print('R2 from the best models in each run are:')
print(R2_all)
print('corr from the best models in each run are:')
print(corr_all)

In [None]:
vi1 = 'T_xy_ins'
vi2 = 'u_xy_ins'
vi3 = 'v_xy_ins'

vo1 = 'ssh_cos'
vo2 = 'ssh_sin'

save_fn_prefix  = 'any_{}{}{}_{}{}_nobatchnorm'.format(vi1, vi2, vi3, vo1, vo2)

batch_size = 60 #maximizing it so that the GPU memory maxes out. Needs to be divisible by Ntrain. Otherwise there will be size mismatch issues.

var_input_names = [vi1, vi2, vi3]
var_output_names = [vo1, vo2]

N_inp = len(var_input_names)
N_out = len(var_output_names)

lr0 = 0.005*10/batch_size #Roughly should scale inversely to batch_size
def preload_data(nctrains, total_records):
    #total_records = Ntrain#sum(nc.dimensions['time_counter'].size for nc in nctrains)
    #dimensions of data of the nc file.
    max_height = 722
    max_width = 258
    all_input_data = np.zeros((total_records, N_inp, max_height, max_width))
    all_output_data = np.zeros((total_records, N_out, max_height, max_width))
    current_index = 0
    for ncindex, ncdata in enumerate(nctrains):
        num_recs = ncdata.dimensions['time_counter'].size
        rec_slice = slice(current_index, current_index + num_recs)
        
        for ind, var_name in enumerate(var_input_names):
            data_slice = np.squeeze(ncdata.variables[var_name])
            # print('data_slice shape:')
            # print(data_slice.shape)        
            #all_input_data[rec_slice, ind, :, :] = data_slice
            #For some variables, the dimensions in (x, y) may be smaller than (max_height, max_width). Changing the code so that it adapts them.
            # Get the actual dimensions of data_slice
            slice_height, slice_width = data_slice.shape[-2], data_slice.shape[-1]
            # Place data_slice into the corresponding slice of all_input_data
            all_input_data[rec_slice, ind, :slice_height, :slice_width] = data_slice
    

        for ind, var_name in enumerate(var_output_names):
            data_slice = np.squeeze(ncdata.variables[var_name])
            #all_output_data[rec_slice, ind, :, :] = data_slice
            # Get the actual dimensions of data_slice
            slice_height, slice_width = data_slice.shape[-2], data_slice.shape[-1]
            # Place data_slice into the corresponding slice of all_input_data
            all_output_data[rec_slice, ind, :slice_height, :slice_width] = data_slice

        current_index += num_recs
        
    return all_input_data, all_output_data
#Recording performance metrics on test data after eaching training cycle
R2_all = np.zeros(nensemble)
corr_all = np.zeros(nensemble)

nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)

all_train_input, all_train_output = preload_data(nctrains, Ntrain)
all_test_input, all_test_output = preload_data(nctest, Ntest)


#Normalize data
#Compute mean and variance for normalization
mean_input=np.nanmean(np.concatenate((all_train_input, all_test_input), axis=0),axis=(0, 2, 3))
mean_output=np.nanmean(np.concatenate((all_train_output, all_test_output), axis=0),axis=(0, 2, 3))
#Subtract the data with their means
all_train_input=all_train_input-mean_input[None, :, None, None]
all_train_output=all_train_output-mean_output[None, :, None, None]
all_test_input=all_test_input-mean_input[None, :, None, None]
all_test_output=all_test_output-mean_output[None, :, None, None]
#Compute the variances
var_input=np.nanmean((np.concatenate((all_train_input, all_test_input), axis=0))**2,axis=(0, 2, 3))
var_output=np.nanmean((np.concatenate((all_train_output, all_test_output), axis=0))**2,axis=(0, 2, 3))
print("mean and variance of all input data:")
print(mean_input,var_input)
print("mean and variance of all output data:")
print(mean_output,var_output)
#Scale the data so that they have variance of 1
all_train_input=all_train_input/np.sqrt(var_input[None, :, None, None])
all_train_output=all_train_output/np.sqrt(var_output[None, :, None, None])
all_test_input=all_test_input/np.sqrt(var_input[None, :, None, None])
all_test_output=all_test_output/np.sqrt(var_output[None, :, None, None])
#Have checked that after these operations, the data is scaled to be zero mean and variance 1.

for iensemble in np.arange(nensemble):
    run_model(var_input_names, var_output_names, save_fn_prefix, N_inp, N_out, iensemble, R2_all, corr_all)  
print('R2 from the best models in each run are:')
print(R2_all)
print('corr from the best models in each run are:')
print(corr_all)

In [None]:
vi1 = 'ssh_ins'
vi2 = 'T_xy_ins'
vi3 = 'u_xy_ins'
vi4 = 'v_xy_ins'

vo1 = 'ssh_cos'
vo2 = 'ssh_sin'

save_fn_prefix  = 'any_{}{}{}{}_{}{}_nobatchnorm'.format(vi1, vi2, vi3, vi4, vo1, vo2)
var_input_names = [vi1, vi2, vi3, vi4]
var_output_names = [vo1, vo2]

batch_size = 50 #maximizing it so that the GPU memory maxes out. Needs to be divisible by Ntrain. Otherwise there will be size mismatch issues.

N_inp = len(var_input_names)
N_out = len(var_output_names)

lr0 = 0.005*10/batch_size #Roughly should scale inversely to batch_size

#Recording performance metrics on test data after eaching training cycle
R2_all = np.zeros(nensemble)
corr_all = np.zeros(nensemble)

nctrains, nctest = hf.load_data_from_nc_as_lists(root_dir)

all_train_input, all_train_output = preload_data(nctrains, Ntrain)
all_test_input, all_test_output = preload_data(nctest, Ntest)

#Normalize data
#Compute mean and variance for normalization
mean_input=np.nanmean(np.concatenate((all_train_input, all_test_input), axis=0),axis=(0, 2, 3))
mean_output=np.nanmean(np.concatenate((all_train_output, all_test_output), axis=0),axis=(0, 2, 3))
#Subtract the data with their means
all_train_input=all_train_input-mean_input[None, :, None, None]
all_train_output=all_train_output-mean_output[None, :, None, None]
all_test_input=all_test_input-mean_input[None, :, None, None]
all_test_output=all_test_output-mean_output[None, :, None, None]
#Compute the variances
var_input=np.nanmean((np.concatenate((all_train_input, all_test_input), axis=0))**2,axis=(0, 2, 3))
var_output=np.nanmean((np.concatenate((all_train_output, all_test_output), axis=0))**2,axis=(0, 2, 3))
print("mean and variance of all input data:")
print(mean_input,var_input)
print("mean and variance of all output data:")
print(mean_output,var_output)
#Scale the data so that they have variance of 1
all_train_input=all_train_input/np.sqrt(var_input[None, :, None, None])
all_train_output=all_train_output/np.sqrt(var_output[None, :, None, None])
all_test_input=all_test_input/np.sqrt(var_input[None, :, None, None])
all_test_output=all_test_output/np.sqrt(var_output[None, :, None, None])
#Have checked that after these operations, the data is scaled to be zero mean and variance 1.

for iensemble in np.arange(nensemble):
    run_model(var_input_names, var_output_names, save_fn_prefix, N_inp, N_out, iensemble, R2_all, corr_all)  
print('R2 from the best models in each run are:')
print(R2_all)
print('corr from the best models in each run are:')
print(corr_all)