In [47]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Batch_SMKernel(nn.Module):
    def __init__(self, Q, D):
        """
        Initialize the MixtureKernel.

        Parameters:
        Q (int): Number of mixture components.
        D (int): Dimensionality of the input space.
        """
        super(Batch_SMKernel, self).__init__()
        self.Q = Q
        self.D = D
        
        # Initialize weights (w_q)
        self.weights = nn.Parameter(torch.ones(Q) / Q)
        
        # Initialize means (mu_q)
        self.means = nn.Parameter(torch.randn(Q, D))
        
        # Initialize diagonal elements of covariance matrices (v_qd)
        self.log_covariances = nn.Parameter(torch.zeros(Q, D))

    def forward(self, x, x_prime):
        """
        Compute the kernel matrix for batched inputs.

        Parameters:
        x (torch.Tensor): Input tensor of shape (n, D)
        x_prime (torch.Tensor): Input tensor of shape (n, D)

        Returns:
        torch.Tensor: The kernel matrix of shape (n, n)
        """
        # Number of samples in the batch
        n = x.shape[0]

        # Reshape x and x_prime for broadcasting
        x = x.unsqueeze(1)  # Shape: (n, 1, D)
        x_prime = x_prime.unsqueeze(0)  # Shape: (1, n, D)

        # Compute pairwise differences
        diff = x - x_prime  # Shape: (n, n, D)
        
        kernel_matrix = torch.zeros(n, n)

        for q in range(self.Q):
            w_q = torch.exp(self.weights[q])
            mu_q = self.means[q]
            sigma_q_diag = torch.exp(self.log_covariances[q])
            Sigma_q = torch.diag(sigma_q_diag)
            
            # Compute determinant of Σ_q (product of diagonal elements)
            det_Sigma_q = torch.prod(sigma_q_diag)
            
            # Compute the normalization factor
            norm_factor = 1 / (det_Sigma_q**0.5 * (2 * torch.pi)**(self.D / 2))
            
            # Compute the exponent term
            exponent = -0.5 * torch.einsum('ijD,D,ijD->ij', diff, sigma_q_diag, diff)  # Shape: (n, n)
            
            # Compute the cosine term
            cosine_term = torch.cos(2 * torch.pi * torch.einsum('ijD,D->ij', diff, mu_q))  # Shape: (n, n)
            
            # Add the weighted component to the kernel matrix
            kernel_matrix += w_q * norm_factor * torch.exp(exponent) * cosine_term

        return kernel_matrix

In [66]:
class SMKernel(nn.Module):
    def __init__(self, Q, D):
        """
        Initialize the MixtureKernel.

        Parameters:
        Q (int): Number of mixture components.
        D (int): Dimensionality of the input space.
        """
        super(SMKernel, self).__init__()
        self.Q = Q
        self.D = D
        
        # Initialize weights (w_q)
        self.weights = nn.Parameter(torch.ones(Q) / Q)
        
        # Initialize means (mu_q)
        self.means = nn.Parameter(torch.randn(Q, D))
        
        # Initialize diagonal elements of covariance matrices (v_qd)
        self.log_covariances = nn.Parameter(torch.zeros(Q, D))

    def forward(self, x, x_prime):
        """
        Compute the kernel function.

        Parameters:
        x (torch.Tensor): Input vector of shape (D,)
        x_prime (torch.Tensor): Input vector of shape (D,)

        Returns:
        torch.Tensor: The value of the kernel function
        """
        kernel_value = 0.0
        
        for q in range(self.Q):
            w_q = torch.exp(self.weights[q])
            mu_q = self.means[q]
            sigma_q_diag = torch.exp(self.log_covariances[q])
            Sigma_q = torch.diag(sigma_q_diag)
            
            # Compute determinant of Σ_q (product of diagonal elements)
            det_Sigma_q = torch.prod(sigma_q_diag)
            
            # Compute the normalization factor
            norm_factor = 1 / (det_Sigma_q**0.5*((2 * torch.pi)**(self.D / 2)))
            
            # Compute the exponent term
            diff = x - x_prime

            exponent = -0.5 * torch.mm(diff, torch.matmul(Sigma_q, diff.T))
            # Compute the cosine term
            cosine_term = torch.cos(2 * torch.pi * torch.dot(diff.squeeze(), mu_q))
            
            # Add the weighted component to the kernel value
            kernel_value += w_q * norm_factor * torch.exp(exponent) * cosine_term

        return kernel_value

In [67]:
Q = 2  # Number of mixture components
D = 3  # Dimensionality of the input space
# Generate some test data
n = 5  # Number of samples
x = torch.randint(0, 10, (n, D)).float()

In [68]:
# Create the kernel
b_kernel = Batch_SMKernel(Q, D)
kernel = SMKernel(Q, D)
# Ensure they have the same parameters
b_kernel.weights.data = kernel.weights.data.clone()
b_kernel.means.data = kernel.means.data.clone()
b_kernel.log_covariances.data = kernel.log_covariances.data.clone()


# Compute the kernel matrix in batch
kernel_matrix_batch = b_kernel(x, x)

# Compute the kernel matrix element-wise
kernel_matrix_elementwise = torch.zeros(n, n)
for i in range(n):
    for j in range(n):
        kernel_matrix_elementwise[i, j] = kernel.forward(x[i].unsqueeze(0), x[j].unsqueeze(0))

# Check if the results are the same
print("Kernel matrix (batch computation):")
print(kernel_matrix_batch)
print("Kernel matrix (element-wise computation):")
print(kernel_matrix_elementwise)

# Check for equality
print("Are the results the same? ", torch.allclose(kernel_matrix_batch, kernel_matrix_elementwise))

Kernel matrix (batch computation):
tensor([[ 2.0937e-01,  1.8155e-06, -1.7399e-14, -8.8167e-04,  3.3351e-04],
        [ 1.8155e-06,  2.0937e-01, -2.1007e-06, -6.5085e-03, -2.4619e-06],
        [-1.7399e-14, -2.1007e-06,  2.0937e-01,  4.4194e-09,  2.4133e-13],
        [-8.8167e-04, -6.5085e-03,  4.4194e-09,  2.0937e-01, -1.5748e-05],
        [ 3.3351e-04, -2.4619e-06,  2.4133e-13, -1.5748e-05,  2.0937e-01]],
       grad_fn=<AddBackward0>)
Kernel matrix (element-wise computation):
tensor([[ 2.0937e-01,  1.8155e-06, -1.7399e-14, -8.8167e-04,  3.3351e-04],
        [ 1.8155e-06,  2.0937e-01, -2.1007e-06, -6.5085e-03, -2.4619e-06],
        [-1.7399e-14, -2.1007e-06,  2.0937e-01,  4.4194e-09,  2.4133e-13],
        [-8.8167e-04, -6.5085e-03,  4.4194e-09,  2.0937e-01, -1.5748e-05],
        [ 3.3351e-04, -2.4619e-06,  2.4133e-13, -1.5748e-05,  2.0937e-01]],
       grad_fn=<CopySlices>)
Are the results the same?  True


In [51]:
# Define a small MixtureKernel for testing


# Create the kernel
kernel = MixtureKernel(Q, D)

# Generate some test data
n = 5  # Number of samples
x = torch.randn(n, D)
x_prime = torch.randn(n, D)

# Compute the kernel matrix in batch
def compute_kernel_matrix_batch(kernel, x, x_prime):
    n = x.shape[0]
    x = x.unsqueeze(1)  # Shape: (n, 1, D)
    x_prime = x_prime.unsqueeze(0)  # Shape: (1, n, D)
    diff = x - x_prime  # Shape: (n, n, D)
    
    kernel_matrix = torch.zeros(n, n)
    
    for q in range(kernel.Q):
        w_q = torch.exp(kernel.weights[q])
        mu_q = kernel.means[q]
        sigma_q_diag = torch.exp(kernel.log_covariances[q])
        Sigma_q_inv = torch.diag(1 / sigma_q_diag)
        det_Sigma_q = torch.prod(sigma_q_diag)
        norm_factor = 1 / (det_Sigma_q**0.5 * (2 * torch.pi)**(kernel.D / 2))
        exponent = -0.5 * torch.einsum('ijD,D,ijD->ij', diff, 1 / sigma_q_diag, diff)
        cosine_term = torch.cos(2 * torch.pi * torch.einsum('ijD,D->ij', diff, mu_q))
        kernel_matrix += w_q * norm_factor * torch.exp(exponent) * cosine_term
    
    return kernel_matrix

kernel_matrix_batch = compute_kernel_matrix_batch(kernel, x, x_prime)