In [None]:
import math
import torch
import gpytorch 

In [None]:

class ExactGPmodel(gpytorch.models.ExactGP):
    """
    A Gaussian Process model for regression tasks using an RBF kernel with Automatic Relevance Determination (ARD).
    
    This model is intended for scenarios where each node may have a different impact on the studeied node, and thus,
    requires a distinct length scale parameter. The use of ARD allows the model to learn the relevance of each node
    independently, which can be critical for high-dimensional data or data where features vary in their importance.
    
    Parameters:
    - train_x (torch.Tensor): The training inputs, expected to be a tensor of shape (T, D),
                              where 'T' is the number of training points and 'D' is the number of features.
                              NOTE: in the current implemention, D = number of nodes N X maximum considered delay \Lambda
    - train_y (torch.Tensor): The training outputs, expected to be a tensor of shape (T,),
                              where 'T' is the number of training points.
    - likelihood (gpytorch.likelihoods.Likelihood): The likelihood model to use for inference.
                                                    This is typically GaussianLikelihood for regression tasks.
    """
    def __init__(self, train_x, train_y,likelihood):
        super(ExactGPmodel, self).__init__(train_x, train_y,likelihood)
        num_features = train_x.size(1)
        self.mean_module = gpytorch.means.ZeroMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=num_features))

        
    def forward(self,x):
        """
        The forward pass to compute the multivariate normal distribution over the targets given inputs.
        
        Parameters:
        - x (torch.Tensor): Input features tensor of shape (T, D).
        
        Returns:
        - gpytorch.distributions.MultivariateNormal: The computed Gaussian Process prior distribution over the outputs given the inputs.
        """
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)