# Omniglot Character Set Classification Using Prototypical Network 


Now we will see how to use prototypical networks to perform the classification task. We use omniglot dataset for performing classification. Omniglot dataset comprises of 1,623 handwritten characters from 50 different alphabets and each character has 20 different examples written by different people. Since we want our network to learn from little data, we train them in the same way. We sample five examples from each class and use that as our support set. We learn the embeddings of our support set using a sequence of four convolution blocks as our encoder and build the class prototype. Similarly, we sample five examples from each class for query set, learn the query set embeddings and predict the query set class by comparing the Euclidean distance between the query set embeddings and class prototype. Let us better understand this by going through it step by step.

First we import all the required libraries,

In [1]:
import os
import glob
from PIL import Image

import numpy as np
import tensorflow as tf

How can we convert this image into an array? we can use np.array function to convert these images into an array and then we reshape it to 28*28

In [118]:
image_name = 'candi/train/budha/candi_banyunibo_yogyakarta_budha/image4.PNG'
alphabet, character, rotation = 'budha/candi_banyunibo_yogyakarta_budha/rot000'.split('/')
rotation = float(rotation[3:])

# image_name = 'candi/data\\budha/candi_banyunibo_yogyakarta_budha\\image4.PNG'
# alphabet, rotation = 'budha/candi_jago_malang_budha/rot000'.split('/')
# candi/train/budha\candi_banyunibo_yogyakarta_budha
# rotation = float(rotation[3:])

In [124]:
value_nyoba = np.array(Image.open(image_name).rotate(rotation).resize((28, 28)).convert('L'), np.float32,copy=True)
value_nyoba

array([[126., 126., 122., 121., 120., 121., 122., 126., 125., 132., 121.,
        121., 121., 121., 121., 120., 119., 119., 119., 119., 120., 121.,
        119., 119., 119., 116., 113., 120.],
       [119., 121., 120., 121., 121., 121., 123., 124., 122., 119., 120.,
        121., 121., 121., 121., 121., 120., 120., 120., 120., 121., 121.,
        118., 119., 119., 120., 120., 128.],
       [119., 116., 121., 122., 120., 121., 122., 121., 121., 120., 122.,
        122., 122., 122., 121., 122., 123., 121., 122., 122., 118., 117.,
        120., 121., 122., 127., 125., 146.],
       [150., 130., 113., 120., 120., 121., 122., 121., 121., 120., 122.,
        122., 124., 124., 122., 123., 125., 125., 126., 125., 129., 133.,
        132., 119., 110., 109., 148., 154.],
       [171., 189., 121., 117., 121., 120., 121., 122., 123., 122., 123.,
        123., 127., 128., 124., 125., 127., 125., 121., 123., 134., 136.,
        142., 121., 129., 153., 158., 147.],
       [136., 173., 158., 118., 121

Now, that we have understood, what is in our dataset, let us load our dataset:

In [188]:
value_nyoba.shape

(28, 28)

In [189]:
root_dir = 'candi/train/'

We have the splitting details in the /data/omniglot/splits/train.txt file which has the language name, character number, rotation information and images in /data/omniglot/data/ directory.

In [190]:
train_split_path = os.path.join(root_dir, 'splits', 'train_2.txt')

with open(train_split_path, 'r') as train_split:
    train_classes = [line.rstrip() for line in train_split.readlines()]

In [191]:
#number of classes
no_of_classes = len(train_classes)

In [192]:
no_of_classes

104

In [193]:
train_classes

['budha/candi_banyunibo_yogyakarta_budha/rot000',
 'budha/candi_banyunibo_yogyakarta_budha/rot090',
 'budha/candi_banyunibo_yogyakarta_budha/rot180',
 'budha/candi_banyunibo_yogyakarta_budha/rot270',
 'budha/candi_borobudur_magelang_budha/rot000',
 'budha/candi_borobudur_magelang_budha/rot090',
 'budha/candi_borobudur_magelang_budha/rot180',
 'budha/candi_borobudur_magelang_budha/rot270',
 'budha/candi_brahu_mojokerto_budha/rot000',
 'budha/candi_brahu_mojokerto_budha/rot090',
 'budha/candi_brahu_mojokerto_budha/rot180',
 'budha/candi_brahu_mojokerto_budha/rot270',
 'budha/candi_jago_malang_budha/rot000',
 'budha/candi_jago_malang_budha/rot090',
 'budha/candi_jago_malang_budha/rot180',
 'budha/candi_jago_malang_budha/rot270',
 'budha/candi_kalasan_yogyakarta_budha/rot000',
 'budha/candi_kalasan_yogyakarta_budha/rot090',
 'budha/candi_kalasan_yogyakarta_budha/rot180',
 'budha/candi_kalasan_yogyakarta_budha/rot270',
 'budha/candi_lumbung_magelang_budha/rot000',
 'budha/candi_lumbung_mage

Now we set the number of examples to 20, as we have 20 example per class in our dataset, and also we set image width and height to 28 x 28:

In [194]:
#number of examples
num_examples = 20

#image width
img_width = 28

#image height
img_height = 28
channels = 1

Next, we initialize our training dataset with a shape as a number of classes, number of examples, image height and image width:

In [195]:
train_dataset = np.zeros([no_of_classes, num_examples, img_height, img_width], dtype=np.float32)

In [196]:
train_dataset

array([[[[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
    

In [197]:
train_dataset.shape

(104, 20, 28, 28)

Now, we read all the images, convert it to numpy array and store it our train_dataset array with their label and values, that is,  train_dataset = [label, values]:

In [222]:
for label, name in enumerate(train_classes):
    alphabet, character, rotation = name.split('/')
    rotation = float(rotation[4:])
    img_dir = os.path.join(root_dir, alphabet, character)
    img_files = glob.glob(os.path.join(img_dir, '*.PNG'))
    # print(img_dir)
  
    
    for index, img_file in enumerate(img_files):
        values = np.array(Image.open(img_file).rotate(rotation).resize((img_width, img_height)).convert('L'), np.float32, copy=False)
        train_dataset[label, index] = values

In [223]:
# for label,name in enumerate(train_classes):

#     alphabet, rotation = name.split('/')
#     rotation = float(rotation[3:])
#     img_dir = os.path.join(root_dir, alphabet)
#     img_files = glob.glob(os.path.join(img_dir, '*.PNG'))

#     for index, img_file in enumerate(img_files):
#         values = 1. - np.array(Image.open(img_file).rotate(rotation).resize((img_width, img_height)), np.float32, copy=False)
#         # print(index)
#         train_dataset[label, index] = values



In [224]:
train_dataset.shape

(104, 20, 28, 28)

In [225]:
train_dataset

array([[[[183., 201., 223., ..., 224., 221., 216.],
         [196., 177., 211., ..., 210., 204., 199.],
         [229., 207., 225., ..., 190., 185., 184.],
         ...,
         [ 88., 127., 114., ..., 103.,  47.,  16.],
         [ 58.,  77.,  61., ..., 112.,  86.,  76.],
         [ 61.,  75.,  97., ..., 121., 109.,  88.]],

        [[112., 107., 109., ...,  99.,  95.,  94.],
         [106., 104., 104., ...,  93.,  93.,  94.],
         [ 99.,  97., 100., ...,  85.,  89.,  91.],
         ...,
         [ 52.,  50.,  61., ..., 101.,  89.,  83.],
         [ 87.,  91.,  93., ...,  99.,  97.,  92.],
         [127., 126., 127., ..., 126., 125., 126.]],

        [[232., 236., 236., ..., 223., 231., 234.],
         [226., 228., 230., ..., 223., 227., 229.],
         [220., 221., 227., ..., 218., 219., 221.],
         ...,
         [100., 121., 129., ...,  73.,  79., 110.],
         [ 90.,  80.,  85., ...,  64.,  78., 112.],
         [ 93.,  91.,  68., ...,  65.,  81., 101.]],

        ...,

  

Now that we have loaded our training data, we need to create embeddings for them. We generate the embeddings using a convolution operation as our inputs are images. So, we define a convolutional block with 64 filters with batch normalization and ReLU as the activation function. Followed by we perform max pooling operation:

In [226]:
import tensorflow as tf

def convolution_block(inputs, out_channels, name='conv'):

    conv = tf.keras.layers.Conv2D(out_channels, kernel_size=3, padding='same')(inputs)
    conv = tf.keras.layers.BatchNormalization(momentum=0.99, scale=True, center=True)(conv)
    conv = tf.keras.layers.Activation('relu')(conv)
    conv = tf.keras.layers.MaxPooling2D(pool_size=2)(conv)
    
    return conv

Now, we define our embedding function which gives us the embedding comprising of four convolutional blocks:

In [227]:

def get_embeddings(support_set, h_dim, z_dim, reuse=False):

        net = convolution_block(support_set, h_dim)
        net = convolution_block(net, h_dim)
        net = convolution_block(net, h_dim) 
        net = convolution_block(net, z_dim) 
        flatten_layer = tf.keras.layers.Flatten()
        net = flatten_layer(net)
        
        return net

Remember, we don't use our whole dataset for training, since we are one shot learning, we sample some data points from each class as a support set and train the network using the support set in an episodic fashion. 


Now we define some of the important variables, we consider a 60-way 5-shot learning scenario:

In [246]:
#number of classes
num_way = 2

#number of examples per class for support set
num_shot = 5  

#number of query points
num_query = 5 

#number of examples
num_examples = 20

h_dim = 64

z_dim = 64

Next, we initialize placeholders for our support set and query set:

In [247]:
import tensorflow.compat.v1 as tf
tf.disable_eager_execution()

support_set = tf.compat.v1.placeholder(tf.float32, [None, None, img_height, img_width, channels])
query_set = tf.compat.v1.placeholder(tf.float32, [None, None, img_height, img_width, channels])

And we store the shape of our support set and query set in support_set_shape and query_set_shape respectively:

In [248]:
support_set_shape = tf.shape(support_set)
query_set_shape = tf.shape(query_set)

Get the number of classes and number of data points in the support set and number of data points in the query set for initializing our support and query sets:

In [249]:
num_classes, num_support_points = support_set_shape[0], support_set_shape[1]

In [250]:
num_query_points = query_set_shape[1]

Next, we define the placeholder for our label:

In [251]:
y = tf.placeholder(tf.int64, [None, None])

#convert the label to one hot
y_one_hot = tf.one_hot(y, depth=num_classes)

Now, we generate the embeddings for our support set using our embedding function:

In [252]:
support_set_embeddings = get_embeddings(tf.reshape(support_set, [num_classes * num_support_points, img_height, img_width, channels]), h_dim, z_dim)

We compute the prototype of each class which is the mean vector of the support set embeddings of the class:

In [253]:
embedding_dimension = tf.shape(support_set_embeddings)[-1]

class_prototype = tf.reduce_mean(tf.reshape(support_set_embeddings, [num_classes, num_support_points, embedding_dimension]), axis=1)

Next, we use our same embedding function for getting embeddings of the query set:

In [254]:
query_set_embeddings = get_embeddings(tf.reshape(query_set, [num_classes * num_query_points, img_height, img_width, channels]), h_dim, z_dim, reuse=True)

Now that, we have the class prototype and query set embeddings, we define a distance function which gives us the distance between the class prototypes and query set embeddings:

In [255]:
def euclidean_distance(a, b):

    N, D = tf.shape(a)[0], tf.shape(a)[1]
    M = tf.shape(b)[0]
    a = tf.tile(tf.expand_dims(a, axis=1), (1, M, 1))
    b = tf.tile(tf.expand_dims(b, axis=0), (N, 1, 1))
    return tf.reduce_mean(tf.square(a - b), axis=2)


Calculate the distance between the class prototype and query set embeddings:

In [256]:
distance = euclidean_distance(class_prototype,query_set_embeddings)

Next, we get the probability for each class as a softmax to the distance:

In [257]:
predicted_probability = tf.reshape(tf.nn.log_softmax(-distance), [num_classes, num_query_points, -1])

Compute the loss

In [258]:
loss = -tf.reduce_mean(tf.reshape(tf.reduce_sum(tf.multiply(y_one_hot, predicted_probability), axis=-1), [-1]))

Calculate accuracy

In [259]:
accuracy = tf.reduce_mean(tf.to_float(tf.equal(tf.argmax(predicted_probability, axis=-1), y)))

We use Adam optimizer for minimizing the loss:



In [260]:
train = tf.train.AdamOptimizer().minimize(loss)

Now, we start our tensorflow session and train the model,

In [261]:
sess = tf.InteractiveSession()
init = tf.global_variables_initializer()
sess.run(init)

In [262]:
num_epochs = 20
num_episodes = 100

In [263]:
for epoch in range(num_epochs):
    
    for episode in range(num_episodes):
        
        # select 60 classes
        episodic_classes = np.random.permutation(no_of_classes)[:num_way]
        
        support = np.zeros([num_way, num_shot, img_height, img_width], dtype=np.float32)
        
        query = np.zeros([num_way, num_query, img_height, img_width], dtype=np.float32)
        
        
        for index, class_ in enumerate(episodic_classes):
            selected = np.random.permutation(num_examples)[:num_shot + num_query]
            support[index] = train_dataset[class_, selected[:num_shot]]
            
            # 5 querypoints per classs
            query[index] = train_dataset[class_, selected[num_shot:]]
            
        support = np.expand_dims(support, axis=-1)
        query = np.expand_dims(query, axis=-1)
        labels = np.tile(np.arange(num_way)[:, np.newaxis], (1, num_query)).astype(np.uint8)
        _, loss_, accuracy_ = sess.run([train, loss, accuracy], feed_dict={support_set: support, query_set: query, y:labels})
        
        if (episode+1) % 10 == 0:
            print('Epoch {} : Episode {} : Loss: {}, Accuracy: {}'.format(epoch+1, episode+1, loss_, accuracy_))

Epoch 1 : Episode 10 : Loss: 4.409198760986328, Accuracy: 0.5
Epoch 1 : Episode 20 : Loss: 2.3422904014587402, Accuracy: 0.5
Epoch 1 : Episode 30 : Loss: 2.305872678756714, Accuracy: 0.5
Epoch 1 : Episode 40 : Loss: 2.3034584522247314, Accuracy: 0.5
Epoch 1 : Episode 50 : Loss: 2.3013694286346436, Accuracy: 0.5
Epoch 1 : Episode 60 : Loss: 2.301924228668213, Accuracy: 0.5
Epoch 1 : Episode 70 : Loss: 2.303318500518799, Accuracy: 0.5
Epoch 1 : Episode 80 : Loss: 2.311032772064209, Accuracy: 0.5
Epoch 1 : Episode 90 : Loss: 2.3025851249694824, Accuracy: 0.5
Epoch 1 : Episode 100 : Loss: 2.3015854358673096, Accuracy: 0.5
Epoch 2 : Episode 10 : Loss: 2.3030848503112793, Accuracy: 0.5
Epoch 2 : Episode 20 : Loss: 2.3025081157684326, Accuracy: 0.5
Epoch 2 : Episode 30 : Loss: 2.3029913902282715, Accuracy: 0.5
Epoch 2 : Episode 40 : Loss: 2.30238676071167, Accuracy: 0.5
Epoch 2 : Episode 50 : Loss: 2.3024327754974365, Accuracy: 0.6000000238418579
Epoch 2 : Episode 60 : Loss: 2.302691698074341