In [None]:
import mlflow
from mlflow.tracking import client
import xarray as xr
import numpy as np
import dask.array as da
import matplotlib.pyplot as plt
import os,sys
sys.path.insert(1, os.path.join(os.getcwd()  , '../../src/gz21_ocean_momentum'))
from utils import select_experiment, select_run
from analysis.utils import plot_dataset, GlobalPlotter, anomalies
from data.pangeo_catalog import get_whole_data
from data.xrtransforms import SeasonalStdizer, TargetedTransform, ScalingTransform
from dask.diagnostics import ProgressBar
from models.submodels import transform3

import cartopy.crs as ccrs
!pip install cmocean
import cmocean

cmap = cmocean.cm.balance
cmap_balance = cmocean.cm.balance
cmap_balance_r=cmocean.cm.balance_r

cmap_amp = cmocean.cm.amp

plt.rcParams["figure.figsize"] = (4, 4 / 1.618)

uv_plotter = GlobalPlotter() 
uv_plotter.x_ticks = np.arange(-150., 151., 50)
uv_plotter.y_ticks = np.arange(-80., 81., 20)


%matplotlib notebook


In [None]:
CATALOG_URL = 'https://raw.githubusercontent.com/pangeo-data/pangeo-datastore\
/master/intake-catalogs/master.yaml'
data = get_whole_data(CATALOG_URL, 0)
grid_info = data[1]
# mask = grid_info['wet'].coarsen(dict(xt_ocean=4, yt_ocean=4))
# mask_ = mask.max()
# mask_ = mask_.where(mask_ > 0.1)

In [None]:
test_exp_name = select_experiment()
test_exp = mlflow.get_experiment_by_name(test_exp_name)
test_exp_id = test_exp.experiment_id
run = select_run(experiment_ids=test_exp_id, cols=['status', 'start_time', 'params.CO2', 'params.factor',
                                                  'params.submodel'],
                merge=[('data-global', 'params.data_run_id', 'run_id'),
                      ('modelsv1', 'params.model_run_id', 'run_id')])
client_ = client.MlflowClient()
data_file_name = client_.download_artifacts(run['params.data_run_id'], 'forcing')
print('Data path:', data_file_name)
data = xr.open_zarr(data_file_name)
data = data.rename({'xu_ocean': 'longitude', 'yu_ocean': 'latitude'})
data = data * 1e7
pred_file_name = client_.download_artifacts(run.run_id, 'test_output_0')
pred = xr.open_zarr(pred_file_name)
data = data.sel(time=slice(pred.time[0], pred.time[-1])).sel(latitude=slice(pred.latitude[0], pred.latitude[-1]))

In [None]:
lon = slice(None, None, 1)
lat= slice(-80, 80, 1)
time_slice = slice(None, None, 1)

p0 = pred['0'].sel(longitude=lon, latitude=lat).isel(time=time_slice)
p1 = pred['1'].sel(longitude=lon, latitude=lat).isel(time=time_slice)
p0 = np.exp(p0) / (np.exp(p0) + np.exp(p1))
p1 = 1 - p0
# Means
mu0 = pred['4'].sel(longitude=lon, latitude=lat).isel(time=time_slice)
mu1 = pred['8'].sel(longitude=lon, latitude=lat).isel(time=time_slice)
true = data['S_x'].sel(longitude=lon, latitude=lat).isel(time=time_slice)
# precisions
beta0 = pred['6'].sel(longitude=lon, latitude=lat).isel(time=time_slice)
beta1 = pred['10'].sel(longitude=lon, latitude=lat).isel(time=time_slice)

In [None]:
def apply_complete_mask(array):
    mask = uv_plotter.borders
    mask2 = uv_plotter.mask
    mask = mask.interp({k: array.coords[k] for k in ['longitude', 'latitude']})
    mask2 = mask2.interp({k: array.coords[k] for k in ['longitude', 'latitude']})
    array = array.where(np.isnan(mask) & (~np.isnan(mask2)))
    array = array.sel(latitude=slice(pred['latitude'][0], pred['latitude'][-1]))
    return array

## Time series analysis

In [None]:
lon = -129
lat= -55

p0_ = p0.sel(longitude=lon, latitude=lat, method='nearest')
mu0_ = mu0.sel(longitude=lon, latitude=lat, method='nearest')
mu1_ = mu1.sel(longitude=lon, latitude=lat, method='nearest')
beta0_ = beta0.sel(longitude=lon, latitude=lat, method='nearest')
beta1_ = beta1.sel(longitude=lon, latitude=lat, method='nearest')
true_ = true.sel(longitude=lon, latitude=lat, method='nearest')

plt.figure()
plt.plot(p0_)
plt.plot(true_, '-x')
plt.plot(mu0_, 'r')
plt.plot(mu1_, 'k')
plt.plot(mu0_ -1.96 * 1 / beta0_, '--r')
plt.plot(mu0_ + 1.96 * 1 / beta0_, '--r')
plt.plot(mu1_ - 1.96 * 1 / beta1_, '--k')
plt.plot(mu1_ + 1.96 * 1 / beta1_, '--k')

In [None]:
time = 796
n_samples = 1000000
sel = np.random.rand(n_samples) > float(p0_.isel(time=time))
sel = np.arange(n_samples) + sel * n_samples
samples0 = np.random.randn(n_samples) / float(beta0_.isel(time=time)) + float(mu0_.isel(time=time))
samples1 = np.random.randn(n_samples) / float(beta1_.isel(time=time)) + float(mu1_.isel(time=time))
samples = np.concatenate((samples0, samples1))
final_samples = samples[sel]

In [None]:
plt.figure()
print(float(true_.isel(time=time)))
_ = plt.hist(final_samples, bins=200)

## Global distribution of true and simulated forcing

We assess the goodness of fit

### Likelihood map

In [None]:
from scipy.stats import norm

In [None]:
def pdf(x, p0, mu0, beta0, mu1, beta1):
    return p0 * norm.pdf((x - mu0) * beta0) + (1 - p0) * norm.pdf((x - mu1) * beta1)

lkh = pdf(true, p0, mu0, beta0, mu1, beta1)

In [None]:
plt.figure()
plt.imshow(lkh.mean(dim='time'), origin='lower', vmin=0, vmax=0.5)

### Mean likelihood

In [None]:
lat = slice(-40, 40)
with ProgressBar():
    mean_lkh = apply_complete_mask(lkh.sel(latitude=lat)).mean().compute()

In [None]:
mean_lkh

### Goodness of fit

In [None]:
from scipy.stats import norm
def my_transform(x ,p0, mu0, beta0, mu1, beta1):
    cdf = lambda x: p0 * norm.cdf((x - mu0) * beta0) + (1 - p0) * norm.cdf((x - mu1) * beta1)
    return cdf(x)

v = my_transform(true, p0, mu0, beta0, mu1, beta1)
v = v.sel(latitude=slice(-40, 40))
v = apply_complete_mask(v)

In [None]:
v

In [None]:
plt.figure()
_ = plt.hist(v.values.flatten(), bins=200, density=True)

In [None]:
v.shape

In [None]:
quantiles = np.exp(np.linspace(-10, 10, 100)) / (1 + np.exp(np.linspace(-10, 10, 100)))

q = np.nanquantile(v, quantiles)

In [None]:
plt.figure()
plt.plot(quantiles, q, 'x')
plt.plot(quantiles, quantiles)

In [None]:
norm.ppf(0.001)

In [None]:
from scipy.stats import norm
s = norm.ppf(q)
plt.figure()
plt.plot(norm.ppf(quantiles),  s, 'x')
plt.plot(norm.ppf(quantiles),  norm.ppf(quantiles))
plt.axis([None, None, -18, None])


## Distribution of simulated vs true forcing

In [None]:
saved_true = true

In [None]:
def apply_complete_mask(array):
    mask = uv_plotter.borders
    mask2 = uv_plotter.mask
    mask = mask.interp({k: array.coords[k] for k in ['longitude', 'latitude']})
    mask2 = mask2.interp({k: array.coords[k] for k in ['longitude', 'latitude']})
    array = array.where(np.isnan(mask) & (~np.isnan(mask2)))
    array = array.sel(latitude=slice(pred['latitude'][0], pred['latitude'][-1]))
    return array

true = apply_complete_mask(saved_true)

epsilon = np.random.randn(*true.shape)
epsilon2 = np.random.randn(*true.shape)
bernouilli = np.random.rand(*true.shape) > p0
simulated = bernouilli * (mu0 + epsilon / beta0) + (1 - bernouilli) * (mu1 + epsilon / beta1)

simulated = apply_complete_mask(simulated)


In [None]:
plt.figure()
_ = plt.hist(true.values.ravel(), log=True, bins=np.arange(-20, 21, 0.5), density=True)
_ = plt.hist(simulated.values.ravel(), log=True, bins=np.arange(-20, 21, 0.5), density=True)


In [None]:
quantiles = np.exp(np.linspace(-3, 3, 100)) / (1 + np.exp(np.linspace(-3, 3, 100)))
quantiles

In [None]:
q_true = np.nanquantile(true.values.ravel(), quantiles)
q_simu = np.nanquantile(simulated.values.ravel(), quantiles)
plt.figure()
plt.plot(q_true, q_simu, 'x')
plt.plot(q_true, q_true)

In [None]:
plt.figure()
plt.plot(q_true, q_simu, 'x')
plt.plot(q_true, q_true, 'o')

In [None]:
simulated.shape

Quantiles analysis

In [None]:
import matplotlib
import cmocean
mean_pred = p0 * mu0 + p1 * mu1
plt.figure()
plt.imshow(np.abs((mean_pred.mean(dim='time') - true.mean(dim='time'))) / true.std(dim='time'), vmin=0.01, vmax=1, norm=matplotlib.colors.LogNorm()
, origin='lower', cmap=cmocean.cm.delta)
plt.colorbar()