In [2]:
import torch
import time
import numpy as np

import sigkernel

import matplotlib.pyplot as plt

In [3]:
dyadic_order = 5
_naive_solver = False

In [4]:
# specify static kernel
static_kernel = sigkernel.LinearKernel()
# static_kernel = sigkernel.RBFKernel(sigma=.5)

# initialize signature kernel
signature_kernel = sigkernel.SigKernel(static_kernel, dyadic_order, _naive_solver)

# Sig Loss gradients

In [5]:
A = 1
M = 3
N = 2
D = 2

X = np.random.randn(A,M,D).cumsum(axis=1)
Y = np.random.randn(A,N,D).cumsum(axis=1)

X /= np.max(X)
Y /= np.max(Y)

In [6]:
X_naive = torch.tensor(X, dtype=torch.float32)
Y_naive = torch.tensor(Y, dtype=torch.float32)

X_cpu = X_naive.clone()
Y_cpu = Y_naive.clone()

# X_gpu = X_naive.clone().cuda()
# Y_gpu = Y_naive.clone().cuda()

X_naive.requires_grad = True
X_cpu.requires_grad = True
# X_gpu.requires_grad = True

In [8]:
t = time.time()
l_naive = sigkernel.SigLoss_naive(static_kernel, dyadic_order, _naive_solver).forward(X_naive,Y_naive)
print('time:', np.round(time.time()-t,3), 's')
print(l_naive)

time: 3.62 s
tensor(4.9382, grad_fn=<SubBackward0>)


In [9]:
t = time.time()
l_cpu = signature_kernel.compute_distance(X_cpu,Y_cpu)
print('time:', np.round(time.time()-t,3), 's')
print(l_cpu)

ValueError: Buffer dtype mismatch, expected 'double' but got 'float'

In [37]:
# t = time.time()
# l_gpu = signature_kernel.compute_distance(X_gpu,Y_gpu)
# print('time:', np.round(time.time()-t,3), 's')
# print(l_gpu)

In [38]:
t = time.time()
l_naive.backward()
print('time:', np.round(time.time()-t,3), 's')

time: 1.673 s


In [39]:
t = time.time()
l_cpu.backward()
print('time:', np.round(time.time()-t,3), 's')

time: 0.004 s


In [40]:
# t = time.time()
# l_gpu.backward()
# print('time:', np.round(time.time()-t,3), 's')

In [41]:
X_naive.grad

tensor([[[-10.4859,  23.4210],
         [ -0.8974,  -0.4398],
         [ 11.3833, -22.9812]]], dtype=torch.float64)

In [42]:
X_cpu.grad

tensor([[[-10.0129,  22.4212],
         [ -0.7487,  -0.6734],
         [ 10.7616, -21.7478]]], dtype=torch.float64)

In [43]:
# X_gpu.grad.cpu()

# Sig MMD gradients

In [44]:
A = 2
B = 3
M = 4
N = 3
D = 2

X = np.random.randn(A,M,D).cumsum(axis=1)
Y = np.random.randn(B,N,D).cumsum(axis=1)

X /= np.max(X)
Y /= np.max(Y)

In [45]:
X_naive = torch.tensor(X, dtype=torch.float64)
Y_naive = torch.tensor(Y, dtype=torch.float64)

X_cpu = X_naive.clone()
Y_cpu = Y_naive.clone()

# X_gpu = X_naive.clone().cuda()
# Y_gpu = Y_naive.clone().cuda()

X_naive.requires_grad = True
X_cpu.requires_grad = True
# X_gpu.requires_grad = True

In [46]:
t = time.time()
mmd_naive = sigkernel.SigMMD_naive(static_kernel, dyadic_order, _naive_solver).forward(X_naive,Y_naive)
print('time:', np.round(time.time()-t,3), 's')
print(mmd_naive)

time: 3.11 s
tensor(114.2561, dtype=torch.float64, grad_fn=<SubBackward0>)


In [47]:
t = time.time()
mmd_cpu = signature_kernel.compute_mmd(X_cpu,Y_cpu)
print('time:', np.round(time.time()-t,3), 's')
print(mmd_cpu)

time: 0.008 s
tensor(114.2561, dtype=torch.float64, grad_fn=<SubBackward0>)


In [48]:
# t = time.time()
# mmd_gpu = signature_kernel.compute_mmd(X_gpu,Y_gpu)
# print('time:', np.round(time.time()-t,3), 's')
# print(mmd_gpu)

In [49]:
t = time.time()
mmd_naive.backward()
print('time:', np.round(time.time()-t,3), 's')

time: 15.692 s


In [50]:
t = time.time()
mmd_cpu.backward()
print('time:', np.round(time.time()-t,3), 's')

time: 0.026 s


In [51]:
# t = time.time()
# mmd_gpu.backward()
# print('time:', np.round(time.time()-t,3), 's')

In [52]:
X_naive.grad

tensor([[[  0.1502,  -3.0669],
         [ -0.2185,   0.3906],
         [ -0.0877,   0.0856],
         [  0.1559,   2.5907]],

        [[  2.2565, -20.2559],
         [  0.3197,  -0.8927],
         [  3.9567,   1.1642],
         [ -6.5329,  19.9844]]], dtype=torch.float64)

In [53]:
X_cpu.grad

tensor([[[  0.0916,  -2.9034],
         [ -0.1113,   0.1997],
         [ -0.1646,   0.2093],
         [  0.1842,   2.4944]],

        [[  2.2402, -19.5380],
         [  0.4181,  -0.4812],
         [  3.4843,   1.8208],
         [ -6.1426,  18.1984]]], dtype=torch.float64)