In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext watermark

import os

os.environ["PREFECT__FLOWS__CHECKPOINTING"] = "True"

import fsspec
import pandas as pd
import xarray as xr
import numpy as np
import cmip6_downscaling
import numpy as np
from xarray_schema import DataArraySchema, DatasetSchema

from cmip6_downscaling.data.observations import open_era5
from cmip6_downscaling.data.cmip import get_gcm, load_cmip
from dask.distributed import Client, LocalCluster
import warnings

warnings.filterwarnings("ignore")
import pytest
from cmip6_downscaling.analysis import analysis, metrics
from cmip6_downscaling.analysis.analysis import (
    qaqc_checks,
    grab_top_city_data,
    load_top_cities,
    get_seasonal,
    change_ds,
)
from cmip6_downscaling.analysis.plot import (
    plot_cdfs,
    plot_values_and_difference,
    plot_seasonal,
)  # , plot_each_step_bcsd
from cmip6_downscaling.analysis.qaqc import make_qaqc_ds
from cmip6_downscaling.methods.common.containers import BBox
from cmip6_downscaling.methods.common.utils import validate_zarr_store
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
from carbonplan import styles
from cmip6_downscaling import config
from upath import UPath
import json

styles.mpl.set_theme(style='carbonplan_dark')

In [None]:
watermark -d -n -t -u -v -p cmip6_downscaling -h -m -g -r -b

In [None]:
from dask.distributed import Client, LocalCluster

cluster = LocalCluster()
client = Client(cluster)
cluster.scale(32)

In [None]:
client

In [None]:
def open_store(d, dataset_nickname: str, chunking_method: str = 'full_time'):
    analysis_path_store_names = {
        'bcsd': {
            'unchunked': {'obs': 'obs_path'},
            'full_time': {
                'obs': 'obs_full_time_path',
                'gcm_train': 'experiment_train_full_time_path',
                'gcm_predict': 'experiment_predict_full_time_path',
                'output_daily': 'final_bcsd_full_time_path',
                'output_monthly': 'monthly_summary_path',
                'output_annual': 'annual_summary_path',
            },
            'full_space': {
                'obs': 'obs_full_space_path',
                'gcm_train': 'experiment_train_path',
                'gcm_predict': 'experiment_predict_path',
                'output_daily': 'final_bcsd_full_space_path',
                'output_monthly': 'monthly_summary_full_space_path',
                'output_annual': 'annual_summary_full_space_path',
            },
        },
        'gard': {
            'unchunked': {'obs': 'obs_path'},
            'full_time': {
                'obs': 'obs_full_time_path',
                'gcm_train': 'experiment_train_full_time_path',
                'gcm_predict': 'experiment_predict_path',
                'output_daily': 'model_output_path',
                'output_monthly': 'monthly_summary_path',
                'output_annual': 'annual_summary_path',
            },
            'full_space': {
                'obs': 'obs_full_space_path',
                'gcm_train': 'experiment_train_path',
                'gcm_predict': 'experiment_predict_path',
                'output_daily': 'full_space_model_output_path',
                'output_monthly': 'monthly_summary_full_space_path',
                'output_annual': 'annual_summary_full_space_path',
            },
        },
    }
    stores = d['datasets']
    parameters = d['parameters']
    downscaling_method = parameters['method']
    store_name = analysis_path_store_names[downscaling_method][chunking_method][dataset_nickname]
    ds = xr.open_zarr(stores[store_name])
    return ds

In [None]:
connection_string = os.environ.get("AZURE_STORAGE_CONNECTION_STRING")
fs = fsspec.filesystem('az', connection_string=connection_string)

In [None]:
latest_json_path = 'flow-outputs/results/0.1.7/runs/gard_ERA5_CanESM5_r1i1p1f1_historical_tasmax_-90_90_-180_180_1981_2010_1950_2014/latest.json'
scheduler_address = None  # does it come in as string

In [None]:
d = json.loads(fs.cat(latest_json_path))

In [None]:
stores = d['datasets']
parameters = d['parameters']

In [None]:
train_period = slice(*parameters['train_dates'])
predict_period = slice(*parameters['predict_dates'])
var = parameters['variable']
method = parameters['method']

# Check that final zarr stores are validÂ¶


In [None]:
for path in stores.values():
    if 'flow-outputs/results' in path:
        print(path)
        validate_zarr_store(path)

# Load in your data


Load in downscaled run


In [None]:
output_daily = open_store(d, 'output_daily', 'full_time')

Load in observational dataset for evaluation below.


In [None]:
# obs = xr.open_zarr('az://scratch/intermediates/0.1.6.post1+dirty/rechunk/3b24d6514b116b81')
obs = open_store(d, 'obs', 'full_time')

Load in raw GCM (not downscaled or bias-corrected).


In [None]:
if method == 'bcsd':
    gcm_train = open_store(d, 'gcm_train', 'full_time')
gcm_predict = open_store(d, 'gcm_predict', 'full_time')

# Let's first just look at the data at a location about which we have intuition

In [None]:
seattle = {'lat': 47.5, 'lon': -121.75}

In [None]:
output_daily.sel(**seattle, method='nearest')[var].plot()

In [None]:
fig, axarr = plt.subplots()
output_daily.sel(**seattle, method='nearest')[var].groupby('time.month').mean().plot(
    ax=axarr, label='downscaled'
)
obs.sel(**seattle, method='nearest')[var].groupby('time.month').mean().plot(ax=axarr, label='ERA5')
plt.legend()

# QAQC Routines

Search for nans and aphysical quantitities. This step can take a while (and so there is a flag to turn it `False`) but can be useful to run if you are suspicious something might be wrong or you have done major changes to the implementation.


In [None]:
run_qaqc = False

In [None]:
if run_qaqc:
    checks = {
        'tasmax': ['nulls', 'aphysical_high_temp', 'aphysical_low_temp'],
        'tasmin': ['nulls', 'aphysical_high_temp', 'aphysical_low_temp'],
        'pr': ['nulls', 'aphysical_high_precip', 'aphysical_low_precip'],
    }
    annual_qaqc_ts, qaqc_maps = qaqc_checks(output_daily, checks=checks[parameters['variable']])

What years, if any, include QAQC issues? Where, if anywhere, do those QAQC issues happen?

In [None]:
if run_qaqc:
    if annual_qaqc_ts.sum().values.any():
        print(annual_qaqc_ts)
        for var in qaqc_maps:
            if qaqc_maps[var].sum().values.any():
                qaqc_maps[var].plot(col_wrap=1, col="qaqc_check")

# Evaluation over training period

How well do statistics of downscaled GCM data look like observations? In other
words, did the model perform as expected and are there any other side-effects.


First, let's look at some individual locations around the world (we'll look at a
sampling of the biggest cities). We'll compare the statistics of the historical
downscaled data with the observations dataset and see how they match. Depending
on the metric they should align very well (if that metric was used in training)
but other metrics might differ still.


Load in the training dataset (ERA5)


We'll do our analyses across the 100 biggest cities (all in different countries
so as to provide some geographic diversity). Also added a few in the Western US
to cover that part of the world.


In [None]:
top_cities = load_top_cities(num_cities=100, add_additional_cities=True, plot=True)

In [None]:
if method == 'bcsd':
    [obs_cities, downscaled_cities, gcm_cities] = grab_top_city_data(
        [obs[var], output_daily[var], gcm_train[var]], top_cities
    )
# if historical then gcm_predict serves as the raw gcm
elif method == 'gard':
    [obs_cities, downscaled_cities, gcm_cities] = grab_top_city_data(
        [obs[var], output_daily[var], gcm_predict[var]], top_cities
    )

In [None]:
plot_cdfs(
    obs_cities,
    top_cities,
    train_period,
    predict_period,
    downscaled_cities.sel(time=train_period),
    downscaled_cities.sel(time=predict_period),
    gcm_cities.sel(time=train_period),
    gcm_cities.sel(time=predict_period),
    sharex=False,
)

In [None]:
plot_cdfs(
    obs=obs_cities,
    top_cities=top_cities,
    train_period=train_period,
    predict_period=predict_period,
    historical_downscaled=downscaled_cities.sel(time=train_period),
    future_downscaled=None,
    historical_gcm=gcm_cities.sel(time=train_period),
    future_gcm=None,
    sharex=True,
)

# Performance of standard statistics


In [None]:
# some sample regions to see finer scale details than global
regions = {
    'US': {'lat': slice(25, 50), 'lon': slice(-120, -70)},
    'tiny central US': {'lat': slice(35, 40), 'lon': slice(-100, -90)},
    'Brazil': {'lat': slice(-30, 10), 'lon': slice(-70, -30)},
}

In [None]:
metric_dict = {'obs': {}, 'downscaled': {}}
for metric in ["mean", "std", 'percentile99', 'percentile1']:
    metric_dict['obs'][metric] = metrics.metric_calc(obs.sel(time=train_period), metric)[
        var
    ].compute()
    metric_dict['downscaled'][metric] = metrics.metric_calc(
        output_daily.sel(time=train_period), metric
    )[var].compute()

In [None]:
var_limits = {
    'tasmax': {
        'general': {
            'mean': (280, 300),
            'std': (0, 20),
            'percentile1': (250, 280),
            'percentile99': (290, 320),
        },
        'over 30c': (0, 365),
        'over 40c': (0, 50),
    },
    'pr': {
        'general': {
            'mean': (0, 10),
            'std': (0, 5),
            'percentile1': (0, 0.1),
            'percentile99': (10, 25),
        }
    },
}
diff_limits = {'tasmax': {'overall': 5, 'over 30c': 50, 'over 40c': 25}, 'pr': {'overall': 50}}
diff_method = {'tasmax': 'absolute', 'pr': 'percent'}

In [None]:
for metric in ["mean", "std", 'percentile99', 'percentile1']:
    plot_values_and_difference(
        metric_dict['obs'][metric],
        metric_dict['downscaled'][metric],
        title1="Observed {} {}".format(metric, var),
        title2="Downscaled {} {}".format(metric, var),
        title3="Difference downscaled-GCM",
        variable=var,
        metric=metric,
        var_limits=var_limits[var]['general'][metric],
        diff_limit=diff_limits[var]['overall'],
        diff_method=diff_method[var],
    )

In [None]:
for metric in ["mean", "std", 'percentile99', 'percentile1']:
    plot_values_and_difference(
        metric_dict['obs'][metric].sel(**regions['US']),
        metric_dict['downscaled'][metric].sel(**regions['US']),
        title1="Observed {} {}".format(metric, var),
        title2="Downscaled {} {}".format(metric, var),
        title3="Difference downscaled-GCM",
        variable=var,
        metric=metric,
        var_limits=var_limits[var]['general'][metric],
        diff_limit=diff_limits[var]['overall'],
        diff_method=diff_method[var],
    )

In [None]:
for metric in ["mean", "std", 'percentile99', 'percentile1']:
    plot_values_and_difference(
        metric_dict['obs'][metric].sel(**regions['tiny central US']),
        metric_dict['downscaled'][metric].sel(**regions['tiny central US']),
        title1="Observed {} {}".format(metric, var),
        title2="Downscaled {} {}".format(metric, var),
        title3="Difference downscaled-GCM",
        variable=var,
        metric=metric,
        var_limits=var_limits[var]['general'][metric],
        diff_limit=diff_limits[var]['overall'],
        diff_method=diff_method[var],
    )

In [None]:
for metric in ["mean", "std", 'percentile99', 'percentile1']:
    plot_values_and_difference(
        metric_dict['obs'][metric].sel(**regions['Brazil']),
        metric_dict['downscaled'][metric].sel(**regions['Brazil']),
        title1="Observed {} {}".format(metric, var),
        title2="Downscaled {} {}".format(metric, var),
        title3="Difference downscaled-GCM",
        variable=var,
        metric=metric,
        var_limits=var_limits[var]['general'][metric],
        diff_limit=diff_limits[var]['overall'],
        diff_method=diff_method[var],
    )

# Hot days

Calculate the average number of hot days (over 30c) within observations and
downscaled model.


In [None]:
if var == "tasmax":
    days_over_30c_era5 = metrics.days_temperature_threshold(obs, "over", 273.15 + 30).compute()
    days_over_30c_ds = metrics.days_temperature_threshold(
        output_daily.sel(time=train_period), "over", 273.15 + 30
    ).compute()

In [None]:
if var == "tasmax":
    plot_values_and_difference(
        days_over_30c_era5["tasmax"],
        days_over_30c_ds["tasmax"],
        cbar_kwargs={"label": "Days over 30C"},
        var_limits=var_limits[var]['over 30c'],
        diff_limit=diff_limits[var]['over 30c'],
        variable=var,
        metric='daysover30',
    )

# Very hot days

Performance: the average number of very hot days (over 40c)


In [None]:
if var == "tasmax":
    days_over_40c_era5 = metrics.days_temperature_threshold(obs, "over", 273 + 40).compute()
    days_over_40c_ds = metrics.days_temperature_threshold(
        output_daily.sel(time=train_period), "over", 273 + 40
    ).compute()

In [None]:
if var == "tasmax":
    plot_values_and_difference(
        days_over_40c_era5["tasmax"],
        days_over_40c_ds["tasmax"],
        title1="Observed days per year over 40C",
        title2="Downscaled days per year over 40C",
        cbar_kwargs={"label": "Days over 40C"},
        variable=var,
        metric='daysover40',
        var_limits=var_limits[var]['over 40c'],
        diff_limit=diff_limits[var]['over 40c'],
    )

In [None]:
if var == "tasmax":
    plot_values_and_difference(
        title1="Observed days per year over 40C",
        title2="Downscaled days per year over 40C",
        cbar_kwargs={"label": "Days over 40C"},
        variable=var,
        metric='daysover40',
        var_limits=var_limits[var]['over 40c'],
        diff_limit=diff_limits[var]['over 40c'],
    )

# Is the change seen in the downscaled dataset the same as the change projected by the raw GCM?


In [None]:
change_analyses = False
if change_analyses:
    downscaled_change_cities = change_ds(
        downscaled_cities.sel(time=train_period),
        downscaled_cities.sel(time=future_period),
    )
    gcm_change_cities = change_ds(
        gcm_cities.sel(time=train_period), gcm_cities.sel(time=future_period)
    )

In [None]:
if change_analyses:
    for metric in gcm_change_cities:
        plot_values_and_difference(
            gcm_change_cities[metric],
            downscaled_change_cities[metric],
            gcm_change_cities,
            city_coords=obs_cities,
            title1="GCM change in {}".format(metric),
            title2="Downscaled change in {}".format(metric),
            title3="Difference downscaled-GCM",
            variable=var,
            metric=metric,
        )

# Seasonal statistics


### Assess the mean value over the season


In [None]:
aggregator = "mean"
obs_seasonal = get_seasonal(obs, aggregator=aggregator)
downscaled_seasonal = get_seasonal(output_daily.sel(time=train_period), aggregator=aggregator)

In [None]:
cmap_diff_dict = {
    'pr': 'orangeblue_light',
    'tasmax': 'orangeblue_light_r',
    'tasmin': 'orangeblue_light_r',
}

In [None]:
plot_seasonal(obs_seasonal[var], downscaled_seasonal[var], cmap_diff=cmap_diff_dict[var])

### Assess the max value over the season


In [None]:
aggregator = "max"
obs_seasonal = get_seasonal(obs, aggregator=aggregator)
downscaled_seasonal = get_seasonal(output_daily.sel(time=train_period), aggregator=aggregator)
plot_seasonal(obs_seasonal[var], downscaled_seasonal[var])

### Assess the variability over the season


In [None]:
aggregator = "std"
obs_seasonal = get_seasonal(obs, aggregator=aggregator)
downscaled_seasonal = get_seasonal(output_daily.sel(time=train_period), aggregator=aggregator)
plot_seasonal(obs_seasonal[var], downscaled_seasonal[var])

# Precipitation-specific metrics


If precipitation, calculate the precipitation indices across season and plot the
same seasonal comparison maps. Metrics taken from Wilby 1998.


### If it was wet, how wet was it?


In [None]:
if var == "pr":
    wet_day_dict = {'obs': {}, 'downscaled': {}}
    for metric in ["mean", "median", "std", "percentile95"]:
        wet_day_dict['obs'][metric] = metrics.wet_day_amount(obs, method=metric)['pr'].compute()
        wet_day_dict['downscaled'][metric] = metrics.wet_day_amount(
            output_daily.sel(time=train_period), method=metric
        )['pr'].compute()

In [None]:
if var == "pr":
    for metric in ["mean", "median", "std", "percentile95"]:
        plot_values_and_difference(
            wet_day_dict['obs'][metric],
            wet_day_dict['downscaled'][metric],
            cbar_kwargs={"label": f"{metric} wet day amount"},
            diff_limit=50,
            diff_method='percent',
            cmap_diff='orangeblue_light',
            variable=var,
        )

In [None]:
if var == "pr":
    for metric in ["mean", "median", "std", "percentile95"]:
        plot_values_and_difference(
            wet_day_dict['obs'][metric].sel(**regions['Brazil']),
            wet_day_dict['downscaled'][metric].sel(**regions['Brazil']),
            cbar_kwargs={"label": f"{metric} wet day amount"},
            diff_limit=50,
            diff_method='percent',
            cmap_diff='orangeblue_light',
            variable=var,
        )

In [None]:
if var == "pr":
    for metric in ["mean", "median", "std", "percentile95"]:
        plot_values_and_difference(
            wet_day_dict['obs'][metric].sel(
                lat=slice(30, 35), lon=slice(-105, -100)
            ),  # .sel(lat=slice(25,50), lon=slice(-120, -70)),
            wet_day_dict['downscaled'][metric].sel(
                lat=slice(30, 35), lon=slice(-105, -100)
            ),  # .sel(lat=slice(25,50), lon=slice(-120, -70)),
            cbar_kwargs={"label": f"{metric} wet day amount"},
            diff_limit=50,
            diff_method='percent',
            cmap_diff='orangeblue_light',
        )

Calculate boolean masks of wet and dry days (they are inverses of eachother)
based upon a threshold. We'll then use these to perform a variety of statistics.


In [None]:
if var == "pr":
    wet_days_obs = metrics.is_wet_day(obs, threshold=0.0).compute()
    dry_days_obs = ~wet_days_obs
    wet_days_downscaled = metrics.is_wet_day(
        output_daily.sel(time=train_period), threshold=0.0
    ).compute()
    dry_days_downscaled = ~wet_days_downscaled

In [None]:
wet_day_obs_count = wet_days_obs.groupby('time.year').sum().mean(dim='year').compute()
wet_day_downscaled_count = wet_days_downscaled.groupby('time.year').sum().mean(dim='year').compute()

In [None]:
plot_values_and_difference(
    wet_day_obs_count.pr,
    wet_day_downscaled_count.pr,
    cbar_kwargs={"label": "wet day count"},
    diff_limit=200,
    var_limits=(0, 350),
    diff_method='absolute',
    cmap_diff='orangeblue_light',
    variable=var,
)

In [None]:
plot_values_and_difference(
    wet_day_obs_count.pr.sel(**regions['Brazil']),
    wet_day_downscaled_count.pr.sel(**regions['Brazil']),
    cbar_kwargs={"label": f"wet day count"},
    var_limits=(0, 350),
    diff_limit=100,
    diff_method='absolute',
    cmap_diff='orangeblue_light',
    variable=var,
)

### Length and variability of wet spells


In [None]:
if var == "pr":
    for metric in ["mean", "std", "percentile90", "percentile99"]:
        plot_values_and_difference(
            apply_spell_length(wet_days_obs, metric),
            apply_spell_length(wet_days_downscaled, metric),
            title1="ERA5 wet spell length",
            title2="Downscaled wet spell length",
            cbar_kwargs={"label": f"{metric} days"},
            variable=var,
            metric='wet spell length',
            var_limits=(0, 5),
            diff_limit=5,
        )

### Length and variability of dry spells


In [None]:
if var == "pr":
    for metric in ["mean", "std", "percentile90", "percentile99"]:
        plot_values_and_difference(
            apply_spell_length(dry_days_obs, metric),
            apply_spell_length(dry_days_downscaled, metric),
            title1="ERA5 dry spell length",
            title2="Downscaled dry spell length",
            cbar_kwargs={"label": f"{metric} days"},
            variable=var,
            metric='dry spell length',
            var_limits=(0, 5),
            diff_limit=5,
        )

### If today was wet, what are odds tomorrow will be wet?


In [None]:
if var == "pr":
    plot_values_and_difference(
        metrics.probability_two_consecutive_days(obs_ds, kind_of_days="wet")["pr"],
        metrics.probability_two_consecutive_days(ds.sel(time=train_period), kind_of_days="wet")[
            "pr"
        ],
        cbar_kwargs={"label": "probability"},
        diff_limit=0.5,
    )

### If today was dry, what are odds tomorrow will be dry?


In [None]:
if var == "pr":
    plot_values_and_difference(
        metrics.probability_two_consecutive_days(obs_ds, kind_of_days="dry")["pr"],
        metrics.probability_two_consecutive_days(ds.sel(time=train_period), kind_of_days="dry")[
            "pr"
        ],
        title1="ERA5",
        title2="Downscaled",
        cbar_kwargs={"label": f"{metric} prob of sequential dry day"},
        variable=var,
        metric='probability',
        var_limits=(0, 1),
        diff_limit=0.5,
    )

### Low frequency variability - standard deviation of monthly precipitation


In [None]:
if var == "pr":
    plot_values_and_difference(
        monthly_variability(obs_ds.pr, method="sum"),
        monthly_variability(ds.sel(time=train_period).pr, method="sum"),
        title1="ERA5 monthly stdev",
        title2="Downscaled monthly stdev",
        cbar_kwargs={"label": f"mm"},
        variable=var,
        metric='stdev monthly precip',
        var_limits=(0, 50),
        diff_limit=10,
    )