In [8]:
import sys
sys.path.append('/home/debo/learn2learn/')
sys.path.append('../../learn2learn/')
from torch.autograd import grad
from utils import clone_module
from torch import nn
import random
import numpy as np

import torch
from torch import nn
from torch import optim

import learn2learn as l2l

from scipy.stats import truncnorm

In [21]:
ways=5
shots=1
meta_lr=0.003
fast_lr=0.5
meta_batch_size=32
adaptation_steps=1
num_iterations=60000
cuda=True
seed=42

In [14]:

from PIL.Image import LANCZOS
from torchvision.datasets import Omniglot
from torchvision import transforms
from torch.utils.data import ConcatDataset

random.seed(seed)
np.random.seed(seed)
th.manual_seed(seed)
device = th.device('cpu')
if cuda:
    th.cuda.manual_seed(seed)
    device = th.device('cuda')

# Create Dataset
# TODO: Create l2l.data.vision.FullOmniglot, which merges background and evaluation sets.
omni_background = Omniglot(root='./data',
                           background=True,
                           transform=transforms.Compose([
                               transforms.Resize(28, interpolation=LANCZOS),
                               transforms.ToTensor(),
                               # TODO: Add DiscreteRotations([0, 90, 180, 270])
                               lambda x: 1.0 - x,
                           ]),
                           download=True)
#    max_y = 1 + max([y for X, y in omni_background])
max_y = 964
omni_evaluation = Omniglot(root='./data',
                           background=False,
                           transform=transforms.Compose([
                               transforms.Resize(28, interpolation=LANCZOS),
                               transforms.ToTensor(),
                               # TODO: Add DiscreteRotations([0, 90, 180, 270])
                               lambda x: 1.0 - x,
                           ]),
                           target_transform=transforms.Compose([
                               lambda x: max_y + x,
                           ]),
                           download=True)
omniglot = ConcatDataset((omni_background, omni_evaluation))
train_generator = l2l.data.TaskGenerator(dataset=omniglot, ways=ways)
valid_generator = l2l.data.TaskGenerator(dataset=omniglot, ways=ways)
test_generator = l2l.data.TaskGenerator(dataset=omniglot, ways=ways)
# TODO: Implement an easy way to split one dataset into splits, based on classes.


Files already downloaded and verified
Files already downloaded and verified


In [23]:
num_input_channels = 1
def conv_block(in_channels: int, out_channels: int) -> nn.Module:
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2)
    )

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)
    
model = nn.Sequential(
        conv_block(num_input_channels, 64),
        conv_block(64, 64),
        conv_block(64, 64),
        conv_block(64, 64),
        Flatten(),
    )
# model.to(device, dtype=torch.double)

In [26]:
from torch.autograd import grad
from utils import clone_module
from torch import nn


class BaseLearner(nn.Module):

    def __init__(self, module=None):
        super(BaseLearner, self).__init__()
        self.module = module

    def __getattr__(self, attr):
        try:
            return super(BaseLearner, self).__getattr__(attr)
        except AttributeError:
            return getattr(self.__dict__['_modules']['module'], attr)

    def forward(self, *args, **kwargs):
        return self.module(*args, **kwargs)
    
def proto_update(model, lr, grads):
    """
    
    The function re-routes the Python object, thus avoiding in-place
    operations. However, it seems like PyTorch handles in-place operations
    fairly well.
    
    NOTE: The model itself is updated in-place (no deepcopy), but the
          parameters' tensors are not.

    NOTE: grads is None -> Don't set the gradients.
    """
    if grads is not None:
        params = list(model.parameters())
        if not len(grads) == len(list(params)):
            msg = 'WARNING:maml_update(): Parameters and gradients have different length. ('
            msg += str(len(params)) + ' vs ' + str(len(grads)) + ')'
            print(msg)
        for p, g in zip(params, grads):
            p.grad = g

    # Update the params
    for param_key in model._parameters:
        p = model._parameters[param_key]
        if p is not None and p.grad is not None:
            model._parameters[param_key] = p - lr * p.grad

    # Second, handle the buffers if necessary
    for buffer_key in model._buffers:
        buff = model._buffers[buffer_key]
        if buff is not None and buff.grad is not None:
            model._buffers[buffer_key] = buff - lr * buff.grad

    # Then, recurse for each submodule
    for module_key in model._modules:
        model._modules[module_key] = proto_update(model._modules[module_key],
                                                 lr=lr,
                                                 grads=None)
    return model


class ProtoNet(BaseLearner):
    def __init__(self, model, lr, first_order=False):
        super(ProtoNet, self).__init__()
        self.module = model
        self.lr = lr
        self.first_order = first_order

    def forward(self, *args, **kwargs):
        return self.module(*args, **kwargs)

    def adapt(self, loss, first_order=None):
        if first_order is None:
            first_order = self.first_order
        second_order = not first_order
        gradients = grad(loss,
                         self.module.parameters(),
                         retain_graph=second_order,
                         create_graph=second_order)
        self.module = proto_update(self.module, self.lr, gradients)

    def clone(self, first_order=None):
        if first_order is None:
            first_order = self.first_order
        return ProtoNet(clone_module(self.module),
                    lr=self.lr,
                    first_order=first_order)

In [27]:
def accuracy(predictions, targets):
    predictions = predictions.argmax(dim=1).view(targets.shape)
    return (predictions == targets).sum().float() / targets.size(0)

In [28]:
def fast_adapt(adaptation_data, evaluation_data, learner, loss, adaptation_steps, device):
    for step in range(adaptation_steps):
        data = [d for d in adaptation_data]
        X = th.cat([d[0] for d in data], dim=0).to(device)
        y = th.cat([th.tensor(d[1]).view(-1) for d in data], dim=0).to(device)
        train_error = loss(learner(X), y)
        train_error /= len(adaptation_data)
        learner.adapt(train_error)
    data = [d for d in evaluation_data]
    X = th.cat([d[0] for d in data], dim=0).to(device)
    y = th.cat([th.tensor(d[1]).view(-1) for d in data], dim=0).to(device)
    predictions = learner(X)
    valid_error = loss(predictions, y)
    valid_error /= len(evaluation_data)
    valid_accuracy = accuracy(predictions, y)
    return valid_error, valid_accuracy

In [30]:
model.to(device, dtype=torch.double)
ptnet = ProtoNet(model, lr=fast_lr, first_order=False)
opt = optim.Adam(model.parameters(), meta_lr)
loss = nn.CrossEntropyLoss(size_average=True, reduction='mean')

for iteration in range(num_iterations):
    opt.zero_grad()
    meta_train_error = 0.0
    meta_train_accuracy = 0.0
    meta_valid_error = 0.0
    meta_valid_accuracy = 0.0
    meta_test_error = 0.0
    meta_test_accuracy = 0.0
    for task in range(meta_batch_size):
        # Compute meta-training loss
        learner = ptnet.clone()
        adaptation_data = train_generator.sample(shots=shots)
        evaluation_data = train_generator.sample(shots=shots,
                                                 classes_to_sample=adaptation_data.sampled_classes)
        evaluation_error, evaluation_accuracy = fast_adapt(adaptation_data,
                                                           evaluation_data,
                                                           learner,
                                                           loss,
                                                           adaptation_steps,
                                                           device)
        evaluation_error.backward()
        meta_train_error += evaluation_error.item()
        meta_train_accuracy += evaluation_accuracy.item()

        # Compute meta-validation loss
        learner = ptnet.clone()
        adaptation_data = valid_generator.sample(shots=shots)
        evaluation_data = valid_generator.sample(shots=shots,
                                                 classes_to_sample=adaptation_data.sampled_classes)
        evaluation_error, evaluation_accuracy = fast_adapt(adaptation_data,
                                                           evaluation_data,
                                                           learner,
                                                           loss,
                                                           adaptation_steps,
                                                           device)
        meta_valid_error += evaluation_error.item()
        meta_valid_accuracy += evaluation_accuracy.item()

        # Compute meta-testing loss
        learner = ptnet.clone()
        adaptation_data = test_generator.sample(shots=shots)
        evaluation_data = test_generator.sample(shots=shots,
                                                classes_to_sample=adaptation_data.sampled_classes)
        evaluation_error, evaluation_accuracy = fast_adapt(adaptation_data,
                                                           evaluation_data,
                                                           learner,
                                                           loss,
                                                           adaptation_steps,
                                                           device)
        meta_test_error += evaluation_error.item()
        meta_test_accuracy += evaluation_accuracy.item()

    # Print some metrics
    print('\n')
    print('Iteration', iteration)
    print('Meta Train Error', meta_train_error / meta_batch_size)
    print('Meta Train Accuracy', meta_train_accuracy / meta_batch_size)
    print('Meta Valid Error', meta_valid_error / meta_batch_size)
    print('Meta Valid Accuracy', meta_valid_accuracy / meta_batch_size)
    print('Meta Test Error', meta_test_error / meta_batch_size)
    print('Meta Test Accuracy', meta_test_accuracy / meta_batch_size)

    # Average the accumulated gradients and optimize
    for p in ptnet.parameters():
        p.grad.data.mul_(1.0 / meta_batch_size)
    opt.step()





RuntimeError: Expected 4-dimensional input for 4-dimensional weight 64 1 3, but got 3-dimensional input of size [5, 28, 28] instead