# Try to fit real hindcasts

In [1]:
%cd /g/data/xv83/users/ds0092/active_projects/Squire_2022_correlation/notebooks/exploratory

/g/data/xv83/users/ds0092/active_projects/Squire_2022_correlation/notebooks/exploratory


In [72]:
import xarray as xr

import numpy as np

import pandas as pd

from src import utils, data, stats, ar_model

import warnings

from statsmodels.tsa.api import VAR
from statsmodels.tsa.ar_model import AutoReg

import matplotlib.pyplot as plt

warnings.filterwarnings("ignore")

In [3]:
%load_ext autoreload
%autoreload 2
%load_ext lab_black

In [4]:
DATA_DIR = "../../data/processed/"

# Develop/test code with some reanalysis data

The function I developed has since been copied into `src.ar_model`. Here we retain a few tests

In [57]:
HadISST = xr.open_zarr(f"{DATA_DIR}/tos_HadISST.zarr", use_cftime=True)
AMV = (
    utils.calculate_period_AMV_index(
        HadISST["sst"],
        [12, 1, 2, 3],
    )
    .rename("AMV")
    .compute()
)
AMV = utils.round_to_start_of_month(AMV, dim="time")

HadSLP = xr.open_zarr(f"{DATA_DIR}/psl_HadSLP2r.zarr", use_cftime=True)
NAO = (
    utils.calculate_period_NAO_index(
        HadSLP["slp"],
        [12, 1, 2, 3],
    )
    .rename("NAO")
    .compute()
)

AMV, NAO = xr.align(AMV, NAO)
# Even times so can divide in half for testing
reanalysis = xr.merge((AMV, NAO)).isel(time=slice(148))

### Check my AR model fit

In [196]:
n_lags = 2

In [197]:
my_params = ar_model.fit(reanalysis[["AMV"]], n_lags=n_lags, dim="time")
my_params.to_dataframe()

Unnamed: 0_level_0,AMV,model_order
params,Unnamed: 1_level_1,Unnamed: 2_level_1
AMV.lag1,0.456688,2
AMV.lag2,0.254508,2
AMV.noise_var,0.011765,2


In [198]:
their_fit = AutoReg(reanalysis[["AMV"]].to_dataframe(), lags=n_lags, trend="n").fit()
their_params = pd.concat(
    (their_fit.params, pd.Series(their_fit.sigma2, ["AMV.noise_var"]))
)
their_params.to_frame(name="AMV")

Unnamed: 0,AMV
AMV.L1,0.456688
AMV.L2,0.254508
AMV.noise_var,0.011765


### Check my VAR model fit

In [199]:
my_params = ar_model.fit(reanalysis, n_lags=n_lags, dim="time")
my_params.to_dataframe()

Unnamed: 0_level_0,AMV,NAO,model_order
params,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
AMV.lag1,0.451376,-1.992555,2
NAO.lag1,-0.000845,0.060381,2
AMV.lag2,0.259422,-9.765535,2
NAO.lag2,0.000801,-0.017539,2
AMV.noise_var,0.012051,-0.084668,2
NAO.noise_var,-0.084668,35.364779,2


In [200]:
their_fit = VAR(reanalysis.to_dataframe()).fit(n_lags, trend="n")
their_params = pd.concat(
    (their_fit.params, their_fit.sigma_u.rename("sigma_u.{}".format))
)
their_params

Unnamed: 0,AMV,NAO
L1.AMV,0.451376,-1.992555
L1.NAO,-0.000845,0.060381
L2.AMV,0.259422,-9.765535
L2.NAO,0.000801,-0.017539
sigma_u.AMV,0.012051,-0.084668
sigma_u.NAO,-0.084668,35.364779


### Check that you get the same results when you duplicate data to `bystander` and `stack` dimensions

In [201]:
reanalysis_stacked = xr.concat(
    [xr.concat([reanalysis] * 10, dim="member")] * 2, dim="x"
)

In [202]:
my_params = ar_model.fit(reanalysis_stacked, n_lags=n_lags, dim="time")
my_params.to_dataframe()

Unnamed: 0_level_0,Unnamed: 1_level_0,AMV,NAO,model_order
x,params,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,AMV.lag1,0.451376,-1.992555,2
0,NAO.lag1,-0.000845,0.060381,2
0,AMV.lag2,0.259422,-9.765535,2
0,NAO.lag2,0.000801,-0.017539,2
0,AMV.noise_var,0.011753,-0.082575,2
0,NAO.noise_var,-0.082575,34.490375,2
1,AMV.lag1,0.451376,-1.992555,2
1,NAO.lag1,-0.000845,0.060381,2
1,AMV.lag2,0.259422,-9.765535,2
1,NAO.lag2,0.000801,-0.017539,2


### Check that you get the same results when you split `time` into two members
Note one data point gets lost by doing this, so you won't get exactly the same answer

In [203]:
first_half = reanalysis.isel(time=slice(int(reanalysis.sizes["time"] / 2)))
first_half = first_half.assign_coords({"time": range(first_half.sizes["time"])})
second_half = reanalysis.isel(time=slice(int(reanalysis.sizes["time"] / 2), None))
second_half = second_half.assign_coords({"time": range(second_half.sizes["time"])})
reanalysis_stacked = xr.concat(
    [first_half, second_half],
    dim="member",
)

In [204]:
my_params = ar_model.fit(reanalysis_stacked, n_lags=n_lags, dim="time")
my_params.to_dataframe()

Unnamed: 0_level_0,AMV,NAO,model_order
params,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
AMV.lag1,0.451882,-2.171956,2
NAO.lag1,-0.000856,0.059289,2
AMV.lag2,0.256853,-9.726328,2
NAO.lag2,0.000786,-0.018835,2
AMV.noise_var,0.012218,-0.085884,2
NAO.noise_var,-0.085884,35.84808,2


### Check that new `select_order` function works gives consistent results

In [205]:
ar_model.select_order(reanalysis[["AMV"]]).to_dataframe()

[1 2]


Unnamed: 0_level_0,AMV,model_order
params,Unnamed: 1_level_1,Unnamed: 2_level_1
AMV.lag1,0.456688,2
AMV.lag2,0.254508,2
noise_var,0.011765,2


In [206]:
ar_model.fit(reanalysis[["AMV"]], n_lags=2).to_dataframe()

Unnamed: 0_level_0,AMV,model_order
params,Unnamed: 1_level_1,Unnamed: 2_level_1
AMV.lag1,0.456688,2
AMV.lag2,0.254508,2
AMV.noise_var,0.011765,2


# Now let's try fit to some real hindcasts

In [None]:
models = [
    "CanESM5",
    "CESM1-1-CAM5-CMIP5",
    "CMCC-CM2-SR5",
    "EC-Earth3",
    "HadGEM3-GC31-MM",
    "IPSL-CM6A-LR",
    "MIROC6",
    "MPI-ESM1-2-HR",
    "NorCPM1",
]

n_init = 57

hindcast = []
prev_member = 0
for model in models:
    tos = xr.open_zarr(f"{DATA_DIR}/tos_Omon_{model}_dcpp.zarr", use_cftime=True)
    AMV = utils.calculate_period_AMV_index(tos["tos"], [12, 1, 2, 3]).to_dataset(
        name="AMV"
    )
    AMV = AMV.assign_coords({"init": range(1960, 1960 + AMV.sizes["init"])})

    psl = xr.open_zarr(f"{DATA_DIR}/psl_Amon_{model}_dcpp.zarr", use_cftime=True) / 100
    NAO = utils.calculate_period_NAO_index(psl["psl"], [12, 1, 2, 3]).to_dataset(
        name="NAO"
    )
    NAO = NAO.assign_coords({"init": range(1960, 1960 + AMV.sizes["init"])})

    ds = xr.merge((AMV.compute(), NAO.compute()))
    ds = ds.sel(lead=slice(14, 120)).assign_coords({"lead": range(1, 10)})
    ds = ds.assign_coords(
        {"member": np.array(range(1, ds.sizes["member"] + 1)) + prev_member}
    )
    ds = ds.assign_coords({"model": ("member", ds.sizes["member"] * [model])})
    ds = utils.round_to_start_of_month(ds, "time")

    prev_member = ds.member.values[-1]

    hindcast.append(ds)

hindcast = xr.concat(
    hindcast, dim="member", coords="minimal", compat="override", join="inner"
)
hindcast = hindcast.assign_coords({"time": hindcast.time.dt.year})

### Fit the reanalysis to generate the synthetic initial conditions

In [None]:
variables = ["AMV", "NAO"]
n_lags = 3
n_samples = 500

In [None]:
reanalysis_variables = reanalysis[variables]
params_reanalysis = ar_model.fit(reanalysis_variables, n_lags=n_lags, dim="time")

inits = ar_model.generate_samples(
    params_reanalysis, n_times=hindcast.sizes["init"] + n_lags - 1, n_samples=n_samples
)

# Append the reanalysis as the -1th sample
reanalysis_init_dates = reanalysis.time[
    (reanalysis.sizes["time"] - inits.sizes["time"]) :
]

reanalysis_init = reanalysis_variables.sel(time=reanalysis_init_dates)
reanalysis_init = reanalysis_init.assign_coords({"time": inits.time, "sample": [-1]})

inits = xr.concat([reanalysis_init, inits], dim="sample")

In [None]:
reanalysis_acf = stats.acf(reanalysis_variables, partial=True)
inits_acf = stats.acf(inits.isel(sample=slice(1, None)), partial=True)

fig = plt.figure(figsize=(7 * len(variables), 5))
axs = np.array(fig.subplots(1, len(variables), sharex=True)).flatten()
q = (0.025, 0.975)

for idx, var in enumerate(variables):
    ax = axs[idx]
    ax.fill_between(
        inits_acf.lag,
        inits_acf[var].quantile(q[0], dim="sample"),
        inits_acf[var].quantile(q[1], dim="sample"),
        alpha=0.5,
    )
    ax.plot(reanalysis_acf.lag, reanalysis_acf[var])
    ax.grid()

### Fit the hindcasts to generate the synthetic forecasts

In [None]:
hindcast_variables = hindcast[variables]
params_hindcast = ar_model.fit(hindcast_variables, n_lags=n_lags, dim="lead")

In [None]:
fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot(111)

params_reanalysis["AMV"].plot(ax=ax, label="Reanalysis fit parameters")
params_hindcast["AMV"].plot(ax=ax, label="Hindcast fit parameters")
ax.grid()
ax.set_ylabel("Coefficient")
ax.set_title("")
_ = ax.legend()

In [None]:
synthetic_hindcast = ar_model.predict(
    params_hindcast,
    inits,
    n_steps=hindcast.sizes["lead"],
    n_members=hindcast.sizes["member"],
)

synthetic_reanalysis = ar_model.predict(
    params_reanalysis,
    inits,
    n_steps=hindcast.sizes["lead"],
    n_members=hindcast.sizes["member"],
)

### Compare the real hindcasts to the synthetic hindcasts generated using the reanalysis as initial conditions

In [None]:
fig = plt.figure(figsize=(14, 5 * len(variables)))
axs = np.array(fig.subplots(len(variables), 1, sharex=True)).flatten()
q = (0.025, 0.975)

for idx, var in enumerate(variables):
    ax = axs[idx]
    ax.plot(inits.time, inits[var].sel(sample=-1), color="k", label="Reanalysis")

    for init_idx, init_date in list(
        zip(inits.time[n_lags:], reanalysis_init_dates[n_lags:])
    )[::9]:
        init_year = init_date.dt.year.values
        ax.plot(
            init_idx, inits[var].sel(sample=-1, time=init_idx), marker="o", color="k"
        )

        try:
            hcst = hindcast[var].sel(init=init_year)
            time = hindcast.lead + init_idx
            ax.fill_between(
                time,
                hcst.quantile(q[0], dim="member"),
                hcst.quantile(q[1], dim="member"),
                alpha=0.5,
                color="C0",
            )
            ax.plot(time, hcst.mean("member"), color="C0")
        except KeyError:
            pass

        synth_hcst = synthetic_hindcast[var].sel(sample=-1, init=init_idx)
        synth_rean = synthetic_reanalysis[var].sel(sample=-1, init=init_idx)
        time = synthetic_hindcast.lead + init_idx
        ax.fill_between(
            time,
            synth_hcst.quantile(q[0], dim="member"),
            synth_hcst.quantile(q[1], dim="member"),
            alpha=0.5,
            color="C1",
        )
        ax.plot(
            time,
            synth_hcst.mean("member"),
            color="C1",
        )
        ax.fill_between(
            time,
            synth_rean.quantile(q[0], dim="member"),
            synth_rean.quantile(q[1], dim="member"),
            alpha=0.5,
            color="C2",
        )
        ax.plot(
            time,
            synth_rean.mean("member"),
            color="C2",
        )
    ax.grid()

### Compare the autocorrelation functions for one year lead

In [None]:
mean_lead_ranges = [(0, 1), (0, 4), (0, 8)]

synthetic_hindcast = synthetic_hindcast.assign_coords(
    {"time": synthetic_hindcast.init + synthetic_hindcast.lead}
)
synthetic_hindcast_mean = utils.get_hindcast_temporal_mean(
    synthetic_hindcast.isel(sample=slice(1, None)),
    mean_lead_ranges=mean_lead_ranges,
).dropna("time", how="any")

synthetic_reanalysis = synthetic_reanalysis.assign_coords(
    {"time": synthetic_reanalysis.init + synthetic_reanalysis.lead}
)
synthetic_reanalysis_mean = utils.get_hindcast_temporal_mean(
    synthetic_reanalysis.isel(sample=slice(1, None)),
    mean_lead_ranges=mean_lead_ranges,
).dropna("time", how="any")

hindcast_mean = utils.get_hindcast_temporal_mean(
    hindcast, mean_lead_ranges=mean_lead_ranges
).dropna("time", how="any")

In [None]:
temporal_mean = 8

hindcast_acf = stats.acf(
    hindcast_mean.sel(temporal_mean=temporal_mean).mean("member"),
    partial=True,
)

synthetic_hindcast_acf = stats.acf(
    synthetic_hindcast_mean.sel(temporal_mean=temporal_mean).mean("member"),
    partial=True,
)

synthetic_reanalysis_acf = stats.acf(
    synthetic_reanalysis_mean.sel(temporal_mean=temporal_mean).mean("member"),
    partial=True,
)

fig = plt.figure(figsize=(7 * len(variables), 5))
axs = np.array(fig.subplots(1, len(variables), sharex=True)).flatten()
q = (0.025, 0.975)

for idx, var in enumerate(variables):
    ax = axs[idx]
    ax.fill_between(
        synthetic_hindcast_acf.lag,
        synthetic_hindcast_acf[var].quantile(q[0], dim="sample"),
        synthetic_hindcast_acf[var].quantile(q[1], dim="sample"),
        alpha=0.5,
    )
    ax.fill_between(
        synthetic_reanalysis_acf.lag,
        synthetic_reanalysis_acf[var].quantile(q[0], dim="sample"),
        synthetic_reanalysis_acf[var].quantile(q[1], dim="sample"),
        alpha=0.5,
    )
    ax.plot(hindcast_acf.lag, hindcast_acf[var], color="k")
    ax.grid()

In [None]:
temporal_mean = 8

synth_hindcast = synthetic_hindcast_mean["AMV"].sel(
    sample=10, temporal_mean=temporal_mean
)
synth_reanalysis = synthetic_reanalysis_mean["AMV"].sel(
    sample=10, temporal_mean=temporal_mean
)
hcst = hindcast_mean["AMV"].sel(temporal_mean=temporal_mean)

plt.fill_between(
    synth_hindcast.time,
    hcst.quantile(q[0], dim="member"),
    hcst.quantile(q[1], dim="member"),
    alpha=0.5,
)
plt.plot(synth_hindcast.time, hcst.mean("member"))

plt.fill_between(
    synth_hindcast.time,
    synth_hindcast.quantile(q[0], dim="member"),
    synth_hindcast.quantile(q[1], dim="member"),
    alpha=0.5,
)
plt.plot(synth_hindcast.time, synth_hindcast.mean("member"))

plt.fill_between(
    synth_reanalysis.time,
    synth_reanalysis.quantile(q[0], dim="member"),
    synth_reanalysis.quantile(q[1], dim="member"),
    alpha=0.5,
)
plt.plot(synth_reanalysis.time, synth_reanalysis.mean("member"))