In [None]:
%load_ext autoreload
%autoreload 2

import gstools as gs
import intake
import os
import zarr
import pandas as pd
import xarray as xr
import intake_esm
import numpy as np
from dask.distributed import Client
from cmip6_downscaling import CLIMATE_NORMAL_PERIOD
from cmip6_downscaling.constants import KELVIN, PERCENT, SEC_PER_DAY
import rioxarray
from rasterio.enums import Resampling
from cmip6_downscaling.workflows.share import (
    chunks,
    future_time,
    get_cmip_runs,
    hist_time,
    xy_region,
)
from cmip6_downscaling.workflows.utils import get_store
import matplotlib.pyplot as plt
intake_esm.__version__

In [None]:
import skdownscale

In [None]:
skdownscale.__file__

# access GCM data

will be replaced by `load_cmip_dictionary` and `gcm_munge`


In [None]:
from cmip6_downscaling.data.cmip import gcm_munge

In [None]:
activity_ids = ["CMIP", "ScenarioMIP"]
experiment_ids = ["historical", "ssp370"]  # , "ssp126", "ssp245",  "ssp585"
member_ids = ["r1i1p1f1"]
source_ids = ["CanESM5"]  # BCC-CSM2-MR"]
table_ids = ["day"]
grid_labels = ["gn"]
variables = "tasmax"
variable_ids = [variables]  # tasmax, tasmin, pr

In [None]:
col_url = (
    "https://cmip6downscaling.blob.core.windows.net/cmip6/pangeo-cmip6.json"
)

col = intake.open_esm_datastore(col_url)
full_subset = col.search(
    activity_id=activity_ids,
    experiment_id=experiment_ids,
    member_id=member_ids,
    table_id=table_ids,
    grid_label=grid_labels,
    variable_id=variable_ids,
    source_id=source_ids,
)

In [None]:
gcm_ds_dict = full_subset.to_dataset_dict(
    zarr_kwargs={
        "consolidated": True,
        "decode_times": True,
        "use_cftime": True,
    },
    storage_options={
        "account_name": "cmip6downscaling",
        "account_key": os.environ.get("AccountKey", None),
    },
)

In [None]:
keys = gcm_ds_dict.keys()
historical_gcm = gcm_munge(
    gcm_ds_dict[[k for k in keys if "historical" in k][0]]
)
future_gcm = gcm_munge(gcm_ds_dict[[k for k in keys if "ssp" in k][0]])

In [None]:
historical_gcm

In [None]:
future_gcm

# access obs data

to be replaced by `open_era5`


In [None]:
# converts cmip standard names to ERA5 names
variable_name_dict = {
    "tasmax": "air_temperature_at_2_metres_1hour_Maximum",
    "tasmin": "air_temperature_at_2_metres_1hour_Minimum",
    "pr": "precipitation_amount_1hour_Accumulation",
}

In [None]:
def get_store(bucket, prefix, account_key=None):
    """helper function to create a zarr store"""

    if account_key is None:
        account_key = os.environ.get("AccountKey", None)

    store = zarr.storage.ABSStore(
        bucket,
        prefix=prefix,
        account_name="cmip6downscaling",
        account_key=account_key,
    )
    return store


def open_era5(var):
    col = intake.open_esm_datastore(
        "https://cmip6downscaling.blob.core.windows.net/cmip6/ERA5_catalog.json"
    )
    subset = col.search(
        variable=variable_name_dict[var], year=np.arange(2010, 2020)
    )
    era5_stores = [
        store.split("az://cmip6/")[1] for store in subset.df.zstore.values
    ]
    store_list = [
        get_store(bucket="cmip6", prefix=prefix) for prefix in era5_stores
    ]
    ds = xr.open_mfdataset(store_list, engine="zarr", concat_dim="time").drop(
        "time1_bounds"
    )
    return ds

In [None]:
# full_obs = open_era5(variable)

In [None]:
# full_obs

# specify spatial regional subset and time periods


In [None]:
from cmip6_downscaling.data.cmip import convert_to_360

# parameters
historical_start = "2010"
historical_end = "2014"
future_start = "2015"
future_end = "2019"
min_lat = 19
max_lat = 55
min_lon = 227
max_lon = 299

# chunk shape for dask execution (time must be contiguous, ie -1)
chunks = {"lat": 10, "lon": 10, "time": -1}

In [None]:
# buffer = 3
# buffer_slice_lat = slice(max_lat + buffer, min_lat - buffer)
# buffer_slice_lon = slice(convert_to_360(min_lon) - buffer, convert_to_360(max_lon) + buffer)
# full_obs = full_obs.rio.write_crs('EPSG:4326')
# obs_buffer = full_obs.sel(lat=buffer_slice_lat, lon=buffer_slice_lon)
# obs_buffer = obs_buffer.resample(time='1D').reduce(np.max).rename({variable_name_dict[variable]:variable})
# obs_buffer = obs_buffer.chunk({'lat': 10, 'lon': 10, 'time': 1000})
# for v in obs_buffer:
#     print(v)
#     if 'chunks' in obs_buffer[v].encoding:
#         del obs_buffer[v].encoding['chunks']
# obs_buffer.to_zarr('obs_buffer.zarr', mode='w')

In [None]:
obs = xr.open_zarr("obs_buffer.zarr")
obs

# start of workflow


In [None]:
historical_period = slice(historical_start, historical_end)
future_period = slice(future_start, future_end)

In [None]:
from cmip6_downscaling.workflows.maca_flow import preprocess_maca

In [None]:
full_gcm, coarse_obs = preprocess_maca(
    historical_gcm=historical_gcm.sel(time=historical_period),
    future_gcm=future_gcm.sel(time=future_period),
    obs=obs,
    min_lon=min_lon,
    max_lon=max_lon,
    min_lat=min_lat,
    max_lat=max_lat,
)

In [None]:
full_gcm.compute()

In [None]:
full_gcm.isel(time=0)[variables].plot()

In [None]:
coarse_obs.compute()

In [None]:
coarse_obs.isel(time=0)[variables].plot()

In [None]:
obs.isel(time=0)[variables].plot()

## Epoch Adjustment


In [None]:
from cmip6_downscaling.methods.detrend import epoch_adjustment

In [None]:
epoch_adjustment_kwargs = None
epoch_adjustment_kws = {"day_rolling_window": 21, "year_rolling_window": 3}
epoch_adjustment_kws.update(
    {} if not epoch_adjustment_kwargs else epoch_adjustment_kwargs
)

# here, the time dimension of ea_gcm needs to be in 1 chunk
ea_gcm, trend = epoch_adjustment(
    data=full_gcm, historical_period=historical_period, **epoch_adjustment_kws
)

In [None]:
i = int(len(ea_gcm.lat) / 2)
j = int(len(ea_gcm.lon) / 2)
plt.figure(figsize=(25, 5))
ea_gcm.isel(lat=i, lon=j)[variables].plot(ax=plt.gca(), label="epoch adjusted")
full_gcm.isel(lat=i, lon=j)[variables].plot(ax=plt.gca(), label="original")

In [None]:
plt.figure(figsize=(25, 5))
trend.isel(lat=i, lon=j)[variables].plot(ax=plt.gca())

## coarse scale bias correction


In [None]:
from cmip6_downscaling.workflows.maca_flow import maca_bias_correction

In [None]:
bias_correction_kwargs = None
bias_correction_kws = {"batch_size": 15, "buffer_size": 15}
bias_correction_kws.update(
    {} if not bias_correction_kwargs else bias_correction_kwargs
)
bc_ea_gcm = maca_bias_correction(
    ds_gcm=ea_gcm,
    ds_obs=coarse_obs,
    historical_period=historical_period,
    variables=variables,
    **bias_correction_kws
)

In [None]:
# plot cdf
plt.hist(
    coarse_obs[variables].values.flatten(),
    bins=500,
    density=True,
    cumulative=True,
    label="observation",
    histtype="step",
    alpha=0.55,
    color="k",
)

plt.hist(
    ea_gcm[variables].sel(time=historical_period).values.flatten(),
    label="epoch adjusted (hist)",
    bins=500,
    density=True,
    cumulative=True,
    histtype="step",
    alpha=0.55,
)
plt.hist(
    ea_gcm[variables].sel(time=future_period).values.flatten(),
    label="epoch adjusted (future)",
    bins=500,
    density=True,
    cumulative=True,
    histtype="step",
    alpha=0.55,
)

plt.hist(
    bc_ea_gcm[variables].sel(time=historical_period).values.flatten(),
    label="bias corrected (hist)",
    bins=500,
    density=True,
    cumulative=True,
    histtype="step",
    alpha=0.55,
)
plt.hist(
    bc_ea_gcm[variables].sel(time=future_period).values.flatten(),
    label="bias corrected (future)",
    bins=500,
    density=True,
    cumulative=True,
    histtype="step",
    alpha=0.55,
)

plt.legend(loc="upper left")
plt.xlabel("value")
plt.ylabel("cumulative prob")
plt.show()
plt.close()

## constructed analogs


In [None]:
from cmip6_downscaling.workflows.maca_flow import maca_constructed_analogs

In [None]:
X = coarse_obs.rename({"time": "ndays_in_obs"})  # coarse obs
y = bc_ea_gcm.rename({"time": "ndays_in_gcm"})  # coarse gcm

# get rmse between each GCM slices to be downscaled and each observation slices
# will have the shape ndays_in_gcm x ndays_in_obs
rmse = np.sqrt(((X - y) ** 2).sum(dim=["lat", "lon"]))  # / n_pixel_coarse

In [None]:
rmse

In [None]:
# %debug
constructed_analogs_kwargs = None
constructed_analogs_kws = {"n_analogs": 10, "doy_range": 45}
constructed_analogs_kws.update(
    {} if not constructed_analogs_kwargs else constructed_analogs_kwargs
)

downscaled_gcm = maca_constructed_analogs(
    ds_gcm=bc_ea_gcm[variables],
    ds_obs_coarse=coarse_obs[variables],
    ds_obs_fine=obs[variables],
    **constructed_analogs_kws
)

In [None]:
downscaled_gcm.isel(time=slice(0, 10)).plot(col="time", col_wrap=5)

## epoch replacement


In [None]:
from cmip6_downscaling.workflows.maca_flow import maca_epoch_replacement

In [None]:
downscaled_bc_gcm = maca_epoch_replacement(
    ds_gcm_fine=downscaled_gcm,
    trend_coarse=trend,
)

In [None]:
downscaled_bc_gcm.isel(time=slice(0, 10)).plot(col="time", col_wrap=5)

## fine scale bias correction


In [None]:
final_gcm = maca_bias_correction(
    ds_gcm=downscaled_bc_gcm,
    ds_obs=obs,
    historical_period=historical_period,
    variables=variables,
    **bias_correction_kws
)