Testing higher order gradient calculations using `pytorch` and the `higher` library. Specifically, the gradient of model parameters with respect to their earlier version during gradient descent.

In [None]:
%reload_ext autoreload
%autoreload 2

import torch
from torch import nn
from torch import optim
from torch.autograd import grad
import higher
from higher import innerloop_ctx
from higher.optim import DifferentiableSGD
import numpy as np

import notebook_setup
import ppo, utils, meta

Take this function where `a_[0],b_[0]` are current parameters (t=0), `x` is input and `y_` is  predicted model output. `a, b` and `y` are true parameters and value respectively.

    y_ = (a_[0]**2 + b_[0]**2) * x

The loss is 

    L = 0.5 * (y_ - y) ** 2

Here, I consider just `a`. The gradients are:

    d L  / d y_  = y_ - y                   # loss w.r.t prediction
    d y_ / d a_[0] = 2 a_[0] x              # prediction w.r.t parameter
    d L  / d a_[0] = (y_ - y)(2 a_[0] x)    # loss w.r.t parameter via chain rule

I am using SGD with learning rate=1 for simplicity. So, the next version of parameter `a`, `a_[1]` will be:

    a_[1] = a_[0] -  d L  / d a_[0]        # update opposite direction of gradient
    a_[1] = a_[0] - (y_ - y)(2 a_[0] x)    # substituting
    a_[1] = a_[0] * (1 - (y_ - y)(2 x) )   # factoring out a_[0]
    a_[1] = a_[0] * (1- 2 x y_ + 2 x y)
    a_[1] = a_[0] * (1- 2 x (a_[0]**2 + b_[0]**2) * x) + 2 x y)
    a_[1] = a_[0] * (1 - 2 x**2 a_[0]**2 - 2 x**2 b_[0]**2 + 2 x y)
    a_[1] = a_[0] - 2 x**2 a_[0]**3 - 2 x**2 a_[0] b_[0]**2 + 2 a_[0] x y

Which gives the gradient between different versions of `a_` as:

    d a_[1] / d a_[0] = 1 - 6 x**2 a_[0]**2 - 2 x**2 b_[0]**2 + 2 x y

Assuming true `a=3, b=4`, and current parameters `a_[0]=2, b_[0]=3`, and input `x=0.1`, then

    y =  (3**2 + 4**2) * 0.1 = 2.5
    y_ = (2**2 + 3**2) * 0.1 = 1.3
    L = 0.5 (1.3 - 2.5)**2 = 0.72

    d L  / d y_ = -1.2
    d y_ / d a_[0] = 2 * 2 * 0.1 = 0.4
    d L  / d a_[0] = (1.3 - 2.5)(2 * 2 * 0.1) = -1.2 * 0.4 = -0.48

    a_[1] = 2 - (-0.48) = 2.48
    d a_[1] / d a_[0] = 1 - 6 * 0.01 * 4 - 2 * 0.01 * 9 + 2 * 0.1 * 2.5 = 1.08

Similarly, `b_[1] = 3.72` using the same method. For the meta-update step, the test loss `TL` is calculated using these updated parameters. The loss gradient is then backprogagated through time=1 to time=0 to get `d TL / d_a[0]`. However since `TL` is a result of both `a_[1], b_[1]`, so backpropagation is done for both `a, b`:

    d TL / d a_[0] = (d TL / d a_[1] * d a_[1] / d_a[0]) + (d TL / d b_[1] * d b_[1] / d_a[0])

In [None]:
class Model(nn.Module):
    
    def __init__(self, a=2., b=3.):
        super().__init__()
        # Use nn.Modules
        self.a = nn.Linear(1,1,False)
        self.a.weight.data.fill_(a)
        self.b = nn.Linear(1,1,False)
        self.b.weight.data.fill_(b)
        # Or custom parameters (toggle comments in forward()):
        # self.a = nn.Parameter(torch.ones(1, requires_grad=True) * a)
        # self.b = nn.Parameter(torch.ones(1, requires_grad=True) * b)
    
    def forward(self, x):
        return self.a(self.a(x)) + self.b(self.b(x))    # with modules
        # return (self.a ** 2 + self.b ** 2) * x        # with custom

In [None]:
class Opt(DifferentiableSGD):
    
    def _update(self, grouped_grads, **kwargs) -> None:
        print('d Loss / d params(t=0): a_[0]:{:.3f} \t b_[0]{:.3f}'.format(*[g.item() for g in grouped_grads[0]]))
        return super()._update(grouped_grads, **kwargs)
        

higher.register_optim(optim.SGD, Opt)

In [None]:
# Code for higher order gradients
x = torch.ones(1) * 0.1
a, b = 3, 4    # true parameter values
m = Model()
o = optim.SGD(m.parameters(), lr=1.)
# Single training step t=0 ==> t=1
with higher.innerloop_ctx(m, o, copy_initial_weights=False) as (hm, ho):
    y_ = hm(x)
    y = (a**2 + b**2) * x
    loss = 0.5 * (y_ - y) ** 2

    print('Inner optimization')
    print('Loss:\t\t{:.3f}'.format(loss.item()))
    print('d Loss / d y_: {:.3f}'.format(torch.autograd.grad(loss, y_, retain_graph=True)[0].item()))
    print('Current params: \ta_[0]:{:.3f} \t b_[0]:{:.3f}'.format(
          *[p.item() for p in hm.parameters(time=0)]))
    
    # This populates grads in `m` when copy_initial_weights=False
    ho.step(loss)
    
    print('Updated params: \ta_[1]:{:.3f} \t b_[1]:{:.3f}'.format(
          *[p.item() for p in hm.parameters(time=1)]))
    
    print('d a_[1] / d param(t=0): a_[0]:{:.3f} \t b_[0]:{:.3f}'.format(
          *[g.item() for g in \
            torch.autograd.grad(list(hm.parameters(time=1))[0],  # this is first param at t=1 (a_[1])
                                hm.parameters(time=0),           # w.r.t all params at t=0
                                retain_graph=True)]))
    
    print()
    y_ = hm(x)
    y = (a**2 + b**2) * x
    testloss = 0.5 * (y_ - y) ** 2
    testloss.backward(retain_graph=True)

    print('Outer optimization using test loss')
    print('TL:\t\t{:.3f}'.format(testloss.item()))
    print('d TL / d a_[0] = [\td TL / d a_[1] * d a_[1] / d a_[0] +\n\t\t\td TL / d b_[1] * d b_[1] / d a_[0] ]')
    print('%.3f' % list(m.parameters())[0].grad.item(), '=',
          '%.3f' % grad(testloss, list(hm.parameters(time=1))[0], retain_graph=True)[0].item(), '*',
          '%.3f' % grad(list(hm.parameters(time=1))[0], list(hm.parameters(time=0))[0], retain_graph=True)[0].item(),
          ' + ',
          '%.3f' % grad(testloss, list(hm.parameters(time=1))[1], retain_graph=True)[0].item(), '*',
          '%.3f' % grad(list(hm.parameters(time=1))[1], list(hm.parameters(time=0))[0])[0].item())
    
    o.step()  # Outer update (meta-update) using d TL / d params
    print('Final parameters:')
    print(*m.parameters())

In [None]:
# Higher order gradient tracking with multiple models.

M, m = Model(), Model()
m.load_state_dict(M.state_dict())
O, o = optim.SGD(M.parameters(), lr=1.), optim.SGD(m.parameters(), lr=1.)

with higher.innerloop_ctx(m, o, copy_initial_weights=False) as (hm, ho):
    y_ = hm(x)
    y = (a**2 + b**2) * x
    loss = 0.5 * (y_ - y) ** 2
    ho.step(loss)
    
    y_ = hm(x)
    y = (a**2 + b**2) * x
    testloss = 0.5 * (y_ - y) ** 2
    testloss.backward(retain_graph=True)
    
    for P, p0, p1 in zip(M.parameters(), hm.parameters(time=0), hm.parameters(time=1)):
        if P.grad is None:
            P.grad = torch.zeros_like(P)
        P.grad.add_(grad(testloss, p0, retain_graph=True)[0])
    
    O.step()
    print('Final parameters:')
    print(*M.parameters())