In [1]:
from tensorflow.keras.datasets import mnist

In [14]:
import numpy as np
from numpy.lib.stride_tricks import sliding_window_view
import matplotlib.pyplot as plt
%matplotlib inline
import math

In [28]:
### input: array
### output: normalized array [0, 1]
def normalize_arr(arr):
    norm_arr = (arr - np.min(arr)) / (np.max(arr) - np.min(arr))
    return norm_arr

### input: 2d images with shape=(n_images, height, width)
### output: list of patches with shape=(n_images, n_patches, patch_size*patch_size)
def extract_patches(images, patch_size, pad_width):
    # pad all images
    if patch_size > 0:
        images = np.pad(images,
               pad_width = ((0,0), (pad_width, pad_width), (pad_width, pad_width)),
               mode='constant', constant_values=0)

    # apply sliding window across images
    patches = sliding_window_view(images, 
                                  window_shape=(patch_size, patch_size),
                                  axis=(1, 2))

    # flatten patches into vectors
    n_images, x_patches, y_patches, _, _ = patches.shape
    patches = patches.reshape(n_images, x_patches * y_patches, patch_size * patch_size)
    
    return patches

In [36]:
### load mnist dataset and data preprocessing
(x_train, y_train), (x_test, y_test) = mnist.load_data()

print('original   shape:', x_train.shape)
print('           dtype:', x_train.dtype)

# normalize data and change to np.float32
x_train = x_train.astype(np.float32)
x_train = normalize_arr(x_train)

print('normalized shape:', x_train.shape)
print('           dtype:', x_train.dtype)

# extract patches from x_train
patch_size = 3
pad_width = 1

x_train_patches = extract_patches(x_train, patch_size, pad_width)
print('patches    shape:', x_train_patches.shape)
print('           dtype:', x_train_patches.dtype)

# delete x_main
if (x_train_patches.base is not None):
    x_train_patches = x_train_patches.copy()
del x_train

original   shape: (60000, 28, 28)
           dtype: uint8
normalized shape: (60000, 28, 28)
           dtype: float32
patches    shape: (60000, 784, 9)
           dtype: float32
