```
Derivation:

k([x, t], [x', t']) = k0(x, x') k1(t, t')
                    ≈ [b0(x)^T b0(x')] [b1(t)^T b1(t')]
                    = tr[b0(x) b0(x')^T] tr[b1(t) b1(t')^T]
                    = tr[(b0(x) b0(x')^T) ⊗ (b1(t) b1(t')^T)]
                    = tr[(b0(x) ⊗ b1(t)) (b0(x') ⊗ b1(t'))^T]
                    = [b0(x) ⊗ b1(t)]^T [b0(x') ⊗ b1(t')],

where ⊗ denotes the Kroncker product while b0 and b1 are (approximate) feature maps for kernels k0 and k1.
```

In [1]:
from functools import partial

import torch
from botorch.sampling.pathwise import (
    fourier_feature_initializer,
    GeneralizedLinearBasis,
)
from gpytorch.kernels import MaternKernel, PolynomialKernel

In [2]:
torch.random.manual_seed(0)
torch.set_default_dtype(torch.float64)

In [3]:
A = torch.rand(3, 3)
B = torch.rand(3, 3)

In [4]:
matern_kernel = MaternKernel(nu=2.5)
matern_basis = GeneralizedLinearBasis(
    initializer=partial(fourier_feature_initializer, kernel=matern_kernel),
    output_shape=torch.Size([1024]), 
)
display(matern_kernel(A[..., :-1], B[..., :-1]).to_dense())
display(matern_basis(A[..., :-1]) @ matern_basis(B[..., :-1]).T)

tensor([[0.4955, 0.5858, 0.5964],
        [0.5510, 0.6342, 0.6615],
        [0.9629, 0.8477, 0.8903]], grad_fn=<MaternCovarianceBackward>)

tensor([[0.4970, 0.5889, 0.5946],
        [0.5530, 0.6367, 0.6592],
        [0.9622, 0.8400, 0.8912]], grad_fn=<MmBackward0>)

In [5]:
quad_kernel = PolynomialKernel(power=2)

def quad_basis(x, kernel=quad_kernel):
    c = kernel.offset
    return torch.concat([x * x, (2 * c).sqrt() * x, c.expand(*x.shape)], dim=-1)

display(quad_kernel(A[..., -1:], B[..., -1:]).to_dense())
display(quad_basis(A[..., -1:]) @ quad_basis(B[..., -1:]).T)

tensor([[0.5996, 0.8363, 0.6202],
        [0.6939, 1.1543, 0.7322],
        [0.6335, 0.9472, 0.6603]], grad_fn=<PowBackward0>)

tensor([[0.5996, 0.8363, 0.6202],
        [0.6939, 1.1543, 0.7322],
        [0.6335, 0.9472, 0.6603]], grad_fn=<MmBackward0>)

In [6]:
K = (matern_kernel(A[..., :-1], B[..., :-1]) * quad_kernel(A[..., -1:], B[..., -1:])).to_dense()
K

tensor([[0.2971, 0.4899, 0.3699],
        [0.3824, 0.7320, 0.4843],
        [0.6100, 0.8030, 0.5878]], grad_fn=<MulBackward0>)

In [7]:
def product_basis(x):
    matern_features = matern_basis(x[..., :-1])
    quad_features = quad_basis(x[..., -1:])
    return (matern_features.unsqueeze(-1) * quad_features.unsqueeze(-2)).view(*x.shape[:-1], -1)

K_est = product_basis(A) @ product_basis(B).T
K_est

tensor([[0.2980, 0.4925, 0.3688],
        [0.3837, 0.7349, 0.4826],
        [0.6096, 0.7957, 0.5884]], grad_fn=<MmBackward0>)