In [1]:
%matplotlib inline
from __future__ import print_function
from PIL import Image
import numpy as np
import tensorflow as tf
import os
import glob
import matplotlib.pyplot as plt

In [2]:
def conv_block(inputs, out_channels, kernel_size=3, strides=2, padding='SAME', name='conv'):
    with tf.variable_scope(name):
        conv = tf.layers.conv2d(inputs, out_channels, kernel_size=kernel_size, strides=strides, padding=padding)
        conv = tf.contrib.layers.batch_norm(conv, updates_collections=None, decay=0.99, scale=True, center=True)
        conv = tf.nn.relu(conv)
        #conv = tf.contrib.layers.max_pool2d(conv, 2)
        return conv

In [3]:
def encoder(x, h_dim, z_dim, reuse=False):
    with tf.variable_scope('encoder', reuse=reuse):
        net = conv_block(x, h_dim, kernel_size=3, strides=2, padding='SAME', name='conv_1') # 42x42
        net = conv_block(net, h_dim, kernel_size=3, strides=2, padding='SAME', name='conv_2') # 21x21
        net = conv_block(net, h_dim,  kernel_size=3, strides=2, padding='VALID', name='conv_3') # 10x10
        net = conv_block(net, z_dim, name='conv_4') # 5x5
        net = tf.contrib.layers.flatten(net)
        return net

In [4]:
def deconv_block(inputs, out_channels, size=3, stride=2, padding='SAME', name='deconv'):
    with tf.variable_scope(name):
        conv = tf.layers.conv2d_transpose(inputs, out_channels, kernel_size=size, strides=stride, padding=padding)
        conv = tf.contrib.layers.batch_norm(conv, updates_collections=None, decay=0.99, scale=True, center=True)
        conv = tf.nn.relu(conv)
        return conv

In [5]:
def decoder(x, h_dim, z_dim, reuse=False):
    with tf.variable_scope('decoder', reuse=reuse):
        net = tf.layers.dense(x, 5 * 5 * 64)
        net = tf.reshape(net, [-1, 5, 5, 64])
        net = deconv_block(net, h_dim, size=4, stride=2, padding='SAME', name='deconv_1') # 10x10
        net = deconv_block(net, h_dim, size=3, stride=2, padding='VALID', name='deconv_2') # 21x21
        net = deconv_block(net, h_dim, size=4, stride=2, padding='SAME', name='deconv_3') # 42x42
        net = deconv_block(net, h_dim, size=4, stride=2, padding='SAME', name='deconv_4') # 84x84
        net = tf.layers.conv2d(net, 3, 3, padding='SAME')
        print(net.shape)
        net = tf.nn.tanh(net)
        return net

In [6]:
def euclidean_distance(a, b):
    # a.shape = N x D
    # b.shape = M x D
    N, D = tf.shape(a)[0], tf.shape(a)[1]
    M = tf.shape(b)[0]
    a = tf.tile(tf.expand_dims(a, axis=1), (1, M, 1))
    b = tf.tile(tf.expand_dims(b, axis=0), (N, 1, 1))
    return tf.reduce_mean(tf.square(a - b), axis=2)

In [7]:
n_epochs = 100
n_episodes = 100
n_way = 20
n_shot = 5
n_query = 15
n_examples = 350
im_width, im_height, channels = 84, 84, 3
h_dim = 64
z_dim = 64

In [8]:
# Load Train Dataset
train_dataset = np.load('mini-imagenet-train.npy')
n_classes = train_dataset.shape[0]
print(train_dataset.shape)

(64, 350, 84, 84, 3)


In [9]:
x = tf.placeholder(tf.float32, [None, None, im_height, im_width, channels])
q = tf.placeholder(tf.float32, [None, None, im_height, im_width, channels])
x_shape = tf.shape(x)
q_shape = tf.shape(q)
num_classes, num_support = x_shape[0], x_shape[1]
num_queries = q_shape[1]
y = tf.placeholder(tf.int64, [None, None])
y_one_hot = tf.one_hot(y, depth=num_classes)
emb_in = encoder(tf.reshape(x, [num_classes * num_support, im_height, im_width, channels]), h_dim, z_dim)
emb_dim = tf.shape(emb_in)[-1]
emb_x = tf.reduce_mean(tf.reshape(emb_in, [num_classes, num_support, emb_dim]), axis=1)
emb_q = encoder(tf.reshape(q, [num_classes * num_queries, im_height, im_width, channels]), h_dim, z_dim, reuse=True)

q_hat = decoder(emb_q, h_dim, z_dim)
q_label = tf.reshape(q, [num_classes * num_queries, im_height, im_width, channels])
recon_loss = tf.reduce_mean(tf.square(q_label-q_hat))

dists = euclidean_distance(emb_q, emb_x)
log_p_y = tf.reshape(tf.nn.log_softmax(-dists), [num_classes, num_queries, -1])
ce_loss = -tf.reduce_mean(tf.reshape(tf.reduce_sum(tf.multiply(y_one_hot, log_p_y), axis=-1), [-1]))
acc = tf.reduce_mean(tf.to_float(tf.equal(tf.argmax(log_p_y, axis=-1), y)))

(?, 84, 84, 3)


In [10]:
train_op = tf.train.AdamOptimizer().minimize(ce_loss+recon_loss)

In [11]:
sess = tf.InteractiveSession()
init_op = tf.global_variables_initializer()
sess.run(init_op)

In [12]:
for ep in range(n_epochs):
    for epi in range(n_episodes):
        epi_classes = np.random.permutation(n_classes)[:n_way]
        support = np.zeros([n_way, n_shot, im_height, im_width, channels], dtype=np.float32)
        query = np.zeros([n_way, n_query, im_height, im_width, channels], dtype=np.float32)
        for i, epi_cls in enumerate(epi_classes):
            selected = np.random.permutation(n_examples)[:n_shot + n_query]
            support[i] = train_dataset[epi_cls, selected[:n_shot]]
            query[i] = train_dataset[epi_cls, selected[n_shot:]]
        # support = np.expand_dims(support, axis=-1)
        # query = np.expand_dims(query, axis=-1)
        labels = np.tile(np.arange(n_way)[:, np.newaxis], (1, n_query)).astype(np.uint8)
        _, ls, ac = sess.run([train_op, ce_loss, acc], feed_dict={x: support, q: query, y:labels})
        if (epi+1) % 50 == 0:
            print('[epoch {}/{}, episode {}/{}] => loss: {:.5f}, acc: {:.5f}'.format(ep+1, n_epochs, epi+1, n_episodes, ls, ac))

[epoch 1/100, episode 50/100] => loss: 2.78017, acc: 0.21000
[epoch 1/100, episode 100/100] => loss: 2.67847, acc: 0.18667
[epoch 2/100, episode 50/100] => loss: 2.75918, acc: 0.15333
[epoch 2/100, episode 100/100] => loss: 2.56399, acc: 0.23000
[epoch 3/100, episode 50/100] => loss: 2.72149, acc: 0.26000
[epoch 3/100, episode 100/100] => loss: 2.63347, acc: 0.19667
[epoch 4/100, episode 50/100] => loss: 2.50564, acc: 0.18333
[epoch 4/100, episode 100/100] => loss: 2.46438, acc: 0.20000
[epoch 5/100, episode 50/100] => loss: 2.58558, acc: 0.22333
[epoch 5/100, episode 100/100] => loss: 2.52959, acc: 0.18667
[epoch 6/100, episode 50/100] => loss: 2.43113, acc: 0.22000
[epoch 6/100, episode 100/100] => loss: 2.40348, acc: 0.20000
[epoch 7/100, episode 50/100] => loss: 2.56218, acc: 0.26000
[epoch 7/100, episode 100/100] => loss: 2.40252, acc: 0.18000
[epoch 8/100, episode 50/100] => loss: 2.48335, acc: 0.23333
[epoch 8/100, episode 100/100] => loss: 2.48831, acc: 0.25000
[epoch 9/100, ep

[epoch 67/100, episode 50/100] => loss: 1.51980, acc: 0.42667
[epoch 67/100, episode 100/100] => loss: 1.64937, acc: 0.36333
[epoch 68/100, episode 50/100] => loss: 1.54091, acc: 0.42333
[epoch 68/100, episode 100/100] => loss: 1.59546, acc: 0.39333
[epoch 69/100, episode 50/100] => loss: 1.45297, acc: 0.42333
[epoch 69/100, episode 100/100] => loss: 1.47947, acc: 0.43667
[epoch 70/100, episode 50/100] => loss: 1.69378, acc: 0.36000
[epoch 70/100, episode 100/100] => loss: 1.51725, acc: 0.42000
[epoch 71/100, episode 50/100] => loss: 1.56655, acc: 0.41333
[epoch 71/100, episode 100/100] => loss: 1.63956, acc: 0.37667
[epoch 72/100, episode 50/100] => loss: 1.38129, acc: 0.48667
[epoch 72/100, episode 100/100] => loss: 1.74691, acc: 0.34000
[epoch 73/100, episode 50/100] => loss: 1.58679, acc: 0.37667
[epoch 73/100, episode 100/100] => loss: 1.68783, acc: 0.37333
[epoch 74/100, episode 50/100] => loss: 1.62780, acc: 0.40333
[epoch 74/100, episode 100/100] => loss: 1.42847, acc: 0.46000


In [13]:
# Load Test Dataset
test_dataset = np.load('mini-imagenet-test.npy')
n_test_classes = test_dataset.shape[0]
print(test_dataset.shape)

(20, 350, 84, 84, 3)


In [14]:
n_test_episodes = 600
n_test_way = 5
n_test_shot = 5
n_test_query = 15

In [15]:
print('Testing...')
avg_acc = 0.
for epi in range(n_test_episodes):
    epi_classes = np.random.permutation(n_test_classes)[:n_test_way]
    support = np.zeros([n_test_way, n_test_shot, im_height, im_width, channels], dtype=np.float32)
    query = np.zeros([n_test_way, n_test_query, im_height, im_width, channels], dtype=np.float32)
    for i, epi_cls in enumerate(epi_classes):
        selected = np.random.permutation(n_examples)[:n_test_shot + n_test_query]
        support[i] = test_dataset[epi_cls, selected[:n_test_shot]]
        query[i] = test_dataset[epi_cls, selected[n_test_shot:]]
    # support = np.expand_dims(support, axis=-1)
    # query = np.expand_dims(query, axis=-1)
    labels = np.tile(np.arange(n_test_way)[:, np.newaxis], (1, n_test_query)).astype(np.uint8)
    ls, ac = sess.run([ce_loss, acc], feed_dict={x: support, q: query, y:labels})
    avg_acc += ac
    if (epi+1) % 50 == 0:
        print('[test episode {}/{}] => loss: {:.5f}, acc: {:.5f}'.format(epi+1, n_test_episodes, ls, ac))
avg_acc /= n_test_episodes
print('Average Test Accuracy: {:.5f}'.format(avg_acc))

Testing...
[test episode 50/600] => loss: 0.81334, acc: 0.61333
[test episode 100/600] => loss: 0.71534, acc: 0.78667
[test episode 150/600] => loss: 0.77452, acc: 0.68000
[test episode 200/600] => loss: 0.65571, acc: 0.74667
[test episode 250/600] => loss: 0.59001, acc: 0.77333
[test episode 300/600] => loss: 0.67461, acc: 0.62667
[test episode 350/600] => loss: 0.66137, acc: 0.70667
[test episode 400/600] => loss: 0.72775, acc: 0.64000
[test episode 450/600] => loss: 1.31093, acc: 0.49333
[test episode 500/600] => loss: 1.14540, acc: 0.52000
[test episode 550/600] => loss: 1.00560, acc: 0.54667
[test episode 600/600] => loss: 0.78214, acc: 0.62667
Average Test Accuracy: 0.61720
