In [None]:
# re-run after state reset
import numpy as np
import matplotlib.pyplot as plt
from numpy.linalg import svd

def ssa_decompose(series, window, center=False, normalize=False):
    N = len(series)
    K = N - window + 1
    X = np.vstack([series[i:i+K] for i in range(window)])
    means = np.zeros((window,1))
    stds = np.ones((window,1))
    if center:
        means = X.mean(axis=1, keepdims=True)
        X = X - means
    if normalize:
        stds = X.std(axis=1, keepdims=True, ddof=0)
        stds[stds==0] = 1.0
        X = X / stds
    U, s, Vt = svd(X, full_matrices=False)
    return dict(U=U, s=s, Vt=Vt, means=means, stds=stds, window=window)

def reconstruct_trend(obj, idx):
    U, s, Vt = obj['U'], obj['s'], obj['Vt']
    means, stds = obj['means'], obj['stds']
    M = obj['window']
    Xn_hat = (U[:, idx] * s[idx]) @ Vt[idx, :]
    X_hat = Xn_hat * stds + means
    M, K = X_hat.shape
    N = M + K - 1
    recon = np.zeros(N)
    counts = np.zeros(N)
    for i in range(M):
        for j in range(K):
            recon[i+j] += X_hat[i,j]
            counts[i+j] +=1
    return recon / counts

N=400
t=np.arange(N)
trend=np.exp(0.1*t)
osc=0.5*np.sin(2*np.pi*t/20)
noise=0.1*np.random.randn(N)
series=trend*(1+osc)*(1+noise)
#log_series=np.log(series)
log_series=series
M=60

ssa_cov=ssa_decompose(log_series,M,center=False,normalize=False)
#trend_cov=np.exp(reconstruct_trend(ssa_cov,[0,1,2]))
trend_cov=reconstruct_trend(ssa_cov,[0,1,2,3,4,5,6,7,8,9])
trend_cov2=reconstruct_trend(ssa_cov,[0,1,2,3,4])

ssa_cor=ssa_decompose(log_series,M,center=True,normalize=True)
#trend_cor=np.exp(reconstruct_trend(ssa_cor,[0,1]))
trend_cor=reconstruct_trend(ssa_cor,[0,1])

plt.figure(figsize=(12,6))
plt.plot(t,series,color='black',label='Original series')
plt.plot(t,trend_cov,color='tab:blue',label='SSA trend (covariance)')
plt.plot(t,trend_cov2,color='tab:red',label='SSA trend (covariance)')
# plt.plot(t,trend_cor,color='tab:red',label='SSA trend (correlation)')
plt.title("Exponential trend extraction via SSA (after proper back‑transform)")
plt.xlabel("t"); plt.ylabel("value"); plt.legend(); plt.grid(True); plt.tight_layout()
plt.show()

res_cov=series-trend_cov
res_cor=series-trend_cor
print("Residual RMS covariance:",np.sqrt(np.mean(res_cov**2)))
print("Residual RMS correlation:",np.sqrt(np.mean(res_cor**2)))
