### 単射となる変換のみを考える

In [10]:
import pickle
import os
import torch
from torchvision import datasets, transforms

from settings.shift_funcs import get_funcs
import numpy as np

SEED = 0
SAMPLE_SIZE = 1000

In [2]:
def augument_data(source):
    data = []
    for c in source:
        new = [
            c,
            np.rot90(c, np.random.randint(1, 4)),
            np.fliplr(c),
            np.flipud(c)                
        ]
        data.extend(new)
    data = np.array(data, dtype=source.dtype)
    return data

In [3]:
def apply(original_set, filter_indices_set, funcs):  
    dataset = []
    for k, indices in enumerate(filter_indices_set):
        img = original_set[k]
        det_img = img.copy()
        for i in indices:
            func, _ = funcs[i]
            det_img = func(det_img)
        dataset.append(det_img)
    return np.array(dataset)

In [14]:
def create_train_dataset(func_indices, shift_len, C=3, fname='', seed=SEED, size=SAMPLE_SIZE):
    funcs = [get_funcs(*delta) for delta in [(shift_len, shift_len), (shift_len, 0), (0, shift_len)]]

    np.random.seed(seed)
    train_dataset = datasets.MNIST(root='../open_data/', train=True, download=True)
    idx = np.random.choice(train_dataset.data.shape[0], size)
    imgs = train_dataset.data[idx].numpy()
    augumented_dataset = augument_data(imgs)

    train_filter_set = np.random.choice(func_indices, (augumented_dataset.shape[0], C))
    train_dataset = apply(augumented_dataset, train_filter_set, funcs)

    np.savez(
        os.path.join('data', fname),
        train_dataset=train_dataset,
        train_func_labels=train_filter_set,
        original_dataset=augumented_dataset,
    )

In [5]:
# テストデータ
def create_test_dataset(shift_len, fname=''):
    funcs = [get_funcs(*delta) for delta in [(shift_len, shift_len), (shift_len, 0), (0, shift_len)]]
    
    np.random.seed(SEED + 1000)
    test_dataset = datasets.MNIST(root='../open_data/', train=False, download=True).data
    test_filter_set = np.random.choice([0, 1, 2], (test_dataset.shape[0], 3))
    test_dataset = apply(test_dataset.numpy(), test_filter_set, funcs)
    np.savez(
        os.path.join('data', fname),
        test_dataset=test_dataset,
        test_func_labels=test_filter_set,
    )

In [26]:
def create_dataset(shift_len):
    os.makedirs('data/shift%d' % shift_len, exist_ok=True)
    create_test_dataset(shift_len=SHIFT_LEN, fname='shift%d/test_dataset.npz' % shift_len)

    create_train_dataset(func_indices=[0], shift_len=SHIFT_LEN, fname='shift%d/diag_dataset.npz' % shift_len)
    create_train_dataset(func_indices=[0, 2], shift_len=SHIFT_LEN, fname='shift%d/diag_vert_dataset.npz' % shift_len)
    create_train_dataset(func_indices=[0, 1], shift_len=SHIFT_LEN, fname='shift%d/diag_hori_dataset.npz' % shift_len)

### 1. シフト1のデータセットを作成する

In [21]:
create_dataset(shift_len=1)

### 2. シフト2のデータセットを作成する

In [22]:
create_dataset(shift_len=2)

### 2. シフト3のデータセットを作成する

In [27]:
create_dataset(shift_len=3)