In this notebook I reproduced the prototype network from https://arxiv.org/abs/1703.05175 by training and testing it on the omniglot dataset.
The analysis done here is not exactly the same as in the paper. However it still shows very good results even though this notebook is a reduced form of the paper.

In [2]:
import matplotlib.pyplot as plt
import torch
from torch import nn, cdist
from torchvision import datasets, transforms
import helper
from skimage import io
import os
from sklearn.model_selection import train_test_split
import random
import numpy as np
np.set_printoptions(precision=5)

In [3]:
transform = transforms.Compose([transforms.ToTensor(),transforms.Resize(28),
                                ])

In [4]:
#get all paths, to sample randomly later
paths=['../input/omniglot/images_background/','../input/omniglot/images_evaluation/']
classes = []

for path in paths:
    for (dirpath, dirnames, filenames) in os.walk(path):
        alphabets = dirnames
        break

    for alphabet in alphabets:
        for (dirpath, dirnames, filenames) in os.walk(path+alphabet):
            for element in dirnames:
                classes.append(path+alphabet+'/'+element)
            break

print(len(classes))

1623


In [5]:
# split classes into training and testing sets
train_classes, test_classes = train_test_split(classes, random_state=1, train_size=0.86)
print(len(train_classes))


1395


In [6]:
classes[:5]

['../input/omniglot/images_background/Grantha/character15',
 '../input/omniglot/images_background/Grantha/character11',
 '../input/omniglot/images_background/Grantha/character35',
 '../input/omniglot/images_background/Grantha/character21',
 '../input/omniglot/images_background/Grantha/character38']

In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

class ProtoNet(nn.Module):
    def __init__(self):
        super(ProtoNet, self).__init__()
        self.block_1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'),
            nn.BatchNorm2d(num_features=64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.MaxPool2d(kernel_size=2, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False),
            nn.ReLU()
        )
        self.block_2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'),
            nn.BatchNorm2d(num_features=64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.MaxPool2d(kernel_size=2, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False),
            nn.ReLU()
        )
        self.block_3 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'),
            nn.BatchNorm2d(num_features=64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.MaxPool2d(kernel_size=2, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False),
            nn.ReLU()
        )
        self.block_4 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'),
            nn.BatchNorm2d(num_features=64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.MaxPool2d(kernel_size=2, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False),
            nn.ReLU()
        )
        

    def forward(self, x):
        x = self.block_1(x)
        x = self.block_2(x)
        x = self.block_3(x)
        #x = self.block_4(x)
        x = torch.flatten(x, start_dim=1, end_dim=-1) # 展开为1维
        return x
    
model = ProtoNet().to(device)


Using cuda device


In [12]:
# now the actual training algorithm (algorithm 1 in the paper)

Nc = 60
N = 20
Ns = 5 # support
Nq = N - Ns # query
norm = 1/(Nc*Nq)
loss_function = nn.LogSoftmax(dim=0)
num_episodes = 700
print_interval = 10
recent_loss = []
recent_accuracy = []

for e in range(1,num_episodes+1):
    # samble Nc classes
    episode_classes = random.sample(train_classes, Nc) # 抽取60个类训练
    prototypes = [] # 每个类的原型
    query_tensors = [] # 所有类的查询集
    opt = torch.optim.Adam(model.parameters(), lr=0.001)
    for k in range(Nc):
        #load entire folder

        support = []
        query = []

        for (dirpath, dirnames, filenames) in os.walk(episode_classes[k]):
            for count, filename in enumerate(filenames,0):
                if not filename.endswith('.png'): continue
                image = transform(io.imread(episode_classes[k]+'/'+filename))
                if count < Ns:
                    support.append(image) # 前5个样本作为支撑集
                else:
                    query.append(image)   # 后15个作为查询集
        
        support_tensor = torch.stack(support)  # (5, 1, 28, 28)
        # query是 [15个shape(1,28,28)], 用stack可以把query中的15个 (1,28,28) tensor 合并为1个(15, 1, 28, 28)tensor
        query_tensor = torch.stack(query) # [15, 1, 28, 28]， 当前类
        query_tensors.append(query_tensor) # 所有类

        # now that we have our support and query sets
        # compute the prototypes with the support set

        prototypes.append(torch.mean(model(support_tensor.to(device)), axis=0))  # 支撑集上当前类的平均embedding
    
#     print(prototypes[0].shape) # [64]
    prototypes_tensor = torch.stack(prototypes).unsqueeze(0) # [15个[64]] --> (15, 64) --> (1, 15, 64)

    query_tensors = torch.stack(query_tensors) #shape = Nc, Nq, 1, 28, 28， 即所有类的query_tensor
#     print(query_tensors.shape) # [60, 15, 1, 28, 28]

    opt.zero_grad()
    loss = 0
    accuracy = 0
    for k in range(Nc):
        
        query_set = query_tensors[k]

        embedded_vectors = model(query_set.to(device)).unsqueeze(0) # shape = 1, Nq, 64
        distances = -cdist(prototypes_tensor, embedded_vectors).view(Nc,Nq) # 计算两两距离,https://blog.csdn.net/weixin_43509698/article/details/111463091
#         print(distances.shape) # [60, 15]，负值
#         print(distances[0])
        logsoftmax_dist = loss_function(distances) # 负值
        class_ls_dist =logsoftmax_dist[k]
        loss -= torch.sum(class_ls_dist) # class_ls_dist为负值
        
        #additional, calculate accuracy on training query images
        dist = distances.detach().to('cpu').numpy()
        
        pred_class = np.argmax(dist, axis=0)
        correct = [p == k for p in pred_class]
        accuracy += np.sum(correct)
        
    loss = loss*norm
    accuracy = accuracy*norm

    loss.backward()
    opt.step()
    
    recent_loss.append(loss.detach().to('cpu').numpy())
    recent_accuracy.append(accuracy)
    
    if e%print_interval == 0:
        print('episode', e, 'loss', np.mean(recent_loss), 'accuracy', np.mean(recent_accuracy))
        recent_loss = []
        recent_accuracy = []


episode 10 loss 2.9751658 accuracy 0.3857777777777778
episode 20 loss 2.7927127 accuracy 0.43133333333333335
episode 30 loss 2.5544593 accuracy 0.4845555555555555
episode 40 loss 2.448113 accuracy 0.5241111111111111
episode 50 loss 2.3549953 accuracy 0.5422222222222223
episode 60 loss 2.2702096 accuracy 0.5617777777777777
episode 70 loss 2.2168486 accuracy 0.5835555555555556
episode 80 loss 2.1074445 accuracy 0.6088888888888888
episode 90 loss 2.0663075 accuracy 0.6152222222222221
episode 100 loss 2.0019755 accuracy 0.6277777777777778
episode 110 loss 1.9944127 accuracy 0.6278888888888889
episode 120 loss 1.926268 accuracy 0.6446666666666667
episode 130 loss 1.9099905 accuracy 0.6445555555555555
episode 140 loss 1.8513796 accuracy 0.6648888888888889
episode 150 loss 1.8053716 accuracy 0.6721111111111111
episode 160 loss 1.7575285 accuracy 0.6795555555555556
episode 170 loss 1.736083 accuracy 0.681
episode 180 loss 1.6978697 accuracy 0.6968888888888889
episode 190 loss 1.6559618 accurac

In [14]:
# validate the trained network

Nc = 20
N = 20
Ns = 5
Nq = N - Ns
norm = 1/(Nc*Nq)

num_episodes = 100
print_interval = 10
recent_loss = []
recent_accuracy = []
total_loss = []
total_accuracy = []

for e in range(1,num_episodes+1):
    # samble Nc classes
    episode_classes = random.sample(test_classes, Nc)
    prototypes = []
    query_tensors = []

    for k in range(Nc):
        #load entire folder

        support = []
        query = []

        for (dirpath, dirnames, filenames) in os.walk(episode_classes[k]):
            for count, filename in enumerate(filenames,0):
                if not filename.endswith('.png'): continue
                image = transform(io.imread(episode_classes[k]+'/'+filename))
                if count < Ns:
                    support.append(image)
                else:
                    query.append(image)

        support_tensor = torch.stack(support)
        query_tensor = torch.stack(query)
        query_tensors.append(query_tensor)

        # now that we have our support and query sets
        # compute the prototypes with the support set

        prototypes.append(torch.mean(model(support_tensor.to(device)), axis=0))

    prototypes_tensor = torch.stack(prototypes).unsqueeze(0) # shape = 1, Nc, 64

    query_tensors = torch.stack(query_tensors) #shape = Nc, Nq, 1, 28, 28

    loss = 0
    accuracy = 0
    for k in range(Nc):
        
        query_set = query_tensors[k]

        embedded_vectors = model(query_set.to(device)).unsqueeze(0) # shape = 1, Nq, 64
        distances = -cdist(prototypes_tensor, embedded_vectors).view(Nc,Nq)
        logsoftmax_dist = loss_function(distances)
        class_ls_dist =logsoftmax_dist[k]
        loss -= torch.sum(class_ls_dist)
        
        #calculate accuracy query images
        dist = distances.detach().to('cpu').numpy()
        
        pred_class = np.argmax(dist, axis=0)
        correct = [p == k for p in pred_class]
        accuracy += np.sum(correct)
        
    loss = loss*norm
    accuracy = accuracy*norm
    
    recent_loss.append(loss.detach().to('cpu').numpy())
    recent_accuracy.append(accuracy)
    
    if e%print_interval == 0:
        loss = np.mean(recent_loss)
        accuracy = np.mean(recent_accuracy)
        print('episode', e, 'loss', loss, 'accuracy', accuracy)
        recent_loss = []
        recent_accuracy = []
        total_loss.append(loss)
        total_accuracy.append(accuracy)
    
print('total loss:', np.mean(total_loss), 'total accuracy:', np.mean(total_accuracy))

episode 10 loss 0.62443894 accuracy 0.8470000000000001
episode 20 loss 0.60585415 accuracy 0.8593333333333334
episode 30 loss 0.5644263 accuracy 0.8753333333333334
episode 40 loss 0.64522886 accuracy 0.8523333333333334
episode 50 loss 0.6055968 accuracy 0.8503333333333334
episode 60 loss 0.5934883 accuracy 0.8630000000000001
episode 70 loss 0.5951405 accuracy 0.8633333333333335
episode 80 loss 0.5802604 accuracy 0.8696666666666667
episode 90 loss 0.6307229 accuracy 0.8510000000000002
episode 100 loss 0.6509991 accuracy 0.837
total loss: 0.6096156 total accuracy: 0.8568333333333333


The model was trained with 60 classes per episode and tested with 20 classes per episode. We can see that the model gives pretty good predictions even though I just used 700 training episodes while the paper uses much more. More training and also using more classes by rotationg the training images would most likely lead to even better results. However my goal was to understand few-shot and one-shot learning therefore this toy example is already enough and could be easily extended to the full analysis from the paper.

In case you have any suggestions to improve on this notebook I would be happy if you would write a comment.