# Train a neural network for janken with Keras

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

from tensorflow import keras

## Create an input pipeline
### Load data

In [None]:
ds_train, ds_test = tfds.load(
    'rock_paper_scissors',
    split=['train', 'test'],
    as_supervised=True
)

### Get the training dataset spec

In [None]:
ds_train_spec = tf.data.DatasetSpec.from_value(ds_train)
ds_train_spec

### Build the training pipeline

1. Scale and cast the `uint8` RGB codes to `float32` in `[0, 1]`.
2. Cache the data pre-shuffle (recommended when it fits in memory).
3. Shuffle.
4. Set up batches.
5. Prefetch (an apparent best practice for [performance reasons](https://www.tensorflow.org/guide/data_performance#prefetching)).

In [None]:
def recast(image, label):
    return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    recast,
    num_parallel_calls=tf.data.experimental.AUTOTUNE
)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(buffer_size=1000)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

### Build the test pipeline

Note: caching is done after batching because batches can be the same between epochs?

In [None]:
ds_test = ds_test.map(
    recast,
    num_parallel_calls=tf.data.experimental.AUTOTUNE
)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)

## Create and train the model

In [None]:
model = keras.models.Sequential()

model.add(keras.layers.Conv2D(16, (3, 3), activation='relu', input_shape=(300, 300, 3)))
model.add(keras.layers.MaxPooling2D((2, 2)))
model.add(keras.layers.Conv2D(32, (3, 3), activation='relu'))
model.add(keras.layers.MaxPooling2D((2, 2)))
model.add(keras.layers.Conv2D(64, (3, 3), activation='relu'))
model.add(keras.layers.MaxPooling2D((2, 2)))
model.add(keras.layers.Conv2D(64, (3, 3), activation='relu'))
model.add(keras.layers.MaxPooling2D((2, 2)))
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(256, activation='relu'))
model.add(keras.layers.Dense(3, activation='softmax'))

In [None]:
model.summary()

In [None]:
model.compile(
    loss='sparse_categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy']
)

In [None]:
model.fit(ds_train, epochs=1 + 2520 // 128, validation_data=ds_test)