# MPI–ORAS validation
Comparison between MPI climate model and ORAS5

## Imports

In [None]:
import warnings
import datetime
import matplotlib
import matplotlib.pyplot as plt
import matplotlib as mpl
import cartopy.crs as ccrs
import numpy as np
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import pandas as pd
import os
import scipy.stats
import calendar
import copy

# import cartopy.util

# Import custom modules
from src.XRO import XRO, xcorr
import src.XRO_utils
import src.utils
import src.lim

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## $T, h$ metrics

### Load data

In [None]:
## MPI data
mpi_load_fp = pathlib.Path(DATA_FP, "mpi_Th", "Th.nc")
Th_mpi = xr.open_dataset(mpi_load_fp)

## ORAS5 reanalysis (use as benchmark)
oras_load_fp = pathlib.Path(DATA_FP, "XRO_indices_oras5.nc")
Th_oras = xr.open_dataset(oras_load_fp)

## rename variables for consistency and trim so time periods covered are the same
Th_mpi = Th_mpi.sel(time=slice("1979", "2024"))
Th_oras = Th_oras[["Nino34", "WWV"]].rename({"Nino34": "T_34", "WWV": "h"})

### Diagnostics

#### Seasonal synchronization

In [None]:
## specify which "T" variable to use for MPI
T_var_mpi = "T_34"

## func to compute std dev as a function of month
get_std = lambda x: x.groupby("time.month").std("time")

## compute std for each dataset
oras_std = get_std(Th_oras["T_34"])
mpi_std = get_std(Th_mpi[T_var_mpi])
mpi_std_plot = src.utils.get_ensemble_stats(mpi_std)

## months (x-coordinate for plotting
months = np.arange(1, 13)

### Set up plot
fig, ax = plt.subplots(figsize=(4, 3))

## plot for ORAS5
oras_plot = ax.plot(months, oras_std, label="ORAS5")

## plot MPI ensemble mean
mpi_plot = ax.plot(months, mpi_std_plot.sel(posn="center"), label="MPI")

## plot MPI bounds
kwargs = dict(c=mpi_plot[0].get_color(), ls="--", lw=1)
for bound in ["upper", "lower"]:
    ax.plot(months, mpi_std_plot.sel(posn=bound), **kwargs)

## adjust limits and label
ax.set_ylim([0.4, None])
ax.set_yticks([0.5, 1])
ax.set_xticks([1, 5, 12], labels=["Jan", "May", "Dec"])
ax.set_xlabel("Month")
ax.set_ylabel(f"$\\sigma(T)$")
ax.set_title("Seasonal synchronization (Niño 3.4)")
ax.legend()
plt.show()

#### $T,h$ cross-correlation

Plotting funcs

In [None]:
def format_axs(axs):
    ## add axis lines to plots
    axis_kwargs = dict(c="k", lw=0.5, alpha=0.5)
    for ax in axs.flatten():
        ax.axhline(0, **axis_kwargs)
        ax.axvline(0, **axis_kwargs)
        ax.set_ylim([-0.7, 1.1])

    #### label plots

    ## bottom row
    for ax in axs[1]:
        ax.set_xlabel("Lag (years)")
        ax.set_xticks([-12, 0, 12], labels=[-1, 0, 1])

    ## top row
    for ax in axs[0]:
        ax.set_xticks([])

    ## left col
    for ax in axs[:, 0]:
        ax.set_ylabel("Correlation")

    ## right col
    for ax in axs[:, 1]:
        ax.set_yticks([])

    return axs


def plot_oras(ax, data, color, label=None):
    """Plot curve for oras"""

    ax.plot(data.lag, data, c=color, label=label, ls="--")

    return


def plot_mpi(ax, data, color, label=None):
    """plot mean and bounds for MPI"""

    ## center
    ax.plot(
        data.lag,
        data.sel(posn="center"),
        c=color,
        label=label,
    )

    ## bounds
    ax.fill_between(
        data.lag, data.sel(posn="upper"), data.sel(posn="lower"), color=color, alpha=0.2
    )

    return

Compute stats

In [None]:
## specify which "T" variable to use for MPI
T_var_mpi = "T_34"
h_var_mpi = "h"

## compute cross-correlation
xcorr_oras = xcorr(Th_oras, Th_oras["T_34"], maxlags=18)
xcorr_mpi = xcorr(Th_mpi, Th_mpi[T_var_mpi], maxlags=18)

## compute MPI stats
xcorr_mpi_stats = src.utils.get_ensemble_stats(xcorr_mpi)

Make plot

In [None]:
## specify plot properties for legend
legend_prop = dict(size=7)

fig, axs = plt.subplots(2, 2, figsize=(6, 5), layout="constrained")

## plot <T,T>
axs[0, 0].set_title(r"$<T, T>$")
plot_oras(axs[0, 0], xcorr_oras["T_34"], color="r", label="ORAS5")
plot_mpi(axs[0, 0], xcorr_mpi_stats[T_var_mpi], color="r", label="MPI")

## plot <T,h>
axs[0, 1].set_title(r"$<T, h>$")
plot_oras(axs[0, 1], xcorr_oras["h"], color="k", label="ORAS5")
plot_mpi(axs[0, 1], xcorr_mpi_stats[h_var_mpi], color="k", label="MPI")

## plot for ORAS5
axs[1, 0].set_title("ORAS5")
plot_oras(axs[1, 0], xcorr_oras["T_34"], color="r", label="$<T, T>$")
plot_oras(axs[1, 0], xcorr_oras["h"], color="k", label="$<T, h>$")

## plot for MPI
axs[1, 1].set_title("MPI")
plot_mpi(axs[1, 1], xcorr_mpi_stats[T_var_mpi], color="r", label="$<T, T>$")
plot_mpi(axs[1, 1], xcorr_mpi_stats[h_var_mpi], color="k", label="$<T, h>$")

## clean up axes
axs = format_axs(axs)

for ax in axs.flatten():
    ax.legend(prop=dict(size=7))

plt.show()

#### Power spectrum

Compute

In [None]:
## specify which variable to use
varname = "T_34"

## specify args for psd
psd_kwargs = dict(dim="time", dt=1 / 12, nw=5)

## compute PSD
compute_psd = lambda x: src.XRO_utils.pmtm(x[varname], **psd_kwargs)
psd_oras = compute_psd(Th_oras)
psd_mpi = compute_psd(Th_mpi)

Make plot

In [None]:
fig, ax = plt.subplots(figsize=(3, 2.5), layout="constrained")

## plot data
src.utils.plot_psd(ax, psd_oras, label="ORAS5")
src.utils.plot_psd(ax, psd_mpi, label="MPI")

## label
ax.set_ylabel(r"PSD ($^{\circ}$C$^2$/cpm)")
ax.legend(prop=dict(size=6))


plt.show()

#### PDF / Skewness

In [None]:
## select variables to compare
time_idx = dict(time=slice(None, None, None))
mpi_var = "T_34"
oras_var = "T_34"

## edges for PDF
edges = np.arange(-3.4, 3.8, 0.4)

## extract relevant data
x_mpi = Th_mpi[mpi_var].isel(time_idx).values.flatten()
x_oras = Th_oras[oras_var].isel(time_idx).values.flatten()

## compute and plot pdf
pdf_mpi, _ = src.utils.get_empirical_pdf(x_mpi, edges=edges)
pdf_oras, _ = src.utils.get_empirical_pdf(x_oras, edges=edges)

## compute skewness
skew_mpi = scipy.stats.skew(x_mpi)
skew_oras = scipy.stats.skew(x_oras)

## get gaussian best fit to obs.
pdf_gauss, pdf_gauss_pts = src.utils.get_gaussian_best_fit(x_oras)

## set up plot
fig, ax = plt.subplots(figsize=(4, 3))

## plot pdf
ax.stairs(pdf_mpi, edges, fill=True, alpha=0.3, label=f"MPI (skew = {skew_mpi:.2f})")
ax.stairs(pdf_oras, edges, color="r", lw=1.5, label=f"ORAS (skew = {skew_oras:.2f})")

## plot gaussian best fit
ax.plot(pdf_gauss_pts, pdf_gauss, c="k", lw=1)

## label
ax.legend(prop=dict(size=8))
ax.set_xlabel(r"$^{\circ}C$")
ax.set_ylabel("Prob. density")

plt.show()

### RO-based diagnostics
Note: need to non-dimensionalize $T$ and $h$ data to compare $F_1$ and $F_2$ parameters to each other and to $R$ and $\epsilon$

In [None]:
## get non-dimensional scale for each variable (specific to each dataset)
scale_oras = Th_oras.std()
scale_mpi = Th_mpi.std()

## non-dimensionalize
Th_oras_nondim = Th_oras / scale_oras
Th_mpi_nondim = Th_mpi / scale_mpi

In [None]:
def plot_param_ensemble(ax, param, label=None):
    """plot ensemble of parameters on given ax object"""

    ## specify plot style for ensemble mean v. members
    mean_kwargs = dict(c="k", lw=3, label=label)
    mem_kwargs = dict(c="k", lw=0.5, alpha=0.2)

    ## plot ensemble mean
    ax.plot(param.cycle, param.mean("member"), **mean_kwargs)

    ## plot individual members
    for m in param.member:
        ax.plot(param.cycle, param.sel(member=m), **mem_kwargs)

    return


def plot_param(ax, param, **plot_kwargs):
    """plot parameter on given ax object"""

    ax.plot(param.cycle, param, **plot_kwargs)


def label_ac_ax(ax):
    """add labels to growth rate annual cycle plot (ax object)"""

    ax.set_ylabel(r"Growth rate (yr$^{-1}$)")
    ax.axhline(0, c="k", ls="-", lw=0.5)
    ax.set_xticks([1, 8, 12], labels=["Jan", "Aug", "Dec"])

    return

#### Growth rate / periodicity

Do the computation

In [None]:
## specify T and h variables to use for MPI
T_var_mpi = "T_34"
h_var_mpi = "h"

## specify order of annual cycle
ac_order = 3

## specify which parameters to mask annual cycle out for [(y_idx0, x_idx0), ...]
ac_mask_idx = [(1, 1)]
# ac_mask_idx = None

## initialize model
model = XRO(ncycle=12, ac_order=ac_order, is_forward=True)

## get fit for reanalysis
fit_oras = model.fit_matrix(Th_oras_nondim[["T_34", "h"]], ac_mask_idx=ac_mask_idx)

## get fit for MPI (ensemble-of-RO method)
_, fit_mpi = src.utils.get_RO_ensemble(
    Th_mpi_nondim,
    model=model,
    T_var=T_var_mpi,
    h_var=h_var_mpi,
    ac_mask_idx=ac_mask_idx,
)

## get fit for MPI (ensemble fit)
fit_mpi_all = model.fit_matrix(
    Th_mpi_nondim[[T_var_mpi, h_var_mpi]],
    ac_mask_idx=ac_mask_idx,  # , maskNT=['T2', 'TH'],
)

## compute timescales for each
bj_oras, period_oras = src.utils.get_timescales(model, fit_oras)
bj_mpi, period_mpi = src.utils.get_timescales_ensemble(model, fit_mpi)
bj_mpi_all, period_mpi_all = src.utils.get_timescales(model, fit_mpi_all)

## get PDF for period
pdf, edges = src.utils.get_empirical_pdf(period_mpi, edges=np.arange(2, 8, 1 / 3))

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(7, 3), layout="constrained")

### plot seasonal cycle of growth rate
# MPI: ensemble and ensemble mean
plot_param_ensemble(axs[0], bj_mpi, label="MPI (ensemble mean)")

# MPI: fit to all ensemble members
plot_kwargs = dict(c="r", ls="-", lw=2, label="MPI (ensemble fit)")
plot_param(axs[0], bj_mpi_all, **plot_kwargs)

# ORAS5
plot_kwargs = dict(c="k", ls="--", lw=2, label="ORAS5")
plot_param(axs[0], bj_oras, **plot_kwargs)
axs[0].legend(prop=dict(size=8))

## label axes
label_ac_ax(axs[0])

## plot period
axs[1].stairs(pdf, edges, color="k", label="MPI (PDF)")
axs[1].axvline(period_oras, c="k", ls="--", lw=2, label="ORAS5")
axs[1].axvline(period_mpi_all, c="r", ls="-", lw=2, label="MPI (ensemble fit)")
axs[1].axvline(period_mpi.mean(), c="k", lw=2, label="MPI (PDF mean)")
axs[1].set_xlim([2, None])
axs[1].set_yticks([])
axs[1].set_xlabel("Period (yrs)")
axs[1].legend(prop=dict(size=8))

plt.show()

#### Parameters

In [None]:
## Get parameters for ORAS and MPI
params_oras = model.get_RO_parameters(fit_oras)
params_mpi = model.get_RO_parameters(fit_mpi)
params_mpi_all = model.get_RO_parameters(fit_mpi_all)

## get labels of parameters for plotting
labels = [r"$R$", r"$F_1$", r"$F_2$", r"$\epsilon$"]
var_names = ["R", "F1", "F2", "epsilon"]

## set up plot
fig, axs = plt.subplots(2, 2, figsize=(7, 6), layout="constrained")

## separate plo for each set of parameters
for i, (ax, label, n) in enumerate(zip(axs.flatten(), labels, var_names)):

    ### plot seasonal cycle parameter for each
    plot_param_ensemble(ax, params_mpi[n], label="MPI (ensemble mean)")
    plot_param(ax, params_mpi_all[n], c="r", lw=2, label="MPI (ensemble fit)")
    plot_param(ax, params_oras[n], c="k", ls="--", lw=2, label="ORAS")
    ax.set_title(label)
    ax.set_ylim([-3, 5])
    label_ac_ax(ax)

## make plot less clustered
for ax in axs[0, :]:
    ax.set_xlabel(None)
    ax.set_xticks([])
for ax in axs[:, 1]:
    ax.set_ylabel(None)
    ax.set_yticks([])
for ax in axs[:, 0]:
    ax.set_yticks([-2, 0, 2, 4])

axs[0, 0].legend(prop=dict(size=8))
plt.show()

## Spatial data (SST, SSH)

#### Load ORAS5 data and compute anomalies

In [None]:
## load data
data_oras = src.utils.load_oras_spatial_extended(
    DATA_FP / "oras5", varnames=["tos", "ssh", "taux", "d20"]
)

## Convert ssh in ORAS from m to cm
data_oras["ssh"].values *= 100

## estimate forced signal by removing 2nd-order polynomial from each calendar month
detrend_fn = lambda x: src.utils.detrend_dim(x, dim="time", deg=2)
oras_anom = data_oras.groupby("time.month").map(detrend_fn)
oras_forced = data_oras - oras_anom

### Load MPI data and compute anomalies

In [None]:
## path to EOF data
eofs_fp = pathlib.Path(DATA_FP, "mpi", "eofs300")

## variables to load (and how to rename them)
names = ["ts", "ssh", "taux", "nhf"]
newnames = ["sst", "ssh", "taux", "nhf"]

## load the EOFs
load_var = lambda x: src.utils.load_eofs(pathlib.Path(eofs_fp, f"{x}.nc"))
eofs = {y: load_var(x) for (y, x) in zip(newnames, names)}

## for convenience, put spatial patterns / components in single dataset
components = xr.merge([eofs_.components().rename(y) for (y, eofs_) in eofs.items()])
pc_data = xr.merge([eofs_.scores().rename(y) for (y, eofs_) in eofs.items()])

## convert unit of taux to something reasonable (check with Yann about this...)
pc_data["taux"].values *= 1e-3

## get forced signal and anomalies
forced_all = pc_data.mean("member")
anom_all = pc_data - forced_all

## trim to same period as ORAS
yr_range = [str(y) for y in data_oras.time.dt.year.values[[0, -1]]]

## estimate forced signal
forced = forced_all.sel(time=slice(*yr_range))
anom = anom_all.sel(time=slice(*yr_range))

### Mean

#### Compute

In [None]:
## compute time-mean for ORAS
mean_oras = oras_forced.mean("time")

## compute time-mean for MPI
mean_mpi = forced.mean("time")
mean_mpi_spatial = xr.merge(
    [
        eofs["sst"].inverse_transform(mean_mpi["sst"]).rename("sst"),
        eofs["ssh"].inverse_transform(mean_mpi["ssh"]).rename("ssh"),
    ]
)

## interpolate to match ORAS grid
mean_mpi_spatial = mean_mpi_spatial.interp_like(mean_oras)

## compute absolute bias
bias = mean_mpi_spatial - mean_oras

#### Plot

In [None]:
## get lon and lat for convenience
lon = mean_oras.longitude
lat = mean_oras.latitude

## specify plotting styles for bias and ground truth
bias_kwargs = dict(
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(3, 0.3),
    transform=ccrs.PlateCarree(),
    extend="both",
    alpha=1,
)

truth_kwargs = dict(
    levels=10,
    transform=ccrs.PlateCarree(),
    extend="both",
    linewidths=0.7,
    colors="k",
    alpha=0.5,
)

## set up the plot
fig = plt.figure(figsize=(6, 1.2), layout="constrained")
ax = src.utils.plot_setup(fig, lon_range=[100, 300], lat_range=[-20, 20])
ax.set_title("SST bias")

## plot bias
plot_data = ax.contourf(lon, lat, bias["sst"], **bias_kwargs)
cb = fig.colorbar(plot_data, ticks=[-3, 0, 3], label=r"$^{\circ}$C")

## plot "true" mean state
plot_data = ax.contour(lon, lat, mean_oras["sst"], **truth_kwargs)

plt.show()

### Spatial variance

#### Compute

In [None]:
## subset for month if desired
months = None

## subset data
oras_anom_ = src.utils.sel_month(oras_anom, months)
anom_ = src.utils.sel_month(anom, months)

## compute for ORAS
var_oras = oras_anom_.std("time")

## compute for MPI
var_mpi = src.utils.reconstruct_std(scores=anom_, components=components)

## interpolate to same grid
var_mpi = var_mpi.interp_like(var_oras)

## compute bias
bias = var_mpi - var_oras

#### Plot (large scale)

In [None]:
def plot_var(x0, x1, amp, amp_diff, label, amp_min=0):
    """plot variance comparison"""

    ## set up paneled subplot
    fig = plt.figure(figsize=(6, 3.6), layout="constrained")
    format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
    axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=1, format_func=format_func)

    ## plot data
    kwargs = dict(
        var0=x0,
        var1=x1,
        amp=amp,
        amp_diff=amp_diff,
        amp_min=amp_min,
        show_colorbars=True,
        cbar_label=label,
    )
    fig, axs = src.utils.make_variance_subplots(fig, axs, **kwargs)

    plt.show()

SST

In [None]:
plot_var(
    x0=var_oras["sst"],
    x1=var_mpi["sst"],
    amp=1.5,
    amp_diff=0.5,
    amp_min=0.5,
    label=r"$^{\circ}$C",
)

SSH

In [None]:
plot_var(
    x0=var_oras["ssh"], x1=var_mpi["ssh"], amp=10, amp_diff=5, amp_min=3, label=r"$cm$"
)

#### Plot (zoom in on Niño regions)

In [None]:
def plot_setup_zoom(ax):
    """Plot Pacific region"""

    ## trim and add coastlines
    ax.coastlines(linewidth=0.3)
    ax.set_extent([160, 285, -10, 10], crs=ccrs.PlateCarree())

    return ax


## set up paneled subplot
fig = plt.figure(figsize=(5, 3.5), layout="constrained")
axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=1, format_func=plot_setup_zoom)

## plot data
kwargs = dict(
    var0=var_oras["sst"],
    var1=var_mpi["sst"],
    amp=1.5,
    amp_diff=0.5,
    amp_min=0.5,
    show_colorbars=True,
)
fig, axs = src.utils.make_variance_subplots(fig, axs, **kwargs)

## plot boxes
for ax in axs.flatten():
    kwargs = dict(linewidth=0.9, alpha=0.8)
    src.utils.plot_nino34_box(ax, c="w", **kwargs)
    src.utils.plot_nino3_box(ax, c="k", ls="--", **kwargs)

axs[1, 0].axvline(275, c="yellow", lw=5)

plt.show()

### Seasonal cycle

#### SST mean

In [None]:
def plot_cycle_hov(ax, data, **kwargs):
    """plot hovmoller of SST on the equator"""

    ## make sure month is first
    data = data.transpose("month", ...)

    ## Get cyclic point for plotting
    data_cyclic, month_cyclic = cartopy.util.add_cyclic_point(data, data.month, axis=0)

    ## plot data
    plot_data = ax.contourf(data.longitude, month_cyclic, data_cyclic, **kwargs)

    ## plot Niño 3.4 region
    kwargs = dict(ls="--", c="w", lw=0.85)
    ax.axvline(190, **kwargs)
    ax.axvline(240, **kwargs)

    ## labels/style
    ax.set_yticks([1, 5, 9, 13], labels=["Jan", "May", "Sep", "Jan"])
    ax.set_ylabel("Month")
    ax.set_xticks([])
    ax.set_xlim([130, 280])

    return plot_data

Compute

In [None]:
## function to compute equatorial mean
equatorial_mean = lambda x: x.sel(latitude=slice(-2, 2)).mean("latitude")

## Get clim for reanalysis
clim_oras = equatorial_mean(oras_forced["sst"].groupby("time.month").mean())

## get reconstruction for MPI
monthly_clim = forced.groupby("time.month").mean()
recon = src.utils.reconstruct_fn(
    components["sst"], monthly_clim["sst"], fn=equatorial_mean
)
recon.values[recon.values == 0] = np.nan

## compute bias
bias = recon.interp_like(clim_oras) - clim_oras

Plot

In [None]:
## make hövmöllers
fig, axs = plt.subplots(3, 1, figsize=(3.5, 5), layout="constrained")

## plot renalysis
cp0 = src.utils.plot_cycle_hov(
    axs[0], clim_oras, cmap="cmo.thermal", levels=np.arange(22, 31)
)

## plot model data
cp1 = src.utils.plot_cycle_hov(
    axs[1], recon, cmap="cmo.thermal", levels=np.arange(22, 31)
)

## plot bias
cp2 = src.utils.plot_cycle_hov(
    axs[2],
    bias,
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(2, 0.4),
    extend="both",
)

## label
axs[0].set_title("ORAS5")
axs[1].set_title("MPI")
axs[2].set_title("Bias")
axs[-1].set_xlabel("Longitude")
axs[-1].set_xticks([140, 190, 240])

cb0 = fig.colorbar(cp0, ax=axs[0], ticks=[22, 26, 30], label=r"$^{\circ}C$")
cb1 = fig.colorbar(cp1, ax=axs[1], ticks=[22, 26, 30], label=r"$^{\circ}C$")
cb2 = fig.colorbar(cp2, ax=axs[2], ticks=[-2, 0, 2], label=r"$^{\circ}C$")

plt.show()

#### SST variance

Compute

In [None]:
## compute for ORAS
var_oras = equatorial_mean(oras_anom["sst"].groupby("time.month").var("time"))

## compute for MPI
get_var = lambda x: equatorial_mean(
    src.utils.reconstruct_var(scores=x, components=components["sst"])
)
var_mpi = anom["sst"].groupby("time.month").map(get_var)

## get difference
diff = var_mpi.interp_like(var_oras) - var_oras

In [None]:
## shared args for plotting
plot_kwargs = dict(cmap="cmo.amp", extend="max")

## make hövmöllers
fig, axs = plt.subplots(3, 1, figsize=(3.5, 5), layout="constrained")

## plot renalysis
cp0 = src.utils.plot_cycle_hov(
    axs[0], var_oras, levels=np.arange(0, 3.3, 0.3), **plot_kwargs
)

## plot model data
cp1 = src.utils.plot_cycle_hov(
    axs[1], var_mpi, levels=np.arange(0, 3.3, 0.3), **plot_kwargs
)

## plot difference
cp2 = src.utils.plot_cycle_hov(
    axs[2],
    diff,
    levels=src.utils.make_cb_range(2, 0.2),
    cmap="cmo.balance",
    extend="both",
)

## label
axs[0].set_title("ORAS5")
axs[1].set_title("MPI")
axs[2].set_title("Bias")
axs[-1].set_xlabel("Longitude")
axs[-1].set_xticks([140, 190, 240])

## add colorbars
kwargs = dict(ticks=[0, 1, 2, 3], label=r"$^{\circ}\text{C}^2$")
cb0 = fig.colorbar(cp0, ax=axs[0], **kwargs)
cb1 = fig.colorbar(cp1, ax=axs[1], **kwargs)
cb2 = fig.colorbar(cp2, ax=axs[2], **dict(kwargs, ticks=[-2, 0, 2]))

plt.show()

## ENSO composites

Get indices for composite

In [None]:
## comp on el niño or la niña?
is_warm = True

## which month to composite on?
peak_month = 12

## specify function for composite
idx_fn_oras = src.utils.get_nino34
idx_fn_mpi = src.utils.get_nino34

## threshold for composite
q = 0.85

## quantile for composite
if is_warm:
    kwargs = dict(q=q, check_cutoff=lambda x, cut: x > cut)
else:
    kwargs = dict(q=1 - q, check_cutoff=lambda x, cut: x < cut)

## get ORAS index
idx_oras = idx_fn_oras(oras_anom["sst"])

## get MPI index
idx_mpi = src.utils.reconstruct_fn(
    components=components["sst"], scores=anom["sst"], fn=idx_fn_mpi
)

Compute composites

In [None]:
## get composites
comps_oras = src.utils.get_composites(idx_oras, oras_anom, **kwargs)
comps_mpi = src.utils.get_composites(idx_mpi, anom, components=components, **kwargs)

Plot

In [None]:
## specify beta values
betas = np.array([1, 1, 0.5])

## set up plot
fig, axs = plt.subplots(2, 3, figsize=(6, 5), layout="constrained")

for row_i, comps in enumerate([comps_oras, comps_mpi]):
    for col_i, (comp, beta) in enumerate(zip(comps, betas)):

        ## get ax object
        ax = axs[row_i, col_i]

        cf, _ = src.utils.plot_hov(ax=ax, x=comp, beta=beta)
        ax.set_xticks([])
        ax.set_yticks([])

## label
for i, l in enumerate(["Total", "'In-phase'", "'Out-of-phase'"]):
    axs[0, i].set_title(l)
for j, l in zip([0, 1], ["ORAS", "MPI"]):
    axs[j, -1].set_ylabel(l)
    axs[j, -1].yaxis.set_label_position("right")

for ax in axs[:, 0]:
    src.utils.label_hov_yaxis(ax, peak_mon=peak_month)

## label x axis
for ax in axs[1, :]:
    ax.set_xlabel("Longitude")
    ax.set_xticks([190, 240])

plt.show()

## LIM

### Funcs

In [None]:
def get_Xi(Z, varphi, month_idx):
    """Get Koopman modes"""

    ## empty list to hold result
    Xi = []

    ## compute separately for each month
    for m in range(12):

        ## Compute Xi for given month
        is_m = month_idx == m
        Xi_m = Z[:, is_m] @ np.linalg.pinv(varphi[:, is_m])

        ## append to list
        Xi.append(Xi_m)

    return np.stack(Xi, axis=0)


def reconstruct(scores, n_modes, other_coord=None, fn=None):
    """reconstruct projected data"""

    ## get reconstructions
    kwargs = dict(n_modes=n_modes, other_coord=other_coord, fn=fn)
    T_recon = reconstruct_helper(scores[:n_modes], model=eofs["sst"], **kwargs)
    h_recon = reconstruct_helper(scores[n_modes:], model=eofs["ssh"], **kwargs)

    return xr.merge([T_recon.rename("T"), h_recon.rename("h")]).real


def reconstruct_helper(scores, model, n_modes, other_coord=None, fn=None):
    """reconstruct scores for given model"""

    ## put raw scores into xarray
    n = scores.shape[1]
    scores_xr = xr.zeros_like(
        model.scores().isel(member=0, time=slice(None, n))
    ).transpose("mode", ...)
    scores_xr.values[:n_modes] = scores

    ## do inverse transform
    if fn is None:
        data = model.inverse_transform(scores_xr)
    else:
        data = src.utils.reconstruct_fn(
            scores=scores_xr, components=model.components(), fn=fn
        )

    ## rename coord
    if other_coord is None:
        other_coord = pd.Index(np.arange(n), name="n")
    data = data.rename({"time": other_coord.name})
    data[other_coord.name] = other_coord.values

    return data


def plot_setup_spectrum(xlim=[-1, 0]):
    """set up canvas for plotting spectruM"""

    ## plot spectrum
    fig, ax = plt.subplots(figsize=(3, 2.5))

    ## label / guidelines
    kwargs = dict(c="k", zorder=0.5, lw=1, ls="--")
    for yt in [-0.25, 0, 0.25]:
        ax.axhline(yt, **kwargs)
    ax.set_yticks([-1 / 4, 0, 1 / 4], labels=[r"1/4", "0", "1/4"])
    ax.set_xticks([-1, -0.5, 0])
    ax.set_xlim(xlim)
    ax.set_xlabel(r"$\sigma$ (year$^{-1}$)")
    ax.set_ylabel(r"$\omega$ (year$^{-1}$)")

    return fig, ax


def flatten_features(X):
    """flatten variable/mode into single feature dimension"""

    ## flatten array
    if "mode" in X.dims:
        X_flat = X.to_dataarray().stack(m=["variable", "mode"])
    else:
        X_flat = X.to_dataarray().rename({"variable": "m"})

    return X_flat.transpose("m", ...)


def flatten_samples(X):
    """flatten time/ensemble member into single sample dimension"""

    ## flatten array
    X_flat = X.stack(sample=["member", "time"])

    return X_flat.transpose(..., "sample")


def flatten(X):
    """flatten feature and sample dimensions"""

    return flatten_samples(flatten_features(X))


def get_merimean(x):
    """get meridional means"""
    lat = dict(latitude=slice(-5, 5))
    return x.sel(lat).mean("latitude")


def preprocess(XY, t_idx, n_modes=20):
    """preprocess data"""

    ## truncate modes
    if "mode" in XY.dims:
        XY = XY.isel(mode=slice(None, n_modes))

    ## truncate in time
    XY_ = XY.isel(t_idx)

    ## get features/targets
    X = XY_.isel(time=slice(None, -1))
    Y = XY_.isel(time=slice(1, None))

    ## flatten features/targets
    X = flatten(X)
    Y = flatten(Y)

    ## get month values
    month_idx = X.time.dt.month - 1

    return X.values, Y.values, month_idx.values


class lim_RO(src.lim.LIM_CS):

    def __init__(self, XY, Z, t_idx, n_modes=20):

        ## do data pre-processing
        prep_kwargs = dict(n_modes=n_modes, t_idx=t_idx)
        X, Y, month_labels = preprocess(XY, **prep_kwargs)

        ## do covariate pre-processing
        self.Z, _, _ = preprocess(Z, **prep_kwargs)

        ## initialize super
        super().__init__(X=X, Y=Y, month_labels=month_labels)

        ### fix units on spectrum
        self.sigma = 12 * self.sigma
        self.omega = 12 / (2 * np.pi) * self.omega

        ## evaluate ENSO eigenfunction
        self.varphi_eval = self.varphi(X=self.X, month_idx=self.month_labels)

        ## Fit Koopman modes
        self.Xi = get_Xi(Z=self.Z, varphi=self.varphi_eval, month_idx=self.month_labels)

        return

    def varphi(self, X, month_idx):
        """evaluate eigenfunctions"""

        return np.einsum("nmk,mn->kn", self.V[month_idx], X)

    def recon_from_varphi(self, varphi_eval, month_idx):
        """reconstruct from eigenfunction"""

        return np.real(self.Xi[month_idx] @ varphi_eval)

    def recon_from_X(self, X, month_idx):
        """reconstruct full tropical pacific from given T, h data"""

        ## evaluate eigenfunctino
        varphi_eval = self.varphi(X, month_idx)

        return recon_from_varphi(varphi_eval, month_idx)

    def find_theta_max(self, peak_month_idx, fn=src.utils.get_nino34):
        """find theta corresponding to maximum of given function"""

        ## get test pts
        theta_test = np.arange(0, 2 * np.pi, np.pi / 32)
        varphi_test = self.varphi_eval.std() * np.exp(1j * theta_test)
        VtX = np.stack([varphi_test, np.conj(varphi_test)], axis=0)

        ## Get recon (EOF space)
        recon = self.recon_from_varphi(varphi_eval=VtX, month_idx=peak_month_idx)

        ## put recon in EOF
        n_modes = int(recon.shape[0] / 2)
        theta_coord = pd.Index(theta_test, name="theta")
        kwargs = dict(n_modes=n_modes, fn=fn, other_coord=theta_coord)
        recon_nino = reconstruct(scores=recon, **kwargs)

        ## get theta for maximimum Niño in peak month
        idx_max = recon_nino["T"].argmax("theta").values.item()
        theta_max = theta_test[idx_max]

        return theta_max

    def get_enso_comp(self, peak_month_idx, lags_months=np.arange(-12, 13)):
        """Get ENSO composite"""

        ## convert lags from years to angle
        lags_years = lags_months / 12
        delta_theta = 2 * np.pi * self.omega[0] * lags_years

        ## get angle for peak El Niño
        theta_max = self.find_theta_max(peak_month_idx=peak_month_idx)

        ## get test points
        month_idx_test = np.mod(peak_month_idx + lags_months, 12)
        theta_test = np.mod(theta_max + delta_theta, 2 * np.pi)
        varphi_test = 1.5 * self.varphi_eval.std() * np.exp(1j * theta_test)

        ## get complex conjugate
        varphi_test = np.stack([varphi_test, np.conj(varphi_test)], axis=0)

        ## Get recon (EOF space)
        recon = np.einsum("nmk,kn->mn", self.Xi[month_idx_test], varphi_test)

        ## Get recon (real space)
        recon_xr = reconstruct(
            scores=recon.real,
            n_modes=lim_kwargs["n_modes"],
            other_coord=pd.Index(lags_months, name="lag"),
            fn=get_merimean,
        )

        return recon_xr.transpose("lag", ...)

### Prep data

#### Do preprocessing

In [None]:
## load T, h data
Th = xr.open_dataset(mpi_load_fp)
Th["time"] = anom_all["time"]

## specify kwargs
lim_kwargs = dict(XY=Th[["T_34", "h"]], Z=anom_all[["sst", "ssh"]], n_modes=20)

## fit LIMs
lim_earl = lim_RO(t_idx=dict(time=slice(None, 360)), **lim_kwargs)
lim_late = lim_RO(t_idx=dict(time=slice(-360, None)), **lim_kwargs)

### Fit model
(and plot spectrum)

In [None]:
## Plot
fig, ax = plot_setup_spectrum()
ax.scatter(lim_earl.sigma, lim_earl.omega, label="Early")
ax.scatter(lim_late.sigma, lim_late.omega, label="Late")
ax.legend()
plt.show()

Find peak angle for El Niño

In [None]:
## specify month to look at
peak_month_idx = 7

## get reconstructions
recon_earl = lim_earl.get_enso_comp(peak_month_idx=peak_month_idx)
recon_late = lim_late.get_enso_comp(peak_month_idx=peak_month_idx)

In [None]:
scales = np.array([0.95, 0.95, 0.475])

## set up plot
fig, axs = plt.subplots(1, 3, figsize=(6, 3), layout="constrained")

for ax, x, scale in zip(axs, [recon_earl, recon_late, recon_late - recon_earl], scales):

    ## Plot
    x_ = x.rename({"T": "sst", "h": "ssh"})
    cf, _ = src.utils.plot_hov(ax=ax, x=x_, beta=scale)

    ## label
    ax.set_xlabel("Longitude")
    ax.set_xticks([190, 240])
    ax.set_yticks([])

## label
axs[0].set_title("Early)")
axs[1].set_title("Late")
axs[2].set_title("Difference")
src.utils.label_hov_yaxis(axs[0], peak_mon=peak_month_idx + 1)

plt.show()

## Look at couplings in Bjerknes feedback

In [None]:
def regress(X, Y, dim="time"):
    """compute local regression coefficient"""

    ## get variance (and add small number to prevent overflow)
    var_X = X.var(dim=dim) + np.finfo(np.float32).eps

    return xr.cov(X, Y, dim=dim, ddof=0) / var_X


def regress_local_SST_SSH(data):
    """compute local regression for SST and SSH"""
    return regress(X=data["ssh"], Y=data["sst"], dim="time")


def regress_local_SST_SSH_proj(data):
    """compute local regression for SST and SSH, from projected data"""

    ## compute covariance
    cov = src.utils.reconstruct_cov_da(
        V_y=data["sst"],
        V_x=data["ssh"],
        U_y=data["sst_comp"],
        U_x=data["ssh_comp"],
    )

    ## compute variance
    var_X = src.utils.reconstruct_var(scores=data["ssh"], components=data["ssh_comp"])

    return cov / var_X

#### Add components to MPI data (useful for reconstruction)

In [None]:
for v in list(components):
    if f"{v}_comp" not in list(anom):
        anom[f"{v}_comp"] = components[v]

### SST - SSH

In [None]:
## compute local regression by month
m_sst_ssh_oras = oras_anom.groupby("time.month").map(regress_local_SST_SSH)

## then, reconstruct regression coefficient
m_sst_ssh_mpi_proj = anom.groupby("time.month").map(regress_local_SST_SSH_proj)

#### Spatial plot

In [None]:
## set up plot
fig = plt.figure(figsize=(7, 2.8), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
axs = src.utils.subplots_with_proj(fig, nrows=2, ncols=1, format_func=format_func)

## plot ORAS
cp0 = axs[0, 0].pcolormesh(
    m_sst_ssh_oras.longitude,
    m_sst_ssh_oras.latitude,
    m_sst_ssh_oras.mean("month"),
    vmax=0.3,
    vmin=-0.3,
    cmap="cmo.balance",
    transform=ccrs.PlateCarree(),
)

## plot data
cp1 = axs[1, 0].pcolormesh(
    m_sst_ssh_mpi_proj.longitude,
    m_sst_ssh_mpi_proj.latitude,
    m_sst_ssh_mpi_proj.mean("month"),
    vmax=0.3,
    vmin=-0.3,
    cmap="cmo.balance",
    transform=ccrs.PlateCarree(),
)

cb_kwargs = dict(ticks=[-0.3, 0, 0.3], label=r"$K ~cm^{-1}$")
cb0 = fig.colorbar(cp0, ax=axs, **cb_kwargs)
plt.show()

#### Scatter plot

In [None]:
def make_scatter(ax, x, y, dim="time"):
    """scatter plot data on axis"""

    ## compute slope for best fit line
    slope = regress(X=x, Y=y, dim=dim).values.item()

    ## plot data
    ax.scatter(x, y, s=0.5)

    ## plot best fit
    xtest = np.linspace(x.values.min(), x.values.max())
    ax.plot(xtest, slope * xtest, c="k", lw=1, label=label)

    return slope

In [None]:
## helper func for label
get_label = lambda slope: f"slope$=${slope:.2f} " + r"$K ~cm^{-1}$"

## plot
fig, axs = plt.subplots(2, 1, figsize=(2, 4), layout="constrained")

## ORAS
nino3_idx = src.utils.get_nino3(oras_anom)
slope = make_scatter(axs[0], x=nino3_idx["ssh"], y=nino3_idx["sst"])
axs[0].set_title(get_label(slope))

## MPI
nino3_mpi = src.utils.reconstruct_fn(
    scores=anom, components=components, fn=src.utils.get_nino3
)
nino3_mpi = nino3_mpi.stack(sample=["time", "member"])
slope = make_scatter(axs[1], x=nino3_mpi["ssh"], y=nino3_mpi["sst"], dim="sample")
axs[1].set_title(get_label(slope))

## set limits
for ax in axs:
    ax.set_xlim([-15, 22])
    ax.set_ylim([-4, 4])
    ax.axhline(0, ls="--", lw=0.8, c="k")
    ax.axvline(0, ls="--", lw=0.8, c="k")
    ax.set_ylabel(r"Niño 3 SST ($K$)")

## format
axs[0].set_xticks([])
axs[1].set_xticks([-10, 0, 10, 20])
axs[1].set_xlabel(r"Niño 3 SSH ($cm$)")


plt.show()

### $SST-\tau_x$ coupling

Functions

In [None]:
def regress_taux_sst_nino(data, sst_fn=src.utils.get_nino3):
    """compute local regression for SST and SSH"""
    return regress(X=sst_fn(data["sst"]), Y=data["taux"], dim="time")


def regress_taux_sst_nino_proj(data, fn_x=src.utils.get_nino3):
    """compute local regression for SST and SSH, from projected data"""

    ## compute covariance
    cov = src.utils.reconstruct_cov_da(
        V_y=data["taux"],
        V_x=data["sst"],
        U_y=data["taux_comp"],
        U_x=data["sst_comp"],
        fn_x=fn_x,
        fn_y=lambda x: x,
    )

    ## compute variance
    var_X = src.utils.reconstruct_var(
        scores=data["sst"],
        components=data["sst_comp"],
        fn=fn_x,
    )

    return cov / var_X

Compute

In [None]:
## ORAS
m_taux_nino_oras = oras_anom.groupby("time.month").map(regress_taux_sst_nino)

## MPI
m_taux_nino_mpi = anom.groupby("time.month").map(
    regress_taux_sst_nino_proj, fn_x=src.utils.get_nino3
)

Spatial plot

In [None]:
## set up plot
fig = plt.figure(figsize=(7, 2.6), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
axs = src.utils.subplots_with_proj(fig, nrows=2, ncols=1, format_func=format_func)

## plot ORAS
cp0 = axs[0, 0].pcolormesh(
    m_taux_nino_oras.longitude,
    m_taux_nino_oras.latitude,
    m_taux_nino_oras.mean("month"),
    vmax=2e-2,
    vmin=-2e-2,
    cmap="cmo.balance",
    transform=ccrs.PlateCarree(),
)

## plot data
cp1 = axs[1, 0].pcolormesh(
    m_taux_nino_mpi.longitude,
    m_taux_nino_mpi.latitude,
    m_taux_nino_mpi.mean("month"),
    vmax=2e-2,
    vmin=-2e-2,
    cmap="cmo.balance",
    transform=ccrs.PlateCarree(),
)

## colorbar
cb_kwargs = dict(ticks=[-2e-2, 0, 2e-2], label=r"$Pa~K^{-1}$")
cb0 = fig.colorbar(cp0, ax=axs, **cb_kwargs)

## Niño 3 box
box_kwargs = dict(c="k", linewidth=0.9, alpha=0.5)
for ax in axs.flatten():
    src.utils.plot_nino3_box(ax, **box_kwargs)

plt.show()

Scatter plot

In [None]:
## scaling factor (to make units nicer)
scale = 1000

## helper func for label
get_label = lambda slope: f"slope$=${slope:.1f} " + r"$mPa ~ K^{-1}$"

## plot
fig, axs = plt.subplots(2, 1, figsize=(2, 4), layout="constrained")

## ORAS
sst_plot = src.utils.get_nino3(oras_anom["sst"])
tau_plot = src.utils.get_nino4(oras_anom["taux"])
slope = make_scatter(axs[0], x=sst_plot, y=tau_plot * scale)
axs[0].set_title(get_label(slope))

## MPI
sst_plot = src.utils.reconstruct_fn(
    scores=anom["sst"], components=components["sst"], fn=src.utils.get_nino3
).stack(sample=["time", "member"])
tau_plot = src.utils.reconstruct_fn(
    scores=anom["taux"], components=components["taux"], fn=src.utils.get_nino4
).stack(sample=["time", "member"])

slope = make_scatter(axs[1], x=sst_plot, y=tau_plot * scale, dim="sample")
axs[1].set_title(get_label(slope))

## set limits
for ax in axs:
    ax.set_xlim([-4, 4.5])
    ax.set_ylim([-50, 70])
    ax.axhline(0, ls="--", lw=0.8, c="k")
    ax.axvline(0, ls="--", lw=0.8, c="k")
    ax.set_ylabel(r"Niño 4 $\tau_x$ ($mPa$)")

# ## format
axs[0].set_xticks([])
axs[1].set_xticks([-3, 0, 3])
axs[1].set_xlabel(r"Niño 3 SST ($K$)")


plt.show()

### $\tau_x-SSH$ coupling

In [None]:
def regress_ssh_taux_nino(data, taux_fn=src.utils.get_nino4):
    """regress spatial ssh on scalar taux"""
    return regress(X=taux_fn(data["taux"]), Y=data["ssh"], dim="time")


def regress_ssh_taux_nino_proj(data, fn_x=src.utils.get_nino4):
    """compute local regression for SST and SSH, from projected data"""

    ## compute covariance
    cov = src.utils.reconstruct_cov_da(
        V_y=data["ssh"],
        V_x=data["taux"],
        U_y=data["ssh_comp"],
        U_x=data["taux_comp"],
        fn_x=fn_x,
        fn_y=lambda x: x,
    )

    ## compute variance
    var_X = src.utils.reconstruct_var(
        scores=data["taux"],
        components=data["taux_comp"],
        fn=fn_x,
    )

    return cov / var_X

Compute

In [None]:
## ORAS
m_ssh_nino_oras = oras_anom.groupby("time.month").map(regress_ssh_taux_nino)

## MPI
m_ssh_nino_mpi = anom.groupby("time.month").map(
    regress_ssh_taux_nino_proj, fn_x=src.utils.get_nino4
)

Plot

In [None]:
## set up plot
fig = plt.figure(figsize=(7, 2.6), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
axs = src.utils.subplots_with_proj(fig, nrows=2, ncols=1, format_func=format_func)

## plot ORAS
cp0 = axs[0, 0].pcolormesh(
    m_ssh_nino_oras.longitude,
    m_ssh_nino_oras.latitude,
    m_ssh_nino_oras.mean("month"),
    vmax=5e2,
    vmin=-5e2,
    cmap="cmo.balance",
    transform=ccrs.PlateCarree(),
)

## plot data
cp1 = axs[1, 0].pcolormesh(
    m_ssh_nino_mpi.longitude,
    m_ssh_nino_mpi.latitude,
    m_ssh_nino_mpi.mean("month"),
    vmax=5e2,
    vmin=-5e2,
    cmap="cmo.balance",
    transform=ccrs.PlateCarree(),
)

## colorbar
cb_kwargs = dict(ticks=[-500, 0, 500], label=r"$cm~Pa^{-1}$")
cb0 = fig.colorbar(cp0, ax=axs, **cb_kwargs)

## Niño 3 box
box_kwargs = dict(c="k", linewidth=0.9, alpha=0.5)
for ax in axs.flatten():
    src.utils.plot_nino4_box(ax, **box_kwargs)

plt.show()

Scatter plot

In [None]:
## scaling factor (to make units nicer)
scale = 1 / 1000

## helper func for label
get_label = lambda slope: f"slope$=${slope:.2f} " + r"$m ~ mPa^{-1}$"

## plot
fig, axs = plt.subplots(2, 1, figsize=(2, 4), layout="constrained")

## ORAS
tau_plot = src.utils.get_nino4(oras_anom["taux"])
ssh_plot = src.utils.get_nino3(oras_anom["ssh"])
slope = make_scatter(axs[0], x=tau_plot, y=ssh_plot * scale)
axs[0].set_title(get_label(slope))

## MPI
tau_plot = src.utils.reconstruct_fn(
    scores=anom["taux"], components=components["taux"], fn=src.utils.get_nino4
).stack(sample=["time", "member"])
ssh_plot = src.utils.reconstruct_fn(
    scores=anom["ssh"], components=components["ssh"], fn=src.utils.get_nino3
).stack(sample=["time", "member"])

slope = make_scatter(axs[1], x=tau_plot, y=ssh_plot * scale, dim="sample")
axs[1].set_title(get_label(slope))

## set limits
for ax in axs:
    ax.set_xlim([-5e-2, 8e-2])
    ax.set_ylim([-0.015, 0.025])
    ax.axhline(0, ls="--", lw=0.8, c="k")
    ax.axvline(0, ls="--", lw=0.8, c="k")
    ax.set_ylabel(r"Niño 3 SSH ($m$)")

# ## format
axs[0].set_xticks([])
axs[1].set_xticks([-0.05, 0, 0.05])
axs[1].set_xlabel(r"Niño 4 $\tau_x$ ($Pa$)")


plt.show()

## Scratch
Look at mean state

In [None]:
taux_clim_proj = forced["taux"].mean("time").expand_dims("member")
taux_clim = get_merimean(eofs["taux"].inverse_transform(taux_clim_proj))
fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(taux_clim.longitude, taux_clim.squeeze())
plt.show()