In [3]:
import random
import numpy as np
import torch
import learn2learn as l2l

from torch import nn, optim

In [4]:
def accuracy(predictions, targets):
    predictions = predictions.argmax(dim=1).view(targets.shape)
    return (predictions == targets).sum().float() / targets.size(0)


def fast_adapt(batch, learner, loss, adaptation_steps, shots, ways, device):
    data, labels = batch
    data, labels = data.to(device), labels.to(device)

    # Separate data into adaptation/evalutation sets
    adaptation_indices = np.zeros(data.size(0), dtype=bool)
    adaptation_indices[np.arange(shots*ways) * 2] = True
    evaluation_indices = torch.from_numpy(~adaptation_indices)
    adaptation_indices = torch.from_numpy(adaptation_indices)
    adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]

    # Adapt the model
    for step in range(adaptation_steps):
        train_error = loss(learner(adaptation_data), adaptation_labels)
        learner.adapt(train_error)

    # Evaluate the adapted model
    predictions = learner(evaluation_data)
    valid_error = loss(predictions, evaluation_labels)
    valid_accuracy = accuracy(predictions, evaluation_labels)
    return valid_error, valid_accuracy

In [5]:
ways=5
shots=5
meta_lr=0.003
fast_lr=0.5
meta_batch_size=32
adaptation_steps=1
num_iterations=6000 # 60000
cuda=False
seed=42

In [None]:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
device = torch.device('cpu')
if cuda:
    torch.cuda.manual_seed(seed)
    device = torch.device('cuda')

# Load train/validation/test tasksets using the benchmark interface
tasksets = l2l.vision.benchmarks.get_tasksets('omniglot',
                                              train_ways=ways,
                                              train_samples=2*shots,
                                              test_ways=ways,
                                              test_samples=2*shots,
                                              num_tasks=20000,
                                              root='~/data',
)

# Create model
model = l2l.vision.models.OmniglotFC(28 ** 2, ways)
model.to(device)
maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=False)
opt = optim.Adam(maml.parameters(), meta_lr)
loss = nn.CrossEntropyLoss(reduction='mean')

for iteration in range(num_iterations):
    opt.zero_grad()
    meta_train_error = 0.0
    meta_train_accuracy = 0.0
    meta_valid_error = 0.0
    meta_valid_accuracy = 0.0
    for task in range(meta_batch_size):
        # Compute meta-training loss
        learner = maml.clone()
        batch = tasksets.train.sample()
        evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                           learner,
                                                           loss,
                                                           adaptation_steps,
                                                           shots,
                                                           ways,
                                                           device)
        evaluation_error.backward()
        meta_train_error += evaluation_error.item()
        meta_train_accuracy += evaluation_accuracy.item()

        # Compute meta-validation loss
        learner = maml.clone()
        batch = tasksets.validation.sample()
        evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                           learner,
                                                           loss,
                                                           adaptation_steps,
                                                           shots,
                                                           ways,
                                                           device)
        meta_valid_error += evaluation_error.item()
        meta_valid_accuracy += evaluation_accuracy.item()

    # Print some metrics
    print('\n')
    print('Iteration', iteration)
    print('Meta Train Error', meta_train_error / meta_batch_size)
    print('Meta Train Accuracy', meta_train_accuracy / meta_batch_size)
    print('Meta Valid Error', meta_valid_error / meta_batch_size)
    print('Meta Valid Accuracy', meta_valid_accuracy / meta_batch_size)

    # Average the accumulated gradients and optimize
    for p in maml.parameters():
        p.grad.data.mul_(1.0 / meta_batch_size)
    opt.step()

meta_test_error = 0.0
meta_test_accuracy = 0.0
for task in range(meta_batch_size):
    # Compute meta-testing loss
    learner = maml.clone()
    batch = tasksets.test.sample()
    evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                       learner,
                                                       loss,
                                                       adaptation_steps,
                                                       shots,
                                                       ways,
                                                       device)
    meta_test_error += evaluation_error.item()
    meta_test_accuracy += evaluation_accuracy.item()
print('Meta Test Error', meta_test_error / meta_batch_size)
print('Meta Test Accuracy', meta_test_accuracy / meta_batch_size)

Files already downloaded and verified
Files already downloaded and verified


  return Variable._execution_engine.run_backward(




Iteration 0
Meta Train Error 1.536802627146244
Meta Train Accuracy 0.3575000006239861
Meta Valid Error 1.5398506298661232
Meta Valid Accuracy 0.32874999660998583


Iteration 1
Meta Train Error 1.4320994541049004
Meta Train Accuracy 0.43374999705702066
Meta Valid Error 1.4276456125080585
Meta Valid Accuracy 0.46625000098720193


Iteration 2
Meta Train Error 1.27377974614501
Meta Train Accuracy 0.5412499979138374
Meta Valid Error 1.2977323308587074
Meta Valid Accuracy 0.5150000043213367


Iteration 3
Meta Train Error 1.2362913750112057
Meta Train Accuracy 0.4824999966658652
Meta Valid Error 1.222744420170784
Meta Valid Accuracy 0.4974999953992665


Iteration 4
Meta Train Error 1.1136803310364485
Meta Train Accuracy 0.566249999217689
Meta Valid Error 1.1581104155629873
Meta Valid Accuracy 0.5525000002235174


Iteration 5
Meta Train Error 1.05033278465271
Meta Train Accuracy 0.5875000050291419
Meta Valid Error 1.1074686236679554
Meta Valid Accuracy 0.5574999982491136


Iteration 6
Meta T



Iteration 51
Meta Train Error 0.6029725852422416
Meta Train Accuracy 0.7862499980255961
Meta Valid Error 0.6270178821869195
Meta Valid Accuracy 0.76999999769032


Iteration 52
Meta Train Error 0.5638855006545782
Meta Train Accuracy 0.7975000031292439
Meta Valid Error 0.5920722875744104
Meta Valid Accuracy 0.8049999941140413


Iteration 53
Meta Train Error 0.6192152760922909
Meta Train Accuracy 0.7637500008568168
Meta Valid Error 0.6239639734849334
Meta Valid Accuracy 0.7699999986216426


Iteration 54
Meta Train Error 0.5474966447800398
Meta Train Accuracy 0.8124999981373549
Meta Valid Error 0.673550701700151
Meta Valid Accuracy 0.7500000027939677


Iteration 55
Meta Train Error 0.5796393433120102
Meta Train Accuracy 0.7974999938160181
Meta Valid Error 0.6460838518105447
Meta Valid Accuracy 0.7824999997392297


Iteration 56
Meta Train Error 0.5222367318347096
Meta Train Accuracy 0.8174999970942736
Meta Valid Error 0.6239109169691801
Meta Valid Accuracy 0.7787499967962503


Iteration 5



Iteration 101
Meta Train Error 0.4068642808124423
Meta Train Accuracy 0.8612499963492155
Meta Valid Error 0.4615884420927614
Meta Valid Accuracy 0.8412499986588955


Iteration 102
Meta Train Error 0.37257821252569556
Meta Train Accuracy 0.8774999938905239
Meta Valid Error 0.5567077035084367
Meta Valid Accuracy 0.798750001937151


Iteration 103
Meta Train Error 0.3660126437898725
Meta Train Accuracy 0.8775000013411045
Meta Valid Error 0.4825038560666144
Meta Valid Accuracy 0.8249999973922968


Iteration 104
Meta Train Error 0.42495828215032816
Meta Train Accuracy 0.8474999945610762
Meta Valid Error 0.4416588945314288
Meta Valid Accuracy 0.837500000372529


Iteration 105
Meta Train Error 0.41561083612032235
Meta Train Accuracy 0.8400000017136335
Meta Valid Error 0.4820490009151399
Meta Valid Accuracy 0.8187500014901161


Iteration 106
Meta Train Error 0.44474571268074214
Meta Train Accuracy 0.8475000001490116
Meta Valid Error 0.4874760943930596
Meta Valid Accuracy 0.8349999953061342






Iteration 151
Meta Train Error 0.35462665162049234
Meta Train Accuracy 0.8774999976158142
Meta Valid Error 0.42801580345258117
Meta Valid Accuracy 0.8562499973922968


Iteration 152
Meta Train Error 0.2809527018107474
Meta Train Accuracy 0.9137499984353781
Meta Valid Error 0.34081632690504193
Meta Valid Accuracy 0.878749992698431


Iteration 153
Meta Train Error 0.30456839001271874
Meta Train Accuracy 0.8962499983608723
Meta Valid Error 0.411830335855484
Meta Valid Accuracy 0.8500000014901161


Iteration 154
Meta Train Error 0.2779368464834988
Meta Train Accuracy 0.9037499986588955
Meta Valid Error 0.3696462311781943
Meta Valid Accuracy 0.8774999938905239


Iteration 155
Meta Train Error 0.32053715881193057
Meta Train Accuracy 0.8899999894201756
Meta Valid Error 0.37894187250640243
Meta Valid Accuracy 0.8737499993294477


Iteration 156
Meta Train Error 0.27366888837423176
Meta Train Accuracy 0.8962499983608723
Meta Valid Error 0.3819354916922748
Meta Valid Accuracy 0.8624999988824129



Iteration 201
Meta Train Error 0.24707239301642403
Meta Train Accuracy 0.9137499965727329
Meta Valid Error 0.29126606264617294
Meta Valid Accuracy 0.8962500020861626


Iteration 202
Meta Train Error 0.3089115236653015
Meta Train Accuracy 0.8974999915808439
Meta Valid Error 0.3212393276626244
Meta Valid Accuracy 0.8987499997019768


Iteration 203
Meta Train Error 0.2849200149066746
Meta Train Accuracy 0.90374999307096
Meta Valid Error 0.28353696735575795
Meta Valid Accuracy 0.8887499943375587


Iteration 204
Meta Train Error 0.2736281306715682
Meta Train Accuracy 0.9024999979883432
Meta Valid Error 0.3571910404134542
Meta Valid Accuracy 0.8799999970942736


Iteration 205
Meta Train Error 0.2551861989777535
Meta Train Accuracy 0.9225000031292439
Meta Valid Error 0.3153311505448073
Meta Valid Accuracy 0.8837499972432852


Iteration 206
Meta Train Error 0.27101614652201533
Meta Train Accuracy 0.907499996945262
Meta Valid Error 0.25466814381070435
Meta Valid Accuracy 0.9112499989569187






Iteration 251
Meta Train Error 0.23037701623979956
Meta Train Accuracy 0.919999998062849
Meta Valid Error 0.2405779764521867
Meta Valid Accuracy 0.9162499997764826


Iteration 252
Meta Train Error 0.2969686323776841
Meta Train Accuracy 0.8950000014156103
Meta Valid Error 0.2504047908587381
Meta Valid Accuracy 0.9137500002980232


Iteration 253
Meta Train Error 0.19117928302148357
Meta Train Accuracy 0.938749996945262
Meta Valid Error 0.256315587554127
Meta Valid Accuracy 0.9124999959021807


Iteration 254
Meta Train Error 0.21247798099648207
Meta Train Accuracy 0.9224999956786633
Meta Valid Error 0.29706951859407127
Meta Valid Accuracy 0.8974999934434891


Iteration 255
Meta Train Error 0.18880320701282471
Meta Train Accuracy 0.9387500006705523
Meta Valid Error 0.25661282625515014
Meta Valid Accuracy 0.91499999538064


Iteration 256
Meta Train Error 0.2308747199131176
Meta Train Accuracy 0.929999990388751
Meta Valid Error 0.3061440803576261
Meta Valid Accuracy 0.9037500005215406


It



Iteration 301
Meta Train Error 0.19752800470450893
Meta Train Accuracy 0.9262499958276749
Meta Valid Error 0.23837256786646321
Meta Valid Accuracy 0.9174999985843897


Iteration 302
Meta Train Error 0.2079161443398334
Meta Train Accuracy 0.9287499934434891
Meta Valid Error 0.25505164510104805
Meta Valid Accuracy 0.9099999982863665


Iteration 303
Meta Train Error 0.20277081400854513
Meta Train Accuracy 0.9262499939650297
Meta Valid Error 0.21688472159439698
Meta Valid Accuracy 0.9287500008940697


Iteration 304
Meta Train Error 0.21822233451530337
Meta Train Accuracy 0.9224999900907278
Meta Valid Error 0.2826466680271551
Meta Valid Accuracy 0.9024999961256981


Iteration 305
Meta Train Error 0.19901868072338402
Meta Train Accuracy 0.9437499940395355
Meta Valid Error 0.2389902548165992
Meta Valid Accuracy 0.9112499915063381


Iteration 306
Meta Train Error 0.2820074902847409
Meta Train Accuracy 0.900000000372529
Meta Valid Error 0.2876481598650571
Meta Valid Accuracy 0.897499995306134



Iteration 351
Meta Train Error 0.1614101103041321
Meta Train Accuracy 0.9424999933689833
Meta Valid Error 0.22516273404471576
Meta Valid Accuracy 0.9237499944865704


Iteration 352
Meta Train Error 0.21884868229972199
Meta Train Accuracy 0.9162499941885471
Meta Valid Error 0.20839666348183528
Meta Valid Accuracy 0.9162499960511923


Iteration 353
Meta Train Error 0.18415020301472396
Meta Train Accuracy 0.9349999967962503
Meta Valid Error 0.24673769995570183
Meta Valid Accuracy 0.921249995008111


Iteration 354
Meta Train Error 0.1770447087183129
Meta Train Accuracy 0.932499997317791
Meta Valid Error 0.22490212495904416
Meta Valid Accuracy 0.9199999943375587


Iteration 355
Meta Train Error 0.17481615304131992
Meta Train Accuracy 0.9412499964237213
Meta Valid Error 0.23606489732628688
Meta Valid Accuracy 0.9174999948590994


Iteration 356
Meta Train Error 0.1469544388819486
Meta Train Accuracy 0.9562499932944775
Meta Valid Error 0.2225899116601795
Meta Valid Accuracy 0.917499998584389



Iteration 401
Meta Train Error 0.1628221096470952
Meta Train Accuracy 0.9462499972432852
Meta Valid Error 0.19099367246963084
Meta Valid Accuracy 0.9387499932199717


Iteration 402
Meta Train Error 0.19245153613155708
Meta Train Accuracy 0.9287499971687794
Meta Valid Error 0.2179392963880673
Meta Valid Accuracy 0.9237499944865704


Iteration 403
Meta Train Error 0.17747654597042128
Meta Train Accuracy 0.941249992698431
Meta Valid Error 0.21246909239562228
Meta Valid Accuracy 0.9237499963492155


Iteration 404
Meta Train Error 0.18318187427939847
Meta Train Accuracy 0.9312499947845936
Meta Valid Error 0.2122811548761092
Meta Valid Accuracy 0.9312499985098839


Iteration 405
Meta Train Error 0.15162242844235152
Meta Train Accuracy 0.9524999931454659
Meta Valid Error 0.2504459325573407
Meta Valid Accuracy 0.90625


Iteration 406
Meta Train Error 0.15480412675242405
Meta Train Accuracy 0.9524999987334013
Meta Valid Error 0.2571554478199687
Meta Valid Accuracy 0.9187499992549419


Iterati



Iteration 451
Meta Train Error 0.2445527340460103
Meta Train Accuracy 0.9212499987334013
Meta Valid Error 0.1918119480542373
Meta Valid Accuracy 0.9299999941140413


Iteration 452
Meta Train Error 0.1645031736115925
Meta Train Accuracy 0.9499999918043613
Meta Valid Error 0.17106435509049334
Meta Valid Accuracy 0.9374999925494194


Iteration 453
Meta Train Error 0.178689655585913
Meta Train Accuracy 0.9374999981373549
Meta Valid Error 0.1672688511898741
Meta Valid Accuracy 0.9512499943375587


Iteration 454
Meta Train Error 0.16509550163755193
Meta Train Accuracy 0.9512499943375587
Meta Valid Error 0.22014814140857197
Meta Valid Accuracy 0.9249999951571226


Iteration 455
Meta Train Error 0.15083706588484347
Meta Train Accuracy 0.9449999928474426
Meta Valid Error 0.1886145050957566
Meta Valid Accuracy 0.9399999976158142


Iteration 456
Meta Train Error 0.18041544951847754
Meta Train Accuracy 0.9424999970942736
Meta Valid Error 0.15394959767581895
Meta Valid Accuracy 0.9437499940395355



Iteration 500
Meta Train Error 0.1814212769677397
Meta Train Accuracy 0.9362499937415123
Meta Valid Error 0.17370201071025804
Meta Valid Accuracy 0.9412499964237213


Iteration 501
Meta Train Error 0.14940897715860046
Meta Train Accuracy 0.9487499948590994
Meta Valid Error 0.21982096030842513
Meta Valid Accuracy 0.9237499963492155


Iteration 502
Meta Train Error 0.15463732622447424
Meta Train Accuracy 0.9487499948590994
Meta Valid Error 0.2025276096246671
Meta Valid Accuracy 0.9262499921023846


Iteration 503
Meta Train Error 0.1134947017126251
Meta Train Accuracy 0.9624999910593033
Meta Valid Error 0.17394430887361523
Meta Valid Accuracy 0.94624999538064


Iteration 504
Meta Train Error 0.1401271996437572
Meta Train Accuracy 0.9599999934434891
Meta Valid Error 0.21201236231718212
Meta Valid Accuracy 0.9287499953061342


Iteration 505
Meta Train Error 0.1376265232975129
Meta Train Accuracy 0.9499999955296516
Meta Valid Error 0.21726135135395452
Meta Valid Accuracy 0.9249999951571226



Iteration 549
Meta Train Error 0.1745303625939414
Meta Train Accuracy 0.9437499940395355
Meta Valid Error 0.18132130964659154
Meta Valid Accuracy 0.9412499964237213


Iteration 550
Meta Train Error 0.11823118009488098
Meta Train Accuracy 0.9549999944865704
Meta Valid Error 0.22964353105635382
Meta Valid Accuracy 0.9237499944865704


Iteration 551
Meta Train Error 0.1461891626531724
Meta Train Accuracy 0.9549999963492155
Meta Valid Error 0.17587152210762724
Meta Valid Accuracy 0.9337499979883432


Iteration 552
Meta Train Error 0.1591399202006869
Meta Train Accuracy 0.9449999984353781
Meta Valid Error 0.1985223634983413
Meta Valid Accuracy 0.9299999978393316


Iteration 553
Meta Train Error 0.10801589960465208
Meta Train Accuracy 0.9637499935925007
Meta Valid Error 0.12999779760139063
Meta Valid Accuracy 0.9574999939650297


Iteration 554
Meta Train Error 0.14631950019975193
Meta Train Accuracy 0.9474999997764826
Meta Valid Error 0.1614259929046966
Meta Valid Accuracy 0.93999999575316



Iteration 598
Meta Train Error 0.11827685168827884
Meta Train Accuracy 0.9574999921023846
Meta Valid Error 0.18742793903220445
Meta Valid Accuracy 0.9299999922513962


Iteration 599
Meta Train Error 0.15004396527365316
Meta Train Accuracy 0.9437499959021807
Meta Valid Error 0.18338491322356276
Meta Valid Accuracy 0.9362499974668026


Iteration 600
Meta Train Error 0.12845829434809275
Meta Train Accuracy 0.9624999985098839
Meta Valid Error 0.1628760583116673
Meta Valid Accuracy 0.947499992325902


Iteration 601
Meta Train Error 0.15932728216284886
Meta Train Accuracy 0.9537499938160181
Meta Valid Error 0.19129919007536955
Meta Valid Accuracy 0.9362499956041574


Iteration 602
Meta Train Error 0.1353548419137951
Meta Train Accuracy 0.9537499938160181
Meta Valid Error 0.22312230833631475
Meta Valid Accuracy 0.9237499963492155


Iteration 603
Meta Train Error 0.1709603594354121
Meta Train Accuracy 0.9462499972432852
Meta Valid Error 0.193603589956183
Meta Valid Accuracy 0.941249994561076



Iteration 648
Meta Train Error 0.11211002862546593
Meta Train Accuracy 0.9649999961256981
Meta Valid Error 0.1987654998083599
Meta Valid Accuracy 0.9237499982118607


Iteration 649
Meta Train Error 0.11101201323617715
Meta Train Accuracy 0.9649999961256981
Meta Valid Error 0.1512245779740624
Meta Valid Accuracy 0.952499995008111


Iteration 650
Meta Train Error 0.15388200742017943
Meta Train Accuracy 0.9474999941885471
Meta Valid Error 0.22975535329896957
Meta Valid Accuracy 0.9187499936670065


Iteration 651
Meta Train Error 0.11051281432446558
Meta Train Accuracy 0.9624999947845936
Meta Valid Error 0.20621612333343364
Meta Valid Accuracy 0.9287499971687794


Iteration 652
Meta Train Error 0.09840411102049984
Meta Train Accuracy 0.9674999956041574
Meta Valid Error 0.20724281016737223
Meta Valid Accuracy 0.9374999906867743


Iteration 653
Meta Train Error 0.12257143622264266
Meta Train Accuracy 0.9537499956786633
Meta Valid Error 0.1821189132751897
Meta Valid Accuracy 0.9349999930709

### validation accuracy increases from 29% to 69.4%

In [31]:
# torch.save(model.state_dict(), 'meta_omniglot2.pth')

In [5]:
ways=5
shots=1

model=l2l.vision.models.OmniglotFC(28 ** 2, ways)
model.load_state_dict(torch.load('meta_omniglot.pth'))
# model.eval()

OmniglotFC(
  (features): Sequential(
    (0): Flatten()
    (1): Sequential(
      (0): LinearBlock(
        (relu): ReLU()
        (normalize): BatchNorm1d(256, eps=0.001, momentum=0.999, affine=True, track_running_stats=False)
        (linear): Linear(in_features=784, out_features=256, bias=True)
      )
      (1): LinearBlock(
        (relu): ReLU()
        (normalize): BatchNorm1d(128, eps=0.001, momentum=0.999, affine=True, track_running_stats=False)
        (linear): Linear(in_features=256, out_features=128, bias=True)
      )
      (2): LinearBlock(
        (relu): ReLU()
        (normalize): BatchNorm1d(64, eps=0.001, momentum=0.999, affine=True, track_running_stats=False)
        (linear): Linear(in_features=128, out_features=64, bias=True)
      )
      (3): LinearBlock(
        (relu): ReLU()
        (normalize): BatchNorm1d(64, eps=0.001, momentum=0.999, affine=True, track_running_stats=False)
        (linear): Linear(in_features=64, out_features=64, bias=True)
      )
   