In [None]:
%reload_ext autoreload
%autoreload 2

import torch
from torch import nn
from torch import optim
from higher import innerloop_ctx

import notebook_setup

## Gradient persistence/accumulation after `load_state_dict`
* `backward()` accumulates gradients regardless of new state dict loading.

In [None]:
m = nn.Sequential(nn.Linear(1,1,False))
(m(torch.ones(1,1))**2).backward()
g_before = []
for p in m.parameters():
    g_before.append(p.grad.clone())
print('M grads:', g_before)

m2 = nn.Sequential(nn.Linear(1,1,False))
m.load_state_dict(m2.state_dict())
g_after = []
for p in m.parameters():
    g_after.append(p.grad)
print('After loading M2 params', g_after)

(m(torch.ones(1,1))**2).backward()
(m2(torch.ones(1,1))**2).backward()
g2 = []
for p in m2.parameters():
    g2.append(p.grad)
print('M2 grads:', g2)

g_final = []
for p in m.parameters():
    g_final.append(p.grad)
print('M grads after backlward w/ new state dict:', g_final)

## `higher` context w/ `copy_initial_weights` & loading multiple `state_dict`s

`m2` and `m3` are converted to functional models using `higher`, and load their state dicts from `m`.

* Original weights referenced with `copy_initial_weights=False` are **not** used when new `state_dict` is loaded.
* **However** gradients are still accumulated when `copy_initial_weights=False`.

In [None]:
m = nn.Sequential(nn.Linear(1,1,False))
m2 = nn.Sequential(nn.Linear(1,1,False))
m3 = nn.Sequential(nn.Linear(1,1,False))
with torch.no_grad():
    for p in m.parameters(): p.fill_(1.)
    for p in m2.parameters(): p.fill_(2.)
    for p in m3.parameters(): p.fill_(3.)
o = optim.SGD(m.parameters(), lr=0.1)

diff, grads = [], []
for name, m_ in zip(('m2', 'm3'), (m2, m3)):
    with innerloop_ctx(m, o, copy_initial_weights=False) as (fm, fo):
        fm.load_state_dict(m_.state_dict())
        print('m', m.state_dict())
        print(name, m_.state_dict())
        print('fm', fm.state_dict())

        loss = (fm(torch.ones(1,1))**2)
        print('dL/dm', torch.autograd.grad(loss, m.parameters(), retain_graph=True))
        fo.step(loss)
        print('After update to fm')
        print('m', m.state_dict())
        print(name, m_.state_dict())
        print('fm', fm.state_dict())
        print('Test loss:')
        loss = fm(torch.ones(1,1))**2
        loss.backward(retain_graph=True)
        print('dL/dfm', torch.autograd.grad(loss, fm.parameters(), retain_graph=True))
        print('dfm/dm', torch.autograd.grad(list(fm.parameters())[0], m.parameters(), retain_graph=True))
        print('Gradients on m', torch.autograd.grad(loss, m.parameters(), retain_graph=True))
        
        pdiff = []
        grad = []
        for p0, p1 in zip(fm.parameters(time=0), fm.parameters(time=1)):
            pdiff.append(p1 - p0)
            grad.append(torch.autograd.grad(p1.sum(), p0)[0])
        diff.append(pdiff)
        grads.append(grad)
        print('=' * 10 + '\n')

        
print('Gradients on m: sum of test gradients on m2 m3')
print(*map(lambda p: p.grad, m.parameters()))
print('Parameter differences from m, for each state_dict loaded and updated:')
print(*diff)
print('Gradients w.r.t m, for each of updated m2, m3:')
print(*diff)

## Second order grads w/o `higher` using `create_graph` and `retain_graph`

In [None]:
m = nn.Sequential(nn.Linear(1,1,False))
m2 = nn.Sequential(nn.Linear(1,1,False))
o = optim.SGD(m.parameters(), lr=0.1)
o2 = optim.SGD(m2.parameters(), lr=0.1)
with torch.no_grad():
    for p in m.parameters(): p.fill_(1.)
    for p in m2.parameters(): p.fill_(1.)

###
        
loss = (m(torch.ones(1))**2).sum()
loss.backward(create_graph=True, retain_graph=True)
o.step()

o.zero_grad()
loss = (m(torch.ones(1))**2).sum()
loss.backward(create_graph=True, retain_graph=True)
for p in m.parameters(): print(p.data, p.grad.data)

###
    
loss = (m2(torch.ones(1))**2).sum()
loss.backward(create_graph=False, retain_graph=False)
o2.step()

o2.zero_grad()
loss = (m2(torch.ones(1))**2).sum()
loss.backward()
for p in m2.parameters(): print(p.data, p.grad.data)

# Accessing higher grads out o `innerloop_ctx`
* Fast weights and differentiation across parameters is possible outside `innerloop_ctx`

In [None]:
m = nn.Sequential(nn.Linear(1,1,False))
o = optim.SGD(m.parameters(), lr=0.1)
with torch.no_grad():
    for p in m.parameters(): p.fill_(1.)

m2 = None

with innerloop_ctx(m, o, track_higher_grads=True) as (fm, fo):
    loss = (fm(torch.ones(1))**2).sum()
    fo.step(loss)
    m2 = fm
    loss = (fm(torch.ones(1))**2).sum()
    loss.backward(retain_graph=True)
    print('Inside context')
    print('dL/dp(t=1):', torch.autograd.grad(loss, m2.parameters(time=1), retain_graph=True))
    print('dp(t=1)/dp(t=0):', torch.autograd.grad(sum(list(m2.parameters())), m2.parameters(time=0), retain_graph=True))
    print('dL/dp(t=0):', torch.autograd.grad(loss, m2.parameters(time=0), retain_graph=True))

print('Outside context')
print('dp(t=1)/dp(t=0):', torch.autograd.grad(sum(list(m2.parameters())), m2.parameters(time=0)))
print('p(t=0)')
for p in m2.parameters(time=0):
    print(p.data)
print('p(t=0)')
for p in m2.parameters(time=1):
    print(p.data)