## Generate the ANN parameters

The MOM6 ANN module reads data in a very specific format. 

There are 3 parameters that are set: 
- USE_ANN - this is a logical, which needs to be set in MOM_parameter/ MOM_override. It can be .true. or .false.
- ANN_num_layers - This the number of layers in the model (including the input and output layers) 
- ANN_PARAMS_FILE - The name of the nextcdf file, which contains all the weights, biases, normalization constants etc. 

##### The ANN_PARAMS_FILE

This netcdf file has a number of variables. 
- layer_sizes - An integer array, with size equal to ANN_num_layers. Each integer corresponds to the number of nodes in that layers. 
- An, bn - These are weights and bias matrices. n ranges from 0 onwards. The number of these parameter sets will be equal to the (number of layers -1); this is because each set of parameters transforms nodes of one layer to the next.
- normalizations (not implemented yet). 

In [None]:
import numpy as np
import xarray as xr

### Diffusive repro case

In [None]:
num_layers = 2 # input slopes and output KS

In [None]:
ds = xr.Dataset()

In [None]:
ds['layer_sizes'] = xr.DataArray(np.array([2,2]).astype('int32'), dims=['num_layers'])

ds['A0'] = xr.DataArray(np.array([[1000,0], [0, 1000]]).astype('float32'), dims=['input', 'output'])
ds['A0'] = ds.A0 * 20e3
ds['b0'] = xr.DataArray(np.array([0,0]).astype('float32'), dims=['output'])

In [None]:
ds

In [17]:
ds.to_netcdf('/scratch/db194/mom6/tests_Phillips_2layer/compare_GM_2_ANN/Phillips_2layer_20km_ANN/ann_params.nc', mode='w')

### Arbitrary test case 

In [27]:
ds = xr.Dataset()

In [39]:
ds['layer_sizes'] = xr.DataArray(np.array([2,4,4,2]).astype('int32'), dims=['num_layers'])

In [90]:
np.random.seed(0)
ds['A0'] = xr.DataArray(np.random.randn(ds.layer_sizes[0].data, ds.layer_sizes[1].data), dims=['input', 'layer1'] )
ds['b0'] = xr.DataArray(np.random.randn(ds.layer_sizes[1].data), dims=['layer1'] )

ds['A1'] = xr.DataArray(np.random.randn(ds.layer_sizes[1].data, ds.layer_sizes[2].data), dims=['layer1', 'layer2'] )
ds['b1'] = xr.DataArray(np.random.randn(ds.layer_sizes[2].data), dims=['layer2'] )

ds['A2'] = xr.DataArray(np.random.randn(ds.layer_sizes[2].data, ds.layer_sizes[3].data), dims=['layer2', 'output'] )
ds['b2'] = xr.DataArray(np.random.randn(ds.layer_sizes[3].data), dims=['output'] )

In [91]:
ds.A0

In [92]:
ds.b2

In [94]:
ds.to_netcdf('/scratch/db194/mom6/tests_Phillips_2layer/Phillips_2layer_20km_ANN/ann_params.nc', mode='w')

In [95]:
xr.open_dataset('/scratch/db194/mom6/tests_Phillips_2layer/Phillips_2layer_20km_ANN/ann_params.nc')

In [96]:
# in Flax there are two parts (https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.Dense.html)
# - a part that evaluates the dense layer
# - a part that sets up the weights etc. 
def dense(A, b, x): 
    ''' dense layer
        y = xA + b
        Note : row vector format
    '''

    y = np.matmul(x, A) + b
    
    return y

def relu(x): 
    ''' relu evaluation (non-linear activation function) 
    '''
    y = np.maximum(x,0.)
    
    return y

In [97]:
x = np.array([6, 7]).astype('float32')

In [98]:
dense(ds.A0.data, ds.b0.data, x)

array([23.55400116, -4.02940341, 12.6670904 , 13.84013224])

In [99]:
y = np.zeros(4)

for j in range(4):
    for i in range(2):
        y[j] = y[j] + x[i] * ds.A0.data[i,j]
        
    y[j] = y[j] + ds.b0.data[j]         

In [100]:
y

array([23.55400116, -4.02940341, 12.6670904 , 13.84013224])

In [57]:
dense(ds.A2.data, ds.b2.data, dense(ds.A1.data, ds.b1.data, dense(ds.A0.data, ds.b0.data, x)))

array([38.11837683, 49.53725901])

In [101]:
relu(dense(ds.A2.data, ds.b2.data, relu(dense(ds.A1.data, ds.b1.data, relu(dense(ds.A0.data, ds.b0.data, x))) ) ) )

array([9.8018168, 0.       ])