In [1]:
import math
import argparse
import time

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F

from torchmeta.datasets.helpers import omniglot, miniimagenet
from torchmeta.utils.data import BatchMetaDataLoader

import higher

from dataclasses import dataclass


@dataclass
class Args:
    seed: int=0
    dataset: str='omniglot'
    hg_mode: str='CG'
    no_cuda: bool=False    

args = Args()

log_interval = 100
eval_interval = 500
inner_log_interval = None
ways = 5
inner_log_interval_test = None
batch_size = 16
n_tasks_test = 1000  # usually 1000 tasks are used for testing


reg_param = 2  # reg_param = 2
T, K = 16, 5  # T, K = 16, 5

T_test = T
inner_lr = .1

cuda = not args.no_cuda and torch.cuda.is_available()

device = torch.device('cuda' if cuda else 'cpu')
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}

In [2]:
args

Args(seed=0, dataset='omniglot', hg_mode='CG', no_cuda=False)

In [3]:
device = torch.device('cuda' if cuda else 'cpu')
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}

# 5-ways 1-shot
dataset = omniglot(
    "data", ways=ways, shots=1, test_shots=15, meta_train=True, download=True)
test_dataset = omniglot(
    "data", ways=ways, shots=1, test_shots=15, meta_test=True, download=True)

In [4]:
def conv_layer(ic, oc, ):
    return nn.Sequential(
        nn.Conv2d(ic, oc, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2),
        nn.BatchNorm2d(oc, momentum=1., affine=True,
                       track_running_stats=True # When this is true is called the "transfuctive setting"
                       )
    )

meta_model = nn.Sequential(
    conv_layer(1, 64),
    conv_layer(64, 64),
    conv_layer(64, 64),
    conv_layer(64, 64),
    nn.Flatten(),
    nn.Linear(64, 5) # hidden_size, ways
)

for m in meta_model.modules():
    if isinstance(m, nn.Conv2d):
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0, math.sqrt(2. / n))
        if m.bias is not None:
            m.bias.data.zero_()
    elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()
    elif isinstance(m, nn.Linear):
        #m.weight.data.normal_(0, 0.01)
        #m.bias.data = torch.ones(m.bias.data.size())
        m.weight.data.zero_()
        m.bias.data.zero_()
        
meta_model

Sequential(
  (0): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): BatchNorm2d(64, eps=1e-05, momentum=1.0, affine=True, track_running_stats=True)
  )
  (1): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): BatchNorm2d(64, eps=1e-05, momentum=1.0, affine=True, track_running_stats=True)
  )
  (2): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): BatchNorm2d(64, eps=1e-05, momentum=1.0, affine=True, track_running_stats=True)
  )
  (3): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)

In [5]:
dataloader = BatchMetaDataLoader(
    dataset, batch_size=batch_size, **kwargs)
test_dataloader = BatchMetaDataLoader(
    test_dataset, batch_size=batch_size, **kwargs)

In [6]:
outer_opt = torch.optim.Adam(params=meta_model.parameters())
outer_opt

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.001
    weight_decay: 0
)

In [7]:
for k, batch in enumerate(dataloader):
    break
    
for i in batch['train']:
    print(i.size())

print()

for i in batch['test']:
    print(i.size())

train_input  = tr_x = tr_xs = batch['train'][0][0]
train_target = tr_y = tr_ys = batch['train'][1][0]
test_input   = ts_x = ts_xs = batch['test'][0][0]
test_target  = ts_y = ts_ys = batch['test'][1][0]

torch.Size([16, 5, 1, 28, 28])
torch.Size([16, 5])

torch.Size([16, 75, 1, 28, 28])
torch.Size([16, 75])


In [8]:
fmodel = higher.monkeypatch(meta_model, copy_initial_weights=True)
fmodel

FunctionalSequential(
  (0): InnerFunctionalSequential(
    (0): InnerFunctionalConv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): InnerFunctionalReLU(inplace=True)
    (2): InnerFunctionalMaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): InnerFunctionalBatchNorm2d(64, eps=1e-05, momentum=1.0, affine=True, track_running_stats=True)
  )
  (1): InnerFunctionalSequential(
    (0): InnerFunctionalConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): InnerFunctionalReLU(inplace=True)
    (2): InnerFunctionalMaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): InnerFunctionalBatchNorm2d(64, eps=1e-05, momentum=1.0, affine=True, track_running_stats=True)
  )
  (2): InnerFunctionalSequential(
    (0): InnerFunctionalConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): InnerFunctionalReLU(inplace=True)
    (2): InnerFunctionalMaxPool2d(kernel_size=2, stride=2, paddi

In [9]:
import torch
from itertools import repeat

In [10]:
n_params = len(list(meta_model.parameters()))

batch_size = 16

val_loss = None
val_acc = None

inner_opt_kwargs = {'step_size': inner_lr}

In [11]:
n_params

18

In [12]:
class DifferentiableOptimizer:
    def __init__(self, loss_f, dim_mult, data_or_iter=None):
        """
        Args:
            loss_f: callable with signature (params, hparams, [data optional]) -> loss tensor
            data_or_iter: (x, y) or iterator over the data needed for loss_f
        """
        self.data_iterator = None
        if data_or_iter:
            self.data_iterator = data_or_iter if hasattr(data_or_iter, '__next__') else repeat(data_or_iter)

        self.loss_f = loss_f
        self.dim_mult = dim_mult
        self.curr_loss = None

    def get_opt_params(self, params):
        opt_params = [p for p in params]
        opt_params.extend([torch.zeros_like(p) for p in params for _ in range(self.dim_mult-1) ])
        return opt_params

    def step(self, params, hparams, create_graph):
        raise NotImplementedError

    def __call__(self, params, hparams, create_graph=True):
        with torch.enable_grad():
            return self.step(params, hparams, create_graph)

    def get_loss(self, params, hparams):
        if self.data_iterator:
            data = next(self.data_iterator)
            self.curr_loss = self.loss_f(params, hparams, data)
        else:
            self.curr_loss = self.loss_f(params, hparams)
        return self.curr_loss

In [13]:
class GradientDescent(DifferentiableOptimizer):
    def __init__(self, loss_f, step_size, data_or_iter=None):
        super(GradientDescent, self).__init__(loss_f, dim_mult=1, data_or_iter=data_or_iter)
        self.step_size_f = step_size if callable(step_size) else lambda x: step_size

    def step(self, params, hparams, create_graph):
        loss = self.get_loss(params, hparams)
        sz = self.step_size_f(hparams)
        return gd_step(params, loss, sz, create_graph=create_graph)

In [14]:
def gd_step(params, loss, step_size, create_graph=True):
    grads = torch.autograd.grad(loss, params, create_graph=create_graph)
    return [w - step_size * g for w, g in zip(params, grads)]

In [15]:
inner_opt_class = GradientDescent

In [16]:
def get_inner_opt(train_loss):
    return inner_opt_class(train_loss, **inner_opt_kwargs)

In [17]:
def bias_reg_f(bias, params):
    # L2 biasd regularization
    return sum([((b-p)**2).sum() for b, p in zip(bias, params)])


def train_loss_f(params, hparams):
    out = fmodel(train_input, params=params)
    return F.cross_entropy(out, train_target) + \
           0.5*reg_param*bias_reg_f(hparams, params)

In [18]:
inner_opt = inner_opt_class(train_loss_f, **inner_opt_kwargs)
inner_opt

<__main__.GradientDescent at 0x225546edef0>

$$\theta_0=\theta_{meta}$$

In [19]:
# single task inner loop
params = [
    p.detach().clone().requires_grad_(True)
    for p in meta_model.parameters()
]

In [20]:
from typing import Generator, List

In [21]:
def inner_loop(
    hparams: Generator[torch.Tensor, None, None],
    params: Generator[torch.Tensor, None, None],
    optim: GradientDescent,
    n_steps: int,
    log_interval: bool,
    create_graph=False,
) -> List[List[torch.Tensor]]:
    params_history = [optim.get_opt_params(params)]
    for t in range(n_steps):
        params_history.append(optim(params_history[-1], hparams, create_graph=create_graph))
        if log_interval and (t % log_interval == 0 or t == n_steps-1):
            print(f't={t}, Loss: {optim.curr_loss.item():.6f}')
    return params_history

In [22]:
hparams = meta_model.parameters()
optim = inner_opt
n_steps = T # 16
log_interval = inner_log_interval

In [23]:
last_param = inner_loop(
    meta_model.parameters(), params, inner_opt, T, log_interval=True
)[-1]

t=0, Loss: 1.609438
t=1, Loss: 0.602081
t=2, Loss: 0.430455
t=3, Loss: 0.574521
t=4, Loss: 0.267405
t=5, Loss: 0.140030
t=6, Loss: 0.083962
t=7, Loss: 0.066514
t=8, Loss: 0.056141
t=9, Loss: 0.049419
t=10, Loss: 0.044100
t=11, Loss: 0.040031
t=12, Loss: 0.036698
t=13, Loss: 0.033906
t=14, Loss: 0.031534
t=15, Loss: 0.029502


In [24]:
cg_fp_map = GradientDescent(loss_f=train_loss_f, step_size=1.)

In [25]:
# Computes the hypergradient by applying K steps of the
# conjugate gradient method (CG).
# It can end earlier when tol is reached

params = [w.detach().requires_grad_(True) for w in last_param]
hparams = list(meta_model.parameters())
stochastic = False
set_grad = True
tol = 1e-10

In [26]:
# outer_loss via task.val_loss_f
out = fmodel(test_input, params=params)
val_loss = F.cross_entropy(out, test_target) / batch_size
o_loss = val_loss

In [27]:
# get outer gradients
def grad_unused_zero(
    output,
    inputs,
    grad_outputs=None,
    retain_graph=False,
    create_graph=False,
):
    grads = torch.autograd.grad(
        output, inputs, grad_outputs, allow_unused=True,
        retain_graph=retain_graph, create_graph=create_graph)
    return tuple(
        torch.zeros_like(v) if g is None else g
        for g, v in zip(grads, inputs)
    )

In [28]:
grad_outer_w = grad_unused_zero(o_loss, params)

In [29]:
len(grad_outer_w)

18

In [30]:
grad_outer_hparams = grad_unused_zero(o_loss, hparams)

RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

In [None]:
fp_map = cg_fp_map

In [48]:


def dfp_map_dw(xs):
    if stochastic:
        w_mapped_in = fp_map(params, hparams)
        Jfp_mapTv = torch.autograd.grad(w_mapped_in, params, grad_outputs=xs, retain_graph=False)
    else:
        Jfp_mapTv = torch.autograd.grad(w_mapped, params, grad_outputs=xs, retain_graph=True)
    return [v - j for v, j in zip(xs, Jfp_mapTv)]

# Conjugate Gradient
Ax = dfp_map_dw
b = grad_outer_w
max_iter = 100
epsilon = 1.0e-5

x_last = [torch.zeros_like(bb) for bb in b]
r_last = [torch.zeros_like(bb).copy_(bb) for bb in b]
p_last = [torch.zeros_like(rr).copy_(rr) for rr in r_last]

for ii in range(max_iter):
    Ap = Ax(p_last)
    Ap_vec = torch.cat([xx.view(-1) for xx in Ap])
    p_last_vec = torch.cat([xx.view(-1) for xx in p_last])
    r_last_vec = torch.cat([xx.view(-1) for xx in r_last])
    rTr = torch.sum(r_last_vec * r_last_vec)
    pAp = torch.sum(p_last_vec * Ap_vec)
    alpha = rTr / pAp
    
    x = [xx + alpha * pp for xx, pp in zip(x_last, p_last)]    
    r = [rr - alpha * pp for rr, pp in zip(r_last, Ap)]
    r_vec = torch.cat([xx.view(-1) for xx in r_last])
    
    if float(torch.norm(r_vec)) < epsilon:
        break
        
    beta = torch.sum(r_vec * r_vec) / rTr
    p = [rr + beta * pp for rr, pp in zip(r, p_last)]
    
    x_last = x
    p_last = p
    r_last = r
    
vs = x_last

if stochastic:
    w_mapped = fp_map(params, hparams)
    
grads = torch.autograd.grad(w_mapped, hparams, grad_outputs=vs)
grads = [g + v for g, v in zip(grads, grad_outer_hparams)]

if set_grad:
    for l, g in zip(hparams, grads):
        if l.grad is None:
            l.grad = torch.zeros_like(l)
        if g is not None:
            l.grad += g

TypeError: forward() takes 2 positional arguments but 3 were given

In [None]:
hg.CG(
    last_param, 
    list(meta_model.parameters()), 
    K=K, 
    fp_map=cg_fp_map, 
    outer_loss=task.val_loss_f)