In [1]:
import torch
import gpytorch

import preprocRandomVariables
import preprocUtils

In [2]:
from gpytorch.functions import add_diag
from gpytorch.likelihoods import Likelihood
from gpytorch.priors._compatibility import _bounds_to_prior
from gpytorch.random_variables import GaussianRandomVariable, MixtureRandomVariable, RandomVariable
from gpytorch.priors import SmoothedBoxPrior

from torch.distributions import Poisson, Normal

# def probGaussTorch(y, mu, sig2):
#     return torch.exp(-0.5*(y-mu)**2/sig2 - 0.5*torch.log(2*np.pi*sig2))

# def probPoissonTorch(y, rate):
#     return torch.exp(torch.log(rate)*y - rate - torch.lgamma(y+1))

class MyLikelihood(gpytorch.likelihoods.Likelihood): #TODO - rename to #NormalMixtureOverPoissonCountsLikelihood
    """
        Implement the likelihood of an observation given f(x) is log Gain(x)
        
        Each f(x) is a GaussianRandomVariable (the output of the GP)
    """
    def __init__(self, gain=None, scale=None, offset=None):
        super(MyLikelihood, self).__init__()
        
        # Add parameters

        #self.offset = 0.0 #* torch.ones_like(self.log_gain)
        #self.scale = 2.0 #* torch.ones_like(self.log_gain)
        
    def log_probability(self, latent_func, target):
        input_device = latent_func.mean().device
        #print(input_device)
        assert(input_device == target.device)
#         n_samples = gpytorch.settings.num_likelihood_samples.value()
#         samples = latent_func.sample(n_samples)

        sample = latent_func.covar().matmul(latent_func.mean().unsqueeze(-1)).squeeze()
        res = torch.zeros((target.size(0),), device=input_device)
        
#         # Each sample is a potential mean parameter of the Poisson (categorical variable)
#         for n in range(n_samples):
#             sample = samples[:,n]

        # Average over log probability of any given draw from the poisson (max 20 photons)
        lambda_cur = Poisson(sample.exp())
        photon_counts = torch.arange(20., device=input_device).unsqueeze(1)
        photon_log_probs = lambda_cur.log_prob(photon_counts)
        #photon_log_probs -= photon_log_probs.exp().sum().log() # Ensure the discrete distribution sums to 1 ?

        # Define the photomultiplier probability distribution as a sum over the transformed photon counts
        p_PM = torch.distributions.Normal(loc=(photon_counts*self.log_gain.exp()).squeeze()+self.offset, 
                                          scale=((photon_counts+1)*(self.log_scale.exp())).squeeze())

        # Take average accross photon counts
        #cur_sample_res = (p_PM.log_prob(target) + photon_log_probs).exp().sum(0).log()


        log_prob_sum_over_target_samples = p_PM.log_prob(target.unsqueeze(-1)).unsqueeze(1).sum(-2).squeeze() # This should work for 1d or 2d targets
        cur_sample_res = (log_prob_sum_over_target_samples.permute(1,0) + photon_log_probs).logsumexp(dim=0) # Numerically more stable logsumexp
        #cur_sample_res = (p_PM.log_prob(target) + photon_log_probs).logsumexp(dim=0) # Numerically more stable logsumexp

        # Add to results (for each sample, devide by number of samples [everything is in log space])
        res += cur_sample_res
        #res += cur_sample_res - torch.tensor([n_samples], device=input_device).float().log()

        return res.sum()
    
    def forward(self, latent_func, approx=True):
        """
        Computes predictive distributions p(y|x) given a latent distribution
        p(f|x). To do this, we solve the integral:
            p(y|x) = \int p(y|f)p(f|x) df
            
        As the true representation gets storage expensive for large latent_func dimensionalities, 
        we provide an approximate method by default in which we only store the true mean and variance for each output
        """
        input_device = latent_func.mean().device
        
        sample = latent_func.covar().matmul(latent_func.mean().unsqueeze(-1)).squeeze()
        
        lambda_cur = Poisson(sample.exp()) # Exponential link function from the mean of the latent
        photon_counts = torch.arange(20., device=input_device).unsqueeze(1)
        photon_log_probs = lambda_cur.log_prob(photon_counts)
        # Ensure the discrete distribution sums to 1, as they are going to be mixture weights
        photon_log_probs -= photon_log_probs.logsumexp(0).unsqueeze(0)
        
        # Define the photomultiplier probability distribution as a weighted mixture over the transformed photon counts
        p_PM_each = [
            GaussianRandomVariable((photon_counts[i]*self.log_gain.exp())+self.offset, 
                                   (photon_counts[i]+1e-6)*self.log_scale.exp().diag())
            for i in range(photon_counts.numel())]

        if not approx: # True representation
            batch_rand_vars = [
                preprocRandomVariables.MixtureRandomVariableWithSampler(
                    *p_PM_each, 
                    weights=photon_log_probs.exp()[:,t]
                )
                for t in range(photon_log_probs.shape[1])
            ]

            return preprocRandomVariables.BatchRandomVariable(*batch_rand_vars)
        else:
            # Each target is a mixture random variable, let's approximate them as gaussian:    
            def tmp_f(x): 
                tmp = MixtureRandomVariable(*p_PM_each, weights=x)
                return torch.stack([tmp.mean().data, tmp.var().data]).to(x.device)
            all_moments = preprocUtils.apply(
                tmp_f,
                photon_log_probs.exp(),
                dim = 1
            ).squeeze()
            
            
            #             all_means = 0.*latent_func.mean()
#             all_vars = 0.*latent_func.mean()
#             for t in range(photon_log_probs.shape[1]):
#                 tmp = MixtureRandomVariable(*p_PM_each, weights=photon_log_probs.exp()[:,t])
#                 all_means[t] = tmp.mean().data[0]
#                 all_vars[t] = tmp.var().data[0]
#                 #del tmp

            return GaussianRandomVariable(all_moments[0,:], gpytorch.lazy.DiagLazyVariable(all_moments[1,:]))
            
        
        #return MixtureRandomVariable(*p_PM_each, weights=photon_log_probs.exp())
    
# #     def forward(self, latent_func):
# #         """
# #         Computes predictive distributions p(y|x) given a latent distribution
# #         p(f|x). To do this, we solve the integral:
# #             p(y|x) = \int p(y|f)p(f|x) df
# #         """
        
# #         if not isinstance(latent_func, GaussianRandomVariable):
# #             raise RuntimeError(
# #                 "MyLikelihood expects a Gaussian distributed latent function to make predictions"
# #             )
        
# #         quadInt = QuadratureIntegratorTorch(30)
        
# #         return self.quadInt.integrate_discrete(
# #                     lambda num_photons : 
# #                         probGauss(x_elem, num_photons*self.params['gain_gauss'], self.params['sig_gauss'])*
# #                                 probPoisson(num_photons,self.params['mu_poiss'])
        

In [3]:
likelihood = MyLikelihood()

In [4]:
n = 50
m = 20
latent_func = GaussianRandomVariable(torch.randn(n,), gpytorch.lazy.DiagLazyVariable(torch.randn(n,)))
target = torch.randn(n, m)

In [None]:
likelihood.log_probability(latent_func, target)

In [None]:
likelihood(latent_func)

# Refactoring into multiple levels of completeness

In [5]:
from IPython.core.debugger import set_trace

In [6]:
def getVar(latent_func):
    """
    Use this method to get variance approximations of interpolated lazy variables
    """
    if latent_func.covar().size(0) > 2000 and isinstance(latent_func.covar(), gpytorch.lazy.InterpolatedLazyVariable):
        return latent_func.covar()._approx_diag()
    else:
        return latent_func.var()


In [7]:
import quadrature_integrator_torch
#reload(quadrature_integrator_torch)

In [8]:
import torch
import math
import gpytorch

from gpytorch.likelihoods import Likelihood, GaussianLikelihood
from torch.distributions import Poisson, Normal

from gpytorch.random_variables import GaussianRandomVariable, MixtureRandomVariable, RandomVariable
from gpytorch.priors import SmoothedBoxPrior

import preprocUtils
import preprocRandomVariables

from preprocUtils import toTorchParam

from quadrature_integrator_torch import QuadratureIntegratorTorch

In [9]:
class BasePhotomultiplierLikelihood(Likelihood):
    """
    Defines useful functions for various photomultiplier models
    """
    
    def __init__(self, 
                 gain=None, offset=None, noise=None,
                 gaussQuadratureDegree = int(10)):
        super(BasePhotomultiplierLikelihood, self).__init__()

        
        #### Add a Gauss-Hermite integrator, we're going to need it all the time
        self.integrator = QuadratureIntegratorTorch(gaussQuadratureDegree)
        

        #### -----------------------------------
        #### Register photomultiplier parameters
        #### -----------------------------------

        self.register_parameter(name="log_gain", 
                                parameter=toTorchParam(gain if gain is not None else 1., ndims=1, to_log=True), 
                                prior=SmoothedBoxPrior(-1, 8, sigma = 0.1))
        self.register_parameter(name="offset", 
                                parameter=toTorchParam(offset if offset is not None else 0., ndims=1, to_log=False), 
                                prior=SmoothedBoxPrior(-200, 200, sigma = 5.0))        
        self.register_parameter(name="log_noise", 
                                parameter=toTorchParam(noise if noise is not None else 1., ndims=1, to_log=True), 
                                prior=SmoothedBoxPrior(-5, 5, sigma = 0.1))


    def single_log_prob(self, inp, cur_target):
        """ Required for log_probability, 
            takes a (N, ) vectors inp [concrete realisation of input] and cur_target and 
            return log p(cur_target | inp)
        """
        raise NotImplementedError

    def log_probability(self, latent_func, target):
        """
        Compute the expectation 
                E_f [log p(y|f) ] = \integral (log p(p|f)) * p(f | mean, var) df
                
        One might average over a set of latent function samples
        For the purposes of our variational inference implementation, y is an
        n-by-1 label vector, and f is an n-by-s matrix of s samples from the
        variational posterior, q(f|D).
        """
        input_device = latent_func.mean().device
        assert(input_device == target.device)
        
        if target.dim()==1:
            target = target.unsqueeze(1)
        
        # Define the log probability function, 
        # then Gauss-Hermite integrate it using the inherited self.integrator
        latStd = getVar(latent_func).sqrt()

        # TODO - for some reason there is still a lot of memory leak in the integration process below (maybe due to gradient computations?)
        
        res = 0
        for i in range(target.size(1)):
            tmp=self.integrator.batch_integrate_gauss(
                lambda x: self.single_log_prob(x, target[:,i]),
                mu = latent_func.mean(),
                sig = latStd
           ).sum()
            
            res+= tmp
            del tmp
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        
        return res.div(target.size(1))
        
    def forward(self, *inputs, **kwargs):
        """
        Compute the expectation
            p(y|x) = E_f [ p(y | f ) * p (f | x)]
        
        Computes a predictive distribution p(y*|x*) given either a posterior
        distribution p(f|D,x) or a prior distribution p(f|x) as input.
        With both exact inference and variational inference, the form of
        p(f|D,x) or p(f|x) should usually be Gaussian. As a result, input
        should usually be a GaussianRandomVariable specified by the mean and
        (co)variance of p(f|...).
        """
        raise NotImplementedError

    def getPhotonLogProbs(self, poissMeans, max_photon=30., reNormalise = False):
        """
        Given a vector of poisson means (Nx1) and a maximum photon count (D), 
        return 
         - the counts vector (D x 1) and the 
         - the log probabilities matrix (D x N)
        """

        # Check if we do not need to represent counts up to max_photon, or we'd need more:
        # TODO - check poissMeans.max() and it's CDF value at max_photon, 
            # if too low - warn for more, it too high, just cut from max_photon
            
        # Average over log probability of any given draw from the poisson (max max_photons)
        #set_trace()
        lambda_cur = Poisson(poissMeans)
        photon_counts = torch.arange(max_photon, device=poissMeans.device).view(-1,1,1)
        photon_log_probs = lambda_cur.log_prob(photon_counts)

        if reNormalise:
            photon_log_probs -= photon_log_probs.logsumexp(0).unsqueeze(0)

        #set_trace()
        return photon_counts, photon_log_probs
    
    
    

In [10]:
class LinearGainLikelihood(BasePhotomultiplierLikelihood):
    """
    Assumes that the incoming continuous (!) log number gets linearly multiplied by log_gain.exp(),
    and the observed value is a scaled Poisson distribution with a given output-offset (instead of the usual input-offset)
    approximated by an appropriate Gaussian
    
    
    p(y | f) = N_y ( g*f + offset,  g^2 * f^2 + sigma_y^2 )
    
        # noise = sigma_y^2
    """
    def __init__(self, gain=None, offset=None, noise=None): 
        super(LinearGainLikelihood, self).__init__(gain = gain, offset = offset, noise=noise)
        
    def forward(self, latent_func):
        
        pred_mean = (latent_func.mean() + self.log_gain).exp() + self.offset
        pred_var = (self.log_gain.exp().pow(2) * 
                        (latent_func.mean().exp().pow(2) + 2*(getVar(latent_func).exp().pow(2))))
             # g^2 * (mean^2 + 2 * var^2) - expected log prob
            
        return GaussianRandomVariable(pred_mean, gpytorch.lazy.DiagLazyVariable(pred_var))
    
    def single_log_prob(self, inp, cur_target):
        """
        Input is directly from latent_func, means it still needs to be exponentiated before multiplying
        """
        #set_trace()
        prob_var = self.log_noise.exp() + (inp + self.log_gain).exp()**2
        res = -0.5 * ((cur_target - ((inp + self.log_gain).exp() + self.offset) )  ** 2) / (
            prob_var
        )
        res += -0.5 * prob_var.log() - 0.5 * math.log(2 * math.pi)
        return res
    
    
#     def log_probability(self, latent_func, target):
#         """
#         Note that latent_func represent the logarithm of the incoming signal!
        
        
#         """
#         input_device = latent_func.mean().device
#         assert(input_device == target.device)
        
#         if target.dim()==1:
#             target = target.unsqueeze(1)
        
#         # Define the log probability function, 
#         # then Gauss-Hermite integrate it using the inherited self.integrator
        
        
#         latStd = getVar(latent_func).sqrt()

#         res = 0
#         for i in range(target.size(1)):
#             res+=self.integrator.batch_integrate_gauss(
#                 lambda x: self.single_log_prob(x, target[:,i]),
#                 mu = latent_func.mean(),
#                 sig = latStd
#            ).sum()
        
#         return res.div(target.size(1))
        
        """
        res = 0.
        for i in range(target.size(1)):
            res += self.single_log_prob(latent_func.mean(), target[:,i]).sum()
        
        
        return res

        """
        
        
        """
        def toApply(pred_row):
            #pred_row is [latent mean, latent std, target]
            #set_trace()
            return self.integrator.integrate_gauss(
                lambda x: single_log_prob(x, pred_row[2]),
                mu = pred_row[0],
                sig = pred_row[1]
            )
        
        eachIntegral = [
            preprocUtils.apply(
                lambda x : toApply(x), 
                torch.stack([latent_func.mean(), 
                          getVar(latent_func).sqrt(), 
                          target[:,i]],
                          dim=1),
                dim = 0
            ).sum()
            
            for i in range(target.size(1))]

        return torch.tensor(eachIntegral, device=input_device).sum()
        """        


In [11]:
class PoissonInputPhotomultiplierLikelihood(BasePhotomultiplierLikelihood):
    """
    Assumes that the incoming discrete number gets linearly multiplied by log_gain.exp(),
    and the observed value is a Normal distribution with a given input-offset, 
    whose mean and variance comes from the observed count.
    
    There is also an explicit "pedestal", which models the distribution around 0 counts
    
    p(f1 | f) = Poisson(f)
    p(y | f1) = {
        N_y ( g*f1 + offset,  g^2 * sigma_y^2 ) if f >0
        N_y (  offset , sigma_y_0^2)            if f == 0
    
        # noise = sigma_y^2
        # noise_pedestal = sigma_y_0^2
    """
    def __init__(self, gain=None, offset=None, noise=None, noise_pedestal=None): 
        super(PoissonInputPhotomultiplierLikelihood, self).__init__(gain = gain, offset = offset, noise=noise)
        
        self.register_parameter(name="log_noise_pedestal", 
                                parameter=toTorchParam(noise_pedestal if noise_pedestal is not None else 5., ndims=1, to_log=True), 
                                prior=SmoothedBoxPrior(-5, 15, sigma = 0.1))
        
#     def forward(self, latent_func):
        
#         pred_mean = (latent_func.mean() + self.log_gain).exp() + self.offset
#         pred_var = (self.log_gain.exp().pow(2) * 
#                         (latent_func.mean().exp().pow(2) + 2*(getVar(latent_func).exp().pow(2))))
#              # g^2 * (mean^2 + 2 * var^2) - expected log prob
            
#         return GaussianRandomVariable(pred_mean, gpytorch.lazy.DiagLazyVariable(pred_var))
    
    
    def getLogProbSumOverTargetSamples(self, p_PM, cur_target_slice):
            """ 
            As we are supposedly dealing with truncated normals, we need to replace
            log probabilities of 0s in the cur_target with the (log) CDF at 0 instead of the log_prob
            """
            allLogProbs = p_PM.log_prob(cur_target_slice.view(-1,1))
            
            # Correct for the less than 1 observations with log CDF instead of log_prob
            if (cur_target_slice<1.).sum()>0:
                allLogProbs[cur_target_slice<1., :] = p_PM.cdf(1.).log().squeeze()
                #allLogProbs[cur_target_slice<1., :] = p_PM.cdf(cur_target_slice[cur_target_slice<1.].view(-1,1)).log()
            
            #set_trace()
            
            # Sum over target samples
            return allLogProbs.unsqueeze(1).sum(-2).squeeze() # This should work for 1d or 2d targets
    
    
    def single_log_prob(self, inp, cur_target, batchsize = int(500), max_photon=25.):
        """
        Input is directly from latent_func, means it still needs to be exponentiated before multiplying.
        Input is the log mean of a Poisson distribution
        
        Input shape is [QuadratureWeights, N]
        cur_target is [N,]
        """
        
        # Have to do this in mini-batches as the resulting QuadWeights x MaxPhotons x N array is too big
        all_res = torch.zeros_like(inp)
        
        photon_counts, photon_log_probs = (
                self.getPhotonLogProbs(inp[:,:2].exp(), max_photon=float(max_photon), reNormalise = False))

        p_PM = torch.distributions.Normal(loc=(photon_counts*self.log_gain.exp()).squeeze()+self.offset, 
                                          scale=((photon_counts*(self.log_noise.mul(0.5).exp())).squeeze())
                                         )

        # Add the pedestal noise
        p_PM.scale[0] += self.log_noise_pedestal.mul(0.5).exp().squeeze()
        
        #set_trace()
        
        
        for i in range(int(int(inp.size(1))/batchsize)+1):
            cur_slice = slice(i*batchsize,min((i+1)*batchsize, inp.size(1)))
            inp_cur = inp[:,cur_slice]
            if inp_cur.ndimension()==1:
                inp_cur = inp_cur.unsqueeze(1)
            
            photon_counts, photon_log_probs = (
                self.getPhotonLogProbs(inp_cur.exp(), max_photon=float(max_photon), reNormalise = False))

            ### return photon_counts, photon_log_probs, p_PM

            log_prob_sum_over_target_samples = self.getLogProbSumOverTargetSamples(p_PM, cur_target[cur_slice])
            if log_prob_sum_over_target_samples.ndimension()==1: # Correct for if we only get a single input value
                log_prob_sum_over_target_samples = log_prob_sum_over_target_samples.unsqueeze(0)



            all_res[:,cur_slice] += (log_prob_sum_over_target_samples.permute(1,0).unsqueeze(1) + photon_log_probs).logsumexp(dim=0)
            
            #del log_prob_sum_over_target_samples, photon_counts, photon_log_probs, p_PM
            
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        
        return all_res # Same dimensionality as inp [QuadratureWeights, N]
    
      
    def forward(self, latent_func, approx=True, max_photon=int(20), batchsize = int(1000)):
        """
        Computes predictive distributions p(y|x) given a latent distribution
        p(f|x). To do this, we solve the integral:
            p(y|x) = \int p(y|f)p(f|x) df
            
        As the true representation gets storage expensive for large latent_func dimensionalities, 
        we provide an approximate method by default in which we only store the true mean and variance for each output
        """
        input_device = latent_func.mean().device
        
        if approx is False: ## Return the full Batch of MixtureRandomVariables (huge object, use only for few predictions)
            
            # Get the input standard deviation
            latStd = getVar(latent_func).sqrt()
            
            # Get the photon log probabilities
            photon_log_probs = self.integrator.batch_integrate_gauss(
                    lambda x: self.getPhotonLogProbs(x.exp(), max_photon=float(max_photon), reNormalise = False)[1].permute(1,0,2),
                    mu = latent_func.mean(),
                    sig = latStd,
                    viewAs = [-1, 1, 1]
               )
            
            # As these probabilities are a result of non-perfect integration, we need to renormalise them outside of the integration
            photon_log_probs -= photon_log_probs.logsumexp(0).view(1,-1)
            
            # For very low probabilities, the renomalisation may not have been perfect, renormalise once more in probability space
            photon_probs = photon_log_probs.exp()
            photon_probs = photon_probs.div(photon_probs.sum(0).unsqueeze(0))
            
            
            
            # Underlying mixture elements
            p_PM_each = [GaussianRandomVariable(self.offset, self.log_noise_pedestal.exp().diag())]
            p_PM_each.extend([
                GaussianRandomVariable(i*self.log_gain.exp()+self.offset, 
                                   i*self.log_noise.exp().diag())
                for i in range(1, max_photon)])
            
            #set_trace()
            
            return preprocRandomVariables.BatchRandomVariable(
                *[preprocRandomVariables.MixtureRandomVariableWithSampler(*p_PM_each, weights=photon_probs[:,i])
                  for i in range(photon_probs.size(1))]
                 )
            
        else: # We are just returning the predictive mean and variance
        
            # Get the input standard deviation
            latStd = getVar(latent_func).sqrt()


            photon_counts, photon_log_probs = (
                    self.getPhotonLogProbs(latent_func.mean().exp(), max_photon=float(max_photon), reNormalise = False))

            predVariancesPerPhoton = (photon_counts*(self.log_noise.exp())).squeeze()
            predVariancesPerPhoton[0] += self.log_noise_pedestal.exp().squeeze()

            # Get the photon probabilities for all latent inputs
            allPhotonLogProbs = torch.zeros([latent_func.mean().size(0), max_photon], device=input_device)


            # Store the resulting moments
            pred_moments = torch.zeros([latent_func.mean().size(0), 2], device=input_device)

            # Do minibatch integration to get expected probabilities by summing samples of log probabilties            
            for i in range(int(int(latStd.size(0))/batchsize)+1):
                cur_slice = slice(i*batchsize,min((i+1)*batchsize, latStd.size(0)))

                #set_trace()

                tmp=self.integrator.batch_integrate_gauss(
                    lambda x: self.getPhotonLogProbs(x.exp(), max_photon=float(max_photon), reNormalise = False)[1].permute(1,0,2),
                    mu = latent_func.mean()[cur_slice],
                    sig = latStd[cur_slice],
                    viewAs = [-1, 1, 1]
               )


                allPhotonLogProbs[cur_slice,:] = tmp.permute(1,0)

                #set_trace()

                # Set the predictive means
                pred_moments[cur_slice, 0] = allPhotonLogProbs[cur_slice,:].exp().matmul((
                    photon_counts*self.log_gain.exp()+self.offset).squeeze().unsqueeze(1)).squeeze()

                # Set the predictive variances
                #set_trace()

                # First weighted squared distance from global mean
                pred_moments[cur_slice, 1] = (allPhotonLogProbs[cur_slice,:].exp().mul(
                    ((photon_counts*self.log_gain.exp()+self.offset).view(1,-1) - pred_moments[cur_slice, 0].view(-1,1)).pow(2)
                        ).sum(1).squeeze())

                #set_trace()

                # Then the individual variances
                pred_moments[cur_slice, 1] += (allPhotonLogProbs[cur_slice,:].exp()
                                               .matmul(predVariancesPerPhoton.view(-1,1)).squeeze())

                # Once we have all probalities, estimate the output mixture variable means and variances

            return GaussianRandomVariable(pred_moments[:,0], gpytorch.lazy.DiagLazyVariable(pred_moments[:,1]))
        
        """
        # Create the mixture elements
        p_PM_each = [GaussianRandomVariable(self.offset, self.log_noise_pedestal.exp().diag())]
        p_PM_each.extend([
            GaussianRandomVariable(i*self.log_gain.exp()+self.offset, 
                                   i*self.log_noise.exp().diag())
            for i in range(1, max_photon)])

            def tmp_f(x): 
                tmp = MixtureRandomVariable(*p_PM_each, weights=x)
                return torch.stack([tmp.mean().data, tmp.var().data], dim=1).to(x.device)

            set_trace()

            pred_moments[cur_slice,:] = preprocUtils.apply(
                tmp_f,
                allPhotonLogProbs[cur_slice,:].exp(),
                dim = 0
            )
       """
        
        #set_trace()

In [12]:
# Plotly 
import plotly
from plotly.offline import iplot as plt
from plotly import graph_objs as plt_type
plotly.offline.init_notebook_mode(connected=True)

import nbimporter
from preprocVisualisationTesting import *

Importing Jupyter notebook from preprocVisualisationTesting.ipynb


In [13]:
import math
from numbers import Number

import torch
from torch.distributions import constraints
from torch.distributions.exp_family import ExponentialFamily
from torch.distributions.utils import broadcast_all

import warnings

def erfcx(x):
    #https://stackoverflow.com/questions/8962542/is-there-a-scaled-complementary-error-function-in-python-available
    ret1 = (1.-x.erf()).mul(x.pow(2.).exp())
    
    y = 1. / x
    z = y * y
    s = y*(1.+z*(-0.5+z*(0.75+z*(-1.875+z*(6.5625-29.53125*z)))))
    ret2 = s * 0.564189583547756287
    
    ret_final = ret2
    #ret_final[x<10.] = ret1[x<10.]
    
    return ret_final


class ExponentiallyModifiedGaussian(torch.distributions.Distribution):
    """
    R. Dossi et al. / Nuclear Instruments and Methods in Physics Research A 451 (2000) 623}637
    Equation 9, with pE = 1.
    
    There's an error in the erf part which should read (x-xp) instead of just xp
    """
    def __init__(self, loc, scale, expamplitude, validate_args=None):
        self.loc, self.scale, self.expamplitude = broadcast_all(loc, scale, expamplitude)
        if isinstance(loc, Number):
            batch_shape = torch.Size()
        else:
            batch_shape = self.loc.size()
        super(ExponentiallyModifiedGaussian, self).__init__(batch_shape, validate_args=validate_args)
        
        self.integrator = QuadratureIntegratorTorch(40.)
        
    def log_prob(self, x):
        """
        For numerical stability we use multiple implementations as suggested on 
        https://en.wikipedia.org/wiki/Exponentially_modified_Gaussian_distribution
        """
        if self._validate_args:
            self._validate_sample(value)
        # compute the variance
        scale = self.scale
        var = scale.pow(2.)
        A = self.expamplitude
        xp = self.loc
        
        # For numerically stable wiki implementation
        sigma = scale
        mu = xp
        tau = A
        
        x = torch.tensor(x).view([-1]+[1]*self.loc.ndimension())
        
        z = ((scale/A)-(x-xp).div(scale)).mul(torch.tensor(1./2.).sqrt())
        
#         print(z.max())
#         print(z.min())
        
        
        # Compute all 3 wikipedia implementation, then fill in the result based on value of z
        
        # model 1
        #norm_part = sigma.div(tau).mul(torch.tensor(math.pi/2.).sqrt())
        #norm_part = torch.tensor(1.)
        norm_part = (2.*tau).reciprocal()
        
        exp_part = (var-2.*tau*(x-mu)).div(2.*tau.pow(2))
        erf_part = z
        
        ret1 = norm_part.log() + exp_part + (1.-erf_part.erf()).log()
        
        # model 2
        
        exp_part_approx = (x-xp).pow(2).div(-2.*var)

        ret2 = norm_part.log() + exp_part_approx + erfcx(erf_part).log()
         
        
        # model 3
        div_part = 1.-(x-xp).mul(A).div(var)  
        
        ret3 = exp_part_approx - div_part.log()
        
        
        ret_final = ret3   # Bad estimate for small-ish z, goes over 1 around 4
        ret_final[z<8191.] = ret2[z<8191.] # based on single-precision float
        ret_final[z<3.] = ret1[z<3.] # Experimentally, see below
        
        return ret_final
    
        """
            # Check which part works best where - seems like the erfcx implementation is not very good
            return torch.stack([ret1[:,1,1], ret2[:,1,1], ret3[:,1,1]], dim=1), z[:,1,1]
            a = ExponentiallyModifiedGaussian(loc = 0.*torch.ones(3,4), 
                                      scale = 0.6*torch.ones(3,4), 
                                      expamplitude=.1*torch.ones(3,4))
            inp = torch.arange(-20., 30., 0.01)
            retval, z = a.log_prob(inp.view(-1,1,1))#[:,1,1]
            #print(retval)
            plot(retval.view(-1,3).exp().clamp(0.,1.), z.view(-1))
        """ 
            
    def cdf(self, x):
        if not isinstance(x, Number) or self.loc.ndimension()>0:
            raise NotImplementedError
        return self.integrator.integrate(
            lambda tmp: self.log_prob(tmp).exp(), 
            a = (self.loc.min()-5.*self.scale.max()),
            b = x,
        )
#         return self.integrator.batch_integrate(
#             lambda tmp: self.log_prob(tmp).exp(), 
#             a = -10.,
#             b = 10.,
# #             a = (self.loc.min()-2*self.scale.max()),
# #             b = x,
#             viewAs = [-1]+[1]*self.loc.ndimension())

In [14]:
a = ExponentiallyModifiedGaussian(loc = 0., 
                                  scale = torch.tensor(1.).sqrt(), 
                                  expamplitude=0.7)

inp = torch.arange(-20., 30., 0.01)
plot(a.log_prob(inp).exp().clamp(0.,1.).view(-1,1), inp.view(-1))

print(a.cdf(1.))

tensor(0.6192)


In [15]:
a = ExponentiallyModifiedGaussian(loc = 0.*torch.ones(3,4), 
                                  scale = 2.3*torch.ones(3,4), 
                                  expamplitude=1.4*torch.ones(3,4))

inp = torch.arange(-20., 30., 0.01)
plot(a.log_prob(inp.view(-1,1,1))[:,1,1].exp().clamp(0.,1.).view(-1,1), inp.view(-1))

In [16]:
class SingleElectronResponse(torch.distributions.Distribution):
    def __init__(self, 
                 loc, scale, 
                 pedestal_loc, pedestal_scale, 
                 underamplified_amplitude, underamplified_probability):
        self.ExpModGauss = ExponentiallyModifiedGaussian(loc=pedestal_loc, 
                                                         scale=pedestal_scale, 
                                                         expamplitude=underamplified_amplitude)
        self.Normal = torch.distributions.Normal(loc = loc, 
                                                 scale=scale)
        self.underamplified_probability = underamplified_probability
        
    def log_prob(self, x):
        return (
            self.underamplified_probability * self.ExpModGauss.log_prob(x).exp() +
            (1.-self.underamplified_probability) * self.Normal.log_prob(x).exp()
        ).log()
    
    def cdf(self, x):
        return (
            self.underamplified_probability * self.ExpModGauss.cdf(x) +
            (1.-self.underamplified_probability) * self.Normal.cdf(x)
        )

In [17]:
class SingleElectronResponseRandomVariable(gpytorch.random_variables.RandomVariable):
    def __init__(self, 
                 loc, scale, 
                 pedestal_loc, pedestal_scale, 
                 underamplified_amplitude, underamplified_probability):
        self.distribution = SingleElectronResponse(loc, scale, 
                 pedestal_loc, pedestal_scale, 
                 underamplified_amplitude, underamplified_probability)                   
        
        # For higher photon counts we ignore convolution with the noise and compute moments of idealised SER0
     
    def representation(self):
        return self.distribution
        
    def mean(self):
        pE = self.distribution.underamplified_probability.squeeze()
        A = self.distribution.ExpModGauss.expamplitude.squeeze()
        gain = self.distribution.Normal.loc.squeeze()
        s0 = self.distribution.Normal.scale.squeeze()
        
        return (pE*A + (1-pE)*gain)                   
        
    
    def var(self):
        pE = self.distribution.underamplified_probability.squeeze()
        A = self.distribution.ExpModGauss.expamplitude.squeeze()
        gain = self.distribution.Normal.loc.squeeze()
        s0 = self.distribution.Normal.scale.squeeze()
        
        return (pE*2.*A.pow(2.) + (1.-pE)*(gain.pow(2.)+s0.pow(2.)) - self.mean().pow(2.))
        
    def sample(self, n_samples=1, n_categories=int(100)):
        """approximates the distribution as a categorical binned distribution"""
        
        bin_mids = torch.linspace(
            float(min((self.distribution.ExpModGauss.loc.min()-4.*self.distribution.ExpModGauss.scale.max()),
                    self.distribution.Normal.loc.min()-2.*self.distribution.Normal.scale.max())),
            float(self.distribution.Normal.loc.max()+2.*self.distribution.Normal.scale.max()),
            n_categories
        ).to(self.distribution.Normal.loc.device)
               
        bin_size = bin_mids[1] - bin_mids[0]
               
        weights = torch.zeros_like(bin_mids)
        weights[1:-1] = self.distribution.log_prob(bin_mids[1:-1]).exp()
        weights[0] = self.distribution.cdf(float(bin_mids[0]))
        weights[-1] = 1.-self.distribution.cdf(float(bin_mids[-1]))
               
        # Sample from a categorical distribution
        sample_ids = torch.distributions.categorical.Categorical(probs=weights).sample((n_samples,))
               
        # Sample from the individual distributions
        samples = torch.tensor([bin_mids[i]+torch.rand(1)*bin_size - bin_size/2. for i in sample_ids], device=weights.device)

        return samples
        

In [18]:
a = SingleElectronResponseRandomVariable(
        3., 1.5, 0., 0.2, 0.3, 0.4)

inp = torch.arange(-2., 8., 0.01)

lik_sample = a.sample(n_samples=100000, n_categories=700)
hist, bins = np.histogram(lik_sample.detach().cpu(), 100)
plt([plt_type.Bar(y=hist/hist.sum()*10, x=bins),
    plt_type.Scatter(y=a.distribution.log_prob(inp).exp().clamp(0.,1.).view(-1), x=inp.view(-1))])

In [19]:
lik_hist = np.histogram(lik_sample.detach().cpu(), bins_extended)[0]/n_samples*train_y.size(1)

NameError: name 'bins_extended' is not defined

In [20]:
a = SingleElectronResponse(
        3., 1.5, 0., 0.2, 0.3, 0.4)

inp = torch.arange(-10., 10., 0.01)
plot(a.log_prob(inp).exp().clamp(0.,1.).view(-1,1), inp.view(-1))

In [50]:
plot(logit(logistic(inp).view(-1,1)), inp)

In [21]:
#R. Dossi et al. / Nuclear Instruments and Methods in Physics Research A 451 (2000) 623}637
class PoissonInputUnderamplifiedPhotomultiplierLikelihood(BasePhotomultiplierLikelihood):
    """
    Assumes that the incoming discrete number gets linearly multiplied by log_gain.exp(),
    and the observed value is a Normal distribution with a given input-offset, 
    whose mean and variance comes from the observed count.
    
    There is also an explicit "pedestal", which models the distribution around 0 counts.
    
    Furthermore the Single PhotoElectron Response (SER) has an underamplified portion described as an exponential.
    For f1 >= 2, the multiple photoelectron response is approximated as a gaussian with the linearly scaled
        mean and variance of the SER.
    
    p(f1 | f) = Poisson(f)
    p(y -offset | f1) = 
        {
            N_y (  0 , sigma_y_0)                       if f1 == 0
            pE * Exp_y(A) + (1-pE)*N_y (gain, s0^2 )    if f1 == 1
            N_y ( f1*x1,  f1^2 * s1^2 )                 if f1 >= 2,
                
                    where we x1 and s1 are the moments of p(y-offset|f1==1):
                        x1 \approx (1-pE)*gain + pe*A
                        s1 \approx (1-pE)*(gain^2+s0^2) + pe*2*A^2-x1^2
    
        # noise = s0^2
        # noise_pedestal = sigma_y_0
        # nderamplified_probability = pE
        # underamplified_amplitude = A
    """
    def __init__(self, gain=None, offset=None, noise=None, noise_pedestal=None, 
                 underamplified_probability=None, underamplified_amplitude=None): 
        super(PoissonInputUnderamplifiedPhotomultiplierLikelihood, self).__init__(
            gain = gain, offset = offset, noise=noise)
        
        self.register_parameter(name="log_noise_pedestal", 
                                parameter=toTorchParam(noise_pedestal if noise_pedestal is not None else 5., ndims=1, to_log=True), 
                                prior=SmoothedBoxPrior(-5, 5, sigma = 0.1))
        
        self.register_parameter(name="logit_underamplified_probability", 
                                parameter=toTorchParam(
                                    logit(underamplified_probability)
                                        if underamplified_probability is not None 
                                        else -2., ndims=1), 
                                prior=SmoothedBoxPrior(-6, 0, sigma = 0.01)) # between 0 and 0.5
        
        self.register_parameter(name="log_underamplified_amplitude", 
                                parameter=toTorchParam(
                                    underamplified_amplitude 
                                        if underamplified_amplitude is not None 
                                        else self.log_gain.clone().exp().div(2), ndims=1, to_log=True), 
                                prior=SmoothedBoxPrior(-5, 5, sigma = 0.1))
    
    
    def getLogProbSumOverTargetSamples(self, p_PM, cur_target_slice):
            """ 
            As we are supposedly dealing with truncated normals and expmod-normals, we need to replace
            log probabilities of 0s in the cur_target with the (log) CDF at 0 instead of the log_prob
            
            p_PM is going to be a list here, with:
              p_PM[0] = noise distribution
              p_PM[1] = Single photon (underamplified) distribution
              p_PM[2] = Multi-photon distribution with mean and var linearly amplified from SER
            
            """
            

            allLogProbs = torch.cat(
                [p_PM[0].log_prob(cur_target_slice.view(-1,1)).view(-1,1),
                p_PM[1].log_prob(cur_target_slice.view(-1,1)).view(-1,1),
                p_PM[2].log_prob(cur_target_slice.view(-1,1))],
                dim = 1)
            
            # Correct for the less than 1 observations with log CDF instead of log_prob
            if (cur_target_slice<1.).sum()>0:       
                allLogProbs[cur_target_slice<1., :] = torch.cat(
                    [p_PM[0].cdf(1.).view(-1),
                    p_PM[1].cdf(1.).view(-1),
                    p_PM[2].cdf(1.).view(-1)],
                    dim = 0)
                
            
            # Sum over target samples
            return allLogProbs.unsqueeze(1).sum(-2).squeeze() # This should work for 1d or 2d targets
    
        
    def single_log_prob(self, inp, cur_target, batchsize = int(500), max_photon=15.):
        """
        Input is directly from latent_func, means it still needs to be exponentiated before multiplying.
        Input is the log mean of a Poisson distribution
        
        Input shape is [QuadratureWeights, N]
        cur_target is [N,]
        """
        
        # Have to do this in mini-batches as the resulting QuadWeights x MaxPhotons x N array is too big
        all_res = torch.zeros_like(inp)
        
        photon_counts, photon_log_probs = (
                self.getPhotonLogProbs(inp[:,:2].exp(), max_photon=float(max_photon), reNormalise = False))

        
        # p_PM is going to be a list here, with:
        #   p_PM[0] = noise distribution
        #   p_PM[1] = Single photon (underamplified) distribution
        #   p_PM[2] = Multi-photon distribution with mean and var linearly amplified from SER
        
        pE = logistic(self.logit_underamplified_probability).squeeze()
        A = self.log_underamplified_amplitude.exp().squeeze()
        gain = self.log_gain.exp().squeeze()
        s0 = self.log_noise.mul(0.5).exp()
        
        
        p_PM = list()
        p_PM.append(torch.distributions.Normal(loc=self.offset.squeeze(), 
                                               scale = self.log_noise_pedestal.mul(0.5).exp().squeeze()))
        p_PM.append(SingleElectronResponse(loc=(gain+self.offset).squeeze(), 
                                           scale=s0.squeeze(),
                                           pedestal_loc = self.offset.squeeze(), 
                                           pedestal_scale = self.log_noise_pedestal.mul(0.5).exp().squeeze(),
                                           underamplified_amplitude = A,
                                           underamplified_probability = pE
                                          )
                   )
                    
        
        # For higher photon counts we ignore convolution with the noise and compute moments of idealised SER0
        multiphoton_loc_base = pE*A + (1-pE)*gain                         
        multiphoton_scale_base = (pE*2.*A.pow(2.) + (1.-pE)*(gain^2+s0^2) - multiphoton_loc_base.pow(2.)).sqrt() # second moment of Exp and Normal - mean^2
        
        p_PM.append(torch.distributions.Normal(loc=(photon_counts[2:]*multiphoton_loc_base+self.offset).squeeze(), 
                                               scale=(photon_counts[2:]*multiphoton_scale_base).squeeze())
                                         )


        #set_trace()
        
        
        for i in range(int(int(inp.size(1))/batchsize)+1):
            cur_slice = slice(i*batchsize,min((i+1)*batchsize, inp.size(1)))
            inp_cur = inp[:,cur_slice]
            if inp_cur.ndimension()==1:
                inp_cur = inp_cur.unsqueeze(1)
            
            photon_counts, photon_log_probs = (
                self.getPhotonLogProbs(inp_cur.exp(), max_photon=float(max_photon), reNormalise = False))

            ### return photon_counts, photon_log_probs, p_PM

            log_prob_sum_over_target_samples = self.getLogProbSumOverTargetSamples(p_PM, cur_target[cur_slice])
            if log_prob_sum_over_target_samples.ndimension()==1: # Correct for if we only get a single input value
                log_prob_sum_over_target_samples = log_prob_sum_over_target_samples.unsqueeze(0)



            all_res[:,cur_slice] += (log_prob_sum_over_target_samples.permute(1,0).unsqueeze(1) + photon_log_probs).logsumexp(dim=0)
            
            #del log_prob_sum_over_target_samples, photon_counts, photon_log_probs, p_PM
            
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        
        return all_res # Same dimensionality as inp [QuadratureWeights, N]
    
    
    
    def forward(self, latent_func, approx=False, max_photon=int(20), batchsize = int(1000)):
        """
        Computes predictive distributions p(y|x) given a latent distribution
        p(f|x). To do this, we solve the integral:
            p(y|x) = \int p(y|f)p(f|x) df
            
        As the true representation gets storage expensive for large latent_func dimensionalities, 
        we provide an approximate method by default in which we only store the true mean and variance for each output
        """
        
        
        input_device = latent_func.mean().device
        
        if approx is False: ## Return the full Batch of MixtureRandomVariables (huge object, use only for few predictions)
            
            # Get the input standard deviation
            latStd = getVar(latent_func).sqrt()
            
            # Get the photon log probabilities
            photon_log_probs = self.integrator.batch_integrate_gauss(
                    lambda x: self.getPhotonLogProbs(x.exp(), max_photon=float(max_photon), reNormalise = False)[1].permute(1,0,2),
                    mu = latent_func.mean(),
                    sig = latStd,
                    viewAs = [-1, 1, 1]
               )
            
            
            
            # As these probabilities are a result of non-perfect integration, we need to renormalise them outside of the integration
            photon_log_probs -= photon_log_probs.logsumexp(0).view(1,-1)
            
            # For very low probabilities, the renomalisation may not have been perfect, renormalise once more in probability space
            photon_probs = photon_log_probs.exp()
            photon_probs = photon_probs.div(photon_probs.sum(0).unsqueeze(0))
            
            pE = logistic(self.logit_underamplified_probability).squeeze()
            A = self.log_underamplified_amplitude.exp().squeeze()
            gain = self.log_gain.exp().squeeze()
            s0 = self.log_noise.mul(0.5).exp()
            
            # Underlying mixture elements
            p_PM_each = [GaussianRandomVariable(self.offset, self.log_noise_pedestal.exp().diag())]
            
            
            
            p_PM_each.append(SingleElectronResponseRandomVariable(loc=(gain+self.offset).squeeze(), 
                                           scale=s0.squeeze(),
                                           pedestal_loc = self.offset.squeeze(), 
                                           pedestal_scale = self.log_noise_pedestal.mul(0.5).exp().squeeze(),
                                           underamplified_amplitude = A,
                                           underamplified_probability = pE
                                          )
                   )
                    
        
            # For higher photon counts we ignore convolution with the noise and compute moments of idealised SER0
            multiphoton_loc_base = pE*A + (1-pE)*gain                         
            multiphoton_scale_base = (pE*2.*A.pow(2.) + (1.-pE)*(gain.pow(2)+s0.pow(2)) - multiphoton_loc_base.pow(2.)).sqrt() # second moment of Exp and Normal - mean^2
            
            p_PM_each.extend([
                GaussianRandomVariable((i*multiphoton_loc_base+self.offset).view(-1), 
                                   (i*multiphoton_scale_base.view(1,1)).pow(2.))
                for i in range(2, max_photon)])
            
            
            
            return preprocRandomVariables.BatchRandomVariable(
                *[preprocRandomVariables.MixtureRandomVariableWithSampler(*p_PM_each, weights=photon_probs[:,i])
                  for i in range(photon_probs.size(1))]
                 )
            
        else: # We are just returning the predictive mean and variance
        
            raise NotImplementedError

In [22]:
a = PoissonInputUnderamplifiedPhotomultiplierLikelihood()

In [23]:
GaussianRandomVariable(inp, gpytorch.lazy.DiagLazyVariable(0.2*torch.ones_like(inp).view(1,1))).covar().evaluate()

RuntimeError: invalid argument 2: size '[1 x 1]' is invalid for input with 2000 elements at /opt/conda/conda-bld/pytorch_1532584813488/work/aten/src/TH/THStorage.cpp:84

In [175]:
tmp.var()

tensor([ 4.5974, 12.2029, 64.5660])

In [176]:
tmp.rand_vars[1].rand_vars[2].sample(10)

RuntimeError: expand(torch.FloatTensor{[1, 1, 1]}, size=[1, 1]): the number of sizes provided (2) must be greater or equal to the number of dimensions in the tensor (3)

In [188]:
a.log_underamplified_amplitude.exp()

tensor([0.5000], grad_fn=<ExpBackward>)

In [190]:
inp = torch.arange(0., 3., 1.)
inp = torch.tensor(0.1).view(-1,1)
tmp = a(GaussianRandomVariable(inp, gpytorch.lazy.DiagLazyVariable(0.2*torch.ones_like(inp))))

lik_sample = tmp.sample(n_samples=10000)
hist, bins = np.histogram(lik_sample.detach().cpu(), 100)
plt([plt_type.Bar(y=hist/hist.sum()*10, x=bins)]),

(None,)

In [108]:
tmp.sample

((((Parameter containing:
tensor([0.], requires_grad=True), <gpytorch.lazy.non_lazy_variable.NonLazyVariable object at 0x7f7814039860>), SingleElectronResponse(), (tensor([1.8808], grad_fn=<ThAddBackward>), <gpytorch.lazy.non_lazy_variable.NonLazyVariable object at 0x7f7814039940>), (tensor([2.8212], grad_fn=<ThAddBackward>), <gpytorch.lazy.non_lazy_variable.NonLazyVariable object at 0x7f7814039be0>), (tensor([3.7616], grad_fn=<ThAddBackward>), <gpytorch.lazy.non_lazy_variable.NonLazyVariable object at 0x7f7814039a58>), (tensor([4.7020], grad_fn=<ThAddBackward>), <gpytorch.lazy.non_lazy_variable.NonLazyVariable object at 0x7f78140394a8>), (tensor([5.6424], grad_fn=<ThAddBackward>), <gpytorch.lazy.non_lazy_variable.NonLazyVariable object at 0x7f7814039a90>), (tensor([6.5828], grad_fn=<ThAddBackward>), <gpytorch.lazy.non_lazy_variable.NonLazyVariable object at 0x7f7814039a20>), (tensor([7.5232], grad_fn=<ThAddBackward>), <gpytorch.lazy.non_lazy_variable.NonLazyVariable object at 0x7f7818

In [39]:
?torch.logsumexp()

In [None]:
import torc

## TEsting

In [2]:
import torch
import gpytorch

import preprocRandomVariables
import preprocUtils

In [3]:
from gpytorch.functions import add_diag
from gpytorch.likelihoods import Likelihood
from gpytorch.priors._compatibility import _bounds_to_prior
from gpytorch.random_variables import GaussianRandomVariable, MixtureRandomVariable, RandomVariable
from gpytorch.priors import SmoothedBoxPrior

from torch.distributions import Poisson, Normal

In [5]:
from preprocLikelihoods import *

In [17]:
import nbimporter
from preprocVisualisationTesting import plot

Importing Jupyter notebook from preprocVisualisationTesting.ipynb


In [18]:
tmp = PoissonInputPhotomultiplierLikelihood(gain=3., offset=1.3)

In [19]:
n_inp = int(3)
n_dim = int(2)
cur_target = torch.rand([n_inp, n_dim])
latent_func = GaussianRandomVariable(torch.rand(n_inp), torch.eye(n_inp))
#tmp.log_probability(latent_func, cur_target)

In [21]:
inp = torch.arange(-10., 10., 0.01)
plot(tmp.single_log_prob(inp).exp().clamp(0.,1.).view(-1,1), inp.view(-1))

TypeError: single_log_prob() missing 1 required positional argument: 'cur_target'

In [8]:
tmp(latent_func)

(tensor([1.8418, 2.2513, 2.1908], grad_fn=<SelectBackward>), <gpytorch.lazy.diag_lazy_variable.DiagLazyVariable object at 0x7faf9d533630>)

In [9]:
a = 3

In [10]:
poissMeans = latent_func.mean().exp().unsqueeze(0).expand(10,-1)
photon_counts, photon_log_probs = (
            tmp.getPhotonLogProbs(poissMeans, max_photon=30., reNormalise = False))

In [11]:
photon_counts.shape

torch.Size([30, 1, 1])

In [12]:
photon_log_probs

tensor([[[ -2.3048,  -1.1776,  -1.5038],
         [ -2.3048,  -1.1776,  -1.5038],
         [ -2.3048,  -1.1776,  -1.5038],
         [ -2.3048,  -1.1776,  -1.5038],
         [ -2.3048,  -1.1776,  -1.5038],
         [ -2.3048,  -1.1776,  -1.5038],
         [ -2.3048,  -1.1776,  -1.5038],
         [ -2.3048,  -1.1776,  -1.5038],
         [ -2.3048,  -1.1776,  -1.5038],
         [ -2.3048,  -1.1776,  -1.5038]],

        [[ -1.4698,  -1.0141,  -1.0958],
         [ -1.4698,  -1.0141,  -1.0958],
         [ -1.4698,  -1.0141,  -1.0958],
         [ -1.4698,  -1.0141,  -1.0958],
         [ -1.4698,  -1.0141,  -1.0958],
         [ -1.4698,  -1.0141,  -1.0958],
         [ -1.4698,  -1.0141,  -1.0958],
         [ -1.4698,  -1.0141,  -1.0958],
         [ -1.4698,  -1.0141,  -1.0958],
         [ -1.4698,  -1.0141,  -1.0958]],

        [[ -1.3280,  -1.5438,  -1.3810],
         [ -1.3280,  -1.5438,  -1.3810],
         [ -1.3280,  -1.5438,  -1.3810],
         [ -1.3280,  -1.5438,  -1.3810],
         [ -

In [13]:
tmp.getPhotonLogProbs(poissMeans, max_photon=30., reNormalise = False)[1].permute(1,0,2).shape

torch.Size([10, 30, 3])

In [14]:
log_prob_sum_over_target_samples = p_PM.log_prob(cur_target.unsqueeze(-1)).unsqueeze(1).sum(-2).squeeze() # This should work for 1d or 2d targets


NameError: name 'p_PM' is not defined

In [None]:
log_prob_sum_over_target_samples.shape

In [None]:
photon_log_probs.shape

In [None]:
(log_prob_sum_over_target_samples.permute(1,0).unsqueeze(1) + photon_log_probs).logsumexp(dim=0).shape

In [None]:
p_PM.log_prob(cur_target[:,0].unsqueeze(-1)).unsqueeze(1).sum(-2).squeeze().shape

In [None]:
log_prob_sum_over_target_samples.shape

In [None]:
poissMeans = latent_func.mean().exp().unsqueeze(0).expand(10,-1)
photon_counts, photon_log_probs = (
            tmp.getPhotonLogProbs(poissMeans, max_photon=30., reNormalise = False))

In [None]:
photon_log_probs.shape

In [None]:
lambda_cur = Poisson(poissMeans)
# photon_counts = torch.arange(max_photon, device=poissMeans.device).unsqueeze(1)
# photon_log_probs = lambda_cur.log_prob(photon_counts) 

In [None]:
lambda_cur.log_prob(torch.arange(max_photon, device=poissMeans.device).view(-1,1,1)).shape

In [None]:
photon_log_probs.shape

In [None]:
latent_func.covar().evaluate().shape

In [None]:
n_inp = int(3)
n_dim = int(2)
cur_target = torch.rand([n_inp, n_dim])
latent_func = GaussianRandomVariable(torch.rand(n_inp), torch.eye(n_inp))
tmp.log_probability(latent_func, cur_target)


In [None]:
torch.eye(n_inp)

In [None]:
photon_counts.shape

In [None]:
photon_log_probs.shape

In [None]:
log_prob_sum_over_target_samples.shape

In [None]:
log_prob_sum_over_target_samples = p_PM.log_prob(cur_target.unsqueeze(-1)).unsqueeze(1).sum(-2) # This should work for 1d or 2d targets
(log_prob_sum_over_target_samples.permute(1,0) + photon_log_probs).logsumexp(dim=0) # Numerically more stable logsumexp


In [None]:
a.scale[0] += torch.tensor([[2.]]).squeeze()

In [None]:
a.scale[0] += 1.2

In [None]:
a.scale

In [None]:
torch.tensor(7.).exp()

In [None]:
tmp = BasePhotomultiplierLikelihood()

In [None]:
tmp.getPhotonLogProbs(poissMeans)[1].shape