<a href="https://colab.research.google.com/github/ehsan-lari/pyro101/blob/main/pyro_02.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## The variational model: A Guassian model with unknown mean and variance

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from scipy import special, stats

In [None]:
ALPHA_PRIOR, BETA_PRIOR = 1e-2, 1e-2
MU_PRIOR = 0
TAU_PRIOR = 1e-6

np.random.seed(42)
N = 100
TRUE_MEAN = 5
TRUE_PRECISION = 1
x = np.random.normal(loc=TRUE_MEAN, scale=1./np.sqrt(TRUE_PRECISION), size=N)

In [None]:
def calculate_ELBO(data, tau, alpha, beta, nu_p, tau_p, alpha_p, beta_p):
  log_p = -0.5 * np.log(2 * np.pi) + 0.5 * np.log(tau) - 0.5 * tau * (1 / tau_p + nu_p * nu_p )

  log_p = log_p + alpha * np.log(beta) + \
          (alpha - 1) * ( special.digamma(alpha_p) - np.log(beta_p) ) - beta * alpha_p / beta_p

  for xi in data:
    log_p += -0.5 * np.log(2 * np.pi) \
             + 0.5 * ( special.digamma(alpha_p) - np.log(beta_p) ) \
             - 0.5 * alpha_p / beta_p * ( xi * xi - 2 * xi * nu_p + 1 / tau_p + nu_p * nu_p )

  entropy = 0.5 * np.log( 2 * np.pi * np.exp(1) / tau_p )
  entropy += alpha_p - np.log( beta_p ) + special.gammaln(alpha_p) \
             + ( 1 - alpha_p ) * special.digamma(alpha_p)

  return log_p + entropy

In [None]:
alpha_q = ALPHA_PRIOR
beta_q = BETA_PRIOR
nu_q = MU_PRIOR
tau_q = TAU_PRIOR
previous_ELBO = -np.inf
ITERATION = 50

for iteration in range(ITERATION):
  alpha_q = ALPHA_PRIOR + 0.5 * N
  beta_q = BETA_PRIOR + 0.5 * np.sum(x * x) - nu_q * np.sum(x) + 0.5 * N * ( 1. / tau_q + nu_q * nu_q )

  expected_gamma = alpha_q / beta_q
  tau_q = TAU_PRIOR + N * expected_gamma
  nu_q = expected_gamma * np.sum(x) / tau_q

  current_ELBO = calculate_ELBO(data=x,
                                tau=TAU_PRIOR,
                                alpha=ALPHA_PRIOR,
                                beta=BETA_PRIOR,
                                nu_p=nu_q,
                                tau_p=tau_q,
                                alpha_p=alpha_q,
                                beta_p=beta_q)

  print("Iteration: {} ELBO: {}".format(iteration, current_ELBO))

  previous_ELBO = current_ELBO

Iteration: 0 ELBO: -786.1880330567017
Iteration: 1 ELBO: -557.6868575549023
Iteration: 2 ELBO: -330.4821015062628
Iteration: 3 ELBO: -154.18289252127906
Iteration: 4 ELBO: -142.16324839011187
Iteration: 5 ELBO: -142.15974543444332
Iteration: 6 ELBO: -142.1597450787318
Iteration: 7 ELBO: -142.1597450786965
Iteration: 8 ELBO: -142.1597450786966
Iteration: 9 ELBO: -142.15974507869643
Iteration: 10 ELBO: -142.15974507869635
Iteration: 11 ELBO: -142.1597450786962
Iteration: 12 ELBO: -142.1597450786962
Iteration: 13 ELBO: -142.1597450786962
Iteration: 14 ELBO: -142.1597450786962
Iteration: 15 ELBO: -142.1597450786962
Iteration: 16 ELBO: -142.1597450786962
Iteration: 17 ELBO: -142.1597450786962
Iteration: 18 ELBO: -142.1597450786962
Iteration: 19 ELBO: -142.1597450786962
Iteration: 20 ELBO: -142.1597450786962
Iteration: 21 ELBO: -142.1597450786962
Iteration: 22 ELBO: -142.1597450786962
Iteration: 23 ELBO: -142.1597450786962
Iteration: 24 ELBO: -142.1597450786962
Iteration: 25 ELBO: -142.15974