In [1]:
import math
import argparse
import time

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F

from torchmeta.datasets.helpers import omniglot, miniimagenet
from torchmeta.utils.data import BatchMetaDataLoader

import higher

import hypergrad as hg

In [2]:
from dataclasses import dataclass


@dataclass
class Args:
    seed: int=0
    dataset: str='omniglot'
    hg_mode: str='CG'
    no_cuda: bool=False    

args = Args()
args

Args(seed=0, dataset='omniglot', hg_mode='CG', no_cuda=False)

In [3]:
log_interval = 100
eval_interval = 500
inner_log_interval = None
ways = 5
inner_log_interval_test = None
batch_size = 16
n_tasks_test = 1000  # usually 1000 tasks are used for testing

In [4]:
reg_param = 2  # reg_param = 2
T, K = 16, 5  # T, K = 16, 5

In [5]:
T_test = T
inner_lr = .1

In [6]:
cuda = not args.no_cuda and torch.cuda.is_available()
cuda

True

In [7]:
device = torch.device('cuda' if cuda else 'cpu')
device

device(type='cuda')

In [8]:
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
kwargs

{'num_workers': 1, 'pin_memory': True}

In [9]:
# the following are for reproducibility on GPU,
# see https://pytorch.org/docs/master/notes/randomness.html
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

torch.random.manual_seed(args.seed)
np.random.seed(args.seed)

In [10]:
# 5-ways 1-shot
dataset = omniglot(
    "data", ways=ways, shots=1, test_shots=15, meta_train=True, download=True)
test_dataset = omniglot(
    "data", ways=ways, shots=1, test_shots=15, meta_test=True, download=True)

In [11]:
def conv_layer(ic, oc, ):
    return nn.Sequential(
        nn.Conv2d(ic, oc, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2),
        nn.BatchNorm2d(oc, momentum=1., affine=True,
                       track_running_stats=True # When this is true is called the "transfuctive setting"
                       )
    )

meta_model = nn.Sequential(
    conv_layer(1, 64),
    conv_layer(64, 64),
    conv_layer(64, 64),
    conv_layer(64, 64),
    nn.Flatten(),
    nn.Linear(64, 5) # hidden_size, ways
)

In [12]:
meta_model.to(device)

Sequential(
  (0): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): BatchNorm2d(64, eps=1e-05, momentum=1.0, affine=True, track_running_stats=True)
  )
  (1): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): BatchNorm2d(64, eps=1e-05, momentum=1.0, affine=True, track_running_stats=True)
  )
  (2): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): BatchNorm2d(64, eps=1e-05, momentum=1.0, affine=True, track_running_stats=True)
  )
  (3): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)

In [14]:
list(meta_model.parameters())[0].device

device(type='cuda', index=0)

In [15]:
# initialize weights properly
for m in meta_model.modules():
    if isinstance(m, nn.Conv2d):
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0, math.sqrt(2. / n))
        if m.bias is not None:
            m.bias.data.zero_()
    elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()
    elif isinstance(m, nn.Linear):
        #m.weight.data.normal_(0, 0.01)
        #m.bias.data = torch.ones(m.bias.data.size())
        m.weight.data.zero_()
        m.bias.data.zero_()

In [16]:
dataloader = BatchMetaDataLoader(dataset, batch_size=batch_size, **kwargs)
test_dataloader = BatchMetaDataLoader(test_dataset, batch_size=batch_size, **kwargs)

In [17]:
outer_opt = torch.optim.Adam(params=meta_model.parameters())
outer_opt

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.001
    weight_decay: 0
)

In [18]:
inner_opt_class = hg.GradientDescent
inner_opt_kwargs = {'step_size': inner_lr}

In [19]:
def get_inner_opt(train_loss):
    return inner_opt_class(train_loss, **inner_opt_kwargs)

In [20]:
class Task:
    """
    Handles the train and valdation loss for a single task
    """
    def __init__(self, reg_param, meta_model, data, batch_size=None):
        device = next(meta_model.parameters()).device

        # stateless version of meta_model
        self.fmodel = higher.monkeypatch(meta_model, device=device, copy_initial_weights=True)

        self.n_params = len(list(meta_model.parameters()))
        self.train_input, self.train_target, self.test_input, self.test_target = data
        self.reg_param = reg_param
        self.batch_size = 1 if not batch_size else batch_size
        self.val_loss, self.val_acc = None, None

    def bias_reg_f(self, bias, params):
        # l2 biased regularization
        return sum([((b - p) ** 2).sum() for b, p in zip(bias, params)])

    def train_loss_f(self, params, hparams):
        # biased regularized cross-entropy loss where the bias are the meta-parameters in hparams
        out = self.fmodel(self.train_input, params=params)
        return F.cross_entropy(out, self.train_target) + 0.5 * self.reg_param * self.bias_reg_f(hparams, params)

    def val_loss_f(self, params, hparams):
        # cross-entropy loss (uses only the task-specific weights in params
        out = self.fmodel(self.test_input, params=params)
        val_loss = F.cross_entropy(out, self.test_target)/self.batch_size
        self.val_loss = val_loss.item()  # avoid memory leaks

        pred = out.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        self.val_acc = pred.eq(self.test_target.view_as(pred)).sum().item() / len(self.test_target)

        return val_loss

In [21]:
def inner_loop(hparams, params, optim, n_steps, log_interval, create_graph=False):
    params_history = [optim.get_opt_params(params)]

    for t in range(n_steps):
        params_history.append(optim(params_history[-1], hparams, create_graph=create_graph))

        if log_interval and (t % log_interval == 0 or t == n_steps-1):
            print('t={}, Loss: {:.6f}'.format(t, optim.curr_loss.item()))

    return params_history

In [22]:
def evaluate(n_tasks, dataloader, meta_model, n_steps, get_inner_opt, reg_param, log_interval=None):
    meta_model.train()
    device = next(meta_model.parameters()).device

    val_losses, val_accs = [], []
    for k, batch in enumerate(dataloader):
        tr_xs, tr_ys = batch["train"][0].to(device), batch["train"][1].to(device)
        tst_xs, tst_ys = batch["test"][0].to(device), batch["test"][1].to(device)

        for t_idx, (tr_x, tr_y, tst_x, tst_y) in enumerate(zip(tr_xs, tr_ys, tst_xs, tst_ys)):
            task = Task(reg_param, meta_model, (tr_x, tr_y, tst_x, tst_y))
            inner_opt = get_inner_opt(task.train_loss_f)

            params = [p.detach().clone().requires_grad_(True) for p in meta_model.parameters()]
            last_param = inner_loop(meta_model.parameters(), params, inner_opt, n_steps, log_interval=log_interval)[-1]

            task.val_loss_f(last_param, meta_model.parameters())

            val_losses.append(task.val_loss)
            val_accs.append(task.val_acc)

            if len(val_accs) >= n_tasks:
                return np.array(val_losses), np.array(val_accs)

In [None]:
for k, batch in enumerate(dataloader):
    start_time = time.time()
    meta_model.train()

    tr_xs, tr_ys = batch["train"][0].to(device), batch["train"][1].to(device)
    tst_xs, tst_ys = batch["test"][0].to(device), batch["test"][1].to(device)

    outer_opt.zero_grad()

    val_loss, val_acc = 0, 0
    forward_time, backward_time = 0, 0
    for t_idx, (tr_x, tr_y, tst_x, tst_y) in enumerate(
            zip(tr_xs, tr_ys, tst_xs, tst_ys)
    ):
        start_time_task = time.time()

        # single task set up
        task = Task(
            reg_param, meta_model, (tr_x, tr_y, tst_x, tst_y), 
            batch_size=tr_xs.shape[0]
        )
        inner_opt = get_inner_opt(task.train_loss_f)

        # single task inner loop
        params = [
            p.detach().clone().requires_grad_(True) 
            for p in meta_model.parameters()
        ]
        last_param = inner_loop(
            meta_model.parameters(), params, inner_opt, T, 
            log_interval=inner_log_interval)[-1]
        forward_time_task = time.time() - start_time_task

        # single task hypergradient computation
        if args.hg_mode == 'CG':
            # This is the approximation used in the paper CG stands for conjugate gradient
            cg_fp_map = hg.GradientDescent(loss_f=task.train_loss_f, step_size=1.)
            hg.CG(last_param, list(meta_model.parameters()), K=K, fp_map=cg_fp_map, outer_loss=task.val_loss_f)
        elif args.hg_mode == 'fixed_point':
            hg.fixed_point(last_param, list(meta_model.parameters()), K=K, fp_map=inner_opt,
                           outer_loss=task.val_loss_f)

        backward_time_task = time.time() - start_time_task - forward_time_task

        val_loss += task.val_loss
        val_acc += task.val_acc/task.batch_size

        forward_time += forward_time_task
        backward_time += backward_time_task

    outer_opt.step()
    step_time = time.time() - start_time

    if k % log_interval == 0:
        print('MT k={} ({:.3f}s F: {:.3f}s, B: {:.3f}s) Val Loss: {:.2e}, Val Acc: {:.2f}.'
              .format(k, step_time, forward_time, backward_time, val_loss, 100. * val_acc))

    if k % eval_interval == 0:
        test_losses, test_accs = evaluate(n_tasks_test, test_dataloader, meta_model, T_test, get_inner_opt,
                                      reg_param, log_interval=inner_log_interval_test)

        print("Test loss {:.2e} +- {:.2e}: Test acc: {:.2f} +- {:.2e} (mean +- std over {} tasks)."
              .format(test_losses.mean(), test_losses.std(), 100. * test_accs.mean(),
                      100.*test_accs.std(), len(test_losses)))

MT k=0 (4.441s F: 2.535s, B: 1.900s) Val Loss: 8.89e-01, Val Acc: 69.58.
Test loss 9.38e-01 +- 1.92e-01: Test acc: 67.63 +- 1.09e+01 (mean +- std over 1000 tasks).
MT k=100 (3.357s F: 1.280s, B: 2.073s) Val Loss: 2.96e-01, Val Acc: 95.33.
MT k=200 (3.216s F: 1.230s, B: 1.981s) Val Loss: 2.16e-01, Val Acc: 97.75.
MT k=300 (3.248s F: 1.230s, B: 2.015s) Val Loss: 1.73e-01, Val Acc: 98.42.
MT k=400 (3.196s F: 1.216s, B: 1.975s) Val Loss: 1.23e-01, Val Acc: 99.17.
MT k=500 (3.248s F: 1.231s, B: 2.013s) Val Loss: 1.60e-01, Val Acc: 98.08.
Test loss 2.01e-01 +- 9.18e-02: Test acc: 97.10 +- 3.93e+00 (mean +- std over 1000 tasks).
MT k=600 (3.229s F: 1.224s, B: 2.001s) Val Loss: 1.53e-01, Val Acc: 97.42.
MT k=700 (3.236s F: 1.243s, B: 1.990s) Val Loss: 1.46e-01, Val Acc: 98.00.
MT k=800 (3.518s F: 1.277s, B: 2.236s) Val Loss: 1.01e-01, Val Acc: 97.42.
MT k=900 (3.613s F: 1.304s, B: 2.303s) Val Loss: 1.75e-01, Val Acc: 95.42.
MT k=1000 (3.575s F: 1.281s, B: 2.290s) Val Loss: 1.24e-01, Val Acc: 9

MT k=8700 (3.407s F: 1.234s, B: 2.158s) Val Loss: 5.52e-02, Val Acc: 98.50.
MT k=8800 (3.486s F: 1.285s, B: 2.195s) Val Loss: 3.17e-02, Val Acc: 99.00.
MT k=8900 (3.503s F: 1.282s, B: 2.216s) Val Loss: 9.59e-02, Val Acc: 97.92.
MT k=9000 (3.253s F: 1.277s, B: 1.969s) Val Loss: 4.36e-02, Val Acc: 98.92.
Test loss 1.18e-01 +- 1.35e-01: Test acc: 96.24 +- 4.24e+00 (mean +- std over 1000 tasks).
MT k=9100 (3.438s F: 1.284s, B: 2.149s) Val Loss: 4.67e-02, Val Acc: 98.33.
MT k=9200 (3.475s F: 1.348s, B: 2.120s) Val Loss: 6.06e-02, Val Acc: 97.50.
