In [1]:
import torch
from torch.distributions import Normal
import math

### Bayesian Inferencing of both Mean and Precision of Gaussian likelihood

We will now study the case where both the mean and the variance is unknown.

We will start, as usual, with the likelihood term - expressing it in terms of the precision, $\lambda$ (which is related to variance $\sigma$ as $\lambda = \frac{ 1 } { \sigma^{2} }$).
The likelihood term is a Gaussian with mean $\mu$ precision $\lambda$

$$p\left( X \middle\vert \lambda\right) \propto \lambda^{ \frac{ n }{ 2 } }  e^{ -\frac{ \lambda } { 2 } \sum_{i=1}^{n} \left( {x^{ \left( i \right) } - \mu } \right)^2 }$$

We will make the prior for the mean a Gaussian with mean $\mu_{0}$ and precision $\lambda_{0} \lambda$

$$p\left( \mu \middle\vert \lambda \right) \propto   \lambda^{ \frac{1}{2} }  e^{ -\frac{ \lambda_{0} \lambda } { 2 }  \left(  \mu - \mu_{0}  \right)^{2} }$$

We will make the prior for the precision a Gamma distribution 
$$p\left( \lambda \right)  \propto  \lambda ^{ \left( \alpha_{0} - 1 \right)} e^{ - \beta_{0} \lambda  }$$

The overall prior function is a Normal-Gamma distribution

$$p\left( \mu, \lambda \right)  \propto  
\lambda^{ \frac{1}{2} }  e^{ -\frac{ \lambda_{0} \lambda } { 2 }  \left(  \mu - \mu_{0}  \right)^{2} } 
\;\;\; 
\lambda ^{ \left( \alpha_{0} - 1 \right)} e^{ - \beta_{0} \lambda  }$$

The corresponding posterior is also a Normal-Gamma distribution, such that 
$$p\left( \mu, \lambda \middle\vert X \right) \propto
e^{ -\frac{\lambda}{2} \lambda_{n}  \left( \mu - \mu_{n} \right)^{2} } \lambda^{ \alpha_{n} - \frac{1}{2} } e^{ -\beta_{n} \lambda }$$

where 
$$\mu_{n} = \frac{ \left( n \bar{x} + \mu_{0} \lambda_{0} \right) }{ n + \lambda_{0} } 
\lambda_{n} = n + \lambda_{0} \\
\alpha_{n} = \frac{n}{2} + \alpha_{0} 
\beta_{n} = \frac{ ns }{ 2 } + \beta_{ 0 } + \frac{ n \lambda_{0} } { 2 \left( n + \lambda_{0} \right) } \left( \bar{x} - \mu_{0} \right)^{ 2 }$$


In [2]:
# Torch does not support NormalGamma distributions yet 
# So let us implement a bare bones implementation of the Normal Gamma distribution

class NormalGamma():
    def __init__(self, mu_, lambda_, alpha_, beta_):
        self.mu_ = mu_
        self.lambda_ = lambda_
        self.alpha_ = alpha_
        self.beta_ = beta_
        
    @property
    def mean(self):
        return self.mu_, self.alpha_/ self.beta_

    
    @property
    def mode(self):
        return self.mu_, (self.alpha_-0.5)/ self.beta_

In [3]:
def inference_unknown_mean_variance(X, prior_dist):
    mu_mle = X.mean()
    sigma_mle = X.std()
    n = X.shape[0]
    # Parameters of the prior
    mu_0 = prior_dist.mu_
    lambda_0 = prior_dist.lambda_
    alpha_0 = prior_dist.alpha_
    beta_0 = prior_dist.beta_
    
    # Parameters of posterior
    mu_n = (n * mu_mle + mu_0 * lambda_0) / (lambda_0 + n) 
    lambda_n = n + lambda_0
    alpha_n = n / 2 + alpha_0
    beta_n = n / 2 * sigma_mle ** 2 + beta_0 + 0.5* n * lambda_0/(n + lambda_0) * (mu_mle - mu_0) **2 
    posterior_dist = NormalGamma(mu_n, lambda_n, alpha_n, beta_n)
    
    return posterior_dist

In [4]:
# Let us assume that the true distribution is a normal distribution. The true distribution corresponds 
# to a single class.

true_dist = Normal(20, 5)

In [5]:
# Case 1: Low data
# Let us assume our prior is a Gamma distribution with a good estimate of the variance
prior_dist = NormalGamma(19, 10, 1, 40)

# Let us set a seed for reproducability
torch.manual_seed(42)

# Number of samples is low. 
n = 3
X = true_dist.sample((n, 1))
posterior_dist_low_n = inference_unknown_mean_variance(X, prior_dist)

true_distribution_mu, true_distribution_std = true_dist.mean, true_dist.scale
mle_mu, mle_std = X.mean(), X.std()
map_mu, map_precision =  posterior_dist_low_n.mode
map_std = math.sqrt(1 / map_precision)

# When n is low, the posterior is dominated by the prior. Thus, a good prior can help offset the lack of data.
# We can see this in the following case. 

# With a small sample (n=3), the MLE estimate of the standard deviation is 0.52, which is way off from the true value of 5.0
# Using a good prior here helps offset it.

print(f"True distribution: mu {true_distribution_mu:0.2f} std {true_distribution_std:0.2f}")
print(f"MLE: mu {mle_mu:0.2f} std {mle_std:0.2f}")
print(f"MAP: mu {map_mu:0.2f} std {map_std:0.2f}")

True distribution: mu 20.00 std 5.00
MLE: mu 21.17 std 0.52
MAP: mu 19.50 std 4.79


In [6]:
# Case 2: High data
# Let us assume our prior is a Gamma distribution with a good estimate of the variance
prior_dist = NormalGamma(19, 10, 1, 40)

# Let us set a seed for reproducability
torch.manual_seed(42)

# Number of samples is low. 
n = 1000
X = true_dist.sample((n, 1))
posterior_dist_high_n = inference_unknown_mean_variance(X, prior_dist)

true_distribution_mu, true_distribution_std = true_dist.mean, true_dist.scale
mle_mu, mle_std = X.mean(), X.std()
map_mu, map_precision =  posterior_dist_high_n.mode
map_std = math.sqrt(1 / map_precision)

# When n is high, the MLE converges to the true distribution. The MAP also converges to the MLE, and in turn 
# converges to the true distribution

print(f"True distribution: mu {true_distribution_mu:0.2f} std {true_distribution_std:0.2f}")
print(f"MLE: mu {mle_mu:0.2f} std {mle_std:0.2f}")
print(f"MAP: mu {map_mu:0.2f} std {map_std:0.2f}")

True distribution: mu 20.00 std 5.00
MLE: mu 20.02 std 5.02
MAP: mu 20.01 std 5.02


### How to use the  estimated mean and variance parameters?

As usual, we will obtain the maxima of the posterior probability density function $p\left( \mu, \sigma \middle\vert X \right) = Normal-Gamma\left(  \mu, \sigma ; \;\; \mu_{n}, \lambda_{n}, \alpha_{n}, \beta_{n} \right) $.

This function attains its maxima when

$$\mu = \mu_{n} \\
\lambda = \frac{ \alpha_{n} - \frac{1}{2} } { \beta_{n} }$$

Thus, the probability density function for data instance $x$ belonging to the same class as the training data $X$ is $\mathcal{N} \left( x; \mu_{n} ,  \sigma_{n}  \right)$ where $\frac{1}{ \sigma_{n}^{2} } = \frac{ \alpha_{n} - \frac{1}{2} } { \beta_{n} }$.


In [7]:
map_mu, map_precision =  posterior_dist_high_n.mode
map_std = math.sqrt(1 / map_precision)
map_dist = Normal(map_mu, map_std)
print(f"MAP distribution mu: {map_dist.mean:0.2f} std:{map_dist.scale:0.2f}")

MAP distribution mu: 20.01 std:5.02
