In [16]:
import argparse
from PIL import Image
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
import os
import glob
import matplotlib.pyplot as plt
from load_data import DataGenerator

In [17]:
class ProtoNet(tf.keras.Model):
    
    def __init__(self, num_filters, latent_dim):
        super(ProtoNet, self).__init__()
        self.num_filters = num_filters
        self.latent_dim = latent_dim
        num_filter_list = self.num_filters + [latent_dim]
        self.convs = []
        for i, num_filter in enumerate(num_filter_list):
            block_parts = [layers.Conv2D(filters=num_filter, kernel_size=3, padding='SAME', activation='linear')]
            block_parts += [layers.BatchNormalization()]
            block_parts += [layers.Activation('relu')]
            block_parts += [layers.MaxPool2D()]
            block = tf.keras.Sequential(block_parts, name = 'conv_block_{}'.format(i))
            self.__setattr__("conv{}".format(i), block)
            self.convs.append(block)
        self.flatten = tf.keras.layers.Flatten()
        
    def call(self, inp):
        out = inp
        for conv in self.convs:
            out = conv(out)
        out = self.flatten(out)
        return out

In [18]:
def calc_euclidian_dists(x, y):
    """
    Calculate euclidian distance between two 3D tensors.
    Args:
        x (tf.Tensor):
        y (tf.Tensor):
    Returns (tf.Tensor): 2-dim tensor with distances.
    """
    n = x.shape[0]
    m = y.shape[0]
    x = tf.tile(tf.expand_dims(x, 1), [1, m, 1])
    y = tf.tile(tf.expand_dims(y, 0), [n, 1, 1])
    return tf.reduce_mean(tf.math.pow(x - y, 2), 2)


In [19]:
def ProtoLoss(x_latent, q_latent, labels_onehot, num_classes, num_support, num_queries):
    """
        calculates the prototype network loss using the latent representation of x
        and the latent representation of the query set
        Args:
            x_latent: latent representation of supports with shape [N*S, D], where D is the latent dimension
            q_latent: latent representation of queries with shape [N*Q, D], where D is the latent dimension
            labels_onehot: one-hot encodings of the labels of the queries with shape [N, Q, N]
            num_classes: number of classes (N) for classification
            num_support: number of examples (S) in the support set
            num_queries: number of examples (Q) in the query set
        Returns:
            ce_loss: the cross entropy loss between the predicted labels and true labels
            acc: the accuracy of classification on the queries
    """
    #############################
    #### YOUR CODE GOES HERE ####
    
    # compute the prototypes
    x_latent_test = np.reshape(x_latent, (num_classes, -1, x_latent.shape[1]))
    prototypes = tf.math.reduce_mean(x_latent_test, axis=1)
    dists = calc_euclidian_dists(q_latent, prototypes)
    
    
    # compute cross entropy loss
    labels_onehot = np.reshape(labels_onehot, (-1, num_classes))
    labels_sparse = tf.argmax(labels_onehot, axis=1)
    
    ### using cross-entropy loss
    softmaxes = tf.nn.softmax(-dists,axis=1)
    scce = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
    ce_loss = scce(labels_sparse, softmaxes)
    scacc = tf.keras.metrics.SparseCategoricalAccuracy()
    acc = scacc(labels_sparse, softmaxes)
    """
    ## using original loss from the original paper 
    log_p_y = tf.nn.log_softmax(-dists, axis=1)
    ce_loss = -tf.reduce_mean(tf.reduce_sum(tf.multiply(labels_onehot, log_p_y), axis=1))
    
    eq = tf.cast(tf.equal(
            tf.cast(tf.argmax(softmaxes, axis=1), tf.int32), 
            tf.cast(labels_sparse, tf.int32)), tf.float32)
    acc = tf.reduce_mean(eq)
    """


    # return the cross-entropy loss and accuracy
    #ce_loss, acc = None, None
    #############################
    return ce_loss, acc

In [20]:
def plot_results(logs):
    
    fig, ax = plt.subplots()
    
    steps = [x for x in range(len(logs['ce_loss']))]
    
    ax.plot(steps, logs['ce_loss'], label = 'train loss')
    ax.plot(steps, logs['val_ce_loss'], label = 'val_loss')
    ax.plot(steps, logs['acc'], label = 'train accuracy')
    ax.plot(steps, logs['val_acc'], label = 'val accuracy')
    ax.set(xlabel='Iterations', title='ProtoNet meta-training')
    
    
    ax.grid()
    ax.legend()
    plt.show()
    
                                                                                                                                                        

In [21]:
DATA_PATH = './omniglot_resized'
NO_CLASSES = 3
NO_SHOTS = 5
NO_QUERIES = 5
NO_CLASSES_META_TEST = 3
NO_SHOTS_META_TEST = 5
NO_QUERIES_META_TEST = 5
NO_EPOCHS = 20
NO_EPISODES = 100
IM_WIDTH, IM_HEIGHT, CHANNELS = 28, 28, 1
NUM_FILTERS = 16
LATENT_DIM = 16
NUM_CONV_LAYERS = 3
NO_META_TEST_EPISODES = 1000


In [22]:
batch_size = 32
optimizer = tf.keras.optimizers.Adam()
model = ProtoNet([NUM_FILTERS] * NUM_CONV_LAYERS, LATENT_DIM)
data_generator = DataGenerator(NO_CLASSES, 
                               NO_SHOTS + NO_QUERIES,
                               NO_CLASSES_META_TEST,
                               NO_SHOTS_META_TEST + NO_QUERIES_META_TEST,
                               config = {'data_folder': DATA_PATH})
logs = {'ce_loss': [],
        'acc': [],
        'val_ce_loss': [],
        'val_acc' : []}

for epoch in range(NO_EPOCHS):
    for episode in range(NO_EPISODES):
        images, labels = data_generator.sample_batch(batch_type = "meta_train",
                                                   batch_size =  batch_size, 
                                                   shuffle=True)

        support = images[:, :, :NO_SHOTS, :]
        query = images[:, :, NO_SHOTS:, :]
        labels_ph = labels[:, :, NO_SHOTS:, :]
        
        num_support = support.shape[2]
        num_queries = query.shape[2]
        
        support = support.reshape(-1, 28, 28, 1)
        query = query.reshape(-1, 28, 28, 1)
        
        #def train_step(inputs, labels, model, loss_object, optimizer):
        with tf.GradientTape() as tape:
            x_latent = model(support)
            q_latent = model(query)
            ce_loss, acc = ProtoLoss(x_latent, q_latent, labels_ph, NO_CLASSES, num_support, num_queries)
        gradients = tape.gradient(ce_loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        
        if episode % 50 == 0:
            images, labels = data_generator.sample_batch(batch_type = "meta_val",
                                                   batch_size =  batch_size, 
                                                   shuffle=True)
            support = images[:, :, :NO_SHOTS, :]
            query = images[:, :, NO_SHOTS:, :]
            labels_ph = labels[:, :, NO_SHOTS:, :]
        
            num_support = support.shape[2]
            num_queries = query.shape[2]
        
            support = support.reshape(-1, 28, 28, 1)
            query = query.reshape(-1, 28, 28, 1)
        
            x_latent = model(support)
            q_latent = model(query)
            val_ce_loss, val_acc = ProtoLoss(x_latent, q_latent, labels_ph, NO_CLASSES, num_support, num_queries)
            
            logs['ce_loss'].append(ce_loss)
            logs['val_ce_loss'].append(val_ce_loss)
            logs['acc'].append(acc)
            logs['val_acc'].append(val_acc)
            
            print('[Epoch {}/{}, Episode {}/{}] => meta-training loss: {:.5f}, meta-training acc: {:.5f}, meta-val loss: {:.5f}, meta-val acc: {:.5f}'.format(
                                                                                                                                                            epoch,
                                                                                                                                                            NO_EPOCHS,
                                                                                                                                                            episode,
                                                                                                                                                            NO_EPISODES,
                                                                                                                                                            ce_loss,
                                                                                                                                                            acc,
                                                                                                                                                            val_ce_loss,
                                                                                                                                                            val_acc))
plot_results(logs)
model.save_weights('./checkpoints/my_checkpoint')




To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.

[Epoch 0/20, Episode 0/100] => meta-training loss: 1.09861, meta-training acc: 0.35417, meta-val loss: 1.09862, meta-val acc: 0.35417


KeyboardInterrupt: 

In [None]:
print('Testing...')
meta_test_accuracies = []
for episode in range(NO_META_TEST_EPISODES):
    #############################
    #### YOUR CODE GOES HERE ####

    # sample a batch of test data and partition into
    # support and query sets
    images, labels = data_generator.sample_batch(batch_type = "meta_test",
                                                 batch_size =  batch_size, 
                                                 shuffle=False)
    print(images.shape)
    print(labels.shape)
    support = images[:, :, :NO_SHOTS_META_TEST, :]
    query = images[:, :, NO_SHOTS_META_TEST:, :]
    labels_ph = labels[:, :, NO_SHOTS_META_TEST:, :]
    
    print(support.shape)
    print(query.shape)

    num_support = support.shape[2]
    num_queries = query.shape[2]

    support = support.reshape(-1, 28, 28, 1)
    query = support.reshape(-1, 28, 28, 1)
    
    print(support.shape)
    print(query.shape)
    print(labels_ph.shape)
    x_latent = model(support)
    q_latent = model(query)
    test_ce_loss, test_acc = ProtoLoss(x_latent, q_latent, labels_ph, NO_CLASSES_META_TEST, num_support, num_queries)
    #############################
    meta_test_accuracies.append(ac)
    if episode % 50 == 0:
        print('[Meta-test episode {}/{}] => loss: {:.5f}, acc: {:.5f}'.format(episode,
                                                                              NO_META_TEST_EPISODES, 
                                                                              test_loss,
                                                                              test_acc))
avg_acc = np.mean(meta_test_accuracies)
stds = np.std(meta_test_accuracies)
print('Average Meta-Test Accuracy: {:.5f}, Meta-Test Accuracy Std: {:.5f}'.format(avg_acc, stds))