In [2]:
import torch
import gpytorch
from gpytorch.kernels import RBFKernel

covar_module = RBFKernel()
theta_t = torch.rand(1,2,requires_grad=False)
X_hat = torch.rand(5,2,requires_grad=False)

def get_KxX_dx( x, X) :
    '''Computes the analytic derivative of the kernel K(x,X) w.r.t. x.

    Args:
        x: (n x D) Test points.

    Returns:
        (n x D) The derivative of K(x,X) w.r.t. x.
    '''
    N = X.shape[0]
    n = x.shape[0]
    D = x.shape[-1]
    
    
    K_xX = covar_module(x, X).evaluate()
    lengthscale = covar_module.lengthscale.detach()
    return (
        -torch.eye(D, device=X.device)
        / lengthscale**2
        @ (
            (x.view(n, 1, D) - X.view(1, N, D))
            * K_xX.view(n, N, 1)
        ).transpose(1, 2)
    )

def get_Kxx_dx2(x):
        """Computes the analytic second derivative of the kernel K(x,x) w.r.t. x.

        Args:
            x: (n x D) Test points.

        Returns:
            (n x D x D) The second derivative of K(x,x) w.r.t. x.
        """
        
        D = x.shape[-1]
        lengthscale = covar_module.lengthscale.detach()
        sigma_f = 1
        return (
            torch.eye(D, device=lengthscale.device) / lengthscale ** 2
        ) * sigma_f

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def K_θX(theta_t,X_hat):
    
    rslt = covar_module(theta_t,X_hat).evaluate()
    
    return rslt

def get_K_θX_dθ(theta_t,X_hat):
        
    jacobs = torch.autograd.functional.jacobian(func=lambda theta : K_θX(theta,X_hat),inputs=(theta_t))
    K_θX_dθ = jacobs.sum(dim=2).transpose(1,2)
    
    
    
    return K_θX_dθ

In [4]:
a = get_K_θX_dθ(theta_t,X_hat)
b = get_KxX_dx(theta_t,X_hat)

print(a.shape,b.shape)

assert ( (a-b) < 1e-5).all()

K_θX_dθtorch.Size([1, 2, 5])
torch.Size([1, 2, 5]) torch.Size([1, 2, 5])


In [5]:
a = get_K_θX_dθ(theta_t,theta_t)
b = get_KxX_dx(theta_t,theta_t)

print(a.shape,b.shape)

assert ( (a-b) < 1e-5).all()

K_θX_dθtorch.Size([1, 2, 1])
torch.Size([1, 2, 1]) torch.Size([1, 2, 1])


In [6]:
def get_K_θX_dθ2(theta_t,X_hat):
    
    jacobs = torch.autograd.functional.jacobian(func= lambda theta_t: get_K_θX_dθ(theta_t,X_hat),inputs=(theta_t))
    print(f'jacobs[0]{jacobs[0].shape}')
    ### we must put it in the right shape
    K_θθ_dθ2 = jacobs[0].sum(dim=2).transpose(1,0) 
    print(f'K_θθ_dθ2 {K_θθ_dθ2.shape}')
    return K_θθ_dθ2

a = get_K_θX_dθ2(theta_t,X_hat)
b = get_Kxx_dx2(theta_t)


print(a.shape,b.shape)

print(a.squeeze())
print(b.squeeze())

#assert ( (a-b) < 1e-5).all()

K_θX_dθtorch.Size([1, 2, 5])
jacobs[0]torch.Size([2, 5, 1, 2])
K_θθ_dθ2 torch.Size([5, 2, 2])
torch.Size([5, 2, 2]) torch.Size([2, 2])
tensor([[[0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.]]])
tensor([[2.0814, 0.0000],
        [0.0000, 2.0814]])


In [21]:
def last_hope(theta_t,X_hat):
    
    hessian = torch.autograd.functional.hessian(func=lambda theta : K_θX(theta,X_hat),inputs=(theta_t))
    
    return -hessian.squeeze()
    

hessian = last_hope(theta_t,theta_t)
hessian.squeeze()

tensor([[2.0814, -0.0000],
        [-0.0000, 2.0814]])

In [18]:
def get_Kxx_dx2(theta_t,X_hat):
    
    hessian = torch.autograd.functional.hessian(func=lambda theta : K_θX(theta,X_hat),inputs=(theta_t))
    print(hessian.shape)
    return -hessian


a = last_hope(theta_t,theta_t)
b = get_Kxx_dx2(theta_t)




torch.Size([1, 2, 1, 2])


tensor([[[[2.0814, -0.0000]],

         [[-0.0000, 2.0814]]]])

tensor([[2.0814, 0.0000],
        [0.0000, 2.0814]])

In [153]:
import torch

delta = torch.tensor([1.0,2,3],requires_grad=True)
g = torch.diag(delta)
a = torch.ones(3,10,3)

rslt = torch.sum(a@g@a.T)

torch.autograd.grad(outputs=rslt,inputs=delta)


RuntimeError: Expected batch2_sizes[0] == bs && batch2_sizes[1] == contraction_size to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

In [169]:
delta = torch.tensor([1.0,2,3],requires_grad=True)
g = torch.diag(delta)
a = torch.ones(1,10,3)

tmp = a@g
print(tmp.shape)
print(a.T.shape)



torch.Size([1, 10, 3])
torch.Size([3, 10, 1])


In [170]:
a@g@a.T

RuntimeError: Expected batch2_sizes[0] == bs && batch2_sizes[1] == contraction_size to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)