# Training main text models

Here we train the 3 models that will be used in the main text. 

Differences: 3 stencil sizes: 1X1, 3X3, 5X5.

Common features:
- Trained on DG and P2L simulataneously.
- Trained on all filter scales at once.
- vel grads and thickness grads as inputs
- Rotated in and out
- 0 - 2048 training.
- Model shape [48,48,2] - will make slightly different model sizes, but generally parameter number increases with stencil size, which is what we would like.
- Non dim inputs and outputs
- MAE loss function
- Adam loss function with LR 0.01
- Num epochs chosen by stopping at 0.001 stabalize in loss.

In [1]:
import warnings

# Ignore all warnings
warnings.filterwarnings("ignore")

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

import sys
sys.path.append('../../modules/')

%reload_ext autoreload
%autoreload 2
import datasets
import ML_classes
import evaluation

2025-02-22 21:29:58.507040: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-02-22 21:29:58.521684: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-02-22 21:29:58.526008: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


## Setup experiment

We use the non-dim training here. This is so that the model can easily focus on data from both experiments.

The tests done in folder comparing dim and non-dim showed that non-dim and dim models can have similar skill. 

In [2]:
# We need one place to save all the experiment relevant info.  
common_config= {'simulation_names':['DG','P2L'], 
                    'filter_scales':['50','100','200','400'],
                    #'filter_scales':['100'],
                    #'window_size':3, 
                    'all_ml_variables' : ['dudx_widened_rotated_nondim',  # must include all variables, including those used as coefficients
                                          'dvdx_widened_rotated_nondim', 
                                          'dudy_widened_rotated_nondim',
                                          'dvdy_widened_rotated_nondim',
                                          'dhdx_widened_rotated_nondim',
                                          'dhdy_widened_rotated_nondim',
                                          'uphp_rotated_nondim',  # this non-dim has taken a particular form (see in paper, flux/L^2/|grad u|)
                                          'vphp_rotated_nondim'], 
                    
                    'input_channels' :   ['dudx_widened_rotated_nondim',  
                                          'dvdx_widened_rotated_nondim', 
                                          'dudy_widened_rotated_nondim',
                                          'dvdy_widened_rotated_nondim',
                                          'dhdx_widened_rotated_nondim',
                                          'dhdy_widened_rotated_nondim'],
                    
                    'output_channels' :  ['uphp_rotated_nondim',
                                          'vphp_rotated_nondim'],
                    
                    'coeff_channels'  : [], 

                    'extra_channels'   : [  'uphp_rotated',
                                            'vphp_rotated', 
                                            'mag_nabla_h_widened',
                                            'mag_nabla_u_widened',
                                            'filter_scale'],

                    'use_coeff_channels': False,
                    'single_layer_mask': True,

                    'all_time_range': slice(0, 3600),
                    'train_time_range': slice(0, 2048),
                    'test_time_range' : slice(-128, None),
                    'eval_time_range' : slice(-256, -128),
                    'num_train_batches': 128, 
                    'num_test_batches' : 8, 

                    #'num_inputs': 55, 
                    'network_shape': [48, 48, 2],

                    'ckpt_save_dir': '/home/jovyan/mesoscale_buoyancy_param_ML/ML_checkpoints/main_models/window_3/shape_48_48_2/'
                    
            }


In [3]:
experiment_configs = {'1point':{'window_size':1, 'num_inputs':6, 'exp_ckpt_save_dir': common_config['ckpt_save_dir']+'1point'},
                      '3point':{'window_size':3, 'num_inputs':3*3*6, 'exp_ckpt_save_dir': common_config['ckpt_save_dir']+'3point'},
                      '5point':{'window_size':5, 'num_inputs':5*5*6, 'exp_ckpt_save_dir': common_config['ckpt_save_dir']+'5point'}    
                        }

In [None]:
#DT = datasets.SimulationData(simulation_names=['P2L', 'DG'], filter_scales=['50','100','200','400'])
for key in experiment_configs.keys():

    print('Starting to load in DT for: ' + key)
    DT = datasets.SimulationData(simulation_names=common_config['simulation_names'], 
                         filter_scales=common_config['filter_scales'], 
                         window_size = experiment_configs[key]['window_size'], 
                         time_sel = common_config['all_time_range'],
                         single_layer_mask_flag=common_config['single_layer_mask']
                         )

    print('Starting to load in ML-DT for: ' + key)
    ML_DT_train = datasets.MLXarrayDataset(simulation_data=DT, 
                                       all_ml_variables=common_config['all_ml_variables'],
                                       time_range=common_config['train_time_range'],
                                       num_batches = common_config['num_train_batches'],
                                       choose_experiment=common_config['simulation_names'])

    ML_DT_test = datasets.MLXarrayDataset(simulation_data=DT, 
                                       all_ml_variables=common_config['all_ml_variables'],
                                       time_range=common_config['test_time_range'],
                                       num_batches = common_config['num_test_batches'],
                                       choose_experiment=common_config['simulation_names'])

    train_ML_data = datasets.MLJAXDataset(ML_DT_train, 
                                      input_channels=common_config['input_channels'], 
                                      output_channels=common_config['output_channels'], 
                                      coeff_channels=common_config['coeff_channels'], 
                                      use_coeff_channels=common_config['use_coeff_channels'],
                                      do_normalize=True)

    test_ML_data = datasets.MLJAXDataset(ML_DT_test, 
                                      input_channels=common_config['input_channels'], 
                                      output_channels=common_config['output_channels'], 
                                      coeff_channels=common_config['coeff_channels'], 
                                      use_coeff_channels=common_config['use_coeff_channels'],
                                      do_normalize=True)

    ML_data_combo = {'train_data':train_ML_data, 'test_data':test_ML_data}

    ANN_model = ML_classes.PointwiseANN(num_in = experiment_configs[key]['num_inputs'],
                                        shape  = common_config['network_shape'],
                                        random_key=1) 

    print('Num parameters to train ' + str(ANN_model.count_parameters()) + ' for ' + key)
          
    regress_sys = ML_classes.AnnRegressionSystem(ANN_model, loss_type='mae')

    print('Start training: ', key)
               
    regress_sys.train_system(ML_data_combo, num_epoch=501, print_freq=20, min_relative_improvement=1e-3)

    experiment_configs[key]['regress_sys'] = regress_sys

    regress_sys.save_checkpoint(experiment_configs[key]['exp_ckpt_save_dir'])
    

Starting to load in DT for: 1point
Starting to load in ML-DT for: 1point
Will load : 1.70606592 gb into memory.
load took: 46.3920 seconds
Will load : 0.10662912 gb into memory.
load took: 4.5018 seconds
Num parameters to train 2786 for 1point
Start training:  1point
At epoch 1. Train loss :  0.4628345638047904 , Test loss: 0.45763134583830833 , Test R2: -1.119167536497116
At epoch 21. Train loss :  0.45518867764621973 , Test loss: 0.4552800618112087 , Test R2: -1.1178016662597656
Early stopping at epoch 25. No improvement in 10 epochs.
Restored best model with smoothed test loss 0.455523
Starting to load in DT for: 3point
Starting to load in ML-DT for: 3point
Will load : 8.00538624 gb into memory.
load took: 134.1020 seconds
Will load : 0.50033664 gb into memory.
load took: 10.8898 seconds
Num parameters to train 5090 for 3point
Start training:  3point
At epoch 1. Train loss :  0.24254556372761726 , Test loss: 0.21503141149878502 , Test R2: 0.0999240055680275
At epoch 21. Train loss :

In [None]:
#DT = datasets.SimulationData(simulation_names=['P2L', 'DG'], filter_scales=['50','100','200','400'])
for key in ['5point']:

    print('Starting to load in DT for: ' + key)
    DT = datasets.SimulationData(simulation_names=common_config['simulation_names'], 
                         filter_scales=common_config['filter_scales'], 
                         window_size = experiment_configs[key]['window_size'], 
                         time_sel = common_config['all_time_range'],
                         single_layer_mask_flag=common_config['single_layer_mask']
                         )

    print('Starting to load in ML-DT for: ' + key)
    ML_DT_train = datasets.MLXarrayDataset(simulation_data=DT, 
                                       all_ml_variables=common_config['all_ml_variables'],
                                       time_range=common_config['train_time_range'],
                                       num_batches = common_config['num_train_batches'],
                                       choose_experiment=common_config['simulation_names'])

    ML_DT_test = datasets.MLXarrayDataset(simulation_data=DT, 
                                       all_ml_variables=common_config['all_ml_variables'],
                                       time_range=common_config['test_time_range'],
                                       num_batches = common_config['num_test_batches'],
                                       choose_experiment=common_config['simulation_names'])

    train_ML_data = datasets.MLJAXDataset(ML_DT_train, 
                                      input_channels=common_config['input_channels'], 
                                      output_channels=common_config['output_channels'], 
                                      coeff_channels=common_config['coeff_channels'], 
                                      use_coeff_channels=common_config['use_coeff_channels'],
                                      do_normalize=True)

    test_ML_data = datasets.MLJAXDataset(ML_DT_test, 
                                      input_channels=common_config['input_channels'], 
                                      output_channels=common_config['output_channels'], 
                                      coeff_channels=common_config['coeff_channels'], 
                                      use_coeff_channels=common_config['use_coeff_channels'],
                                      do_normalize=True)

    ML_data_combo = {'train_data':train_ML_data, 'test_data':test_ML_data}

    ANN_model = ML_classes.PointwiseANN(num_in = experiment_configs[key]['num_inputs'],
                                        shape  = common_config['network_shape'],
                                        random_key=1) 

    print('Num parameters to train ' + str(ANN_model.count_parameters()) + ' for ' + key)
          
    regress_sys = ML_classes.AnnRegressionSystem(ANN_model, loss_type='mae')

    print('Start training: ', key)
               
    regress_sys.train_system(ML_data_combo, num_epoch=501, print_freq=20, min_relative_improvement=1e-3)

    experiment_configs[key]['regress_sys'] = regress_sys

    regress_sys.save_checkpoint(experiment_configs[key]['exp_ckpt_save_dir'])
    

Starting to load in DT for: 5point


In [17]:
## Loss plot
num_models = len(exp_sets.keys())

fig = plt.figure(figsize=(5, 4*num_models))  

for i, set_keys in enumerate(exp_sets.keys()):
    regress_sys = exp_sets[set_keys]['regress_sys'] 
    plt.subplot(num_models, 1, i+1)
    plt.plot(regress_sys.train_loss, label='Training loss, '+str(regress_sys.train_loss[-1]))
    plt.plot(regress_sys.test_loss, label='Test loss'+str(regress_sys.test_loss[-1]))

    plt.grid()
    plt.yscale('log')
    plt.title(exp_sets[set_keys]['sel_sim'])
    plt.legend()

KeyError: 'regress_sys'

<Figure size 500x1200 with 0 Axes>