In [None]:
import mlflow
import xarray as xr
import matplotlib.pyplot as plt
import os,sys
sys.path.insert(1, os.path.join(os.getcwd()  , '../../src/gz21_ocean_momentum'))
from utils import select_experiment, select_run

mlruns_path=os.path.join(os.getcwd(), '../../mlruns')
%env MLFLOW_TRACKING_URI $mlruns_path

In [None]:
def load_data_from_run(i_run: int):
    run = runs.iloc[i_run]
    print(run)
    filenames = os.listdir(run['artifact_uri'])
    datasets = []
    for fn in filenames:
        print(f'Loading {fn}')
        name = fn.split('_')[0]
        datasets.append(xr.open_dataset(os.path.join(run['artifact_uri'], fn)))
    return datasets

In [None]:
exp_id, _ = select_experiment(default_selection='22')
runs=mlflow.search_runs(experiment_ids=(exp_id,))

In [None]:
datasets = load_data_from_run(0)

In [None]:
u = datasets[0]
v = datasets[4]
eta = datasets[3]
u = u.interp(dict(x=eta.x, y=eta.y))
v = v.interp(dict(x=eta.x, y=eta.y))

In [None]:
uv_high_rez = xr.merge((u, v)).rename(dict(x='xu_ocean', y='yu_ocean', t='time', u='usurf', v='vsurf'))

In [None]:
uv_high_rez

In [None]:
from data.coarse import eddy_forcing
import numpy as np

In [None]:
dxu = xr.DataArray(dims=('xu_ocean', 'yu_ocean'), data=np.ones((384, 384)) * 1e4,
                  coords=dict(xu_ocean=uv_high_rez.xu_ocean, yu_ocean=uv_high_rez.yu_ocean))
dyu = xr.DataArray(dims=('xu_ocean', 'yu_ocean'), data=np.ones((384, 384)) * 1e4,
                  coords=dict(xu_ocean=uv_high_rez.xu_ocean, yu_ocean=uv_high_rez.yu_ocean))
grid_data = xr.Dataset(dict(dxu=dxu, dyu=dyu))
grid_data

In [None]:
uv_high_rez = uv_high_rez.chunk(dict(time=500))

In [None]:
template=uv_high_rez.coarsen(dict(xu_ocean=4, yu_ocean=4)).mean()
template = xr.merge((template, template.rename(dict(usurf='S_x', vsurf='S_y'))))
low_rez = xr.map_blocks(lambda x: eddy_forcing(x, grid_data, 4), uv_high_rez,
                        template=template)

In [None]:
low_rez.nbytes / 1e9

In [None]:
from dask.diagnostics import ProgressBar
with ProgressBar():
    low_rez = low_rez.compute()

In [None]:
low_rez

In [None]:
%matplotlib notebook
low_rez['S_x'].isel(time=1100).plot(vmin=-1e-7, vmax=1e-7, cmap='Spectral')

In [None]:
from models.models1 import FullyCNN
net = FullyCNN(padding='same')


In [None]:
net

In [None]:
import pickle
def pickle_artifact(run_id: str, path: str):
    client = mlflow.tracking.MlflowClient()
    file = client.download_artifacts(run_id, path)
    f = open(file, 'rb')
    return pickle.load(f)

client = mlflow.tracking.MlflowClient()

models_experiment_id, _ = select_experiment()
cols = ['metrics.test loss', 'start_time', 'params.time_indices',
        'params.model_cls_name', 'params.source.run_id', 'params.submodel']
model_run = select_run(sort_by='start_time', cols=cols,
                       experiment_ids=[models_experiment_id, ])
model_file = client.download_artifacts(model_run.run_id,
                                       'models/trained_model.pth')
transformation = pickle_artifact(model_run.run_id, 'models/transformation')
net.final_transformation = transformation

In [None]:
import torch
net.load_state_dict(torch.load(model_file))

In [None]:
net

In [None]:
device = torch.device('cuda')
from train.losses import HeteroskedasticGaussianLossV2

In [None]:
criterion = HeteroskedasticGaussianLossV2(n_target_channels=2)

In [None]:
from testing.utils import create_large_test_dataset
from torch.utils.data import DataLoader

In [None]:
from data.datasets import (RawDataFromXrDataset, DatasetTransformer,
                           Subset_, DatasetWithTransform, ComposeTransforms,
                           MultipleTimeIndices, DatasetPartitioner)
dataset = RawDataFromXrDataset(low_rez * 10.)
dataset.index = 'time'
dataset.add_input('usurf')
dataset.add_input('vsurf')
dataset.add_output('S_x')
dataset.add_output('S_y')
features_transform_ = ComposeTransforms()
targets_transform_ = ComposeTransforms()
transform = DatasetTransformer(features_transform_, targets_transform_)
transform.fit(dataset)
dataset = DatasetWithTransform(dataset, transform)
test = create_large_test_dataset(net.to(device=device), criterion, [dataset, ], [DataLoader(dataset)], device)
test = test.rename(dict(longitude='xu_ocean', latitude='yu_ocean'))

In [None]:
with ProgressBar():
    test = test.compute()

In [None]:
test

In [None]:
plt.figure()
test.isel(time=1000)['S_x'].plot(vmin=-1, vmax=1, cmap='Spectral')

In [None]:
plt.figure()
(low_rez.isel(time=1000)['S_x']*1e7).plot(vmin=-1, vmax=1, cmap='Spectral')

In [None]:
from scipy.stats import norm
plt.figure()
((low_rez['S_x']*1e7 - test['S_x'])
 * (test['S_xscale'])).plot.hist(bins=np.linspace(-4, 4, 100), density=True)
plt.plot(np.linspace(-4, 4, 100), norm.pdf(np.linspace(-4, 4, 100)))

In [None]:
plt.figure()
for i, var in enumerate(['S_x', 'S_y']):
    residuals_t = ((low_rez[var]*1e7 - test[var]) * (test[f'{var}scale']))
    quantiles = np.linspace(0., 1., 100)
    norm_quantiles = norm.ppf(quantiles)
    obs_quantiles = np.nanquantile(residuals_t.values[::20, ::, ::], quantiles)
    plt.subplot(1, 2, i + 1)
    plt.plot(norm_quantiles, norm_quantiles, 'gray')
    plt.plot(norm_quantiles, obs_quantiles, 'k*')
    plt.ylim(-4, 4)
    plt.yticks(np.arange(-4, 5, 2))

In [None]:
plt.savefig('offline_test_swm1.jpg', dpi=400)

In [None]:
mse = dict()
variance = dict()
r_squared = dict()
correlation = dict()
for var in ['S_x', 'S_y']:
    mse[var] = ((test[var] - low_rez[var]*1e7)**2).mean(dim='time')
    variance[var] = ((low_rez[var]*1e7)**2).mean(dim='time')
    r_squared[var] = 1 - mse[var] / variance[var]
    correlation[var] = xr.corr(test[var], low_rez[var]*1e7, dim='time')

In [None]:
from matplotlib.patches import Arrow, Circle
fig = plt.figure()
extent = (0, 3840, 0, 3840)
for i, var in enumerate(['S_x', 'S_y']):
    plt.subplot(1, 2, i + 1)
    im = plt.imshow(r_squared[var].values, vmin=0.75, vmax=1, cmap='inferno',
                   origin='lower', extent=extent)
    im.axes.set_xticks([2000])
    if i > 0:
        im.axes.set_yticks([])
    else:
        im.axes.set_yticks([0, 1000, 2000, 3000])
    im.axes.set_xticks([0, 2000])
    if i == 0:
        im.axes.set_xlabel('km')
        im.axes.set_ylabel('km')
        patches = [Circle((25*40, 47*40), radius=2*40, color='white'),
                   Circle((80*40, 47*40), radius=2*40, color='green')]
        for patch in patches:
            im.axes.add_patch(patch)

fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.3, 0.025, 0.4])
cbar = fig.colorbar(im, cax=cbar_ax, label=r'$R^2$', ticks=[0.8, 0.9, 1])

In [None]:
plt.savefig('offline_test_swm2.jpg', dpi=400)

In [None]:
%matplotlib notebook
x = 80
y = 47
print(x, ' ', y)
plt.figure()
for i, var in enumerate(['S_x', 'S_y']):
    ax = plt.subplot(2, 1, i + 1)
    plt.plot(low_rez[var].isel(xu_ocean=x, yu_ocean=y, time=slice(2000, 2100))*1e7)
    plt.plot(test[var].isel(xu_ocean=x, yu_ocean=y, time=slice(2000, 2100)))
    ub = (test[var].isel(xu_ocean=x, yu_ocean=y, time=slice(2000, 2100))
          + 1.96 / test[f'{var}scale'].isel(xu_ocean=x, yu_ocean=y, time=slice(2000, 2100)))
    lb = (test[var].isel(xu_ocean=x, yu_ocean=y, time=slice(2000, 2100))
          - 1.96 / test[f'{var}scale'].isel(xu_ocean=x, yu_ocean=y, time=slice(2000, 2100)))
    plt.plot(ub, '--g')
    plt.plot(lb, '--g')
    if i == 0:
        ax.set_xticks([])
    if i == 1:
        plt.xlabel('time (days)')
    plt.ylabel(fr'${var}$' + '  ' +  r'$(1e^{-7}ms^{-2}$)')

In [None]:
plt.savefig(f"offline_test_swm3x={x}-y={y}.jpg", dpi=400)

In [None]:
plt.figure()
#((low_rez['S_y']*1e7 - test['S_y']).median(dim='time') ).plot()
#residuals_t.median(dim='time').plot(vmin=0, vmax=1)
(abs(((low_rez['S_x']*1e7 - test['S_x'])).mean(dim='time') / (low_rez['S_x']*1e7).std(dim='time'))).plot(vmin=0, vmax=0.5)

In [None]:
plt.figure()
(np.log((low_rez['S_y']*1e7).std(dim='time'))).plot(vmin=-5, vmax=5)

In [None]:
low_rez['S_y'].isel(xu_ocean=slice(45, None)).std()