In [1]:
import numpy as np, numpy.random as npr, torch.nn as nn, copy, timeit, torch
from torch.distributions.bernoulli import Bernoulli 
from HMCfunctions import *
from time import time
import matplotlib.pyplot as plt
%matplotlib inline
from pylab import plot, show, legend

Gradient w.r.t. $\theta$:

$$
\nabla_\theta \log p(\theta \mid \sigma, x_{1:n}, y_{1:n}) 
= 
- \sum_{i=1}^n \frac{\left ( \mu_\theta(x_i) - y_i \right ) \, \nabla_\theta \mu_\theta(x_i) }{\sigma^{2k}} + \nabla_\theta \log p_0(\theta) 
$$

#### Set up neural network:

In [2]:
n_in = 1
n_h1 = 5
n_out = 2

nn_model = nn.Sequential(nn.Linear(n_in, n_h1),
                        nn.ReLU(),
                        nn.Linear(n_h1, n_out))

* Total number of parameters:

In [3]:
print(sum(p.numel() for p in nn_model.parameters()))

22


* Randomly initialise model parameters:

In [4]:
nn_model.apply(init_normal)

Sequential(
  (0): Linear(in_features=1, out_features=5, bias=True)
  (1): ReLU()
  (2): Linear(in_features=5, out_features=2, bias=True)
)

#### Generate some random data:

In [5]:
nobs = 100
x = torch.rand(nobs, n_in)
y = np.zeros((nobs, n_out))
for i in range(nobs) :
    y[:,0] = list(np.cos(2*np.pi*x))
    y[:,1] = list(np.sin(2*np.pi*x))
y = torch.from_numpy(y).float()
criterion = nn.MSELoss()

* Get dimensions of parameters:

In [6]:
shapes = get_shapes(nn_model)
print("Shapes = ", shapes)

Shapes =  [torch.Size([5, 1]), torch.Size([5]), torch.Size([2, 5]), torch.Size([2])]


###  HMC

* First define the MCMC chain and randomly initialise it:

In [7]:
T = 2_000
chain = []
for shape in shapes :
    chain_shape = list(shape)
    chain_shape.insert(0,T)
    chain.append(torch.randn(chain_shape, requires_grad=False))

* Then run HMC:

In [8]:
delta_leapfrog = 1e-1
n_leapfrog = 10
sigma = 1
nn_model.apply(init_normal)
n_accept = 0

In [9]:
start_time = time()
for t in range(T) : 
    updated_nn_model, a = HMC_1step(nn_model, n_leapfrog, delta_leapfrog, shapes, x, y, criterion, sigma)
    n_accept += a
    for (i,param) in enumerate(nn_model.parameters()) :
        chain[i][t] = param.data
        
    if ((t+1) % 200 == 0) or (t+1) == T :
        accept_rate = float(n_accept) / float(t+1)
        print("iter %6d/%d after %7.1f sec | accept_rate %.3f" % (
            t+1, T, time() - start_time, accept_rate))

iter    200/2000 after     0.9 sec | accept_rate 0.555
iter    400/2000 after     1.8 sec | accept_rate 0.562
iter    600/2000 after     2.7 sec | accept_rate 0.563
iter    800/2000 after     3.5 sec | accept_rate 0.559
iter   1000/2000 after     4.4 sec | accept_rate 0.560
iter   1200/2000 after     5.2 sec | accept_rate 0.557
iter   1400/2000 after     6.1 sec | accept_rate 0.556
iter   1600/2000 after     6.9 sec | accept_rate 0.553
iter   1800/2000 after     7.8 sec | accept_rate 0.554
iter   2000/2000 after     8.6 sec | accept_rate 0.551


#### ESS's:

In [10]:
find_ESS(chain, shapes, True);

ESS: 979/2000
ESS: 833/2000
ESS: 1104/2000
ESS: 1120/2000
ESS: 938/2000
ESS: 934/2000
ESS: 827/2000
ESS: 1142/2000
ESS: 731/2000
ESS: 979/2000
ESS: 1023/2000
ESS: 835/2000
ESS: 1017/2000
ESS: 996/2000
ESS: 952/2000
ESS: 1133/2000
ESS: 871/2000
ESS: 886/2000
ESS: 871/2000
ESS: 1056/2000
ESS: 1166/2000
ESS: 1258/2000
