In [None]:
import sys
sys.path.append('../python')
from virtual_stations import get_waterlevel
from misc import get_precipitation, get_pet, get_label_tree, startswith_label, get_mask, get_masks, str2datetime, get_peq_from_df, gcs_get_dir
from models import gr4hh
from mcmc_utils import dist_map, get_likelihood_logp, get_prior_logp

from mcmc import smc, dist
from datetime import timedelta
import random
import subprocess
import pickle
import pandas as pd
from pandas import DataFrame
import numpy as np
import os
from tqdm import tqdm
import xarray as xr
import gcsfs
from dask.distributed import Client

is_pangeo_data = True # True if in Pangeo binder, False if in laptop
if is_pangeo_data:
    from dask_kubernetes import KubeCluster as Cluster
    n_workers = 10
else:
    from dask.distributed import LocalCluster as Cluster
    n_workers = 4

%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
cluster = Cluster(n_workers=n_workers)
client = Client(cluster)
cluster

In [None]:
%run ../python/misc.py

In [None]:
if False:
    def set_worker_env():
        global mcmc
        import sys, os
        cwd = os.getcwd()
        sys.path.append(cwd + '/python')
        import mcmc
    client.run(set_worker_env)

The coordinates of the virtual stations in [Hydroweb](http://hydroweb.theia-land.fr) don't match with the rivers in [HydroSHEDS](http://www.hydrosheds.org). In order to find the corresponding coordinates in HydroSHEDS, we look around the original position for the pixel with the biggest accumulated flow which is bigger than a minimum flow. If no such flow is found, we look further around, until we find one (but not too far away, in which case we just drop the virtual station). The new_lat/new_lon are the coordinates of this pixel, if found.

In [None]:
if not os.path.exists('../data/amazonas/amazonas.pkl'):
    df_vs = locate_vs('../data/amazonas/amazonas.txt', pix_nb=20, acc_min=1_000_000)
    df_vs.to_pickle('../data/amazonas/amazonas.pkl')
else:
    df_vs = pd.read_pickle('../data/amazonas/amazonas.pkl')

In [None]:
df_vs.head()

In [None]:
sub_latlon = df_vs[['new_lat', 'new_lon']].dropna().values
print(f'Out of {len(df_vs)} virtual stations in Hydroweb, {len(sub_latlon)} could be found in HydroSHEDS.')

The following coordinates are duplicated because some virtual stations fall inside the same pixels.

In [None]:
rm_latlon = [(-4.928333333333334, -62.733333333333334), (-3.8666666666666667, -61.6775)]

In [None]:
df_ll = df_vs[['new_lat', 'new_lon']].dropna()
duplicated = df_ll[df_ll.duplicated(keep=False)]
duplicated

In [None]:
# all the subbasins in the hydrologic partition (including virtual stations)
#gcs_get_dir('pangeo-data/gross/ws_mask/amazonas', 'ws_mask/amazonas', fs)
#gcs_w_token = gcsfs.GCSFileSystem(project='pangeo-data', token='browser')
if is_pangeo_data:
    fs = gcsfs.GCSFileSystem(project='pangeo-data')
    all_labels = [os.path.basename(path[:-1]) for path in fs.ls('pangeo-data/gross/ws_mask/amazonas') if os.path.basename(path[:-1]).startswith('0')]
else:
    all_labels = [fname for fname in os.listdir('ws_mask/amazonas') if fname.startswith('0')]
print('Total number of subbasins:', len(all_labels))

In [None]:
label_pickle_path = '../data/amazonas/labels.pkl'
if not os.path.exists(label_pickle_path):
    labels_without_vs = list(labels)
    labels_with_vs = {}
    if is_pangeo_data:
        gcs_get_dir('pangeo-data/gross/ws_mask', 'ws_mask', fs)
    for label in tqdm(labels):
        ds = xr.open_zarr(f'ws_mask/amazonas/{label}')
        da = ds['mask']
        olat, olon = da.attrs['outlet']
        idx = df_ll[(olat-0.25/1200<df_ll.new_lat.values) & (df_ll.new_lat.values<olat+0.25/1200) & (olon-0.25/1200<df_ll.new_lon.values) & (df_ll.new_lon.values<olon+0.25/1200)].index.values
        if len(idx) > 0:
            labels_without_vs.remove(label)
            labels_with_vs[label] = list(df_vs.iloc[idx].station.values)
    with open(label_pickle_path, 'wb') as f:
        pickle.dump((labels_with_vs, labels_without_vs), f)
else:
    with open(label_pickle_path, 'rb') as f:
        labels_with_vs, labels_without_vs = pickle.load(f)

In [None]:
labels_with_vs_tree = get_label_tree(list(labels_with_vs))

In [None]:
os.makedirs('tmp/precipitation', exist_ok=True)
os.makedirs('tmp/pet', exist_ok=True)
os.makedirs('ws_mask/amazonas', exist_ok=True)

d0, d1 = '2000-03-01 12:00:00', '2018-12-31'
x_range = ((0.1, 1e4), (-1, 1), (0.1, 1e3), (0.1, 1e2))
draws = 100
warmup = 12 * 30 * 24 * 2 # one year in 30min steps
n_pdf = 10

x_pdf = {}
for down_label in labels_with_vs_tree:
    vs = labels_with_vs[down_label]
    for s in vs:
        print(df_vs.query(f"station == '{s}'").iloc[0])
    # get whole basin's labels
    whole_labels = startswith_label(down_label, all_labels)
    # copy basin's masks locally
    for label in whole_labels:
        if not os.path.exists(f'ws_mask/amazonas/{label}'):
            fs = gcsfs.GCSFileSystem(project='pangeo-data')
            gcs_get_dir(f'pangeo-data/gross/ws_mask/amazonas/{label}', f'ws_mask/amazonas/{label}', fs)
    # get upstream basin's labels and compute its area
    # also compute upstream basins' areas
    up_labels, areas_up = {}, {}
    for label in labels_with_vs_tree[down_label]['up']:
        areas_up[label] = 0
        up_labels[label] = startswith_label(label, all_labels)
        for label2 in up_labels[label]:
            areas_up[label] += xr.open_zarr(f'ws_mask/amazonas/{label2}', auto_chunk=False)['mask'].attrs['area']
    # get downstream bassin's labels and compute its area
    down_labels = whole_labels
    for labels in up_labels.values():
        down_labels = subtract_label(labels, down_labels)
    area_down = 0
    for label in down_labels:
        area_down += xr.open_zarr(f'ws_mask/amazonas/{label}', auto_chunk=False)['mask'].attrs['area']
    areas = list(areas_up.values()) + [area_down]
    area_whole = sum(areas)

    if is_pangeo_data:
        trmm_mask_path = 'gs://pangeo-data/gross/ws_mask/amazonas/trmm_mask'
        gpm_mask_path = 'gs://pangeo-data/gross/ws_mask/amazonas/gpm_mask'
    else:
        trmm_mask_path = 'ws_mask/amazonas/trmm_mask'
        gpm_mask_path = 'ws_mask/amazonas/gpm_mask'
    da_trmm_mask = xr.open_zarr(get_path(trmm_mask_path))['mask']
    da_gpm_mask = xr.open_zarr(get_path(gpm_mask_path))['mask']
    p = get_precipitation(d0, d1, all_labels, da_trmm_mask, da_gpm_mask, chunk_time=True, zarr_path='ws_precipitation')
    sys.exit()
    
    # get whole basin's mask
    #da = get_mask('ws_mask/amazonas', whole_labels)
    #subprocess.check_call('rm -rf tmp/mask'.split())
    #da.to_dataset(name='mask').to_zarr('tmp/mask')
    # get basin's water level time series at virtual station (basin's outlet)
    he = get_waterlevel(d0, d1, labels_with_vs[down_label][0]) # there might be several stations
    dh0 = he.dropna().index[0] # first date of observation
    dh1 = he.dropna().index[-1] # last date of observation
    # start date which allows to warmup the model, but not more (warmup is in 30min)
    dh0 = max(str2datetime(d0), dh0 - timedelta(minutes=warmup*30))
    # get whole basin's precipitation and PET time series
    if False:#os.path.exists(f'tmp/precipitation/{down_label}.pkl'):
        p_whole = pd.read_pickle(f'tmp/precipitation/{down_label}.pkl')
        e_whole = pd.read_pickle(f'tmp/pet/{down_label}.pkl')
    else:
        print(f'Getting precipitation and PET for {down_label}')
        p_whole = get_precipitation(d0, d1, mask_path)
        p_whole.to_pickle(f'tmp/precipitation/{down_label}.pkl')
        e_whole = get_pet(d0, d1, 'tmp/mask')
        e_whole.to_pickle(f'tmp/pet/{down_label}.pkl')
    p_up, e_up = {}, {}
    # precipitation and PET of upstream bassins already exist from previous iteration
    for label in labels_with_vs_tree[down_label]['up']:
        p_up[label] = pd.read_pickle(f'tmp/precipitation/{label}.pkl')
        e_up[label] = pd.read_pickle(f'tmp/pet/{label}.pkl')
    # compute precipitation and PET of downstream bassin
    p_down = p_whole * area_whole
    e_down = e_whole * area_whole
    for label in labels_with_vs_tree[down_label]['up']:
        p_down -= p_up[label] * areas_up[label]
        e_down -= e_up[label] * areas_up[label]
    p_down /= area_whole
    e_down /= area_whole

    pe = []
    # upstream basins' precipitation and PET
    for label in labels_with_vs_tree[down_label]['up']:
        df = DataFrame()
        df['p'] = p_up[label].loc[dh0:dh1]
        df['e'] = e_up[label].loc[dh0:dh1]
        pe.append(df)
    # downstream basin's precipitation and PET
    df = DataFrame()
    df['p'] = p_down.loc[dh0:dh1]
    df['e'] = e_down.loc[dh0:dh1]
    pe.append(df)
    # basin's water level
    he = he.reindex(df.index)

    if not up_labels:
        # this is a source basin (no basin flowing into it)
        x = [dist.uniform_pdf(*r) for r in x_range]
        prior_logp = get_prior_logp(x)
        model_logp = get_likelihood_logp(gr4hh, warmup, pe, areas, he=he)
    else:
        # there are basins flowing into this basin
        x = [xp + [dist.uniform_pdf(*d_range)] for xp in x_pdf[label] for label in labels_with_vs_tree[down_label]['up']] + [dist.uniform_pdf(*r) for r in x_range]
        prior_logp = get_prior_logp(x)
        model_logp = get_likelihood_logp(gr4hh, warmup, pe, areas, he=he)
    # run SMC
    posterior, q_sims = smc.smc(x, model_logp, prior_logp, draws=draws, dask_client=client)
    plt.figure(figsize=(20, 5))
    for q_sim in q_sims:
        plt.plot(q_sim.index[warmup:], dist_map(q_sim.values[warmup:], he.h.values[warmup:]), alpha=0.005, color='blue')
    plt.scatter(he.index[warmup:], he.h.values[warmup:])
    plt.show()
    # get simulated streamflow's PDF
    if up_labels:
        # reduce model
        q_pdf = np.empty((2, n_pdf, q_sims.shape[1]))
        for i in range(q_pdf.shape[2]):
            q_pdf[:, :, i] = dist.pdf_from_samples(q_sims[:, i], nb=n_pdf, kde=True)
        pe = DataFrame()
        pe['p'] = p_whole.loc[dh0:dh1]
        pe['e'] = e_whole.loc[dh0:dh1]
        x = [dist.uniform_pdf(*r) for r in x_range]
        prior_logp = get_prior_logp(x)
        model_logp = get_likelihood_logp(gr4hh, warmup, pe, [1], q_pdf=q_pdf)
        posterior, _ = smc.smc(x, model_logp, prior_logp, draws=draws, dask_client=client)
    x_pdf[down_label] = [dist.pdf_from_samples(posterior[:, i]) for i in range(4)]

In [None]:
p.sel(label='0').plot()

In [None]:
import dask.dataframe as dd
ddf = dd.read_csv('/home/david/Downloads/*.csv').set_index('time')

In [None]:
df = ddf[ddf.columns[:10]].mean(axis=1).compute()

In [None]:
df.plot()

In [None]:
df = ddf[ddf.columns[:10]].compute()
df.index = pd.to_datetime(df.index)

In [None]:
df.plot(legend=False, alpha=0.1, color='blue')

In [None]:
if is_pangeo_data:
    mask_path = 'gs://pangeo-data/gross/ws_mask/amazonas'
else:
    mask_path = 'ws_mask/amazonas'

In [None]:
da_trmm_mask = get_trmm_masks(mask_path, all_labels, None).astype('float32').chunk({'label': 10})
da_trmm_mask.to_dataset(name='mask').to_zarr('trmm_mask')

In [None]:
da_gpm_mask = get_gpm_masks(mask_path, all_labels, None).astype('float32').chunk({'label': 10})
da_gpm_mask.to_dataset(name='mask').to_zarr('gpm_mask')

In [None]:
def mult(x):
    return x * np.random.rand()
da_trmm_mask.groupby('label').apply(mult).sum('label').plot()

In [None]:
da_gpm_mask.groupby('label').apply(mult).sum('label').plot()