In [None]:
import os,sys
sys.path.insert(1, os.path.join(os.getcwd()  , '../../src/gz21_ocean_momentum'))
from utils import select_experiment, select_run
from analysis.utils import plot_dataset, GlobalPlotter
import mlflow.tracking
import xarray as xr
import cmocean
import matplotlib.pyplot as plt
from dask.diagnostics import ProgressBar
from data.xrtransforms import SeasonalStdizer
import cartopy

cmap_balance = cmocean.cm.balance
proj_robinson = cartopy.crs.Robinson
plotter = GlobalPlotter()
%matplotlib notebook


In [None]:
run = select_run(cols = ['params.CO2', 'params.factor'], experiment_ids=('19',))

In [None]:
client = mlflow.tracking.MlflowClient()
data_file = client.download_artifacts(run.run_id, 'forcing')

In [None]:
data = xr.open_zarr(data_file)
data = data.rename(dict(xu_ocean='longitude', yu_ocean='latitude'))

In [None]:
data

In [None]:
data.isel(time=slice(1, 10)).max().compute()

In [None]:
import numpy as np
from scipy.stats import norm
quantiles = data['usurf'].isel(time=slice(0, 10)).quantile(np.arange(1e-5, 1+1e-5, 1e-5)).compute()
quantiles = np.concatenate((np.array([-5,]), quantiles, np.array([5, ])))
normal_quantiles = norm.ppf(np.arange(1e-5, 1+1e-5, 1e-5))
normal_quantiles = np.concatenate((np.array([-20,]), normal_quantiles, np.array([20, ])))


def _transform(value):
    value[np.isnan(value)] = 0
    quantile_index = np.searchsorted(quantiles, value) - 1
    v1 = quantiles[quantile_index]
    v2 = quantiles[quantile_index + 1]
    r = (value - v1) / (v2 - v1)
    v1 = normal_quantiles[quantile_index]
    v2 = normal_quantiles[quantile_index + 1]
    result =  v1 + r * (v2 - v1)
    result[np.isnan(value)] = np.nan
    return result

from sklearn.preprocessing import QuantileTransformer
t = QuantileTransformer(output_distribution='normal')
    

In [None]:
len(quantiles)

In [None]:
new = t.fit_transform(data['usurf'].isel(time=slice(0, 100)).compute().data.reshape((-1, 1)))

In [None]:
%matplotlib notebook
plt.figure()
plt.imshow(new[600, ...], origin='lower', vmin=-1.96, vmax=1.96, cmap=cmocean.cm.delta)

In [None]:
plt.figure()
_ = plt.hist(new.ravel(), bins=np.arange(-10, 10, 0.1), density=True)

In [None]:
%matplotlib notebook
plt.figure()
plt.plot(norm.ppf(np.arange(0.01, 1, 0.01)), quantiles['usurf'])

In [None]:
plotter.plot(data['vsurf'].isel(time=0), lon=0, projection_cls=proj_robinson, cmap=cmap_balance, vmin=-0.5, vmax=0.5)

In [None]:
data['usurf'].encoding

In [None]:
data['S_x'].encoding

In [None]:
from dask.diagnostics import ProgressBar, ResourceProfiler
import dask

d1 = data.sel(longitude=slice(-40, -20), latitude=slice(30, 50))
d2 = data.sel(longitude=slice(-70, -40), latitude=slice(30, 50))

print(d1.nbytes / 1e9)

In [None]:
d1

In [None]:
type(d1['usurf'].data)

In [None]:
dask.__version__

In [None]:
with ProgressBar(), ResourceProfiler() as prof:
    d1_val = d1.compute()

In [None]:
dask.__version__

In [None]:
prof.results

In [None]:
plotter.plot(d1_val.isel(time=5000)['usurf'], cmap=cmap_balance, vmin=-1, vmax=1, projection_cls=proj_robinson, lon=0.)

In [None]:
plotter.plot(data.vsurf.isel(time=0), lon=0., cmap=cmap_balance, vmin=-1, vmax=1)

In [None]:
t = SeasonalStdizer(std=True)
with ProgressBar():
    t.fit(data)

In [None]:
data_n = t(data)

In [None]:
data_v = data.chunk(dict(time=-1, longitude=28, latitude=17))

In [None]:
data_v = data.chunk(dict(time=100, longitude=450, latitude=275))
data_v = data_v.chunk(dict(time=400, longitude=225, latitude=137))
data_v = data_v.chunk(dict(time=1600, longitude=112, latitude=68))
data_v = data_v.chunk(dict(time=3200, longitude=56, latitude=34))
data_v = data_v.chunk(dict(time=-1, longitude=28, latitude=17))

In [None]:
import dask.array as da

In [None]:
v = da.random.randint(0, 10, (100, 1000), chunks=(10, -1))

In [None]:
v

In [None]:
d = v.rechunk((5, -1))

In [None]:
d

In [None]:
from dask import delayed
import numpy as np
def func(x):
    return 2*x
delayed_func = delayed(func)
new_v = da.concatenate([da.from_delayed(delayed_func(v[start:min(start+5, 100), :]), shape=(5, 1000), dtype=np.int64)
                       for start in range(0, 100, 5)])

In [None]:
import dask
z = dask.delayed(2)

In [None]:
z.compute()

In [None]:
with ProgressBar():
    test = data_v.isel(longitude=slice(0, 50), latitude=slice(0, 50)).compute()

In [None]:
plt.savefig('histtemp.jpg', dpi=300)

In [None]:
t = SeasonalStdizer()
test = t.fit_transform(data)

In [None]:
u = t.transform(data)

In [None]:
u.isel(time=slice(25, 50)).compute()

In [None]:
data['time'].dt.month.dtype

In [None]:
from dask import delayed
import dask.array as da
import numpy as np
@delayed
def get_months(times):
    return times.dt.month

def all_months(times):
    return xr.concat((xr.DataArray(data=da.from_delayed(get_months(times[start:start+100]), 
                                                        shape=(100,), dtype=np.int64),
                                   coords=dict(time=times[start:start+100]), dims=('time',))
                                   for start in range(0, 4300, 100)), 'time')

months = all_months(data['time'])

In [None]:
for k,v in data.items():
    print(k)
    print(v)

In [None]:
@delayed
def get_transformed(data, var_name):
    times = data.time
    months = times.dt.month
    r = data - monthly_means[var_name].sel(month=months)
    del r['month']
    return r.values

def all_data(data):
    sub_datasets = []
    for start in range(0, 4300, 100):
        sub_data = data.isel(time=slice(start, start+100))
        sub_coords = sub_data.coords
        new_xr_arrays = {}
        for k, val in sub_data.items():
            new_shape = val.shape
            dims = val.dims
            transformed = get_transformed(val, k)
            dask_array = da.from_delayed(transformed, shape=new_shape, dtype=np.float64)
            new_xr_array = xr.DataArray(data=dask_array, coords=sub_coords, dims=dims)
            new_xr_arrays[k] = new_xr_array
        new_ds = xr.Dataset(new_xr_arrays)
        sub_datasets.append(new_ds)
    return xr.concat(sub_datasets, dim='time')

data_n = all_data(data)
data_n

In [None]:
data_n['usurf'].isel(time=2000).std().compute()

In [None]:
data_n['usurf'].isel(time=0)

In [None]:
def remove_seasonal_means(data):
    months = all_months(data.time)
    

In [None]:
data = data.assign_coords(dict(month=months))

In [None]:
e = data.groupby('time.month') - monthly_grouped.mean()

In [None]:
del e['month']

In [None]:
data

In [None]:
data

In [None]:
test = test.drop_vars('month')

In [None]:
def standardize(x):
    return (x - x.mean(dim='time')) / x.std(dim='time')
monthly_grouped = data.groupby('time.month')
monthly_means = monthly_grouped.mean()

In [None]:
with ProgressBar():
    monthly_means = monthly_means.compute()

In [None]:
monthly_means

In [None]:
values = monthly_means.sel(month=months[1:10])

In [None]:
values

In [None]:
months = data['time'].dt.month
months
months_means = monthly_means.sel(month=months)
v = data - months_means
with ProgressBar():
    v = v.compute()

In [None]:
data.nbytes

In [None]:
data.dims

In [None]:
plotter.plot(v.isel(time=0)['usurf'], lon=0., cmap=cmap_balance)

In [None]:
monthly_grouped.groups

In [None]:
with ProgressBar():
    rr = monthly_grouped.compute()

In [None]:
plt.figure()
plt.plot(monthly_means['usurf'].sel(longitude=-161, latitude=0, method='nearest'))

In [None]:
plotter.plot(monthly_stds['usurf'].sel(month=10), lon=0., cmap=cmap_balance, vmax=0.5)

In [None]:
plotter.plot(data['usurf'].isel(time=0) - monthly_means['usurf'].sel(month=1), lon=0, cmap=cmap_balance)

In [None]:
plotter.plot(data_n['usurf'].isel(time=100), lon=0., cmap=cmap_balance)

In [None]:
var = t.apply(data)