- すべての言語の文字を混ぜて問題を作っている。
- すべての文字を均等に学習に使いたい的な？


- サブ問題だけで学習させた時、すべての文字を一色単に混ぜて学習させた時と比較する。

In [51]:
import sys
import numpy as np
from PIL import Image
import os
import random

import torch
import torch.utils.data as data


# 何を返せばいいの？
class Omniglot(data.Dataset):
    def __init__(self, num_classes, num_instances, mode='train'):
        
        self.root = '../omniglot/images_background' if mode == 'train' else 'omniglot/images_evaluation'
        
        languages = os.listdir(self.root) # すべての言語
        
        chars = [] # すべての文字の一覧。言語の区別はしない。一次元のリスト。
        for l in languages:
            chars += [os.path.join(l, x) for x in os.listdir(os.path.join(self.root, l))]            
        print("chars[:10]\n", chars[:10], "\n")
        
        classes = random.sample(chars, num_classes)
        print("classes\n", classes, "\n")
        
        labels = np.array(range(len(classes)))
        labels = dict(zip(classes, labels)) 
        print("labels\n", labels, "\n")
        
        instances = dict()
        
        self.train_ids = []
        self.val_ids = []

        # 各クラスから同数ずつtrainとvalをサンプリングする
        for c in classes:
            temp = [os.path.join(c, x) for x in os.listdir(os.path.join(self.root, c))]
            instances[c] = random.sample(temp, len(temp)) # random.shuffleだと代入できない。
            self.train_ids += instances[c][:num_instances]
            self.val_ids += instances[c][num_instances:num_instances*2]
        
        print("self.train_ids\n", self.train_ids, "\n")
        print("self.val_ids\n", self.val_ids, "\n")
        self.train_labels = [labels[self.get_class(x)] for x in self.train_ids]
        self.val_labels = [labels[self.get_class(x)] for x in self.val_ids]
        
        # 各クラスから順に1枚ずつとるので、np.arange(20)みたいな結果になる
        print("self.train_labels\n", self.train_labels, "\n")
        print("self.val_labels\n", self.val_labels, "\n")
        
    def get_class(self, instance):
        return os.path.join(*instance.split('/')[:-1])
        

    
    def load_image(self, idx):
        im = Image.open('{}/{}'.format(self.root, idx)).convert('RGB')
        im = im.resize((28,28), resample=Image.LANCZOS) # per Chelsea's implementation
        im = np.array(im, dtype=np.float32)
        return im
    
    def __len__(self):
        return len(self.img_ids)
    
    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        im = self.load_image(img_id)
        if self.transform is not None:
            im = self.transform(im)
        label = self.labels[idx]
        if self.target_transform is not None:
            label = self.target_transform(label)
        return im, label
    
    def __iter__(self):
        

In [52]:
omniglot = Omniglot(20, 1)

chars[:10]
 ['Hebrew/character12', 'Hebrew/character02', 'Hebrew/character09', 'Hebrew/character07', 'Hebrew/character19', 'Hebrew/character21', 'Hebrew/character11', 'Hebrew/character18', 'Hebrew/character15', 'Hebrew/character14'] 

classes
 ['Alphabet_of_the_Magi/character20', 'Arcadian/character01', 'Tifinagh/character26', 'Grantha/character16', 'Braille/character10', 'Japanese_(hiragana)/character19', 'Balinese/character02', 'Armenian/character04', 'N_Ko/character03', 'Inuktitut_(Canadian_Aboriginal_Syllabics)/character11', 'Tifinagh/character29', 'Syriac_(Estrangelo)/character01', 'Futurama/character12', 'Gujarati/character04', 'Tifinagh/character30', 'Japanese_(katakana)/character06', 'Grantha/character41', 'Gujarati/character02', 'Tifinagh/character33', 'Asomtavruli_(Georgian)/character39'] 

labels
 {'Alphabet_of_the_Magi/character20': 0, 'Arcadian/character01': 1, 'Tifinagh/character26': 2, 'Grantha/character16': 3, 'Braille/character10': 4, 'Japanese_(hiragana)/character19':