# Analysis of the posterior distribution of a dataset with known $\sigma^2$ and unkown $\mu$

We assume a dataset $X\sim\mathcal{N}(\mu, \sigma^2)$ has known $\sigma^2$ and unkown $\mu$. The goal is to find the value (or the distribution) of $\mu$.

Since we do not have information about $\mu$, we choose a _prior_ distribution over $\mu$ taking into account that we know the distribution of the dataset $X$. For mathematical convenience we can choose $\mu$ to be a conjugate prior (meaning that it belongs to the same distribution familiy as $X$), that is, we assume $\mu \sim \mathcal{N}(\mu_0, \sigma^2_0)$.

It can be shown that the posterior distribution of $\mu$, $\mu_N$, is of the form

$$
    \mu_N \sim \mathcal{N}\left(\frac{\sigma^2\mu_0 + N\sigma_0^2\bar x}{\sigma^2 + N\sigma_0^2}, \left(\frac{1}{\sigma_0^2} + \frac{N}{\sigma^2} \right)^{-1}\right)
$$

In [2]:
import plotly.graph_objs as go
from plotly.offline import iplot
import numpy as np
from scipy.stats import norm

In [36]:
mu_0, sigma_0 = 0, 1
xrange = np.linspace(-3, 3, 300)
yrange = norm.pdf(xrange, loc=mu_0, scale=sigma_0)

In [37]:
data = [
    {"type": "scatter", "x": xrange, "y": yrange}
]

fig = go.FigureWidget(data=data)
iplot(fig)

Suppose a stream of data reaches our system one by one

In [38]:
mu_0, sigma_0 = 0, 1
# Our known variance
sigma = 0.8 # Our known variance
mu = 2 # our *unknown* mean
stream = np.random.randn(15) * sigma + mu

priors = {
    "mu": [mu_0],
    "sigma": [sigma_0]
}

print(f"mu={mu_0:0.2f}, sigma2={sigma_0:0.2f}")
for obs in stream:
    # Update our prior mean
    mu_0 = (sigma * mu_0 + 1 * sigma_0 * obs) / (sigma + 1 * sigma_0)
    # Update our prior variane
    sigma_0 = 1 / (1 / sigma_0 + 1 / sigma)
    
    priors["mu"].append(mu_0)
    priors["sigma"].append(sigma_0)
    print(f"mu={mu_0:0.2f}, sigma2={sigma_0:0.2f}")

mu=0.00, sigma2=1.00
mu=0.51, sigma2=0.44
mu=1.28, sigma2=0.29
mu=1.39, sigma2=0.21
mu=1.47, sigma2=0.17
mu=1.50, sigma2=0.14
mu=1.56, sigma2=0.12
mu=1.61, sigma2=0.10
mu=1.76, sigma2=0.09
mu=1.75, sigma2=0.08
mu=1.77, sigma2=0.07
mu=1.82, sigma2=0.07
mu=1.83, sigma2=0.06
mu=1.83, sigma2=0.06
mu=1.95, sigma2=0.05
mu=2.00, sigma2=0.05


In [40]:
data = [
    {"type": "scatter", "x": xrange, "y": norm.pdf(xrange, loc=mu, scale=sigma)}
for mu, sigma in list(zip(*priors.values()))]

fig = go.FigureWidget(data=data)
iplot(fig)