<a href="https://colab.research.google.com/github/bhuvnk/skunworks/blob/main/Phase1/3.%20MAML/MAML_Omniglot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

MAML: Omniglot dataset

taken from implementation by: github.com/oscarknagg

https://github.com/oscarknagg/few-shot/blob/master/few_shot/maml.py

# Data downloading

The Omniglot dataset Official : https://github.com/brendenlake/omniglot

In [1]:
!wget https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip
!wget https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip

!unzip -qq images_background.zip
!unzip -qq images_evaluation.zip

--2020-11-08 09:11:51--  https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip
Resolving github.com (github.com)... 140.82.113.4
Connecting to github.com (github.com)|140.82.113.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_evaluation.zip [following]
--2020-11-08 09:11:52--  https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_evaluation.zip
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 6462886 (6.2M) [application/zip]
Saving to: ‘images_evaluation.zip’


2020-11-08 09:11:52 (82.6 MB/s) - ‘images_evaluation.zip’ saved [6462886/6462886]

--2020-11-08 09:11:52--  https://github.com/brendenlake/omniglot/raw

##Import libraries

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage
import multiprocessing as mp
import os
import cv2
import tqdm

import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

In [3]:
#Check GPU 
print(torch.cuda.is_available())

True


##Read data

## MAML DATA

In [4]:
from tqdm.notebook import tqdm
def load_characters(root, alphabet):
    X = []
    y = []
    
    alphabet_path = os.path.join(root, alphabet)
    characters = os.listdir(alphabet_path)
    
    for char in characters:
        char_path = os.path.join(alphabet_path, char)
        images = os.listdir(char_path)
        
        for img in images:
            image = cv2.imread(os.path.join(char_path, img))
            image = cv2.resize(image, (28, 28)) / 255
            
            X.append(image)
            y.append(f'{alphabet}_{char}')
    
    return X, y


def load_data(root):

    X_data = []
    y_data = []
    
    print('Loading Data')
    
    alphabets = os.listdir(root)
    for alphabet in tqdm(alphabets):
        X, y = load_characters(root, alphabet)
        X_data.extend(X)
        y_data.extend(y)
    
    return np.array(X_data), np.array(y_data)

In [5]:
def extract_sample(X_data, y_data, task_params):
    k_shot = task_params['k_shot']
    n_way = task_params['n_way']
    n_query = task_params['n_query']
    
    X_train = []
    y_train = []
    
    X_test = []
    y_test = []
    
    # Randomly sample n_way classes
    sampled_cls = np.random.choice(np.unique(y_data), n_way, replace=False)
        
    for i, c in enumerate(sampled_cls):
        # Select all images belonging to a class C
        X_data_c = X_data[y_data == c]
        
        # Sample k_shot+n_query images from all images of that class
        sample_images = np.random.permutation(X_data_c)[:(k_shot+n_query)]
        
        # Add to lists
        X_train.extend(sample_images[:k_shot])
        X_test.extend(sample_images[k_shot:])
        
        y_train.extend([i] * k_shot)
        y_test.extend([i] * n_query)
    
    # Shuffle indices
    train_idx = np.random.permutation(len(X_train))
    test_idx = np.random.permutation(len(X_test))
    
    # Convert to tensor and permute the images as channels first and use the shuffle indices
    X_train = torch.Tensor(X_train).float().permute(0, 3, 1, 2)[train_idx]
    y_train = torch.Tensor(y_train)[train_idx].long()
    
    X_test = torch.Tensor(X_test).float().permute(0, 3, 1, 2)[test_idx]
    y_test = torch.Tensor(y_test)[test_idx].long()
      
    return (X_train, y_train), (X_test, y_test)

##Build model

MAML Network

In [6]:
import torch
from torch import nn
import torch.nn.functional as F

def conv_block(in_channels, out_channels):
    '''Convolution Block of 3x3 kernels + batch norm + maxpool of 2x2'''
    
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2)
    )

def functional_conv_block(x, weights, biases, bn_weights, bn_biases):
    '''Functional version of the conv_block||We'll use this as vessel for Task updates'''
    
    x = F.conv2d(x, weights, biases, padding=1)
    x = F.batch_norm(x, running_mean=None, running_var=None, weight=bn_weights, bias=bn_biases, training=True)
    x = F.relu(x)
    x = F.max_pool2d(x, kernel_size=2, stride=2)
    
    return x

In [7]:
class MAMLClassifier(nn.Module):
    
    def __init__(self, n_way):
        super(MAMLClassifier, self).__init__()
        
        self.conv1 = conv_block(3, 64)
        self.conv2 = conv_block(64, 64)
        self.conv3 = conv_block(64, 64)
        self.conv4 = conv_block(64, 64)
        
        self.head = nn.Linear(64, n_way)
        
    def forward(self, x):
        
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        
        # Features of shape (batch_size, 64)
        feat = x.view(x.size(0), -1)
        
        # Output
        out = self.head(feat)
        
        return out
    
    def task_forward(self, x, params):
        '''Functional forward pass given the parameters'''
        '''the vessel for the task thetas'''
        
        for block in [1,2,3,4]:
            x = functional_conv_block(x, 
                                      params[f'conv{block}.0.weight'], 
                                      params[f'conv{block}.0.bias'],
                                      params[f'conv{block}.1.weight'],
                                      params[f'conv{block}.1.bias'])
        
        # Features of shape (batch_size, 64)   
        feat = x.view(x.size(0), -1)
        
        # Output
        out = F.linear(feat, params['head.weight'], params['head.bias'])
        
        return out

# MAML Train

In [8]:
# Task Parameters
task_params = {'k_shot': 1,
               'n_way': 5, 
               'n_query': 5}

In [9]:
# Load Data
X_train_dataset, y_train_dataset = load_data('images_background')

Loading Data


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))




In [10]:
#(X_support, y_support); (X_query, y_query)
(X_train, y_train), (X_test, y_test) = extract_sample(X_train_dataset, y_train_dataset, task_params)

In [11]:
# Hyperparameters
inner_train_steps = 1
alpha = 0.4 # Inner LR
beta = 0.001 # Meta LR
epochs = 5
batch_size = 32
num_episodes = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#### Important
Training Iterations on paper: 60,000

Implemented in this notebook : 500

In [12]:
# model
model = MAMLClassifier(n_way=task_params['n_way'])

In [13]:
# Loss Function
criterion = nn.CrossEntropyLoss()
# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=beta)

# Mount model to device
model.to(device)

MAMLClassifier(
  (conv1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv3): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv4): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Batch

In [14]:
from collections import OrderedDict
from tqdm.notebook import tqdm
from tqdm import tqdm_notebook

from tqdm.notebook import tnrange

# Start Meta-Training
for epoch in range(1, epochs+1):
    
    pbar = tqdm(total=num_episodes, desc='Epoch {}'.format(epoch))
    
    # Meta Episode
    for episode in range(num_episodes):
        
        task_losses = []
        task_accuracies = []
        
        # Task Fine-tuning
        for task_idx in range(batch_size):
            # Get the train and val splits
            train_sample, test_sample = extract_sample(X_train_dataset, y_train_dataset, task_params)
            X_train = train_sample[0].to(device)
            y_train = train_sample[1].to(device)
            X_val = test_sample[0].to(device)
            y_val = test_sample[1].to(device)
            
            # Create a task model using current meta model weights
            task_weights = OrderedDict(model.named_parameters())
            
            # Fine-tune
            for step in range(inner_train_steps):
                # Forward pass
                logits = model.task_forward(X_train, task_weights)
                # Loss
                loss = criterion(logits, y_train)
                # Compute Gradients
                gradients = torch.autograd.grad(loss, task_weights.values(), create_graph=True)
                # Manual Gradient Descent on the task weights
                task_weights = OrderedDict(
                                    (name, param - alpha * grad) # Gradient update on parameters(task theta)
                                    for ((name, param), grad) in zip(task_weights.items(), gradients)
                                )
            
            # Testing on the Query Set (Val)
            val_logits = model.task_forward(X_val, task_weights)
            val_loss = criterion(val_logits, y_val)
            
            # Calculating accuracy
            y_pred = val_logits.softmax(dim=1)
            accuracy = torch.eq(y_pred.argmax(dim=-1), y_val).sum().item() / y_pred.shape[0]
            
            task_accuracies.append(accuracy)
            task_losses.append(val_loss)
        
        # Meta Update
        model.train()
        optimizer.zero_grad()
        # Meta Loss
        meta_batch_loss = torch.stack(task_losses).mean()
        # Meta backpropagation
        meta_batch_loss.backward()
        # Meta Optimization
        optimizer.step()
        
        meta_batch_accuracy = torch.Tensor(task_accuracies).mean()
        
        # Progress Bar Logging
        pbar.update(1)
        pbar.set_postfix({'Loss': meta_batch_loss.item(), 
                          'Accuracy': meta_batch_accuracy.item()})
        
    pbar.close()


# Save Model
torch.save({'weights': model.state_dict(),
            'task_params': task_params}, "meta_model_para.pth")

HBox(children=(FloatProgress(value=0.0, description='Epoch 1', style=ProgressStyle(description_width='initial'…




HBox(children=(FloatProgress(value=0.0, description='Epoch 2', style=ProgressStyle(description_width='initial'…




HBox(children=(FloatProgress(value=0.0, description='Epoch 3', style=ProgressStyle(description_width='initial'…




HBox(children=(FloatProgress(value=0.0, description='Epoch 4', style=ProgressStyle(description_width='initial'…




HBox(children=(FloatProgress(value=0.0, description='Epoch 5', style=ProgressStyle(description_width='initial'…




## Validating or testing

In [15]:
# # Load Checkpoint
# checkpoint = torch.load('/content/meta_model_para.pth')

# task_params = checkpoint['task_params']

# # load the weights into the model again
# model = MAMLClassifier(n_way=task_params['n_way'])
# model.load_state_dict(checkpoint['weights'])

In [16]:
# Load Validation data Data
X_test_dataset, y_test_dataset = load_data('images_evaluation')

Loading Data


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




In [17]:
# Hyperparameters
inner_train_steps = 1
alpha = 0.4 # Inner LR
batch_size = 32
num_episodes = 100

# Loss Function
criterion = nn.CrossEntropyLoss()

# Validation innersteps are 3, as given in paper
inner_train_steps = 3

# # Mount model to device
# model.to(device)

In [18]:
# Evaluation
pbar = tqdm_notebook(total=num_episodes, desc='Evaluating')

overall_accuracies = []
# Meta Episode
for episode in range(num_episodes):

    task_losses = []
    task_accuracies = []

    # Task Fine-tuning
    for task_idx in range(batch_size):
        # Get the train and val splits
        train_sample, test_sample = extract_sample(X_test_dataset, y_test_dataset, task_params)
        X_train = train_sample[0].to(device)
        y_train = train_sample[1].to(device)
        X_val = test_sample[0].to(device)
        y_val = test_sample[1].to(device)

        # Create a task model using current meta model weights
        task_weights = OrderedDict(model.named_parameters())

        # Fine-tune
        for step in range(inner_train_steps):
            # Forward pass
            logits = model.task_forward(X_train, task_weights)
            # Loss
            loss = criterion(logits, y_train)
            # Compute Gradients
            gradients = torch.autograd.grad(loss, task_weights.values(), create_graph=True)
            # Manual Gradient Descent on the task weights
            task_weights = OrderedDict(
                                (name, param - alpha * grad)
                                for ((name, param), grad) in zip(task_weights.items(), gradients)
                            )

        # Testing on the Query Set (Val)
        val_logits = model.task_forward(X_val, task_weights)
        val_loss = criterion(val_logits, y_val)

        # Calculating accuracy
        y_pred = val_logits.softmax(dim=1)
        accuracy = torch.eq(y_pred.argmax(dim=-1), y_val).sum().item() / y_pred.shape[0]

        task_accuracies.append(accuracy)
        overall_accuracies.append(accuracy)
        task_losses.append(val_loss)

    # Meta Loss and Accuracy
    meta_batch_loss = torch.stack(task_losses).mean()
    meta_batch_accuracy = torch.Tensor(task_accuracies).mean()

    pbar.update(1)
    pbar.set_postfix({'Loss': meta_batch_loss.item(), 
                      'Accuracy': meta_batch_accuracy.item()})

pbar.close()
print(f'Mean Accuracy {np.array(overall_accuracies).mean()}')

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


HBox(children=(FloatProgress(value=0.0, description='Evaluating', style=ProgressStyle(description_width='initi…


Mean Accuracy 0.9292625000000001
