In [1]:
%load_ext autoreload
%autoreload 2
%env CUDA_VISIBLE_DEVICES=1
import os, sys, time
sys.path.insert(0, '..')
import lib

import numpy as np
import torch, torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

import random
random.seed(42)
np.random.seed(42)
torch.random.manual_seed(42)

import time
from resnet import ResNet18
device = 'cuda' if torch.cuda.is_available() else 'cpu'

env: CUDA_VISIBLE_DEVICES=1
Files already downloaded and verified
Files already downloaded and verified


In [None]:
from torchvision import transforms, datasets

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=512, shuffle=True, num_workers=2)

testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

X_test, y_test = map(torch.cat, zip(*list(testloader)))

In [2]:
from torch.utils.checkpoint import checkpoint

class Print(nn.Module):
    def __init__(self, name):
        super().__init__()
        self.name = name
        self.times = 0
        
    def forward(self, x):
        self.times += 1
        print(self.name, "#", self.times)
        return x
        
class CheckpointedModule(nn.Sequential):
    def forward(self, x):
        if not x.requires_grad:
            x = x.requires_grad_(True)
        return checkpoint(super().forward, x)
        

In [None]:
model = nn.Sequential(
    Print('bar'),
    nn.Conv2d(3, 32, kernel_size=(3, 3), bias=None), nn.BatchNorm2d(32, affine=False), nn.ReLU(),
    nn.MaxPool2d((2, 2)),
    nn.Conv2d(32, 128, kernel_size=(3, 3), bias=None), nn.BatchNorm2d(128, affine=False), nn.ReLU(),
    nn.Conv2d(128, 128, kernel_size=(3, 3), bias=None), nn.BatchNorm2d(128, affine=False), nn.ReLU(),
    nn.Conv2d(128, 32, kernel_size=(3, 3), bias=None), nn.BatchNorm2d(32, affine=False), nn.ReLU(),
    nn.Conv2d(32, 128, kernel_size=(3, 3), bias=None), nn.BatchNorm2d(128, affine=False), nn.ReLU(),
    nn.Conv2d(128, 128, kernel_size=(3, 3), bias=None), nn.BatchNorm2d(128, affine=False), nn.ReLU(),
    nn.Conv2d(128, 32, kernel_size=(3, 3), bias=None), nn.BatchNorm2d(32, affine=False), nn.ReLU(),
    nn.Conv2d(32, 128, kernel_size=(3, 3), bias=None), nn.BatchNorm2d(128, affine=False), nn.ReLU(),
    nn.Conv2d(128, 128, kernel_size=(3, 3), bias=None), nn.BatchNorm2d(128, affine=False), nn.ReLU(),
    nn.Conv2d(128, 32, kernel_size=(3, 3), bias=None), nn.BatchNorm2d(32, affine=False), nn.ReLU(),
    nn.Flatten(),
    nn.Linear(32, 10)
        ).to(device)

model = lib.MAML(model).to(device)


In [None]:
x_batch = torch.randn(128, 3, 40, 40, device=device)
y_batch = torch.randint(0, 10, size=[128], device=device)

In [7]:
from itertools import chain

def maml(model: nn.Module, x_batch, y_batch, total_steps, checkpoint_every_steps=None, get_parameters=nn.Module.parameters):
    print("MODEL UNIQUE ID", id(model))
    #     assert not model.training, "randomness and batchnorm-like layers not yet supported"
    if checkpoint_every_steps is None:
        checkpoint_every_steps = max(1, int(np.ceil(np.sqrt(total_steps))))

    # TODO make sure get_parameters returns parameters in the same order each time it is called
    parameters_to_copy = list(get_parameters(model))
    parameters_to_copy_set = set(parameters_to_copy)
    parameters_not_to_copy = [
        param for param in chain(model.parameters(), model.buffers()) if param not in parameters_to_copy_set]

    # WARNING: this code treats model, parameters_to_copy, parameters_not_to_copy]
    # as a global variables for _maml_internal. Please DO NOT change or delete them in this function

    def _maml_internal(i, steps, *trainable_parameters):
        editable = copy_and_replace(
            model, dict(zip(parameters_to_copy, trainable_parameters)), parameters_not_to_copy)
        # version of model with specified initial parameters

        i_am_inside_checkpoint_forward = not torch.is_grad_enabled()

        for _ in range(steps.item()):
            i = i + 1
            print("MAML INTERNAL STEP:", int(i), "i_am_inside_checkpoint_forward:", i_am_inside_checkpoint_forward)

            with torch.enable_grad():
                preds = editable(x_batch)
                loss = F.cross_entropy(preds, y_batch)
                # TODO use ingraph_update optimizers
                with lib.do_not_copy(*parameters_not_to_copy):
                    editable = get_updated_model(
                        editable, loss=loss, learning_rate=0.01, detach=i_am_inside_checkpoint_forward,
                        parameters=list(get_parameters(editable)))
        return (i, loss, *get_parameters(editable))

    i = torch.zeros(1, requires_grad=True)
    trainable_parameters = get_parameters(model)
    for _ in range(total_steps // checkpoint_every_steps):
        i, loss, *trainable_parameters = checkpoint(
            _maml_internal, i, torch.as_tensor(checkpoint_every_steps), *trainable_parameters)

    if total_steps % checkpoint_every_steps != 0:
        i, loss, *trainable_parameters = checkpoint(
            _maml_internal, i, torch.as_tensor(total_steps % checkpoint_every_steps), *trainable_parameters)
    assert i == total_steps
    edited_model = copy_and_replace(
        model, dict(zip(get_parameters(model), trainable_parameters)), parameters_not_to_copy)
    return edited_model, loss

In [8]:
def naive_maml(model, x_batch, y_batch, total_steps, get_parameters=nn.Module.parameters):
    print("MODEL UNIQUE ID", id(model))
    
    parameters_to_copy = list(get_parameters(model))
    parameters_to_copy_set = set(parameters_to_copy)
    parameters_not_to_copy = [
        param for param in chain(model.parameters(), model.buffers()) if param not in parameters_to_copy_set]

    editable = model
    for i in range(total_steps):
        preds = editable(x_batch)
        loss = F.cross_entropy(preds, y_batch)
        #TODO use ingraph_update optimizers
        with lib.do_not_copy(*parameters_not_to_copy):
            editable = get_updated_model(
                editable, loss=loss, learning_rate=0.01, detach=False, parameters=list(get_parameters(editable)))
        
    return editable, loss

In [9]:
res_model, loss = naive_maml(model, x_batch, y_batch, 25)
print(loss, list(res_model.parameters())[-1])

MODEL UNIQUE ID 140671300441680
bar # 1
bar # 2
bar # 3
bar # 4
bar # 5
bar # 6
bar # 7
bar # 8
bar # 9
bar # 10
bar # 11
bar # 12
bar # 13
bar # 14
bar # 15
bar # 16
bar # 17
bar # 18
bar # 19
bar # 20
bar # 21
bar # 22
bar # 23
bar # 24
bar # 25
tensor(1.0415, device='cuda:0', grad_fn=<NllLossBackward>) tensor([-0.1516, -0.1224, -0.1589, -0.1736,  0.0486,  0.0528, -0.0095,  0.1595,
        -0.1024, -0.0525], device='cuda:0', grad_fn=<SubBackward0>)


In [9]:
res_model, loss = maml(model, x_batch, y_batch, 100, 5)
print(loss, list(res_model.parameters())[-1])

MODEL UNIQUE ID 140496043821264
MAML INTERNAL STEP: 1 i_am_inside_checkpoint_forward: True
bar # 1
MAML INTERNAL STEP: 2 i_am_inside_checkpoint_forward: True
bar # 2
MAML INTERNAL STEP: 3 i_am_inside_checkpoint_forward: True
bar # 3
MAML INTERNAL STEP: 4 i_am_inside_checkpoint_forward: True
bar # 4
MAML INTERNAL STEP: 5 i_am_inside_checkpoint_forward: True
bar # 5
MAML INTERNAL STEP: 6 i_am_inside_checkpoint_forward: True
bar # 1
MAML INTERNAL STEP: 7 i_am_inside_checkpoint_forward: True
bar # 2
MAML INTERNAL STEP: 8 i_am_inside_checkpoint_forward: True
bar # 3
MAML INTERNAL STEP: 9 i_am_inside_checkpoint_forward: True
bar # 4
MAML INTERNAL STEP: 10 i_am_inside_checkpoint_forward: True
bar # 5
MAML INTERNAL STEP: 11 i_am_inside_checkpoint_forward: True
bar # 1
MAML INTERNAL STEP: 12 i_am_inside_checkpoint_forward: True
bar # 2
MAML INTERNAL STEP: 13 i_am_inside_checkpoint_forward: True
bar # 3
MAML INTERNAL STEP: 14 i_am_inside_checkpoint_forward: True
bar # 4
MAML INTERNAL STEP: 15 i_

In [11]:
loss.backward()

MAML INTERNAL STEP: 996 i_am_inside_checkpoint_forward: False
foo # 1
MAML INTERNAL STEP: 997 i_am_inside_checkpoint_forward: False
foo # 2
MAML INTERNAL STEP: 998 i_am_inside_checkpoint_forward: False
foo # 3
MAML INTERNAL STEP: 999 i_am_inside_checkpoint_forward: False
foo # 4
MAML INTERNAL STEP: 1000 i_am_inside_checkpoint_forward: False
foo # 5
MAML INTERNAL STEP: 991 i_am_inside_checkpoint_forward: False
foo # 1
MAML INTERNAL STEP: 992 i_am_inside_checkpoint_forward: False
foo # 2
MAML INTERNAL STEP: 993 i_am_inside_checkpoint_forward: False
foo # 3
MAML INTERNAL STEP: 994 i_am_inside_checkpoint_forward: False
foo # 4
MAML INTERNAL STEP: 995 i_am_inside_checkpoint_forward: False
foo # 5
MAML INTERNAL STEP: 986 i_am_inside_checkpoint_forward: False
foo # 1
MAML INTERNAL STEP: 987 i_am_inside_checkpoint_forward: False
foo # 2
MAML INTERNAL STEP: 988 i_am_inside_checkpoint_forward: False
foo # 3
MAML INTERNAL STEP: 989 i_am_inside_checkpoint_forward: False
foo # 4
MAML INTERNAL STEP: