In [12]:
import numpy as np
import os
import random
import imageio
from scipy import misc

In [13]:
def get_images(paths, labels, nb_samples = None, shuffle = True):
    """
    获取一组字符文件夹和标签，并返回带有标签的图像文件的路径。
    输入:
        paths: 一个字符文件夹的列表
        labels: 与路径长度相同的列表或numpy数组
        nb_samples: 每个字符检索的图像数量
    输出:
        由元组(标签, 图像路径)构成的列表
    """
    if nb_samples is not None:
        sampler = lambda x: random.sample(x, nb_samples)
    else:
        sampler = lambda x: x
    images_labels = [(i, os.path.join(path, image))
                     for i, path in zip(labels, paths)
                     for image in sampler([pathstr for pathstr in os.listdir(path) if pathstr[-4:] == '.png' ])]
    if shuffle:
        random.shuffle(images_labels)
    return images_labels

In [39]:
def image_file_to_array(filename, dim_input):
    """
    读取图像路径并返回numpy数组
    输入:
        filename: 图像文件名称
        dim_input: 图像的扁平形状
    输出:
        单通道图像
    """
    image = imageio.v2.imread(filename)
    image = image.reshape([dim_input])
    image = image.astype(np.float32) / 255.0
    image = 1.0 - image

    return image

In [40]:
def pair_shuffle(array_a, array_b):
    """
    获取一个图像数组和一个标签数组
    输出:
        打乱的图像数组和标签数组
    """
    temp_perm = np.random.permutation(array_a.shape[0])
    array_a = array_a[temp_perm]
    array_b = array_b[temp_perm]
    return array_a, array_b

In [41]:
def LoadData(num_classes = 50, num_samples_per_class_train = 15, num_samples_per_class_test = 5, seed = 1):
    """
    加载数据并将其分割为训练和测试集
    输入:
        num_classes: 采用的类数，-1表示使用所有类
        num_samples_per_class_train: 每个类用于训练的样本数量
        num_samples_per_class_test: 每个类用于测试的样本数量
        seed: 随机种子以确保结果一致
    输出:
        一个元组：(1)用于训练的图像(2)用于训练的标签(3)用于测试的图像，以及(4)用于测试的标签
            (1) 形状[num_classes * num_samples_per_class_train, 784]，二进制像素的Numpy数组
            (2) 形状[num_classes * num_samples_per_class_train]的Numpy数组，类标签的整数
            (3) 形状[num_classes * num_samples_per_class_test, 784]，二进制像素的Numpy数组
            (4) 形状[num_classes * num_samples_per_class_test]的Numpy数组，类标签的整数
    """
    random.seed(seed)
    np.random.seed(seed)
    num_samples_per_class = num_samples_per_class_train + num_samples_per_class_test
    assert num_classes <= 1623
    assert num_samples_per_class <= 20
    dim_input = 28 * 28   # 784

    # 建立文件夹
    data_folder = './omniglot_resized'
    character_folders = [os.path.join(data_folder, family, character)
                         for family in os.listdir(data_folder)
                         if os.path.isdir(os.path.join(data_folder, family))
                         for character in os.listdir(os.path.join(data_folder, family))
                         if os.path.isdir(os.path.join(data_folder, family, character))]
    random.shuffle(character_folders)
    if num_classes == -1:
        num_classes = len(character_folders)
    else:
        character_folders = character_folders[: num_classes]

    # 读取图像
    all_images = np.zeros(shape = (num_samples_per_class, num_classes, dim_input))
    all_labels = np.zeros(shape = (num_samples_per_class, num_classes))
    label_images = get_images(character_folders, list(range(num_classes)), nb_samples = num_samples_per_class, shuffle = True)
    temp_count = np.zeros(num_classes, dtype=int)
    for label,imagefile in label_images:
        temp_num = temp_count[label]
        all_images[temp_num, label, :] = image_file_to_array(imagefile, dim_input)
        all_labels[temp_num, label] = label
        temp_count[label] += 1

    # 分裂和随机排列
    train_image = all_images[:num_samples_per_class_train].reshape(-1,dim_input)
    test_image  = all_images[num_samples_per_class_train:].reshape(-1,dim_input)
    train_label = all_labels[:num_samples_per_class_train].reshape(-1)
    test_label  = all_labels[num_samples_per_class_train:].reshape(-1)
    train_image, train_label = pair_shuffle(train_image, train_label)
    test_image, test_label = pair_shuffle(test_image, test_label)
    return train_image, train_label, test_image, test_label

In [42]:
train_image, train_label, test_image, test_label = LoadData()