In [6]:
import numpy as np
import pandas as pd
import wandb
import tensorflow as tf
from plotutil import PlotCallback

wandb.init()
config = wandb.config
config.repeated_predictions = False
config.batch_size = 40
config.look_back = 4
config.epochs = 500


def load_data(data_type="airline"):
    """read a CSV into a dataframe"""
    if data_type == "flu":
        df = pd.read_csv('flusearches.csv')
        data = df.flu.astype('float32').values
    elif data_type == "airline":
        df = pd.read_csv('international-airline-passengers.csv')
        data = df.passengers.astype('float32').values
    elif data_type == "sin":
        df = pd.read_csv('sin.csv')
        data = df.sin.astype('float32').values
    return data


def create_dataset(dataset):
    """convert an array of values into a dataset matrix"""
    dataX, dataY = [], []
    for i in range(len(dataset)-config.look_back-1):
        a = dataset[i:(i+config.look_back)]
        dataX.append(a)
        dataY.append(dataset[i + config.look_back])
    return np.array(dataX), np.array(dataY)


data = load_data("airline")

E1029 18:32:57.931004 139851127641920 jupyter.py:104] Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable
wandb: Wandb version 0.8.13 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


In [7]:
# normalize data to between 0 and 1
max_val = max(data)
min_val = min(data)
data = (data-min_val)/(max_val-min_val)

# split into train and test sets
split = int(len(data) * 0.70)
train = data[:split]
test = data[split-config.look_back-2:]

In [11]:
trainX, trainY = create_dataset(train)
testX, testY = create_dataset(test)

In [24]:
# Add channel dimension
trainX = trainX[:, :, np.newaxis]
testX = testX[:, :, np.newaxis]

In [27]:
# create and fit the RNN
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.SimpleRNN(5, input_shape=(config.look_back, 1))) # 5 here is the size of the input
model.add(tf.keras.layers.Dense(1, activation="sigmoid"))
model.compile(loss='mae', optimizer='rmsprop') #rmsprop 
model.fit(trainX, trainY, epochs=config.epochs, batch_size=config.batch_size, validation_data=(testX, testY),  callbacks=[
          PlotCallback(trainX, trainY, testX, testY,
                       config.look_back, config.repeated_predictions),
          wandb.keras.WandbCallback(input_type="time")])

Train on 95 samples, validate on 45 samples
Epoch 1/500


E1029 18:34:56.941313 139851127641920 jupyter.py:104] Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable
wandb: Wandb version 0.8.13 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade
W1029 18:34:57.504840 139851127641920 callbacks.py:257] Method (on_train_batch_end) is slow compared to the batch update (0.581494). Check your callbacks.




W1029 18:34:57.515868 139851127641920 callbacks.py:257] Method (on_train_batch_end) is slow compared to the batch update (0.290772). Check your callbacks.


Epoch 2/500
Epoch 3/500
Epoch 4/500
Epoch 5/500
Epoch 6/500
Epoch 7/500
Epoch 8/500
Epoch 9/500
Epoch 10/500
Epoch 11/500
Epoch 12/500
Epoch 13/500
Epoch 14/500
Epoch 15/500
Epoch 16/500
Epoch 17/500
Epoch 18/500
Epoch 19/500
Epoch 20/500
Epoch 21/500
Epoch 22/500
Epoch 23/500
Epoch 24/500
Epoch 25/500
Epoch 26/500
Epoch 27/500
Epoch 28/500
Epoch 29/500
Epoch 30/500
Epoch 31/500
Epoch 32/500
Epoch 33/500
Epoch 34/500
Epoch 35/500
Epoch 36/500
Epoch 37/500
Epoch 38/500
Epoch 39/500
Epoch 40/500
Epoch 41/500
Epoch 42/500
Epoch 43/500
Epoch 44/500
Epoch 45/500
Epoch 46/500
Epoch 47/500
Epoch 48/500
Epoch 49/500
Epoch 50/500
Epoch 51/500
Epoch 52/500
Epoch 53/500
Epoch 54/500
Epoch 55/500
Epoch 56/500
Epoch 57/500
Epoch 58/500
Epoch 59/500
Epoch 60/500
Epoch 61/500
Epoch 62/500
Epoch 63/500
Epoch 64/500
Epoch 65/500
Epoch 66/500
Epoch 67/500
Epoch 68/500
Epoch 69/500
Epoch 70/500
Epoch 71/500
Epoch 72/500
Epoch 73/500
Epoch 74/500
Epoch 75/500
Epoch 76/500
Epoch 77/500
Epoch 78/500
Epoch 7

<tensorflow.python.keras.callbacks.History at 0x7f30d45316d8>