Predicting time series using Recurrent Neural Networks (RNN)

In [None]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import tensorflow as tf
import matplotlib.pyplot as plt

###### Reading data

In [None]:
df = pd.read_csv('/content/data_boston.csv', header='infer', encoding='latin1')

###### Processing

In [None]:
df = df[['PRICE']]

# scale input & X, y
scaler = MinMaxScaler()
ts_scaled = scaler.fit_transform(df)

# sacle
ts_scaled_2 = ts_scaled.reshape(1, -1, 1)

###### training parameters

In [None]:
batch_size = 1
n_epochs  = 1000
learn_rate = 0.0001

###### model

In [None]:
# model
model = tf.keras.Sequential()
model.add(tf.keras.Input(shape=(None, 1)))

In [None]:
# model layers
model.add(tf.keras.layers.SimpleRNN(units=100, return_sequences=True, input_shape=(None, 1)))
model.add(tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(units=1, activation="linear")))

In [None]:
# model summary
model.summary()

###### Optimizer and compile the model

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate = learn_rate)
model.compile(loss="mse", optimizer=optimizer, metrics=["mse"])

###### visualize

In [None]:
plt.plot(summary.history['mse'], c="b")
plt.title('Historial de Entrenamiento')
plt.ylabel('MSE')
plt.xlabel('Época')
plt.legend(['Entrenamiento'], loc='upper right')
plt.show()

###### seed

In [None]:
n_ts_seed = 510
n_predict_time_steps = 20

ts_seed = ts_scaled[0:n_ts_seed]
for i in range(n_predict_time_steps):
  X = ts_seed.reshape(1, -1, 1)
  y_pred = model.predict(X, verbose=0)
  y_last = y_pred[0, -1, 0]
  ts_seed = np.concatenate((ts_seed, np.array([y_last]).reshape(1, 1)), axis=0)

###### visualize prediction

In [None]:
ts = scaler.inverse_transform(ts_seed)
plt.plot(df.PRICE, c='b', linewidth=2, linestyle="-", label="Datos")
plt.plot(ts, c='r', linewidth=2, linestyle="--", label="Ajuste")
plt.xlim(350, n_ts_seed+n_predict_time_steps+10)
plt.legend()
plt.tight_layout()
plt.savefig("out.png")
plt.show()