In [10]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import LSTM, Dense, Input, RNN
from tensorflow.keras.datasets import mnist
import numpy as np

from rnn_layers import PhasedLSTM

In [2]:
def transform_xy(X, y):
    img_size = X.shape
    time_steps = img_size[1]*img_size[2]
    x = tf.cast(tf.reshape(X, [-1, time_steps, 1]), dtype=tf.int32)    
    t = tf.tile([tf.range(time_steps)], [img_size[0], 1])
    t = tf.expand_dims(t, 2)
    X = tf.concat([x, t], axis=2)
    uniques, _ = tf.unique(y)
    n_classes = uniques.shape[0]
    y_one_hot = tf.one_hot(y, n_classes)
    return X, y_one_hot

In [5]:
(X, y), (X_test, y_test) = mnist.load_data()
X_train, y_train = X[:-10000], y[:-10000]
X_val, y_val = X[10000:], y[10000:]

# Transforming input and labels
X_train, y_train = transform_xy(X_train, y_train)
X_val, y_val     = transform_xy(X_val, y_val)
X_test, y_test   = transform_xy(X_test, y_test)

In [13]:
model = tf.keras.Sequential()
model.add(Input(shape=(None, 2), name="inputs"))
model.add(RNN(PhasedLSTM(128)))
model.add(Dense(10))
model.summary()

Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
rnn_3 (RNN)                  (None, 128)               66944     
_________________________________________________________________
dense_1 (Dense)              (None, 10)                1290      
Total params: 68,234
Trainable params: 68,234
Non-trainable params: 0
_________________________________________________________________


In [14]:
model.compile(optimizer="adam", 
              loss=keras.losses.CategoricalCrossentropy(),
              metrics=[keras.metrics.CategoricalAccuracy()])

In [15]:
history = model.fit(
    X_train,
    y_train,
    batch_size=128,
    epochs=2,
    validation_data=(X_val, y_val),
)

Epoch 1/2
  8/391 [..............................] - ETA: 17:52 - loss: 6.1926 - categorical_accuracy: 0.1035

KeyboardInterrupt: 