In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import time

os.environ['PREFECT__FLOWS__CHECKPOINTING'] = 'true'

from funnel import CacheStore
from funnel.prefect.result import FunnelResult
from prefect import task, Flow, Parameter
from prefect.executors import DaskExecutor
import xarray as xr

In [3]:
os.environ.get("AZURE_STORAGE_CONNECTION_STRING")

'DefaultEndpointsProtocol=https;AccountName=cmip6downscaling;AccountKey=Z86Ca7uttNw4smgPikZFgsB1Lgf8GFpL+lMiViWRaovBSsOWK5JFXAvlk5rO18R/61wqSfUDolvd8olXLt0law==;EndpointSuffix=core.windows.net'

In [4]:
import fsspec
import os 
connection_string = os.environ.get("AZURE_STORAGE_CONNECTION_STRING")

fs = fsspec.filesystem('az', connection_string=connection_string)

# # # fs.ls('flow-outputs/')

# # fs.rm('flow-outputs/intermediate/bias_corrected_gcm',recursive=True)
# # fs.rm('flow-outputs/intermediate/funnel_metadata_store/bias_corrected_gcm',recursive=True)
# # fs.rm('flow-outputs/intermediate/bias_corrected_obs',recursive=True)
# # fs.rm('flow-outputs/intermediate/funnel_metadata_store/bias_corrected_obs',recursive=True)
# fs.rm('flow-outputs/intermediate/interpolated_obs',recursive=True)
# fs.rm('flow-outputs/intermediate/epoch_adjusted_gcm', recursive=True)

In [5]:
# import gstools as gs
# import intake
# import os
# import zarr
# import pandas as pd
# import xarray as xr
# import intake_esm
# import numpy as np
# from dask.distributed import Client
# from cmip6_downscaling import CLIMATE_NORMAL_PERIOD
# from cmip6_downscaling.constants import KELVIN, PERCENT, SEC_PER_DAY
# import rioxarray
# from rasterio.enums import Resampling
# from cmip6_downscaling.workflows.share import (
#     chunks,
#     future_time,
#     get_cmip_runs,
#     hist_time,
#     xy_region,
# )
# from cmip6_downscaling.workflows.utils import get_store
# import matplotlib.pyplot as plt
# intake_esm.__version__

In [6]:
run_hyperparameters = {
    "OBS": "ERA5",
    "GCM": "MIROC6",
    "SCENARIO": "ssp370",
    "LABEL": "tasmax",
    "TRAIN_PERIOD_START": "1991",
    "TRAIN_PERIOD_END": "1995",
    "PREDICT_PERIOD_START": "2071",
    "PREDICT_PERIOD_END": "2075",
    # could also just put the default kwargs here 
    "EPOCH_ADJUSTMENT_DAY_ROLLING_WINDOW": 21, 
    "EPOCH_ADJUSTMENT_YEAR_ROLLING_WINDOW": 3,
    "BIAS_CORRECTION_BATCH_SIZE": 15,
    "BIAS_CORRECTION_BUFFER_SIZE": 15,
}

In [7]:
from cmip6_downscaling.tasks.common_tasks import (
    path_builder_task,
    get_obs_task,
    get_coarse_obs_task,
    get_gcm_task,
)

from cmip6_downscaling.workflows.maca_flow import (
    calc_epoch_trend_task,
    remove_epoch_trend_task,
    maca_bias_correction_task,
)

In [8]:
with Flow(name='maca-flow') as maca_flow:
    obs = Parameter("OBS")
    gcm = Parameter("GCM")
    scenario = Parameter("SCENARIO")
    label = Parameter("LABEL")

    train_period_start = Parameter("TRAIN_PERIOD_START")
    train_period_end = Parameter("TRAIN_PERIOD_END")
    predict_period_start = Parameter("PREDICT_PERIOD_START")
    predict_period_end = Parameter("PREDICT_PERIOD_END")

    epoch_adjustment_day_rolling_window = Parameter("EPOCH_ADJUSTMENT_DAY_ROLLING_WINDOW")
    epoch_adjustment_year_rolling_window = Parameter("EPOCH_ADJUSTMENT_YEAR_ROLLING_WINDOW")
    bias_correction_batch_size = Parameter("BIAS_CORRECTION_BATCH_SIZE")
    bias_correction_buffer_size = Parameter("BIAS_CORRECTION_BUFFER_SIZE")


    # dictionary with information to build appropriate paths for caching
    gcm_grid_spec, obs_identifier, gcm_identifier = path_builder_task(
        obs=obs,
        gcm=gcm,
        scenario=scenario,
        train_period_start=train_period_start,
        train_period_end=train_period_end,
        predict_period_start=predict_period_start,
        predict_period_end=predict_period_end,
        variables=[label],
    )
    
    # get original resolution observations 
    ds_obs_full_space = get_obs_task(
        obs=obs,
        train_period_start=train_period_start,
        train_period_end=train_period_end,
        variables=[label],
        chunking_approach='full_space',
        cache_within_rechunk=True,
    )

    # get coarsened resolution observations 
    # TODO: this coarse obs is going to be used in bias correction, need to figure out how it should be chunked 
    ds_obs_coarse = get_coarse_obs_task(
        ds_obs=ds_obs_full_space, 
        gcm=gcm, 
        chunking_approach='full_space', 
        gcm_grid_spec=gcm_grid_spec,
        obs_identifier=obs_identifier,
    )
    
    # get gcm 
    ds_gcm_full_time = get_gcm_task(
        gcm=gcm,
        scenario=scenario,
        variables=[label],
        train_period_start=train_period_start,
        train_period_end=train_period_end,
        predict_period_start=predict_period_start,
        predict_period_end=predict_period_end,
        chunking_approach='full_time',
        cache_within_rechunk=True,
    )
    
    # epoch adjustment 
    coarse_epoch_trend = calc_epoch_trend_task(
        data=ds_gcm_full_time,
        train_period_start=train_period_start,
        train_period_end=train_period_end,
        day_rolling_window=epoch_adjustment_day_rolling_window,
        year_rolling_window=epoch_adjustment_year_rolling_window,
        gcm_identifier=gcm_identifier,
    )
    
    epoch_adjusted_gcm = remove_epoch_trend_task(
        data=ds_gcm_full_time,
        trend=coarse_epoch_trend,
        day_rolling_window=epoch_adjustment_day_rolling_window,
        year_rolling_window=epoch_adjustment_year_rolling_window,
        gcm_identifier=gcm_identifier,
    )
    
    # bias correction 
    bias_corrected_gcm = maca_bias_correction_task(
        ds_gcm=epoch_adjusted_gcm,
        ds_obs=ds_obs_coarse,
        train_period_start=train_period_start,
        train_period_end=train_period_end,
        variables=[label],
        batch_size=bias_correction_batch_size,
        buffer_size=bias_correction_buffer_size,
        method='maca_edcdfm',
        gcm_identifier=gcm_identifier,
        chunking_approach='matched',
    )

In [None]:
maca_flow.run(parameters=run_hyperparameters)

[2022-01-03 16:56:56+0000] INFO - prefect.FlowRunner | Beginning Flow run for 'maca-flow'
[2022-01-03 16:56:56+0000] INFO - prefect.TaskRunner | Task 'TRAIN_PERIOD_END': Starting task run...
[2022-01-03 16:56:56+0000] INFO - prefect.TaskRunner | Task 'TRAIN_PERIOD_END': Finished task run for task with final state: 'Success'
[2022-01-03 16:56:56+0000] INFO - prefect.TaskRunner | Task 'PREDICT_PERIOD_START': Starting task run...
[2022-01-03 16:56:56+0000] INFO - prefect.TaskRunner | Task 'PREDICT_PERIOD_START': Finished task run for task with final state: 'Success'
[2022-01-03 16:56:56+0000] INFO - prefect.TaskRunner | Task 'PREDICT_PERIOD_END': Starting task run...
[2022-01-03 16:56:56+0000] INFO - prefect.TaskRunner | Task 'PREDICT_PERIOD_END': Finished task run for task with final state: 'Success'
[2022-01-03 16:56:56+0000] INFO - prefect.TaskRunner | Task 'LABEL': Starting task run...
[2022-01-03 16:56:56+0000] INFO - prefect.TaskRunner | Task 'LABEL': Finished task run for task with

1. Consolidating metadata in this existing store with zarr.consolidate_metadata().
2. Explicitly setting consolidated=False, to avoid trying to read consolidate metadata, or
3. Explicitly setting consolidated=True, to raise an error in this case instead of falling back to try reading non-consolidated metadata.
  output = xr.open_zarr(target_store)


[2022-01-03 16:56:58+0000] INFO - prefect.TaskRunner | Task 'get_obs': Finished task run for task with final state: 'Success'
[2022-01-03 16:56:58+0000] INFO - prefect.TaskRunner | Task 'EPOCH_ADJUSTMENT_YEAR_ROLLING_WINDOW': Starting task run...
[2022-01-03 16:56:58+0000] INFO - prefect.TaskRunner | Task 'EPOCH_ADJUSTMENT_YEAR_ROLLING_WINDOW': Finished task run for task with final state: 'Success'
[2022-01-03 16:56:58+0000] INFO - prefect.TaskRunner | Task 'BIAS_CORRECTION_BATCH_SIZE': Starting task run...
[2022-01-03 16:56:58+0000] INFO - prefect.TaskRunner | Task 'BIAS_CORRECTION_BATCH_SIZE': Finished task run for task with final state: 'Success'
[2022-01-03 16:56:58+0000] INFO - prefect.TaskRunner | Task 'GCM': Starting task run...
[2022-01-03 16:56:58+0000] INFO - prefect.TaskRunner | Task 'GCM': Finished task run for task with final state: 'Success'
[2022-01-03 16:56:58+0000] INFO - prefect.TaskRunner | Task 'get_gcm': Starting task run...


    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array[indexer]

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ...     array[indexer]
  return self.array[key]
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array[indexer]

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ...     array[indexer]
  return self.array[key]
1. Consolidating metadata in this existing store with zarr.consolidate_metadata().
2. Explicitly setting consolidated=False, to avoid trying to read consolidate metadata, or
3. Explicitly setting consolidated=True, to raise an error in this case instead of falling back to try reading non-consolidated metadata.
  output = xr.open_zarr(target_store)


target path is az://flow-outputs/intermediate/rechunked_gcm/MIROC6_ssp370_1991_1995_2071_2075_tasmax_full_time.zarr
checking the cache
[2022-01-03 16:57:00+0000] INFO - prefect.TaskRunner | Task 'get_gcm': Finished task run for task with final state: 'Success'
[2022-01-03 16:57:00+0000] INFO - prefect.TaskRunner | Task 'BIAS_CORRECTION_BUFFER_SIZE': Starting task run...
[2022-01-03 16:57:00+0000] INFO - prefect.TaskRunner | Task 'BIAS_CORRECTION_BUFFER_SIZE': Finished task run for task with final state: 'Success'
[2022-01-03 16:57:00+0000] INFO - prefect.TaskRunner | Task 'List': Starting task run...
[2022-01-03 16:57:00+0000] INFO - prefect.TaskRunner | Task 'List': Finished task run for task with final state: 'Success'
[2022-01-03 16:57:00+0000] INFO - prefect.TaskRunner | Task 'path_builder_task': Starting task run...
[2022-01-03 16:57:00+0000] INFO - prefect.TaskRunner | Task 'path_builder_task': Finished task run for task with final state: 'Success'
[2022-01-03 16:57:00+0000] INFO

    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array[indexer]

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ...     array[indexer]
  return self.array[key]


[2022-01-03 16:57:02+0000] INFO - prefect.TaskRunner | Task 'calc_epoch_trend_task': Finished task run for task with final state: 'Cached'
[2022-01-03 16:57:02+0000] INFO - prefect.TaskRunner | Task 'get_coarse_obs_task': Starting task run...
[2022-01-03 16:57:02+0000] INFO - prefect.TaskRunner | Task 'get_coarse_obs_task': Finished task run for task with final state: 'Cached'
[2022-01-03 16:57:02+0000] INFO - prefect.TaskRunner | Task 'remove_epoch_trend': Starting task run...
[2022-01-03 16:57:03+0000] INFO - prefect.TaskRunner | Task 'remove_epoch_trend': Finished task run for task with final state: 'Cached'
[2022-01-03 16:57:03+0000] INFO - prefect.TaskRunner | Task 'maca_bias_correction_task': Starting task run...
target path is az://cmip6/temp/qnzsklezdv.zarr
checking the chunk
rechunking
_copy_chunk((slice(0, 1, None), slice(0, 30, None), slice(0, 30, None)))
_copy_chunk((slice(59, 60, None), slice(90, 120, None), slice(60, 90, None)))
_copy_chunk((slice(202, 203, None), slice(3

In [10]:
print('done')

done


# specify spatial regional subset and time periods


In [None]:
from cmip6_downscaling.data.cmip import convert_to_360

# parameters
historical_start = "2010"
historical_end = "2014"
future_start = "2015"
future_end = "2019"
min_lat = 19
max_lat = 55
min_lon = 227
max_lon = 299

# chunk shape for dask execution (time must be contiguous, ie -1)
chunks = {"lat": 10, "lon": 10, "time": -1}

In [None]:
# buffer = 3
# buffer_slice_lat = slice(max_lat + buffer, min_lat - buffer)
# buffer_slice_lon = slice(convert_to_360(min_lon) - buffer, convert_to_360(max_lon) + buffer)
# full_obs = full_obs.rio.write_crs('EPSG:4326')
# obs_buffer = full_obs.sel(lat=buffer_slice_lat, lon=buffer_slice_lon)
# obs_buffer = obs_buffer.resample(time='1D').reduce(np.max).rename({variable_name_dict[variable]:variable})
# obs_buffer = obs_buffer.chunk({'lat': 10, 'lon': 10, 'time': 1000})
# for v in obs_buffer:
#     print(v)
#     if 'chunks' in obs_buffer[v].encoding:
#         del obs_buffer[v].encoding['chunks']
# obs_buffer.to_zarr('obs_buffer.zarr', mode='w')

In [None]:
obs = xr.open_zarr("obs_buffer.zarr")
obs

# start of workflow


In [None]:
historical_period = slice(historical_start, historical_end)
future_period = slice(future_start, future_end)

In [None]:
from cmip6_downscaling.workflows.maca_flow import preprocess_maca

In [None]:
full_gcm, coarse_obs = preprocess_maca(
    historical_gcm=historical_gcm.sel(time=historical_period),
    future_gcm=future_gcm.sel(time=future_period),
    obs=obs,
    min_lon=min_lon,
    max_lon=max_lon,
    min_lat=min_lat,
    max_lat=max_lat,
)

In [None]:
full_gcm.compute()

In [None]:
full_gcm.isel(time=0)[variables].plot()

In [None]:
coarse_obs.compute()

In [None]:
coarse_obs.isel(time=0)[variables].plot()

In [None]:
obs.isel(time=0)[variables].plot()

## Epoch Adjustment


In [None]:
from cmip6_downscaling.methods.detrend import epoch_adjustment

In [None]:
epoch_adjustment_kwargs = None
epoch_adjustment_kws = {"day_rolling_window": 21, "year_rolling_window": 3}
epoch_adjustment_kws.update(
    {} if not epoch_adjustment_kwargs else epoch_adjustment_kwargs
)

# here, the time dimension of ea_gcm needs to be in 1 chunk
ea_gcm, trend = epoch_adjustment(
    data=full_gcm, historical_period=historical_period, **epoch_adjustment_kws
)

In [None]:
i = int(len(ea_gcm.lat) / 2)
j = int(len(ea_gcm.lon) / 2)
plt.figure(figsize=(25, 5))
ea_gcm.isel(lat=i, lon=j)[variables].plot(ax=plt.gca(), label="epoch adjusted")
full_gcm.isel(lat=i, lon=j)[variables].plot(ax=plt.gca(), label="original")

In [None]:
plt.figure(figsize=(25, 5))
trend.isel(lat=i, lon=j)[variables].plot(ax=plt.gca())

## coarse scale bias correction


In [None]:
from cmip6_downscaling.workflows.maca_flow import maca_bias_correction

In [None]:
bias_correction_kwargs = None
bias_correction_kws = {"batch_size": 15, "buffer_size": 15}
bias_correction_kws.update(
    {} if not bias_correction_kwargs else bias_correction_kwargs
)
bc_ea_gcm = maca_bias_correction(
    ds_gcm=ea_gcm,
    ds_obs=coarse_obs,
    historical_period=historical_period,
    variables=variables,
    **bias_correction_kws
)

In [None]:
# plot cdf
plt.hist(
    coarse_obs[variables].values.flatten(),
    bins=500,
    density=True,
    cumulative=True,
    label="observation",
    histtype="step",
    alpha=0.55,
    color="k",
)

plt.hist(
    ea_gcm[variables].sel(time=historical_period).values.flatten(),
    label="epoch adjusted (hist)",
    bins=500,
    density=True,
    cumulative=True,
    histtype="step",
    alpha=0.55,
)
plt.hist(
    ea_gcm[variables].sel(time=future_period).values.flatten(),
    label="epoch adjusted (future)",
    bins=500,
    density=True,
    cumulative=True,
    histtype="step",
    alpha=0.55,
)

plt.hist(
    bc_ea_gcm[variables].sel(time=historical_period).values.flatten(),
    label="bias corrected (hist)",
    bins=500,
    density=True,
    cumulative=True,
    histtype="step",
    alpha=0.55,
)
plt.hist(
    bc_ea_gcm[variables].sel(time=future_period).values.flatten(),
    label="bias corrected (future)",
    bins=500,
    density=True,
    cumulative=True,
    histtype="step",
    alpha=0.55,
)

plt.legend(loc="upper left")
plt.xlabel("value")
plt.ylabel("cumulative prob")
plt.show()
plt.close()

## constructed analogs


In [None]:
from cmip6_downscaling.workflows.maca_flow import maca_constructed_analogs

In [None]:
X = coarse_obs.rename({"time": "ndays_in_obs"})  # coarse obs
y = bc_ea_gcm.rename({"time": "ndays_in_gcm"})  # coarse gcm

# get rmse between each GCM slices to be downscaled and each observation slices
# will have the shape ndays_in_gcm x ndays_in_obs
rmse = np.sqrt(((X - y) ** 2).sum(dim=["lat", "lon"]))  # / n_pixel_coarse

In [None]:
rmse

In [None]:
# %debug
constructed_analogs_kwargs = None
constructed_analogs_kws = {"n_analogs": 10, "doy_range": 45}
constructed_analogs_kws.update(
    {} if not constructed_analogs_kwargs else constructed_analogs_kwargs
)

downscaled_gcm = maca_constructed_analogs(
    ds_gcm=bc_ea_gcm[variables],
    ds_obs_coarse=coarse_obs[variables],
    ds_obs_fine=obs[variables],
    **constructed_analogs_kws
)

In [None]:
downscaled_gcm.isel(time=slice(0, 10)).plot(col="time", col_wrap=5)

## epoch replacement


In [None]:
from cmip6_downscaling.workflows.maca_flow import maca_epoch_replacement

In [None]:
downscaled_bc_gcm = maca_epoch_replacement(
    ds_gcm_fine=downscaled_gcm,
    trend_coarse=trend,
)

In [None]:
downscaled_bc_gcm.isel(time=slice(0, 10)).plot(col="time", col_wrap=5)

## fine scale bias correction


In [None]:
final_gcm = maca_bias_correction(
    ds_gcm=downscaled_bc_gcm,
    ds_obs=obs,
    historical_period=historical_period,
    variables=variables,
    **bias_correction_kws
)