# Test EOF reconstructions

## Imports

In [1]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import xarray as xr
import pathlib
import os
import src.utils

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])

## Load data

In [2]:
## MPI data
mpi_load_fp = pathlib.Path(DATA_FP, "mpi_Th", "Th.nc")
Th_mpi = xr.open_dataset(mpi_load_fp)

## EOFs (only do sst)
eofs_fp = pathlib.Path(DATA_FP, "mpi", "eofs")
eofs_sst = src.utils.load_eofs(eofs_fp / "ts.nc")

## for convenience, put PC data into single dataset
scores = eofs_sst.scores()

## Compute anomalies

In [3]:
emean = scores.mean("member")
anom = scores - emean

## Tests

### Identity reconstruction

In [4]:
## get random subset for testing
sample = scores.isel(member=slice(2, 4), time=slice(24, 60))

## identity function
identity = lambda x: x

## reconstruct 2 different ways
r0 = eofs_sst.inverse_transform(sample)
r1 = src.utils.reconstruct_fn(eofs_sst.components(), sample, fn=identity)

## check allclose
print(f"All close? {np.max(np.abs(r0-r1)).values < 1e-10}")

All close? True


### Niño 3.4 reconstruction

In [5]:
## next, nino 34
n34_r0 = src.utils.get_nino34(r0)
n34_r1 = src.utils.reconstruct_fn(
    eofs_sst.components(), sample, fn=src.utils.get_nino34
)
print(f"All close? {np.allclose(n34_r0, n34_r1)}\n")

## get reconstruction error on full dataset
n34_recon = src.utils.reconstruct_fn(
    eofs_sst.components(), anom, fn=src.utils.get_nino34
)
corr = xr.corr(n34_recon.sel(time=Th_mpi.time), Th_mpi["T_34"])
print(f"Corr w/ non-truncated Niño index: {corr:.2f}")

All close? True

Corr w/ non-truncated Niño index: 0.96


### spatial variance

In [6]:
## point to check variance at
posn_coords = dict(latitude=-20.5, longitude=200.5)

## compute using custom function
var0_all = src.utils.reconstruct_var(anom, eofs_sst.components()).compute()
var0 = var0_all.sel(posn_coords).values.item()

## check at single point (reconstruct, then compute variance)
sel_fn = lambda x: x.sel(posn_coords)
recon = src.utils.reconstruct_fn(eofs_sst.components(), anom, fn=sel_fn)
var1 = (recon**2).mean(["time", "member"]).values.item()

print(f"All close? {np.allclose(var0, var1)}")

All close? True


#### synthetic example

In [7]:
m = 100
n = 24

## generate random data
rng = np.random.default_rng()
X = rng.normal(size=(m, n))

## get true variance
varX = np.diag(src.utils.get_cov(X))

## do SVD on data
U, s, Vt = np.linalg.svd(X, full_matrices=False)
S = np.diag(s)

## compute covariance and reconstruct variance
scores_cov = src.utils.get_cov(Vt)
varXhat = np.diag(U @ S @ scores_cov @ S @ U.T)

print(f"All close? {np.allclose(varX, varXhat)}")

All close? True
