# Prototypical Networks for Few-shot Learning

So the idea is to solve the problem of few shot learning in terms of classifier.
- Classifier must generalize to new classes not seen in the training set, given only a small number of examples of each new class. 
- So, we don't give the model totally new dataset it will be given only a small number of them.
Humans are well suited at this task, given one image we can learn what it is and classify the object.
Prototypical networks learn a metric space in which classification can be performed by computing distances to prototype representations of each class.
- What is prototype representation of a class ?
    - I bet this paper is all about this, we have to read more 
- Their assumptions:
    - Our approach, prototypical networks, is based on the idea that there exists an embedding in which points cluster around a single prototype representation for each class.
- They learn an embedding of the meta-data into a shared space to serve as the prototype for each class.




# Dataset Loading 

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import os, sys
import cv2
from PIL import Image
import numpy as np
import multiprocessing as mp
import pandas as pd
from threading import Thread

In [33]:
def read_image(path, use_pillow=False, height=256):
    """
    reads and returns image
    Arguments:
        path: str - path to the image
    Returns 
        img: numpy array        
    """
    if(use_pillow):
        img = Image.open(path)
    else:
        img = cv2.resize(cv2.imread(path, 1), (height, height))
    return img

def load_all_images(paths):
    data = []
    for idx, path in enumerate(paths):
        data.append(read_image(path))
        if(idx % 1000 == 0):
            print(idx)
            print(path)
    return data


"""
Birds Dataset Meta should return two subset of dataset:
1. Support set which is used to get the prototype representations of the classes.
2. Query set which is used to train the model.
Open questions :
1. Do they have the same classes(classes in the support set are the same with the classes with the query set) ?
    - yes they should be the same 
2. Is it ok if support set and query set have the same images ?
    - it is not recommended, or must not be true
"""
def start_process():
    print('Starting', mp.current_process().name)
class BirdsDatasetMeta(Dataset):
    """
    this is a Birds dataset reader for meta learning
    """
    def __init__(self, images_pd, imgs,classes, transform = None, n_class_per_batch=5, n_support= 10, n_query=10):
        self.n_class_per_batch = n_class_per_batch
        self.n_support = n_support
        self.n_query = n_query
        self.images_pd = images_pd
        self.imgs = imgs
        self.classes = classes

    def __len__(self):
        return len(self.classes) # each class is one sample
    
    def __getitem__(self, idx):
       # so we have the same type of classes for K shot learning.
        chosen_classes = np.random.choice(self.classes, self.n_class_per_batch, replace=False)
        feats = []
        labels = []
        for c in chosen_classes:
            datax_cls = self.images_pd[self.images_pd["class_id"]==c]            
            perm = np.random.permutation(datax_cls["index"].values)
            if(self.imgs is not None):
                sample_cls = self.imgs[perm[:(self.n_support+self.n_query)]]
            else:
                raise("not implemented error")
            feats.append(sample_cls)
            labels.append([c] * len(sample_cls))
        return np.float32(feats), np.int32(labels)

In [3]:
ROOT = "./CUB_200_2011"
images_path_file = os.path.join(ROOT, "images.txt")
images_class_file = os.path.join(ROOT, "image_class_labels.txt")
with open(images_path_file, 'r') as f:
    images = f.readlines()
images = [im_path.split(" ") for im_path in images]
images = [[int(im_path[0]), int(im_path[-1][:3]), os.path.join(ROOT, "images",im_path[-1][:-1])] for im_path in images]
images_pd = pd.DataFrame(images, columns =["index", "class_id", "path"])
imgs= load_all_images(images_pd["path"].values)
classes = images_pd["class_id"].unique()

0
./CUB_200_2011\images\001.Black_footed_Albatross/Black_Footed_Albatross_0046_18.jpg
1000
./CUB_200_2011\images\019.Gray_Catbird/Gray_Catbird_0039_21040.jpg
2000
./CUB_200_2011\images\036.Northern_Flicker/Northern_Flicker_0037_28751.jpg
3000
./CUB_200_2011\images\052.Pied_billed_Grebe/Pied_Billed_Grebe_0091_35276.jpg
4000
./CUB_200_2011\images\069.Rufous_Hummingbird/Rufous_Hummingbird_0046_59647.jpg
5000
./CUB_200_2011\images\086.Pacific_Loon/Pacific_Loon_0022_75405.jpg
6000
./CUB_200_2011\images\103.Sayornis/Sayornis_0079_98434.jpg
7000
./CUB_200_2011\images\120.Fox_Sparrow/Fox_Sparrow_0113_114389.jpg
8000
./CUB_200_2011\images\137.Cliff_Swallow/Cliff_Swallow_0005_133696.jpg
9000
./CUB_200_2011\images\154.Red_eyed_Vireo/Red_Eyed_Vireo_0036_156727.jpg
10000
./CUB_200_2011\images\170.Mourning_Warbler/Mourning_Warbler_0078_795377.jpg
11000
./CUB_200_2011\images\187.American_Three_toed_Woodpecker/American_Three_Toed_Woodpecker_0030_796144.jpg


In [7]:
imgsn = np.float32(imgs)
del imgs

NameError: name 'imgs' is not defined

In [34]:
data = BirdsDatasetMeta(images_pd, imgsn, classes)

In [38]:
with open(os.path.join(ROOT, "classes.txt"), "r") as f:
    class_info = f.readlines()
# to print the name of the class:
# class_info[label-1]

In [16]:
# len(data)

In [35]:
imgs, labels = data[0]

In [37]:
labels.shape

(5, 20)

# Model 

In [40]:
import torch.nn as nn

the model is composed of four convolutional blocks. Each block comprises a 64-filter 3 × 3 convolution, batch normalization layer

In [46]:
class PrototypeModel(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        nn.Module.__init__(self)        
        self.in_dim = in_dim
        self.hidden_dim = hidden_dim
        self.out_dim = out_dim
        self.conv_blocks = nn.Sequential(
            self.block(in_dim, hidden_dim),
            self.block(hidden_dim, hidden_dim),
            self.block(hidden_dim, hidden_dim),
            self.block(hidden_dim, out_dim),
        )
    def block(self,in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.MaxPool2d(2)
            )
    def forward(self, x):
        x= self.conv_blocks(x)
        return x.view(x.size(0), -1)

In [47]:
model = PrototypeModel(3, 256, 1600)

In [49]:
img = torch.zeros((1, 3,256,256))
ret = model(img)

In [50]:
ret.shape

torch.Size([1, 409600])