In [1]:
import torch
import math
import gpytorch
from gpytorch.random_variables import RandomVariable, MixtureRandomVariable

In [2]:
class MixtureRandomVariableWithSampler(gpytorch.random_variables.MixtureRandomVariable):
    def sample(self, n_samples=1):
        # Get representation
        rand_vars, weights = self.representation()

        # Sample from a categorical distribution
        sample_ids = torch.distributions.categorical.Categorical(probs=weights).sample((n_samples,))

        # Sample from the individual distributions
        samples = torch.tensor([rand_vars[i].sample(1) for i in sample_ids], device=weights.device)

        return samples

In [3]:
class BatchRandomVariable(gpytorch.random_variables.RandomVariable):
    def __init__(self, *rand_vars, **kwargs):
        """
        Batch of random variables
        Params:
        - rand_vars (iterable of RandomVariables with b elements)
        """
        
        super(BatchRandomVariable, self).__init__(*rand_vars, **kwargs)
        if not all(isinstance(rand_var, RandomVariable) for rand_var in rand_vars):
            raise RuntimeError("Everything needs to be an instance of a random variable")
            
        self.rand_vars = rand_vars
        
    def representation(self):
        return self.rand_vars
    
    def mean(self):
        means = [rand_var.mean() for rand_var in self.rand_vars]
        return torch.tensor(means, device=means[0].device)
    
    def var(self):
        variances = [rand_var.var() for rand_var in self.rand_vars]
        return torch.tensor(variances, device=variances[0].device)
    
    def sample(self, n_samples=1):
        '''
        Sample n_samples for each of the b rand_vars and return an 
        b x (d) x n_samples... object consistent with random variables for which batch mode is enabled
        '''
        
        # b x ... x n_samples  Implementation (copying GaussianRandomVariable)
        samples = torch.cat([rand_var.sample(n_samples).squeeze().unsqueeze(0) for rand_var in self.rand_vars])
        return samples

#         # n_samples x b x ... Implementation
#         samples = torch.cat([rand_var.sample(n_samples).unsqueeze(0) for rand_var in self.rand_vars])
#         return samples.permute(1,0, *range(2,samples.ndimension()))