## Gradient Checkpointing Meta-Agnostic-Meta-Learning

We demonstrate the running example of memory efficient MAML by utilizing Gradient Checkpointing technique.
This notebook performs one forward and backward for MAML with a large number of iterations

Data: Random tensor (batch_size, 3, 224, 224)  

Model: ResNet18

Optimizer: SGD with 0.01 learning rate

Batch size: 16

Max MAML steps: 1000

GPU: GeForce RTX 2080Ti

In [1]:
%load_ext autoreload
%autoreload 2
%env CUDA_VISIBLE_DEVICES=1
import os, sys, time
sys.path.insert(0, '..')
import torch_maml

import time
import numpy as np
import torch, torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

device = 'cuda' if torch.cuda.is_available() else 'cpu'

env: CUDA_VISIBLE_DEVICES=1


### Reproducibility

In [2]:
import random
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmarks = False

#### Generate batch for demonstration

In [3]:
x_batch, y_batch = torch.randn((16, 3, 224, 224)), torch.randint(0, 10, (16, ))

#### Define compute_loss function and create model

In [4]:
# Interface:
# def compute_loss(model, data, **kwargs):
#      <YOUR CODE HERE>
#      return loss

# Our example
def compute_loss(model, data, device='cuda'):
    inputs, targets = data
    preds = model(inputs.to(device=device))
    loss = F.cross_entropy(preds, targets.to(device=device))
    return loss

In [5]:
# Model is a torch.nn.Module 
model = models.resnet18(num_classes=10).to(device)
# Optimizer is a custom MAML optimizer, e.g. SGD
optimizer = torch-maml.IngrapGradientDescent(learning_rate=0.01)

#### Create NaiveMAML and GradientCheckpointingMAML for comparison

In [6]:
gd_ckpt_maml = torch-maml.GradientCheckpointMAML(model, compute_loss,
                                                 optimizer=optimizer,
                                                 checkpoint_steps=5)
maml = torch-maml.NaiveMAML(model, compute_loss, optimizer=optimizer)

#### Check NaiveMAML results are equal to GradientCheckpointMAML one 

In [7]:
# First, we set such max steps that fits memory for naive MAML to check the implementation
max_steps = 10
max_grad_norm = 1e2
inputs = [(x_batch, y_batch)] * max_steps

In [8]:
%%time
torch.cuda.empty_cache()
updated_model, loss, _ = maml(inputs, loss_kwargs={'device':device},
                              max_grad_norm=max_grad_norm)
loss.backward()
torch.cuda.synchronize()
print("Loss: %.4f" % loss.item())

Loss: 0.6287
CPU times: user 2.1 s, sys: 733 ms, total: 2.84 s
Wall time: 2.85 s


In [9]:
grads1 = [params.grad for params in model.parameters()]

In [10]:
%%time
torch.cuda.empty_cache()
updated_model, loss, _ = gd_ckpt_maml(inputs, loss_kwargs={'device':device},
                                      max_grad_norm=max_grad_norm)
loss.backward()
torch.cuda.synchronize()
print("Loss: %.4f" % loss.item())

Loss: 0.6287
CPU times: user 2.43 s, sys: 969 ms, total: 3.4 s
Wall time: 3.4 s


In [11]:
grads2 = [params.grad for params in model.parameters()]

In [12]:
for grad1, grad2 in zip(grads1, grads2):
    assert (grad1 == grad2).all()

Model parameter gradients coincide for all layers as well as loss values of updated models

Let's now increase max_steps to obviously large value, e.g. 1000

In [7]:
max_iters = 100
inputs = [(x_batch, y_batch)] * max_iters
gd_ckpt_maml.checkpoint_steps = 10

In [14]:
%%time
torch.cuda.empty_cache()
updated_model, loss, _ = maml(inputs, loss_kwargs={'device':device},
                              max_grad_norm=max_grad_norm)
loss.backward()
torch.cuda.synchronize()
print("Loss: %.4f" % loss.item())

RuntimeError: CUDA out of memory. Tried to allocate 26.00 MiB (GPU 0; 10.76 GiB total capacity; 9.58 GiB already allocated; 9.06 MiB free; 324.61 MiB cached)

In [8]:
%%time
torch.cuda.empty_cache()
updated_model, loss, _ = gd_ckpt_maml(inputs, loss_kwargs={'device':device},
                                      max_grad_norm=max_grad_norm)
loss.backward()
torch.cuda.synchronize()
print("Loss: %.4f" % loss.item())

Loss: 0.0419
CPU times: user 24.9 s, sys: 8.82 s, total: 33.7 s
Wall time: 33.7 s
