In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import logging

from collections import OrderedDict

import higher  # tested with higher v0.2

from torchmeta.datasets.helpers import omniglot
from torchmeta.utils.data import BatchMetaDataLoader

from dataclasses import dataclass

In [2]:
@dataclass
class Args:
    folder: str='data'
    num_shots: int=5
    num_ways: int=5
    step_size: float=0.4
    hidden_size: int=64
    output_folder: str=None
    batch_size: int=16
    num_batches: int=100
    num_workers: int=1
    download: bool=True
    use_cuda: bool=True
    seed: int=0
    hg_mode: str='CG'

args = Args(num_shots=1, num_ways=5, num_batches=2000)

In [3]:
args.device = torch.device(
    'cuda' if args.use_cuda and torch.cuda.is_available() else 'cpu')
args.device

device(type='cuda')

In [4]:
dataset = omniglot(
    args.folder,
    shots=args.num_shots,
    ways=args.num_ways,
    shuffle=True,
    test_shots=15,
    meta_train=True,
    download=args.download
)
dataloader = BatchMetaDataLoader(
    dataset,
    batch_size=args.batch_size,
    shuffle=True,
    num_workers=args.num_workers
)

In [5]:
def conv3x3(in_channels, out_channels, **kwargs):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, **kwargs),
        nn.BatchNorm2d(out_channels, momentum=1., track_running_stats=False),
        nn.ReLU(),
        nn.MaxPool2d(2)
    )

class ConvolutionalNeuralNetwork(nn.Module):
    def __init__(self, in_channels, out_features, hidden_size=64):
        super(ConvolutionalNeuralNetwork, self).__init__()
        self.in_channels = in_channels
        self.out_features = out_features
        self.hidden_size = hidden_size

        self.features = nn.Sequential(
            conv3x3(in_channels, hidden_size),
            conv3x3(hidden_size, hidden_size),
            conv3x3(hidden_size, hidden_size),
            conv3x3(hidden_size, hidden_size)
        )

        self.classifier = nn.Linear(hidden_size, out_features)

    def forward(self, inputs, params=None):
        features = self.features(inputs)
        features = features.view((features.size(0), -1))
        logits = self.classifier(features)
        return logits

In [6]:
def get_accuracy(logits, targets):
    """Compute the accuracy (after adaptation) of MAML on the test/query points
    Parameters
    ----------
    logits : `torch.FloatTensor` instance
        Outputs/logits of the model on the query points. This tensor has shape
        `(num_examples, num_classes)`.
    targets : `torch.LongTensor` instance
        A tensor containing the targets of the query points. This tensor has 
        shape `(num_examples,)`.
    Returns
    -------
    accuracy : `torch.FloatTensor` instance
        Mean accuracy on the query points
    """
    _, predictions = torch.max(logits, dim=-1)
    return torch.mean(predictions.eq(targets).float())

In [7]:
model = ConvolutionalNeuralNetwork(
    1, args.num_ways, hidden_size=args.hidden_size)
model.to(device=args.device)
model.train()

ConvolutionalNeuralNetwork(
  (features): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=1.0, affine=True, track_running_stats=False)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (1): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=1.0, affine=True, track_running_stats=False)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (2): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=1.0, affine=True, track_running_stats=False)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (3): Sequential(
      (0): Conv2d(64, 64, 

In [8]:
outer_opt = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
log_interval = 100
eval_interval = 500
inner_log_interval = None
ways = 5
inner_log_interval_test = None
batch_size = 16
n_tasks_test = 1000  # usually 1000 tasks are used for testing


reg_param = 2  # reg_param = 2
T, K = 16, 5  # T, K = 16, 5

T_test = T
inner_lr = .1

cuda = not args.no_cuda and torch.cuda.is_available()

device = torch.device('cuda' if cuda else 'cpu')
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}

In [9]:
class Task:
    """
    Handles the train and valdation loss for a single task
    """
    def __init__(self, reg_param, meta_model, data, batch_size=None):
        device = next(meta_model.parameters()).device

        # stateless version of meta_model
        self.fmodel = higher.monkeypatch(meta_model, device=device, copy_initial_weights=True)

        self.n_params = len(list(meta_model.parameters()))
        self.train_input, self.train_target, self.test_input, self.test_target = data
        self.reg_param = reg_param
        self.batch_size = 1 if not batch_size else batch_size
        self.val_loss, self.val_acc = None, None

    def bias_reg_f(self, bias, params):
        # l2 biased regularization
        return sum([((b - p) ** 2).sum() for b, p in zip(bias, params)])

    def train_loss_f(self, params, hparams):
        # biased regularized cross-entropy loss where the bias are the meta-parameters in hparams
        out = self.fmodel(self.train_input, params=params)
        return F.cross_entropy(out, self.train_target) + 0.5 * self.reg_param * self.bias_reg_f(hparams, params)

    def val_loss_f(self, params, hparams):
        # cross-entropy loss (uses only the task-specific weights in params
        out = self.fmodel(self.test_input, params=params)
        val_loss = F.cross_entropy(out, self.test_target)/self.batch_size
        self.val_loss = val_loss.item()  # avoid memory leaks

        pred = out.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        self.val_acc = pred.eq(self.test_target.view_as(pred)).sum().item() / len(self.test_target)

        return val_loss

In [11]:
import torch
from itertools import repeat

In [12]:
def gd_step(params, loss, step_size, create_graph=True):
    grads = torch.autograd.grad(loss, params, create_graph=create_graph)
    return [w - step_size * g for w, g in zip(params, grads)]


class DifferentiableOptimizer:
    def __init__(self, loss_f, dim_mult, data_or_iter=None):
        """
        Args:
            loss_f: callable with signature (params, hparams, [data optional]) -> loss tensor
            data_or_iter: (x, y) or iterator over the data needed for loss_f
        """
        self.data_iterator = None
        if data_or_iter:
            self.data_iterator = data_or_iter if hasattr(data_or_iter, '__next__') else repeat(data_or_iter)

        self.loss_f = loss_f
        self.dim_mult = dim_mult
        self.curr_loss = None

    def get_opt_params(self, params):
        opt_params = [p for p in params]
        opt_params.extend([torch.zeros_like(p) for p in params for _ in range(self.dim_mult-1) ])
        return opt_params

    def step(self, params, hparams, create_graph):
        raise NotImplementedError

    def __call__(self, params, hparams, create_graph=True):
        with torch.enable_grad():
            return self.step(params, hparams, create_graph)

    def get_loss(self, params, hparams):
        if self.data_iterator:
            data = next(self.data_iterator)
            self.curr_loss = self.loss_f(params, hparams, data)
        else:
            self.curr_loss = self.loss_f(params, hparams)
        return self.curr_loss
    
class GradientDescent(DifferentiableOptimizer):

    def __init__(self, loss_f, step_size, data_or_iter=None):
        super(GradientDescent, self).__init__(
            loss_f, dim_mult=1, data_or_iter=data_or_iter)
        if callable(step_size):
            self.step_size_f = step_size
        else:
            self.step_size_f = lambda x: step_size

    def step(self, params, hparams, create_graph):
        loss = self.get_loss(params, hparams)
        sz = self.step_size_f(hparams)
        # partial derivative per task-specific parameter phi
        return gd_step(params, loss, sz, create_graph=create_graph)

def get_inner_opt(train_loss):
    inner_opt_class = GradientDescent
    return inner_opt_class(train_loss, **inner_opt_kwargs)

In [None]:
from typing import Generator, List, Callable


def inner_loop(
    hparams: Generator[torch.Tensor, None, None],
    params: Generator[torch.Tensor, None, None],
    optim: GradientDescent,
    n_steps: int,
    log_interval: bool,
    create_graph=False,
) -> List[List[torch.Tensor]]:
    params_history = [optim.get_opt_params(params)]
    for t in range(n_steps):
        params_history.append(optim(params_history[-1], hparams, create_graph=create_graph))
        if log_interval and (t % log_interval == 0 or t == n_steps-1):
            print(f't={t}, Loss: {optim.curr_loss.item():.6f}')
    return params_history

In [None]:
# Training loop
with tqdm(dataloader, total=args.num_batches) as pbar:
    for k, batch in enumerate(pbar):
        outer_opt.zero_grad()
        model.train()

        tr_xs  = batch['train'][0].to(args.device)
        tr_ys  = batch['train'][1].to(args.device)
        tst_xs = batch['test'][0].to(args.device)
        tst_ys = batch['test'][1].to(args.device)

        outer_loss = torch.tensor(0., device=args.device)
        accuracy = torch.tensor(0., device=args.device)

        for t_idx, (tr_x, tr_y, tst_x, tst_y) in enumerate(
                zip(tr_xs, tr_ys, tst_xs, tst_ys)
        ):
            # single task set up
            task = Task(
                reg_param, meta_model, (tr_x, tr_y, tst_x, tst_y), 
                batch_size=tr_xs.shape[0]
            )
            inner_opt = get_inner_opt(task.train_loss_f)
            with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=False) as (fmodel, diffopt):
                tr_logit = fmodel(tr_x)
                inner_loss = F.cross_entropy(tr_logit, tr_y)

                diffopt.step(inner_loss)

                tst_logit = fmodel(tst_x)
                outer_loss += F.cross_entropy(tst_logit, tst_y)

                with torch.no_grad():
                    accuracy += get_accuracy(tst_logit, tst_y)

        outer_loss.div_(args.batch_size)
        accuracy.div_(args.batch_size)

        outer_loss.backward()
        outer_opt.step()

        pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))
        if k >= args.num_batches:
            break