In [44]:
import pyro
import torch
import math
from pyro.optim import SGD, Adam
import pyro.distributions as dist
from torch.distributions import constraints
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import beta
%matplotlib inline

import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

## Introduction

In this notebook we the simple generative model from Slide 18, which you also experimented with in the notebook *student_BBVI.ipynb*:
 * https://www.moodle.aau.dk/mod/resource/view.php?id=1049031

In the previous notebook we derived the required gradients manually. Here we instead rely on differentiation functionality in Pyro, which, in turn is based n PyTorch.

## The model in plate notation

<img src="https://www.moodle.aau.dk/pluginfile.php/1695750/mod_folder/content/0/mean_model.png?forcedownload=1" width="600">

## The model defined in Pyro

Here we define the probabilistic model. Notice the close resemblance with the plate specification above.

In [3]:
def mean_model(data):

    # Define the random variable mu having a normal distribution as prior
    mu = pyro.sample("mu", dist.Normal(0.0,1000.0))

    # and now the plate holding the observations. The number of observations are determined by the data set 
    # supplied to the function. 
    with pyro.plate("x_plate"):
        pyro.sample(f"X", dist.Normal(mu, 1), obs=data)

## The variational distribution

In Pyro the variational distribution is defined as a so-called guide. In this example our variational distribution is a beta distribution with parameters q_alpha and q_beta:

$$
q(\mu)= \mathit{Normal}(\mu | q_{mu}, 1)
$$

In [4]:
def mean_guide(data):

    # We initialize the variational parameter to 0.0. 
    q_mu = pyro.param("q_mu", torch.tensor(0.0))

    # The name of the random variable of the variational distribution must match the name of the corresponding
    # variable in the model exactly.
    pyro.sample("mu", dist.Normal(q_mu, 1.0))

## Learning

Here we encapsulate the learning steps, relying on standard stochastic gradient descent

In [15]:
def learn(data):

    pyro.clear_param_store()

    elbo = pyro.infer.Trace_ELBO()
    svi = pyro.infer.SVI(model=mean_model,
                         guide=mean_guide,
                         optim=SGD({'lr':0.0001}),
                         loss=elbo)

    num_steps = 1000
    for step in range(num_steps):
        loss = svi.step(data)

        if step % 50 == 0:
            print(f"Loss for iteration {step}: {loss}")

In [6]:
data = torch.tensor(np.random.normal(loc=10.0, scale=1.0, size=100),dtype=torch.float)
learn(data)

Loss for iteration 0: 5436.3763236403465
Loss for iteration 50: 1937.3406717777252
Loss for iteration 100: 698.8343402147293
Loss for iteration 150: 414.3496378660202
Loss for iteration 200: 319.3915444612503
Loss for iteration 250: 149.88095009326935
Loss for iteration 300: 257.7423424720764
Loss for iteration 350: 189.83737969398499
Loss for iteration 400: 145.53304851055145
Loss for iteration 450: 144.2721302509308
Loss for iteration 500: 153.16068708896637
Loss for iteration 550: 145.11027246713638
Loss for iteration 600: 190.50959634780884
Loss for iteration 650: 148.2471822500229
Loss for iteration 700: 147.65406131744385
Loss for iteration 750: 211.7738790512085
Loss for iteration 800: 288.0651717185974
Loss for iteration 850: 150.1859109401703
Loss for iteration 900: 183.0964126586914
Loss for iteration 950: 144.34351563453674


Get the learned variational parameter

## The learned parameter

In [7]:
qmu = pyro.param("q_mu").item()

In [8]:
print(f"Mean of vaiational distribution: {qmu}")

Mean of vaiational distribution: 10.066452026367188


## Exercise
* Adapt the code above to accomodate a slight more rich variational distribution, where we also have a variational parameter for the standard deviation:
$$
q(\mu)= \mathit{Normal}(\mu | q_{mu}, q_{std})
$$
* Experiment with different data sets and parameter values. Try visualizing the variational posterior distribution.

In [40]:
def mean_std_model(data):

    # Define the random variable mu having a normal distribution as prior
    mu = pyro.sample("mu", dist.Normal(0.0,1000.0))
    log_std = pyro.sample("std", dist.Normal(0,1))
    std = math.exp(log_std)

    # and now the plate holding the observations. The number of observations are determined by the data set 
    # supplied to the function. 
    with pyro.plate("x_plate"):
        pyro.sample(f"X", dist.Normal(mu, std), obs=data)

In [41]:
def mean_std_guide(data):

    # We initialize the variational parameter to 0.0. 
    q_mu = pyro.param("q_mu", torch.tensor(0.0))
    q_std = pyro.param("q_std", torch.tensor(1.0), constraint=constraints.positive)

    # The name of the random variable of the variational distribution must match the name of the corresponding
    # variable in the model exactly.
    pyro.sample("mu", dist.Normal(q_mu, 1,0))
    pyro.sample("std", dist.Normal(q_std, 1,0))

In [42]:
def learn_new(data):

    pyro.clear_param_store()

    elbo = pyro.infer.Trace_ELBO()
    svi = pyro.infer.SVI(model=mean_std_model,
                         guide=mean_std_guide,
                         optim=SGD({'lr':0.0001}),
                         loss=elbo)

    num_steps = 1000
    for step in range(num_steps):
        loss = svi.step(data)

        if step % 50 == 0:
            print(f"Loss for iteration {step}: {loss}")

In [45]:
data = torch.tensor(np.random.normal(loc=10.0, scale=1.0, size=100),dtype=torch.float)
learn_new(data)

Loss for iteration 0: 687.7784962654114
Loss for iteration 50: 387.34066212177277
Loss for iteration 100: 381.74849677085876
Loss for iteration 150: 310.92422914505005
Loss for iteration 200: 257.03968846797943
Loss for iteration 250: 223.47096055746078
Loss for iteration 300: 193.379743039608
Loss for iteration 350: 272.50359547138214
Loss for iteration 400: 142.11035233736038
Loss for iteration 450: 163.9282724261284
Loss for iteration 500: 222.60163289308548
Loss for iteration 550: 142.84347289800644
Loss for iteration 600: 260.3253366947174
Loss for iteration 650: 163.1066094636917
Loss for iteration 700: 196.40394812822342
Loss for iteration 750: 146.22518855333328
Loss for iteration 800: 165.56290924549103
Loss for iteration 850: 221.4096195101738
Loss for iteration 900: 276.9619903564453
Loss for iteration 950: 180.16667711734772


In [46]:
qmu = pyro.param("q_mu").item()
print(f"Mean of vaiational distribution: {qmu}")

Mean of vaiational distribution: 9.750102043151855
