In [None]:
import numpy as np
import pickle
import keras
from keras.models import Sequential
from keras.layers import LSTM, Dense
from keras.callbacks import ModelCheckpoint
from sklearn.model_selection import StratifiedKFold
import training_data_generator

weights_save_name = "weights_lstm_sequences.hdf5"

model = None
model_generated = False
cvscores = []

# Save and Load dataset from file
def save_dataset():
    with open(dataset_filename, "wb") as fp:
        pickle.dump(dataset, fp)
#
def load_dataset():
    global dataset
    with open(dataset_filename, "rb") as fp:
        dataset = pickle.load(fp)

# Load weights from file if file exists
def load_pretrained_weights():
    global model
    try:
        model.load_weights(weights_save_name)
    except:
        print('Pre-trained weights do not exist. Please train model to obtain weights')        
        
def generateSequenceData(pattern_count):
    tdg = training_data_generator.training_data_generator(pattern_count)
    trainingdata = tdg.produceTrainingSequences()
    start_xy =[]
    end_xy =[]
    for batch in trainingdata:
        for pattern in batch:
            order = 1
            for sequence in pattern:
                #start = [order,sequence[0],sequence[1]]
                start = [sequence[0],sequence[1]]
                start = np.array(start)
                start_xy.append(start)
                #end = [order,sequence[2],sequence[3]]
                end = [sequence[2],sequence[3]]
                end = np.array(end)
                end_xy.append(end)
                order = order +1

    start_xy = np.array(start_xy)
    start_xy = start_xy.reshape(*start_xy.shape, 1)
    end_xy = np.array(end_xy)
    end_xy = end_xy.reshape(*end_xy.shape, 1)
    return start_xy,end_xy
    

def generate_and_train_model(start_xy, end_xy, load_weights = False):
    global model
    model = Sequential()
    model.add(LSTM(64, return_sequences=False, input_shape=start_xy.shape[1:]))
    model.add(Dense(3, activation='relu'))
    model.compile(loss='mse', optimizer='adam')
    callbacks_list = [ModelCheckpoint(weights_save_name, monitor='loss', verbose=1, save_best_only=True, mode='auto', save_weights_only='True')]
    model.fit(start_xy,end_xy,batch_size=len(start_xy), epochs=5000,  verbose=2)
    if(load_weights==True): load_pretrained_weights()
    model_generated = True


def generate_and_train_model_with_evaluation(start_xy, end_xy,val_start_xy, val_end_xy, test_start_xy, test_end_xy, load_weights = False):
    global model
    model = Sequential()
    model.add(LSTM(64, return_sequences=False, input_shape=start_xy.shape[1:]))
    model.add(Dense(2, activation='relu'))
    #model.compile(loss='mse', optimizer='adam')
    model.compile(optimizer=keras.optimizers.RMSprop(),loss=keras.losses.BinaryCrossentropy())
    callbacks_list = [ModelCheckpoint(weights_save_name, monitor='loss', verbose=1, save_best_only=True, mode='auto', save_weights_only='True')]
    history = model.fit(start_xy,end_xy,batch_size=len(start_xy), epochs=50, validation_data=(val_start_xy, val_end_xy), verbose=2 )
    #print("history: ", history.history)
    # evaluate the model
    scores = model.evaluate(test_start_xy, test_end_xy, verbose=0)
    print("scores: ",scores)
    print("predicting for: ")
    for i in range(10):
        print("start: (",test_start_xy[i][0],"|",test_start_xy[i][1],") end: (",test_end_xy[i][0],"|",test_end_xy[i][1],")")
    predictions = model.predict(test_start_xy[:10])
    print("predictions shape:", predictions.shape)
    print("predictions:")
    for i in range(10):
        print("start: (",test_start_xy[i][0],"|",test_start_xy[i][1],") end: (",predictions[i][0],"|",predictions[i][1],")")

print("start generating sequences")
start_xy,end_xy = generateSequenceData(500)
val_start_xy,val_end_xy = generateSequenceData(50)
test_start_xy,test_end_xy = generateSequenceData(50)
print("training and evaluation")
generate_and_train_model_with_evaluation(start_xy, end_xy,val_start_xy, val_end_xy, test_start_xy, test_end_xy)
#print("%.2f%% (+/- %.2f%%)" % (numpy.mean(cvscores), numpy.std(cvscores)))
