In [None]:
import warnings
import datetime
import matplotlib
import matplotlib.pyplot as plt
import matplotlib as mpl
import cartopy.crs as ccrs
import numpy as np
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import pandas as pd
import os
import scipy.stats

# import cartopy.util

# Import custom modules
from src.XRO import XRO, xcorr
import src.XRO_utils
import src.utils

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

In [None]:
## MPI data
mpi_load_fp = pathlib.Path(DATA_FP, "mpi_Th", "Th.nc")
Th_mpi = xr.open_dataset(mpi_load_fp)

## rename variables for consistency and trim so time periods covered are the same
Th_mpi = Th_mpi.sel(time=slice("2070", "2100"))

In [None]:
## specify T and h variables to use for MPI
T_var_mpi = "T_34"
h_var_mpi = "h"

## specify order of annual cycle
ac_order = 3

## specify which parameters to mask annual cycle out for [(y_idx0, x_idx0), ...]
ac_mask_idx = [(1, 1)]

## initialize model
model = XRO(ncycle=12, ac_order=ac_order, is_forward=True)


## get fit for MPI (ensemble fit)
fit = model.fit_matrix(
    Th_mpi[[T_var_mpi, h_var_mpi]],
    ac_mask_idx=ac_mask_idx,  # , maskNT=['T2', 'TH'],
)

#### Put data in matrix form

In [None]:
month = 6
t_idx_X0 = dict(time=Th_mpi.time.dt.month==month)
t_idx_X1 = dict(time=Th_mpi.time.dt.month==(month+1))

## put in array form
stack = lambda Z : Z.to_dataarray().stack(sample = ["member","time"]).values
X0 = stack(Th_mpi.isel(t_idx_X0)[[T_var_mpi,h_var_mpi]])
X1 = stack(Th_mpi.isel(t_idx_X1)[[T_var_mpi,h_var_mpi]])

dX = X1-X0
dt = 1/12
dXdt = dX / dt

## compute operator
Lhat = dXdt @ np.linalg.pinv(X0)

## compute a different way
G = X1 @ np.linalg.pinv(X0)
w,v = np.linalg.eig(G)
L = v @ np.diag(1/dt * np.log(w)) @ np.linalg.inv(v)

## check relationship holds
print(np.allclose(G, dt * Lhat + np.eye(2)))
print()

## print out matrices
print(np.round(L.real,2))
print()
print(np.round(Lhat.real,2))