This file augments the training data and saves the resulting samples in the files aug_train.pkl and aug_train_y.pkl as numpy memmap. It requires a patched version of Keras to augment images with 7 channels.

In [1]:
import pickle
import numpy as np
from sklearn.model_selection import train_test_split
from keras.preprocessing.image import ImageDataGenerator
from tqdm import tqdm
import matplotlib.pyplot as plt

Using TensorFlow backend.


In [2]:
# Load data
data = pickle.load(open('a2_dataTrain.pkl', 'rb'))

for k in data.keys():
    print(k, data[k].shape)

depth (77421, 120, 90)
gestureLabels (77421,)
subjectLabels (77421,)
segmentation (77421, 120, 90, 3)
rgb (77421, 120, 90, 3)


In [3]:
# Combine data into numpy matrix
X = np.concatenate((data['rgb'], data['segmentation'], data['depth'][:, :, :, np.newaxis]), axis=3)
Y = data['gestureLabels']
del data

In [4]:
# Train/test split, random_state for reproducability
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.1, random_state=1337, stratify=Y)
del X, Y

print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')

69678 train samples
7743 test samples


In [5]:
datagen = ImageDataGenerator(
    rotation_range=3,
    fill_mode='nearest',
    width_shift_range=0.1,
    height_shift_range=0.1,
    shear_range=.1,
    channel_shift_range=20,
    horizontal_flip=True,
)

# Number of samples to generate
num_aug = 300000

In [6]:
# Use numpy memmap to save samples
X_aug = np.memmap('aug_train.pkl', dtype='uint8', mode='w+', shape=(num_aug,)+X_train.shape[1:])
Y_aug = np.memmap('aug_train_y.pkl', dtype='uint8', mode='w+', shape=(num_aug,))

In [7]:
# Generate samples
i = 0
with tqdm(total=num_aug) as pbar:
    while i < num_aug:
        for X_batch, Y_batch in datagen.flow(X_train, Y_train, batch_size=1024, shuffle=True):
            batch_len = X_batch.shape[0]
            X_aug[i:min(num_aug, i+batch_len)] = X_batch[:min(batch_len, num_aug-i)].astype(np.uint8)
            Y_aug[i:min(num_aug, i+batch_len)] = Y_batch[:min(batch_len, num_aug-i)]
            i += batch_len
            pbar.update(batch_len)

            if i >= num_aug:
                break
        
        X_aug.flush()
        Y_aug.flush()
        print('Flushed.')

1024it [00:11, 92.99it/s]              

Flushed.





In [None]:
# Inspect generated data
for i in range(0, 5):
    plt.imshow(X_aug[i, ..., 2::-1].astype(np.uint8))
    plt.show()