In [1]:
import math, os, sys, torch, pyro, pyro.optim, pyro.infer
import numpy as np
from torch.autograd import Variable, grad, Function
from pyro.distributions import Distribution
from matplotlib import pyplot
%matplotlib inline
torch.set_default_tensor_type('torch.DoubleTensor')

In [2]:
class BivariateNormal(Distribution):
    reparameterized = True
    
    def __init__(self, loc, scale_triu, batch_size=None):
        self.loc = loc
        self.scale_triu = scale_triu
        self.batch_size = 1 if batch_size is None else batch_size

    def batch_shape(self, x=None):
        loc = self.loc.expand(self.batch_size, *self.loc.size()).squeeze(0)
        if x is not None:
            if x.size()[-1] != loc.size()[-1]:
                raise ValueError("The event size for the data and distribution parameters must match.\n"
                                 "Expected x.size()[-1] == self.loc.size()[0], but got {} vs {}".format(
                                     x.size(-1), loc.size(-1)))
            try:
                loc = loc.expand_as(x)
            except RuntimeError as e:
                raise ValueError("Parameter `loc` with shape {} is not broadcastable to "
                                 "the data shape {}. \nError: {}".format(loc.size(), x.size(), str(e)))

        return loc.size()[:-1]

    def event_shape(self):
        return self.loc.size()[-1:]

    def sample(self):
        return self.loc + torch.mv(self.scale_triu.t(), Variable(torch.randn(self.loc.size())), )

    def batch_log_pdf(self, x):
        delta = x - self.loc
        z0 = delta[..., 0] / self.scale_triu[..., 0, 0]
        z1 = (delta[..., 1] - self.scale_triu[..., 0, 1] * z0) / self.scale_triu[..., 1, 1]
        z = torch.stack([z0, z1], dim=-1)
        mahalanobis_squared = (z ** 2).sum(-1)
        normalization_constant = self.scale_triu.diag().log().sum(-1) + np.log(2 * np.pi)
        return -(normalization_constant + 0.5 * mahalanobis_squared).unsqueeze(-1)

    def entropy(self):
        return self.scale_triu.diag().log().sum() + (1 + math.log(2 * math.pi))
    
def _BVN_backward_reptrick(white, scale_triu, grad_output):
    grad = (grad_output.unsqueeze(-2) * white.unsqueeze(-1)).squeeze(0)  
    return grad_output, torch.triu(grad)
        
def _BVN_backward_symm(white, scale_triu, grad_output):
    grad = (grad_output.unsqueeze(-1) * white.unsqueeze(-2)).squeeze(0)
    x = torch.trtrs(white.t(), scale_triu, transpose=False)[0].t()
    y = torch.mm(scale_triu, grad_output.t()).t()
    grad += (x.unsqueeze(-1) * y.unsqueeze(-2)).squeeze(0)
    grad *= 0.5
    return grad_output, torch.triu(grad.t())
    
class _RepTrickSample(Function):
    @staticmethod
    def forward(ctx, loc, scale_triu):
        ctx.save_for_backward(scale_triu)
        ctx.white = loc.new(loc.size()).normal_()
        return loc + torch.mm(ctx.white, scale_triu)

    @staticmethod
    def backward(ctx, grad_output):
        scale_triu, = ctx.saved_variables
        return _BVN_backward_reptrick(Variable(ctx.white), scale_triu, grad_output)    

class _SymmetricSample(Function):
    @staticmethod
    def forward(ctx, loc, scale_triu):
        ctx.save_for_backward(scale_triu)
        ctx.white = loc.new(loc.size()).normal_()
        return loc + torch.mm(ctx.white, scale_triu)

    @staticmethod
    def backward(ctx, grad_output):
        scale_triu, = ctx.saved_variables
        return _BVN_backward_symm(Variable(ctx.white), scale_triu, grad_output)    


class BivariateNormalRepTrick(BivariateNormal):
    def sample(self):
        loc = self.loc.expand(self.batch_size, *self.loc.size())
        return _RepTrickSample.apply(loc, self.scale_triu)

class BivariateNormalSymmetric(BivariateNormal):
    def sample(self):
        loc = self.loc.expand(self.batch_size, *self.loc.size())
        return _SymmetricSample.apply(loc, self.scale_triu)

In [4]:
mu = Variable(torch.zeros(2))
scale_triu = torch.triu(Variable(torch.Tensor([[1, 0], [0, 1]]), requires_grad=True))
distrt = BivariateNormalRepTrick(mu, scale_triu)
distsym = BivariateNormalSymmetric(mu, scale_triu)

In [5]:
z = distrt.sample()
torch.pow(z,3.0).sum().backward()
print(scale_triu.grad)

backward rt
None


In [6]:
z = distsym.sample()
torch.pow(z,3.0).sum().backward()
print(scale_triu.grad)

RuntimeError: the derivative for 'A' is not implemented