In [None]:
"""
    Base network Conv component
"""
def conv_block(inputs, out_channels, name='conv'):
    with tf.variable_scope(name):
        conv = tf.layers.conv2d(inputs, out_channels, kernel_size=3, padding='SAME')
        conv = tf.contrib.layers.batch_norm(conv, updates_collections=None, decay=0.99, scale=True, center=True)
        conv = tf.nn.relu(conv)
        conv = tf.contrib.layers.max_pool2d(conv, 2)
        return conv

"""
    Base Network
"""
def encoder(x, h_dim, z_dim, reuse=False):
    with tf.variable_scope('encoder', reuse=reuse):
        net = conv_block(x, h_dim, name='conv_1')
        net = conv_block(net, h_dim, name='conv_2')
        net = conv_block(net, h_dim, name='conv_3')
        net = conv_block(net, z_dim, name='conv_4')
        net = tf.contrib.layers.flatten(net)
        return net

In [None]:
"""
    matching network와의 큰 차이점.
    embedding 공간에서의 clustering을 위해 data point간 distance metric을
    euclidean distance 했을때, 결과가 향상된다.
    
    https://www.baeldung.com/cs/euclidean-distance-vs-cosine-similarity
"""
def euclidean_distance(a, b):
    # a.shape = N x D
    # b.shape = M x D
    N, D = tf.shape(a)[0], tf.shape(a)[1]
    M = tf.shape(b)[0]
    a = tf.tile(tf.expand_dims(a, axis=1), (1, M, 1))
    b = tf.tile(tf.expand_dims(b, axis=0), (N, 1, 1))
    return tf.reduce_mean(tf.square(a - b), axis=2)

In [3]:
n_epochs = 100                                # origin-training loop count
n_episodes = 100                              # task count
n_way = 20                                    # 20 way
n_shot = 5                                    # 5 shot
n_query = 15                                  # 15 (meta_test 용 data count)
n_examples = 350                              # ....
im_width, im_height, channels = 84, 84, 3     # network config
h_dim = 64                                    # hidden layer dimension : hyper-param
z_dim = 64                                    # n-1 layer dimension : hyper-param

In [None]:
# Load Train Dataset
train_dataset = np.load('mini-imagenet-train.npy')     # data-set
n_classes = train_dataset.shape[0]                     # class count (task 구성시 해당 class들의 combination)
print(train_dataset.shape)

In [None]:
x = tf.placeholder(tf.float32, [None, None, im_height, im_width, channels])    # image placeholder setting  -->  (20,5,84,84,3)   : (kway, train_shot, h, w, c)
q = tf.placeholder(tf.float32, [None, None, im_height, im_width, channels])    # query placeholder setting  -->  (20,15,84,84,3)  : (kway, test_shot, h, w, c)
x_shape = tf.shape(x)                                                          # shape=(5,)
q_shape = tf.shape(q)                                                          # shape=(5,)
num_classes, num_support = x_shape[0], x_shape[1]                              # num_classes = 20
                                                                               # num_support = 5         
num_queries = q_shape[1]                                                       # num_queries = 15  
y = tf.placeholder(tf.int64, [None, None])                                     # (20, 15) 20개의 class에 대해, queries 15장
y_one_hot = tf.one_hot(y, depth=num_classes)


# num_classes(20) * num_support(5) = 100
emb_in = encoder(tf.reshape(x, [num_classes * num_support, im_height, im_width, channels]), h_dim, z_dim)
# emb_dim : 1600
emb_dim = tf.shape(emb_in)[-1]
# 각 class별 instance(data,image)들의 평균 : manifold space에서의 centroid를 구함. (20, 5, 1600)
# emb_x : centorid (1~k)가 1600 dim의 vector로 표현 shape=(20: class * support, 1600)
emb_x = tf.reduce_mean(tf.reshape(emb_in, [num_classes, num_support, emb_dim]), axis=1)

# num_classes(20) * num_queries(15) = 300
# emb_q : query data들에 대한 representation, 1600 dim의 vector shape=(300: class * query, 1600)
emb_q = encoder(tf.reshape(q, [num_classes * num_queries, im_height, im_width, channels]), h_dim, z_dim, reuse=True)

# shape=(300, 20)
# 300개의 query instance와 class별 centroid와이 euclid 거리
dists = euclidean_distance(emb_q, emb_x) 

# 0~1 soft_max후 
# shape=(20, 15, 20) 각 query들이 각 class별로 어디에 가장 가까운지. 클래스별, 쿼리, 각 클래스 centroid와의 거리 -> softmax & log
log_p_y = tf.reshape(tf.nn.log_softmax(-dists), [num_classes, num_queries, -1])
ce_loss = -tf.reduce_mean(tf.reshape(tf.reduce_sum(tf.multiply(y_one_hot, log_p_y), axis=-1), [-1]))
acc = tf.reduce_mean(tf.to_float(tf.equal(tf.argmax(log_p_y, axis=-1), y)))

In [None]:
train_op = tf.train.AdamOptimizer().minimize(ce_loss)

In [None]:
# tf syntax
sess = tf.InteractiveSession()
init_op = tf.global_variables_initializer()
sess.run(init_op)

In [None]:

# 전체 epoch loop = 100번
for ep in range(n_epochs):
    # episode == task 
    # n_epochs : meta learner loop,= 100번 
    for epi in range(n_episodes):
        # n_way : 하나의 task에 포함된 class 개수 : way (20way)
        # n_classes : 전체 class 개수, 전체 candidate classes (총 64개 class)
        # epi_classes : random으로 k_way개(5) 추출된 class
        epi_classes = np.random.permutation(n_classes)[:n_way]
        # (20, 5, 84, 84, 3)  --> x
        support = np.zeros([n_way, n_shot, im_height, im_width, channels], dtype=np.float32)
        # (20, 15, 84, 84, 3) --> q
        query = np.zeros([n_way, n_query, im_height, im_width, channels], dtype=np.float32)
        # 즉, meta-train용(support) shot과 meta-test용(query) 필요 image 개수는 20개 (하나의 class내에 20개의 data를 사용)
        # 추출된, meta learner용 class별로 looping
        for i, epi_cls in enumerate(epi_classes):
            # class별로 20장을 추출함. n-shot:support용(5) + query용(15) = (20)
            selected = np.random.permutation(n_examples)[:n_shot + n_query]
            support[i] = train_dataset[epi_cls, selected[:n_shot]] # meta-query data를 채워줌
            query[i] = train_dataset[epi_cls, selected[n_shot:]]   # meta-test data를 채워줌
        # (20, 15) ~ row:n_way(20), col:n_query(15)
        labels = np.tile(np.arange(n_way)[:, np.newaxis], (1, n_query)).astype(np.uint8) # --> y
        _, loss, acc = sess.run([train_op, ce_loss, acc], feed_dict={x:support, q:query, y:labels})
        if (epi+1) % 50 == 0:
            print('[epoch {}/{}, episode {}/{}] => loss: {:.5f}, acc: {:.5f}'.format(ep+1, n_epochs, epi+1, n_episodes, loss, acc))

In [None]:
# Load Test Dataset
test_dataset = np.load('mini-imagenet-test.npy')
n_test_classes = test_dataset.shape[0]
print(test_dataset.shape)

In [None]:
n_test_episodes = 600
n_test_way = 5
n_test_shot = 5
n_test_query = 15

In [None]:
print('Testing...')
avg_acc = 0.
for epi in range(n_test_episodes):
    epi_classes = np.random.permutation(n_test_classes)[:n_test_way]
    support = np.zeros([n_test_way, n_test_shot, im_height, im_width, channels], dtype=np.float32)
    query = np.zeros([n_test_way, n_test_query, im_height, im_width, channels], dtype=np.float32)
    for i, epi_cls in enumerate(epi_classes):
        selected = np.random.permutation(n_examples)[:n_test_shot + n_test_query]
        support[i] = test_dataset[epi_cls, selected[:n_test_shot]]
        query[i] = test_dataset[epi_cls, selected[n_test_shot:]]
    labels = np.tile(np.arange(n_test_way)[:, np.newaxis], (1, n_test_query)).astype(np.uint8)
    loss, acc = sess.run([ce_loss, acc], feed_dict={x: support, q: query, y:labels})
    avg_acc += acc
    if (epi+1) % 50 == 0:
        print('[test episode {}/{}] => loss: {:.5f}, acc: {:.5f}'.format(epi+1, n_test_episodes, loss, acc))
avg_acc /= n_test_episodes
print('Average Test Accuracy: {:.5f}'.format(avg_acc))