<a href="https://colab.research.google.com/github/lakshmirnair/applet/blob/master/maml_mini_imagenet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import random
import numpy as np

import torch
from torch import nn, optim
!pip install learn2learn
import learn2learn as l2l
from learn2learn.data.transforms import (NWays,
                                         KShots,
                                         LoadData,
                                         RemapLabels,
                                         ConsecutiveLabels)


def accuracy(predictions, targets):
    predictions = predictions.argmax(dim=1).view(targets.shape)
    return (predictions == targets).sum().float() / targets.size(0)


def fast_adapt(batch, learner, loss, adaptation_steps, shots, ways, device):
    data, labels = batch
    data, labels = data.to(device), labels.to(device)

    # Separate data into adaptation/evalutation sets
    adaptation_indices = np.zeros(data.size(0), dtype=bool)
    adaptation_indices[np.arange(shots*ways) * 2] = True
    evaluation_indices = torch.from_numpy(~adaptation_indices)
    adaptation_indices = torch.from_numpy(adaptation_indices)
    adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]

    # Adapt the model
    for step in range(adaptation_steps):
        adaptation_error = loss(learner(adaptation_data), adaptation_labels)
        adaptation_error /= len(adaptation_data)
        learner.adapt(adaptation_error)

    # Evaluate the adapted model
    predictions = learner(evaluation_data)
    evaluation_error = loss(predictions, evaluation_labels)
    evaluation_error /= len(evaluation_data)
    evaluation_accuracy = accuracy(predictions, evaluation_labels)
    return evaluation_error, evaluation_accuracy


def main(
        ways=5,
        shots=3,
        meta_lr=0.003,
        fast_lr=0.5,
        meta_batch_size=32,
        adaptation_steps=1,
        num_iterations=60000,
        cuda=True,
        seed=42,
):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device('cpu')
    if cuda and torch.cuda.device_count():
        torch.cuda.manual_seed(seed)
        device = torch.device('cuda')

    # Create Tasksets using the benchmark interface
    tasksets = l2l.vision.benchmarks.get_tasksets('mini-imagenet',
                                                  train_samples=2*shots,
                                                  train_ways=ways,
                                                  test_samples=2*shots,
                                                  test_ways=ways,
                                                  root='~/data',
    )

    # Create model
    model = l2l.vision.models.MiniImagenetCNN(ways)
    model.to(device)
    maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=False)
    opt = optim.Adam(maml.parameters(), meta_lr)
    loss = nn.CrossEntropyLoss(reduction='mean')

    for iteration in range(num_iterations):
        opt.zero_grad()
        meta_train_error = 0.0
        meta_train_accuracy = 0.0
        meta_valid_error = 0.0
        meta_valid_accuracy = 0.0
        for task in range(meta_batch_size):
            # Compute meta-training loss
            learner = maml.clone()
            batch = tasksets.train.sample()
            evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                               learner,
                                                               loss,
                                                               adaptation_steps,
                                                               shots,
                                                               ways,
                                                               device)
            evaluation_error.backward()
            meta_train_error += evaluation_error.item()
            meta_train_accuracy += evaluation_accuracy.item()

            # Compute meta-validation loss
            learner = maml.clone()
            batch = tasksets.validation.sample()
            evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                               learner,
                                                               loss,
                                                               adaptation_steps,
                                                               shots,
                                                               ways,
                                                               device)
            meta_valid_error += evaluation_error.item()
            meta_valid_accuracy += evaluation_accuracy.item()

        # Print some metrics
        print('\n')
        print('Iteration', iteration)
        print('Meta Train Error', meta_train_error / meta_batch_size)
        print('Meta Train Accuracy', meta_train_accuracy / meta_batch_size)
        print('Meta Valid Error', meta_valid_error / meta_batch_size)
        print('Meta Valid Accuracy', meta_valid_accuracy / meta_batch_size)

        # Average the accumulated gradients and optimize
        for p in maml.parameters():
            p.grad.data.mul_(1.0 / meta_batch_size)
        opt.step()

    meta_test_error = 0.0
    meta_test_accuracy = 0.0
    for task in range(meta_batch_size):
        # Compute meta-testing loss
        learner = maml.clone()
        batch = tasksets.test.sample()
        evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                           learner,
                                                           loss,
                                                           adaptation_steps,
                                                           shots,
                                                           ways,
                                                           device)
        meta_test_error += evaluation_error.item()
        meta_test_accuracy += evaluation_accuracy.item()
    print('Meta Test Error', meta_test_error / meta_batch_size)
    print('Meta Test Accuracy', meta_test_accuracy / meta_batch_size)


if __name__ == '__main__':
    main()

Collecting learn2learn
[?25l  Downloading https://files.pythonhosted.org/packages/da/9c/6ac5cf155baee6279a1fe4dcad8ed59131c7c33c5a805523609ec2d85810/learn2learn-0.1.2.tar.gz (321kB)
[K     |████████████████████████████████| 327kB 1.7MB/s 
Building wheels for collected packages: learn2learn
  Building wheel for learn2learn (setup.py) ... [?25l[?25hdone
  Created wheel for learn2learn: filename=learn2learn-0.1.2-cp36-cp36m-linux_x86_64.whl size=836491 sha256=68a951818e324311b292a210fc239b99afbcc81dae059dcd21f74e864e4e5943
  Stored in directory: /root/.cache/pip/wheels/37/ae/e2/fe45cb0d10f64d01ed2d9ff42904d919924da021796b05a575
Successfully built learn2learn
Installing collected packages: learn2learn
Successfully installed learn2learn-0.1.2
Downloading mini-ImageNet -- train
Downloading: /root/data/mini-imagenet-cache-train.pkl
Downloading mini-ImageNet -- validation
Downloading: /root/data/mini-imagenet-cache-validation.pkl
Downloading mini-ImageNet -- test
Downloading: /root/data/mi





Iteration 0
Meta Train Error 0.11082327854819596
Meta Train Accuracy 0.23541667452082038
Meta Valid Error 0.10900857369415462
Meta Valid Accuracy 0.29166667559184134


Iteration 1
Meta Train Error 0.10696659074164927
Meta Train Accuracy 0.2750000087544322
Meta Valid Error 0.10944312950596213
Meta Valid Accuracy 0.23333334177732468


Iteration 2
Meta Train Error 0.10701689939014614
Meta Train Accuracy 0.266666674753651
Meta Valid Error 0.11008032504469156
Meta Valid Accuracy 0.27083334140479565


Iteration 3
Meta Train Error 0.10628528567031026
Meta Train Accuracy 0.2958333413116634
Meta Valid Error 0.1079073476139456
Meta Valid Accuracy 0.258333342615515


Iteration 4
Meta Train Error 0.10433838004246354
Meta Train Accuracy 0.29375000996515155
Meta Valid Error 0.10663084127008915
Meta Valid Accuracy 0.2791666758712381


Iteration 5
Meta Train Error 0.10512840421870351
Meta Train Accuracy 0.27500000805594027
Meta Valid Error 0.10445646056905389
Meta Valid Accuracy 0.29583334061317146
