# Implementation of Prototypical Networks for Few Shot Learning (https://arxiv.org/abs/1703.05175) in Pytorch

> Prototypical Networks
> ImageNet’s dataset contains a multitude of examples for machine learning, which is not always the case
> And typical deep learning architecture relies on substantial data for sufficient outcomes

> With N-shot learning, we want to create our model which is trained using a very limited number of images (less than N)

> Prototypical networks are one of the most popular deep learning algorithms

> Unlike typical deep learning architecture, prototypical networks do not classify the image directly, and instead learn the mapping of an image in metric space.
> It simply calculates the distance betwwen a vector in an image (we only compute the distance of one point to another), so the "origin" point does not matter

In [9]:
import multiprocessing as mp
from matplotlib import pyplot as plt
import cv2
from tqdm import tqdm
import numpy as np 
import pandas as pd
import os
print(os.listdir("C:/Users/bokhy/Desktop/Python/Python-Projects/Pytorch/data/Omniglot"))

['images_background', 'images_evaluation']


# 1. Read / Import Data

> The network was trained on the Omniglot dataset. The Omniglot data set is designed for developing more human-like learning algorithms. It contains 1,623 different handwritten characters from 50 different alphabets
> Data: https://github.com/brendenlake/omniglot

In [10]:
train_dir = os.listdir('C:/Users/bokhy/Desktop/Python/Python-Projects/Pytorch/data/Omniglot/images_background/')
datax = np.array([])

In [11]:
train_dir

['Alphabet_of_the_Magi',
 'Anglo-Saxon_Futhorc',
 'Arcadian',
 'Armenian',
 'Asomtavruli_(Georgian)',
 'Balinese',
 'Bengali',
 'Blackfoot_(Canadian_Aboriginal_Syllabics)',
 'Braille',
 'Burmese_(Myanmar)',
 'Cyrillic',
 'Early_Aramaic',
 'Futurama',
 'Grantha',
 'Greek',
 'Gujarati',
 'Hebrew',
 'Inuktitut_(Canadian_Aboriginal_Syllabics)',
 'Japanese_(hiragana)',
 'Japanese_(katakana)',
 'Korean',
 'Latin',
 'Malay_(Jawi_-_Arabic)',
 'Mkhedruli_(Georgian)',
 'N_Ko',
 'Ojibwe_(Canadian_Aboriginal_Syllabics)',
 'Sanskrit',
 'Syriac_(Estrangelo)',
 'Tagalog',
 'Tifinagh']

> Data Augmentation: Images are rotated by 90, 180 and 270 degrees and each rotation resulted in one more class. Hence the total count of classes reached to 6492 (1,623 * 4) total classes

> TODO: Use Albumentation library to augment images

In [12]:
def image_rotate(img, angle):
    rows,cols, _ = img.shape
    M = cv2.getRotationMatrix2D((cols/2 ,rows/2),angle,1)
    dst = cv2.warpAffine(img,M,(cols,rows))
    return np.expand_dims(dst, 0)

def read_alphabets(alphabet_directory, directory):
    """
    Reads all the characters from alphabet_directory and augment each image with 90, 180, 270 degrees of rotation.
    """
    datax = None
    datay = []
    characters = os.listdir(alphabet_directory)
    for character in characters:
        images = os.listdir(alphabet_directory + character + '/')
        for img in images:
            image = cv2.resize(cv2.imread(alphabet_directory + character + '/' + img), (28,28))
            image90 = image_rotate(image, 90)
            image180 = image_rotate(image, 180)
            image270 = image_rotate(image, 270)
            image = np.expand_dims(image, 0)
            if datax is None:
                datax = np.vstack([image, image90, image180, image270])
            else:
                datax = np.vstack([datax, image, image90, image180, image270])
            datay.append(directory + '_' + character + '_0')
            datay.append(directory + '_' + character + '_90')
            datay.append(directory + '_' + character + '_180')
            datay.append(directory + '_' + character + '_270')
    return datax, np.array(datay)

def read_images(base_directory):
    """
    Used multithreading for data reading to decrease the reading time drastically
    """
    datax = None
    datay = []
    pool = mp.Pool(mp.cpu_count())
    results = [pool.apply(read_alphabets, args=(base_directory + '/' + directory + '/', directory, )) for directory in os.listdir(base_directory)]
    pool.close()
    for result in results:
        if datax is None:
            datax = result[0]
            datay = result[1]
        else:
            datax = np.vstack([datax, result[0]])
            datay = np.concatenate([datay, result[1]])
    return datax, datay

In [None]:
%time trainx, trainy = read_images('C:/Users/bokhy/Desktop/Python/Python-Projects/Pytorch/data/Omniglot/images_background/')

Wall time: 0 ns


In [None]:
%time testx, testy = read_images('C:/Users/bokhy/Desktop/Python/Python-Projects/Pytorch/data/Omniglot/images_evaluation/')

> Making sure we have shpaes that we want

In [None]:
trainx.shape, trainy.shape, testx.shape, testy.shape

# 2. Create a model,

In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from tqdm import trange
from time import sleep
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
use_gpu = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [None]:
trainx = torch.from_numpy(trainx).float()
#trainy = torch.from_numpy(trainy)
testx = torch.from_numpy(testx).float()
#testy = torch.from_numpy(testy)
if use_gpu:
    trainx = trainx.to(device)
    testx = testx.to(device)
trainx.size(), testx.size()

### "Image2vector CNN" architecture. It takes images of dimensions 28x28x3 and returns a vector of length 64

In [None]:
class Net(nn.Module):
    
    def sub_block(self, in_channels, out_channels=64, kernel_size=3):
        block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels, padding=1),
                    torch.nn.BatchNorm2d(out_channels),
                    torch.nn.ReLU(),
                    torch.nn.MaxPool2d(kernel_size=2)
                )
        return block
    
    def __init__(self):
        super(Net, self).__init__()
        self.convnet1 = self.sub_block(3)
        self.convnet2 = self.sub_block(64)
        self.convnet3 = self.sub_block(64)
        self.convnet4 = self.sub_block(64)

    def forward(self, x):
        x = self.convnet1(x)
        x = self.convnet2(x)
        x = self.convnet3(x)
        x = self.convnet4(x)
        x = torch.flatten(x, start_dim=1)
        return x

In [None]:
class PrototypicalNet(nn.Module):
    def __init__(self, use_gpu=False):
        super(PrototypicalNet, self).__init__()
        self.f = Net()
        self.gpu = use_gpu
        if self.gpu:
            self.f = self.f.cuda()
    
    def forward(self, datax, datay, Ns,Nc, Nq, total_classes):
        """
        Implementation of one episode in Prototypical Net
        datax: Training images
        datay: Corresponding labels of datax
        Nc: Number  of classes per episode
        Ns: Number of support data per class
        Nq:  Number of query data per class
        total_classes: Total classes in training set
        """
        k = total_classes.shape[0]
        K = np.random.choice(total_classes, Nc, replace=False)
        Query_x = torch.Tensor()
        if(self.gpu):
            Query_x = Query_x.cuda()
        Query_y = []
        Query_y_count = []
        centroid_per_class  = {}
        class_label = {}
        label_encoding = 0
        for cls in K:
            S_cls, Q_cls = self.random_sample_cls(datax, datay, Ns, Nq, cls)
            centroid_per_class[cls] = self.get_centroid(S_cls, Nc)
            class_label[cls] = label_encoding
            label_encoding += 1
            Query_x = torch.cat((Query_x, Q_cls), 0) # Joining all the query set together
            Query_y += [cls]
            Query_y_count += [Q_cls.shape[0]]
        Query_y, Query_y_labels = self.get_query_y(Query_y, Query_y_count, class_label)
        Query_x = self.get_query_x(Query_x, centroid_per_class, Query_y_labels)
        return Query_x, Query_y
    
    def random_sample_cls(self, datax, datay, Ns, Nq, cls):
        """
        Randomly samples Ns examples as support set and Nq as Query set
        """
        data = datax[(datay == cls).nonzero()]
        perm = torch.randperm(data.shape[0])
        idx = perm[:Ns]
        S_cls = data[idx]
        idx = perm[Ns : Ns+Nq]
        Q_cls = data[idx]
        if self.gpu:
            S_cls = S_cls.cuda()
            Q_cls = Q_cls.cuda()
        return S_cls, Q_cls
    
    def get_centroid(self, S_cls, Nc):
        """
        Returns a centroid vector of support set for a class
        """
        return torch.sum(self.f(S_cls), 0).unsqueeze(1).transpose(0,1) / Nc
    
    def get_query_y(self, Qy, Qyc, class_label):
        """
        Returns labeled representation of classes of Query set and a list of labels.
        """
        labels = []
        m = len(Qy)
        for i in range(m):
            labels += [Qy[i]] * Qyc[i]
        labels = np.array(labels).reshape(len(labels), 1)
        label_encoder = LabelEncoder()
        Query_y = torch.Tensor(label_encoder.fit_transform(labels).astype(int)).long()
        if self.gpu:
            Query_y = Query_y.cuda()
        Query_y_labels = np.unique(labels)
        return Query_y, Query_y_labels
    
    def get_centroid_matrix(self, centroid_per_class, Query_y_labels):
        """
        Returns the centroid matrix where each column is a centroid of a class.
        """
        centroid_matrix = torch.Tensor()
        if(self.gpu):
            centroid_matrix = centroid_matrix.cuda()
        for label in Query_y_labels:
            centroid_matrix = torch.cat((centroid_matrix, centroid_per_class[label]))
        if self.gpu:
            centroid_matrix = centroid_matrix.cuda()
        return centroid_matrix
    
    def get_query_x(self, Query_x, centroid_per_class, Query_y_labels):
        """
        Returns distance matrix from each Query image to each centroid.
        """
        centroid_matrix = self.get_centroid_matrix(centroid_per_class, Query_y_labels)
        Query_x = self.f(Query_x)
        m = Query_x.size(0)
        n = centroid_matrix.size(0)
        # The below expressions expand both the matrices such that they become compatible to each other in order to caclulate L2 distance.
        centroid_matrix = centroid_matrix.expand(m, centroid_matrix.size(0), centroid_matrix.size(1)) # Expanding centroid matrix to "m".
        Query_matrix = Query_x.expand(n, Query_x.size(0), Query_x.size(1)).transpose(0,1) # Expanding Query matrix "n" times
        Qx = torch.pairwise_distance(centroid_matrix.transpose(1,2), Query_matrix.transpose(1,2))
        return Qx

In [None]:
protonet = PrototypicalNet(use_gpu=use_gpu)
optimizer = optim.SGD(protonet.parameters(), lr = 0.01, momentum=0.99)

# 3. Train the Model

In [None]:
def train_step(datax, datay, Ns,Nc, Nq):
    optimizer.zero_grad()
    Qx, Qy= protonet(datax, datay, Ns, Nc, Nq, np.unique(datay))
    pred = torch.log_softmax(Qx, dim=-1)
    loss = F.nll_loss(pred, Qy)
    loss.backward()
    optimizer.step()
    acc = torch.mean((torch.argmax(pred, 1) == Qy).float())
    return loss, acc

In [None]:
num_episode = 16000
frame_size = 1000
frame_loss = 0
frame_acc = 0

for i in range(num_episode):
    loss, acc = train_step(trainx, trainy, 5, 60, 5)
    frame_loss += loss.data
    frame_acc += acc.data
    if( (i+1) % frame_size == 0):
        print("Frame Number:", ((i+1) // frame_size), 'Frame Loss: ', frame_loss.data.cpu().numpy().tolist()/ frame_size, 'Frame Accuracy:', (frame_acc.data.cpu().numpy().tolist() * 100) / frame_size)
        frame_loss = 0
        frame_acc = 0

# 4. Test the Model

In [None]:
def test_step(datax, datay, Ns,Nc, Nq):
    Qx, Qy= protonet(datax, datay, Ns, Nc, Nq, np.unique(datay))
    pred = torch.log_softmax(Qx, dim=-1)
    loss = F.nll_loss(pred, Qy)
    acc = torch.mean((torch.argmax(pred, 1) == Qy).float())
    return loss, acc

In [None]:
num_test_episode = 2000
avg_loss = 0
avg_acc = 0

for _ in range(num_test_episode):
    loss, acc = test_step(testx, testy, 5, 60, 15)
    avg_loss += loss.data
    avg_acc += acc.data
print('Avg Loss: ', avg_loss.data.cpu().numpy().tolist() / num_test_episode , 'Avg Accuracy:', (avg_acc.data.cpu().numpy().tolist() * 100) / num_test_episode)

In [None]:
> Weight-loading for later use

In [None]:
def load_weights(filename, protonet, use_gpu):
    if use_gpu:
        protonet.load_state_dict(torch.load(filename))
    else:
        protonet.load_state_dict(torch.load(filename), map_location='cpu')
    return protonet

> Limitations

> It works well on predicting images of a character using the same dataset, because they share the similar traits. However, if we were to try using this model to use on different tpyes of image (i.e. fruits), it would not be as good as this one, as fruits and writings share few characteristics.

> Reference: https://blog.floydhub.com/n-shot-learning/