In [None]:
import pandas as pd
import numpy as np
from tqdm import tqdm_notebook as tqdm
from tqdm import tqdm_pandas
import matplotlib.pyplot as plt
import pickle
from nltk import word_tokenize
import tensorflow as tf

import keras.backend as K
from keras import initializers, regularizers
from keras.models import Model
from keras.layers import Input, Activation, Embedding, RNN, LSTM, LSTMCell, Dense, Dropout, Concatenate
from keras.layers import TimeDistributed, Bidirectional, Lambda
from keras.layers import concatenate
from keras.layers.recurrent import Recurrent
from keras.layers.core import Reshape
from keras.activations import tanh, softmax
from keras.utils import to_categorical
from keras.preprocessing.sequence import pad_sequences

In [None]:
# make sure gpu is available
K.tensorflow_backend._get_available_gpus()

In [None]:
K.clear_session()

# import embedding

In [None]:
# dictionaries, pretrained embeddings
with open('data/glv_w2idx.pkl', 'rb') as f:
    w2idx = pickle.load(f)
with open('data/glv_embed_matrix.pkl', 'rb') as f:
    embedding = pickle.load(f)
    
# need to append BOS ('\t') and EOS ('\n') tokens to embeddings
# give (consistently) random initialization since they don't actually mean anything
# padding already exists as '' at the end of the embedding

pad = len(w2idx) - 1

w2idx['\t'] = embedding.shape[0]
np.random.seed(1)
embedding = np.append(embedding, np.random.rand(1, 300), axis=0)

w2idx['\n'] = embedding.shape[0]
np.random.seed(2)
embedding = np.append(embedding, np.random.rand(1, 300), axis=0)

# dataset

In [None]:
# placeholder dataset

df = pd.DataFrame({'Sentence': ["I do not know what to say.", 
                                "The girl will not go to bed.",
                               "He will not come tomorrow night.",
                               "They would not want you to do that.",
                               "We can not believe that this happened.",
                               "I could not handle the truth."], 
                  'Original': ["do not", "will not", "will not", "would not", "can not", "could not"],
                  'Replacement': ["don't", "won't", "won't", "wouldn't", "can't", "couldn't"]})
df.head()

In [None]:
# preprocessing
# change from text to indices

def sent_to_word_idx(df):
    new = []
    for idx, row in tqdm(df.iterrows(), total = df.shape[0]):
        sent = word_tokenize(row['Sentence'])
        # add start-of-sequence ('\t') and end-of-sequence ('\n') markers to all texts
        sent = ['\t'] + sent + ['\n']
        sent_indices = []
        for word in sent:
            word = word.lower()
            if word in w2idx:
                sent_indices.append(w2idx[word])
            else:
                sent_indices.append(pad)            
        new.append(sent_indices)
    df['x_word'] = new
    return df

def orig_to_place_idx(df):
    # takes the part of the sentence to be replaced and turns it into a pair of start/end indices
    y_start = []
    y_end = []
    for idx, row in tqdm(df.iterrows(), total = df.shape[0]):
        sent = word_tokenize(row['Original'])
        sent_indices = []
        for word in sent:
            word = word.lower()
            if word in w2idx:
                sent_indices.append(w2idx[word])
            else:
                sent_indices.append(pad)
        # take indices and find the slice in the whole sentence
        slice_length = len(sent_indices)
        starts = [i for i, x in enumerate(row['x_word']) if x == sent_indices[0]]
        slice_idx = np.nan
        for potential_start in starts:
            potential_slice = row['x_word'][potential_start : potential_start + slice_length]
            if (potential_slice == np.array(sent_indices)).all():
                y_start.append(potential_start)
                y_end.append(potential_start + slice_length - 1)
                break
    df['y_start'] = y_start
    df['y_end'] = y_end
    return df

def repl_to_word_idx(df):
    # takes original & replacement texts and turns them into both decoder input and decoder output
    # both so that teacher forcing can be done
    y_rep = []
    y_orig = []
    for idx, row in tqdm(df.iterrows(), total = df.shape[0]):
        sent = word_tokenize(row['Replacement'])
        # add start-of-sequence ('\t') and end-of-sequence ('\n') markers to all texts
        sent = ['\t'] + sent + ['\n']
        sent_indices = []
        for word in sent:
            word = word.lower()
            if word in w2idx:
                sent_indices.append(w2idx[word])
            else:
                sent_indices.append(pad)
        y_rep.append(sent_indices)
    for idx, row in tqdm(df.iterrows(), total = df.shape[0]):
        sent = word_tokenize(row['Original'])
        # add start-of-sequence ('\t') and end-of-sequence ('\n') markers to all texts
        sent = ['\t'] + sent + ['\n']
        sent_indices = []
        for word in sent:
            word = word.lower()
            if word in w2idx:
                sent_indices.append(w2idx[word])
            else:
                sent_indices.append(pad)
        y_orig.append(sent_indices)
    df['y_orig'] = y_orig
    df['y_rep'] = y_rep
    return df

In [None]:
df = sent_to_word_idx(df)
df = orig_to_place_idx(df)
df = repl_to_word_idx(df)
df.head()

In [None]:
# extract data to arrays from df, add pre-padding

X = pad_sequences(df['x_word'], value = pad).astype('int64')
y_rep = pad_sequences(df['y_rep'], value = pad).astype('int64')
y_orig = pad_sequences(df['y_orig'], value = pad).astype('int64')

# set up target data from output sequence, 1 timestep off from y_rep
#y_rep_output

y_start = to_categorical(np.array(df['y_start']), num_classes = X.shape[1], dtype = 'int64')
y_end = to_categorical(np.array(df['y_end']), num_classes = X.shape[1], dtype = 'int64')
y_rep_cat = np.array([to_categorical(x, num_classes = embedding.shape[0]) for x in y_rep]) 

# training model

In [None]:
num_units = 256 # I think 512 or 1028 is standard - lessening for memory purposes for now
epochs = 50
batch_size = len(X)
learning_rate = 0.1

input_len = X.shape[1]
orig_len = y_orig.shape[1]
repl_len = y_rep.shape[1]

In [None]:
# input sentences in form of word indices
main_input = Input(shape = (input_len,), dtype = 'int64', name = 'main_input')
repl_input = Input(shape = (repl_len,), dtype = 'int64', name = 'repl_input')
orig_input = Input(shape = (orig_len,), dtype = 'int64', name = 'orig_input')

# embedding layer
# note for later: can use mask_zero parameter in embedding layer, but would need to go back and change some indices

with tf.device('/cpu:0'):
    embedding_layer = Embedding(input_dim = embedding.shape[0],
                          output_dim = embedding.shape[1],
                          weights = [embedding],
                          trainable = False, 
                          name = 'embedding_layer')

    input_embed = embedding_layer(main_input)
    repl_embed = embedding_layer(repl_input)
    orig_embed = embedding_layer(orig_input)

In [None]:
# helper function taken from https://github.com/datalogue/keras-attention/issues/15
def _time_distributed_dense(x, w, b=None, dropout=None,
                        input_dim=None, output_dim=None,
                        timesteps=None, training=None):
    """Apply `y . w + b` for every temporal slice y of x.
    # Arguments
        x: input tensor.
        w: weight matrix.
        b: optional bias vector.
        dropout: wether to apply dropout (same dropout mask
            for every temporal slice of the input).
        input_dim: integer; optional dimensionality of the input.
        output_dim: integer; optional dimensionality of the output.
        timesteps: integer; optional number of timesteps.
        training: training phase tensor or boolean.
    # Returns
        Output tensor.
    """

    if not input_dim:
        input_dim = K.shape(x)[2]
    if not timesteps:
        timesteps = K.shape(x)[1]
    if not output_dim:
        output_dim = K.shape(w)[1]

    if dropout is not None and 0. < dropout < 1.:
        # apply the same dropout pattern at every timestep
        ones = K.ones_like(K.reshape(x[:, 0, :], (-1, input_dim)))
        dropout_matrix = K.dropout(ones, dropout)
        expanded_dropout_matrix = K.repeat(dropout_matrix, timesteps)
        x = K.in_train_phase(x * expanded_dropout_matrix, x, training=training)

    # collapse time dimension and batch dimension together
    x = K.reshape(x, (-1, input_dim))
    x = K.dot(x, w)
    if b is not None:
        x = K.bias_add(x, b)
        
    # reshape to 3D tensor
    if K.backend() == 'tensorflow':
        x = K.reshape(x, K.stack([-1, timesteps, output_dim]))
        x.set_shape([None, None, output_dim])
    else:
        x = K.reshape(x, (-1, timesteps, output_dim))
        
    return x

# pointer network implementation
class PointerNet(LSTM):
    def __init__(self, units, *args, **kwargs):
        super().__init__(units, *args, **kwargs)
    
    def build(self, input_shape):
        # immediately set variables for later use
        # keep same number as units as encoder LSTM by default
        self.num_units = input_shape[2]
        self.seq_len = input_shape[1]
        
        # add trainable attention weights
        self.W1 = self.add_weight(name="W1",
                                  shape=(self.num_units, 1),
                                  initializer="uniform",
                                  trainable=True)
        self.W2 = self.add_weight(name="W2",
                                  shape=(self.num_units, 1),
                                  initializer="uniform",
                                  trainable=True)
        self.vt = self.add_weight(name="vt",
                                  shape=(self.seq_len, 1),
                                  initializer='uniform',
                                  trainable=True)
        
        super(PointerNet, self).build(input_shape)
    
    def call(self, x):
        initial_state = self.get_initial_state(x)
                
        pointer, _, _ = K.rnn(self.step, x, initial_state, 
                              constants = [x], input_length = self.seq_len)
        
        return pointer # only need 1 pointer for whole sequence, so h/c don't matter for this task
    
    def step(self, x_input, states):
        # x_input = original input at current time stamp (batch_size, num_units)
        # states = 3 tensors:
        # states[0] = h hidden state (batch_size, num_units)
        # states[1] = c cell state/memory (batch_size, num_units)
        # states[2] = x next word input (batch_size, seq_len, num_units)        
        encoded = states[2]
        _, [h, c] = self.cell.call(x_input, states[0:2])
        decoded = K.repeat(h, self.seq_len)

        # vt*tanh(W1*e+W2*d)
        W1_eij = _time_distributed_dense(encoded, self.W1, output_dim=1)
        W2_dij = _time_distributed_dense(decoded, self.W2, output_dim=1)
        U = self.vt * tanh(W1_eij + W2_dij)
        U = K.squeeze(U, 2) # removes a 1-dimension at 2nd axis

        # softmax over U to get probability distribution over input length
        pointer = softmax(U)
        return pointer, [h, c]
    
    def compute_output_shape(self, input_shape):
        # input shape should be (batch_size, seq_len, units)
        # output shape should be (batch_size, seq_len)
        return (input_shape[0], input_shape[1])
    


In [None]:
### get indices

lstm = LSTM(return_sequences = True, units = num_units, name = 'lstm')(input_embed)

y_start_output = PointerNet(units = num_units, activation="softmax", input_shape = (batch_size, input_len, num_units), name = 'y_start_output')(lstm)
y_end_output = PointerNet(units = num_units, activation="softmax", input_shape = (batch_size, input_len, num_units), name = 'y_end_output')(lstm)

In [None]:
### connect indices to main_input

## later use this to slice and concatenate with context
## in the meantime just assume y_orig is somehow the output

#y_indices = concatenate([K.argmax(y_start_output, axis = 1), K.argmax(y_end_output, axis = 1)])
y_start_sparse = Lambda(lambda x : K.argmax(x, axis = 1))(y_start_output)
y_end_sparse = Lambda(lambda x : K.argmax(x, axis = 1))(y_end_output)
y_start_reshape = Reshape((1,))(y_start_sparse)
y_end_reshape = Reshape((1,))(y_end_sparse)
y_indices = K.cast(concatenate([y_start_reshape, y_end_reshape]), 'int32')
#return_input = concatenate([y_start_reshape, y_end_reshape, main_input], axis = 1)
return_input = tf.gather_nd(input_embed, y_indices)

# https://stackoverflow.com/questions/50820639/keras-how-to-slice-tensor-using-information-from-another-tensor

In [None]:
### feed encoder input (main_input), decoder input (repl_input) and sliced replacement text to enc-dec system

# these should change later to some sort of context-based or conditional model
# also with attention

# decoder given 2*units to accept bidirectional outputs
encoder = Bidirectional(LSTM(return_state = True, units = num_units), name = "encoder")
decoder = LSTM(return_sequences = True, return_state = True, name = "decoder", units = 2 * num_units)

# sequence is unnecessary for the encoder - just states, to start the decoder correctly
# state and sequence for decoder will be necessary in inference, but not right now
enc_output, enc_h_forward, enc_c_forward, enc_h_backward, enc_c_backward = encoder(orig_input)
enc_h = Concatenate()([enc_h_forward, enc_h_backward])
enc_c = Concatenate()([enc_c_forward, enc_c_backward])
dec_output, _, _ = decoder(repl_embed, initial_state = [enc_h, enc_c])

# Dropout?

y_rep_output = TimeDistributed(Dense(embedding.shape[0], activation='softmax'), 
                               name = 'y_rep_output')(dec_output)

In [None]:
#### define & train model

# other parameters
# https://keras.io/examples/lstm_seq2seq/

model = Model(inputs = [main_input, orig_input, repl_input], outputs = [y_start_output, y_end_output, y_rep_output])

model.compile(optimizer = 'adam',
             loss = 'categorical_crossentropy',
             metrics = ['accuracy'])

history = model.fit([X, y_orig, y_rep], [y_start, y_end, y_rep_cat], epochs = epochs, batch_size = batch_size)

In [None]:
model.summary()

In [None]:
plt.plot(history.history['y_start_output_acc'], label='y_start accuracy')
plt.plot(history.history['y_end_output_acc'], label='y_end accuracy')
plt.plot(history.history['y_rep_output_acc'], label='y_rep accuracy')
plt.title('Accuracy')
plt.xlabel('epochs')
plt.ylabel('accuracy')
plt.legend()

In [None]:
plt.plot(history.history['loss'], label='overall train loss')
plt.plot(history.history['y_start_output_loss'], label='y_start loss')
plt.plot(history.history['y_end_output_loss'], label='y_end loss')
plt.title('Loss')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.legend()

# inference mode

# tbd

In [None]:

# Next: inference mode (sampling).
# Here's the drill:
# 1) encode input and retrieve initial decoder state
# 2) run one step of decoder with this initial state
# and a "start of sequence" token as target.
# Output will be the next target token
# 3) Repeat with the current target token and current states

# Define sampling models
encoder_model = Model(encoder_inputs, encoder_states)

decoder_state_input_h = Input(shape=(latent_dim,))
decoder_state_input_c = Input(shape=(latent_dim,))
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
decoder_outputs, state_h, state_c = decoder_lstm(
    decoder_inputs, initial_state=decoder_states_inputs)
decoder_states = [state_h, state_c]
decoder_outputs = decoder_dense(decoder_outputs)
decoder_model = Model(
    [decoder_inputs] + decoder_states_inputs,
    [decoder_outputs] + decoder_states)

# Reverse-lookup token index to decode sequences back to
# something readable.
reverse_input_char_index = dict(
    (i, char) for char, i in input_token_index.items())
reverse_target_char_index = dict(
    (i, char) for char, i in target_token_index.items())


def decode_sequence(input_seq):
    # Encode the input as state vectors.
    states_value = encoder_model.predict(input_seq)

    # Generate empty target sequence of length 1.
    target_seq = np.zeros((1, 1, num_decoder_tokens))
    # Populate the first character of target sequence with the start character.
    target_seq[0, 0, target_token_index['\t']] = 1.

    # Sampling loop for a batch of sequences
    # (to simplify, here we assume a batch of size 1).
    stop_condition = False
    decoded_sentence = ''
    while not stop_condition:
        output_tokens, h, c = decoder_model.predict(
            [target_seq] + states_value)

        # Sample a token
        sampled_token_index = np.argmax(output_tokens[0, -1, :])
        sampled_char = reverse_target_char_index[sampled_token_index]
        decoded_sentence += sampled_char

        # Exit condition: either hit max length
        # or find stop character.
        if (sampled_char == '\n' or
           len(decoded_sentence) > max_decoder_seq_length):
            stop_condition = True

        # Update the target sequence (of length 1).
        target_seq = np.zeros((1, 1, num_decoder_tokens))
        target_seq[0, 0, sampled_token_index] = 1.

        # Update states
        states_value = [h, c]

    return decoded_sentence


for seq_index in range(100):
    # Take one sequence (part of the training set)
    # for trying out decoding.
    input_seq = encoder_input_data[seq_index: seq_index + 1]
    decoded_sentence = decode_sequence(input_seq)
    print('-')
    print('Input sentence:', input_texts[seq_index])
    print('Decoded sentence:', decoded_sentence)

In [None]:
# https://towardsdatascience.com/light-on-math-ml-attention-with-keras-dc8dbc1fad39

class AttentionLayer(Layer):
    """
    This class implements Bahdanau attention (https://arxiv.org/pdf/1409.0473.pdf).
    There are three sets of weights introduced W_a, U_a, and V_a
     """

    def __init__(self, **kwargs):
        super(AttentionLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        assert isinstance(input_shape, list)
        # Create a trainable weight variable for this layer.

        self.W_a = self.add_weight(name='W_a',
                                   shape=tf.TensorShape((input_shape[0][2], input_shape[0][2])),
                                   initializer='uniform',
                                   trainable=True)
        self.U_a = self.add_weight(name='U_a',
                                   shape=tf.TensorShape((input_shape[1][2], input_shape[0][2])),
                                   initializer='uniform',
                                   trainable=True)
        self.V_a = self.add_weight(name='V_a',
                                   shape=tf.TensorShape((input_shape[0][2], 1)),
                                   initializer='uniform',
                                   trainable=True)

        super(AttentionLayer, self).build(input_shape)  # Be sure to call this at the end

    def call(self, inputs, verbose=False):
        """
        inputs: [encoder_output_sequence, decoder_output_sequence]
        """
        assert type(inputs) == list
        encoder_out_seq, decoder_out_seq = inputs
        if verbose:
            print('encoder_out_seq>', encoder_out_seq.shape)
            print('decoder_out_seq>', decoder_out_seq.shape)

        def energy_step(inputs, states):
            """ Step function for computing energy for a single decoder state """

            assert_msg = "States must be a list. However states {} is of type {}".format(states, type(states))
            assert isinstance(states, list) or isinstance(states, tuple), assert_msg

            """ Some parameters required for shaping tensors"""
            en_seq_len, en_hidden = encoder_out_seq.shape[1], encoder_out_seq.shape[2]
            de_hidden = inputs.shape[-1]

            """ Computing S.Wa where S=[s0, s1, ..., si]"""
            # <= batch_size*en_seq_len, latent_dim
            reshaped_enc_outputs = K.reshape(encoder_out_seq, (-1, en_hidden))
            # <= batch_size*en_seq_len, latent_dim
            W_a_dot_s = K.reshape(K.dot(reshaped_enc_outputs, self.W_a), (-1, en_seq_len, en_hidden))
            if verbose:
                print('wa.s>',W_a_dot_s.shape)

            """ Computing hj.Ua """
            U_a_dot_h = K.expand_dims(K.dot(inputs, self.U_a), 1)  # <= batch_size, 1, latent_dim
            if verbose:
                print('Ua.h>',U_a_dot_h.shape)

            """ tanh(S.Wa + hj.Ua) """
            # <= batch_size*en_seq_len, latent_dim
            reshaped_Ws_plus_Uh = K.tanh(K.reshape(W_a_dot_s + U_a_dot_h, (-1, en_hidden)))
            if verbose:
                print('Ws+Uh>', reshaped_Ws_plus_Uh.shape)

            """ softmax(va.tanh(S.Wa + hj.Ua)) """
            # <= batch_size, en_seq_len
            e_i = K.reshape(K.dot(reshaped_Ws_plus_Uh, self.V_a), (-1, en_seq_len))
            # <= batch_size, en_seq_len
            e_i = K.softmax(e_i)

            if verbose:
                print('ei>', e_i.shape)

            return e_i, [e_i]

        def context_step(inputs, states):
            """ Step function for computing ci using ei """
            # <= batch_size, hidden_size
            c_i = K.sum(encoder_out_seq * K.expand_dims(inputs, -1), axis=1)
            if verbose:
                print('ci>', c_i.shape)
            return c_i, [c_i]

        def create_inital_state(inputs, hidden_size):
            # We are not using initial states, but need to pass something to K.rnn funciton
            fake_state = K.zeros_like(inputs)  # <= (batch_size, enc_seq_len, latent_dim
            fake_state = K.sum(fake_state, axis=[1, 2])  # <= (batch_size)
            fake_state = K.expand_dims(fake_state)  # <= (batch_size, 1)
            fake_state = K.tile(fake_state, [1, hidden_size])  # <= (batch_size, latent_dim
            return fake_state

        fake_state_c = create_inital_state(encoder_out_seq, encoder_out_seq.shape[-1])
        fake_state_e = create_inital_state(encoder_out_seq, encoder_out_seq.shape[1])  # <= (batch_size, enc_seq_len, latent_dim

        """ Computing energy outputs """
        # e_outputs => (batch_size, de_seq_len, en_seq_len)
        last_out, e_outputs, _ = K.rnn(
            energy_step, decoder_out_seq, [fake_state_e],
        )

        """ Computing context vectors """
        last_out, c_outputs, _ = K.rnn(
            context_step, e_outputs, [fake_state_c],
        )

        return c_outputs, e_outputs

    def compute_output_shape(self, input_shape):
        """ Outputs produced by the layer """
        return [
            tf.TensorShape((input_shape[1][0], input_shape[1][1], input_shape[1][2])),
            tf.TensorShape((input_shape[1][0], input_shape[1][1], input_shape[0][1]))
        ]

In [None]:
def define_nmt(hidden_size, batch_size, en_timesteps, en_vsize, fr_timesteps, fr_vsize):
    """ Defining a NMT model """

    # Define an input sequence and process it.
    if batch_size:
        encoder_inputs = Input(batch_shape=(batch_size, en_timesteps, en_vsize), name='encoder_inputs')
        decoder_inputs = Input(batch_shape=(batch_size, fr_timesteps - 1, fr_vsize), name='decoder_inputs')
    else:
        encoder_inputs = Input(shape=(en_timesteps, en_vsize), name='encoder_inputs')
        decoder_inputs = Input(shape=(fr_timesteps - 1, fr_vsize), name='decoder_inputs')

    # Encoder GRU
    encoder_gru = Bidirectional(GRU(hidden_size, return_sequences=True, return_state=True, name='encoder_gru'), name='bidirectional_encoder')
    encoder_out, encoder_fwd_state, encoder_back_state = encoder_gru(encoder_inputs)

    # Set up the decoder GRU, using `encoder_states` as initial state.
    decoder_gru = Bidirectional(GRU(hidden_size, return_sequences=True, return_state=True, name='decoder_gru'), name='bidirectional_decoder')
    decoder_out, decoder_fwd_state, decoder_back_state = decoder_gru(decoder_inputs, initial_state=[encoder_fwd_state, encoder_back_state])

    # Attention layer
    attn_layer = AttentionLayer(name='attention_layer')
    attn_out, attn_states = attn_layer([encoder_out, decoder_out])

    # Concat attention input and decoder GRU output
    decoder_concat_input = Concatenate(axis=-1, name='concat_layer')([decoder_out, attn_out])

    # Dense layer
    dense = Dense(fr_vsize, activation='softmax', name='softmax_layer')
    dense_time = TimeDistributed(dense, name='time_distributed_layer')
    decoder_pred = dense_time(decoder_concat_input)

    # Full model
    full_model = Model(inputs=[encoder_inputs, decoder_inputs], outputs=decoder_pred)
    full_model.compile(optimizer='adam', loss='categorical_crossentropy')

    full_model.summary()

    """ Inference model """
    batch_size = 1

    """ Encoder (Inference) model """
    encoder_inf_inputs = Input(batch_shape=(batch_size, en_timesteps, en_vsize), name='encoder_inf_inputs')
    encoder_inf_out, encoder_inf_fwd_state, encoder_inf_back_state = encoder_gru(encoder_inf_inputs)
    encoder_model = Model(inputs=encoder_inf_inputs, outputs=[encoder_inf_out, encoder_inf_fwd_state, encoder_inf_back_state])

    """ Decoder (Inference) model """
    decoder_inf_inputs = Input(batch_shape=(batch_size, 1, fr_vsize), name='decoder_word_inputs')
    encoder_inf_states = Input(batch_shape=(batch_size, en_timesteps, 2*hidden_size), name='encoder_inf_states')
    decoder_init_fwd_state = Input(batch_shape=(batch_size, hidden_size), name='decoder_fwd_init')
    decoder_init_back_state = Input(batch_shape=(batch_size, hidden_size), name='decoder_back_init')

    decoder_inf_out, decoder_inf_fwd_state, decoder_inf_back_state = decoder_gru(decoder_inf_inputs, initial_state=[decoder_init_fwd_state, decoder_init_back_state])
    attn_inf_out, attn_inf_states = attn_layer([encoder_inf_states, decoder_inf_out])
    decoder_inf_concat = Concatenate(axis=-1, name='concat')([decoder_inf_out, attn_inf_out])
    decoder_inf_pred = TimeDistributed(dense)(decoder_inf_concat)
    decoder_model = Model(inputs=[encoder_inf_states, decoder_init_fwd_state, decoder_init_back_state, decoder_inf_inputs],
                          outputs=[decoder_inf_pred, attn_inf_states, decoder_inf_fwd_state, decoder_inf_back_state])

    return full_model, encoder_model, decoder_model

In [None]:
# https://gist.github.com/oarriaga/7ac353a70fd68f953514f4d404c203ae
# https://machinelearningmastery.com/encoder-decoder-attention-sequence-to-sequence-prediction-keras/