In [1]:
from typing import Dict, List, Iterator

In [43]:
import numpy as np
import random
from pathlib import Path
from PIL import Image

In [36]:
def load_and_process_image(img_path: Path, rot: float):
    with Image.open(img_path) as img:
        img = img.rotate(rot)
        img = img.resize((28, 28))
        
        data = np.asarray(img, dtype=np.float32)
        data = np.transpose(data)  # (width, height) => (height, width)
        data = data.reshape((28, 28, 1))  # (height, width) => (height, width, channel)
        return data

result = load_and_process_image(Path("data/omniglot/data/Angelic/character15/0979_01.png"), 90)
assert result.shape == (28, 28, 1)
assert result.dtype == np.float32

In [37]:
OMNIGLOT_CACHE = {}


def load_class_image(data_dir: Path, clazz: str) -> List[np.ndarray]:
    if clazz not in OMNIGLOT_CACHE:
        alphabet, character, raw_rot = clazz.split('/')
        rot = float(raw_rot[3:])

        image_dir = data_dir / 'data' / alphabet / character

        class_images = sorted(image_dir.glob('*.png'))

        if len(class_images) == 0:
            raise Exception("No images found for omniglot class {} at {}. Did you run download_omniglot.sh first?".format(clazz, data_dir))

        image_list = [load_and_process_image(img_path, rot) for img_path in class_images]
            
        OMNIGLOT_CACHE[clazz] = image_list

    return OMNIGLOT_CACHE[clazz]

result = load_class_image(Path("data/omniglot"), "Angelic/character01/rot000")
assert len(result) == 20

In [49]:
def read_images(data_dir: Path, split: str) -> Dict[str, List[np.ndarray]]:
    split_dir = data_dir / "splits" / "vinyals"
    
    class_names = []
    with open(split_dir / "{:s}.txt".format(split), 'r') as f:
        for class_name in f.readlines():
            class_names.append(class_name.rstrip('\n'))
            
    images = {clazz: load_class_image(data_dir, clazz) for clazz in class_names}
    
    return images


result = read_images(Path("data/omniglot"), "train")
assert len(result) == 4112

In [52]:
def extract_episode(data_dir: Path, split: str, n_support, n_query) -> Dict[str, Dict[str, List[np.ndarray]]]:
    data = read_images(data_dir, split)

    reuslt = {}
    for clazz, images in data.items():
        random.shuffle(images)
        xs = images[:n_support]
        xq = images[n_support:n_support + n_query]
        
        result[clazz] = {"xs": xs, "xq": xq}
        
    return result

result = extract_episode(Path("data/omniglot"), "train", 5, 5)
assert len(result) == 4112

some_key = list(result.keys())[0]
assert len(result[some_key]["xs"]) == 5
assert len(result[some_key]["xq"]) == 5

In [5]:
def generate_episode_batch(data: Dict[str, Dict[str, List[np.ndarray]]], n_episodes, n_way) -> Iterator[Dict[str, List[np.ndarray]]]:
    for i in range(n_episodes):
        yield ...