In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Batch_SMKernel(nn.Module):
    def __init__(self, Q, D):
        """
        Initialize the MixtureKernel.

        Parameters:
        Q (int): Number of mixture components.
        D (int): Dimensionality of the input space.
        """
        super(Batch_SMKernel, self).__init__()
        self.Q = Q
        self.D = D
        
        # Initialize weights (w_q)
        self.weights = nn.Parameter(torch.ones(Q) / Q)
        
        # Initialize means (mu_q)
        self.means = nn.Parameter(torch.randn(Q, D))
        
        # Initialize diagonal elements of covariance matrices (v_qd)
        self.log_covariances = nn.Parameter(torch.zeros(Q, D))

    def forward(self, x, x_prime):
        """
        Compute the kernel matrix for batched inputs.

        Parameters:
        x (torch.Tensor): Input tensor of shape (n, D)
        x_prime (torch.Tensor): Input tensor of shape (n, D)

        Returns:
        torch.Tensor: The kernel matrix of shape (n, n)
        """
        # Number of samples in the batch
        n = x.shape[0]

        # Reshape x and x_prime for broadcasting
        x = x.unsqueeze(1)  # Shape: (n, 1, D)
        x_prime = x_prime.unsqueeze(0)  # Shape: (1, n, D)

        # Compute pairwise differences
        diff = x - x_prime  # Shape: (n, n, D)
        
        kernel_matrix = torch.zeros(n, n)

        for q in range(self.Q):
            w_q = torch.exp(self.weights[q])
            mu_q = self.means[q]
            sigma_q_diag = torch.exp(self.log_covariances[q])
            Sigma_q = torch.diag(sigma_q_diag)
            
            # Compute determinant of Σ_q (product of diagonal elements)
            det_Sigma_q = torch.prod(sigma_q_diag)
            
            # Compute the normalization factor
            norm_factor = 1 / (det_Sigma_q**0.5 * (2 * torch.pi)**(self.D / 2))
            
            # Compute the exponent term
            exponent = -0.5 * torch.einsum('ijD,D,ijD->ij', diff, sigma_q_diag, diff)  # Shape: (n, n)
            
            # Compute the cosine term
            cosine_term = torch.cos(2 * torch.pi * torch.einsum('ijD,D->ij', diff, mu_q))  # Shape: (n, n)
            exp_term = torch.exp(exponent)
            print(exp_term, cosine_term)
            # Add the weighted component to the kernel matrix
            kernel_matrix +=  exp_term* cosine_term

        return kernel_matrix

In [2]:
class SMKernel(nn.Module):
    def __init__(self, Q, D):
        """
        Initialize the MixtureKernel.

        Parameters:
        Q (int): Number of mixture components.
        D (int): Dimensionality of the input space.
        """
        super(SMKernel, self).__init__()
        self.Q = Q
        self.D = D
        
        # Initialize weights (w_q)
        self.weights = nn.Parameter(torch.ones(Q) / Q)
        
        # Initialize means (mu_q)
        self.means = nn.Parameter(torch.randn(Q, D))
        
        # Initialize diagonal elements of covariance matrices (v_qd)
        self.log_covariances = nn.Parameter(torch.zeros(Q, D))

    def forward(self, x, x_prime):
        """
        Compute the kernel function.

        Parameters:
        x (torch.Tensor): Input vector of shape (D,)
        x_prime (torch.Tensor): Input vector of shape (D,)

        Returns:
        torch.Tensor: The value of the kernel function
        """
        kernel_value = 0.0
        
        for q in range(self.Q):
            w_q = torch.exp(self.weights[q])
            mu_q = self.means[q]
            sigma_q_diag = torch.exp(self.log_covariances[q])
            Sigma_q = torch.diag(sigma_q_diag)
            
            # Compute determinant of Σ_q (product of diagonal elements)
            det_Sigma_q = torch.prod(sigma_q_diag)
            
            # Compute the normalization factor
            norm_factor = 1 / (det_Sigma_q**0.5*((2 * torch.pi)**(self.D / 2)))
            
            # Compute the exponent term
            diff = x - x_prime

            exponent = -0.5 * torch.mm(diff, torch.matmul(Sigma_q, diff.T))
            # Compute the cosine term
            cosine_term = torch.cos(2 * torch.pi * torch.dot(diff.squeeze(), mu_q))
            
            # Add the weighted component to the kernel value
            kernel_value += w_q * norm_factor * torch.exp(exponent) * cosine_term

        return kernel_value

In [18]:
Q = 2  # Number of mixture components
D = 3  # Dimensionality of the input space
# Generate some test data
n = 5  # Number of samples
m = 10
x = torch.randint(0, 10, (n, D)).float()
x_prime = torch.randint(0, 10, (n, D)).float()
# torch.exp(x)*x_prime
torch.exp(x)* x_prime

tensor([[2.7299e+02, 2.4206e+03, 4.3679e+02],
        [0.0000e+00, 1.0000e+00, 8.9048e+02],
        [3.6945e+01, 1.3357e+03, 1.4905e+04],
        [0.0000e+00, 8.9429e+03, 1.8077e+02],
        [2.1746e+01, 2.1839e+02, 4.3865e+03]])

In [15]:
# Create the kernel
b_kernel = Batch_SMKernel(Q, D)
kernel = SMKernel(Q, D)
# Ensure they have the same parameters
b_kernel.weights.data = kernel.weights.data.clone()
b_kernel.means.data = kernel.means.data.clone()
b_kernel.log_covariances.data = kernel.log_covariances.data.clone()


# Compute the kernel matrix in batch
kernel_matrix_batch = b_kernel(x, x_prime)

# Compute the kernel matrix element-wise
kernel_matrix_elementwise = torch.zeros(n, n)
for i in range(n):
    for j in range(n):
        kernel_matrix_elementwise[i, j] = kernel.forward(x[i].unsqueeze(0), x_prime[j].unsqueeze(0))

# Check if the results are the same
print("Kernel matrix (batch computation):")
print(kernel_matrix_batch)
print("Kernel matrix (element-wise computation):")
print(kernel_matrix_elementwise)

# Check for equality
print("Are the results the same? ", torch.allclose(kernel_matrix_batch, kernel_matrix_elementwise))

tensor([[4.5991e-10, 1.3888e-11, 6.3051e-16, 4.1399e-08, 1.2502e-09, 8.2085e-02,
         2.5110e-08, 6.1442e-06, 1.2502e-09, 4.1399e-08],
        [1.4069e-16, 3.4425e-14, 3.4425e-14, 6.3051e-16, 8.6441e-22, 6.8256e-08,
         4.1938e-13, 3.6788e-01, 1.0395e-15, 2.0612e-09],
        [2.1151e-19, 3.5786e-29, 3.1800e-22, 5.0435e-07, 2.6103e-23, 1.1254e-07,
         1.8795e-12, 5.0435e-07, 2.3195e-16, 3.0988e-12],
        [6.9144e-13, 7.6812e-15, 1.0395e-15, 2.0347e-04, 5.6028e-09, 3.6788e-01,
         6.1442e-06, 2.5110e-08, 3.0590e-07, 5.0435e-07],
        [5.2429e-22, 5.3111e-27, 2.5768e-18, 8.2085e-02, 8.5330e-17, 1.2341e-04,
         2.2603e-06, 2.2897e-11, 5.6028e-09, 2.5110e-08]],
       grad_fn=<ExpBackward0>) tensor([[-0.5895,  0.1869,  0.8441,  0.8596, -0.2630, -0.4058, -0.9507, -0.8643,
         -0.9655, -0.4796],
        [ 0.6875, -0.0604, -0.9055, -0.7876,  0.3836,  0.5188,  0.9824,  0.7933,
          0.9245,  0.5873],
        [-0.8173, -0.1378,  0.9714,  0.6505, -0.5583, -

RuntimeError: The size of tensor a (5) must match the size of tensor b (10) at non-singleton dimension 1

In [5]:
# Define a small MixtureKernel for testing


# Create the kernel
kernel = MixtureKernel(Q, D)

# Generate some test data
n = 5  # Number of samples
x = torch.randn(n, D)
x_prime = torch.randn(n+5, D)

# Compute the kernel matrix in batch
def compute_kernel_matrix_batch(kernel, x, x_prime):
    n = x.shape[0]
    x = x.unsqueeze(1)  # Shape: (n, 1, D)
    x_prime = x_prime.unsqueeze(0)  # Shape: (1, n, D)
    diff = x - x_prime  # Shape: (n, n, D)
    
    kernel_matrix = torch.zeros(n, n)
    
    for q in range(kernel.Q):
        w_q = torch.exp(kernel.weights[q])
        mu_q = kernel.means[q]
        sigma_q_diag = torch.exp(kernel.log_covariances[q])
        Sigma_q_inv = torch.diag(1 / sigma_q_diag)
        det_Sigma_q = torch.prod(sigma_q_diag)
        norm_factor = 1 / (det_Sigma_q**0.5 * (2 * torch.pi)**(kernel.D / 2))
        exponent = -0.5 * torch.einsum('ijD,D,ijD->ij', diff, 1 / sigma_q_diag, diff)
        cosine_term = torch.cos(2 * torch.pi * torch.einsum('ijD,D->ij', diff, mu_q))
        kernel_matrix += w_q * norm_factor * torch.exp(exponent) * cosine_term
    
    return kernel_matrix

kernel_matrix_batch = compute_kernel_matrix_batch(kernel, x, x_prime)

NameError: name 'MixtureKernel' is not defined