In [78]:
import torch
import gpytorch

def compute_quadratic_form(z_new, z_data, kernel):
    # Evaluate the RBF kernel on the batch of vectors z
    K = kernel(z_data, z_data).to_dense()
    
    print(K.shape)
    
    # Get the Cholesky decomposition of the kernel matrix
    L = torch.cholesky(K)
    
    # Solve L^T x = z for x using forward substitution
    x = torch.triangular_solve(z_new.unsqueeze(-1), L, upper=False).solution.squeeze(-1)
    
    # Compute the dot product of x with itself
    result = torch.dot(x.view(-1), x.view(-1))
    
    return result

def compute_quadratic_form_chol(z_new, z_data, kernel):
    K = kernel(z_data, z_data).to_dense()   # (batch_size, batch_size)
    
    print("kernel matrix shape ->", K.shape)
    
    # Compute the Cholesky decomposition of K
    L = torch.linalg.cholesky(K)
    
    K_inv = torch.inverse(K)
    
    k_vec = kernel(z_new, z_data).to_dense()    # (vec_batch, batch_size)
    
    print("kernel vector shape ->", k_vec.shape)
    
    res = torch.matmul(k_vec, K_inv)
    
    print("res ->", res.shape)
    
    res2 = torch.matmul(res, k_vec.t())
    
    print("res2 ->", res2.shape)
    
    res_final = (kernel(z_new, z_new) - res2).to_dense()
    
    print("res final ->", res_final.shape)
    
    return res_final

def compute_quadratic_form2(z_new, z_data, kernel):
    K = kernel(z_data, z_data).to_dense()   # (batch_size, batch_size)
    
    print("kernel matrix shape ->", K.shape)
    
    K_inv = torch.inverse(K)
    
    k_vec = kernel(z_new, z_data).to_dense()    # (vec_batch, batch_size)
    
    print("kernel vector shape ->", k_vec.shape)
    
    res = torch.matmul(k_vec, K_inv)
    
    print("res ->", res.shape)
    
    res2 = torch.matmul(res, k_vec.t())
    
    print("res2 ->", res2.shape)
    
    res_final = (kernel(z_new, z_new) - res2).to_dense()
    
    print("res final ->", res_final.shape)
    
    return res_final

# More numerically stable: https://pytorch.org/docs/stable/generated/torch.linalg.inv.html#torch.linalg.inv
def compute_quadratic_form3(z_new, z_data, kernel):
    K = kernel(z_data, z_data).to_dense()   # (batch_size, batch_size)
    
    print("kernel matrix shape ->", K.shape)
    
    k_vec = kernel(z_new, z_data).to_dense()    # (vec_batch, batch_size)
    
    print("kernel vector shape ->", k_vec.shape)
    
    res = torch.linalg.solve(K, k_vec.squeeze(0)).unsqueeze(0)
    
    print("res ->", res.shape)
    
    res2 = torch.matmul(res, k_vec.t())
    
    print("res2 ->", res2.shape)
    
    res_final = (kernel(z_new, z_new) - res2).to_dense()
    
    print("res final ->", res_final.shape)
    
    return res_final



# Example usage
z_new = torch.rand(1, 5)         # 16 5 dimensional  batch of vectors
z_data = torch.rand(32, 5)   # 32 5 dimensional tensors in storage.
# Create an RBF kernel
rbf_kernel = gpytorch.kernels.RBFKernel()

result = compute_quadratic_form2(z_new, z_data, rbf_kernel)
print(result)
result2 = compute_quadratic_form3(z_new, z_data, rbf_kernel)
print(result2)

kernel matrix shape -> torch.Size([32, 32])
kernel vector shape -> torch.Size([1, 32])
res -> torch.Size([1, 32])
res2 -> torch.Size([1, 1])
res final -> torch.Size([1, 1])
tensor([[0.0650]], grad_fn=<AddBackward0>)
kernel matrix shape -> torch.Size([32, 32])
kernel vector shape -> torch.Size([1, 32])
res -> torch.Size([1, 32])
res2 -> torch.Size([1, 1])
res final -> torch.Size([1, 1])
tensor([[0.0650]], grad_fn=<AddBackward0>)
