In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
from qgsw.fields.variables.tuples import UVH
import numpy as np
import torch
from traitlets import default
from qgsw.masks import Masks
from qgsw.models.qg.stretching_matrix import compute_A, compute_layers_to_mode_decomposition
from qgsw.models.qg.uvh.projectors.core import QGProjector
from qgsw.output import RunOutput
from qgsw.spatial.core.discretization import SpaceDiscretization3D
from qgsw.spatial.core.grid_conversion import points_to_surfaces
from qgsw.specs import defaults
from qgsw.utils import covphys
from scipy import signal
from collections.abc import Iterable, Iterator

from qgsw.output import OutputFileUVH
from qgsw.plots.heatmaps import AnimatedHeatmaps
from qgsw.filters.base import _Filter

class NoFilter:
    def __call__(self, to_filter: torch.Tensor) -> torch.Tensor:
        return to_filter


run = RunOutput("../output/g5k/sw_double_gyre_long_hr")

H = run.summary.configuration.model.h
f0 = run.summary.configuration.physics.f0
P = QGProjector(
    A =compute_A(
        H = H,
        g_prime = torch.tensor([9.81,0.025, 0.0125],**defaults.get())
    ),
    H = H.unsqueeze(-1).unsqueeze(-1),
    space=SpaceDiscretization3D.from_config(
        run.summary.configuration.space,
        run.summary.configuration.model
    ),
    f0 = run.summary.configuration.physics.f0,
    masks = Masks.empty(nx=run.summary.configuration.space.nx,ny=run.summary.configuration.space.ny)
)
space=P.space

y = torch.linspace(
    0.5 * space.dy,
    space.ly - 0.5 * space.dy,
    space.ny,
    **defaults.get()
).unsqueeze(0)
y0 = 0.5 * space.ly

# Beta-effect
beta_effect = run.summary.configuration.physics.beta * (y - y0)

outputs = run.outputs()
dx,dy = run.summary.configuration.space.dx, run.summary.configuration.space.dy
nx,ny = run.summary.configuration.space.nx, run.summary.configuration.space.ny

def compute_sf(uvh:UVH, P:QGProjector, dx:float, dy:float, filt:_Filter)-> torch.Tensor:
    sf = P.compute_p(covphys.to_cov(uvh,dx,dy))[1][0]/P._f0
    sf[0] = filt(sf[0])
    sf[1] = filt(sf[1])
    sf[2] = filt(sf[2])
    return sf

_ , _, Cl2m = compute_layers_to_mode_decomposition(P.A)

def compute_modes(uvh:UVH, P:QGProjector, dx:float, dy:float, filt:_Filter)-> torch.Tensor:
    sf = P.compute_p(covphys.to_cov(uvh,dx,dy))[1][0]/P._f0
    return torch.einsum("...lm,...mxy->...lxy", Cl2m, sf)


def compute_pv(uvh:UVH, H:torch.Tensor, f0:float, dx:float, dy:float, filt:_Filter)-> torch.Tensor:
    u,v,h = uvh.u, uvh.v, uvh.h
    omega = torch.diff(v[..., 1:-1], dim=-2) / dx - torch.diff(u[..., 1:-1, :], dim=-1) / dy
    h = points_to_surfaces(h)
    pv =  (omega - f0 * h / H)[0]
    pv[0] = filt(pv[0])
    pv[1] = filt(pv[1])
    pv[2] = filt(pv[2])
    return pv

def load_sf(outputs: Iterable[OutputFileUVH], P:QGProjector, dx:float, dy:float, filt:_Filter = NoFilter()) -> Iterator[torch.Tensor]:
    return (compute_sf(o.read(), P, dx, dy, filt) for o in outputs)

def load_modes(outputs: Iterable[OutputFileUVH], P:QGProjector, dx:float, dy:float, filt:_Filter = NoFilter()) -> Iterator[torch.Tensor]:
    return (compute_modes(o.read(), P, dx, dy, filt) for o in outputs)


def load_pv(outputs: Iterable[OutputFileUVH], H:torch.Tensor, f0:float,dx:float, dy:float, filt:_Filter = NoFilter()) -> Iterator[torch.Tensor]:
    return (compute_pv(o.read(), H, f0, dx, dy, filt) for o in outputs)

def compute_correlations(profiles: torch.Tensor, normalize_mean:bool = False) -> torch.Tensor:
    nt, nx, ny = profiles.shape
    stacked = profiles.reshape((nt, nx*ny))

    correlation_nt = 2*nt-1
    correlations = torch.zeros((correlation_nt,nx*ny))

    Ns = np.array([366/(366-abs(i)) for i in range(-365,366)])

    for i in range(nx*ny):
        mean = np.mean(stacked[:,i])
        timeseries = stacked[:,i] - mean
        correlations_np = signal.correlate(timeseries,timeseries,mode="full")
        if normalize_mean:
            correlations_np*=Ns
        correlations[:,i]= torch.tensor(correlations_np)

    normalized = correlations / correlations[correlation_nt//2]
    sliced = normalized[correlation_nt//2:]
    return torch.reshape(sliced, (nt, nx, ny))

In [None]:
def plot_correlations(correlations:torch.Tensor, suptitle:str|None=None) -> None:

    nt, nl, nx,ny = correlations.shape

    nx_nb_slice, ny_nb_slice = 8,16
    xs = [(max(0,int((i-1)*nx/nx_nb_slice)),min(nx-1,int(i*nx/nx_nb_slice))) for i in range(1,nx_nb_slice+1)]
    ys = [(max(0,int((i-1)*ny/ny_nb_slice)),min(ny-1,int(i*ny/ny_nb_slice))) for i in range(1,ny_nb_slice+1)]
    decorrelation_times = torch.sum(correlations,dim=0)
    decorrelation_means = torch.zeros(decorrelation_times.shape)


    fig_time_decorr, axs_time_decorr = plt.subplots(1,nl,squeeze=False, constrained_layout=True,figsize=(15,8))

    fig_mean_decorr, axs_mean_decorr = plt.subplots(1,nl,squeeze=False, constrained_layout=True,figsize=(15,8))
    
    for xmin,xmax in xs:
        for ymin,ymax in ys:
            mean = torch.mean(decorrelation_times[...,xmin:xmax+1,ymin:ymax+1],dim=[-2,-1],keepdim=True)
            decorrelation_means[:,xmin:xmax+1,ymin:ymax+1] = mean
            axs_mean_decorr[0,0].text((xmin+xmax)//2,(ymin+ymax)//2,round(mean[0].item()),ha="center",va="center")
            axs_mean_decorr[0,1].text((xmin+xmax)//2,(ymin+ymax)//2,round(mean[1].item()),ha="center",va="center")
            axs_mean_decorr[0,2].text((xmin+xmax)//2,(ymin+ymax)//2,round(mean[2].item()),ha="center",va="center")
    
    fig_profiles, axs_profiles = plt.subplots(1,nl,squeeze=False, constrained_layout=True,figsize=(15,5))
    
    if suptitle is not None:
        fig_time_decorr.suptitle(suptitle)
        fig_mean_decorr.suptitle(suptitle)
        fig_profiles.suptitle(suptitle)

    for l in range(nl):

        corr = correlations[:,l,...]
        axs_time_decorr[0,l].set_title(f"Layer {l}")
        cbar = axs_time_decorr[0,l].imshow(torch.sum(corr,dim=0).T,cmap="jet")
        fig_time_decorr.colorbar(cbar,ax=axs_time_decorr[0,l], label="Integral time scale [days]")
        axs_time_decorr[0,l].scatter(64,128, label="x=64, y=128", c="b")
        axs_time_decorr[0,l].scatter(64,384, label="x=64, y=384", c="r" )
        axs_time_decorr[0,l].scatter(128,256, label="x=128, y=256", c = "k")
        axs_time_decorr[0,l].scatter(192,128, label="x=192, y=128", c="brown")
        axs_time_decorr[0,l].scatter(192,384, label="x=192, y=384", c = "orange")

        xs = [k for k in range(corr.shape[0])]
        axs_profiles[0,l].set_title(f"Layer {l}")
        axs_profiles[0,l].plot(xs, corr[:,64,128], label="x=64, y=128", color="b")
        axs_profiles[0,l].plot(xs, corr[:,64,384], label="x=64, y=384", color="r" )
        axs_profiles[0,l].plot(xs, corr[:,128,256], label="x=128, y=256", color = "k")
        axs_profiles[0,l].plot(xs, corr[:,192,128], label="x=192, y=128", color="brown")
        axs_profiles[0,l].plot(xs, corr[:,192,384], label="x=192, y=384", color = "orange")
        axs_profiles[0,l].hlines(y=0,xmin=0,xmax=nt,linestyles="--",colors="k",alpha=0.25)
        axs_profiles[0,l].vlines(x=20,ymin=-1,ymax=1,linestyles="--",colors="k",alpha=0.25)

        axs_mean_decorr[0,l].set_title(f"Layer {l}")
        cbar = axs_mean_decorr[0,l].imshow(decorrelation_means[l].T,cmap="jet")
        # fig_mean_decorr.colorbar(cbar,ax=axs_mean_decorr[0,l], label="Integral time scale [days]")

    plt.legend()
    plt.show()
    plt.close(fig_time_decorr)
    plt.close(fig_mean_decorr)
    plt.close(fig_profiles)

In [None]:
threshold = 100

In [None]:
corrs_layers = []

for layer in [0,1,2]:

    data_3D = load_sf(run.outputs(), P, dx, dy)
    data_stacked = torch.stack([e[layer] for e in data_3D],dim=0).cpu().numpy()

    correlations_layer = compute_correlations(data_stacked)
    corrs_layers.append(correlations_layer)
correlations_full = torch.stack(corrs_layers,dim=1)

correlations = correlations_full[:threshold]


plot_correlations(correlations,"Stream function")

In [None]:
from qgsw.filters.high_pass import GaussianHighPass2D


corrs_layers_filt = []
filt = GaussianHighPass2D(sigma=5)

for layer in [0,1,2]:

    data_3D_filt = load_sf(run.outputs(), P, dx, dy, filt)
    data_stacked_filt = torch.stack([e[layer] for e in data_3D_filt],dim=0).cpu().numpy()

    correlations_layer_filt = compute_correlations(data_stacked_filt)
    corrs_layers_filt.append(correlations_layer_filt)
correlations_full_filt = torch.stack(corrs_layers_filt,dim=1)

correlations_filt = correlations_full_filt[:threshold]


plot_correlations(correlations_filt,f"Filtered stream function: σ = {filt.sigma}")

In [None]:
corrs_layers = []

for layer in [0,1,2]:

    data_3D = load_pv(run.outputs(), H.unsqueeze(-1).unsqueeze(-1), f0, dx, dy)
    data_stacked = torch.stack([e[layer] for e in data_3D],dim=0).cpu().numpy()

    correlations_layer = compute_correlations(data_stacked)
    corrs_layers.append(correlations_layer)
correlations_full = torch.stack(corrs_layers,dim=1)

correlations = correlations_full[:threshold]


plot_correlations(correlations,f"Potential vorticity")

In [None]:
from qgsw.filters.high_pass import GaussianHighPass2D


corrs_layers_filt = []
filt = GaussianHighPass2D(sigma=30)

for layer in [0,1,2]:

    data_3D_filt = load_pv(run.outputs(), H.unsqueeze(-1).unsqueeze(-1), f0, dx, dy, filt)
    data_stacked_filt = torch.stack([e[layer] for e in data_3D_filt],dim=0).cpu().numpy()

    correlations_layer_filt = compute_correlations(data_stacked_filt)
    corrs_layers_filt.append(correlations_layer_filt)
correlations_full_filt = torch.stack(corrs_layers_filt,dim=1)

correlations_filt = correlations_full_filt[:threshold]

plot_correlations(correlations_filt,f"Filtered potential vorticity: σ = {filt.sigma}")

In [None]:
from matplotlib import pyplot as plt

from qgsw.plots import plt_wrapper

s = torch.sign(correlations)
c = torch.clone(correlations)
c[c<0] = np.nan
idx = torch.nonzero(np.bitwise_not(torch.isnan(c)))

idx =torch.argmax(-c,dim=0)

fig = plt.figure(figsize=(5,8))
ax = fig.add_subplot()
cbar = ax.imshow(idx[0].T,vmin=0,cmap=plt_wrapper.DEFAULT_CMAP)
fig.colorbar(cbar,ax=ax)
plt.show()
plt.close(fig)

plt.hist(idx[0].ravel(),bins=100,density=True)
plt.title("'first 0 correlation reached' timestep")
plt.ylabel("Proportion")
plt.xlabel("Time offset [days]")
plt.show()
plt.close()

idx_np = idx[0].numpy()
x = np.array([k for k in range(correlations.shape[0])])
y = np.array([np.sum(idx_np == k) for k in x])
y_cumsum = np.cumsum(y)
plt.plot(x,y_cumsum/nx/ny*100)
plt.ylabel("% of data")
plt.xlabel("Time offset [days]")
plt.vlines(x=20,color='red', linestyles='--', ymin=0,ymax=100,alpha=0.25,label="20 days")
plt.title("Cumulative proportion of 'first 0 correlation reached' timestep")
plt.legend()
plt.show()
plt.close()

In [None]:
x = data_stacked[50:,10, 20] - np.mean(data_stacked[:,10,20])
y = data_stacked[:-50,10, 20] - np.mean(data_stacked[:,10,20])
x0 = data_stacked[:,10,20] - np.mean(data_stacked[:,10,20])
np.sum(x*y)/np.sum(x0*x0)

In [None]:
correlations[50,2,10,20]

In [None]:
# plot = AnimatedHeatmaps([[sf_correlations[i,0].T for i in range(threshold)]])
# plot.set_zbounds(-1,1)
# ts = [f"Offset: {o.read().t.item() / 3600 / 24:.1f} day(s)" for o in run.outputs()][:threshold]
# plot.set_frame_labels(ts)
# plot.save_video("../output/video/autocorrelation_sf.mp4",width=512,height=1024,fps=5)