In [2]:
import torch
from torch.autograd import Variable, grad
from pyro.distributions import MultivariateNormal as MVN

In [3]:
help(MVN)

Help on class MultivariateNormal in module pyro.distributions.multivariate_normal:

class MultivariateNormal(pyro.distributions.distribution.Distribution)
 |  Multivariate normal (Gaussian) distribution.
 |  
 |  A distribution over vectors in which all the elements have a joint Gaussian
 |  density.
 |  
 |  :param torch.autograd.Variable loc: Mean. Must be a vector (Variable
 |      containing a 1d Tensor).
 |  :param torch.autograd.Variable covariance_matrix: Covariance matrix.
 |      Must be symmetric and positive semidefinite.
 |  :param torch.autograd.Variable scale_tril: The Cholesky decomposition of
 |      the covariance matrix. You can pass this instead of `covariance_matrix`.
 |  :param use_inverse_for_batch_log: If this is set to true, the torch.inverse
 |      function will be used to compute log_pdf. This means that the results of
 |      log_pdf can be differentiated with respect to the covariance matrix.
 |      Since the gradient of torch.potri is currently not implem

In [74]:
loc = Variable(torch.zeros(2))
scale_tril = Variable(torch.Tensor([[2, 0], [1, 2]]), requires_grad=True)
cov = torch.mm(scale_tril, scale_tril.transpose(0, 1))

In [75]:
cov

Variable containing:
 4  2
 2  5
[torch.FloatTensor of size 2x2]

In [76]:
x = MVN(loc, cov).sample().squeeze()
x

Variable containing:
 1.9858
 1.0031
[torch.FloatTensor of size 2]

In [77]:
dx_dt = torch.Tensor([grad([x[0]], [scale_tril], create_graph=True)[0].data[1, 0],
                      grad([x[1]], [scale_tril], create_graph=True)[0].data[1, 0]])
dx_dt


 0.0000
 0.9929
[torch.FloatTensor of size 2]

In [89]:
eps = torch.trtrs(x, scale_tril, upper=False)[0][:, 0]
eps

Variable containing:
 0.9929
 0.0051
[torch.FloatTensor of size 2]

In [87]:
torch.mv(scale_tril, eps)

Variable containing:
 1.9858
 1.0031
[torch.FloatTensor of size 2]

In [90]:
def compute_v(x):
    return torch.trtrs(x, scale_tril, upper=False)[0][:, 0]