In [1]:
from typing import Dict, List, Iterator

In [2]:
import numpy as np
import random
from pathlib import Path
from PIL import Image

In [3]:
def load_and_process_image(img_path: Path, rot: float):
    with Image.open(img_path) as img:
        img = img.rotate(rot)
        img = img.resize((28, 28))
        
        data = np.asarray(img, dtype=np.float32)
        data = np.transpose(data)  # (width, height) => (height, width)
        data = data.reshape((28, 28, 1))  # (height, width) => (height, width, channel)
        return data

result = load_and_process_image(Path("data/omniglot/data/Angelic/character15/0979_01.png"), 90)
assert result.shape == (28, 28, 1)
assert result.dtype == np.float32

In [4]:
OMNIGLOT_CACHE = {}


def load_class_image(data_dir: Path, clazz: str) -> List[np.ndarray]:
    if clazz not in OMNIGLOT_CACHE:
        alphabet, character, raw_rot = clazz.split('/')
        rot = float(raw_rot[3:])

        image_dir = data_dir / 'data' / alphabet / character

        class_images = sorted(image_dir.glob('*.png'))

        if len(class_images) == 0:
            raise Exception("No images found for omniglot class {} at {}. Did you run download_omniglot.sh first?".format(clazz, data_dir))

        image_list = [load_and_process_image(img_path, rot) for img_path in class_images]
            
        OMNIGLOT_CACHE[clazz] = image_list

    return OMNIGLOT_CACHE[clazz]

result = load_class_image(Path("data/omniglot"), "Angelic/character01/rot000")
assert len(result) == 20

In [5]:
def read_images(data_dir: Path, split: str) -> Dict[str, List[np.ndarray]]:
    split_dir = data_dir / "splits" / "vinyals"
    
    class_names = []
    with open(split_dir / "{:s}.txt".format(split), 'r') as f:
        for class_name in f.readlines():
            class_names.append(class_name.rstrip('\n'))
            
    images = {clazz: load_class_image(data_dir, clazz) for clazz in class_names}
    
    return images


result = read_images(Path("data/omniglot"), "train")
assert len(result) == 4112

In [6]:
def extract_episode(data_dir: Path, split: str, n_support, n_query) -> Dict[str, Dict[str, List[np.ndarray]]]:
    data = read_images(data_dir, split)

    result = {}
    for clazz, images in data.items():
        random.shuffle(images)
        xs = images[:n_support]
        xq = images[n_support:n_support + n_query]
        
        result[clazz] = {"xs": np.stack(xs), "xq": np.stack(xq)}
        
    return result

result = extract_episode(Path("data/omniglot"), "train", 5, 3)
assert len(result) == 4112

some_key = list(result.keys())[0]
assert result[some_key]["xs"].shape == (5, 28, 28, 1)
assert result[some_key]["xq"].shape == (3, 28, 28, 1)

In [7]:
class EpisodeBatcher:
    def __init__(self, data_dir: Path, split: str, n_support, n_query, n_episodes, n_way):
        self.episode_config = {"data_dir": data_dir, "split": split, "n_support": n_support, "n_query": n_query}
        self.n_episodes = n_episodes
        self.n_way = n_way
    
    def __len__(self):
        return self.n_episodes
    
    def __iter__(self):
        episode = extract_episode(**self.episode_config)
        class_list = list(episode.keys())
        for e in range(self.n_episodes):
            random.shuffle(class_list)
            n_way_class = class_list[:self.n_way]
            data = [v for k, v in episode.items() if k in n_way_class]
            
            xs = np.stack([i["xs"] for i in data])
            xq = np.stack([i["xq"] for i in data])
            
            yield xs, xq

result = list(EpisodeBatcher(Path("data/omniglot"), "train", 5, 5, 6, 5))
assert len(result) == 6

for xs, xq in EpisodeBatcher(Path("data/omniglot"), "train", 5, 3, 6, 10):    
    assert xs.shape == (10, 5, 28, 28, 1)
    assert xq.shape == (10, 3, 28, 28, 1)
    
    break

In [8]:
def pairwise_dist(A, B):  
    """
    Computes pairwise distances between each elements of A and each elements of B.
    Args:
    A,    [m,d] matrix
    B,    [n,d] matrix
    Returns:
    D,    [m,n] matrix of pairwise distances
    """
    # squared norms of each row in A and B
    na = tf.reduce_sum(tf.square(A), 1)
    nb = tf.reduce_sum(tf.square(B), 1)

    # na as a row and nb as a co"lumn vectors
    na = tf.reshape(na, [-1, 1])
    nb = tf.reshape(nb, [1, -1])

    # return pairwise euclidead difference matrix
    D = tf.sqrt(tf.maximum(na - 2*tf.matmul(A, B, False, True) + nb, 0.0))
    return D

In [9]:
import tensorflow as tf


class Img2Vec(tf.keras.Model):
    """
    Img2Vec CNN which takes image of dimension (28x28x?) and return column vector length 64
    """

    def __init__(self):
        super(Img2Vec, self).__init__()

        self.convnet1 = self.sub_block()
        self.convnet2 = self.sub_block()
        self.convnet3 = self.sub_block()
        self.convnet4 = self.sub_block()

    def sub_block(self, out_channels=64, kernel_size=3):
        block = tf.keras.models.Sequential(
            [
                tf.keras.layers.Conv2D(out_channels, kernel_size, padding="same"),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.ReLU(),
                tf.keras.layers.MaxPool2D(),
            ]
        )
        return block

    def call(self, x):
        x = self.convnet1(x)
        x = self.convnet2(x)
        x = self.convnet3(x)
        x = self.convnet4(x)
        x = tf.keras.layers.Flatten()(x)
        return x

In [25]:
img2vec = Img2Vec()
optimizer = tf.keras.optimizers.Adam()
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_acc = tf.keras.metrics.Mean(name='train_acc')

def train_step(xs, xq):
    n_class = xs.shape[0]
    assert xq.shape[0] == n_class
    n_support = xs.shape[1]
    n_query = xq.shape[1]
    
    with tf.GradientTape() as tape:
        target_idx = tf.broadcast_to(
            tf.reshape(tf.range(0, n_class), (-1, 1, 1)),
            (n_class, n_query, 1)
        )

        x = np.concatenate([
            np.reshape(xs, [n_class * n_support, *xs.shape[2:]]),
            np.reshape(xq, [n_class * n_query, *xq.shape[2:]])
        ])

        z = img2vec(x)

        z_proto = tf.math.reduce_mean(tf.reshape(z[:n_class * n_support], (n_class, n_support, -1)), 1)

        z_q = z[n_class * n_support:]

        dists = pairwise_dist(z_q, z_proto)

        log_p_y = tf.reshape(tf.nn.log_softmax(-1 * dists, 1), (n_class, n_query, -1))

        loss_val = -1 * tf.reduce_mean(tf.reshape(tf.gather(log_p_y, target_idx, axis=2, batch_dims=2), (-1, )))
        
        y_hat = tf.math.argmax(log_p_y, 2, output_type=tf.int32)
        acc_val = tf.reduce_mean(tf.cast(tf.math.equal(y_hat, tf.squeeze(target_idx)), tf.float32))

    
    gradients = tape.gradient(loss_val, img2vec.trainable_variables)
    optimizer.apply_gradients(zip(gradients, img2vec.trainable_variables))

    train_loss(loss_val)
    train_acc(acc_val)

In [None]:
import time
from tqdm import tqdm

EPOCHES = 10000
EPISODES = 100

for epoch in range(EPOCHES):
    train_loss.reset_states()
    train_acc.reset_states()
        
    episode_generator = EpisodeBatcher(Path("data/omniglot"), "train", 5, 5, EPISODES, 60)
    for episode in tqdm(episode_generator, desc="Batch {}".format(epoch)):
        train_step(*episode)
        
    template = 'Epoch {}, Loss: {}, Acc: {}'
    print(template.format(epoch + 1, train_loss.result(), train_acc.result()))

Batch 0: 100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [01:52<00:00,  1.13s/it]
Batch 1:   0%|                                                                                          | 0/100 [00:00<?, ?it/s]

Epoch 1, Loss: 1.6581666469573975, Acc: 0.587766706943512


Batch 1: 100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [01:28<00:00,  1.13it/s]
Batch 2:   0%|                                                                                          | 0/100 [00:00<?, ?it/s]

Epoch 2, Loss: 0.6782087087631226, Acc: 0.8044664859771729


Batch 2: 100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [01:32<00:00,  1.08it/s]
Batch 3:   0%|                                                                                          | 0/100 [00:00<?, ?it/s]

Epoch 3, Loss: 0.37346938252449036, Acc: 0.8883668780326843


Batch 3: 100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [01:42<00:00,  1.03s/it]
Batch 4:   0%|                                                                                          | 0/100 [00:00<?, ?it/s]

Epoch 4, Loss: 0.2669461667537689, Acc: 0.9206336140632629


Batch 4: 100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [01:32<00:00,  1.09it/s]
Batch 5:   0%|                                                                                          | 0/100 [00:00<?, ?it/s]

Epoch 5, Loss: 0.22348469495773315, Acc: 0.9323664903640747


Batch 5: 100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [01:48<00:00,  1.08s/it]
Batch 6:   0%|                                                                                          | 0/100 [00:00<?, ?it/s]

Epoch 6, Loss: 0.19715291261672974, Acc: 0.9402332305908203


Batch 6: 100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [01:27<00:00,  1.14it/s]
Batch 7:   0%|                                                                                          | 0/100 [00:00<?, ?it/s]

Epoch 7, Loss: 0.18443991243839264, Acc: 0.9436662197113037


Batch 7: 100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [01:23<00:00,  1.20it/s]
Batch 8:   0%|                                                                                          | 0/100 [00:00<?, ?it/s]

Epoch 8, Loss: 0.1657325178384781, Acc: 0.9496997594833374


Batch 8: 100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [01:24<00:00,  1.18it/s]
Batch 9:   0%|                                                                                          | 0/100 [00:00<?, ?it/s]

Epoch 9, Loss: 0.1602976769208908, Acc: 0.950499951839447


Batch 9: 100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [01:24<00:00,  1.18it/s]
Batch 10:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 10, Loss: 0.13771668076515198, Acc: 0.9577998518943787


Batch 10: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [01:24<00:00,  1.19it/s]
Batch 11:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 11, Loss: 0.13586026430130005, Acc: 0.9597001075744629


Batch 11: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [01:24<00:00,  1.19it/s]
Batch 12:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 12, Loss: 0.12791933119297028, Acc: 0.9572665691375732


Batch 12: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [01:25<00:00,  1.17it/s]
Batch 13:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 13, Loss: 0.12732936441898346, Acc: 0.9608666300773621


Batch 13: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [01:24<00:00,  1.19it/s]
Batch 14:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 14, Loss: 0.12196223437786102, Acc: 0.9622999429702759


Batch 14: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [01:24<00:00,  1.18it/s]
Batch 15:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 15, Loss: 0.11095141619443893, Acc: 0.9660670757293701


Batch 15: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [03:10<00:00,  1.90s/it]
Batch 16:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 16, Loss: 0.10093966126441956, Acc: 0.9685666561126709


Batch 16: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [36:03<00:00, 21.64s/it]
Batch 17:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 17, Loss: 0.10753487795591354, Acc: 0.9676334261894226


Batch 17: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [01:44<00:00,  1.05s/it]
Batch 18:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 18, Loss: 0.10465271770954132, Acc: 0.9678334593772888


Batch 18: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [01:49<00:00,  1.09s/it]
Batch 19:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 19, Loss: 0.1059812381863594, Acc: 0.9669000506401062


Batch 19: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [01:56<00:00,  1.16s/it]
Batch 20:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 20, Loss: 0.09568878263235092, Acc: 0.970033586025238


Batch 20: 100%|████████████████████████████████████████████████████████████████████████████| 100/100 [4:02:08<00:00, 145.29s/it]
Batch 21:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 21, Loss: 0.09750793129205704, Acc: 0.9711000323295593


Batch 21: 100%|█████████████████████████████████████████████████████████████████████████████| 100/100 [1:24:08<00:00, 50.48s/it]
Batch 22:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 22, Loss: 0.08536353707313538, Acc: 0.9727335572242737


Batch 22: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [01:27<00:00,  1.15it/s]
Batch 23:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 23, Loss: 0.09130334109067917, Acc: 0.9709333777427673


Batch 23: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [01:27<00:00,  1.14it/s]
Batch 24:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 24, Loss: 0.08493991196155548, Acc: 0.9744336009025574


Batch 24: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [01:26<00:00,  1.15it/s]
Batch 25:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 25, Loss: 0.08441689610481262, Acc: 0.9732001423835754


Batch 25: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [01:27<00:00,  1.15it/s]
Batch 26:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 26, Loss: 0.08357072621583939, Acc: 0.975199818611145


Batch 26: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [01:30<00:00,  1.11it/s]
Batch 27:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 27, Loss: 0.08207185566425323, Acc: 0.9751332998275757


Batch 27: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [01:27<00:00,  1.14it/s]
Batch 28:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 28, Loss: 0.08294637501239777, Acc: 0.9756000638008118


Batch 28: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [01:30<00:00,  1.11it/s]
Batch 29:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 29, Loss: 0.08167175203561783, Acc: 0.9752665758132935


Batch 29: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [05:20<00:00,  3.21s/it]
Batch 30:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 30, Loss: 0.07733689993619919, Acc: 0.9759998321533203


Batch 30: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [01:26<00:00,  1.15it/s]
Batch 31:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 31, Loss: 0.08196026086807251, Acc: 0.9756665825843811


Batch 31: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [01:29<00:00,  1.12it/s]
Batch 32:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 32, Loss: 0.07579225301742554, Acc: 0.9770665764808655


Batch 32: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [01:27<00:00,  1.15it/s]
Batch 33:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 33, Loss: 0.07257801294326782, Acc: 0.9783331155776978


Batch 33: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [01:26<00:00,  1.15it/s]
Batch 34:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 34, Loss: 0.07135836780071259, Acc: 0.977699875831604


Batch 34: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [01:27<00:00,  1.15it/s]
Batch 35:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 35, Loss: 0.06489169597625732, Acc: 0.9797333478927612


Batch 35: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [01:27<00:00,  1.14it/s]
Batch 36:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 36, Loss: 0.07293123006820679, Acc: 0.9784665107727051


Batch 36: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [01:28<00:00,  1.13it/s]
Batch 37:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 37, Loss: 0.07211887091398239, Acc: 0.9773666262626648


Batch 37: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [01:26<00:00,  1.15it/s]
Batch 38:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 38, Loss: 0.07471217215061188, Acc: 0.977866530418396


Batch 38: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [01:26<00:00,  1.15it/s]
Batch 39:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 39, Loss: 0.06766016781330109, Acc: 0.978966474533081


Batch 39: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [01:31<00:00,  1.09it/s]
Batch 40:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 40, Loss: 0.06566406041383743, Acc: 0.9792333245277405


Batch 40: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [01:45<00:00,  1.05s/it]
Batch 41:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 41, Loss: 0.06261061877012253, Acc: 0.9790666103363037


Batch 41: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [01:27<00:00,  1.14it/s]
Batch 42:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 42, Loss: 0.06085619330406189, Acc: 0.9815664887428284


Batch 42: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [01:25<00:00,  1.17it/s]
Batch 43:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 43, Loss: 0.07030083984136581, Acc: 0.9773666858673096


Batch 43: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [01:28<00:00,  1.13it/s]
Batch 44:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 44, Loss: 0.06048328056931496, Acc: 0.9807662963867188


Batch 44: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [01:38<00:00,  1.02it/s]
Batch 45:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 45, Loss: 0.06218019127845764, Acc: 0.9808666110038757


Batch 45: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [01:30<00:00,  1.11it/s]
Batch 46:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 46, Loss: 0.06509291380643845, Acc: 0.9794996380805969


Batch 46: 100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [01:27<00:00,  1.15it/s]
Batch 47:   0%|                                                                                         | 0/100 [00:00<?, ?it/s]

Epoch 47, Loss: 0.061215486377477646, Acc: 0.9808666110038757


Batch 47:   8%|██████▏                                                                       | 8/100 [35:14<2:44:45, 107.45s/it]