In [1]:
from typing import List, Dict, Set, Any, Optional, Tuple, Literal, Callable
import numpy as np
import torch
from torch import Tensor
import sigkernel

from kernels.static_kernels import StaticKernel, AbstractKernel
from kernels.static_kernels import LinearKernel, RBFKernel, PolyKernel


class CrisStaticWrapper:
    def __init__(
            self, 
            kernel: StaticKernel,
        ):
        """Wrapper for static kernels for Cris Salvi's sigkernel library"""
        self.kernel = kernel


    def batch_kernel(
            self, 
            X:Tensor, 
            Y:Tensor
        ) -> Tensor:
        """
        Outputs k(X^i_t, Y^j_t)

        Args:
            X (Tensor): Tensor of shape (N, T1, d)
            Y (Tensor): Tensor of shape (N, T2, d)

        Returns:
            Tensor: Tensor of shape (N, T1, T2)
        """
        X = X.transpose(1,0)
        Y = Y.transpose(1,0)
        trans_gram = self.kernel.gram(X, Y) # shape (T1, T2, N)
        return trans_gram.permute(2, 0, 1)


    def Gram_matrix(
            self, 
            X: Tensor, 
            Y: Tensor
        ) -> Tensor:
        """
        Outputs k(X^i_s, Y^j_t)
        
        Args:
            X (Tensor): Tensor of shape (N1, T1, d)
            Y (Tensor): Tensor of shape (N2, T2, d)
        
        Returns:
            Tensor: Tensor of shape (N1, N2, T1, T2)
        """
        N1, T1, d = X.shape
        N2, T2, d = Y.shape
        X = X.reshape(-1, d)
        Y = Y.reshape(-1, d)
        flat_gram = self.kernel.gram(X, Y) # shape (N1 * T1, N2 * T2)
        gram = flat_gram.reshape(N1, T1, N2, T2)
        return gram.permute(0, 2, 1, 3)
    
    

class SigPDEKernel(AbstractKernel):
    def __init__(
            self,
            static_kernel: StaticKernel = RBFKernel(),
            dyadic_order:int = 1,
            max_batch:int = 10,
        ):
        """
        Signature PDE kernel for timeseries (x_1, ..., x_T) in R^d,
        kernelized with a static kernel k : R^d x R^d -> R.

        Args:
            static_kernel (StaticKernel): Static kernel on R^d.
            dyadic_order (int, optional): Dyadic order in PDE solver. Defaults to 1.
            max_batch (int, optional): Max batch size for computations. Defaults to 10.
        """
        self.static_wrapper = CrisStaticWrapper(static_kernel)
        self.dyadic_order = dyadic_order
        self.sig_ker = sigkernel.SigKernel(self.static_wrapper, dyadic_order)
        self.max_batch = max_batch


    def gram(
            self, 
            X: Tensor, 
            Y: Tensor, 
            diag: bool = False, 
        ):
        """
        Computes the Gram matrix K(X_i, Y_j), or the diagonal K(X_i, Y_i) 
        if diag=True. The time series in X are of shape (T1, d), and the
        time series in Y are of shape (T2, d), where d is the path dimension.

        Args:
            X (Tensor): Tensor with shape (N1, T1, d).
            Y (Tensor): Tensor with shape (N2, T2, d).
            diag (bool, optional): If True, only computes the kernel for the 
                pairs K(X_i, Y_i). Defaults to False.

        Returns:
            Tensor: Tensor with shape (N1, N2), or (N1) if diag=True.
        """
        if diag:
            return self.sig_ker.compute_kernel(X, Y, self.max_batch)
        else:
            return self.sig_ker.compute_Gram(X, Y, sym=(X is Y), max_batch=self.max_batch)


    def __call__(
            self, 
            X: Tensor, 
            Y: Tensor, 
        )->Tensor:
        """
        Computes the kernel evaluation k(X, Y) of two time series 
        (with batch support). The time series in X are of shape (T1, d), 
        and the time series in Y are of shape (T2, d), where d is the 
        path dimension.

        Args:
            X (Tensor): Tensor with shape (... , T1, d).
            Y (Tensor): Tensor with shape (... , T2, d), with (...) same as X.
        
        Returns:
            Tensor: Tensor with shape (...).
        """
        if X.ndim == 2 and Y.ndim == 2:
            X = X.unsqueeze(0)
            Y = Y.unsqueeze(0)
        return self.sig_ker.compute_kernel(X, Y, self.max_batch)


In [2]:
###########################################
### Playing around with sigkernel
###########################################


# Specify the static kernel (for linear kernel use sigkernel.LinearKernel())
static_kernel = sigkernel.RBFKernel(sigma=0.5)

# Specify dyadic order for PDE solver (int > 0, default 0, the higher the more accurate but slower)
dyadic_order = 3

# Specify maximum batch size of computation; if memory is a concern try reducing max_batch, default=100
max_batch = 11

# Initialize the corresponding signature kernel
signature_kernel = sigkernel.SigKernel(static_kernel, dyadic_order)

# Synthetic data
batch = 5
batch_z = 6
len_x = 7
len_y = 8
dim = 3
device = "cpu"
torch.manual_seed(0)
X = torch.rand((batch,len_x,dim), dtype=torch.float64, device=device) # shape (batch,len_x,dim)
Y = torch.rand((batch,len_y,dim), dtype=torch.float64, device=device) # shape (batch,len_y,dim)
Z = torch.rand((batch_z,len_x,dim), dtype=torch.float64, device=device) # shape (batch,len_y,dim)

# Compute signature kernel "batch-wise" (i.e. k(x_1,y_1),...,k(x_batch, y_batch))
K = signature_kernel.compute_kernel(X,Y,max_batch)

# Compute signature kernel Gram matrix (i.e. k(x_i,y_j) for i,j=1,...,batch), also works for different batch_x != batch_y)
G = signature_kernel.compute_Gram(X,X,sym=True, max_batch=max_batch)

print("X", X.shape)
print("Y", Y.shape)
print("Z", Z.shape)
print("K", K.shape)
print("G", G.shape)

X torch.Size([5, 7, 3])
Y torch.Size([5, 8, 3])
Z torch.Size([6, 7, 3])
K torch.Size([5])
G torch.Size([5, 5])


In [3]:
from kernels.static_kernels import StaticKernel, AbstractKernel
from kernels.static_kernels import LinearKernel, RBFKernel, PolyKernel
from kernels.integral_type import IntegralKernel

# Start with static kernel experiments

N1 = 5
N2 = 6
d = 3
torch.manual_seed(0)
X = torch.rand((N1, d))
Y = torch.rand((N2, d))
Z = torch.rand((N1, d))
print("X", X.shape)
print("Y", Y.shape)
print("Z", Z.shape)

def test_kernel(kernel: AbstractKernel):
    print("Testing", kernel)

    gram = kernel.gram(X, Y)
    diag = kernel.gram(X, Z, diag=True)
    batch_call = kernel(X, Z)
    call = kernel(X[0], Z[0])
    print("gram", gram.shape)
    print("diag", diag.shape)
    print("batch_call", batch_call.shape)
    print("call", call.shape, call)
    print("\n")


linear = LinearKernel()
rbf = RBFKernel()
poly = PolyKernel()
test_kernel(linear)
test_kernel(rbf)
test_kernel(poly)

X torch.Size([5, 3])
Y torch.Size([6, 3])
Z torch.Size([5, 3])
Testing <kernels.static_kernels.LinearKernel object at 0x7ffbe06b4210>
gram torch.Size([5, 6])
diag torch.Size([5])
batch_call torch.Size([5])
call torch.Size([1]) tensor([0.2525])


Testing <kernels.static_kernels.RBFKernel object at 0x7ffae0ae53d0>
gram torch.Size([5, 6])
diag torch.Size([5])
batch_call torch.Size([5])
call torch.Size([1]) tensor([0.6467])


Testing <kernels.static_kernels.PolyKernel object at 0x7ffae0ae4f50>
gram torch.Size([5, 6])
diag torch.Size([5])
batch_call torch.Size([5])
call torch.Size([1]) tensor([1.5687])




In [5]:
N1 = 5
N2 = 6
T = 7
d = 3
torch.manual_seed(0)
X = torch.rand((N1, T, d), dtype=torch.float64)
Y = torch.rand((N2, T, d), dtype=torch.float64)
Z = torch.rand((N1, T, d), dtype=torch.float64)
print("X", X.shape)
print("Y", Y.shape)
print("Z", Z.shape)

integral = IntegralKernel(rbf)
sigpde = SigPDEKernel(rbf)
test_kernel(integral)
test_kernel(sigpde)

X torch.Size([5, 7, 3])
Y torch.Size([6, 7, 3])
Z torch.Size([5, 7, 3])
Testing <kernels.integral_type.IntegralKernel object at 0x7ffbe5077bd0>
gram torch.Size([5, 6])
diag torch.Size([5])
batch_call torch.Size([5])
call torch.Size([1]) tensor([0.7511], dtype=torch.float64)


Testing <__main__.SigPDEKernel object at 0x7ffae0b01f90>


IndexError: tuple index out of range