In [1]:
%matplotlib inline
from __future__ import print_function
from PIL import Image
import numpy as np
import tensorflow as tf
import os
import glob
import matplotlib.pyplot as plt

In [2]:
def conv_block(inputs, out_channels, name='conv'):
    with tf.variable_scope(name):
        conv = tf.layers.conv2d(inputs, out_channels, kernel_size=3, padding='SAME')
        conv = tf.contrib.layers.batch_norm(conv, updates_collections=None, decay=0.99, scale=True, center=True)
        conv = tf.nn.relu(conv)
        conv = tf.contrib.layers.max_pool2d(conv, 2)
        return conv

In [3]:
def encoder(x, h_dim, z_dim, reuse=False):
    with tf.variable_scope('encoder', reuse=reuse):
        net = conv_block(x, h_dim, name='conv_1')
        net = conv_block(net, h_dim, name='conv_2')
        net = conv_block(net, h_dim, name='conv_3')
        net = conv_block(net, z_dim, name='conv_4')
        net = tf.contrib.layers.flatten(net)
        return net

In [4]:
def euclidean_distance(a, b):
    # a.shape = N x D
    # b.shape = M x D
    N, D = tf.shape(a)[0], tf.shape(a)[1]
    M = tf.shape(b)[0]
    a = tf.tile(tf.expand_dims(a, axis=1), (1, M, 1))
    b = tf.tile(tf.expand_dims(b, axis=0), (N, 1, 1))
    return tf.reduce_mean(tf.square(a - b), axis=2)

In [5]:
n_epochs = 32
n_episodes = 10
n_way = 2
n_shot = 3
n_query = 3
n_examples = n_shot+n_query
im_width, im_height, channels = 48, 48, 1
h_dim = 64
z_dim = 64

In [6]:
root_dir = './data/BP_61'
train_path = os.path.join(root_dir, 'train.txt')
with open(train_path, 'r') as train:
    train_classes = [line.rstrip() for line in train.readlines()]
n_classes = len(train_classes)
train_dataset = np.zeros([n_classes, n_examples, im_height, im_width], dtype=np.float32)
for i, tc in enumerate(train_classes):
    genre, character, rotation = tc.split('/')
    rotation = float(rotation[3:])
    im_dir = os.path.join(root_dir, genre, character)
    im_files = sorted(glob.glob(os.path.join(im_dir, '*.png')))
   # print(im_files)
    for j, im_file in enumerate(im_files):
        im = np.array(Image.open(im_file).rotate(rotation).resize((im_width, im_height)), np.float32, copy=False)/255.0
        train_dataset[i, j] = im
       # print(im_file)
print(train_dataset.shape)

['./data/BP_61/train/left/1.png', './data/BP_61/train/left/11.png', './data/BP_61/train/left/3.png', './data/BP_61/train/left/5.png', './data/BP_61/train/left/7.png', './data/BP_61/train/left/9.png']
./data/BP_61/train/left/1.png
./data/BP_61/train/left/11.png
./data/BP_61/train/left/3.png
./data/BP_61/train/left/5.png
./data/BP_61/train/left/7.png
./data/BP_61/train/left/9.png
['./data/BP_61/train/right/0.png', './data/BP_61/train/right/10.png', './data/BP_61/train/right/2.png', './data/BP_61/train/right/4.png', './data/BP_61/train/right/6.png', './data/BP_61/train/right/8.png']
./data/BP_61/train/right/0.png
./data/BP_61/train/right/10.png
./data/BP_61/train/right/2.png
./data/BP_61/train/right/4.png
./data/BP_61/train/right/6.png
./data/BP_61/train/right/8.png
(2, 6, 48, 48)


In [7]:
x = tf.placeholder(tf.float32, [None, None, im_height, im_width, channels])
q = tf.placeholder(tf.float32, [None, None, im_height, im_width, channels])
x_shape = tf.shape(x)
q_shape = tf.shape(q)
num_classes, num_support = x_shape[0], x_shape[1]
num_queries = q_shape[1]
y = tf.placeholder(tf.int64, [None, None])
y_one_hot = tf.one_hot(y, depth=num_classes)
emb_x = encoder(tf.reshape(x, [num_classes * num_support, im_height, im_width, channels]), h_dim, z_dim)
emb_dim = tf.shape(emb_x)[-1]
emb_x = tf.reduce_mean(tf.reshape(emb_x, [num_classes, num_support, emb_dim]), axis=1)
emb_q = encoder(tf.reshape(q, [num_classes * num_queries, im_height, im_width, channels]), h_dim, z_dim, reuse=True)
dists = euclidean_distance(emb_q, emb_x)
log_p_y = tf.reshape(tf.nn.log_softmax(-dists), [num_classes, num_queries, -1])
ce_loss = -tf.reduce_mean(tf.reshape(tf.reduce_sum(tf.multiply(y_one_hot, log_p_y), axis=-1), [-1]))
acc = tf.reduce_mean(tf.to_float(tf.equal(tf.argmax(log_p_y, axis=-1), y)))

In [8]:
train_op = tf.train.AdamOptimizer().minimize(ce_loss)

In [9]:
sess = tf.InteractiveSession()
init_op = tf.global_variables_initializer()
sess.run(init_op)

In [18]:
for ep in range(n_epochs):
    for epi in range(n_episodes):
        epi_classes = np.random.permutation(n_classes)[:n_way]
        support = np.zeros([n_way, n_shot, im_height, im_width], dtype=np.float32)
        query = np.zeros([n_way, n_query, im_height, im_width], dtype=np.float32)
        for i, epi_cls in enumerate(epi_classes):
            selected = np.random.permutation(n_examples)[:n_shot + n_query]
            support[i] = train_dataset[epi_cls, selected[:n_shot]]
            query[i] = train_dataset[epi_cls, selected[n_shot:]]
        support = np.expand_dims(support, axis=-1)
        query = np.expand_dims(query, axis=-1)
        labels = np.tile(np.arange(n_way)[:, np.newaxis], (1, n_query)).astype(np.uint8)
        _, ls, ac = sess.run([train_op, ce_loss, acc], feed_dict={x: support, q: query, y:labels})
        print('[epoch {}/{}, episode {}/{}] => loss: {:.5f}, acc: {:.5f}'.format(ep+1, n_epochs, epi+1, n_episodes, ls, ac))

[epoch 1/32, episode 1/100] => loss: 0.01208, acc: 1.00000
[epoch 1/32, episode 2/100] => loss: 0.01175, acc: 1.00000
[epoch 1/32, episode 3/100] => loss: 0.01091, acc: 1.00000
[epoch 1/32, episode 4/100] => loss: 0.01041, acc: 1.00000
[epoch 1/32, episode 5/100] => loss: 0.00981, acc: 1.00000
[epoch 1/32, episode 6/100] => loss: 0.00946, acc: 1.00000
[epoch 1/32, episode 7/100] => loss: 0.00906, acc: 1.00000
[epoch 1/32, episode 8/100] => loss: 0.00868, acc: 1.00000
[epoch 1/32, episode 9/100] => loss: 0.00823, acc: 1.00000
[epoch 1/32, episode 10/100] => loss: 0.00803, acc: 1.00000
[epoch 1/32, episode 11/100] => loss: 0.00761, acc: 1.00000
[epoch 1/32, episode 12/100] => loss: 0.00730, acc: 1.00000
[epoch 1/32, episode 13/100] => loss: 0.00702, acc: 1.00000
[epoch 1/32, episode 14/100] => loss: 0.00690, acc: 1.00000
[epoch 1/32, episode 15/100] => loss: 0.00661, acc: 1.00000
[epoch 1/32, episode 16/100] => loss: 0.00635, acc: 1.00000
[epoch 1/32, episode 17/100] => loss: 0.00614, ac

[epoch 2/32, episode 39/100] => loss: 0.00119, acc: 1.00000
[epoch 2/32, episode 40/100] => loss: 0.00121, acc: 1.00000
[epoch 2/32, episode 41/100] => loss: 0.00118, acc: 1.00000
[epoch 2/32, episode 42/100] => loss: 0.00116, acc: 1.00000
[epoch 2/32, episode 43/100] => loss: 0.00117, acc: 1.00000
[epoch 2/32, episode 44/100] => loss: 0.00115, acc: 1.00000
[epoch 2/32, episode 45/100] => loss: 0.00114, acc: 1.00000
[epoch 2/32, episode 46/100] => loss: 0.00113, acc: 1.00000
[epoch 2/32, episode 47/100] => loss: 0.00113, acc: 1.00000
[epoch 2/32, episode 48/100] => loss: 0.00111, acc: 1.00000
[epoch 2/32, episode 49/100] => loss: 0.00110, acc: 1.00000
[epoch 2/32, episode 50/100] => loss: 0.00110, acc: 1.00000
[epoch 2/32, episode 51/100] => loss: 0.00109, acc: 1.00000
[epoch 2/32, episode 52/100] => loss: 0.00108, acc: 1.00000
[epoch 2/32, episode 53/100] => loss: 0.00110, acc: 1.00000
[epoch 2/32, episode 54/100] => loss: 0.00108, acc: 1.00000
[epoch 2/32, episode 55/100] => loss: 0.

[epoch 3/32, episode 76/100] => loss: 0.00055, acc: 1.00000
[epoch 3/32, episode 77/100] => loss: 0.00053, acc: 1.00000
[epoch 3/32, episode 78/100] => loss: 0.00053, acc: 1.00000
[epoch 3/32, episode 79/100] => loss: 0.00053, acc: 1.00000
[epoch 3/32, episode 80/100] => loss: 0.00052, acc: 1.00000
[epoch 3/32, episode 81/100] => loss: 0.00054, acc: 1.00000
[epoch 3/32, episode 82/100] => loss: 0.00052, acc: 1.00000
[epoch 3/32, episode 83/100] => loss: 0.00052, acc: 1.00000
[epoch 3/32, episode 84/100] => loss: 0.00054, acc: 1.00000
[epoch 3/32, episode 85/100] => loss: 0.00051, acc: 1.00000
[epoch 3/32, episode 86/100] => loss: 0.00052, acc: 1.00000
[epoch 3/32, episode 87/100] => loss: 0.00052, acc: 1.00000
[epoch 3/32, episode 88/100] => loss: 0.00051, acc: 1.00000
[epoch 3/32, episode 89/100] => loss: 0.00050, acc: 1.00000
[epoch 3/32, episode 90/100] => loss: 0.00050, acc: 1.00000
[epoch 3/32, episode 91/100] => loss: 0.00050, acc: 1.00000
[epoch 3/32, episode 92/100] => loss: 0.

[epoch 5/32, episode 13/100] => loss: 0.00031, acc: 1.00000
[epoch 5/32, episode 14/100] => loss: 0.00031, acc: 1.00000
[epoch 5/32, episode 15/100] => loss: 0.00032, acc: 1.00000
[epoch 5/32, episode 16/100] => loss: 0.00031, acc: 1.00000
[epoch 5/32, episode 17/100] => loss: 0.00031, acc: 1.00000
[epoch 5/32, episode 18/100] => loss: 0.00031, acc: 1.00000
[epoch 5/32, episode 19/100] => loss: 0.00032, acc: 1.00000
[epoch 5/32, episode 20/100] => loss: 0.00030, acc: 1.00000
[epoch 5/32, episode 21/100] => loss: 0.00030, acc: 1.00000
[epoch 5/32, episode 22/100] => loss: 0.00030, acc: 1.00000
[epoch 5/32, episode 23/100] => loss: 0.00030, acc: 1.00000
[epoch 5/32, episode 24/100] => loss: 0.00031, acc: 1.00000
[epoch 5/32, episode 25/100] => loss: 0.00030, acc: 1.00000
[epoch 5/32, episode 26/100] => loss: 0.00030, acc: 1.00000
[epoch 5/32, episode 27/100] => loss: 0.00030, acc: 1.00000
[epoch 5/32, episode 28/100] => loss: 0.00030, acc: 1.00000
[epoch 5/32, episode 29/100] => loss: 0.

[epoch 6/32, episode 51/100] => loss: 0.00021, acc: 1.00000
[epoch 6/32, episode 52/100] => loss: 0.00021, acc: 1.00000
[epoch 6/32, episode 53/100] => loss: 0.00021, acc: 1.00000
[epoch 6/32, episode 54/100] => loss: 0.00022, acc: 1.00000
[epoch 6/32, episode 55/100] => loss: 0.00021, acc: 1.00000
[epoch 6/32, episode 56/100] => loss: 0.00021, acc: 1.00000
[epoch 6/32, episode 57/100] => loss: 0.00021, acc: 1.00000
[epoch 6/32, episode 58/100] => loss: 0.00020, acc: 1.00000
[epoch 6/32, episode 59/100] => loss: 0.00020, acc: 1.00000
[epoch 6/32, episode 60/100] => loss: 0.00020, acc: 1.00000
[epoch 6/32, episode 61/100] => loss: 0.00021, acc: 1.00000
[epoch 6/32, episode 62/100] => loss: 0.00020, acc: 1.00000
[epoch 6/32, episode 63/100] => loss: 0.00021, acc: 1.00000
[epoch 6/32, episode 64/100] => loss: 0.00020, acc: 1.00000
[epoch 6/32, episode 65/100] => loss: 0.00020, acc: 1.00000
[epoch 6/32, episode 66/100] => loss: 0.00020, acc: 1.00000
[epoch 6/32, episode 67/100] => loss: 0.

[epoch 7/32, episode 89/100] => loss: 0.00015, acc: 1.00000
[epoch 7/32, episode 90/100] => loss: 0.00015, acc: 1.00000
[epoch 7/32, episode 91/100] => loss: 0.00015, acc: 1.00000
[epoch 7/32, episode 92/100] => loss: 0.00015, acc: 1.00000
[epoch 7/32, episode 93/100] => loss: 0.00015, acc: 1.00000
[epoch 7/32, episode 94/100] => loss: 0.00015, acc: 1.00000
[epoch 7/32, episode 95/100] => loss: 0.00015, acc: 1.00000
[epoch 7/32, episode 96/100] => loss: 0.00015, acc: 1.00000
[epoch 7/32, episode 97/100] => loss: 0.00015, acc: 1.00000
[epoch 7/32, episode 98/100] => loss: 0.00015, acc: 1.00000
[epoch 7/32, episode 99/100] => loss: 0.00015, acc: 1.00000
[epoch 7/32, episode 100/100] => loss: 0.00015, acc: 1.00000
[epoch 8/32, episode 1/100] => loss: 0.00014, acc: 1.00000
[epoch 8/32, episode 2/100] => loss: 0.00015, acc: 1.00000
[epoch 8/32, episode 3/100] => loss: 0.00014, acc: 1.00000
[epoch 8/32, episode 4/100] => loss: 0.00014, acc: 1.00000
[epoch 8/32, episode 5/100] => loss: 0.0001

[epoch 9/32, episode 27/100] => loss: 0.00011, acc: 1.00000
[epoch 9/32, episode 28/100] => loss: 0.00011, acc: 1.00000
[epoch 9/32, episode 29/100] => loss: 0.00011, acc: 1.00000
[epoch 9/32, episode 30/100] => loss: 0.00011, acc: 1.00000
[epoch 9/32, episode 31/100] => loss: 0.00011, acc: 1.00000
[epoch 9/32, episode 32/100] => loss: 0.00011, acc: 1.00000
[epoch 9/32, episode 33/100] => loss: 0.00011, acc: 1.00000
[epoch 9/32, episode 34/100] => loss: 0.00011, acc: 1.00000
[epoch 9/32, episode 35/100] => loss: 0.00011, acc: 1.00000
[epoch 9/32, episode 36/100] => loss: 0.00011, acc: 1.00000
[epoch 9/32, episode 37/100] => loss: 0.00011, acc: 1.00000
[epoch 9/32, episode 38/100] => loss: 0.00011, acc: 1.00000
[epoch 9/32, episode 39/100] => loss: 0.00011, acc: 1.00000
[epoch 9/32, episode 40/100] => loss: 0.00011, acc: 1.00000
[epoch 9/32, episode 41/100] => loss: 0.00011, acc: 1.00000
[epoch 9/32, episode 42/100] => loss: 0.00011, acc: 1.00000
[epoch 9/32, episode 43/100] => loss: 0.

[epoch 10/32, episode 64/100] => loss: 0.00009, acc: 1.00000
[epoch 10/32, episode 65/100] => loss: 0.00009, acc: 1.00000
[epoch 10/32, episode 66/100] => loss: 0.00009, acc: 1.00000
[epoch 10/32, episode 67/100] => loss: 0.00009, acc: 1.00000
[epoch 10/32, episode 68/100] => loss: 0.00009, acc: 1.00000
[epoch 10/32, episode 69/100] => loss: 0.00009, acc: 1.00000
[epoch 10/32, episode 70/100] => loss: 0.00009, acc: 1.00000
[epoch 10/32, episode 71/100] => loss: 0.00009, acc: 1.00000
[epoch 10/32, episode 72/100] => loss: 0.00009, acc: 1.00000
[epoch 10/32, episode 73/100] => loss: 0.00009, acc: 1.00000
[epoch 10/32, episode 74/100] => loss: 0.00009, acc: 1.00000
[epoch 10/32, episode 75/100] => loss: 0.00009, acc: 1.00000
[epoch 10/32, episode 76/100] => loss: 0.00009, acc: 1.00000
[epoch 10/32, episode 77/100] => loss: 0.00009, acc: 1.00000
[epoch 10/32, episode 78/100] => loss: 0.00009, acc: 1.00000
[epoch 10/32, episode 79/100] => loss: 0.00009, acc: 1.00000
[epoch 10/32, episode 80

[epoch 11/32, episode 99/100] => loss: 0.00007, acc: 1.00000
[epoch 11/32, episode 100/100] => loss: 0.00007, acc: 1.00000
[epoch 12/32, episode 1/100] => loss: 0.00007, acc: 1.00000
[epoch 12/32, episode 2/100] => loss: 0.00007, acc: 1.00000
[epoch 12/32, episode 3/100] => loss: 0.00007, acc: 1.00000
[epoch 12/32, episode 4/100] => loss: 0.00007, acc: 1.00000
[epoch 12/32, episode 5/100] => loss: 0.00007, acc: 1.00000
[epoch 12/32, episode 6/100] => loss: 0.00007, acc: 1.00000
[epoch 12/32, episode 7/100] => loss: 0.00007, acc: 1.00000
[epoch 12/32, episode 8/100] => loss: 0.00007, acc: 1.00000
[epoch 12/32, episode 9/100] => loss: 0.00007, acc: 1.00000
[epoch 12/32, episode 10/100] => loss: 0.00007, acc: 1.00000
[epoch 12/32, episode 11/100] => loss: 0.00007, acc: 1.00000
[epoch 12/32, episode 12/100] => loss: 0.00007, acc: 1.00000
[epoch 12/32, episode 13/100] => loss: 0.00007, acc: 1.00000
[epoch 12/32, episode 14/100] => loss: 0.00007, acc: 1.00000
[epoch 12/32, episode 15/100] =>

[epoch 13/32, episode 34/100] => loss: 0.00006, acc: 1.00000
[epoch 13/32, episode 35/100] => loss: 0.00006, acc: 1.00000
[epoch 13/32, episode 36/100] => loss: 0.00006, acc: 1.00000
[epoch 13/32, episode 37/100] => loss: 0.00006, acc: 1.00000
[epoch 13/32, episode 38/100] => loss: 0.00006, acc: 1.00000
[epoch 13/32, episode 39/100] => loss: 0.00006, acc: 1.00000
[epoch 13/32, episode 40/100] => loss: 0.00006, acc: 1.00000
[epoch 13/32, episode 41/100] => loss: 0.00006, acc: 1.00000
[epoch 13/32, episode 42/100] => loss: 0.00006, acc: 1.00000
[epoch 13/32, episode 43/100] => loss: 0.00006, acc: 1.00000
[epoch 13/32, episode 44/100] => loss: 0.00006, acc: 1.00000
[epoch 13/32, episode 45/100] => loss: 0.00006, acc: 1.00000
[epoch 13/32, episode 46/100] => loss: 0.00006, acc: 1.00000
[epoch 13/32, episode 47/100] => loss: 0.00006, acc: 1.00000
[epoch 13/32, episode 48/100] => loss: 0.00006, acc: 1.00000
[epoch 13/32, episode 49/100] => loss: 0.00006, acc: 1.00000
[epoch 13/32, episode 50

[epoch 14/32, episode 70/100] => loss: 0.00005, acc: 1.00000
[epoch 14/32, episode 71/100] => loss: 0.00005, acc: 1.00000
[epoch 14/32, episode 72/100] => loss: 0.00005, acc: 1.00000
[epoch 14/32, episode 73/100] => loss: 0.00005, acc: 1.00000
[epoch 14/32, episode 74/100] => loss: 0.00005, acc: 1.00000
[epoch 14/32, episode 75/100] => loss: 0.00005, acc: 1.00000
[epoch 14/32, episode 76/100] => loss: 0.00005, acc: 1.00000
[epoch 14/32, episode 77/100] => loss: 0.00005, acc: 1.00000
[epoch 14/32, episode 78/100] => loss: 0.00005, acc: 1.00000
[epoch 14/32, episode 79/100] => loss: 0.00005, acc: 1.00000
[epoch 14/32, episode 80/100] => loss: 0.00005, acc: 1.00000
[epoch 14/32, episode 81/100] => loss: 0.00005, acc: 1.00000
[epoch 14/32, episode 82/100] => loss: 0.00005, acc: 1.00000
[epoch 14/32, episode 83/100] => loss: 0.00005, acc: 1.00000
[epoch 14/32, episode 84/100] => loss: 0.00005, acc: 1.00000
[epoch 14/32, episode 85/100] => loss: 0.00005, acc: 1.00000
[epoch 14/32, episode 86

[epoch 16/32, episode 6/100] => loss: 0.00004, acc: 1.00000
[epoch 16/32, episode 7/100] => loss: 0.00004, acc: 1.00000
[epoch 16/32, episode 8/100] => loss: 0.00004, acc: 1.00000
[epoch 16/32, episode 9/100] => loss: 0.00004, acc: 1.00000
[epoch 16/32, episode 10/100] => loss: 0.00004, acc: 1.00000
[epoch 16/32, episode 11/100] => loss: 0.00004, acc: 1.00000
[epoch 16/32, episode 12/100] => loss: 0.00004, acc: 1.00000
[epoch 16/32, episode 13/100] => loss: 0.00004, acc: 1.00000
[epoch 16/32, episode 14/100] => loss: 0.00004, acc: 1.00000
[epoch 16/32, episode 15/100] => loss: 0.00004, acc: 1.00000
[epoch 16/32, episode 16/100] => loss: 0.00004, acc: 1.00000
[epoch 16/32, episode 17/100] => loss: 0.00004, acc: 1.00000
[epoch 16/32, episode 18/100] => loss: 0.00004, acc: 1.00000
[epoch 16/32, episode 19/100] => loss: 0.00004, acc: 1.00000
[epoch 16/32, episode 20/100] => loss: 0.00004, acc: 1.00000
[epoch 16/32, episode 21/100] => loss: 0.00004, acc: 1.00000
[epoch 16/32, episode 22/100

[epoch 17/32, episode 42/100] => loss: 0.00004, acc: 1.00000
[epoch 17/32, episode 43/100] => loss: 0.00004, acc: 1.00000
[epoch 17/32, episode 44/100] => loss: 0.00004, acc: 1.00000
[epoch 17/32, episode 45/100] => loss: 0.00004, acc: 1.00000
[epoch 17/32, episode 46/100] => loss: 0.00004, acc: 1.00000
[epoch 17/32, episode 47/100] => loss: 0.00003, acc: 1.00000
[epoch 17/32, episode 48/100] => loss: 0.00004, acc: 1.00000
[epoch 17/32, episode 49/100] => loss: 0.00003, acc: 1.00000
[epoch 17/32, episode 50/100] => loss: 0.00004, acc: 1.00000
[epoch 17/32, episode 51/100] => loss: 0.00003, acc: 1.00000
[epoch 17/32, episode 52/100] => loss: 0.00003, acc: 1.00000
[epoch 17/32, episode 53/100] => loss: 0.00004, acc: 1.00000
[epoch 17/32, episode 54/100] => loss: 0.00004, acc: 1.00000
[epoch 17/32, episode 55/100] => loss: 0.00003, acc: 1.00000
[epoch 17/32, episode 56/100] => loss: 0.00004, acc: 1.00000
[epoch 17/32, episode 57/100] => loss: 0.00004, acc: 1.00000
[epoch 17/32, episode 58

[epoch 18/32, episode 78/100] => loss: 0.00003, acc: 1.00000
[epoch 18/32, episode 79/100] => loss: 0.00003, acc: 1.00000
[epoch 18/32, episode 80/100] => loss: 0.00003, acc: 1.00000
[epoch 18/32, episode 81/100] => loss: 0.00003, acc: 1.00000
[epoch 18/32, episode 82/100] => loss: 0.00003, acc: 1.00000
[epoch 18/32, episode 83/100] => loss: 0.00003, acc: 1.00000
[epoch 18/32, episode 84/100] => loss: 0.00003, acc: 1.00000
[epoch 18/32, episode 85/100] => loss: 0.00003, acc: 1.00000
[epoch 18/32, episode 86/100] => loss: 0.00003, acc: 1.00000
[epoch 18/32, episode 87/100] => loss: 0.00003, acc: 1.00000
[epoch 18/32, episode 88/100] => loss: 0.00003, acc: 1.00000
[epoch 18/32, episode 89/100] => loss: 0.00003, acc: 1.00000
[epoch 18/32, episode 90/100] => loss: 0.00003, acc: 1.00000
[epoch 18/32, episode 91/100] => loss: 0.00003, acc: 1.00000
[epoch 18/32, episode 92/100] => loss: 0.00003, acc: 1.00000
[epoch 18/32, episode 93/100] => loss: 0.00003, acc: 1.00000
[epoch 18/32, episode 94

[epoch 20/32, episode 14/100] => loss: 0.00003, acc: 1.00000
[epoch 20/32, episode 15/100] => loss: 0.00003, acc: 1.00000
[epoch 20/32, episode 16/100] => loss: 0.00003, acc: 1.00000
[epoch 20/32, episode 17/100] => loss: 0.00003, acc: 1.00000
[epoch 20/32, episode 18/100] => loss: 0.00003, acc: 1.00000
[epoch 20/32, episode 19/100] => loss: 0.00003, acc: 1.00000
[epoch 20/32, episode 20/100] => loss: 0.00003, acc: 1.00000
[epoch 20/32, episode 21/100] => loss: 0.00003, acc: 1.00000
[epoch 20/32, episode 22/100] => loss: 0.00003, acc: 1.00000
[epoch 20/32, episode 23/100] => loss: 0.00003, acc: 1.00000
[epoch 20/32, episode 24/100] => loss: 0.00003, acc: 1.00000
[epoch 20/32, episode 25/100] => loss: 0.00003, acc: 1.00000
[epoch 20/32, episode 26/100] => loss: 0.00003, acc: 1.00000
[epoch 20/32, episode 27/100] => loss: 0.00003, acc: 1.00000
[epoch 20/32, episode 28/100] => loss: 0.00003, acc: 1.00000
[epoch 20/32, episode 29/100] => loss: 0.00003, acc: 1.00000
[epoch 20/32, episode 30

[epoch 21/32, episode 50/100] => loss: 0.00002, acc: 1.00000
[epoch 21/32, episode 51/100] => loss: 0.00002, acc: 1.00000
[epoch 21/32, episode 52/100] => loss: 0.00002, acc: 1.00000
[epoch 21/32, episode 53/100] => loss: 0.00002, acc: 1.00000
[epoch 21/32, episode 54/100] => loss: 0.00002, acc: 1.00000
[epoch 21/32, episode 55/100] => loss: 0.00002, acc: 1.00000
[epoch 21/32, episode 56/100] => loss: 0.00002, acc: 1.00000
[epoch 21/32, episode 57/100] => loss: 0.00002, acc: 1.00000
[epoch 21/32, episode 58/100] => loss: 0.00002, acc: 1.00000
[epoch 21/32, episode 59/100] => loss: 0.00002, acc: 1.00000
[epoch 21/32, episode 60/100] => loss: 0.00002, acc: 1.00000
[epoch 21/32, episode 61/100] => loss: 0.00002, acc: 1.00000
[epoch 21/32, episode 62/100] => loss: 0.00002, acc: 1.00000
[epoch 21/32, episode 63/100] => loss: 0.00002, acc: 1.00000
[epoch 21/32, episode 64/100] => loss: 0.00002, acc: 1.00000
[epoch 21/32, episode 65/100] => loss: 0.00002, acc: 1.00000
[epoch 21/32, episode 66

[epoch 22/32, episode 86/100] => loss: 0.00002, acc: 1.00000
[epoch 22/32, episode 87/100] => loss: 0.00002, acc: 1.00000
[epoch 22/32, episode 88/100] => loss: 0.00002, acc: 1.00000
[epoch 22/32, episode 89/100] => loss: 0.00002, acc: 1.00000
[epoch 22/32, episode 90/100] => loss: 0.00002, acc: 1.00000
[epoch 22/32, episode 91/100] => loss: 0.00002, acc: 1.00000
[epoch 22/32, episode 92/100] => loss: 0.00002, acc: 1.00000
[epoch 22/32, episode 93/100] => loss: 0.00002, acc: 1.00000
[epoch 22/32, episode 94/100] => loss: 0.00002, acc: 1.00000
[epoch 22/32, episode 95/100] => loss: 0.00002, acc: 1.00000
[epoch 22/32, episode 96/100] => loss: 0.00002, acc: 1.00000
[epoch 22/32, episode 97/100] => loss: 0.00002, acc: 1.00000
[epoch 22/32, episode 98/100] => loss: 0.00002, acc: 1.00000
[epoch 22/32, episode 99/100] => loss: 0.00002, acc: 1.00000
[epoch 22/32, episode 100/100] => loss: 0.00002, acc: 1.00000
[epoch 23/32, episode 1/100] => loss: 0.00002, acc: 1.00000
[epoch 23/32, episode 2/

[epoch 24/32, episode 21/100] => loss: 0.00002, acc: 1.00000
[epoch 24/32, episode 22/100] => loss: 0.00002, acc: 1.00000
[epoch 24/32, episode 23/100] => loss: 0.00002, acc: 1.00000
[epoch 24/32, episode 24/100] => loss: 0.00002, acc: 1.00000
[epoch 24/32, episode 25/100] => loss: 0.00002, acc: 1.00000
[epoch 24/32, episode 26/100] => loss: 0.00002, acc: 1.00000
[epoch 24/32, episode 27/100] => loss: 0.00002, acc: 1.00000
[epoch 24/32, episode 28/100] => loss: 0.00002, acc: 1.00000
[epoch 24/32, episode 29/100] => loss: 0.00002, acc: 1.00000
[epoch 24/32, episode 30/100] => loss: 0.00002, acc: 1.00000
[epoch 24/32, episode 31/100] => loss: 0.00002, acc: 1.00000
[epoch 24/32, episode 32/100] => loss: 0.00002, acc: 1.00000
[epoch 24/32, episode 33/100] => loss: 0.00002, acc: 1.00000
[epoch 24/32, episode 34/100] => loss: 0.00002, acc: 1.00000
[epoch 24/32, episode 35/100] => loss: 0.00002, acc: 1.00000
[epoch 24/32, episode 36/100] => loss: 0.00002, acc: 1.00000
[epoch 24/32, episode 37

[epoch 25/32, episode 56/100] => loss: 0.00002, acc: 1.00000
[epoch 25/32, episode 57/100] => loss: 0.00002, acc: 1.00000
[epoch 25/32, episode 58/100] => loss: 0.00002, acc: 1.00000
[epoch 25/32, episode 59/100] => loss: 0.00002, acc: 1.00000
[epoch 25/32, episode 60/100] => loss: 0.00002, acc: 1.00000
[epoch 25/32, episode 61/100] => loss: 0.00002, acc: 1.00000
[epoch 25/32, episode 62/100] => loss: 0.00002, acc: 1.00000
[epoch 25/32, episode 63/100] => loss: 0.00002, acc: 1.00000
[epoch 25/32, episode 64/100] => loss: 0.00002, acc: 1.00000
[epoch 25/32, episode 65/100] => loss: 0.00002, acc: 1.00000
[epoch 25/32, episode 66/100] => loss: 0.00002, acc: 1.00000
[epoch 25/32, episode 67/100] => loss: 0.00002, acc: 1.00000
[epoch 25/32, episode 68/100] => loss: 0.00002, acc: 1.00000
[epoch 25/32, episode 69/100] => loss: 0.00002, acc: 1.00000
[epoch 25/32, episode 70/100] => loss: 0.00002, acc: 1.00000
[epoch 25/32, episode 71/100] => loss: 0.00002, acc: 1.00000
[epoch 25/32, episode 72

[epoch 26/32, episode 91/100] => loss: 0.00001, acc: 1.00000
[epoch 26/32, episode 92/100] => loss: 0.00001, acc: 1.00000
[epoch 26/32, episode 93/100] => loss: 0.00001, acc: 1.00000
[epoch 26/32, episode 94/100] => loss: 0.00001, acc: 1.00000
[epoch 26/32, episode 95/100] => loss: 0.00001, acc: 1.00000
[epoch 26/32, episode 96/100] => loss: 0.00001, acc: 1.00000
[epoch 26/32, episode 97/100] => loss: 0.00001, acc: 1.00000
[epoch 26/32, episode 98/100] => loss: 0.00001, acc: 1.00000
[epoch 26/32, episode 99/100] => loss: 0.00001, acc: 1.00000
[epoch 26/32, episode 100/100] => loss: 0.00001, acc: 1.00000
[epoch 27/32, episode 1/100] => loss: 0.00001, acc: 1.00000
[epoch 27/32, episode 2/100] => loss: 0.00001, acc: 1.00000
[epoch 27/32, episode 3/100] => loss: 0.00001, acc: 1.00000
[epoch 27/32, episode 4/100] => loss: 0.00001, acc: 1.00000
[epoch 27/32, episode 5/100] => loss: 0.00001, acc: 1.00000
[epoch 27/32, episode 6/100] => loss: 0.00001, acc: 1.00000
[epoch 27/32, episode 7/100] 

[epoch 28/32, episode 26/100] => loss: 0.00001, acc: 1.00000
[epoch 28/32, episode 27/100] => loss: 0.00001, acc: 1.00000
[epoch 28/32, episode 28/100] => loss: 0.00001, acc: 1.00000
[epoch 28/32, episode 29/100] => loss: 0.00001, acc: 1.00000
[epoch 28/32, episode 30/100] => loss: 0.00001, acc: 1.00000
[epoch 28/32, episode 31/100] => loss: 0.00001, acc: 1.00000
[epoch 28/32, episode 32/100] => loss: 0.00001, acc: 1.00000
[epoch 28/32, episode 33/100] => loss: 0.00001, acc: 1.00000
[epoch 28/32, episode 34/100] => loss: 0.00001, acc: 1.00000
[epoch 28/32, episode 35/100] => loss: 0.00001, acc: 1.00000
[epoch 28/32, episode 36/100] => loss: 0.00001, acc: 1.00000
[epoch 28/32, episode 37/100] => loss: 0.00001, acc: 1.00000
[epoch 28/32, episode 38/100] => loss: 0.00001, acc: 1.00000
[epoch 28/32, episode 39/100] => loss: 0.00001, acc: 1.00000
[epoch 28/32, episode 40/100] => loss: 0.00001, acc: 1.00000
[epoch 28/32, episode 41/100] => loss: 0.00001, acc: 1.00000
[epoch 28/32, episode 42

[epoch 29/32, episode 62/100] => loss: 0.00001, acc: 1.00000
[epoch 29/32, episode 63/100] => loss: 0.00001, acc: 1.00000
[epoch 29/32, episode 64/100] => loss: 0.00001, acc: 1.00000
[epoch 29/32, episode 65/100] => loss: 0.00001, acc: 1.00000
[epoch 29/32, episode 66/100] => loss: 0.00001, acc: 1.00000
[epoch 29/32, episode 67/100] => loss: 0.00001, acc: 1.00000
[epoch 29/32, episode 68/100] => loss: 0.00001, acc: 1.00000
[epoch 29/32, episode 69/100] => loss: 0.00001, acc: 1.00000
[epoch 29/32, episode 70/100] => loss: 0.00001, acc: 1.00000
[epoch 29/32, episode 71/100] => loss: 0.00001, acc: 1.00000
[epoch 29/32, episode 72/100] => loss: 0.00001, acc: 1.00000
[epoch 29/32, episode 73/100] => loss: 0.00001, acc: 1.00000
[epoch 29/32, episode 74/100] => loss: 0.00001, acc: 1.00000
[epoch 29/32, episode 75/100] => loss: 0.00001, acc: 1.00000
[epoch 29/32, episode 76/100] => loss: 0.00001, acc: 1.00000
[epoch 29/32, episode 77/100] => loss: 0.00001, acc: 1.00000
[epoch 29/32, episode 78

[epoch 30/32, episode 98/100] => loss: 0.00001, acc: 1.00000
[epoch 30/32, episode 99/100] => loss: 0.00001, acc: 1.00000
[epoch 30/32, episode 100/100] => loss: 0.00001, acc: 1.00000
[epoch 31/32, episode 1/100] => loss: 0.00001, acc: 1.00000
[epoch 31/32, episode 2/100] => loss: 0.00001, acc: 1.00000
[epoch 31/32, episode 3/100] => loss: 0.00001, acc: 1.00000
[epoch 31/32, episode 4/100] => loss: 0.00001, acc: 1.00000
[epoch 31/32, episode 5/100] => loss: 0.00001, acc: 1.00000
[epoch 31/32, episode 6/100] => loss: 0.00001, acc: 1.00000
[epoch 31/32, episode 7/100] => loss: 0.00001, acc: 1.00000
[epoch 31/32, episode 8/100] => loss: 0.00001, acc: 1.00000
[epoch 31/32, episode 9/100] => loss: 0.00001, acc: 1.00000
[epoch 31/32, episode 10/100] => loss: 0.00001, acc: 1.00000
[epoch 31/32, episode 11/100] => loss: 0.00001, acc: 1.00000
[epoch 31/32, episode 12/100] => loss: 0.00001, acc: 1.00000
[epoch 31/32, episode 13/100] => loss: 0.00001, acc: 1.00000
[epoch 31/32, episode 14/100] =>

[epoch 32/32, episode 33/100] => loss: 0.00001, acc: 1.00000
[epoch 32/32, episode 34/100] => loss: 0.00001, acc: 1.00000
[epoch 32/32, episode 35/100] => loss: 0.00001, acc: 1.00000
[epoch 32/32, episode 36/100] => loss: 0.00001, acc: 1.00000
[epoch 32/32, episode 37/100] => loss: 0.00001, acc: 1.00000
[epoch 32/32, episode 38/100] => loss: 0.00001, acc: 1.00000
[epoch 32/32, episode 39/100] => loss: 0.00001, acc: 1.00000
[epoch 32/32, episode 40/100] => loss: 0.00001, acc: 1.00000
[epoch 32/32, episode 41/100] => loss: 0.00001, acc: 1.00000
[epoch 32/32, episode 42/100] => loss: 0.00001, acc: 1.00000
[epoch 32/32, episode 43/100] => loss: 0.00001, acc: 1.00000
[epoch 32/32, episode 44/100] => loss: 0.00001, acc: 1.00000
[epoch 32/32, episode 45/100] => loss: 0.00001, acc: 1.00000
[epoch 32/32, episode 46/100] => loss: 0.00001, acc: 1.00000
[epoch 32/32, episode 47/100] => loss: 0.00001, acc: 1.00000
[epoch 32/32, episode 48/100] => loss: 0.00001, acc: 1.00000
[epoch 32/32, episode 49

In [11]:
# Load Test Dataset
root_dir = './data/BP_61'
test_path = os.path.join(root_dir, 'test.txt')
with open(test_path, 'r') as test:
    test_classes = [line.rstrip() for line in test.readlines()]
n_test_classes = len(test_classes)
test_dataset = np.zeros([n_test_classes, 500, im_height, im_width], dtype=np.float32)
for i, tc in enumerate(test_classes):
    genre, character, rotation = tc.split('/')
    rotation = float(rotation[3:])
    im_dir = os.path.join(root_dir, genre, character)
    im_files = sorted(glob.glob(os.path.join(im_dir, '*.png')))
    for j, im_file in enumerate(im_files):
        im = np.array(Image.open(im_file).rotate(rotation).resize((im_width, im_height)), np.float32, copy=False)/255.0
        test_dataset[i, j] = im
print(test_dataset.shape)

(2, 500, 48, 48)


In [16]:
n_test_episodes = 20
n_test_way = 2
n_test_shot = 5
n_test_query = 480

In [17]:
print('Testing...')
avg_acc = 0.
for epi in range(n_test_episodes):
    epi_classes = np.random.permutation(n_test_classes)[:n_test_way]
    support = np.zeros([n_test_way, n_test_shot, im_height, im_width], dtype=np.float32)
    query = np.zeros([n_test_way, n_test_query, im_height, im_width], dtype=np.float32)
    for i, epi_cls in enumerate(epi_classes):
        selected = np.random.permutation(n_test_shot+n_test_query)[:n_test_shot + n_test_query]
        support[i] = test_dataset[epi_cls, selected[:n_test_shot]]
        query[i] = test_dataset[epi_cls, selected[n_test_shot:]]
    support = np.expand_dims(support, axis=-1)
    query = np.expand_dims(query, axis=-1)
    labels = np.tile(np.arange(n_test_way)[:, np.newaxis], (1, n_test_query)).astype(np.uint8)
    ls, ac = sess.run([ce_loss, acc], feed_dict={x: support, q: query, y:labels})
    avg_acc += ac
    #if (epi+1) % 50 == 0:
    print('[test episode {}/{}] => loss: {:.5f}, acc: {:.5f}'.format(epi+1, n_test_episodes, ls, ac))
avg_acc /= n_test_episodes
print('Average Test Accuracy: {:.5f}'.format(avg_acc))

Testing...
[test episode 1/1000] => loss: 0.69348, acc: 0.51771
[test episode 2/1000] => loss: 0.71608, acc: 0.50833
[test episode 3/1000] => loss: 0.70077, acc: 0.49271
[test episode 4/1000] => loss: 0.70046, acc: 0.51458
[test episode 5/1000] => loss: 0.70270, acc: 0.47708
[test episode 6/1000] => loss: 0.72425, acc: 0.51875
[test episode 7/1000] => loss: 0.69800, acc: 0.48229
[test episode 8/1000] => loss: 0.69542, acc: 0.51458
[test episode 9/1000] => loss: 0.69227, acc: 0.50729
[test episode 10/1000] => loss: 0.70039, acc: 0.48958
[test episode 11/1000] => loss: 0.69464, acc: 0.48854
[test episode 12/1000] => loss: 0.73493, acc: 0.51146
[test episode 13/1000] => loss: 0.69915, acc: 0.52188
[test episode 14/1000] => loss: 0.70848, acc: 0.47083
[test episode 15/1000] => loss: 0.69757, acc: 0.52188
[test episode 16/1000] => loss: 0.69857, acc: 0.50417
[test episode 17/1000] => loss: 0.69323, acc: 0.52500
[test episode 18/1000] => loss: 0.69436, acc: 0.51667
[test episode 19/1000] => 

KeyboardInterrupt: 