<a href="https://colab.research.google.com/github/bhuvnk/skunworks/blob/main/Phase1/3.%20MAML/MAML_MiniImagenet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

MAML: Omniglot dataset

taken from implementation by: github.com/oscarknagg

https://github.com/oscarknagg/few-shot/blob/master/few_shot/maml.py

# Important
Training Iterations on paper: 60,000

Implemented in this notebook : 500

# Data Stuff

Mini Imagenet pickle by some gentleman:
https://drive.google.com/file/d/1fJAK5WZTjerW7EWHHQAR9pRJVNg1T1Y7/view


In [None]:
# Download the MiniImagenet from Aauth
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)
downloaded = drive.CreateFile({'id':"1fJAK5WZTjerW7EWHHQAR9pRJVNg1T1Y7"})   
downloaded.GetContentFile('miniImageNet.zip') 

In [None]:
!mkdir miniImageNet
!unzip -qq miniImageNet.zip -d miniImageNet

##Import libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage
import multiprocessing as mp
import os
import cv2
from tqdm.notebook import tqdm
import pickle as pkl
import random

import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

In [None]:
#Check GPU 
print(torch.cuda.is_available())

True


##Read data

## Mini Imagenet DATA

In [None]:
#data path
data_path = '/content/miniImageNet'
train_path = data_path + '/miniImageNet_category_split_train_phase_train.pickle'
val_path = data_path + '/miniImageNet_category_split_val.pickle'
test_path = data_path + '/miniImageNet_category_split_test.pickle'

In [None]:
def load_data(data_file):
    dataset = read_file(data_file)
    data = dataset['data']
    labels = dataset['labels']
    label2ind = buildLabelIndex(labels)

    return {key: np.array(data[val]) for (key, val) in tqdm(label2ind.items())}

def read_file(data_file):
    try:
        with open(data_file, 'rb') as fo:
            data = pkl.load(fo)
        return data
    except:
        with open(data_file, 'rb') as f:
            u = pkl._Unpickler(f)
            u.encoding = 'latin1'
            data = u.load()
        return data

def buildLabelIndex(labels):
    label2inds = {}
    for idx, label in enumerate(labels):
        if label not in label2inds:
            label2inds[label] = []
        label2inds[label].append(idx)

    return label2inds

# Task Sampler

In [None]:
import numpy as np

def extract_sample(data, task_params):
    k_shot = task_params['k_shot']
    n_way = task_params['n_way']
    n_query = task_params['n_query']

    X_train = []
    y_train = []
    
    X_test = []
    y_test = []
    
    # randomly select classes
    sampled_characters = random.sample(data.keys(), n_way)

    for (k, char) in enumerate(sampled_characters):
        _imgs = data[char] # all the images for that class
        # Sample k_shot+n_query images from all images of that class
        sample_images = np.random.permutation(_imgs)[:(k_shot+n_query)]

        # Add to lists
        X_train.extend(sample_images[:k_shot]/255.)
        X_test.extend(sample_images[k_shot:]/255.)
        
        y_train.extend([k] * k_shot)
        y_test.extend([k] * n_query)
    

    # Shuffle indices
    train_idx = np.random.permutation(len(X_train))
    test_idx = np.random.permutation(len(X_test))
    
    # Convert to tensor and permute the images as channels first and use the shuffle indices
    X_train = torch.Tensor(X_train).float().permute(0, 3, 1, 2)[train_idx]
    y_train = torch.Tensor(y_train)[train_idx].long()
    
    X_test = torch.Tensor(X_test).float().permute(0, 3, 1, 2)[test_idx]
    y_test = torch.Tensor(y_test)[test_idx].long()
      
    return (X_train, y_train), (X_test, y_test)

In [None]:
# Task Parameters
task_params = {'k_shot': 1,
               'n_way': 5, 
               'n_query': 5}

##Build model

MAML Network

In [None]:
import torch
from torch import nn
import torch.nn.functional as F

def conv_block(in_channels, out_channels):
    '''Convolution Block of 3x3 kernels + batch norm + maxpool of 2x2'''
    
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2)
    )

def functional_conv_block(x, weights, biases, bn_weights, bn_biases):
    '''Functional version of the conv_block||We'll use this as vessel for Task updates'''
    
    x = F.conv2d(x, weights, biases, padding=1)
    x = F.batch_norm(x, running_mean=None, running_var=None, weight=bn_weights, bias=bn_biases, training=True)
    x = F.relu(x)
    x = F.max_pool2d(x, kernel_size=2, stride=2)
    
    return x

In [None]:
class MAMLClassifier(nn.Module):
    
    def __init__(self, n_way, input_channel= 3, final_layer_size = 64):
        super(MAMLClassifier, self).__init__()
        
        self.conv1 = conv_block(input_channel, 64)  # 84 | 85 | 83 | 42 
        self.conv2 = conv_block(64, 64)             # 42 | 43 | 41 | 21 
        self.conv3 = conv_block(64, 64)             # 21 | 22 | 20 | 10
        self.conv4 = conv_block(64, 64)             # 10 | 11 | 9 | 5
        
        self.head = nn.Linear(final_layer_size, n_way)
        
    def forward(self, x):
        
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        
        # Features of shape (batch_size, final_layer_size)
        feat = x.contiguous().view(x.size(0), -1)
        
        # Output
        out = self.head(feat)
        
        return out
    
    def task_forward(self, x, params):
        '''Functional forward pass given the parameters'''
        '''the vessel for the task thetas'''
        
        for block in [1,2,3,4]:
            x = functional_conv_block(x, 
                                      params[f'conv{block}.0.weight'], 
                                      params[f'conv{block}.0.bias'],
                                      params[f'conv{block}.1.weight'],
                                      params[f'conv{block}.1.bias'])
        
        # Features of shape (batch_size, final_layer_size)   
        feat = x.contiguous().view(x.size(0), -1)
        
        # Output
        out = F.linear(feat, params['head.weight'], params['head.bias'])
        
        return out

# MAML Train

In [None]:
%%time
train_data = load_data(train_path)

HBox(children=(FloatProgress(value=0.0, max=64.0), HTML(value='')))


CPU times: user 13 s, sys: 1min 9s, total: 1min 22s
Wall time: 1min 22s


In [None]:
# Task Parameters
task_params = {'k_shot': 1,
               'n_way': 5, 
               'n_query': 15}

In [None]:
train_data.keys()

dict_keys([52, 59, 41, 16, 8, 13, 39, 50, 7, 26, 24, 31, 21, 12, 63, 2, 18, 0, 14, 6, 30, 27, 19, 5, 22, 3, 62, 29, 11, 10, 61, 45, 49, 32, 43, 1, 58, 28, 53, 42, 46, 57, 25, 23, 34, 33, 4, 17, 35, 60, 51, 47, 54, 20, 56, 55, 44, 36, 38, 15, 48, 40, 9, 37])

In [None]:
(X_train, y_train), (X_test, y_test) = extract_sample(train_data, task_params)
X_train.shape, y_train.shape, X_test.shape, y_test.shape

(torch.Size([5, 3, 84, 84]),
 torch.Size([5]),
 torch.Size([75, 3, 84, 84]),
 torch.Size([75]))

In [None]:
y_train, y_test

(tensor([4, 2, 3, 1, 0]),
 tensor([2, 2, 3, 0, 3, 2, 3, 4, 0, 2, 3, 0, 2, 1, 1, 2, 1, 4, 3, 4, 4, 4, 2, 3,
         4, 0, 0, 3, 1, 2, 4, 2, 4, 1, 1, 3, 1, 1, 0, 0, 4, 3, 2, 0, 1, 1, 0, 3,
         1, 4, 2, 1, 0, 0, 2, 3, 4, 4, 4, 1, 0, 1, 0, 4, 2, 2, 3, 0, 0, 3, 4, 3,
         3, 1, 2]))

In [None]:
# Hyperparameters as per the paper
inner_train_steps = 5
alpha = 0.01 # Inner LR # as given in paper
beta = 0.001 # Meta LR

# batch_size = 32
batch_size = 4

# epochs = 60
# num_episodes = 1000
epochs = 5
num_episodes = 100

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

final_layer_size = 1600 # as used in paper


In [None]:
# model
model = MAMLClassifier(n_way=task_params['n_way'], final_layer_size = final_layer_size)

In [None]:
# Loss Function
criterion = nn.CrossEntropyLoss()
# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=beta)

# Mount model to device
model.to(device)

MAMLClassifier(
  (conv1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv3): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv4): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Batch

In [None]:
from collections import OrderedDict
from tqdm.notebook import tqdm

from tqdm.notebook import tnrange

# Start Meta-Training
for epoch in range(1, epochs+1):

    epoch_accuracy = []
    pbar = tqdm(total=num_episodes, desc='Epoch {}'.format(epoch))
    
    # Meta Episode
    for episode in range(num_episodes):
        
        task_losses = []
        task_accuracies = []
        
        # Task Fine-tuning
        for task_idx in range(batch_size):
            # Get the train and val splits
            train_sample, test_sample = extract_sample(train_data, task_params)
            X_train = train_sample[0].to(device)
            y_train = train_sample[1].to(device)
            X_val = test_sample[0].to(device)
            y_val = test_sample[1].to(device)
            
            # Create a task model using current meta model weights
            task_weights = OrderedDict(model.named_parameters())
            
            # Fine-tune
            for step in range(inner_train_steps):
                # Forward pass
                logits = model.task_forward(X_train, task_weights)
                # Loss
                loss = criterion(logits, y_train)
                # Compute Gradients
                gradients = torch.autograd.grad(loss, task_weights.values(), create_graph=True)
                # Manual Gradient Descent on the task weights
                task_weights = OrderedDict(
                                    (name, param - alpha * grad) # Gradient update on parameters(task theta)
                                    for ((name, param), grad) in zip(task_weights.items(), gradients)
                                )
            
            # Testing on the Query Set (Val)
            val_logits = model.task_forward(X_val, task_weights)
            val_loss = criterion(val_logits, y_val)
            
            # Calculating accuracy
            y_pred = val_logits.softmax(dim=1)
            accuracy = torch.eq(y_pred.argmax(dim=-1), y_val).sum().item() / y_pred.shape[0]
            
            task_accuracies.append(accuracy)
            task_losses.append(val_loss)

            # record for epoch
            epoch_accuracy.append(accuracy)
        
        # Meta Update
        model.train()
        optimizer.zero_grad()
        # Meta Loss
        meta_batch_loss = torch.stack(task_losses).mean()
        # Meta backpropagation
        meta_batch_loss.backward()
        # Meta Optimization
        optimizer.step()
        
        meta_batch_accuracy = torch.Tensor(task_accuracies).mean()
        
        # Progress Bar Logging
        pbar.update(1)
        pbar.set_postfix({'Loss': meta_batch_loss.item(), 
                          'Accuracy': meta_batch_accuracy.item()})
        
    pbar.close()
    print(f'Mean Epoch Accuracy {np.array(epoch_accuracy).mean()}')


# Save Model
torch.save({'weights': model.state_dict(),
            'task_params': task_params}, "meta_model_para.pth")

HBox(children=(FloatProgress(value=0.0, description='Epoch 1', style=ProgressStyle(description_width='initial'…




HBox(children=(FloatProgress(value=0.0, description='Epoch 2', style=ProgressStyle(description_width='initial'…




HBox(children=(FloatProgress(value=0.0, description='Epoch 3', style=ProgressStyle(description_width='initial'…




HBox(children=(FloatProgress(value=0.0, description='Epoch 4', style=ProgressStyle(description_width='initial'…




HBox(children=(FloatProgress(value=0.0, description='Epoch 5', style=ProgressStyle(description_width='initial'…




## Validating or testing

In [None]:
# # Load Checkpoint
# checkpoint = torch.load('/content/meta_model_para.pth')

# task_params = checkpoint['task_params']

# # load the weights into the model again
# model = MAMLClassifier(n_way=task_params['n_way'])
# model.load_state_dict(checkpoint['weights'])

In [None]:
# Load Validation data Data
%%time
test_dataset = load_data(val_path)

HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))




In [None]:
# Hyperparameters
# inner_train_steps = 1
alpha = 0.01 # Inner LR
batch_size = 4
num_episodes = 100

# Loss Function
criterion = nn.CrossEntropyLoss()

# Validation innersteps are 5, as given in paper
inner_train_steps = 10

# # Mount model to device
# model.to(device)

In [None]:
# Evaluation
pbar = tqdm(total=num_episodes, desc='Evaluating')

overall_accuracies = []
# Meta Episode
for episode in range(num_episodes):

    task_losses = []
    task_accuracies = []

    # Task Fine-tuning
    for task_idx in range(batch_size):
        # Get the train and val splits
        train_sample, test_sample = extract_sample(test_dataset, task_params)
        X_train = train_sample[0].to(device)
        y_train = train_sample[1].to(device)
        X_val = test_sample[0].to(device)
        y_val = test_sample[1].to(device)

        # Create a task model using current meta model weights
        task_weights = OrderedDict(model.named_parameters())

        # Fine-tune
        for step in range(inner_train_steps):
            # Forward pass
            logits = model.task_forward(X_train, task_weights)
            # Loss
            loss = criterion(logits, y_train)
            # Compute Gradients
            gradients = torch.autograd.grad(loss, task_weights.values(), create_graph=True)
            # Manual Gradient Descent on the task weights
            task_weights = OrderedDict(
                                (name, param - alpha * grad)
                                for ((name, param), grad) in zip(task_weights.items(), gradients)
                            )

        # Testing on the Query Set (Val)
        val_logits = model.task_forward(X_val, task_weights)
        val_loss = criterion(val_logits, y_val)

        # Calculating accuracy
        y_pred = val_logits.softmax(dim=1)
        accuracy = torch.eq(y_pred.argmax(dim=-1), y_val).sum().item() / y_pred.shape[0]

        task_accuracies.append(accuracy)
        overall_accuracies.append(accuracy)
        task_losses.append(val_loss)

    # Meta Loss and Accuracy
    meta_batch_loss = torch.stack(task_losses).mean()
    meta_batch_accuracy = torch.Tensor(task_accuracies).mean()

    pbar.update(1)
    pbar.set_postfix({'Loss': meta_batch_loss.item(), 
                      'Accuracy': meta_batch_accuracy.item()})

pbar.close()
print(f'Mean Accuracy {np.array(overall_accuracies).mean()}')

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


HBox(children=(FloatProgress(value=0.0, description='Evaluating', style=ProgressStyle(description_width='initi…


Mean Accuracy 0.3514333333333333
