# Inference of subgrid forcing via neural networks

Firstly, run the following. When prompted, choose a model among the list by providing its integer id (first column on the left). Note that this usually takes quite a long time to run (I don't know why but the kernel takes very long to start it seems, this is not related to the code being run).

In [None]:
import matplotlib.pyplot as plt

import numpy as np
import os, sys
sys.path.insert(1, os.path.join(os.getcwd()  , '../../src/gz21_ocean_momentum'))
from analysis.utils import view_predictions, DisplayMode, plot_dataset
from utils import select_run, select_experiment
from analysis.utils import play_movie
import mlflow
from mlflow.tracking import MlflowClient
import xarray as xr
time = 50

# Prompts the user to select a trained model
mlflow.set_tracking_uri(os.path.join(os.getcwd(), '../../mlruns'))
cols = ['params.model_cls_name', 'params.loss_cls_name']
exp_id, _ =select_experiment()
run = select_run(sort_by='metrics.test loss', cols=cols, experiment_ids=[exp_id,])

# Display some info about the train and validation sets for this run
print(run)

In [None]:
print(run)

In [None]:
# Download model output dataset
client = MlflowClient()
run_id = run['run_id']
data_id = 0
output_file = client.download_artifacts(run_id, 
                                        f'model_output/test_output{data_id}')
model_output = xr.open_zarr(output_file)

In [None]:
model_output

In [None]:
raw_data = client.download_artifacts(run['params.run_id'].split('/')[0], 'forcing')
raw_data = xr.open_zarr(raw_data)
raw_datasets = load_training_datasets(raw_data, 'training_subdomains.yaml')

In [None]:
raw_datasets

In [None]:
data = raw_datasets[0]
data['time_index'] = xr.DataArray(np.arange(len(data.coords['time'])),
                                       dims = ('time',),
                                       coords = {'time' : data['time']})
data = data.swap_dims({'time' : 'time_index'})

In [None]:
from random import randint
plot_dataset(data[['usurf', 'vsurf']].isel(xu_ocean=randint(0, len(data['xu_ocean'])),
                                           yu_ocean=randint(0, len(data['yu_ocean']))))

In [None]:
plot_dataset(data[['usurf', 'vsurf']].mean(dim='time_index'))
_ = plt.suptitle('Average mean flow')

In [None]:
train_mse = client.get_metric_history(run_id, 'train loss')
test_mse = client.get_metric_history(run_id, 'test loss')
train_mse = np.array([el.value for el in train_mse])
test_mse = np.array([el.value for el in test_mse])
plt.figure(figsize = (18, 7))
plt.subplot(1, 2, 1)
plt.plot(train_mse)
plt.plot(test_mse)
plt.xlabel('Epoch number')
plt.ylabel('MSE')
plt.legend(('Train MSE', 'Test MSE'))
plt.subplot(1, 2, 2)
plt.plot(10*np.log10(train_mse + 3.5))
plt.plot(10*np.log10(test_mse + 3.5))
plt.xlabel('Epoch number')
plt.ylabel(r'$10 \times \log_{10} \ MSE$')
plt.legend(('Train MSE', 'Test MSE'))

The train loss and the test loss initially decrease steeply with the number of epochs (each epoch has around 600 samples and our number of parameters is not that high as we only use convolutional layers plus a final locally connected layer). 

The validation data for this dataset is available in the form of a dataset. Run the following to add a time index and print the model_output dataset. The variables u_surf and v_surf are the surface velocity components that are used as input. 

In [None]:
model_output['time_index'] = xr.DataArray(np.arange(len(model_output.coords['time'])),
                                       dims = ('time',),
                                       coords = {'time' : model_output['time']})
model_output = model_output.swap_dims({'time' : 'time_index'})
model_output

We create a dataset of the errors

In [None]:
test_ds=xr.Dataset(dict(a=xr.DataArray([1,2,3], dims=('x',))), coords=dict(x=[5, 12, 13]))

In [None]:
xr.DataArray([1,2,3], coords=dict(x=[1,2,5]), dims=test_ds.dims)

In [None]:
model_output['S_xscale'] = 1 / (model_output['S_xscale'])
model_output['S_yscale'] = 1 / (model_output['S_yscale'])
errors_x = model_output['S_xpred'] - model_output['S_x']
errors_y = model_output['S_ypred'] - model_output['S_y']
errors_x_n = errors_x / model_output['S_xscale']
errors_y_n = errors_y / model_output['S_yscale']
mse_x = (errors_x**2).mean(dim='time_index')
mse_y = (errors_y**2).mean(dim='time_index')
mse_time = ((errors_x + errors_y)**2).mean(dim='latitude').mean(dim='longitude')
mse_time_n = ((errors_x_n + errors_y_n)**2).mean(dim='latitude').mean(dim='longitude')
errors_ds = xr.Dataset({'S_x (error)' : errors_x, 'S_y (error)' : errors_y, 
                        'S_x (mse)' : mse_x, 'S_y (mse)' : mse_y,
                        'S_x (normalised error)' : errors_x_n,
                        'S_y (normalised error)' : errors_y_n,
                        'mse (time)' : mse_time,
                         'mse (time, normalized)': mse_time_n})


In [None]:
plot_dataset(model_output[['u_surf', 'v_surf', 'S_x', 'S_y', 'S_xpred', 'S_ypred', 'S_xscale', 'S_yscale']], plot_type='hist', bins=np.arange(-5,5, 0.2))

In [None]:
(abs(model_output)).max().compute()

We plot a snapshot corresponding to a random day from our test data. The first row correspond to the two components of the surface velocity field. The second row correspond to the two components of the "true" forcing. The third row corresponds to the two components of the predicted subgrid forcing.

In [None]:
from random import randint
n_times = len(model_output['time'])
random_time = randint(340, 500)
#random_time=301
plot_dataset(model_output.isel(time_index=random_time)[['u_surf', 'v_surf', 'S_x', 'S_y', 'S_xpred', 'S_ypred', 
                                                         'S_xscale', 'S_yscale']],
            vmin = [-2]*6 + [-0.0, 0.0], vmax = [2]*6+[1, 1])
print(random_time)


## Distribution of true vs pred 

In [None]:
model_output['rez'] = (model_output['S_xpred'] - model_output['S_x']) / model_output['S_xscale']

In [None]:
from dask.diagnostics import ProgressBar
def func(x):
    return np.power(x, 3) / 100
def func2(x):
    return x
with ProgressBar():
    groups = model_output.groupby_bins('S_xpred', func(np.arange(-10, 10.1, 0.25)))
groups

In [None]:
with ProgressBar():
    m = groups.apply(lambda x: x.mean(skipna=True)).compute()
    s = groups.apply(lambda x: x.std(skipna=True)).compute()
    sup = groups.apply(lambda x: x.max(skipna=True)).compute()

In [None]:
%matplotlib notebook
plt.figure()
plt.plot(m['S_xpred'], m['S_x'])
plt.plot(m['S_xpred'], m['S_xscale'])
plt.plot(m['S_xpred'], s['S_x'])
plt.plot(m['S_xpred'], sup['S_xscale'])
plt.plot(m['S_xpred'], m['rez'])
plt.plot(np.arange(-15, 15), np.arange(-15, 15))
plt.legend(('m S_x', 'm S_xscale', 's S_x', 'sup S_xscale', 'm rez'))

In [None]:
from random import randint
plt.figure()
with ProgressBar():
    for i, g in enumerate(groups):
        if g[0].left < -4:
            continue
        g[1]['S_x'].plot.hist(bins=np.arange(-20, 20, 0.25))
        g[1]['S_xpred'].plot.hist(bins=np.arange(-20, 20, 0.25), alpha=0.5)
        plt.legend(('truth', 'pred'))
        plt.title(str(g[0].left) + ' -> ' + str(g[0].right))
        break

In [None]:
ani = dataset_to_movie(model_output.isel(time_index=slice(0, 200))[['u_surf', 'v_surf', 'S_x', 'S_xpred', 'S_y', 'S_ypred']],
                      interval = 200)

In [None]:
video = ani.to_html5_video()

In [None]:
from IPython.display import HTML
HTML(video)

We do a quick analysis by showing the MSE across time at all spatial points of our domain. We also plot the mean ampltiude of the velocity components as well as its variance.

In [None]:
plot_dataset(errors_ds[['S_x (mse)', 'S_y (mse)']], vmin=0, vmax=1)

In [None]:
plot_dataset(model_output[['u_surf', 'v_surf']].mean(dim='time_index'))

In [None]:
plot_dataset(model_output[['u_surf', 'v_surf']].std(dim='time_index'))

It doesn't seem far fetched to associate the larger errors in the predicted subgrid forcing with the larger variance of the velocity field, at least for the NW area.
We could look at the time series of the predictions for the specific areas with larger errors.

In [None]:
fig = plt.figure(figsize=(20, 30))
long = -172
lat = -34

plt.subplot(2, 1, 1)
time = slice(0, 400)
model_output['S_y'].isel(time_index=time).sel(longitude=long, latitude=lat, method='nearest').plot(linewidth=3)
model_output['S_ypred'].isel(time_index=time).sel(longitude=long, latitude=lat, method='nearest').plot(linewidth=3)
uB = model_output['S_ypred'] + 1.96 * model_output['S_yscale']
lB = model_output['S_ypred'] - 1.96 * model_output['S_yscale']
uB.isel(time_index=time).sel(longitude=long, latitude=lat, method='nearest').plot(linestyle='--',color='gray')
lB.isel(time_index=time).sel(longitude=long, latitude=lat, method='nearest').plot(linestyle='--',color='gray')
plt.ylim(-1, 1)
plt.legend(('True forcing', 'Inferred forcing', '95% confidence interval'))

We see that the amplitude of the forcing reaches 15 stds at some point. This needs investigation. It also turns out that the same phenomenon is observed for the NW location where larger errors are seen.

We also look at the aspect of the error through time.

In [None]:
(errors_ds['S_x (normalised error)']).isel(time_index=randint(0, 500)).plot(vmax=2)

In [None]:
errors_ds['S_x (normalised error)'].sel(longitude=-172, latitude=-34, method='nearest').plot()

In [None]:
((errors_ds['S_x (normalised error)'])**2).mean(dim=('longitude', 'latitude')).plot()

Again we look at the amplitude of the velocity field along the same dimension (time this time)

In [None]:
plot_dataset(model_output[['u_surf', 'v_surf', 'S_x', 'S_y']].mean(dim='latitude').mean(dim='longitude'))

In [None]:
from analysis.utils import sample
from scipy.stats import norm
from scipy.stats import laplace
from scipy.stats import t
t0 = t(6)
errors_ds = errors_ds.sel(longitude=slice(-175, -167), latitude=slice(-36, -30))
residuals = errors_ds[['S_x (normalised error)', 'S_y (normalised error)']].to_array().compute().data
residuals = residuals.swapaxes(0, 1)
s0 = sample(residuals, 5, 50)
s1 = s0 / np.std(s0)
plt.hist(s1[:, :, :].reshape((-1, 1)), bins = 500, density=True)
plt.plot(np.arange(-5, 5, 0.01), norm.pdf(np.arange(-5, 5, 0.01)))
plt.plot(np.arange(-5, 5, 0.01), laplace.pdf(np.arange(-5, 5, 0.01)))
# plt.plot(np.arange(-5, 5, 0.01), t0.pdf(np.arange(-5, 5, 0.01)))


plt.xlim([-7, 7])

In [None]:
np.mean(residuals)

In [None]:
from scipy.stats import norm
t0 = t(6)
errors = s1[:, :, :].reshape(-1, 1)
errors = errors - np.mean(errors)
n = errors.shape[0]
norm_quantiles = norm.ppf(np.linspace(1/n, 1 - 1/n, n))
sorted_errors = np.sort(errors, axis=None)
plt.plot(norm_quantiles, norm_quantiles)
plt.plot(norm_quantiles, sorted_errors)
_ = plt.title('Quantile-Quantile plot of the errors')