In [None]:
from random import randint
from numpy import array
from numpy import argmax
from keras.utils import to_categorical
from utils.text_sum_models import Seq2Seq

from keras.models import Sequential, Model
from keras.layers import Dense, Conv1D, MaxPooling1D, GlobalMaxPooling1D, SeparableConv1D
from keras.layers import Flatten, Dropout, Input, LSTM, BatchNormalization, Activation, TimeDistributed
from keras.layers.embeddings import Embedding

In [None]:
# generate target given source sequence
def predict_sequence(infenc, infdec, source, n_steps, cardinality):
    # encode
    state = infenc.predict(source)
    # start of sequence input
    target_seq = array([0.0 for _ in range(cardinality)]).reshape(1, 1, cardinality)
    # collect predictions
    output = list()
    for t in range(n_steps):
        # predict next char
        yhat, h, c = infdec.predict([target_seq] + state)
        # store prediction
        output.append(yhat[0,0,:])
        # update state
        state = [h, c]
        # update target sequence
        target_seq = yhat
    return array(output)

In [None]:
# generate a sequence of random integers
def generate_sequence(length, n_unique):
    return [randint(1, n_unique-1) for _ in range(length)]

In [None]:
# prepare data for the LSTM
def get_dataset(n_in, n_out, cardinality, n_samples):
    X1, X2, y = list(), list(), list()
    for _ in range(n_samples):
        # generate source sequence
        source = generate_sequence(n_in, cardinality)
        # define target sequence
        target = source[:n_out]
        target.reverse()
        # create padded input target sequence
        target_in = [0] + target[:-1]
        # encode
        src_encoded = to_categorical([source], num_classes=cardinality)
        tar_encoded = to_categorical([target], num_classes=cardinality)
        tar2_encoded = to_categorical([target_in], num_classes=cardinality)
        # store
        X1.append(src_encoded)
        X2.append(tar2_encoded)
        y.append(tar_encoded)
        
    X1 = array(X1)
    X2 = array(X2)
    y = array(y)
    
    X1 = X1.reshape((X1.shape[0], X1.shape[2], X1.shape[3]))
    X2 = X2.reshape((X2.shape[0], X2.shape[2], X2.shape[3]))
    y = y.reshape((y.shape[0], y.shape[2], y.shape[3]))
    
    return X1, X2, y

In [None]:
# decode a one hot encoded string
def one_hot_decode(encoded_seq):
    return [argmax(vector) for vector in encoded_seq]

In [None]:
# configure problem
n_features = 50 + 1
n_steps_in = 6
n_steps_out = 3

In [None]:
# generate a single source and target sequence
X1, X2, y = get_dataset(n_steps_in, n_steps_out, n_features, 1)
print(X1.shape, X2.shape, y.shape)
print('X1=%s, X2=%s, y=%s' % (one_hot_decode(X1[0]), one_hot_decode(X2[0]), one_hot_decode(y[0])))

In [None]:
# generate training dataset
X1, X2, y = get_dataset(n_steps_in, n_steps_out, n_features, 100000)
print(X1.shape,X2.shape,y.shape)

In [None]:
# generate val dataset
X1_val, X2_val, y_val = get_dataset(n_steps_in, n_steps_out, n_features, 100)
print(X1_val.shape,X2_val.shape,y_val.shape)

In [None]:
def define_models(n_input, n_output, n_units):
    # define training encoder
    encoder_inputs = Input(shape=(None, n_input))
    encoder = LSTM(n_units, return_state=True)
    encoder_outputs, state_h, state_c = encoder(encoder_inputs)
    encoder_states = [state_h, state_c]
    # define training decoder
    decoder_inputs = Input(shape=(None, n_output))
    decoder_lstm = LSTM(n_units, return_sequences=True, return_state=True)
    decoder_outputs, _, _ = decoder_lstm(decoder_inputs, initial_state=encoder_states)
    decoder_dense = Dense(n_output, activation='softmax')
    decoder_outputs = decoder_dense(decoder_outputs)
    model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
    # define inference encoder
    encoder_model = Model(encoder_inputs, encoder_states)
    # define inference decoder
    decoder_state_input_h = Input(shape=(n_units,))
    decoder_state_input_c = Input(shape=(n_units,))
    decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
    decoder_outputs, state_h, state_c = decoder_lstm(decoder_inputs, initial_state=decoder_states_inputs)
    decoder_states = [state_h, state_c]
    decoder_outputs = decoder_dense(decoder_outputs)
    decoder_model = Model([decoder_inputs] + decoder_states_inputs, [decoder_outputs] + decoder_states)
    # return all models
    return model, encoder_model, decoder_model

In [None]:
# define model
train, infenc, infdec = define_models(n_features, n_features, 128)
train.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])

In [None]:
# train model
train.fit([X1, X2], y, epochs=1)

In [None]:
for _ in range(10):
    X1, X2, y = get_dataset(n_steps_in, n_steps_out, n_features, 1)
    target = predict_sequence(infenc, infdec, X1, n_steps_out, n_features)
    print('X=%s y=%s, yhat=%s' % (one_hot_decode(X1[0]), one_hot_decode(y[0]), one_hot_decode(target)))

In [None]:
# input_dim = n_features
# output_dim = n_features
# hidden_dim = 51
# input_seq_length = n_steps_in
# output_seq_length = n_steps_out
# epochs = 1
# optimizer = 'adam'
# batch_size = 128

# new_experiment = Seq2Seq(epochs=epochs,
#                                metrics=['accuracy'],
#                                optimizer=optimizer,
#                                batch_size=batch_size, 
#                                input_dim=input_dim,
#                                output_dim=output_dim,
#                                hidden_dim=hidden_dim,
#                                input_seq_length=input_seq_length,
#                                output_seq_length=output_seq_length,
#                                verbose=True)
# new_experiment.build_model()
# new_experiment.model.summary()

In [None]:
# new_experiment.run_experiment(X1, X2, y, X1_val, X2_val, y_val)

In [None]:
# for _ in range(10):
#     X1, X2, y = get_dataset(n_steps_in, n_steps_out, n_features, 1)
#     target = predict_sequence(new_experiment.encoder_model, new_experiment.decoder_model, X1, n_steps_out, n_features)
#     print('X=%s y=%s, yhat=%s' % (one_hot_decode(X1[0]), one_hot_decode(y[0]), one_hot_decode(target)))