In [1]:
import numpy as np
import torch 
import matplotlib.pyplot as plt

import sigkernel as ksig_pde
import sigkernel_ as ksig_disc
import utils.data
from generators.generators import *
from generators.ESN import ESNGenerator
from models.ESN_MMD_mvp_trainer import train_ESN_MMD_mvp 
from sigkernel_.loss import compute_mmd_loss

%load_ext autoreload
%autoreload 2

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
dtype=torch.float64

  align: pd.Timedelta | str = pd.Timedelta(1, "T"),


# Kernels

In [None]:
# Get discretized signature kernel
static_kernel_type = 'rq' # type of static kernel to use - rbf, rbfmix, rq, rqmix, rqlinear for
n_levels           = 10 # number of levels in the truncated signature kernel
disc_sig_kernel_sigma       = 1e-5 # bandwidth parameter for the static kernel

kwargs_disc_sig = {
    'static_kernel_type': static_kernel_type,
    'n_levels': n_levels,
    'kernel_sigma': disc_sig_kernel_sigma,
}

# Define the discretized signature kernel
sig_disc_kernel = ksig_disc.kernels.get_discretized_signature_kernel(**kwargs_disc_sig)

#---------------------------------
# Get pde signature kernel
pde_sig_sigma = 1e1 # bandwidth parameter for the static kernel
static_kernel = ksig_pde.RBFKernel(sigma=pde_sig_sigma) # define static kernel

# Initialize the corresponding signature kernel
sig_pde_kernel = ksig_pde.SigKernel(static_kernel, dyadic_order=0)

#---------------------------------
# Get static kernel
static_sigma = 1e3 # bandwidth parameter for the static kernel
rbf_kernel = ksig_disc.kernels.RBFKernel(sigma=static_sigma)

# Generator, kernel, and training parameters

In [3]:
# Generator and ESN details ----------------------
T = 200 # length of time series
N = 50  # number of samples
d = 1   # dimension of time series
p = 2   # AR order
q = 0   # MA order
phi = [0.7, -0.2] # AR coefficients
theta = None # MA coefficients

h, m, d = 500, 20, 1 # ESN hyperparameters: reservoir size, input dimension, output dimension
A = 0.9 * torch.randn(h, h) / (h ** 0.5) # ESN reservoir weight matrix
C = torch.randn(h, m) / (m ** 0.5) # ESN input weight matrix

# Define data generator and esn
x_ar = ARMA(T=T, p=p, q=q, phi=phi, theta=theta)
esn = ESNGenerator(A, C, out_dim=d, xi_scale=1.0, eta_scale=0.05)

target_generator = x_ar

# Kernel details --------------------------------
# choose one:
# kernel = sig_disc_kernel      # discretized signature kernel
kernel = sig_pde_kernel        # pde signature kernel
# kernel = rbf_kernel           # static kernel

if kernel is rbf_kernel:
    kernel_mode = "static"
else:
    kernel_mode = "sequential"

lead_lag = False
lags = 1

# Training details ------------------------------
epochs=2000
batch_size=20
lr=1e-3

# Training

In [4]:
hist = train_ESN_MMD_mvp(
    esn=esn,
    kernel=kernel,
    kernel_mode=kernel_mode,      # flatten (B,T,d)->(B,T*d) inside the trainer
    T=T,
    epochs=epochs,
    batch_size=batch_size,
    lr=lr,
    lead_lag=lead_lag,
    lags=lags,
    target_generator=target_generator,     # <- oracle target, no dataloader
    device=device,
    dtype=dtype,
)

print("final loss:", hist["losses"][-1])
plt.figure(figsize=(8, 4))
plt.plot(hist["losses"])
plt.title("Training Loss History")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(True)
plt.show()

  0%|          | 0/2000 [00:03<?, ?it/s]


KeyboardInterrupt: 