# Properties of reparameterized gradients

In [1]:
import torch
from torch.autograd import Variable, grad
from torch.distributions import Normal

Let's define some helpers for vector calculus.

In [2]:
def reparam_grad(x, a):
    """Computes `R = dx[:]/da` for a vector `x` that depends on a scalar `a`."""
    assert a.size() == (1,)
    return torch.stack([
        grad([x[i]], [a], create_graph=True)[0]
        for i in range(x.size(-1))
    ]).squeeze()

def jacobian(x, e):
    return torch.stack([
        grad([x[i]], [e], create_graph=True)[0]
        for i in range(x.size(-1))
    ], -2).squeeze()

def vector_field_deriv(R, e, x):
    """Computes the matrix of derivatives `dR[i]/dx[j]` for a vector field `R`."""
    dR_de = jacobian(R, e)
    dx_de = jacobian(x, e)
    dR_dx = torch.gesv(dx_de.transpose(0, 1), dR_de.transpose(0, 1))[0].transpose(0, 1)
    return dR_dx

def divergence(dR):
    """Computes divergence of a vector field `R` given its derivatives `dR`."""
    return (dR * Variable(torch.eye(R.size(-1)))).sum(-1).sum(-1)

def antisymmetric_part(dR):
    """Computes generalized curl of a vector field `R` given its derivatives `dR`."""
    return 0.5 * (dR - dR.transpose(-1, -2))

Consdier first a transformed normal distribution.

In [3]:
a = Variable(torch.Tensor([1]), requires_grad=True)
A = Variable(torch.eye(2))
A += a * Variable(torch.Tensor([[0, 1], [-1, 0]]))
e = Variable(torch.Tensor(2).normal_(), requires_grad=True)
x = torch.mv(A, e) * a
x

Variable containing:
 1.8636
 0.4137
[torch.FloatTensor of size 2]

In [4]:
dx_de = jacobian(x, e)
dx_de

Variable containing:
 1  1
-1  1
[torch.FloatTensor of size 2x2]

In [5]:
de_dx = dx_de.inverse()
de_dx

Variable containing:
 0.5000 -0.5000
 0.5000  0.5000
[torch.FloatTensor of size 2x2]

In [6]:
logp = Normal(0, Variable(torch.ones(2))).log_prob(e).sum() - torch.det(A).log()
p = logp.exp()
p

Variable containing:
1.00000e-02 *
  3.1997
[torch.FloatTensor of size 1]

In [7]:
dlogp_da = grad([logp], [a], create_graph=True)
dlogp_da

(Variable containing:
 -1.0000
 [torch.FloatTensor of size 1],)

In [8]:
dlogp_dx = torch.mv(de_dx.transpose(0, 1), grad([logp], [e], create_graph=True)[0])
dlogp_dx

Variable containing:
-0.9318
-0.2069
[torch.FloatTensor of size 2]

In [9]:
R = reparam_grad(x, a)
R

Variable containing:
 3.0023
-0.3112
[torch.FloatTensor of size 2]

In [10]:
dR = vector_field_deriv(R, e, x)
dR

Variable containing:
 0.6000 -0.2000
 0.2000  0.6000
[torch.FloatTensor of size 2x2]

In [11]:
divergence(p * dR)

Variable containing:
1.00000e-02 *
  3.8397
[torch.FloatTensor of size 1]

The following should both be zero:

In [12]:
divergence(p * dR) + dlogp_dx * p

Variable containing:
1.00000e-02 *
  0.8581
  3.1778
[torch.FloatTensor of size 2]

In [13]:
antisymmetric_part(p * dR)

Variable containing:
1.00000e-03 *
  0.0000 -6.3994
  6.3994  0.0000
[torch.FloatTensor of size 2x2]