In [14]:
import time
import numpy as np
from numba import njit
from typing import Callable
import sigkernel
from signature import streams_to_sigs
from experiment_code import case_sig_pde
from kernels import linear_kernel_gram


def calc_iisig_kernel(X, Y, order):
    sig_X, sig_Y = streams_to_sigs([X,Y], order, disable_tqdm=True)
    dot = 1 + np.dot(sig_X, sig_Y)
    return dot


def calc_sigpde_kernel(X,Y):
    dyadic_order = 3
    static_kernel = sigkernel.LinearKernel()
    vv, uv = case_sig_pde([X], [Y], dyadic_order, static_kernel)
    return uv[0,0]


# def calc_ksig_kernel(X,Y, order):
#     import ksig
#     static_kernel = ksig.static.kernels.LinearKernel() 
#     sig_kernel = ksig.kernels.SignatureKernel(n_levels=order, order=1, static_kernel=static_kernel, normalize=False)
#     dot = sig_kernel(np.array([X,X]), np.array([Y,Y]))[0,0]
#     return dot


def trunc_sig_kernel(s1:np.ndarray, 
                    s2:np.ndarray, 
                    order:int, #order is truncation level of the signature
                    static_kernel_gram:Callable = linear_kernel_gram,
                    only_last:bool = True,

                    ):
    """s1 and s2 are time series of shape (T_i, d)"""
    K = static_kernel_gram(s1, s2, divide_by_dims=False)
    nabla = K[1:, 1:] + K[:-1, :-1] - K[1:, :-1] - K[:-1, 1:]
    sig_kers = jitted_trunc_sig_kernel(nabla, order)
    if only_last:
        return sig_kers[-1]
    else:
        return sig_kers



@njit
def reverse_cumsum(arr:np.ndarray, axis:int): #ndim=2
    """JITed reverse cumulative sum along the specified axis.
    (np.cumsum with axis is not natively supported by Numba)"""
    A = arr.copy()
    if axis==0:
        for i in np.arange(A.shape[0]-2, -1, -1):
            A[i, :] += A[i+1, :]
    else: #axis==1
        for i in np.arange(A.shape[1]-2, -1, -1):
            A[:,i] += A[:,i+1]
    return A


@njit
def jitted_trunc_sig_kernel(nabla:np.ndarray, # gram matrix (T_1, T_2)
                            order:int,
                            ):
    """Given difference matrix nabla_ij = K[i+1, j+1] + K[i, j] - K[i+1, j] - K[i, j+1],
    computes the truncated signature kernel of all orders up to 'order'."""
    B = np.ones((order+1, order+1, order+1, *nabla.shape))
    for d in np.arange(order):
        for n in np.arange(order-d):
            for m in np.arange(order-d):
                B[d+1,n,m] = 1 + nabla/(n+1)/(m+1)*B[d, n+1, m+1]
                r1 = reverse_cumsum(nabla * B[d, n+1, 1] / (n+1), axis=0)
                B[d+1,n,m, :-1, :] += r1[1:, :]
                r2 = reverse_cumsum(nabla * B[d, 1, m+1] / (m+1), axis=1)
                B[d+1,n,m, :, :-1] += r2[:, 1:]
                rr = reverse_cumsum(nabla * B[d, 1, 1], axis=0)
                rr = reverse_cumsum(rr, axis=1)
                B[d+1,n,m, :-1, :-1] += rr[1:, 1:]

    return B[:,0,0,0,0]
    
    



d = 2
MAX_ORDER = 18
times_iisig = np.zeros( (MAX_ORDER) )
times_sigker  = np.zeros( (MAX_ORDER) )
times_sigpde = np.zeros( (MAX_ORDER) )
np.random.seed(99)
X, Y = np.random.randn(2, 45, d)/np.sqrt(d)
for order in range(1, MAX_ORDER+1):
    print("\norder", order)
    t0= time.time()
    dot1=calc_iisig_kernel(X, Y, order)
    t1 = time.time()
    dot2=trunc_sig_kernel(X, Y, order)
    t2 = time.time()
    dot3=calc_sigpde_kernel(X, Y)
    t3 = time.time()
    times_iisig[order-1] = t1-t0
    times_sigker[order-1] = t2-t1
    times_sigpde[order-1] = t3-t2
    print("dot1", dot1)
    print("dot2", dot2)
    print("dot3", dot3)




print("\ncomparison", times_iisig[1:]/times_sigker[1:])
print("\niisig", times_iisig[1:])
print("\nsigker", times_sigker[1:])
print("\npde", times_sigpde[1:])


order 1
dot1 -0.16533720040547983
dot2 -0.1653372004054815
dot3 -14.176695989907277

order 2
dot1 -4.874092158600572
dot2 -4.874092158600578
dot3 -14.176695989907277

order 3
dot1 -11.842561772383382
dot2 -11.842561772383384
dot3 -14.176695989907277

order 4
dot1 22.627855724494943
dot2 22.627855724494925
dot3 -14.176695989907277

order 5
dot1 -11.258366856587294
dot2 -11.258366856587227
dot3 -14.176695989907277

order 6
dot1 45.8559864678114
dot2 45.85598646781163
dot3 -14.176695989907277

order 7
dot1 35.222585902252995
dot2 35.22258590225276
dot3 -14.176695989907277

order 8
dot1 46.0297297687434
dot2 46.029729768744026
dot3 -14.176695989907277

order 9
dot1 20.91851932368872
dot2 20.91851932368812
dot3 -14.176695989907277

order 10
dot1 85.66399990828846
dot2 85.6639999082895
dot3 -14.176695989907277

order 11
dot1 -38.224053792512734
dot2 -38.224053792513196
dot3 -14.176695989907277

order 12
dot1 77.89729096799529
dot2 77.89729096799503
dot3 -14.176695989907277

order 13
dot1 -1

In [23]:
import time
import numpy as np
import sigkernel
from kernels import sig_kernel_gram, pairwise_kernel_gram, linear_kernel_gram
from conformance import stream_to_torch
from kernels import sig_kernel

d=2
train = np.random.randn(10, 45, d)/np.sqrt(d)
test = np.random.randn(11, 45, d)/np.sqrt(d)



static_kernel = lambda x,y: linear_kernel_gram(x, y, divide_by_dims=False)
ORDER=10
salvi_ker = sigkernel.SigKernel(sigkernel.LinearKernel(), dyadic_order=3)
kernel = lambda s1, s2 : salvi_ker.compute_kernel(
                            stream_to_torch(s1), 
                            stream_to_torch(s2)).numpy()[0]

#experiments
t0 = time.perf_counter()
UV = sig_kernel_gram(test, train, ORDER, static_kernel, sym=False, verbose=True)
t1 = time.perf_counter()
UV2 = pairwise_kernel_gram(test, train, kernel, sym=False, verbose=True)
t2 = time.perf_counter()
UV3 = np.zeros( (test.shape[0], train.shape[0]) )
for i in range(test.shape[0]):
    for j in range(train.shape[0]):
        UV3[i,j] = sig_kernel(test[i], train[j], ORDER, static_kernel)
t3 = time.perf_counter()
UV4 = np.zeros( (test.shape[0], train.shape[0]) )
for i in range(test.shape[0]):
    for j in range(train.shape[0]):
        UV4[i,j] = kernel(test[i], train[j])
t4 = time.perf_counter()
print("sigker", t1-t0, "pde", t2-t1, "brute", t3-t2, "brute2", t4-t3)


Kernel Gram Matrix: 100%|██████████| 110/110 [00:02<00:00, 42.27it/s]
Kernel Gram Matrix: 100%|██████████| 110/110 [00:00<00:00, 427.59it/s]


sigker 2.6064247699996486 pde 0.25984002000041073 brute 2.215484223999738 brute2 0.2430941869997696
