In [3]:
import random
import numpy as np
import torch
import learn2learn as l2l

from torch import nn, optim

In [4]:
def accuracy(predictions, targets):
    predictions = predictions.argmax(dim=1).view(targets.shape)
    return (predictions == targets).sum().float() / targets.size(0)


def fast_adapt(batch, learner, loss, adaptation_steps, shots, ways, device):
    data, labels = batch
    data, labels = data.to(device), labels.to(device)

    # Separate data into adaptation/evalutation sets
    adaptation_indices = np.zeros(data.size(0), dtype=bool)
    adaptation_indices[np.arange(shots*ways) * 2] = True
    evaluation_indices = torch.from_numpy(~adaptation_indices)
    adaptation_indices = torch.from_numpy(adaptation_indices)
    adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]

    # Adapt the model
    for step in range(adaptation_steps):
        train_error = loss(learner(adaptation_data), adaptation_labels)
        learner.adapt(train_error)

    # Evaluate the adapted model
    predictions = learner(evaluation_data)
    valid_error = loss(predictions, evaluation_labels)
    valid_accuracy = accuracy(predictions, evaluation_labels)
    return valid_error, valid_accuracy

In [5]:
ways=5
shots=5
meta_lr=0.003
fast_lr=0.5
meta_batch_size=32
adaptation_steps=1
num_iterations=6000 # 60000
cuda=False
seed=42

In [None]:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
device = torch.device('cpu')
if cuda:
    torch.cuda.manual_seed(seed)
    device = torch.device('cuda')

# Load train/validation/test tasksets using the benchmark interface
tasksets = l2l.vision.benchmarks.get_tasksets('omniglot',
                                              train_ways=ways,
                                              train_samples=2*shots,
                                              test_ways=ways,
                                              test_samples=2*shots,
                                              num_tasks=20000,
                                              root='~/data',
)

# Create model
model = l2l.vision.models.OmniglotFC(28 ** 2, ways)
model.to(device)
maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=False)
opt = optim.Adam(maml.parameters(), meta_lr)
loss = nn.CrossEntropyLoss(reduction='mean')

for iteration in range(num_iterations):
    opt.zero_grad()
    meta_train_error = 0.0
    meta_train_accuracy = 0.0
    meta_valid_error = 0.0
    meta_valid_accuracy = 0.0
    for task in range(meta_batch_size):
        # Compute meta-training loss
        learner = maml.clone()
        batch = tasksets.train.sample()
        evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                           learner,
                                                           loss,
                                                           adaptation_steps,
                                                           shots,
                                                           ways,
                                                           device)
        evaluation_error.backward()
        meta_train_error += evaluation_error.item()
        meta_train_accuracy += evaluation_accuracy.item()

        # Compute meta-validation loss
        learner = maml.clone()
        batch = tasksets.validation.sample()
        evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                           learner,
                                                           loss,
                                                           adaptation_steps,
                                                           shots,
                                                           ways,
                                                           device)
        meta_valid_error += evaluation_error.item()
        meta_valid_accuracy += evaluation_accuracy.item()

    # Print some metrics
    print('\n')
    print('Iteration', iteration)
    print('Meta Train Error', meta_train_error / meta_batch_size)
    print('Meta Train Accuracy', meta_train_accuracy / meta_batch_size)
    print('Meta Valid Error', meta_valid_error / meta_batch_size)
    print('Meta Valid Accuracy', meta_valid_accuracy / meta_batch_size)

    # Average the accumulated gradients and optimize
    for p in maml.parameters():
        p.grad.data.mul_(1.0 / meta_batch_size)
    opt.step()

meta_test_error = 0.0
meta_test_accuracy = 0.0
for task in range(meta_batch_size):
    # Compute meta-testing loss
    learner = maml.clone()
    batch = tasksets.test.sample()
    evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                       learner,
                                                       loss,
                                                       adaptation_steps,
                                                       shots,
                                                       ways,
                                                       device)
    meta_test_error += evaluation_error.item()
    meta_test_accuracy += evaluation_accuracy.item()
print('Meta Test Error', meta_test_error / meta_batch_size)
print('Meta Test Accuracy', meta_test_accuracy / meta_batch_size)

Files already downloaded and verified
Files already downloaded and verified


  return Variable._execution_engine.run_backward(




Iteration 0
Meta Train Error 1.536802627146244
Meta Train Accuracy 0.3575000006239861
Meta Valid Error 1.5398506298661232
Meta Valid Accuracy 0.32874999660998583


Iteration 1
Meta Train Error 1.4320994541049004
Meta Train Accuracy 0.43374999705702066
Meta Valid Error 1.4276456125080585
Meta Valid Accuracy 0.46625000098720193


Iteration 2
Meta Train Error 1.27377974614501
Meta Train Accuracy 0.5412499979138374
Meta Valid Error 1.2977323308587074
Meta Valid Accuracy 0.5150000043213367


Iteration 3
Meta Train Error 1.2362913750112057
Meta Train Accuracy 0.4824999966658652
Meta Valid Error 1.222744420170784
Meta Valid Accuracy 0.4974999953992665


Iteration 4
Meta Train Error 1.1136803310364485
Meta Train Accuracy 0.566249999217689
Meta Valid Error 1.1581104155629873
Meta Valid Accuracy 0.5525000002235174


Iteration 5
Meta Train Error 1.05033278465271
Meta Train Accuracy 0.5875000050291419
Meta Valid Error 1.1074686236679554
Meta Valid Accuracy 0.5574999982491136


Iteration 6
Meta T



Iteration 51
Meta Train Error 0.6029725852422416
Meta Train Accuracy 0.7862499980255961
Meta Valid Error 0.6270178821869195
Meta Valid Accuracy 0.76999999769032


Iteration 52
Meta Train Error 0.5638855006545782
Meta Train Accuracy 0.7975000031292439
Meta Valid Error 0.5920722875744104
Meta Valid Accuracy 0.8049999941140413


Iteration 53
Meta Train Error 0.6192152760922909
Meta Train Accuracy 0.7637500008568168
Meta Valid Error 0.6239639734849334
Meta Valid Accuracy 0.7699999986216426


Iteration 54
Meta Train Error 0.5474966447800398
Meta Train Accuracy 0.8124999981373549
Meta Valid Error 0.673550701700151
Meta Valid Accuracy 0.7500000027939677


Iteration 55
Meta Train Error 0.5796393433120102
Meta Train Accuracy 0.7974999938160181
Meta Valid Error 0.6460838518105447
Meta Valid Accuracy 0.7824999997392297


Iteration 56
Meta Train Error 0.5222367318347096
Meta Train Accuracy 0.8174999970942736
Meta Valid Error 0.6239109169691801
Meta Valid Accuracy 0.7787499967962503


Iteration 5



Iteration 101
Meta Train Error 0.4068642808124423
Meta Train Accuracy 0.8612499963492155
Meta Valid Error 0.4615884420927614
Meta Valid Accuracy 0.8412499986588955


Iteration 102
Meta Train Error 0.37257821252569556
Meta Train Accuracy 0.8774999938905239
Meta Valid Error 0.5567077035084367
Meta Valid Accuracy 0.798750001937151


Iteration 103
Meta Train Error 0.3660126437898725
Meta Train Accuracy 0.8775000013411045
Meta Valid Error 0.4825038560666144
Meta Valid Accuracy 0.8249999973922968


Iteration 104
Meta Train Error 0.42495828215032816
Meta Train Accuracy 0.8474999945610762
Meta Valid Error 0.4416588945314288
Meta Valid Accuracy 0.837500000372529


### validation accuracy increases from 29% to 69.4%

In [31]:
# torch.save(model.state_dict(), 'meta_omniglot2.pth')

In [5]:
ways=5
shots=1

model=l2l.vision.models.OmniglotFC(28 ** 2, ways)
model.load_state_dict(torch.load('meta_omniglot.pth'))
# model.eval()

OmniglotFC(
  (features): Sequential(
    (0): Flatten()
    (1): Sequential(
      (0): LinearBlock(
        (relu): ReLU()
        (normalize): BatchNorm1d(256, eps=0.001, momentum=0.999, affine=True, track_running_stats=False)
        (linear): Linear(in_features=784, out_features=256, bias=True)
      )
      (1): LinearBlock(
        (relu): ReLU()
        (normalize): BatchNorm1d(128, eps=0.001, momentum=0.999, affine=True, track_running_stats=False)
        (linear): Linear(in_features=256, out_features=128, bias=True)
      )
      (2): LinearBlock(
        (relu): ReLU()
        (normalize): BatchNorm1d(64, eps=0.001, momentum=0.999, affine=True, track_running_stats=False)
        (linear): Linear(in_features=128, out_features=64, bias=True)
      )
      (3): LinearBlock(
        (relu): ReLU()
        (normalize): BatchNorm1d(64, eps=0.001, momentum=0.999, affine=True, track_running_stats=False)
        (linear): Linear(in_features=64, out_features=64, bias=True)
      )
   