In [33]:
import numpy as np
import tensorflow as tf
from scipy.spatial.distance import cdist

In [55]:
way  = 5 # number of query points
shot = 5 # number of support points
embedding_size = 64
classes_per_ep = 60

# generate dummy data using numpy
classes = np.arange(classes_per_ep)
support_labels    = np.array([c for c in classes for s in range(shot)])
support_embedding = np.random.rand(classes_per_ep*shot, embedding_size)
query_labels      = np.array([c for c in classes for w in range(way)])
query_embedding   = np.random.rand(classes_per_ep*way, embedding_size)
print 'query_embedding shape:', query_embedding.shape
print 'support embedding shape:', support_embedding.shape

# randomly permute
support_perm = np.random.permutation(len(support_labels))
support_labels = support_labels[support_perm]
support_embedding = support_embedding[support_perm]

query_perm = np.random.permutation(len(query_labels))
query_labels = query_labels[query_perm]
query_embedding = query_embedding[query_perm]

query_embedding shape: (300, 64)
support embedding shape: (300, 64)


In [56]:
# use numpy to compute distance between prototypes and query_embeddings

# compute prototypes
prototypes = np.zeros((classes_per_ep, embedding_size))
for k in classes:
    class_embedding = support_embedding[np.where(support_labels == k)]
    class_prototype  = np.mean(class_embedding, axis=0)
    prototypes[k] = class_prototype
    
# calculate euclidean squared distance between prototypes and query points
distances = cdist(query_embedding, prototypes)**2
print 'distances shape:', distances.shape
# use distances to predict the class of each query embedding
probs     = np.exp(-distances) / np.sum(np.exp(-distances), axis=1).reshape(-1, 1)
print 'probs shape: ', probs.shape
preds     = np.argmax(probs, axis=1)
print 'preds shape:', preds.shape

distances shape: (300, 60)
probs shape:  (300, 60)
preds shape: (300,)


In [57]:
# use tensorflow
support_embedding_placeholder = tf.placeholder(tf.float32, shape=(classes_per_ep*shot, embedding_size))
support_labels_placeholder = tf.placeholder(tf.int64, shape=(classes_per_ep*shot))
query_embedding_placeholder = tf.placeholder(tf.float32, shape=(classes_per_ep*shot, embedding_size))

# compute the prototype for each class
ones = tf.ones_like(support_embedding_placeholder)
per_class_embedding_sum = tf.unsorted_segment_sum(support_embedding_placeholder, support_labels_placeholder, classes_per_ep)
class_counts = tf.unsorted_segment_sum(ones, support_labels_placeholder, classes_per_ep)
tf_prototypes = per_class_embedding_sum / class_counts

# calculate euclidean distance
query_square_sum = tf.reshape(tf.reduce_sum(tf.square(query_embedding_placeholder), 1), shape=[-1, 1])
proto_square_sum = tf.reduce_sum(tf.square(tf_prototypes), 1)
tf_distances = tf.add(query_square_sum, proto_square_sum, name='square_sum') - 2*tf.matmul(query_embedding_placeholder, tf_prototypes, transpose_b=True)

# use distances to make prediction
logits = -1.0 * tf_distances
tf_predictions = tf.argmax(tf.nn.softmax(logits), 1)

init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    feed_dict = {support_embedding_placeholder : support_embedding, 
                 support_labels_placeholder    : support_labels,
                 query_embedding_placeholder   : query_embedding}
    tf_dist = sess.run([tf_distances], feed_dict=feed_dict)[0]

In [58]:
print np.isclose(tf_dist, distances)

[[ True  True  True ...,  True  True  True]
 [ True  True  True ...,  True  True  True]
 [ True  True  True ...,  True  True  True]
 ..., 
 [ True  True  True ...,  True  True  True]
 [ True  True  True ...,  True  True  True]
 [ True  True  True ...,  True  True  True]]
