### Action空間が単射のもの

In [1]:
import pickle
import torch
from torchvision import datasets, transforms

SEED = 0
SAMPLE_SIZE = 1000

In [3]:
import numpy as np

def get_funcs(delta_x, delta_y):
    func_shift = lambda A: np.roll(A, (delta_x, delta_y), axis=(0, 1))
    func_unshift = lambda A: np.roll(A, (-delta_x, -delta_y), axis=(0, 1))
    return func_shift, func_unshift

def augument_data(source):
    data = []
    for c in source:
        new = [
            c,
            np.rot90(c, np.random.randint(1, 4)),
            np.fliplr(c),
            np.flipud(c)                
        ]
        data.extend(new)
    data = np.array(data, dtype=source.dtype)
    return data

In [4]:
funcs = [get_funcs(*delta) for delta in [(2, 2), (2, 0), (0, 2)]]

np.random.seed(SEED)
data = datasets.MNIST(root='./data/', train=True, download=True)
idx = np.random.choice(data.data.shape[0], SAMPLE_SIZE)

In [5]:
imgs = data.data[idx].numpy()
augumented_data = augument_data(imgs)

In [6]:
train_filter_set = np.random.choice([0, 2], (augumented_data.shape[0], 3))  # 保存
train_data = []
for k, filt_idx in enumerate(train_filter_set):
    # 加工する
    img = augumented_data[k]
    deteriolated_img = img.copy()
    for i in filt_idx:
        func, _ = funcs[i]
        deteriolated_img = func(deteriolated_img)
    train_data.append([img, deteriolated_img])

In [7]:
train_data = np.array(train_data)

In [8]:
train_data.shape

(4000, 2, 28, 28)

In [71]:
np.savez(
    'data/shift_dataset.npz',
    train_dataset=train_data,
    original_dataset=augumented_data,
    train_func_labels=train_filter_set
)