In [None]:
import xarray as xr
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import cftime
import cartopy.crs as ccrs

## set plotting style
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## initialize random number generator
rng = np.random.default_rng()

#### Define stochastic climate model

In [None]:
def get_T_bar(t, trend, trend_type="exp"):
    """Get 'background' temperature for stochastic climate model at times 't'.
    Args:
        - 't' is array of times (in units of years)
        - 'trend' is float representing linear trend in background temperature,
            computed based on initial and final background temperature (units of K/year)
        - 'trend_type' is one of {"exp","linear"}, indicating whether the shape
            of the background temperature curve should be exponential or linear
    """

    ## get final and initial times
    tf = t[-1]
    ti = t[0]

    ## get "background" temperature
    if trend_type == "exp":
        a = 3e-2
        b = trend * (tf - ti) * np.exp(-a * (tf - ti))
        T_bar = b * np.exp(a * (t - ti))

    elif trend_type == "linear":
        T_bar = (t - t[0]) * trend

    else:
        print("Not a valid trend type")
        T_bar = 0.0

    return T_bar


def stochastic_climate_model(
    ti,
    tf,
    dt=1 / 365.25,
    g=-0.3,
    n=0.3,
    n_members=1,
    trend=0,
    trend_type="exp",
    nyears_spinup=10,
):
    """Minimal version of the 'stochastic climate model' studied
    by Hasselman et al (1976) and Frankignoul and Hasselmann (1977).
    (See Eqn 3.6 in Frankignoul and Hasselman, 1977). The damping rate
    used here is much lower than in the paper, for illustration purposes.
    Args:
        - ti: number representing initial time (units: years)
        - tf: number representing final time (units: years)
        - dt: timestep (units: years)
        - g: damping rate (equivalent to lambda in the paper; units: 1/year)
        - n: noise amplitude (units: K / year^{1/2})
        - n_members: number of ensemble members
        - trend: calculated increase in "background" T, with units of K/year
            (based on (T[tf]-T[ti]) / (tf-ti)
        - trend_type: one of "exp" (exponential) or "linear"
        - nyears_spinup: number of spinup years (discard these)
    """

    ## initialize RNG
    rng = np.random.default_rng()

    ## Get timesteps and dimensions for output
    t = np.arange(ti, tf, dt)
    nt = len(t)

    ## Create empty arrays arrays to hold simulation output
    T = np.zeros([n_members, nt])

    ## initialize with random value
    T[:, 0] = n * rng.normal(size=n_members)

    ## get "background" temperature
    T_bar = get_T_bar(t, trend, trend_type)[None, :]

    for i, t_ in enumerate(tqdm(t[:-1])):
        dW = np.sqrt(dt) * rng.normal(size=n_members)
        dT = g * (T[:, i] - T_bar[:, i]) * dt + n * dW
        T[:, i + 1] = T[:, i] + dT

    ##  put in xarray
    time_idx = xr.cftime_range(start=cftime.datetime(ti, 1, 1), periods=nt, freq="1D")
    e_member_idx = pd.Index(np.arange(1, n_members + 1), name="ensemble_member")

    T = xr.DataArray(
        T,
        dims=["ensemble_member", "time"],
        coords={"ensemble_member": e_member_idx, "time": time_idx},
    )

    ## resample to Annual
    T = T.resample({"time": "YS"}).mean()

    ## change time coordinate to year
    year = T.time.dt.year.values
    T = T.rename({"time": "year"})
    T["year"] = year

    ## discard model spinup
    T = T.isel(year=slice(nyears_spinup, None))

    return T

#### Run the model

In [None]:
## specify number of ensemble members and end year for simulation
warming_trend = 0.005  # warming trend, in deg C / year
n_members = 1000  # number of ensemble members

## simulation pre-industrial and warming scenarios
T_PI = stochastic_climate_model(
    ti=1850, tf=2006, n_members=n_members, trend=0, nyears_spinup=5
)
T_warming = stochastic_climate_model(
    ti=1850, tf=2006, n_members=n_members, trend=warming_trend, nyears_spinup=5
)

## for convenience, get subset of pre-industrial control which overlaps with warming
T_PI_hist = T_PI.sel(year=T_warming.year)

In [None]:
def plot_ensemble_spread(ax, T, color, label=None):
    """plot mean and +/- 1 standard dev. of ensemble on
    given ax object."""

    ## compute stats
    mean = T.mean("ensemble_member")
    std = T.std("ensemble_member")

    ## plot mean
    mean_plot = ax.plot(mean.year, mean, label=label, color=color)

    ## plot spread
    ax.plot(mean.year, mean + std, lw=0.5, c=mean_plot[0].get_color())
    ax.plot(mean.year, mean - std, lw=0.5, c=mean_plot[0].get_color())

    return


## Plot ensemble stats
fig, ax = plt.subplots(figsize=(4, 3))

## plot data
plot_ensemble_spread(ax, T_PI_hist, color="black", label="P.I. control")
plot_ensemble_spread(ax, T_warming, color="red", label="warming")

## label axes
ax.set_xlabel("Year")
ax.set_ylabel(r"SST anomaly ($^{\circ}C$)")
ax.legend(prop={"size": 10})
ax.set_title(r"Ensemble mean $\pm$1 standard dev.")

plt.show()