In [74]:
import numpy as np
import torch 
import matplotlib.pyplot as plt

import sigkernel as ksig_pde
import sigkernel_ as ksig_disc
import utils.data
from generators.synthetic_generators import *
from generators.ESN import ESNGenerator
from sigkernel_.loss import compute_mmd_loss
from sigkernel_.kernels import gram

%load_ext autoreload
%autoreload 2

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
dtype=torch.float64

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Kernels

In [75]:
# Define the discretized signature kernel
static_kernel_type_disc    = 'rq' # type of static kernel to use - rbf, rbfmix, rq, rqmix, rqlinear for
n_levels              = 10 # number of levels in the truncated signature kernel
disc_sig_kernel_sigma = 1e-5 # bandwidth parameter for the static kernel

kwargs_disc_sig = {
    'static_kernel_type': static_kernel_type_disc,
    'n_levels': n_levels,
    'kernel_sigma': disc_sig_kernel_sigma,
}

sig_disc_kernel = ksig_disc.kernels.get_discretized_signature_kernel(**kwargs_disc_sig)

#---------------------------------
# Define the pde signature kernel
pde_sig_sigma = 1e1 # bandwidth parameter for the static kernel
static_kernel_pde = ksig_pde.RBFKernel(sigma=pde_sig_sigma) # define static kernel

# Initialize the corresponding signature kernel
sig_pde_kernel = ksig_pde.SigKernel(static_kernel_pde, dyadic_order=0)

#---------------------------------
# Define static kernel
static_sigma = 1e1 # bandwidth parameter for the static kernel
rbf_kernel = ksig_disc.kernels.RBFKernel(sigma=static_sigma)

# Generator and ESN

In [76]:
# Generator and ESN details ----------------------
T = 200 # length of time series
N = 20  # number of samples
d = 1   # dimension of time series
p = 2   # AR order
q = 0   # MA order
phi = [0.7, -0.2] # AR coefficients
theta = None # MA coefficients

h, m, d = 500, 20, 1 # ESN hyperparameters: reservoir size, input dimension, output dimension
A = 0.9 * torch.randn(h, h) / (h ** 0.5) # ESN reservoir weight matrix
C = torch.randn(h, m) / (m ** 0.5) # ESN input weight matrix

# Define data generator and esn
target_generator = ARMA(T=T, p=p, q=q, phi=phi, theta=theta)
esn = ESNGenerator(A, C, out_dim=d, xi_scale=1.0, eta_scale=0.05)

# Generate examples

In [77]:
with torch.no_grad():
    X = target_generator.generate(N=N, noise=Noise("normal")).to(device=device, dtype=dtype)
    Z = esn(T=T, N=N).to(device=device, dtype=dtype)
print(f"Generated data shapes: X: {X.shape}, Z: {Z.shape}")

Generated data shapes: X: torch.Size([20, 200, 1]), Z: torch.Size([20, 200, 1])


# RBF

In [78]:
rbf_gram = gram(rbf_kernel, X, Z)
print("RBF Gram matrix stats:")
print(f"Shape: {rbf_gram.shape}")
print(f"Min: {rbf_gram.min().item():.4f}")
print(f"Max: {rbf_gram.max().item():.4f}")
print(f"Mean: {rbf_gram.mean().item():.4f}")
print(f"Std: {rbf_gram.std().item():.4f}")

print("First 3 rows and columns of RBF Gram matrix:")
print(rbf_gram[:3, :3])

RBF Gram matrix stats:
Shape: torch.Size([20, 200, 200])
Min: 0.3954
Max: 1.0000
Mean: 0.9607
Std: 0.0527
First 3 rows and columns of RBF Gram matrix:
tensor([[[0.9982, 0.9992, 0.9923,  ..., 0.9440, 0.8086, 0.9986],
         [0.9555, 0.9611, 0.9932,  ..., 0.9952, 0.9192, 0.9574],
         [0.9797, 0.9834, 0.9998,  ..., 0.9808, 0.8783, 0.9809]],

        [[0.9199, 0.9875, 0.9975,  ..., 0.8907, 0.9959, 0.9995],
         [0.9471, 0.9968, 0.9888,  ..., 0.8549, 0.9999, 0.9989],
         [0.9886, 0.9952, 0.9477,  ..., 0.7617, 0.9863, 0.9751]],

        [[0.9592, 0.9022, 0.8960,  ..., 0.8252, 0.8837, 0.9811],
         [0.9983, 0.9944, 0.9927,  ..., 0.9636, 0.9888, 0.9885],
         [0.9825, 0.9395, 0.9345,  ..., 0.8739, 0.9243, 0.9955]]],
       dtype=torch.float64)


In [79]:
mmd_rbf = compute_mmd_loss(rbf_kernel, X.reshape(N, -1), Z.reshape(N, -1))
print(f"MMD (RBF kernel): {mmd_rbf.item():.6f}")

MMD (RBF kernel): 0.001602


# SIG-PDE

In [80]:
sig_pde_gram = gram(sig_pde_kernel, X, Z)
print("sig_pde Gram matrix stats:")
print(f"Shape: {sig_pde_gram.shape}")
print(f"Min: {sig_pde_gram.min().item():.4f}")
print(f"Max: {sig_pde_gram.max().item():.4f}")
print(f"Mean: {sig_pde_gram.mean().item():.4f}")
print(f"Std: {sig_pde_gram.std().item():.4f}")

print("First 3 rows and columns of sig_pde Gram matrix:")
print(sig_pde_gram[:3, :3])

sig_pde Gram matrix stats:
Shape: torch.Size([20, 20])
Min: -2.8340
Max: 35.4731
Mean: 7.1872
Std: 4.2923
First 3 rows and columns of sig_pde Gram matrix:
tensor([[ 3.8300,  7.2270,  3.5549],
        [ 8.0353,  7.9639,  6.7139],
        [ 7.8823, 13.6096,  7.5740]], dtype=torch.float64)


In [81]:
mmd_sig_pde = compute_mmd_loss(sig_pde_kernel, X, Z)
print(f"MMD (sig_pde kernel): {mmd_sig_pde.item():.6f}")

MMD (sig_pde kernel): 702.456662


# SIG-DISC

In [82]:
sig_disc_gram = gram(sig_disc_kernel, X, Z)
print("sig_disc Gram matrix stats:")
print(f"Shape: {sig_disc_gram.shape}")
print(f"Min: {sig_disc_gram.min().item():.4f}")
print(f"Max: {sig_disc_gram.max().item():.4f}")
print(f"Mean: {sig_disc_gram.mean().item():.4f}")
print(f"Std: {sig_disc_gram.std().item():.4f}")

print("First 3 rows and columns of sig_disc Gram matrix:")
print(sig_disc_gram[:3, :3])

sig_disc Gram matrix stats:
Shape: torch.Size([20, 20])
Min: 1.0000
Max: 3.2625
Mean: 1.1787
Std: 0.3428
First 3 rows and columns of sig_disc Gram matrix:
tensor([[1.0002, 1.0058, 1.1083],
        [1.0089, 1.0003, 1.0010],
        [1.0165, 1.8155, 1.0381]], dtype=torch.float64)


In [83]:
mmd_sig_disc = compute_mmd_loss(sig_disc_kernel, X, Z)
print(f"MMD (sig_disc kernel): {mmd_sig_disc.item():.6f}")

MMD (sig_disc kernel): 0.024633
