In [None]:
import xarray as xr
import pandas as pd
import fsspec
import dask

dask.config.set({"distributed.logging.distributed": "critical"})

In [None]:
# from dask.distributed import Client
# setup a local dask cluster
# client = Client()
# client

from dask_gateway import Gateway

gateway = Gateway()
options = gateway.cluster_options()
options.worker_cores = 4
options.worker_memory = 16
cluster = gateway.new_cluster(cluster_options=options)
cluster.adapt(minimum=1, maximum=40)
cluster

In [None]:
client = cluster.get_client()
client

In [None]:
bucket = "gs://carbonplan-data/raw/terraclimate/4000m/raster.zarr"
mapper = fsspec.get_mapper(bucket)

In [None]:
%%time
ds = xr.open_zarr(mapper, consolidated=True)

In [None]:
ds

In [None]:
ds.dims

In [None]:
null_count = ds.isnull().mean("time")
null_count = null_count.where(null_count != 1)

In [None]:
pdsi = null_count.pdsi.load()

# [::50, ::50].plot()

In [None]:
pdsi.coarsen(lat=2, lon=2).mean().plot(robust=True)
# pdsi

In [None]:
df = pd.read_csv("FIA_forBill_DroughtRiskV1_08182020.csv")  # .iloc[0:10000]
df.head()

In [None]:
import math

n = 5

size = math.ceil(len(df) / n)
assert len(df) <= size * n
print(size)

In [None]:
index_ds = df[["lat", "lon"]].to_xarray()
index_ds

In [None]:
# %%time
# ds_cond = ds.sel(lat=index_ds['lat'], lon=index_ds['lon'], method='nearest')
# ds_cond

In [None]:
from tqdm import tqdm

In [None]:
ds_list = []
for k, ids in tqdm(index_ds.groupby(index_ds.index // size)):
    ds_list.append(ds.sel(lat=ids["lat"], lon=ids["lon"], method="nearest"))

In [None]:
# ds_list[:2]

In [None]:
sum(d.nbytes for d in ds_list) / 1e9

In [None]:
parts = []
for d in tqdm(ds_list):
    part = d.compute()
    parts.append(part)

In [None]:
ds_cond = xr.concat(parts, dim="index")
# ds_cond = ds_cond.chunk({'index': 40000, 'time': 240})

In [None]:
ds_cond.PDSI.isnull().sum("index").plot()

In [None]:
ds_cond = xr.concat(parts, dim="index")
ds_cond = ds_cond.chunk({"index": 40000, "time": 240})

In [None]:
from dask.distributed import Client

client = Client()
client

In [None]:
bucket = "gs://carbonplan-scratch/terraclimate-fia-cond.zarr"
mapper2 = fsspec.get_mapper(bucket)

ds_cond.to_zarr(mapper2, mode="w")

In [None]:
# bucket = "gs://carbonplan-scratch/terraclimate-fia-cond.zarr"
# mapper2 = fsspec.get_mapper(bucket)

# ds_cond = xr.open_zarr(mapper2)

In [None]:
# ds_cond_ann = ds_cond.resample(time='AS').mean()
# ds_cond_ann

In [None]:
def weighted_mean(ds, *args, **kwargs):
    weights = ds.time.dt.days_in_month
    return ds.weighted(weights).mean(dim="time")


ds_cond_ann = ds_cond.resample(time="AS").map(weighted_mean, dim="time")

In [None]:
ds_cond_ann = ds_cond_ann.chunk({"index": -1, "time": -1})
ds_cond_ann["lon"] = ds_cond_ann["lon"].load()
ds_cond_ann["lat"] = ds_cond_ann["lat"].load()
ds_cond_ann

In [None]:
# from zarr.storage import ZipStore

# store = ZipStore('terraclimate-fia-cond-ann.zarr.zip', mode='w')

In [None]:
encoding = {k: {} for k in ds_cond_ann.data_vars}
encoding

In [None]:
bucket = "gs://carbonplan-scratch/terraclimate-fia-cond-ann-3.zarr"
mapper3 = fsspec.get_mapper(bucket, create=True)

from dask.diagnostics import ProgressBar

with ProgressBar():
    ds_cond_ann.to_zarr(mapper3, mode="w", consolidated=True, encoding=encoding)

In [None]:
da = ds_cond_ann.PDSI.load()

In [None]:
da.isel(index=slice(20000))

In [None]:
da.isnull().sum("index").plot()