In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import logging

from collections import OrderedDict

import higher  # tested with higher v0.2

from torchmeta.datasets.helpers import omniglot
from torchmeta.utils.data import BatchMetaDataLoader

from dataclasses import dataclass

In [2]:
@dataclass
class Args:
    folder: str='data'
    num_shots: int=5
    num_ways: int=5
    step_size: float=0.4
    hidden_size: int=64
    output_folder: str=None
    batch_size: int=16
    num_batches: int=100
    num_workers: int=1
    download: bool=True
    use_cuda: bool=True
    seed: int=0
    hg_mode: str='CG'

args = Args(num_shots=1, num_ways=5, num_batches=2000)

In [3]:
args.device = torch.device(
    'cuda' if args.use_cuda and torch.cuda.is_available() else 'cpu')
args.device

device(type='cuda')

In [4]:
dataset = omniglot(
    args.folder,
    shots=args.num_shots,
    ways=args.num_ways,
    shuffle=True,
    test_shots=15,
    meta_train=True,
    download=args.download
)
dataloader = BatchMetaDataLoader(
    dataset,
    batch_size=args.batch_size,
    shuffle=True,
    num_workers=args.num_workers
)

In [5]:
def conv3x3(in_channels, out_channels, **kwargs):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, **kwargs),
        nn.BatchNorm2d(out_channels, momentum=1., track_running_stats=False),
        nn.ReLU(),
        nn.MaxPool2d(2)
    )

class ConvolutionalNeuralNetwork(nn.Module):
    def __init__(self, in_channels, out_features, hidden_size=64):
        super(ConvolutionalNeuralNetwork, self).__init__()
        self.in_channels = in_channels
        self.out_features = out_features
        self.hidden_size = hidden_size

        self.features = nn.Sequential(
            conv3x3(in_channels, hidden_size),
            conv3x3(hidden_size, hidden_size),
            conv3x3(hidden_size, hidden_size),
            conv3x3(hidden_size, hidden_size)
        )

        self.classifier = nn.Linear(hidden_size, out_features)

    def forward(self, inputs, params=None):
        features = self.features(inputs)
        features = features.view((features.size(0), -1))
        logits = self.classifier(features)
        return logits

In [6]:
def get_accuracy(logits, targets):
    """Compute the accuracy (after adaptation) of MAML on the test/query points
    Parameters
    ----------
    logits : `torch.FloatTensor` instance
        Outputs/logits of the model on the query points. This tensor has shape
        `(num_examples, num_classes)`.
    targets : `torch.LongTensor` instance
        A tensor containing the targets of the query points. This tensor has 
        shape `(num_examples,)`.
    Returns
    -------
    accuracy : `torch.FloatTensor` instance
        Mean accuracy on the query points
    """
    _, predictions = torch.max(logits, dim=-1)
    return torch.mean(predictions.eq(targets).float())

In [7]:
model = ConvolutionalNeuralNetwork(
    1, args.num_ways, hidden_size=args.hidden_size)
model.to(device=args.device)
model.train()

ConvolutionalNeuralNetwork(
  (features): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=1.0, affine=True, track_running_stats=False)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (1): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=1.0, affine=True, track_running_stats=False)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (2): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=1.0, affine=True, track_running_stats=False)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (3): Sequential(
      (0): Conv2d(64, 64, 

In [8]:
inner_opt = torch.optim.SGD(model.parameters(), lr=args.step_size)
outer_opt = torch.optim.Adam(model.parameters(), lr=1e-3)

In [9]:
# Training loop
with tqdm(dataloader, total=args.num_batches) as pbar:
    for k, batch in enumerate(pbar):
        outer_opt.zero_grad()
        model.train()

        tr_xs  = batch['train'][0].to(args.device)
        tr_ys  = batch['train'][1].to(args.device)
        tst_xs = batch['test'][0].to(args.device)
        tst_ys = batch['test'][1].to(args.device)

        outer_loss = torch.tensor(0., device=args.device)
        accuracy = torch.tensor(0., device=args.device)

        for t_idx, (tr_x, tr_y, tst_x, tst_y) in enumerate(
                zip(tr_xs, tr_ys, tst_xs, tst_ys)
        ):
            with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=False) as (fmodel, diffopt):
                tr_logit = fmodel(tr_x)
                inner_loss = F.cross_entropy(tr_logit, tr_y)

                diffopt.step(inner_loss)

                tst_logit = fmodel(tst_x)
                outer_loss += F.cross_entropy(tst_logit, tst_y)

                with torch.no_grad():
                    accuracy += get_accuracy(tst_logit, tst_y)

        outer_loss.div_(args.batch_size)
        accuracy.div_(args.batch_size)

        outer_loss.backward()
        outer_opt.step()

        pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))
        if k >= args.num_batches:
            break

100%|█████████████████████| 2000/2000 [15:24<00:00,  2.16it/s, accuracy=0.9833]
