In [321]:
import pandas as pd
import numpy as np
from tsdata_to_var_torch import tsdata_to_var_torch
from var_to_autocov_torch import var_to_autocov_torch
from autocov_to_pwcgc_torch import autocov_to_pwcgc_torch
import seaborn as sns
import torch
import matplotlib.pyplot as plt

X = pd.read_csv('X.csv',header=None)
X  = np.array(X) # X must be numpy array in shape of (num var, num rows)
# X = X[:96,:]



# window_size = 6

# # Convert to DataFrame
# df = pd.DataFrame(X)

# # Apply a centered rolling mean
# centered_ma_df = df.rolling(window=window_size, min_periods=1, center=True).mean()

# # Convert back to NumPy
# X = centered_ma_df.to_numpy()






X = torch.tensor(X[:100,:], requires_grad=True)
X.shape




torch.Size([100, 7])

In [302]:
import torch
from typing import Dict, Tuple

def ridge_var(
    Y: torch.Tensor,
    p: int = 1,
    alpha: float = 1e-2,
    standardize: bool = False,
    add_intercept: bool = False,
    eps: float = 1e-8,
) -> Dict[str, torch.Tensor]:
    """
    Ridge-regularized VAR(p) fit on a sequence window.

    Args
    ----
    Y : (H, d) tensor
        Time-by-feature matrix for the window (rows: time T+1..T+H, cols: d series).
    p : int
        VAR order (number of lags).
    alpha : float
        Ridge penalty (L2) on the stacked coefficient matrix A. Must be > 0 for stability.
    standardize : bool
        If True, z-score each column of Y over the window before building lags.
        (This helps scale-invariance and conditioning; the model remains differentiable.)
    add_intercept : bool
        If True, include an intercept term (column of ones) in the regressors.
        The returned 'c' is the intercept vector (d,). Otherwise 'c' is zeros.
    eps : float
        Numerical floor for standard deviations when standardizing.

    Returns
    -------
    out : dict with keys
        A        : (p, d, d) stacked coefficient blocks; A[k-1] maps Y_{t-k} to Y_t
        Sigma    : (d, d) innovation covariance (unbiased: divide by H - p)
        R        : (H-p, d) residuals on the window
        X_reg    : (H-p, p*d [+1]) regressor matrix actually used
        Y_resp   : (H-p, d) response matrix actually used
        c        : (d,) intercept (zeros if add_intercept=False)
        mu, std  : (d,), (d,) used for standardization (if enabled), else None
    Notes
    -----
    - Row t of X_reg (0-indexed) corresponds to predicting Y[p + t] using
      [Y[p + t - 1], Y[p + t - 2], ..., Y[p + t - p]] concatenated (most recent first).
      Hence A[0] = A_1 multiplies Y_{t-1}, ..., A[p-1] = A_p multiplies Y_{t-p}.
    - Uses normal equations with Cholesky solve:
        G = X^T X + alpha I;  solve G A = X^T Y
      Everything is differentiable in Y (no detach).
    """
    assert Y.dim() == 2, "Y must be (H, d)"
    H, d = Y.shape
    assert H > p, "Need H > p to fit a VAR(p)"
    device, dtype = Y.device, Y.dtype

    # Optional standardization (per feature over the window)
    if standardize:
        mu = Y.mean(dim=0, keepdim=True)
        std = Y.std(dim=0, unbiased=False, keepdim=True).clamp_min(eps)
        Yz = (Y - mu) / std
    else:
        mu, std = None, None
        Yz = Y

    # Build responses Y_resp = [Y[p], Y[p+1], ..., Y[H-1]] (length H-p)
    Y_resp = Yz[p:, :]  # (H-p, d)

    # Build lagged design X_reg with columns [Y_{t-1}, Y_{t-2}, ..., Y_{t-p}] (most recent first)
    # Each block has shape (H-p, d); we then concat along last dim -> (H-p, p*d)
    X_lags = [Yz[p - k : H - k, :] for k in range(1, p + 1)]  # k=1 -> Y_{t-1}
    X_reg = torch.cat(X_lags, dim=1)  # (H-p, p*d)

    # Optionally add intercept column
    if add_intercept:
        ones = torch.ones((H - p, 1), device=device, dtype=dtype)
        X_reg = torch.cat([X_reg, ones], dim=1)  # (H-p, p*d + 1)

    # Normal equations components
    # G = X^T X + alpha I, B = X^T Y
    G = X_reg.T @ X_reg
    # Ridge stabilizer on all parameters (including intercept if present)
    G = G + alpha * torch.eye(G.shape[0], device=device, dtype=dtype)
    B = X_reg.T @ Y_resp  # ((p*d [+1]), d)

    # Solve G * A_full = B with Cholesky (stable + differentiable)
    # torch.cholesky is deprecated; use linalg.cholesky + triangular_solve
    L = torch.linalg.cholesky(G)                       # G = L L^T
    # Solve L Z = B  -> Z = L^{-1} B
    Z = torch.cholesky_solve(B, L)                     # Solves G X = B directly (preferred)
    A_full = Z                                         # ((p*d [+1]), d)

    # Split intercept (if any) and VAR blocks
    if add_intercept:
        A_flat = A_full[:-1, :]                        # (p*d, d)
        c = A_full[-1, :]                              # (d,)
    else:
        A_flat = A_full                                # (p*d, d)
        c = torch.zeros(d, device=device, dtype=dtype)

    # Reshape A_flat -> (p, d, d); block k-1 corresponds to lag k
    A = A_flat.reshape(p, d, d)  # (p, d, d)

    # Residuals on the (possibly standardized) scale:
    Y_hat = X_reg @ A_full                    # (H-p, d)
    R = Y_resp - Y_hat                        # (H-p, d)

    # Innovation covariance (unbiased)
    Sigma = (R.T @ R) / (H - p)               # (d, d)

    # If we standardized inputs/outputs for fitting, A maps standardized lags to standardized responses.
    # You can map A back to the original scale if needed:
    #   If Y_std = (Y - mu)/std, then
    #   Y_std_t = sum_k A_k Y_std_{t-k} + ...  => on original scale:
    #   A_k_orig = diag(std.squeeze()) @ A_k @ diag(1/std.squeeze())
    #   c_orig   = mu.squeeze() + diag(std.squeeze()) @ c - sum_k A_k_orig @ mu.squeeze()
    # We leave them in standardized coordinates to keep gradients simple; return mu,std for optional re-scaling.
    out = {
        "A": A,                    # (p, d, d)
        "Sigma": Sigma,            # (d, d)
        "R": R,                    # (H-p, d)
        "X_reg": X_reg,            # (H-p, p*d [+1])
        "Y_resp": Y_resp,          # (H-p, d)
        "c": c,                    # (d,)
        "mu": None if not standardize else mu.squeeze(0),   # (d,)
        "std": None if not standardize else std.squeeze(0), # (d,)
    }
    return out

import torch
from typing import Dict

def build_companion(
    A: torch.Tensor,           # (p, d, d), with A[0]=A1, ..., A[p-1]=Ap
    c: torch.Tensor = None,    # (d,), optional intercept
    augmented: bool = False,   # if True, return augmented companion (pd+1 x pd+1)
) -> Dict[str, torch.Tensor]:
    """
    Construct the companion matrix for a VAR(p), with optional intercept handling.

    Theory:
      VAR(p): X_t = sum_{k=1}^p A_k X_{t-k} + c + eps_t
      Companion state Z_t = [X_t; X_{t-1}; ...; X_{t-p+1}] ∈ R^{pd}

      Standard companion (no state augmentation):
        Z_t = C Z_{t-1} + d + W_t
        C = [[A1 A2 ... Ap],
             [I  0  ... 0 ],
             [0  I  ... 0 ],
             ...
             [0  0  ... I 0]] ∈ R^{pd x pd}
        d = [c; 0; ...; 0] ∈ R^{pd}

      Augmented companion (embed intercept as linear map on a constant 1):
        \tilde Z_t = [X_t; X_{t-1}; ...; X_{t-p+1}; 1]
        \tilde Z_t = C_aug \tilde Z_{t-1} + \tilde W_t
        C_aug = [[A1 A2 ... Ap | c],
                 [I  0  ... 0  | 0],
                 [0  I  ... 0  | 0],
                 ...
                 [0  0  ... I  0],
                 [0  0  ... 0  | 1]] ∈ R^{(pd+1) x (pd+1)}

    Returns:
      If augmented=False:
        {
          "C": (pd, pd) companion matrix,
          "d": (pd,)   constant input vector (zeros if c is None),
        }
      If augmented=True:
        {
          "C_aug": (pd+1, pd+1) augmented companion,
        }

    Notes:
      - Stability is governed by the spectral radius of the non-augmented C: rho(C) < 1.
      - Works with any floating dtype (float32/float64) and on any device.
    """
    assert A.dim() == 3, "A must be (p, d, d)"
    p, d, d2 = A.shape
    assert d == d2, "A blocks must be square (d x d)"
    device, dtype = A.device, A.dtype

    pdim = p * d

    # Top block: [A1 A2 ... Ap] of shape (d, p*d)
    top = A.reshape(p, d, d).permute(1, 0, 2).reshape(d, p * d)
    # Lower block: shift identity rows to move lags down
    if p > 1:
        # Build the (p-1) * d x p * d block that stacks [I, 0, ..., 0] down the subdiagonal
        I = torch.eye(d, device=device, dtype=dtype)
        zeros = torch.zeros((d, d), device=device, dtype=dtype)
        rows = []
        for r in range(p - 1):
            # row r: [0 ... 0, I, 0 ... 0] placing I at block column r
            blocks = [zeros] * p
            blocks[r] = I
            rows.append(torch.cat(blocks, dim=1))
        lower = torch.cat(rows, dim=0)  # ((p-1)*d, p*d)
    else:
        lower = torch.zeros((0, p * d), device=device, dtype=dtype)

    # Assemble standard companion C (pd x pd)
    C = torch.zeros((pdim, pdim), device=device, dtype=dtype)
    C[:d, :] = top
    if p > 1:
        C[d:, :] = lower

    if not augmented:
        # Constant input d (top block is c, rest zeros)
        if c is None:
            d_vec = torch.zeros((pdim,), device=device, dtype=dtype)
        else:
            assert c.shape == (d,), "c must be (d,)"
            d_vec = torch.zeros((pdim,), device=device, dtype=dtype)
            d_vec[:d] = c
        return {"C": C, "d": d_vec}

    # Augmented companion C_aug ((pd+1) x (pd+1))
    if c is None:
        c = torch.zeros((d,), device=device, dtype=dtype)
    assert c.shape == (d,), "c must be (d,)"

    C_aug = torch.zeros((pdim + 1, pdim + 1), device=device, dtype=dtype)
    # Fill the left (pd x pd) block with C
    C_aug[:pdim, :pdim] = C
    # Put intercept column (only in the top d rows)
    C_aug[:d, -1] = c
    # Keep the constant 1 invariant
    C_aug[-1, -1] = torch.tensor(1.0, device=device, dtype=dtype)

    return {"C_aug": C_aug}



def spectral_radius_power(C: torch.Tensor, iters: int = 20) -> torch.Tensor:
    """
    Estimate spectral radius rho(C) via power iteration on C^T C.

    Returns
    -------
    rho : scalar tensor, approximate max singular value of C, which upper bounds spectral radius.
    For non-normal C this is an upper bound. For stability penalty it is sufficient and smooth.
    """
    # Compute largest singular value of C as sqrt(lambda_max(C^T C))
    # This upper bounds the spectral radius and is differentiable.
    v = torch.randn(C.shape[1], device=C.device, dtype=C.dtype)
    v = v / (v.norm() + 1e-12)
    for _ in range(iters):
        v = (C.T @ (C @ v))
        v = v / (v.norm() + 1e-12)
    # Rayleigh quotient
    lam = v @ (C.T @ (C @ v))
    sigma_max = torch.sqrt(torch.clamp(lam, min=0.0))
    return sigma_max

import torch

def gc_strengths(A: torch.Tensor) -> torch.Tensor:
    """
    A: (p, d, d) with A[0]=A1,...,A[p-1]=Ap
    returns g: (d, d), where g[i,j] is strength j->i
    """
    # group-l2 over lag dimension
    g = torch.sqrt((A ** 2).sum(dim=0))  # (d, d)
    mask = 1.0 - torch.eye(A.shape[1], device=A.device, dtype=A.dtype)
    g_masked = g * mask
    
    return g_masked

def soft_mask(g: torch.Tensor, tau: float, beta: float) -> torch.Tensor:
    # S in [0,1], differentiable
    return torch.sigmoid(beta * (g - tau))

# def choose_tau(g_real: torch.Tensor, quantile: float = 0.3) -> float:
#     off = g_real[~torch.eye(g_real.shape[0], dtype=torch.bool, device=g_real.device)]
#     return torch.quantile(off, quantile).item()


def choose_tau(g_real: torch.Tensor, quantile: float = 0.3) -> float:
    off = g_real[~torch.eye(g_real.shape[0], dtype=torch.bool, device=g_real.device)]
    return torch.quantile(off, torch.tensor(quantile,dtype=g_real.dtype, device=g_real.device)).item()


def gc_mask_l2_loss(S_pred: torch.Tensor, S_real: torch.Tensor) -> torch.Tensor:
    # sum over i != j (diagonal is already ~0)
    return ((S_pred - S_real) ** 2).sum()



import torch
import torch.fft as fft

def spectral_density_fft(A: torch.Tensor, Sigma: torch.Tensor, M: int = 256, eps = 1e-7) -> torch.Tensor:
    """
    Compute spectral density matrices S(ω) for a VAR(p) using FFT.

    Parameters
    ----------
    A : (p, d, d)
        VAR coefficient blocks.
    Sigma : (d, d)
        Innovation covariance (positive definite).
    M : int
        Number of frequency grid points (default 256).

    Returns
    -------
    S : (M, d, d) complex tensor
        Spectral density matrices S(ω_m) for ω_m = 2πm/M.
        Hermitian and positive definite for all m.
    """
    p, d, _ = A.shape
    device, dtype = A.device, A.dtype

    # -----------------------
    # 1. Pad A_k to length M for FFT
    # -----------------------
    # Create sequence [I, -A1, -A2, ..., -Ap, 0, 0, ..., 0] of shape (M, d, d)
    A_seq = torch.zeros((M, d, d), dtype=torch.complex64 if dtype == torch.float32 else torch.complex128, device=device)
    A_seq[0] = torch.eye(d, device=device, dtype=dtype)        # coefficient for lag 0
    for k in range(p):
        A_seq[k+1] = -A[k].to(dtype=torch.complex128)

    # -----------------------
    # 2. FFT along lag dimension to get H(ω)
    # -----------------------
    # torch.fft.fft applies over axis=0 (the lag dimension)
    H_w = fft.fft(A_seq, dim=0)   # (M, d, d) complex

    # -----------------------
    # 3. Invert each H(ω) and compute S(ω)
    # -----------------------
    # Σ_ε might be real; convert to complex
    Sigma_c = Sigma.to(dtype=torch.complex128)

    # Compute S(ω) = H^{-1} Σ H^{-H}
    # We'll use batched linear solves for efficiency
    # Solve H_w @ X = I to get H_w^{-1}
    I_d = torch.eye(d, device=device, dtype=torch.complex128).expand(M, d, d)
    H_inv = torch.linalg.solve(H_w, I_d)          # (M, d, d)
    S_w = H_inv @ Sigma_c @ H_inv.conj().transpose(-1, -2)  # (M, d, d)
    if eps > 0:
        S_w = S_w + eps * I_d 
    return S_w



def inverse_spectrum(S_w: torch.Tensor) -> torch.Tensor:
    """
    Compute inverse spectrum Θ(ω) = S(ω)^{-1} for each frequency.
    """
    return torch.linalg.inv(S_w)

def partial_coherence_from_S(S_w: torch.Tensor):
    # Θ = (S + eps_inv I)^{-1}
    M, d, _ = S_w.shape
    I = torch.eye(d, dtype=S_w.dtype, device=S_w.device).expand(M, d, d)
    Theta = torch.linalg.inv(S_w )
    # γ^2_ij = |Θ_ij|^2 / (Θ_ii Θ_jj)
    num = (Theta.abs() ** 2)
    diag = Theta.diagonal(dim1=-2, dim2=-1).real.clamp_min(1e-12)  # (M,d)
    denom = diag.unsqueeze(-1) * diag.unsqueeze(-2)                 # (M,d,d)
    gamma2 = (num / denom).clamp(0, 1)                              # (M,d,d) in [0,1]
    return gamma2




In [None]:


# 2. Convert to a pandas DataFrame
# .rolling() operates on DataFrames
# df = pd.DataFrame(X.detach().numpy())
# ma_df = df.rolling(window=5, min_periods=1).mean()
# X2 = torch.tensor(ma_df.to_numpy())
X2 = X[:,:2]

# for i in range(7):
#     plt.plot(X2[:,i])
#     plt.plot(X.detach().numpy()[:,i])
#     plt.show()

for p in range(1,20):
    out = ridge_var(X2,p=p,alpha=0.01,standardize=True,add_intercept=True,eps = 1e-12)
    companion = build_companion(A = out['A'], c=out['c'], augmented=False)
    C = companion['C']
    rho_sur = spectral_radius_power(C)     # ≈ largest singular value
    print(p,rho_sur)


    

1 tensor(0.8913, dtype=torch.float64, grad_fn=<SqrtBackward0>)
2 tensor(1.6311, dtype=torch.float64, grad_fn=<SqrtBackward0>)
3 tensor(1.5746, dtype=torch.float64, grad_fn=<SqrtBackward0>)
4 tensor(1.6388, dtype=torch.float64, grad_fn=<SqrtBackward0>)
5 tensor(1.7146, dtype=torch.float64, grad_fn=<SqrtBackward0>)
6 tensor(1.6996, dtype=torch.float64, grad_fn=<SqrtBackward0>)
7 tensor(1.7247, dtype=torch.float64, grad_fn=<SqrtBackward0>)
8 tensor(1.7358, dtype=torch.float64, grad_fn=<SqrtBackward0>)
9 tensor(1.8010, dtype=torch.float64, grad_fn=<SqrtBackward0>)
10 tensor(1.8025, dtype=torch.float64, grad_fn=<SqrtBackward0>)
11 tensor(1.7804, dtype=torch.float64, grad_fn=<SqrtBackward0>)
12 tensor(1.7778, dtype=torch.float64, grad_fn=<SqrtBackward0>)
13 tensor(1.7972, dtype=torch.float64, grad_fn=<SqrtBackward0>)
14 tensor(1.9452, dtype=torch.float64, grad_fn=<SqrtBackward0>)
15 tensor(1.9141, dtype=torch.float64, grad_fn=<SqrtBackward0>)
16 tensor(1.9402, dtype=torch.float64, grad_fn=<S

In [414]:
import pandas as pd
import numpy as np
from statsmodels.tsa.api import VAR
from statsmodels.tsa.stattools import adfuller
from statsmodels.tsa.vector_ar.var_model import VAR

# --- 1. Load & Prepare Data ---
# Let's create some sample data. In a real case, you'd load this
# from a file (e.g., pd.read_csv(...)).
# We create two time series that depend on each other's past.

df = pd.read_csv('data/ETT-small/ETTh1.csv', parse_dates=['date']).iloc[0:100,:]
df = df.set_index('date')
df = df.asfreq('h')  # ‘H’ = hourly
df = (df - df.mean()) / df.std()



# model = VAR(endog=df, dates=df.index, freq='h')  # pass freq explicitly


# model
# # 1. Load the data 

# w
# # 2. Set the date column as index

df.shape

(100, 7)

In [369]:
import pandas as pd
import numpy as np
from statsmodels.tsa.stattools import adfuller, kpss, grangercausalitytests
from statsmodels.tsa.vector_ar.var_model import VAR

# assume df is already loaded, set index, standardized, etc (as you did)
# e.g., df = df.set_index('date').asfreq('H').dropna()

# 1. Stationarity tests for each column
def adf_test(series, signif=0.05, name='', verbose=True):
    result = adfuller(series.dropna(), autolag='AIC')
    p_value = result[1]
    if verbose:
        print(f'ADF test for {name}: p-value = {p_value:.4f}')
    return p_value < signif

def kpss_test(series, signif=0.05, name='', verbose=True):
    # Note: statsmodels.kpss may warn about trend='c' or 'ct'
    from statsmodels.tsa.stattools import kpss
    result = kpss(series.dropna(), regression='c', nlags="auto")
    p_value = result[1]
    if verbose:
        print(f'KPSS test for {name}: p-value = {p_value:.4f}')
    return p_value > signif

stationary = {}
for col in df.columns:
    print(f'Column: {col}')
    adf_ok = adf_test(df[col], name=col)
    kpss_ok = kpss_test(df[col], name=col)
    stationary[col] = adf_ok and kpss_ok

print('Stationary status:', stationary)

# If any series is not stationary, difference it:
if not all(stationary.values()):
    df_diff = df.diff().dropna()
    print('Used first difference for non-stationary series')
    # You might repeat the tests until acceptable.
    df = df_diff

# 2. Find lag order for VAR
model = VAR(df)
order_results = model.select_order(maxlags=8)
print(order_results.summary())

best_lag = order_results.selected_orders['aic']  # you could also inspect bic
print('Best lag (AIC):', best_lag)

# 3. Fit VAR
var_res = model.fit(best_lag)
print(var_res.summary())

# Stability check
print('Is the VAR stable?', var_res.is_stable())

# 4. Granger causality matrix of all combinations
def granger_matrix(data, variables, maxlag, verbose=False):
    import pandas as pd
    from statsmodels.tsa.stattools import grangercausalitytests

    df_pvals = pd.DataFrame(np.zeros((len(variables), len(variables))), columns=variables, index=variables)
    for y in variables:
        for x in variables:
            if x == y:
                df_pvals.loc[y, x] = np.nan
                continue
            test_result = grangercausalitytests(data[[y, x]], maxlag=maxlag)
            p_values = [round(test_result[i+1][0]['ssr_chi2test'][1],4) for i in range(maxlag)]
            min_p = np.min(p_values)
            df_pvals.loc[y, x] = min_p
    return df_pvals

variables = df.columns.tolist()
maxlag = best_lag
print('Granger causality p-value matrix (rows = Y, cols = X):')
pval_matrix = granger_matrix(df, variables, maxlag=maxlag)
print(pval_matrix)

# Interpretation:
# If pval_matrix.loc[Y, X] < 0.05 → X Granger-causes Y at that lag.


Column: HUFL
ADF test for HUFL: p-value = 0.0000
KPSS test for HUFL: p-value = 0.1000
Column: HULL
ADF test for HULL: p-value = 0.0000
KPSS test for HULL: p-value = 0.1000
Column: MUFL
ADF test for MUFL: p-value = 0.0000
KPSS test for MUFL: p-value = 0.1000
Column: MULL
ADF test for MULL: p-value = 0.0000
KPSS test for MULL: p-value = 0.1000
Column: LUFL
ADF test for LUFL: p-value = 0.0000
KPSS test for LUFL: p-value = 0.1000
Column: LULL
ADF test for LULL: p-value = 0.0000
KPSS test for LULL: p-value = 0.1000
Column: OT
ADF test for OT: p-value = 0.0000
KPSS test for OT: p-value = 0.0417
Stationary status: {'HUFL': True, 'HULL': True, 'MUFL': True, 'MULL': True, 'LUFL': True, 'LULL': True, 'OT': False}
Used first difference for non-stationary series
 VAR Order Selection (* highlights the minimums) 
      AIC         BIC         FPE         HQIC   
-------------------------------------------------
0       8.234       8.431       3766.       8.313
1       2.309       3.885       10.10  

look-up table. The actual p-value is greater than the p-value returned.

  result = kpss(series.dropna(), regression='c', nlags="auto")
look-up table. The actual p-value is greater than the p-value returned.

  result = kpss(series.dropna(), regression='c', nlags="auto")
look-up table. The actual p-value is greater than the p-value returned.

  result = kpss(series.dropna(), regression='c', nlags="auto")
look-up table. The actual p-value is greater than the p-value returned.

  result = kpss(series.dropna(), regression='c', nlags="auto")
look-up table. The actual p-value is greater than the p-value returned.

  result = kpss(series.dropna(), regression='c', nlags="auto")
look-up table. The actual p-value is greater than the p-value returned.

  result = kpss(series.dropna(), regression='c', nlags="auto")


parameter F test:         F=0.0416  , p=0.9967  , df_denom=83, df_num=4

Granger Causality
number of lags (no zero) 5
ssr based F test:         F=0.2119  , p=0.9565  , df_denom=80, df_num=5
ssr based chi2 test:   chi2=1.2054  , p=0.9444  , df=5
likelihood ratio test: chi2=1.1975  , p=0.9451  , df=5
parameter F test:         F=0.2119  , p=0.9565  , df_denom=80, df_num=5

Granger Causality
number of lags (no zero) 6
ssr based F test:         F=0.1697  , p=0.9842  , df_denom=77, df_num=6
ssr based chi2 test:   chi2=1.1898  , p=0.9774  , df=6
likelihood ratio test: chi2=1.1820  , p=0.9778  , df=6
parameter F test:         F=0.1697  , p=0.9842  , df_denom=77, df_num=6

Granger Causality
number of lags (no zero) 7
ssr based F test:         F=0.2730  , p=0.9626  , df_denom=74, df_num=7
ssr based chi2 test:   chi2=2.2988  , p=0.9415  , df=7
likelihood ratio test: chi2=2.2696  , p=0.9434  , df=7
parameter F test:         F=0.2730  , p=0.9626  , df_denom=74, df_num=7

Granger Causality
number of

In [376]:
from statsmodels.tsa.stattools import grangercausalitytests
import pandas as pd
import numpy as np

def simple_granger_matrix(data, maxlag, alpha=0.05):
    cols = data.columns
    mat = pd.DataFrame(np.zeros((len(cols), len(cols))), index=cols, columns=cols)
    for y in cols:
        for x in cols:
            if x == y:
                mat.loc[y, x] = np.nan
            else:
                test = grangercausalitytests(data[[y, x]].dropna(), maxlag=maxlag, verbose=False)
                # get the smallest p-value among lags 1..maxlag for the F-test
                p_vals = [ test[l][0]['ssr_ftest'][1] for l in range(1, maxlag+1) ]
                min_p = min(p_vals)
                mat.loc[y, x] = 1 if min_p < alpha else 0
    return mat

# Usage:
maxlag = 5  # adjust based on your lag-order selection
gr_mat = simple_granger_matrix(df, maxlag=maxlag, alpha=0.05)
print("Matrix (1 = X Granger-causes Y):\n", gr_mat)


Matrix (1 = X Granger-causes Y):
       HUFL  HULL  MUFL  MULL  LUFL  LULL   OT
HUFL   NaN   0.0   0.0   0.0   0.0   0.0  0.0
HULL   0.0   NaN   0.0   0.0   0.0   0.0  0.0
MUFL   0.0   0.0   NaN   1.0   0.0   0.0  0.0
MULL   0.0   0.0   0.0   NaN   0.0   1.0  0.0
LUFL   1.0   0.0   0.0   0.0   NaN   0.0  0.0
LULL   0.0   0.0   0.0   0.0   0.0   NaN  1.0
OT     0.0   1.0   0.0   0.0   1.0   1.0  NaN




In [None]:

from statsmodels.tools.tools import add_constant
from statsmodels.regression.linear_model import OLS, yule_walker
from statsmodels.tsa.tsatools import lagmat2ds
from scipy import stats




def grangercausalitytests2(x, maxlag):

    lags = np.arange(1, maxlag + 1)
    resli = {}

    for mlg in lags:
        result = {}
        mxlg = mlg

        # create lagmat of both time series
        dta = lagmat2ds(x, mxlg, trim="both", dropex=1)

        # add constant
        dtaown = add_constant(dta[:, 1 : (mxlg + 1)], prepend=False)
        dtajoint = add_constant(dta[:, 1:], prepend=False)

        # Run ols on both models without and with lags of second variable
        res2down = OLS(dta[:, 0], dtaown).fit()
        res2djoint = OLS(dta[:, 0], dtajoint).fit()


        # Granger Causality test using ssr (F statistic)
        if res2djoint.model.k_constant:
            tss = res2djoint.centered_tss
        else:
            tss = res2djoint.uncentered_tss

        fgc1 = (
            (res2down.ssr - res2djoint.ssr)
            / res2djoint.ssr
            / mxlg
            * res2djoint.df_resid
        )

        result["ssr_ftest"] = (
            stats.f.sf(fgc1, mxlg, res2djoint.df_resid),
        )



        resli[mxlg] = (result)

    return resli






def simple_granger_matrix(data, maxlag, alpha=0.05):
    cols = data.columns
    mat = pd.DataFrame(np.zeros((len(cols), len(cols))), index=cols, columns=cols)

    for y in cols:
        for x in cols:
            if x == y:
                mat.loc[y, x] = np.nan
            else:
                test = grangercausalitytests2(data[[y, x]].dropna(), maxlag=maxlag)
                # get the smallest p-value among lags 1..maxlag for the F-test
                p_vals = [ test[l]['ssr_ftest'][0] for l in range(1, maxlag+1) ]
                min_p = min(p_vals)
                mat.loc[y, x] = 1 if min_p < alpha else 0
    return mat




# Usage:
maxlag = 5  # adjust based on your lag-order selection
gr_mat = simple_granger_matrix(df, maxlag=maxlag, alpha=0.05)
gr_mat




gr_mat.sum().sum()

17.0

In [None]:

X2 = X[1:193:2,:]
X1 = X[:96,:]
 
p =4

out1 = ridge_var(X1,p=p,alpha=0.01,standardize=True,add_intercept=True,eps = 1e-12)
# companion = build_companion(A = out['A'], c=out['c'], augmented=False)
# C = companion['C']
# rho_sur1 = spectral_radius_power(C)     # ≈ largest singular value
# stab_loss = torch.nn.functional.softplus(rho_sur - 0.98)
G1 = gc_strengths(out1['A'])
S1 = soft_mask(G1, tau=choose_tau(G1,quantile=0.4), beta=1.5)
F1 = spectral_density_fft(out1['A'], out1['Sigma'], M=256)  # (256, 3, 3)
Theta1 = inverse_spectrum(F1)
Theta1 = partial_coherence_from_S(Theta1)


# # for p in range(1,50):
out2 = ridge_var(X2,p=p,alpha=0.01,standardize=True,add_intercept=True,eps = 1e-12)
# companion = build_companion(A = out['A'], c=out['c'], augmented=False)
# C = companion['C']
# rho_sur2 = spectral_radius_power(C)     # ≈ largest singular value
# stab_loss = torch.nn.functional.softplus(rho_sur - 0.98)
G2 = gc_strengths(out2['A'])
S2 = soft_mask(G2, tau=choose_tau(G2,quantile=0.4), beta=1.5)
F2 = spectral_density_fft(out2['A'], out2['Sigma'], M=256)  # (256, 3, 3)
Theta2 = inverse_spectrum(F2)
Theta2 = partial_coherence_from_S(Theta2)


# print(p, gc_mask_l2_loss(S1,S2).item(),rho_sur1.item(), rho_sur2.item())
off = ~torch.eye(Theta1.shape[-1], dtype=torch.bool, device=Theta1.device)


L_CIG = (Theta1 - Theta2).abs()[..., off].mean()
L_Struct = gc_mask_l2_loss(S1,S2)



L = L_Struct + L_CIG


L.backward()


In [283]:
def choose_tau(g_real: torch.Tensor, quantile: float = 0.3) -> float:
    off = g_real[~torch.eye(g_real.shape[0], dtype=g_real.type, device=g_real.device)]
    return torch.quantile(off, torch.tensor(quantile, device=g_real.device)).item()

g_real = G1
quantile = 0.3
off = g_real[~torch.eye(g_real.shape[0], dtype=torch.bool, device=g_real.device)]
torch.quantile(off, torch.tensor(quantile,dtype=g_real.dtype, device=g_real.device)).item()

0.25273539157991975

In [None]:
# model.py
# ---------------------------------------------------------
# Training with your ETT-style data factory + loaders.
# Uses:
#   - data_provider(args, flag) from your data factory
#   - Dataset_ETT_* shapes: (seq_x, seq_y, seq_x_mark, seq_y_mark)
#   - MLP forecaster (feed-forward) for simplicity
#   - Differentiable ridge-VAR extractor E on FUTURE window
#   - GC-skeleton loss (time-domain) + optional CIG loss (frequency-domain)
# ---------------------------------------------------------


# =========================
# ---- import your data factory
# =========================


import argparse
from typing import Tuple, Dict, Optional
import torch
import torch.nn as nn
import torch.fft as fft
import os
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from typing import List
from pandas.tseries import offsets
from pandas.tseries.frequencies import to_offset


class TimeFeature:
    def __init__(self):
        pass

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        pass

    def __repr__(self):
        return self.__class__.__name__ + "()"


class SecondOfMinute(TimeFeature):
    """Minute of hour encoded as value between [-0.5, 0.5]"""

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        return index.second / 59.0 - 0.5


class MinuteOfHour(TimeFeature):
    """Minute of hour encoded as value between [-0.5, 0.5]"""

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        return index.minute / 59.0 - 0.5


class HourOfDay(TimeFeature):
    """Hour of day encoded as value between [-0.5, 0.5]"""

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        return index.hour / 23.0 - 0.5


class DayOfWeek(TimeFeature):
    """Hour of day encoded as value between [-0.5, 0.5]"""

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        return index.dayofweek / 6.0 - 0.5


class DayOfMonth(TimeFeature):
    """Day of month encoded as value between [-0.5, 0.5]"""

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        return (index.day - 1) / 30.0 - 0.5


class DayOfYear(TimeFeature):
    """Day of year encoded as value between [-0.5, 0.5]"""

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        return (index.dayofyear - 1) / 365.0 - 0.5


class MonthOfYear(TimeFeature):
    """Month of year encoded as value between [-0.5, 0.5]"""

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        return (index.month - 1) / 11.0 - 0.5


class WeekOfYear(TimeFeature):
    """Week of year encoded as value between [-0.5, 0.5]"""

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        return (index.isocalendar().week - 1) / 52.0 - 0.5

def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]:
    """
    Returns a list of time features that will be appropriate for the given frequency string.
    Parameters
    ----------
    freq_str
        Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc.
    """

    features_by_offsets = {
        offsets.YearEnd: [],
        offsets.QuarterEnd: [MonthOfYear],
        offsets.MonthEnd: [MonthOfYear],
        offsets.Week: [DayOfMonth, WeekOfYear],
        offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear],
        offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear],
        offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear],
        offsets.Minute: [
            MinuteOfHour,
            HourOfDay,
            DayOfWeek,
            DayOfMonth,
            DayOfYear,
        ],
        offsets.Second: [
            SecondOfMinute,
            MinuteOfHour,
            HourOfDay,
            DayOfWeek,
            DayOfMonth,
            DayOfYear,
        ],
    }

    offset = to_offset(freq_str)

    for offset_type, feature_classes in features_by_offsets.items():
        if isinstance(offset, offset_type):
            return [cls() for cls in feature_classes]

    supported_freq_msg = f"""
    Unsupported frequency {freq_str}
    The following frequencies are supported:
        Y   - yearly
            alias: A
        M   - monthly
        W   - weekly
        D   - daily
        B   - business days
        H   - hourly
        T   - minutely
            alias: min
        S   - secondly
    """
    raise RuntimeError(supported_freq_msg)


def time_features(dates, freq='h'):
    return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)])

class Dataset_ETT_hour(Dataset):
    def __init__(self, root_path, flag='train', size=None,
                 features='S', data_path='ETTh1.csv',
                 target='OT', scale=True, timeenc=0, freq='h'):
        # size [seq_len, label_len, pred_len]
        # info
        if size == None:
            self.seq_len = 24 * 4 * 4
            self.label_len = 24 * 4
            self.pred_len = 24 * 4
        else:
            self.seq_len = size[0]
            self.label_len = size[1]
            self.pred_len = size[2]
        # init
        assert flag in ['train', 'test', 'val']
        type_map = {'train': 0, 'val': 1, 'test': 2}
        self.set_type = type_map[flag]

        self.features = features
        self.target = target
        self.scale = scale
        self.timeenc = timeenc
        self.freq = freq

        self.root_path = root_path
        self.data_path = data_path
        self.__read_data__()

    def __read_data__(self):
        self.scaler = StandardScaler()
        df_raw = pd.read_csv(os.path.join(self.root_path,
                                          self.data_path))

        border1s = [0, 12 * 30 * 24 - self.seq_len, 12 * 30 * 24 + 4 * 30 * 24 - self.seq_len]
        border2s = [12 * 30 * 24, 12 * 30 * 24 + 4 * 30 * 24, 12 * 30 * 24 + 8 * 30 * 24]
        border1 = border1s[self.set_type]
        border2 = border2s[self.set_type]

        if self.features == 'M' or self.features == 'MS':
            cols_data = df_raw.columns[1:]
            df_data = df_raw[cols_data]
        elif self.features == 'S':
            df_data = df_raw[[self.target]]

        if self.scale:
            train_data = df_data[border1s[0]:border2s[0]]
            self.scaler.fit(train_data.values)
            data = self.scaler.transform(df_data.values)
        else:
            data = df_data.values

        df_stamp = df_raw[['date']][border1:border2]
        df_stamp['date'] = pd.to_datetime(df_stamp.date)
        if self.timeenc == 0:
            df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
            df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
            df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
            df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
            data_stamp = df_stamp.drop(['date'], axis=1).values
        elif self.timeenc == 1:
            data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
            data_stamp = data_stamp.transpose(1, 0)

        self.data_x = data[border1:border2]
        self.data_y = data[border1:border2]
        self.data_stamp = data_stamp

    def __getitem__(self, index):
        s_begin = index
        s_end = s_begin + self.seq_len
        r_begin = s_end - self.label_len
        r_end = r_begin + self.label_len + self.pred_len

        seq_x = self.data_x[s_begin:s_end]
        seq_y = self.data_y[r_begin:r_end]
        seq_x_mark = self.data_stamp[s_begin:s_end]
        seq_y_mark = self.data_stamp[r_begin:r_end]

        return seq_x, seq_y, seq_x_mark, seq_y_mark

    def __len__(self):
        return len(self.data_x) - self.seq_len - self.pred_len + 1

    def inverse_transform(self, data):
        return self.scaler.inverse_transform(data)




data_dict = {
    'ETTh1': Dataset_ETT_hour,
    'ETTh2': Dataset_ETT_hour,

}

class Dataset_Pred(Dataset):
    def __init__(self, root_path, flag='pred', size=None,
                 features='S', data_path='ETTh1.csv',
                 target='OT', scale=True, inverse=False, timeenc=0, freq='15min', cols=None):
        # size [seq_len, label_len, pred_len]
        # info
        if size == None:
            self.seq_len = 24 * 4 * 4
            self.label_len = 24 * 4
            self.pred_len = 24 * 4
        else:
            self.seq_len = size[0]
            self.label_len = size[1]
            self.pred_len = size[2]
        # init
        assert flag in ['pred']

        self.features = features
        self.target = target
        self.scale = scale
        self.inverse = inverse
        self.timeenc = timeenc
        self.freq = freq
        self.cols = cols
        self.root_path = root_path
        self.data_path = data_path
        self.__read_data__()

    def __read_data__(self):
        self.scaler = StandardScaler()
        df_raw = pd.read_csv(os.path.join(self.root_path,
                                          self.data_path))
        '''
        df_raw.columns: ['date', ...(other features), target feature]
        '''
        if self.cols:
            cols = self.cols.copy()
            cols.remove(self.target)
        else:
            cols = list(df_raw.columns)
            cols.remove(self.target)
            cols.remove('date')
        df_raw = df_raw[['date'] + cols + [self.target]]
        border1 = len(df_raw) - self.seq_len
        border2 = len(df_raw)

        if self.features == 'M' or self.features == 'MS':
            cols_data = df_raw.columns[1:]
            df_data = df_raw[cols_data]
        elif self.features == 'S':
            df_data = df_raw[[self.target]]

        if self.scale:
            self.scaler.fit(df_data.values)
            data = self.scaler.transform(df_data.values)
        else:
            data = df_data.values

        tmp_stamp = df_raw[['date']][border1:border2]
        tmp_stamp['date'] = pd.to_datetime(tmp_stamp.date)
        pred_dates = pd.date_range(tmp_stamp.date.values[-1], periods=self.pred_len + 1, freq=self.freq)

        df_stamp = pd.DataFrame(columns=['date'])
        df_stamp.date = list(tmp_stamp.date.values) + list(pred_dates[1:])
        if self.timeenc == 0:
            df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
            df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
            df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
            df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
            df_stamp['minute'] = df_stamp.date.apply(lambda row: row.minute, 1)
            df_stamp['minute'] = df_stamp.minute.map(lambda x: x // 15)
            data_stamp = df_stamp.drop(['date'], axis=1).values
        elif self.timeenc == 1:
            data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
            data_stamp = data_stamp.transpose(1, 0)

        self.data_x = data[border1:border2]
        if self.inverse:
            self.data_y = df_data.values[border1:border2]
        else:
            self.data_y = data[border1:border2]
        self.data_stamp = data_stamp

    def __getitem__(self, index):
        s_begin = index
        s_end = s_begin + self.seq_len
        r_begin = s_end - self.label_len
        r_end = r_begin + self.label_len + self.pred_len

        seq_x = self.data_x[s_begin:s_end]
        if self.inverse:
            seq_y = self.data_x[r_begin:r_begin + self.label_len]
        else:
            seq_y = self.data_y[r_begin:r_begin + self.label_len]
        seq_x_mark = self.data_stamp[s_begin:s_end]
        seq_y_mark = self.data_stamp[r_begin:r_end]

        return seq_x, seq_y, seq_x_mark, seq_y_mark

    def __len__(self):
        return len(self.data_x) - self.seq_len + 1

    def inverse_transform(self, data):
        return self.scaler.inverse_transform(data)

def data_provider(args, flag):
    Data = data_dict[args.data]
    timeenc = 0 if args.embed != 'timeF' else 1

    if flag == 'test':
        shuffle_flag = False
        drop_last = True
        batch_size = args.batch_size
        freq = args.freq
    elif flag == 'pred':
        shuffle_flag = False
        drop_last = False
        batch_size = 1
        freq = args.freq
        Data = Dataset_Pred
    else:
        shuffle_flag = True
        drop_last = True
        batch_size = args.batch_size
        freq = args.freq

    data_set = Data(
        root_path=args.root_path,
        data_path=args.data_path,
        flag=flag,
        size=[args.seq_len, args.label_len, args.pred_len],
        features=args.features,
        target=args.target,
        timeenc=timeenc,
        freq=freq
    )
    print(flag, len(data_set))
    data_loader = DataLoader(
        data_set,
        batch_size=batch_size,
        shuffle=shuffle_flag,
        num_workers=args.num_workers,
        drop_last=drop_last)
    return data_set, data_loader


# =========================
# ---- Model: simple feed-forward MLP
# =========================
class MLPForecaster(nn.Module):
    """
    Maps flattened context (B, C*d) -> flattened future (B, H*d).
    """
    def __init__(self, d: int, context_len: int, horizon: int,
                 hidden: int = 512, depth: int = 2, dropout: float = 0.0):
        super().__init__()
        self.d = d
        self.C = context_len
        self.H = horizon

        in_dim = self.C * d
        out_dim = self.H * d
        dims = [in_dim] + [hidden] * max(0, depth - 1) + [out_dim]
        layers = []
        for i in range(len(dims) - 1):
            layers.append(nn.Linear(dims[i], dims[i + 1]))
            if i < len(dims) - 2:
                layers.append(nn.ReLU(inplace=True))
                if dropout > 0.0:
                    layers.append(nn.Dropout(dropout))
        self.net = nn.Sequential(*layers)

    def forward(self, ctx: torch.Tensor) -> torch.Tensor:
        # ctx: (B, C, d) -> pred: (B, H, d)
        B, C, d = ctx.shape
        y = self.net(ctx.reshape(B, C * d))
        return y.view(B, -1, d)


# =========================
# ---- Differentiable ridge-VAR extractor E
# =========================
def ridge_var(
    Y: torch.Tensor,
    p: int = 2,
    alpha: float = 1e-2,
    standardize: bool = True,
    add_intercept: bool = True,
    eps: float = 1e-8,
) -> Dict[str, torch.Tensor]:
    """
    Ridge-regularized VAR(p) fit on a future window Y (H, d). Differentiable.
    Returns A (p,d,d), Sigma (d,d), residuals, etc.
    """
    assert Y.dim() == 2, "Y must be (H, d)"
    H, d = Y.shape
    assert H > p, "Need H > p to fit a VAR(p)"
    device, dtype = Y.device, Y.dtype

    if standardize:
        mu = Y.mean(dim=0, keepdim=True)
        std = Y.std(dim=0, unbiased=False, keepdim=True).clamp_min(eps)
        Yz = (Y - mu) / std
    else:
        mu = std = None
        Yz = Y

    Y_resp = Yz[p:, :]  # (H-p, d)
    X_lags = [Yz[p - k: H - k, :] for k in range(1, p + 1)]
    X_reg = torch.cat(X_lags, dim=1)  # (H-p, p*d)

    if add_intercept:
        ones = torch.ones((H - p, 1), device=device, dtype=dtype)
        X_reg = torch.cat([X_reg, ones], dim=1)  # (H-p, p*d + 1)

    G = X_reg.T @ X_reg + alpha * torch.eye(X_reg.shape[1], device=device, dtype=dtype)
    B = X_reg.T @ Y_resp

    L = torch.linalg.cholesky(G)
    A_full = torch.cholesky_solve(B, L)  # solves G A = B

    if add_intercept:
        A_flat = A_full[:-1, :]
        c = A_full[-1, :]
    else:
        A_flat = A_full
        c = torch.zeros(d, device=device, dtype=dtype)

    A = A_flat.reshape(p, d, d)  # (p, d, d)

    Y_hat = X_reg @ A_full
    R = Y_resp - Y_hat
    Sigma = (R.T @ R) / (H - p)

    return {"A": A, "Sigma": Sigma, "R": R, "X_reg": X_reg, "Y_resp": Y_resp, "c": c,
            "mu": None if not standardize else mu.squeeze(0),
            "std": None if not standardize else std.squeeze(0)}


# =========================
# ---- GC-skeleton helper functions
# =========================
def gc_strengths(A: torch.Tensor) -> torch.Tensor:
    g = torch.sqrt((A ** 2).sum(dim=0))  # (d,d)
    g.fill_diagonal_(0.0)
    return g

def soft_mask(g: torch.Tensor, tau: float, beta: float) -> torch.Tensor:
    return torch.sigmoid(beta * (g - tau))

def choose_tau(g_real: torch.Tensor, quantile: float = 0.3) -> float:
    d = g_real.shape[0]
    off = g_real[~torch.eye(d, dtype=torch.bool, device=g_real.device)]
    q = torch.tensor(quantile, device=g_real.device, dtype=g_real.dtype)
    return torch.quantile(off, q).item()

def gc_mask_l2_loss(S_pred: torch.Tensor, S_real: torch.Tensor) -> torch.Tensor:
    return ((S_pred - S_real) ** 2).sum()


# =========================
# ---- Spectral/CIG helpers (batched)
# =========================
def batched_spectrum_inverse_partial_coherence(
    A: torch.Tensor,
    Sigma: torch.Tensor,
    M: int = 256,
    eps_s: float = 1e-8,
    eps_inv: float = 1e-5,
):
    """
    A: (B,p,d,d), Sigma: (B,d,d)
    Returns: S_w (B,M,d,d), Theta_w (B,M,d,d), gamma2 (B,M,d,d) with γ^2 in [0,1].
    """
    if A.dim() == 3:   A = A.unsqueeze(0)
    if Sigma.dim() == 2: Sigma = Sigma.unsqueeze(0)

    B, p, d, _ = A.shape
    device = A.device
    ctype = torch.complex128 if A.dtype == torch.float64 else torch.complex64

    A_seq = torch.zeros((B, M, d, d), dtype=ctype, device=device)
    eye_c = torch.eye(d, dtype=ctype, device=device).expand(B, d, d)
    A_seq[:, 0] = eye_c
    A_c = A.to(dtype=ctype)
    for k in range(p):
        A_seq[:, k + 1] = -A_c[:, k]

    H_w = fft.fft(A_seq, dim=1)                                # (B,M,d,d)
    I_w = torch.eye(d, dtype=ctype, device=device).expand(B, M, d, d)
    H_inv = torch.linalg.solve(H_w, I_w)                       # (B,M,d,d)

    Sigma_c = Sigma.to(dtype=ctype)
    S_w = H_inv @ Sigma_c.unsqueeze(1) @ H_inv.conj().transpose(-1, -2)
    if eps_s > 0:
        S_w = S_w + eps_s * I_w
    Theta_w = torch.linalg.inv(S_w + eps_inv * I_w)

    num = (Theta_w.abs() ** 2)
    diag = Theta_w.diagonal(dim1=-2, dim2=-1).real.clamp_min(1e-12)
    denom = diag.unsqueeze(-1) * diag.unsqueeze(-2)
    gamma2 = (num / denom).real.clamp_(0.0, 1.0)
    return S_w, Theta_w, gamma2


# =========================
# ---- Combined GC + CIG loss on FUTURE windows
# =========================
def gc_and_cig_loss_from_windows(
    Y_real: torch.Tensor,        # (B, H, d)
    Y_pred: torch.Tensor,        # (B, H, d)
    p: int = 2,
    alpha: float = 1e-2,
    beta: float = 4.0,
    tau_quantile: float = 0.4,
    use_cig: bool = True,
    cig_lambda: float = 0.1,
    M: int = 256,
    eps_s: float = 1e-8,
    eps_inv: float = 1e-5,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
    """
    End-to-end loss combining:
      - GC-skeleton soft-mask L2 (time-domain) averaged over batch
      - optional CIG partial-coherence L1 (frequency-domain), batched
    """
    B, H, d = Y_real.shape
    device = Y_real.device

    gc_total = 0.0
    diag = {}
    A_real_list, A_pred_list, Sig_real_list, Sig_pred_list = [], [], [], []

    for b in range(B):
        Er = ridge_var(Y_real[b], p=p, alpha=alpha, standardize=True, add_intercept=True)
        Ep = ridge_var(Y_pred[b], p=p, alpha=alpha, standardize=True, add_intercept=True)

        A_r, A_p = Er["A"], Ep["A"]
        g_r = gc_strengths(A_r)
        g_p = gc_strengths(A_p)

        tau = choose_tau(g_r, quantile=tau_quantile)
        S_r = soft_mask(g_r, tau=tau, beta=beta)
        S_p = soft_mask(g_p, tau=tau, beta=beta)

        gc_total = gc_total + gc_mask_l2_loss(S_p, S_r)

        A_real_list.append(A_r)
        A_pred_list.append(A_p)
        Sig_real_list.append(Er["Sigma"])
        Sig_pred_list.append(Ep["Sigma"])

        if b == 0:
            diag.update({
                "g_real": g_r.detach(),
                "g_pred": g_p.detach(),
                "S_real": S_r.detach(),
                "S_pred": S_p.detach(),
                "tau": torch.tensor(tau, device=device),
            })

    gc_loss = gc_total / B

    if use_cig:
        A1 = torch.stack(A_real_list, dim=0)   # (B,p,d,d)
        A2 = torch.stack(A_pred_list, dim=0)   # (B,p,d,d)
        S1 = torch.stack(Sig_real_list, dim=0) # (B,d,d)
        S2 = torch.stack(Sig_pred_list, dim=0) # (B,d,d)

        _, _, gamma1 = batched_spectrum_inverse_partial_coherence(A1, S1, M, eps_s, eps_inv)
        _, _, gamma2 = batched_spectrum_inverse_partial_coherence(A2, S2, M, eps_s, eps_inv)

        off = ~torch.eye(d, dtype=torch.bool, device=device)
        cig_loss = (gamma1 - gamma2).abs()[..., off].mean()
        diag["cig_loss"] = cig_loss.detach()
    else:
        cig_loss = torch.tensor(0.0, device=device)

    total_loss = gc_loss + cig_lambda * cig_loss
    diag.update({"gc_loss": gc_loss.detach(), "total_loss": total_loss.detach()})
    return total_loss, diag


# =========================
# ---- Training loop integrated with your data_provider
# =========================
def train_with_factory(
    args,
    hidden: int = 512,
    depth: int = 2,
    dropout: float = 0.0,
    ridge_p: int = 2,
    ridge_alpha: float = 1e-2,
    gc_lambda: float = 0.1,
    gc_beta: float = 4.0,
    gc_tau_q: float = 0.3,
    use_cig: bool = True,
    cig_lambda: float = 0.05,
    cig_M: int = 256,
    cig_eps_s: float = 1e-8,
    cig_eps_inv: float = 1e-5,
    device: Optional[str] = None,
):
    """
    Uses your data_provider(args, flag) to get loaders for train/val/test.
    Expects args to contain fields used by data_provider (root_path, data_path, etc.)
    """
    dev = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))

    # get loaders
    train_data, train_loader = data_provider(args, flag='train')
    vali_data,  vali_loader  = data_provider(args, flag='val')
    test_data,  test_loader  = data_provider(args, flag='test')

    # infer dims from one batch
    sample = next(iter(train_loader))
    seq_x, seq_y, seq_x_mark, seq_y_mark = sample
    B, C, d = seq_x.shape
    H = args.pred_len
    assert seq_y.shape[1] == args.label_len + args.pred_len, "seq_y layout mismatch."

    model = MLPForecaster(d=d, context_len=C, horizon=H,
                          hidden=hidden, depth=depth, dropout=dropout).to(dev)
    opt = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    scaler = None  # (you can wrap amp if you want)

    def run_epoch(loader, training: bool):
        if training:
            model.train()
        else:
            model.eval()
        mse_running, gc_running, cig_running = 0.0, 0.0, 0.0
        n_batches = 0

        for seq_x, seq_y, seq_x_mark, seq_y_mark in loader:
            n_batches += 1
            ctx = seq_x.to(dev).float()                   # (B,C,d)
            fut_full = seq_y.to(dev).float()              # (B,label_len+H,d)
            fut = fut_full[:, -H:, :]                     # (B,H,d) -> ground truth future

            pred = model(ctx)                             # (B,H,d)

            mse = ((pred - fut) ** 2).mean()
            struct_loss, diag = gc_and_cig_loss_from_windows(
                Y_real=fut, Y_pred=pred,
                p=ridge_p, alpha=ridge_alpha,
                beta=gc_beta, tau_quantile=gc_tau_q,
                use_cig=use_cig, cig_lambda=cig_lambda,
                M=cig_M, eps_s=cig_eps_s, eps_inv=cig_eps_inv
            )
            # struct_loss already = gc + λ*cig; keep separate logs by reading diag
            total = mse + struct_loss

            if training:
                opt.zero_grad()
                total.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                opt.step()

            mse_running += mse.item()
            gc_running  += diag["gc_loss"].item()
            cig_running += diag.get("cig_loss", torch.tensor(0.0)).item()

        return mse_running / n_batches, gc_running / n_batches, cig_running / n_batches

    for epoch in range(1, args.train_epochs + 1):
        tr_mse, tr_gc, tr_cig = run_epoch(train_loader, training=True)
        va_mse, va_gc, va_cig = run_epoch(vali_loader,  training=False)
        print(f"[epoch {epoch:03d}] "
              f"train MSE {tr_mse:.6f} | GC {tr_gc:.6f} | CIG {tr_cig:.6f}  ||  "
              f"val MSE {va_mse:.6f} | GC {va_gc:.6f} | CIG {va_cig:.6f}")

    # optional: evaluate on test set
    te_mse, te_gc, te_cig = run_epoch(test_loader, training=False)
    print(f"[test] MSE {te_mse:.6f} | GC {te_gc:.6f} | CIG {te_cig:.6f}")

    return model


# =========================
# ---- CLI wiring to your args
# =========================
def build_arg_parser():
    p = argparse.ArgumentParser()
    # ---- dataset args (must match your data_provider expectations)
    p.add_argument('--root_path', type=str, default='./dataset/ETT-small/')
    p.add_argument('--data_path', type=str, default='ETTh1.csv')
    p.add_argument('--data', type=str, default='ETTh1', choices=['ETTh1','ETTh2','ETTm1','ETTm2','custom'])
    p.add_argument('--features', type=str, default='M', choices=['M','S','MS'])
    p.add_argument('--target', type=str, default='OT')
    p.add_argument('--freq', type=str, default='h')
    p.add_argument('--embed', type=str, default='timeF')  # controls timeenc in data_provider
    p.add_argument('--seq_len', type=int, default=96)     # context length (C)
    p.add_argument('--label_len', type=int, default=48)   # decoder warmup, unused by MLP but part of loader
    p.add_argument('--pred_len', type=int, default=96)    # horizon (H)
    p.add_argument('--batch_size', type=int, default=32)
    p.add_argument('--num_workers', type=int, default=0)
    # ---- training args
    p.add_argument('--learning_rate', type=float, default=1e-3)
    p.add_argument('--train_epochs', type=int, default=20)
    # ---- model/extractor/loss args
    p.add_argument('--hidden', type=int, default=512)
    p.add_argument('--depth', type=int, default=2)
    p.add_argument('--dropout', type=float, default=0.0)
    p.add_argument('--ridge_p', type=int, default=2)
    p.add_argument('--ridge_alpha', type=float, default=1e-2)
    p.add_argument('--gc_lambda', type=float, default=0.1)   # (informational; struct has own λ for CIG)
    p.add_argument('--gc_beta', type=float, default=4.0)
    p.add_argument('--gc_tau_q', type=float, default=0.3)
    p.add_argument('--use_cig', action='store_true')
    p.add_argument('--cig_lambda', type=float, default=0.05)
    p.add_argument('--cig_M', type=int, default=256)
    p.add_argument('--cig_eps_s', type=float, default=1e-8)
    p.add_argument('--cig_eps_inv', type=float, default=1e-5)
    p.add_argument('--device', type=str, default=None)
    return p


def main():
    parser = build_arg_parser()
    args = parser.parse_args()
    _ = train_with_factory(
        args=args,
        hidden=args.hidden,
        depth=args.depth,
        dropout=args.dropout,
        ridge_p=args.ridge_p,
        ridge_alpha=args.ridge_alpha,
        gc_lambda=args.gc_lambda,
        gc_beta=args.gc_beta,
        gc_tau_q=args.gc_tau_q,
        use_cig=args.use_cig,
        cig_lambda=args.cig_lambda,
        cig_M=args.cig_M,
        cig_eps_s=args.cig_eps_s,
        cig_eps_inv=args.cig_eps_inv,
        device=args.device,
    )


if __name__ == "__main__":
    main()
