In [None]:
import tensorflow as tf
from tensorflow import keras

## Solucionar el tema de los unstable gradients:

### Aplicar Layer Normalization a una SimpleRNNCell:

SimpleRNNCell: This class processes one step within the whole time sequence input, whereas tf.keras.layer.SimpleRNN processes the whole sequence.

In [None]:
class LNSimpleRNNCell(keras.layers.Layer):
    def __init__(self, units, activation="tanh", **kwargs):
        super().__init__(**kwargs)
        # en una SimpleRNNCell, both state_size and output_size son el numero de units
        self.state_size = units
        self.output_size = units
        # no activation porque queremos hacer LN entre medias
        self.simple_rnn_cell = keras.layers.SimpleRNNCell(units, activation=None)
        self.layer_norm = keras.layers.LayerNormalization()
        self.activation = keras.activations.get(activation)

    # initial state para la primera llamada (h_init)
    def get_initial_state(self, inputs = None, batch_size = None, dtype = None):
        if inputs is not None:
            batch_size = tf.shape(inputs)[0] # (batch, ...)
            dtype = inputs.dtype
        return [tf.zeros([batch_size, self.state_size], dtype = dtype)]

    def call(self, inputs, states):
      # outputs == new_states, porque es una SimpleRNNCell
      # orden: RNNCell -> LN -> Activation
        outputs, new_states = self.simple_rnn_cell(inputs, states)
        norm_outputs = self.activation(self.layer_norm(outputs))
        return norm_outputs, [norm_outputs] # output, hidden_state
        

In [None]:
model = keras.models.Sequential([
    # hacemos Keras RNN layer a partir de las Cells que hemos definido
    keras.layers.RNN(LNSimpleRNNCell(20), return_sequences=True,
                     input_shape=[None, 1]), # None = timesteps, 1 = features univariate
    keras.layers.RNN(LNSimpleRNNCell(20), return_sequences=True),
    # return_sequences = True en todos los casos anteriores porque es seq-to-seq
    # TimeDistributed para que se aplique la Dense(10) a cada time step
    keras.layers.TimeDistributed(keras.layers.Dense(10))
])

In [None]:
# para la evaluacion queremos tener en cuenta solo el mse del ultimo time_step
def last_time_step_mse(Y_true, Y_pred):
    # batch_size, time_steps, features
    # [:, -1] es todo el batch, ultimo time_step (la tercera dimension se ignora y se coge todo entiendo)
    # otro ejemplo tonto, [0] en este caso seria primer batch, todo de las demas
    # como time_steps es el segundo, pues -1 en el segundo y ya esta
    return keras.metrics.mean_squared_error(Y_true[:, -1], Y_pred[:, -1])

In [None]:
model.compile(loss = "mse", # en general pq es la loss
              optimizer = "adam",
              metrics = [last_time_step_mse], # para eval, mse solo last time step
)

history = model.fit(X_train, Y_train, epochs = 20, validation_data = (X_valid, Y_valid))

### Dropout (inputs) y Recurrent Dropout (hidden states):

eso lo implementa Keras en todas las RNN (menos RNN) y en todas las RNNCells