In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable

# 1. Does backpropping through a flatten work?

In [7]:
x = Variable(torch.Tensor([1, 2, 3, 4]), requires_grad=True)

z = x.reshape((2, 2))

loss = torch.norm(z) * torch.norm(z)
loss.backward()

In [8]:
x.grad

tensor([2., 4., 6., 8.])

In [10]:
z

tensor([[1., 2.],
        [3., 4.]], grad_fn=<ViewBackward>)

In [11]:
z.grad

In [12]:
x.shape

torch.Size([4])

# 2. Computing Hessian-vector products

In [35]:
# setup for now:
# f maps [n], [m] to [1], so Hessian is n x n or n x m or m x m depending on what's being done
# Have a vector v, want to compute Hv

def f(x, y):
    return torch.sum(torch.pow(x, 3)) * torch.sum(torch.pow(y, 3))

def g(x, y, theta):
    return theta**2 * torch.sum(torch.pow(x, 3)) * torch.sum(torch.pow(y, 3))

n = 4
m = 3
x = torch.Tensor(range(n))
y = torch.Tensor(range(m))
print(f(x, y))
g(x, y, 1)

tensor(324.)


tensor(324.)

In [25]:
def hvp(g, x, y, v):
    xvar = Variable(x, requires_grad=True)
    yvar = Variable(y, requires_grad=True)
    vvar = Variable(v, requires_grad=True)
    
    score = g(xvar, yvar)
    
    grad, = torch.autograd.grad(score, yvar, create_graph=True)
    #print(grad)
    total = torch.sum(grad * vvar)
    #print(total)
    
    if xvar.grad:
        xvar.grad.data.zero_()
    if yvar.grad:
        yvar.grad.data.zero_()
        
    total.backward()
    
    if xvar.grad is not None:
        return xvar.grad
    else:
        return torch.zeros(x.shape)

In [24]:
# this will allow us to backpropagate hessian-vector products, the create_graph=True is important

def hvp2(g, x, y, v):
    xvar = Variable(x, requires_grad=True)
    yvar = Variable(y, requires_grad=True)
    vvar = Variable(v, requires_grad=True)
    
    score = g(xvar, yvar)
    
    grad, = torch.autograd.grad(score, yvar, create_graph=True)
    #print(grad)
    total = torch.sum(grad * vvar)
    #print(total)
    
    if xvar.grad:
        xvar.grad.data.zero_()
    if yvar.grad:
        yvar.grad.data.zero_()
        
    grad2, = torch.autograd.grad(total, xvar, create_graph=True, allow_unused=True)
    return grad2

In [15]:
v = torch.Tensor([2, 1, 0])
print("result", hvp(f, x, y, v))
print("result", hvp2(f, x, y, v))

tensor([  0., 108., 432.], grad_fn=<MulBackward0>)
tensor(108., grad_fn=<SumBackward0>)
result tensor([ 0.,  9., 36., 81.])
tensor([  0., 108., 432.], grad_fn=<MulBackward0>)
tensor(108., grad_fn=<SumBackward0>)
result tensor([ 0.,  9., 36., 81.], grad_fn=<MulBackward0>)


In [None]:
# so this breaks when the Hessian doesn't survive somehow

In [47]:
# this will not work

theta.grad.data.zero_()
theta = Variable(torch.Tensor([1]), requires_grad=True)
x = torch.Tensor(range(n))
y = torch.Tensor(range(m))
v = torch.Tensor([2, 1, 0])

score = hvp(lambda x, y: g(x, y, theta), x, y, v)
print(score)
torch.sum(score).backward()
print(theta.grad)

tensor([ 0.,  9., 36., 81.])


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [49]:
# but this will

theta.grad.data.zero_()
theta = Variable(torch.Tensor([1]), requires_grad=True)
x = torch.Tensor(range(n))
y = torch.Tensor(range(m))
v = torch.Tensor([2, 1, 0])

score = hvp2(lambda x, y: g(x, y, theta), x, y, v)
print(score)
torch.sum(score).backward()
print(theta.grad)

tensor([ 0.,  9., 36., 81.], grad_fn=<MulBackward0>)
tensor([252.])
