### Attempting MALA with PyTorch's autograd

#### Model: 

\begin{eqnarray}
\text{Latent:} \quad X_t & = & A X_{t-1}  + \nu_t, 
\\
\text{Observed:} \quad Y_t & = & C X_t + B Z_t + \omega_t, \quad Z_t ~ \text{are covariates}
\\
\nu_t & \sim & \text{N}(0, Q ),
\\
\omega_t & \sim & \text{N}(0, R).
\end{eqnarray}

In [1]:
from __future__ import division
%matplotlib inline
from pykalman import KalmanFilter
import numpy as np, numpy.random as npr, matplotlib.pyplot as plt, copy, multiprocessing as mp, torch, pandas
from scipy.stats import *
from pylab import plot, show, legend
from tqdm import trange
from ozone_functions import *
from torch.distributions import multivariate_normal
from time import time

In [2]:
data = pandas.read_csv("data.csv").values
data = data[:,1::]

In [3]:
T = np.shape(data)[0]
Y = torch.FloatTensor(data[:,0:3])
Z = torch.FloatTensor(data[:,3::])

obs_dim = np.shape(Y)[-1]
lat_dim = 1
cov_dim = np.shape(Z)[-1]

In [4]:
A = torch.zeros(lat_dim,lat_dim, requires_grad=True)
C = torch.randn(obs_dim,lat_dim, requires_grad=True)
log_sigmay2 = torch.tensor(0., requires_grad=True)

Q = torch.eye(lat_dim)
R = torch.exp(log_sigmay2)*torch.eye(obs_dim)

B = torch.randn(obs_dim,cov_dim, requires_grad=True)
b = (torch.matmul(B,Z.transpose(0,1))).transpose(0,1)

In [5]:
mu0 = torch.zeros(lat_dim)
Sigma0 = torch.eye(lat_dim)

In [6]:
lat_dim, obs_dim, np.shape(C)

(1, 3, torch.Size([3, 1]))

In [7]:
def get_lpdf(Y, Z, A, C, B, log_sigmay2, mu0, Sigma0) :
    T = np.shape(Y)[0]
    filtering_mean = torch.clone(mu0.detach())
    filtering_cov = torch.clone(Sigma0.detach())
    
    lat_dim = np.shape(A)[0]
    obs_din = np.shape(C)[0]

    lpdf = torch.tensor(0.)
    Q = torch.eye(lat_dim)
    R = torch.exp(log_sigmay2)*torch.eye(obs_dim)

    for t in range(T) :
        predictive_mean = torch.matmul(A,filtering_mean)
        predictive_cov = torch.matmul(A,torch.matmul(filtering_cov,A.transpose(0,1)))

        K = torch.matmul(torch.matmul(predictive_cov,C.transpose(0,1)),\
                         (torch.inverse(torch.matmul(torch.matmul(C,predictive_cov),C.transpose(0,1))+R)))
        filtering_mean = predictive_mean + torch.matmul(K,(Y[t]- torch.matmul(C,predictive_mean)))
        filtering_cov = torch.matmul(torch.eye(lat_dim) - torch.matmul(K,C), predictive_cov) + Q

        mean = torch.matmul(C,filtering_mean) + b[t] 
        cov = torch.matmul(C,torch.matmul(filtering_cov,C.transpose(0,1))) + R

        dist = multivariate_normal.MultivariateNormal(loc=mean,covariance_matrix=cov)
        lpdf += dist.log_prob(Y[t])
    
    lpdf.backward(retain_graph=True)
    
    return lpdf, A.grad, C.grad, B.grad, log_sigmay2.grad

In [8]:
start = time()
lpdf, A_grad, C_grad, B_grad, log_sigmay2_grad = get_lpdf(Y[:100], Z[:100], A, C, B, log_sigmay2, mu0, Sigma0)
print(time()-start)

0.07656693458557129


In [9]:
A_grad.numpy(), C_grad, B_grad, log_sigmay2_grad

(array([[0.]], dtype=float32), tensor([[-219742.5938],
         [ 172460.3906],
         [  79737.4141]]), tensor([[-243049.7656, -472102.0000],
         [ 182992.8750,  377115.3750],
         [  85311.6641,  173620.4844]]), tensor(333977.8125))

In [12]:
def adaptive_MALA(Y, Z, A, C, B, log_sigmay2, mu0, Sigma0, 
                  n_mcmc, tauA, tauC, tausy, adapt=True, start_adapt=0.2, power=1, kappa=1) :
    
    npr.seed()
    scipy.random.seed()
    
    log_sigmay2_chain = torch.zeros(n_mcmc+1, requires_grad=False)
    log_sigmay2_chain[0] = log_sigmay2
    A_chain = torch.zeros((n_mcmc+1, *np.shape(A)), requires_grad=False)
    C_chain = torch.zeros((n_mcmc+1, *np.shape(C)), requires_grad=False)
    B_chain = torch.zeros((n_mcmc+1, *np.shape(B)), requires_grad=False)
    A_chain[0], C_chain[0], B_chain[0] = A, C, B
    
    accepted = 0
    last_accepted = 0
    
    start = time()
    A_current = torch.clone(A.detach())
    C_current = torch.clone(C.detach())
    B_current = torch.clone(B.detach())
    log_sigmay2_current = torch.clone(log_sigmay2.detach()) 
    
    A_current.requires_grad = True
    C_current.requires_grad = True
    B_current.requires_grad = True
    log_sigmay2_current.requires_grad = True
    
    for n in trange(n_mcmc) :
        
        ll_current, A_grad, C_grad, B_grad, log_sigmay2_grad \
        = get_lpdf(Y, Z, A_current, C_current, B_current, log_sigmay2_current, mu0, Sigma0)
    
        log_sigmay2_proposed = torch.tensor((log_sigmay2_current + tausy*log_sigmay2_grad \
                                             + torch.sqrt(2*tausy)*torch.randn(1)).detach().numpy(),
                                            requires_grad=True)
        A_proposed = torch.tensor((A_current + tauA*A_grad + torch.sqrt(2*tauA)*torch.randn(*np.shape(A))).detach().numpy(), 
                                  requires_grad=True)
        C_proposed = torch.tensor((C_current + tauC*C_grad + torch.sqrt(2*tauC)*torch.randn(*np.shape(C))).detach().numpy(), 
                                  requires_grad=True)
        B_proposed = torch.clone(B)
        
        if np.abs(A_proposed.detach().numpy()) < 1 :
        
            ll_proposed, A_grad_proposed, C_grad_proposed, B_grad_proposed, log_sigmay2_grad_proposed \
            = get_lpdf(Y, Z, A_proposed, C_proposed, B_proposed,
                       log_sigmay2_proposed, mu0, Sigma0)

            log_accept_ratio = power*(ll_proposed - ll_current)
            bottom = -power/(4*tauA)*torch.sum(A_proposed-A_current-tauA*A_grad).detach()**2 \
                     -power/(4*tauC)*torch.sum(C_proposed-C_current-tauC*C_grad).detach()**2 \
                     -power/(4*tausy)*torch.sum(log_sigmay2_proposed-log_sigmay2_current-tausy*log_sigmay2_grad).detach()**2
            top = -power/(4*tauA)*torch.sum(A_current-A_proposed-tauA*A_grad_proposed).detach()**2 \
                  -power/(4*tauC)*torch.sum(C_current-C_proposed-tauC*C_grad_proposed).detach()**2 \
                  -power/(4*tausy)*torch.sum(log_sigmay2_current-log_sigmay2_proposed-\
                                              tausy*log_sigmay2_grad_proposed).detach()**2

            log_accept_ratio = log_accept_ratio + top-bottom
        
            if np.log(npr.rand()) < log_accept_ratio.detach().numpy() :
                log_sigmay2_chain[n+1] = log_sigmay2_proposed.detach()
                A_chain[n+1] = A_proposed.detach()
                C_chain[n+1] = C_proposed.detach()
                B_chain[n+1] = B_proposed.detach()
                A_grad = A_grad_proposed.detach()
                C_grad = C_grad_proposed.detach()
                log_sigmay2_grad = log_sigmay2_grad_proposed.detach()
                accepted += 1
                last_accepted = n
        else :
            log_sigmay2_chain[n+1] = log_sigmay2_current.detach()
            A_chain[n+1] = A_current.detach()
            C_chain[n+1] = C_current.detach()
            B_chain[n+1] = B_current.detach()

    
    print(100*accepted/n_mcmc, "% acceptance rate")
    return log_sigmay2_chain.detach().numpy(), A_chain.detach().numpy(), C_chain.detach().numpy(), \
            B_chain.detach().numpy(), accepted

In [16]:
n_mcmc = 500
tauA = torch.tensor(1e-8)
tauC = torch.tensor(1e-8)
tausy = torch.tensor(1e-8)

log_sigmay2_chain, A_chain, C_chain, B_chain, accepted = \
adaptive_MALA(Y[:100], Z[:100], A, C, B, log_sigmay2, mu0, Sigma0, n_mcmc, tauA, tauC, tausy, adapt=False, power=1)

100%|██████████| 500/500 [01:01<00:00,  8.18it/s]

0.8 % acceptance rate





In [19]:
A_chain[:,0,0]

array([ 0.0000000e+00,  5.0922812e-05, -3.1224961e-05,  2.1016075e-04,
        8.2552287e-06,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
        0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
        0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
        0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
        0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
        0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
        0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
        0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
        0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
        0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
        0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
        0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
        0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
      