In [19]:
import sys

sys.path.append('/home/debo/learn2learn/')
sys.path.append('../../learn2learn/')

In [20]:
!git branch

  master[m
  omniglot[m
* [32mprototypical_network[m


In [32]:
from torch.autograd import grad
from utils import clone_module
from torch import nn


class BaseLearner(nn.Module):

    def __init__(self, module=None):
        super(BaseLearner, self).__init__()
        self.module = module

    def __getattr__(self, attr):
        try:
            return super(BaseLearner, self).__getattr__(attr)
        except AttributeError:
            return getattr(self.__dict__['_modules']['module'], attr)

    def forward(self, *args, **kwargs):
        return self.module(*args, **kwargs)

In [33]:
def maml_update(model, lr, grads):
    """
    Performs a MAML update on model using grads and lr.
    The function re-routes the Python object, thus avoiding in-place
    operations. However, it seems like PyTorch handles in-place operations
    fairly well.
    
    NOTE: The model itself is updated in-place (no deepcopy), but the
          parameters' tensors are not.

    NOTE: grads is None -> Don't set the gradients.
    """
    if grads is not None:
        params = list(model.parameters())
        if not len(grads) == len(list(params)):
            msg = 'WARNING:maml_update(): Parameters and gradients have different length. ('
            msg += str(len(params)) + ' vs ' + str(len(grads)) + ')'
            print(msg)
        for p, g in zip(params, grads):
            p.grad = g

    # Update the params
    for param_key in model._parameters:
        p = model._parameters[param_key]
        if p is not None and p.grad is not None:
            model._parameters[param_key] = p - lr * p.grad

    # Second, handle the buffers if necessary
    for buffer_key in model._buffers:
        buff = model._buffers[buffer_key]
        if buff is not None and buff.grad is not None:
            model._buffers[buffer_key] = buff - lr * buff.grad

    # Then, recurse for each submodule
    for module_key in model._modules:
        model._modules[module_key] = maml_update(model._modules[module_key],
                                                 lr=lr,
                                                 grads=None)
    return model


class MAML(BaseLearner):

    def __init__(self, model, lr, first_order=False):
        super(MAML, self).__init__()
        self.module = model
        self.lr = lr
        self.first_order = first_order

    def forward(self, *args, **kwargs):
        return self.module(*args, **kwargs)

    def adapt(self, loss, first_order=None):
        if first_order is None:
            first_order = self.first_order
        second_order = not first_order
        gradients = grad(loss,
                         self.module.parameters(),
                         retain_graph=second_order,
                         create_graph=second_order)
        self.module = maml_update(self.module, self.lr, gradients)

    def clone(self, first_order=None):
        if first_order is None:
            first_order = self.first_order
        return MAML(clone_module(self.module),
                    lr=self.lr,
                    first_order=first_order)


In [34]:
#!/usr/bin/env python3

import random
import numpy as np

import torch as th
from torch import nn
from torch import optim

import learn2learn as l2l

from scipy.stats import truncnorm


def truncated_normal_(tensor, mean=0.0, std=1.0):
    # PT doesn't have truncated normal.
    # https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/18
    values = truncnorm.rvs(-2, 2, size=tensor.shape)
    values = mean + std * values
    tensor.copy_(th.from_numpy(values))
    return tensor


def maml_fc_init_(module):
    if hasattr(module, 'weight') and module.weight is not None:
        truncated_normal_(module.weight.data, mean=0.0, std=0.01)
    if hasattr(module, 'bias') and module.bias is not None:
        nn.init.constant_(module.bias.data, 0.0)
    return module


class MAMLLinearBlock(nn.Module):

    def __init__(self, input_size, output_size):
        super(MAMLLinearBlock, self).__init__()
        self.relu = nn.ReLU()
        self.normalize = nn.BatchNorm1d(output_size,
                                        affine=True,
                                        momentum=0.999,
                                        eps=1e-3,
                                        track_running_stats=False,
                                        )
        # TODO: Remove affine and use AddBias
        # self.bias = AddBias(output_size)
        self.linear = nn.Linear(input_size, output_size)
        maml_fc_init_(self.linear)

    def forward(self, x):
        x = self.linear(x)
        # x = self.bias(x)
        x = self.normalize(x)
        x = self.relu(x)
        return x

class MAMLFC(nn.Sequential):

    def __init__(self, input_size, output_size, sizes=None):
        if sizes is None:
            sizes = [256, 128, 64, 64]
        layers = [MAMLLinearBlock(input_size, sizes[0]), ]
        for s_i, s_o in zip(sizes[:-1], sizes[1:]):
            layers.append(MAMLLinearBlock(s_i, s_o))
        layers.append(maml_fc_init_(nn.Linear(sizes[-1], output_size)))
        super(MAMLFC, self).__init__(*layers)
#        super(MAMLFC, self).__init__(
#            MAMLLinearBlock(input_size, 256),
#            MAMLLinearBlock(256, 128),
#            MAMLLinearBlock(128, 64),
#            MAMLLinearBlock(64, 64),
#            maml_fc_init_(nn.Linear(64, output_size)),
#        )
        self.input_size = input_size

    def forward(self, x):
        return super(MAMLFC, self).forward(x.view(-1, self.input_size))


def accuracy(predictions, targets):
    predictions = predictions.argmax(dim=1).view(targets.shape)
    return (predictions == targets).sum().float() / targets.size(0)


def fast_adapt(adaptation_data, evaluation_data, learner, loss, adaptation_steps, device):
    for step in range(adaptation_steps):
        data = [d for d in adaptation_data]
        X = th.cat([d[0] for d in data], dim=0).to(device)
        y = th.cat([th.tensor(d[1]).view(-1) for d in data], dim=0).to(device)
        train_error = loss(learner(X), y)
        train_error /= len(adaptation_data)
        learner.adapt(train_error)
    data = [d for d in evaluation_data]
    X = th.cat([d[0] for d in data], dim=0).to(device)
    y = th.cat([th.tensor(d[1]).view(-1) for d in data], dim=0).to(device)
    predictions = learner(X)
    valid_error = loss(predictions, y)
    valid_error /= len(evaluation_data)
    valid_accuracy = accuracy(predictions, y)
    return valid_error, valid_accuracy



def main(
        ways=5,
        shots=1,
        meta_lr=0.003,
        fast_lr=0.5,
        meta_batch_size=32,
        adaptation_steps=1,
        num_iterations=60000,
        cuda=True,
        seed=42,
    ):
    from PIL.Image import LANCZOS
    from torchvision.datasets import Omniglot
    from torchvision import transforms
    from torch.utils.data import ConcatDataset

    random.seed(seed)
    np.random.seed(seed)
    th.manual_seed(seed)
    device = th.device('cpu')
    if cuda:
        th.cuda.manual_seed(seed)
        device = th.device('cuda')

    # Create Dataset
    # TODO: Create l2l.data.vision.FullOmniglot, which merges background and evaluation sets.
    omni_background = Omniglot(root='./data',
                               background=True,
                               transform=transforms.Compose([
                                   transforms.Resize(28, interpolation=LANCZOS),
                                   transforms.ToTensor(),
                                   # TODO: Add DiscreteRotations([0, 90, 180, 270])
                                   lambda x: 1.0 - x,
                               ]),
                               download=True)
#    max_y = 1 + max([y for X, y in omni_background])
    max_y = 964
    omni_evaluation = Omniglot(root='./data',
                               background=False,
                               transform=transforms.Compose([
                                   transforms.Resize(28, interpolation=LANCZOS),
                                   transforms.ToTensor(),
                                   # TODO: Add DiscreteRotations([0, 90, 180, 270])
                                   lambda x: 1.0 - x,
                               ]),
                               target_transform=transforms.Compose([
                                   lambda x: max_y + x,
                               ]),
                               download=True)
    omniglot = ConcatDataset((omni_background, omni_evaluation))
    train_generator = l2l.data.TaskGenerator(dataset=omniglot, ways=ways)
    valid_generator = l2l.data.TaskGenerator(dataset=omniglot, ways=ways)
    test_generator = l2l.data.TaskGenerator(dataset=omniglot, ways=ways)
    # TODO: Implement an easy way to split one dataset into splits, based on classes.

    # Create model
    model = MAMLFC(28**2, ways)
    model.to(device)
    maml = MAML(model, lr=fast_lr, first_order=False)
    opt = optim.Adam(maml.parameters(), meta_lr)
    loss = nn.CrossEntropyLoss(size_average=True, reduction='mean')

    for iteration in range(num_iterations):
        opt.zero_grad()
        meta_train_error = 0.0
        meta_train_accuracy = 0.0
        meta_valid_error = 0.0
        meta_valid_accuracy = 0.0
        meta_test_error = 0.0
        meta_test_accuracy = 0.0
        for task in range(meta_batch_size):
            # Compute meta-training loss
            learner = maml.clone()
            adaptation_data = train_generator.sample(shots=shots)
            evaluation_data = train_generator.sample(shots=shots,
                                                     classes_to_sample=adaptation_data.sampled_classes)
            evaluation_error, evaluation_accuracy = fast_adapt(adaptation_data,
                                                               evaluation_data,
                                                               learner,
                                                               loss,
                                                               adaptation_steps,
                                                               device)
            evaluation_error.backward()
            meta_train_error += evaluation_error.item()
            meta_train_accuracy += evaluation_accuracy.item()

            # Compute meta-validation loss
            learner = maml.clone()
            adaptation_data = valid_generator.sample(shots=shots)
            evaluation_data = valid_generator.sample(shots=shots,
                                                     classes_to_sample=adaptation_data.sampled_classes)
            evaluation_error, evaluation_accuracy = fast_adapt(adaptation_data,
                                                               evaluation_data,
                                                               learner,
                                                               loss,
                                                               adaptation_steps,
                                                               device)
            meta_valid_error += evaluation_error.item()
            meta_valid_accuracy += evaluation_accuracy.item()

            # Compute meta-testing loss
            learner = maml.clone()
            adaptation_data = test_generator.sample(shots=shots)
            evaluation_data = test_generator.sample(shots=shots,
                                                    classes_to_sample=adaptation_data.sampled_classes)
            evaluation_error, evaluation_accuracy = fast_adapt(adaptation_data,
                                                               evaluation_data,
                                                               learner,
                                                               loss,
                                                               adaptation_steps,
                                                               device)
            meta_test_error += evaluation_error.item()
            meta_test_accuracy += evaluation_accuracy.item()

        # Print some metrics
        print('\n')
        print('Iteration', iteration)
        print('Meta Train Error', meta_train_error / meta_batch_size)
        print('Meta Train Accuracy', meta_train_accuracy / meta_batch_size)
        print('Meta Valid Error', meta_valid_error / meta_batch_size)
        print('Meta Valid Accuracy', meta_valid_accuracy / meta_batch_size)
        print('Meta Test Error', meta_test_error / meta_batch_size)
        print('Meta Test Accuracy', meta_test_accuracy / meta_batch_size)

        # Average the accumulated gradients and optimize
        for p in maml.parameters():
            p.grad.data.mul_(1.0 / meta_batch_size)
        opt.step()



In [35]:
main()

Files already downloaded and verified
Files already downloaded and verified






Iteration 0
Meta Train Error 0.31229730788618326
Meta Train Accuracy 0.44375001126900315
Meta Valid Error 0.31193888559937477
Meta Valid Accuracy 0.41875001043081284
Meta Test Error 0.3159230723977089
Meta Test Accuracy 0.33750000689178705


Iteration 1
Meta Train Error 0.3043672563508153
Meta Train Accuracy 0.48125001322478056
Meta Valid Error 0.30012034717947245
Meta Valid Accuracy 0.5937500135041773
Meta Test Error 0.2989203268662095
Meta Test Accuracy 0.5562500152736902


Iteration 2
Meta Train Error 0.2911027930676937
Meta Train Accuracy 0.5625000116415322
Meta Valid Error 0.2865180401131511
Meta Valid Accuracy 0.568750015925616
Meta Test Error 0.2849751953035593
Meta Test Accuracy 0.5875000138767064


Iteration 3
Meta Train Error 0.27604627050459385
Meta Train Accuracy 0.637500012293458
Meta Valid Error 0.28138114139437675
Meta Valid Accuracy 0.6187500120140612
Meta Test Error 0.28452222887426615
Meta Test Accuracy 0.5562500110827386


Iteration 4
Meta Train Error 0.27506551891



Iteration 35
Meta Train Error 0.16446618421468884
Meta Train Accuracy 0.7312500108964741
Meta Valid Error 0.1774138929322362
Meta Valid Accuracy 0.6937500140629709
Meta Test Error 0.2030633909162134
Meta Test Accuracy 0.5812500128522515


Iteration 36
Meta Train Error 0.16824059654027224
Meta Train Accuracy 0.7125000115483999
Meta Valid Error 0.1805078205652535
Meta Valid Accuracy 0.7187500121071935
Meta Test Error 0.1669217178132385
Meta Test Accuracy 0.731250009033829


Iteration 37
Meta Train Error 0.16904267820063978
Meta Train Accuracy 0.7125000110827386
Meta Valid Error 0.19665839697699994
Meta Valid Accuracy 0.6812500110827386
Meta Test Error 0.16468996845651418
Meta Test Accuracy 0.7187500149011612


Iteration 38
Meta Train Error 0.17568306589964777
Meta Train Accuracy 0.7000000118277967
Meta Valid Error 0.18122114962898195
Meta Valid Accuracy 0.6875000107102096
Meta Test Error 0.19467550492845476
Meta Test Accuracy 0.6625000098720193


Iteration 39
Meta Train Error 0.1688691

KeyboardInterrupt: 