# Bias correct model output with respect to observations

In this notebook, we go over the steps required to bias correct model data to a station, <b>thereby providing localized data which is appropriate for making projections at a given station</b>. We choose to bias correct with a method called <i>quantile delta mapping</i> (QDM). QDM is chosen here because it better preserves changes in individual quantiles, rather than (say) only applying a correction to the mean. 

**Intended Application**: As a user, I want to **<span style="color:#FF0000">understand how projections data is localized to a weather station</span>**.

**Runtime**: With the default settings, this notebook takes approximately **2 minutes** to run from start to finish. Modifications to selections may increase the runtime. 

## Step 0: Set-up
Load in the libraries, and define the climakitae `app` object.

In [1]:
import climakitae as ck
import climakitaegui as ckg
import xarray as xr
import pandas as pd
import panel as pn
pn.extension()

from climakitae.util.utils import read_csv_file

from xclim.core.calendar import convert_calendar
from xclim.core.units import convert_units_to
from xclim.sdba.adjustment import QuantileDeltaMapping
from xclim.sdba import Grouper

ImportError: cannot import name 'convert_calendar' from 'xclim.core.calendar' (/srv/conda/envs/notebook/lib/python3.12/site-packages/xclim/core/calendar.py)

## Step 1: Read in the station data
Open the HadISD dataset for Sacramento Executive Airport (KSAC) temperatures and do some processing

In [2]:
# Import stations names and coordinates file
stations = "data/hadisd_stations.csv"
stations_df = read_csv_file(stations)
my_station = 'KSAC'
station_id = str(stations_df[stations_df['ID'] == my_station]['station id'].values[0])

filepaths = [
    "s3://cadcat/hadisd/HadISD_{}.zarr".format(s_id)
    for s_id in [station_id]
]

obs_ds = xr.open_mfdataset(
    filepaths,
    engine="zarr",
    consolidated=False,
    parallel=True,
    backend_kwargs=dict(storage_options={"anon": True}),
)

obs_ds = obs_ds.tas
obs_ds = convert_units_to(obs_ds, "degF")
obs_ds = convert_calendar(obs_ds, "noleap")
# need to unchunk for bias correction
obs_ds = obs_ds.chunk(dict(
    time=-1)).compute()

# extract coordinates
lat0 = obs_ds.latitude.values
lon0 = obs_ds.longitude.values

NameError: name 'convert_units_to' is not defined

In [None]:
obs_ds

In [None]:
obs_ds = obs_ds.loc[(obs_ds.time.dt.year >= 1981) & (obs_ds.time.dt.year <= 2014)]

## Step 2: Read in the model output
Here we specifically pick CESM2 because it is known to have a strong warm bias.

In [None]:
selections = ckg.Select()

In [None]:
selections.scenario_historical=['Historical Climate']
selections.scenario_ssp=['SSP 3-7.0']
selections.append_historical = True
selections.area_average = 'No'
selections.time_slice = (1981, 2060)
selections.resolution = '3 km'
selections.timescale = 'hourly'
selections.variable = 'Air Temperature at 2m'
selections.units = 'degF'
selections.area_subset = 'CA counties'
selections.cached_area = ['Sacramento County']

wrf_ds = selections.retrieve()

# Select just CESM2 simulation 
cesm_sim_name = [sim for sim in wrf_ds.simulation.values if "cesm2" in sim.lower()]
wrf_ds = wrf_ds.sel(simulation = cesm_sim_name).squeeze()

Extract the WRF grid cell closest to the station and process the data

In [None]:
from climakitae.util.utils import get_closest_gridcell

# convert calendar
wrf_ds = convert_calendar(wrf_ds, "noleap")
# extract closest grid cell
wrf_ds = get_closest_gridcell(wrf_ds, lat0, lon0)
# need to unchunk for bias correction
wrf_ds = wrf_ds.chunk(dict(time=-1)).compute()
# do some renaming for plotting ease later
wrf_ds.attrs['physical_variable'] = wrf_ds.name
wrf_ds.name = 'Raw'

## Step 3: Inspect the model data and observations
Here we show record-mean daily-mean temperatures for the observations and raw WRF data to get a sense of the bias in the WRF model.

In [None]:
def group_ds(ds, obs_ds=obs_ds, projected_ceil=2060):
    
    proj_floor = str(projected_ceil-29)
    proj_ceil = str(projected_ceil)
    
    hist_ds = ds.sel(time=slice(str(obs_ds.time.values[0]),
            str(obs_ds.time.values[-1]))).groupby(
            'time.dayofyear').mean()    
    ssp_ds = ds.sel(time=slice(proj_floor,proj_ceil)).groupby(
            'time.dayofyear').mean()    
    obs_ds = obs_ds.groupby(
        'time.dayofyear').mean()
    
    hist_ds = hist_ds.rename(dict(dayofyear = 'Day of Year'))
    ssp_ds = ssp_ds.rename(dict(dayofyear = 'Day of Year'))
    obs_ds = obs_ds.rename(dict(dayofyear = 'Day of Year'))
    
    return hist_ds, ssp_ds, obs_ds

def compare_raw_and_obs(ds, obs_ds=obs_ds, ylim=(None,None), 
                        width=700, height=300): 
    
    hist_gp, ssp_gp, obs_gp = group_ds(ds, obs_ds)
    
    hist_pl = hist_gp.hvplot(label='Historical raw',c='royalblue')    
    ssp_pl = ssp_gp.hvplot(label='Projected raw',c='goldenrod')    
    obs_pl = obs_gp.hvplot(label='Observations',c='k')
   
    pl = obs_pl * hist_pl * ssp_pl
    pl.opts(ylim=ylim, width=width, height=height, legend_position='right',
           toolbar='below', ylabel=obs_ds.units,title='Record-mean Daily-mean '
            +ds.attrs['physical_variable'])
    
    return pl

compare_raw_and_obs(wrf_ds, obs_ds=obs_ds)

## Step 4: Perform the bias correction procedure
The next cells define and perform the two operations needed for bias correction:
1. Create the training set, which finds the adjustment factors between the observations and raw model historical data. Note that the raw model output used for training needs to be from the same time period of the observations.
2. Apply these adjustment factors to the raw model historical and projected data.

All training and adjustment is performed on rolling 90-day windows centered on each day of the year (i.e., +/- 45 days) to allow for seasonal adjustment.

In [None]:
window = 90
def do_QDM(obs, ds, nquantiles=20, 
           group='time.dayofyear', window=window, 
           kind="+"):
    
    group = Grouper(group, window=window)

    ds.attrs['variable'] = ds.name
    ds.name = 'Raw' 
    
    QDM = QuantileDeltaMapping.train(obs, ds.sel(
        time=slice(str(obs_ds.time.values[0]),
        str(obs_ds.time.values[-1]))), 
        nquantiles=nquantiles, group=group, kind=kind)
    
    ds_adj = QDM.adjust(ds).compute()
    
    QDM_ds = QDM.ds.rename(dict(
        dayofyear = 'Day of Year', 
        quantiles='Quantile'))    
    
    ds_adj.name = 'Adjusted' 
    ds_adj = xr.merge([ds, ds_adj])
    
    return QDM_ds,ds_adj

In [None]:
%%time
adj_factors, adj_ds = do_QDM(obs_ds,wrf_ds)

## Step 5: Visualize the bias correction results
### Step 5a. Inspect the raw historical WRF quantiles and the adjustment factor for each quantile

In [None]:
from bokeh.models import HoverTool
from climakitae.util.colormap import read_ae_colormap

raw_cmap = read_ae_colormap(cmap='ae_orange', cmap_hex=True)
af_cmap = read_ae_colormap(cmap='ae_diverging', cmap_hex=True)

hover_temp = HoverTool(description='Custom Tooltip', 
        tooltips=[('Quantile', '@Quantile'), 
        ('Day of Year', '@{Day_of_Year}'),
                 ('Temperature (degF)', '@hist_q')])
hover_adj = HoverTool(description='Custom Tooltip', 
        tooltips=[('Quantile', '@Quantile'), 
        ('Day of Year', '@{Day_of_Year}'),
                 ('Adjustment (degF)', '@af')])

raw_temp_qs = adj_factors['hist_q'].hvplot.quadmesh(
    x='Quantile',y='Day of Year',z='hist_q').opts(
    tools=[hover_temp], width=425, height=300,
    cmap=raw_cmap, clabel="degF", 
    title = "Temperature quantiles by day of year")
adjf_temp = adj_factors['af'].hvplot.quadmesh(
    x='Quantile',y='Day of Year',z='af').opts(
    tools=[hover_adj], width=425, height=300,
    cmap=af_cmap, clabel="degF",
    title="Adjustment factors for each quantile")

raw_and_af = raw_temp_qs + adjf_temp
raw_and_af.opts(title="Raw historical quantiles"
                + " and computed adjustment factors",
               toolbar='below')

As expected, adjustment factors here tend to be negative, which is consistent with the known warm bias in the CESM2 global climate model. Now we will compare the raw and adjusted results for the historical and projected model data.

### Step 5b: Directly compare the raw and adjusted data

First we define a function to make comparisons easy:

In [None]:
def make_comparison_plot(hist_ds, ssp_ds, obs_ds=None, 
                         width=475, height=300, title="",ylim=(None,None)):
    
    y_str = hist_ds.physical_variable+' ('+hist_ds.attrs['units']+')'
    
    pl_historical = hist_ds.Raw.hvplot(
        color="royalblue", label='Historical '+hist_ds.Raw.name) 
    pl_historical *= hist_ds.Adjusted.hvplot(
        color="goldenrod", line_width=2,
        label='Historical '+hist_ds.Adjusted.name)    
    if obs_ds is not None:
        pl_historical *= obs_ds.hvplot(
            color='k',line_width=1, label="Observations")        
    pl_historical.opts(
        legend_position='top_left', width=width, 
        height=height, title=title, ylabel=y_str, ylim=ylim)
    
    pl_ssp = ssp_ds.Raw.hvplot(
        color="royalblue", label='Projected '+ssp_ds.Raw.name) 
    pl_ssp *= ssp_ds.Adjusted.hvplot(
        color="goldenrod", line_width=2,label='Projected ' 
        +ssp_ds.Adjusted.name)    
    pl_ssp.opts(
        legend_position='top_left', width=width, 
        height=height, title=title, ylabel=y_str,
        ylim=ylim)
    
    return pl_historical + pl_ssp

#### 1. Record-mean daily-mean raw and adjusted data.

In [None]:
hist_gp, ssp_gp, obs_gp = group_ds(adj_ds, obs_ds) 

make_comparison_plot(hist_gp, ssp_gp, obs_ds=obs_gp
                    ).opts(title="Record-mean Daily-mean "
                          + "Raw and Adjusted Output "
                          + "for Historical and Projected "
                          + "Time Periods")

#### 2. Annual mean data

In [None]:
def ann_gp(ds, obs_ds=obs_ds, projected_ceil=2060):
    
    proj_floor = str(projected_ceil-29)
    proj_ceil = str(projected_ceil)    
    hist_ds = ds.sel(time=slice(str(
            obs_ds.time.values[0]),
            str(obs_ds.time.values[-1]))).groupby(
            'time.year').mean()    
    ssp_ds = ds.sel(time=slice(proj_floor,
            proj_ceil)).groupby('time.year').mean()    
    obs_ds = obs_ds.groupby('time.year').mean()    
    hist_ds = hist_ds.rename(dict(year = 'Year'))
    ssp_ds = ssp_ds.rename(dict(year = 'Year'))
    obs_ds = obs_ds.rename(dict(year = 'Year'))    
    return hist_ds, ssp_ds, obs_ds

ann_hist, ann_ssp, ann_obs = ann_gp(adj_ds, obs_ds=obs_ds)
make_comparison_plot(ann_hist, ann_ssp, obs_ds=ann_obs,
                    ylim=(57,68)).opts(shared_axes=False, 
                    title="Annual Mean Raw and Adjusted Output "
                    + "for Historical and Projected Time Periods")

#### 3. Zoom in on the hourly time series by sampling some weather extremes
Define a function to identify extremes in the timeseries.

In [None]:
def sel_extremes(ds, obs_ds, projected_ceil=2060, window=72):
    
    window = window
    proj_floor = str(projected_ceil-29)
    proj_ceil = str(projected_ceil)    
    hist_ds = ds.sel(time=slice(str(
        obs_ds.time.values[0]),
        str(obs_ds.time.values[-1])))    
    hist_max = hist_ds.isel(time=slice(
        hist_ds.Raw.argmax().values-window,
        hist_ds.Raw.argmax().values+window))
    hist_min = hist_ds.isel(time=slice(
        hist_ds.Raw.argmin().values-window,
        hist_ds.Raw.argmin().values+window))
    
    ssp_ds = ds.sel(time=slice(proj_floor,
            proj_ceil))
    ssp_max = ssp_ds.isel(time=slice(
        ssp_ds.Raw.argmax().values-window,
        ssp_ds.Raw.argmax().values+window))
    ssp_min = ssp_ds.isel(time=slice(
        ssp_ds.Raw.argmin().values-window,
        ssp_ds.Raw.argmin().values+window))    

    return hist_min, hist_max, ssp_min, ssp_max
hist_min, hist_max, ssp_min, ssp_max = sel_extremes(adj_ds, obs_ds)

##### Plot record maxima for raw and adjusted historical and projected data
You will see a repeated warning regarding "non-standard calendar" -- don't worry about this! This just means that the data specifically had leap days removed in order to downscale. Removing leap days in bias correction is the standard best practice for consistency. 

In [None]:
make_comparison_plot(hist_max, ssp_max, ylim=(55, 130),
                    ).opts(shared_axes=False, 
                    title="Maximum hourly temperature +/- 72 hours")

##### Plot record minima for raw and adjusted historical and projected data
You'll see the same "non-standard calendar" warning again - you can safely ignore this. 

In [None]:
make_comparison_plot(hist_min, ssp_min, ylim=(-5, 53),
                    ).opts(shared_axes=False, 
                    title="Minimum hourly temperature +/- 72 hours")