In [1]:
#!/usr/bin/env python3
"""
Demonstrates how to:
    * use the MAML wrapper for fast-adaptation,
    * use the benchmark interface to load mini-ImageNet, and
    * sample tasks and split them in adaptation and evaluation sets.

To contrast the use of the benchmark interface with directly instantiating mini-ImageNet datasets and tasks, compare with `protonet_miniimagenet.py`.
"""

import random
import numpy as np
import torch
from torch import nn, optim
import matplotlib.pyplot as plt
import learn2learn as l2l
from learn2learn.data.transforms import (NWays,
                                         KShots,
                                         LoadData,
                                         RemapLabels,
                                         ConsecutiveLabels)
import os
# #Reproducibility
seed = 101

In [2]:
# Meta-Dataset
ways = 5
shots = 1
dataset = "omniglot"
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
device = torch.device('cpu')
if torch.cuda.is_available() and torch.cuda.device_count():
    torch.cuda.manual_seed(seed)
    device = torch.device('cuda')

# Create Tasksets using the benchmark interface
tasksets = l2l.vision.benchmarks.get_tasksets(dataset,
                                              train_samples=shots,#2*shots,
                                              train_ways=ways,
                                              test_samples=shots,#2*shots,
                                              test_ways=ways,
                                              root='~/data',
)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
batch = tasksets.train.sample()
x, y = batch
print("X:", x.shape, " Y:", y)
# plt.show(x[0])

X: torch.Size([5, 1, 28, 28])  Y: tensor([1, 3, 2, 4, 0])


In [4]:
from torch.nn import Conv2d, Linear
import torch.nn.functional as F
import torch.nn as nn

# Backbone Network
class NN(nn.Module):
    def __init__(self):
        super(NN, self).__init__()
        self.conv1 = nn.Conv2d(1, 1, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=3)
        self.conv3 = nn.Conv2d(32,64, kernel_size=3)
        self.fc1 = nn.Linear(3*3*64, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        #x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(F.max_pool2d(self.conv3(x),2))
        x = F.dropout(x, p=0.5, training=self.training)
        x = x.view(-1, 3*3*64)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)
    
    
class MAML(nn.Module):
    
    def __init__(self, model):
        super(MAML, self).__init__()
        self.model = model
        
    def forward(self, x):
        return self.model(x)
    
    def clone(self):
        
        return None
    
    def fast_adapt(self):
        return None
        

In [5]:
net1 = NN()
net2 = NN()

In [24]:
x = torch.randn(2, 2, requires_grad=True)
y = x.clone()
y.retain_grad()

z = y**2
z.mean().backward()

print(y.grad)
print(x.grad)

tensor([[-0.4503, -0.3298],
        [-0.1728,  0.5032]])
tensor([[-0.4503, -0.3298],
        [-0.1728,  0.5032]])


In [23]:
v = torch.autograd.Variable(torch.randn(3), requires_grad=True)
print("V:",v)
v2 = (v + 12)**2
print("V2:", v2)
v2.retain_grad()
v2.sum().backward()
print("V2.Grad:", v2.grad)
print("V1.Grad:", v.grad)

V: tensor([-0.4931, -1.1959,  0.4018], requires_grad=True)
V2: tensor([132.4094, 116.7291, 153.8048], grad_fn=<PowBackward0>)
V2.Grad: tensor([1., 1., 1.])
V1.Grad: tensor([23.0139, 21.6082, 24.8036])


In [None]:
v2 = v+1
v2.retain_grad()
v2.sum().backward()
v2.grad

In [47]:
for  params1, params2 in zip(net1.parameters(), net2.parameters()):
    params2.data.copy_(params1.data)
    print("Param1:", params1)
    print("Param2:", params2)
    break
#     print(params1.data)
#     break
#     if(params1.data == params2.data):
#         print("True")
    print(params1.shape, " : ", params2.shape)
    

Param1: Parameter containing:
tensor([[[[ 0.0254, -0.0639,  0.2566],
          [-0.0979,  0.1889, -0.2140],
          [-0.1157,  0.1258, -0.1244]]]], requires_grad=True)
Param2: Parameter containing:
tensor([[[[ 0.0254, -0.0639,  0.2566],
          [-0.0979,  0.1889, -0.2140],
          [-0.1157,  0.1258, -0.1244]]]], requires_grad=True)


Param1: Parameter containing:
tensor([[[[ 0.0254, -0.0639,  0.2566],
          [-0.0979,  0.1889, -0.2140],
          [-0.1157,  0.1258, -0.1244]]]], requires_grad=True)
Param2: Parameter containing:
tensor([[[[-0.0911, -0.0321,  0.1369],
          [-0.1179,  0.0861, -0.3058],
          [-0.0653, -0.1240, -0.0681]]]], requires_grad=True)


In [21]:
t1 = torch.randn((2,2))
t2 = torch.randn((2,2))
# t2 = t

print("*" * 50)
print("T1:", t1, " \n", "T2:", t2)
print("*" * 50)

**************************************************
T1: tensor([[ 0.1762, -2.1027],
        [ 0.6261,  1.3108]])  
 T2: tensor([[ 1.2931, -0.3317],
        [-0.6834,  0.1683]])
**************************************************


In [22]:
# print("*" * 50)
# print("T1:", t1, " \n", "T2:", t2)
# print("*" * 50)
print("T1:", t1.data.copy_(t2), " \n", "T2:", t2.data)
print("*" * 50)
print(type(t1.data))

T1: tensor([[ 1.2931, -0.3317],
        [-0.6834,  0.1683]])  
 T2: tensor([[ 1.2931, -0.3317],
        [-0.6834,  0.1683]])
**************************************************
<class 'torch.Tensor'>


In [23]:
print("T1:", t1, " \n ", "T2:", t2)

T1: tensor([[ 1.2931, -0.3317],
        [-0.6834,  0.1683]])  
  T2: tensor([[ 1.2931, -0.3317],
        [-0.6834,  0.1683]])


In [93]:
for param in net.named_parameters():
#     print(param)
    name, parameters = param
#     print("Name:", name, " Parameters:", parameters)
    print(" Parameters:", parameters, " Type:", type(parameters), " Shape:", parameters.shape)
    print(parameters.data)
    print("*" * 100)

 Parameters: Parameter containing:
tensor([[[[ 0.2014, -0.1923, -0.0068],
          [-0.0415,  0.1986, -0.2086],
          [-0.1305,  0.1839,  0.2130]]]], requires_grad=True)  Type: <class 'torch.nn.parameter.Parameter'>  Shape: torch.Size([1, 1, 3, 3])
tensor([[[[ 0.2014, -0.1923, -0.0068],
          [-0.0415,  0.1986, -0.2086],
          [-0.1305,  0.1839,  0.2130]]]])
****************************************************************************************************
 Parameters: Parameter containing:
tensor([0.0719], requires_grad=True)  Type: <class 'torch.nn.parameter.Parameter'>  Shape: torch.Size([1])
tensor([0.0719])
****************************************************************************************************
 Parameters: Parameter containing:
tensor([[[[ 0.0010,  0.0366, -0.0201],
          [ 0.0160, -0.0496,  0.0394],
          [-0.0354,  0.0353,  0.0176]],

         [[-0.0301, -0.0520,  0.0141],
          [ 0.0522,  0.0540, -0.0008],
          [ 0.0385,  0.0444, -0.0

In [56]:
conv = nn.Conv2d(1, 1, kernel_size=3)
x = torch.randn(1, 1, 8, 8)
print("X:",x)
print(x.shape)
print(conv(x).shape)
print("Conv X:", conv(x))

X: tensor([[[[ 1.4350, -0.0418,  0.0662, -1.0033, -0.9563,  0.8041,  0.0131,
           -1.2973],
          [-0.2135,  0.8128, -0.7179,  0.7968, -1.9276,  0.7482, -0.9018,
            0.1769],
          [-0.8309,  1.8309,  0.3510, -0.3548, -1.4007, -0.7809,  0.4636,
            1.3764],
          [-2.7068, -0.7709,  0.3637, -0.4615,  0.2624, -0.5807,  0.8520,
            0.3480],
          [-1.8321,  0.4307, -0.8719, -0.4008,  0.1217,  0.2684, -0.0599,
            0.2308],
          [ 0.3181,  1.5212,  0.0312,  0.9021,  1.1248,  0.2039, -1.1476,
           -1.5148],
          [-1.8644,  0.1639, -1.2798, -0.1055,  1.2482,  0.2642, -1.0575,
           -0.5832],
          [ 0.1748,  0.2098,  0.2814,  0.8081,  1.0729,  0.5243, -0.9525,
            0.5459]]]])
torch.Size([1, 1, 8, 8])
torch.Size([1, 1, 6, 6])
Conv X: tensor([[[[ 3.4187e-01, -6.5614e-01, -2.7153e-01, -3.0435e-01,  1.0122e-02,
           -4.0183e-03],
          [ 1.5786e-01, -1.6811e-01, -5.8853e-01,  3.3191e-01, -6.1250e-01,

In [57]:
print(conv.weight)
print(conv.bias)

Parameter containing:
tensor([[[[ 0.3276,  0.1306,  0.0370],
          [-0.1876, -0.1801,  0.0048],
          [-0.2366,  0.0408, -0.1571]]]], requires_grad=True)
Parameter containing:
tensor([-0.2314], requires_grad=True)


In [30]:
print(NN())

NN(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=576, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=10, bias=True)
)


In [5]:

def accuracy(predictions, targets):
    predictions = predictions.argmax(dim=1).view(targets.shape)
    return (predictions == targets).sum().float() / targets.size(0)


def fast_adapt(batch, learner, loss, adaptation_steps, shots, ways, device):
    data, labels = batch
    data, labels = data.to(device), labels.to(device)

    # Separate data into adaptation/evalutation sets
    adaptation_indices = np.zeros(data.size(0), dtype=bool)
    adaptation_indices[np.arange(shots*ways) * 2] = True
    evaluation_indices = torch.from_numpy(~adaptation_indices)
    adaptation_indices = torch.from_numpy(adaptation_indices)
    adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]

    # Adapt the model
    for step in range(adaptation_steps):
        adaptation_error = loss(learner(adaptation_data), adaptation_labels)
        learner.adapt(adaptation_error)

    # Evaluate the adapted model
    predictions = learner(evaluation_data)
    evaluation_error = loss(predictions, evaluation_labels)
    evaluation_accuracy = accuracy(predictions, evaluation_labels)
    print("Adapted Network")
    return evaluation_error, evaluation_accuracy


def perform_experiment(dataset,
        ways=5,
        shots=5,
        meta_lr=0.003,
        fast_lr=0.5,
        meta_batch_size=32,
        adaptation_steps=1,
        num_iterations=60000,
        cuda=True,
        seed=42,
):
    
    Meta_Train_Accuracy = []
    Meta_Train_Error = []
    Meta_Val_Accuracy = []
    Meta_Val_Error = []
    
#     Iterations = []
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device('cpu')
    if cuda and torch.cuda.device_count():
        torch.cuda.manual_seed(seed)
        device = torch.device('cuda')

    # Create Tasksets using the benchmark interface
    tasksets = l2l.vision.benchmarks.get_tasksets(dataset,
                                                  train_samples=2*shots,
                                                  train_ways=ways,
                                                  test_samples=2*shots,
                                                  test_ways=ways,
                                                  root='~/data',
    )

    
    # Create model
    model = l2l.vision.models.MiniImagenetCNN(ways)
    model.to(device)
    print("*" * 100)
    print(model)
    print("*" * 100)
#     return
    maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=False)
    #Meta optimizer
    opt = optim.Adam(maml.parameters(), meta_lr)
    loss = nn.CrossEntropyLoss(reduction='mean')

    for iteration in range(num_iterations):
        opt.zero_grad()
        meta_train_error = 0.0
        meta_train_accuracy = 0.0
        meta_valid_error = 0.0
        meta_valid_accuracy = 0.0
        for task in range(meta_batch_size):
            print("Task:", task)
            # Compute meta-training loss
            learner = maml.clone()
            batch = tasksets.train.sample()
#             print("Tasks Batch:", batch)
#             continue
            evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                               learner,
                                                               loss,
                                                               adaptation_steps,
                                                               shots,
                                                               ways,
                                                               device)
            evaluation_error.backward()
            meta_train_error += evaluation_error.item()
            meta_train_accuracy += evaluation_accuracy.item()

            # Compute meta-validation loss
            learner = maml.clone()
            batch = tasksets.validation.sample()
            evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                               learner,
                                                               loss,
                                                               adaptation_steps,
                                                               shots,
                                                               ways,
                                                               device)
            meta_valid_error += evaluation_error.item()
            meta_valid_accuracy += evaluation_accuracy.item()

        # Print some metrics
#         print('\n')
#         print('Iteration', iteration)
#         print('Meta Train Error', meta_train_error / meta_batch_size)
#         print('Meta Train Accuracy', meta_train_accuracy / meta_batch_size)
#         print('Meta Valid Error', meta_valid_error / meta_batch_size)
#         print('Meta Valid Accuracy', meta_valid_accuracy / meta_batch_size)
        break
        meta_train_error =  meta_train_error / meta_batch_size
        meta_train_accuracy = meta_train_accuracy / meta_batch_size
        meta_val_error =  meta_valid_error / meta_batch_size
        meta_val_accuracy = meta_valid_accuracy / meta_batch_size
        if(iteration % 4 ==0):
            print('\n')
            print('Iteration', iteration)
            print('Meta Train Error', meta_train_error)
            print('Meta Train Accuracy', meta_train_accuracy)
            print('Meta Valid Error', meta_val_error)
            print('Meta Valid Accuracy', meta_val_accuracy)

        Meta_Train_Accuracy.append(meta_train_accuracy)
        Meta_Train_Error.append(meta_train_error)
        Meta_Val_Accuracy.append(meta_val_accuracy)
        Meta_Val_Error.append(meta_val_error)

        # Average the accumulated gradients and optimize
        for p in maml.parameters():
            p.grad.data.mul_(1.0 / meta_batch_size)
        opt.step()

    meta_test_error = 0.0
    meta_test_accuracy = 0.0
    for task in range(meta_batch_size):
        # Compute meta-testing loss
        learner = maml.clone()
        batch = tasksets.test.sample()
        evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                           learner,
                                                           loss,
                                                           adaptation_steps,
                                                           shots,
                                                           ways,
                                                           device)
        meta_test_error += evaluation_error.item()
        meta_test_accuracy += evaluation_accuracy.item()
#     print('Meta Test Error', meta_test_error / meta_batch_size)
#     print('Meta Test Accuracy', meta_test_accuracy / meta_batch_size)
    Meta_Test_Error = meta_test_error / meta_batch_size
    print('Meta Test Error', Meta_Test_Error)
    Meta_Test_Accuracy = meta_test_accuracy / meta_batch_size
    print('Meta Test Accuracy', Meta_Test_Accuracy)
    
    if not os.path.exists('plots'):
        os.makedirs('plots')
    if not os.path.exists('plots/acc'):
        os.makedirs('plots/acc')
    if not os.path.exists('plots/loss'):
        os.makedirs('plots/loss')
        
    ###### Plot Accuracies ######
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    fig = plt.figure()
    ax = plt.subplot(111)
    ax.plot(list(range(0, len(Meta_Train_Accuracy))), Meta_Train_Accuracy, label="Meta Train ")
    ax.plot(list(range(0, len(Meta_Val_Accuracy))), Meta_Val_Accuracy, label="Meta Val")
#     ax.text((len(Meta_Val_Accuracy)/2), 0, 'Meta Test:{0}'.format(round(meta_test_accuracy, 2)), style='italic',
#         bbox={'facecolor': 'red', 'alpha': 0.25, 'pad': 5})
    # place a text box in upper left in axes coords
    ax.text(0.05, 0.5, 'Meta Test:{0}'.format(round(Meta_Test_Accuracy, 2)), transform=ax.transAxes, fontsize=14,
            verticalalignment='top', bbox=props)
    plt.title('Adaption')
    plt.xlabel('Iteration')
    plt.ylabel('Meta Accuracy')
    ax.legend()
    plt.savefig('./plots/acc/{0}_ways_{1}_shots_{2}_Acc_I_{3}.png'.format(dataset, ways, shots, num_iterations),
               dpi=150)
    ###### Plot Accuracies ######

    ###### Plot Errors ######
    fig = plt.figure()
    ax = plt.subplot(111)
    ax.plot(list(range(0, len(Meta_Train_Error))), Meta_Train_Error, label="Meta Train ")
    ax.plot(list(range(0, len(Meta_Val_Error))), Meta_Val_Accuracy, label="Meta Val Lss")
#     ax.text((len(Meta_Val_Accuracy)/2),0, 'Meta Test Accuracy:{0}  Meta Test Loss:{1}'.format(round(meta_test_accuracy, 2),
#                               round(meta_test_error, 2)), style='italic',
#         bbox={'facecolor': 'red', 'alpha': 0.25, 'pad': 5})
    ax.text(0.05, 0.5, 'MTestAcc:{0}  MTestLoss:{1}'.format(round(Meta_Test_Accuracy, 2),
                                           round(Meta_Test_Error, 2)), transform=ax.transAxes, fontsize=14,
            verticalalignment='top', bbox=props)
    plt.title('Adaption')
    plt.xlabel('Iteration')
    plt.ylabel('Meta Loss')
    ax.legend()
    plt.savefig('./plots/loss/{0}_ways_{1}_shots_{2}_Loss_I_{3}.png'.format(dataset, ways, shots, num_iterations), dpi=150)
    ###### Plot Errors ######


if __name__ == '__main__':
    Ways = [5]
    Shots = [1]#, 5]
    Iterations = [1]#, 10000, 60000, 120000]
    for ways in Ways:
        for shots in Shots:
            for iteration in Iterations:
                
                perform_experiment(dataset='mini-imagenet',
                ways=ways,
                shots=shots,
                meta_lr=0.003,
                fast_lr=0.5,
                meta_batch_size=4,
                adaptation_steps=1,
                num_iterations=iteration,
                cuda=True,
                seed=seed)

****************************************************************************************************
CNN4(
  (features): Sequential(
    (0): ConvBase(
      (0): ConvBlock(
        (max_pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
        (normalize): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU()
        (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (1): ConvBlock(
        (max_pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
        (normalize): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU()
        (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (2): ConvBlock(
        (max_pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
        (normalize): BatchNorm2d(32, eps