# TFLite Micro Seq2Seq: LSTM w/ attention & bidirectional encoder

In this notebook, we first build an attentional LSTM-based encoder-decoder model for sequence to sequence applications such as machine translation, then we convert it to TFLite models such that we can deploy it using TFLite Micro.

Full instructions including pretrained TFLite models and source code for deploying converted models to Arduino Nano 33 BLE can be found at https://github.com/da03/TFLite-Micro-Seq2Seq.

## Dependencies

We fully tested our models on TensorFlow 2.1.0. We used this particular version because the latest precompiled package for Arduino is built using TF 2.1.0.

In [None]:
!apt-get -qq update && apt-get -qq install xxd
!pip install -q tensorflow==2.1.0!pip install -U numpy==1.18.5

In [None]:
import os

import numpy as np
import tensorflow as tf

## Data

Seq2seq models can be used for a variety of applications, such as machine translation, document summarization, math-image-to-LaTeX etc. For demonstration purposes, we use a number-to-word task here, where the goal is to convert a number to English words:

|      Input |       Output         |
|------------|----------------------|
|    7929    |seven thousand nine hundred and twenty nine|
|   842259   |eight hundred and forty two thousand two hundred and fifty nine|
|   508217   |five hundred and eight thousand two hundred and seventeen|

|      Module |    Configuration |   #Parameters         |
|------------|----------|------------|
|    Src Embeddings    |embedding size 64| 1k |
|   Tgt Embeddings   |embedding size 64| 2k |
|   Encoder LSTM (l2r)   |hidden size 32| 12k |
|   Encoder LSTM (r2l)   |hidden size 32| 12k |
|   Decoder LSTM   |hidden size 64| 38k|
| Total | - | 65k|

In [None]:
!wget -nv -N -P data https://raw.githubusercontent.com/da03/TFLite-Micro-Seq2Seq/master/data/train.src
!wget -nv -N -P data https://raw.githubusercontent.com/da03/TFLite-Micro-Seq2Seq/master/data/train.tgt
!wget -nv -N -P data https://raw.githubusercontent.com/da03/TFLite-Micro-Seq2Seq/master/data/dev.src
!wget -nv -N -P data https://raw.githubusercontent.com/da03/TFLite-Micro-Seq2Seq/master/data/dev.tgt
!wget -nv -N -P data https://raw.githubusercontent.com/da03/TFLite-Micro-Seq2Seq/master/data/test.src
!wget -nv -N -P data https://raw.githubusercontent.com/da03/TFLite-Micro-Seq2Seq/master/data/test.tgt

Let's take a look at our dataset:

In [None]:
with open('data/dev.src') as fsrc:
  with open('data/dev.tgt') as ftgt:
    print (f'{"Source":20s} {"Target":70s}')
    for src, tgt, _ in zip(fsrc, ftgt, range(3)):
      print (f'{src.strip():20s} {tgt.strip():70s}')

## Data Preprocessing

We load and tokenize data, build vocabulary, and convert words to word ids. 

For target sentence, we prepend a special token `<bos>` for beginning-of-sentence, and append a special token `<eos>` for end-of-sentence. For decoder input, we remove the last token `<eos>`, and for decoder ground truth target we remove the first token `<bos>`. For example, if the sentence is `seven thousand`, then decoder input is `<bos> seven thousand`, and the decoder ground truth target is `seven thousand <eos>`.

Finally, we convert preprocessed data to numpy arrays.

In [None]:
# Build vocabulary from training data
src_train = []
tgt_train = []

with open('data/train.src') as fsrc:
  with open('data/train.tgt') as ftgt:
    for src, tgt in zip(fsrc, ftgt):
      src_train.append(src.strip())
      tgt_train.append('<bos> ' + tgt.strip() + ' <eos>')

print (f'Size of training set: {len(src_train)}')
src_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='', lower=False)
src_tokenizer.fit_on_texts(src_train)

tgt_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='', lower=False)
tgt_tokenizer.fit_on_texts(tgt_train)

SRC_VOCAB_SIZE = len(src_tokenizer.word_index) + 1
print(f'Size of source vocab: {SRC_VOCAB_SIZE}')

TGT_VOCAB_SIZE = len(tgt_tokenizer.word_index) + 1
print(f'Size of target vocab: {TGT_VOCAB_SIZE}')

src_vocab = []
for word in src_tokenizer.word_index:
  src_vocab.append(word)
tgt_vocab = []
for word in tgt_tokenizer.word_index:
  tgt_vocab.append(word)

In [None]:
# Prepare training data
# encoder_input_data
tokenized_src = src_tokenizer.texts_to_sequences(src_train)
max_len_src = max([len(x) for x in tokenized_src])
padded_src = tf.keras.preprocessing.sequence.pad_sequences(tokenized_src, 
                                                           maxlen=max_len_src, 
                                                           padding='post')
encoder_input_data = np.array(padded_src)

# decoder_input_data
tokenized_tgt = tgt_tokenizer.texts_to_sequences(tgt_train)
max_len_tgt = max([len(x) for x in tokenized_tgt])
padded_tgt = tf.keras.preprocessing.sequence.pad_sequences(tokenized_tgt, 
                                                           maxlen=max_len_tgt, 
                                                           padding='post')
decoder_input_data = np.array(padded_tgt)

# decoder_output_data
tokenized_tgt = tgt_tokenizer.texts_to_sequences(tgt_train)
for i in range(len(tokenized_tgt)) :
  tokenized_tgt[i] = tokenized_tgt[i][1:]
padded_tgt = tf.keras.preprocessing.sequence.pad_sequences(tokenized_tgt, 
                                                           maxlen=max_len_tgt, 
                                                           padding='post')
onehot_tgt = tf.keras.utils.to_categorical(padded_tgt, TGT_VOCAB_SIZE)
decoder_output_data = np.array(onehot_tgt)

In [None]:
# Prepare test data
src_test = []
tgt_test = []

with open('data/test.src') as fsrc:
  with open('data/test.tgt') as ftgt:
    for src, tgt in zip(fsrc, ftgt):
      src_test.append(src.strip())
      tgt_test.append('<bos> ' + tgt.strip()+ ' <eos>')

# encoder_input_data
tokenized_src_test = src_tokenizer.texts_to_sequences(src_test)
tokenized_src_test = [x[:max_len_src] for x in tokenized_src_test]
padded_src_test = tf.keras.preprocessing.sequence.pad_sequences(tokenized_src_test, 
                                                                maxlen=max_len_src, 
                                                                padding='post')
encoder_input_data_test = np.array(padded_src_test)

# decoder_input_data
tokenized_tgt_test = tgt_tokenizer.texts_to_sequences(tgt_test)
tokenized_tgt_test = [x[:max_len_tgt] for x in tokenized_tgt_test]
padded_tgt_test = tf.keras.preprocessing.sequence.pad_sequences(tokenized_tgt_test, 
                                                                maxlen=max_len_tgt, 
                                                                padding='post')
decoder_input_data_test = np.array(padded_tgt_test)

# decoder_output_data
tokenized_tgt_test = tgt_tokenizer.texts_to_sequences(tgt_test)
for i in range(len(tokenized_tgt_test)) :
  tokenized_tgt_test[i] = tokenized_tgt_test[i][1:]
padded_tgt_test = tf.keras.preprocessing.sequence.pad_sequences(tokenized_tgt_test, 
                                                                maxlen=max_len_tgt, 
                                                                padding='post')
labels_test = np.array(padded_tgt_test)

## Model

We use an  LSTM-based attentional encoder-decoder model, where the encoder is a bi-directional LSTM, and the decoder is an LSTM with attention. A diagram of the model can be found below.

<img src="https://raw.githubusercontent.com/da03/TFLite-Micro-Seq2Seq/main/img/encoder_decoder_attn_1layer.png" alt="attentional encoder-decoder illustration" />












For implementation, we use the builtin `tf.keras.layers.LSTMCell` for encoder LSTM cells. For decoder attentional LSTM cell, we use a customized implementation to avoid operations not supported by TFLite Micro, such as [transpose](https://github.com/tensorflow/tensorflow/issues/43472).

In [None]:
class AttnLSTMCell(tf.keras.layers.Layer):
  def __init__(self, hidden_size, output_size, **kwargs):
    self.h = hidden_size
    self.o = output_size

    self.state_size = tf.TensorShape([hidden_size])
    self.output_size = tf.TensorShape([output_size])

    super(AttnLSTMCell, self).__init__(**kwargs)

  def build(self, input_shapes):
    self.decoder_lstm_cell = tf.keras.layers.LSTMCell(self.h)

  def call(self, inputs, states):
    inputs, encoder_outputs = inputs
    outputs, new_states = self.decoder_lstm_cell(inputs, states)
    query = tf.keras.layers.Reshape((dec_hidden_size, 1))(outputs) # bsz, H, 1
    values = encoder_outputs # bsz, max_len_src, H
    keys = values
    scores = tf.matmul(values, query) # bsz, max_len_src, 1
    scores = tf.keras.layers.Softmax(1)(scores) # bsz, max_len_src, 1
    context = scores * values # bsz, max_len_src, H
    context_list = tf.split(context, num_or_size_splits=max_len_src, axis=1)
    context = context_list[0]
    for i in range(1, max_len_src):
      context = context + context_list[i]
    context = tf.keras.layers.Reshape((dec_hidden_size,))(context)
    outputs = tf.keras.layers.Concatenate(axis=-1)([context, outputs])
    
    return outputs, new_states

  def get_config(self):
    return {"hidden_size": self.h, "output_size": self.o}

With the customized attentional LSTM cell, we are ready to build the full model.

In [None]:
HIDDEN_SIZE = 64

hidden_size = HIDDEN_SIZE
enc_hidden_size = hidden_size // 2
dec_hidden_size = hidden_size

src_embedding_size = hidden_size
tgt_embedding_size = hidden_size

# Input tensors:
# Encoder inputs
encoder_inputs = tf.keras.layers.Input(shape=(max_len_src,)) # bsz, max_len_src
decoder_inputs = tf.keras.layers.Input(shape=(max_len_tgt,)) # bsz, max_len_tgt

# Encoder
encoder_embedding_layer = tf.keras.layers.Embedding(SRC_VOCAB_SIZE, 
                                                    src_embedding_size) 
encoder_embeddings = encoder_embedding_layer(encoder_inputs) # bsz, max_len_src, src_embedding_size
encoder_lstm_cell_fw = tf.keras.layers.LSTMCell(enc_hidden_size)
encoder_lstm_cell_bw = tf.keras.layers.LSTMCell(enc_hidden_size)
encoder_lstm_layer_fw = tf.keras.layers.RNN(encoder_lstm_cell_fw, 
                                            return_sequences=True, 
                                            return_state=True, 
                                            go_backwards=False)
encoder_lstm_layer_bw = tf.keras.layers.RNN(encoder_lstm_cell_bw, 
                                            return_sequences=True, 
                                            return_state=True, 
                                            go_backwards=True)

encoder_lstm_layer = tf.keras.layers.Bidirectional(encoder_lstm_layer_fw, 
                                                   merge_mode='concat',
                                                   backward_layer=encoder_lstm_layer_bw)
# Reset bidirectional since Bidirectional creates new cells
encoder_lstm_cell_fw = encoder_lstm_layer.forward_layer.cell 
encoder_lstm_cell_bw = encoder_lstm_layer.backward_layer.cell
encoder_outputs, encoder_states_h_fw, encoder_states_c_fw, \
  encoder_states_h_bw, encoder_states_c_bw \
                                        = encoder_lstm_layer(encoder_embeddings)
encoder_states_h = tf.keras.layers.Concatenate(axis=1)([encoder_states_h_fw, 
                                                        encoder_states_h_bw])
encoder_states_c = tf.keras.layers.Concatenate(axis=1)([encoder_states_c_fw, 
                                                        encoder_states_c_bw])
encoder_states = (encoder_states_h, encoder_states_c)

# Decoder
decoder_embedding_layer = tf.keras.layers.Embedding(TGT_VOCAB_SIZE, tgt_embedding_size)
decoder_embeddings = decoder_embedding_layer(decoder_inputs) # bsz, max_len_tgt, tgt_embedding_size

decoder_lstm_cell = AttnLSTMCell(dec_hidden_size, TGT_VOCAB_SIZE)
decoder_proj_layer = tf.keras.layers.Dense(TGT_VOCAB_SIZE)

logits = []
decoder_state = encoder_states
context = None
for t in range(max_len_tgt):
  decoder_embedding = decoder_embeddings[:, t]
  # Feed context vector to decoder input (see model diagram)
  if context is not None:
    decoder_embedding = decoder_embedding + context
  decoder_output, decoder_state = decoder_lstm_cell([decoder_embedding, encoder_outputs], decoder_state)
  logit = decoder_proj_layer(decoder_output) # bsz, vocab_size
  context = decoder_output[:, :dec_hidden_size]
  logits.append(logit)

decoder_logits = tf.stack(logits, 1)

# Compile model
model = tf.keras.models.Model([encoder_inputs, decoder_inputs], decoder_logits)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), 
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
              experimental_run_tf_function=False)

model.summary()

## Optimization

In [None]:
# Train model
BATCH_SIZE = 64
EPOCHS = 32
model.fit([encoder_input_data, decoder_input_data], 
          decoder_output_data, 
          batch_size=BATCH_SIZE, 
          epochs=EPOCHS) 

## Evaluation

In [None]:
# Evaluate model
total = 0
correct = 0
for encoder_input_d, decoder_input_d, label_d in zip(encoder_input_data_test, 
                                                     decoder_input_data_test, 
                                                     labels_test):
  encoder_input_d = encoder_input_d.reshape((1, -1))
  decoder_input_d = decoder_input_d.reshape((1, -1))
  label_d = label_d.reshape((1, -1))
  logits = model.predict_on_batch([encoder_input_d, decoder_input_d])
  predictions = logits.argmax(-1)
  if (predictions == label_d).all():
    correct += 1
  total +=1
print (f'Test accuracy: {100.0*correct/total}%')

## Conversion to TFLite

With a trained model, we can deploy it to microcontrollers. First, we convert it to TFLite. In particular, we convert encoder LSTM cell and decoder LSTM cell to TFLite. We don't convert the unrolled full encoder because TFLite Micro does not support subgraphs. We don't convert Embedding layers because they are not yet supported by TFLite Micro. Later we will show a workaround for this issue.

In [None]:
# Create tflite model
# Encoder LSTM Cell (left-to-right)
# Inputs
encoder_embedding_fw = tf.keras.Input(shape=(src_embedding_size,)) # bsz, src_embedding_size
encoder_state_h_fw = tf.keras.Input(shape=(enc_hidden_size,)) # bsz, hidden_size
encoder_state_c_fw = tf.keras.Input(shape=(enc_hidden_size,)) # bsz, hidden_size

encoder_output_fw, (encoder_state_h_out_fw, encoder_state_c_out_fw) \
               = encoder_lstm_cell_fw(encoder_embedding_fw, 
                                      (encoder_state_h_fw, encoder_state_c_fw))

enc_fw_micro_model =  tf.keras.models.Model([encoder_embedding_fw, 
                                             encoder_state_h_fw, 
                                             encoder_state_c_fw], 
                                            [encoder_output_fw, 
                                             encoder_state_h_out_fw, 
                                             encoder_state_c_out_fw])
enc_fw_micro_model.summary()

converter = tf.lite.TFLiteConverter.from_keras_model(enc_fw_micro_model)
buffer_enc_fw = converter.convert()
open('enc_model_fw.tflite', 'wb').write(buffer_enc_fw)

# Encoder LSTM Cell (right-to-left)
# Inputs
encoder_embedding_bw = tf.keras.Input(shape=(src_embedding_size,)) # bsz, src_embedding_size
encoder_state_h_bw = tf.keras.Input(shape=(enc_hidden_size,)) # bsz, hidden_size
encoder_state_c_bw = tf.keras.Input(shape=(enc_hidden_size,)) # bsz, hidden_size

encoder_output_bw, (encoder_state_h_out_bw, encoder_state_c_out_bw) \
               = encoder_lstm_cell_bw(encoder_embedding_bw, 
                                      (encoder_state_h_bw, encoder_state_c_bw))

enc_bw_micro_model =  tf.keras.models.Model([encoder_embedding_bw, 
                                             encoder_state_h_bw, 
                                             encoder_state_c_bw], 
                                            [encoder_output_bw, 
                                             encoder_state_h_out_bw, 
                                             encoder_state_c_out_bw])
enc_bw_micro_model.summary()

converter = tf.lite.TFLiteConverter.from_keras_model(enc_bw_micro_model)
buffer_enc_bw = converter.convert()
open('enc_model_bw.tflite', 'wb').write(buffer_enc_bw)

# Decoder LSTM Cell
# Inputs
encoder_output = tf.keras.Input(shape=(max_len_src, dec_hidden_size,)) # bsz, tgt_embedding_size
decoder_embedding = tf.keras.Input(shape=(tgt_embedding_size,)) # bsz, tgt_embedding_size
decoder_state_h = tf.keras.Input(shape=(dec_hidden_size,)) # bsz, hidden_size
decoder_state_c = tf.keras.Input(shape=(dec_hidden_size,)) # bsz, hidden_size

decoder_output, (decoder_state_h_out, decoder_state_c_out) \
              = decoder_lstm_cell([decoder_embedding, encoder_output], 
                                  (decoder_state_h, decoder_state_c))
decoder_context = decoder_output[:, :dec_hidden_size]
decoder_logit = decoder_proj_layer(decoder_output) # bsz, tgt_vocab_size


dec_micro_model =  tf.keras.models.Model([encoder_output, 
                                          decoder_embedding, 
                                          decoder_state_h, 
                                          decoder_state_c], 
                                         [decoder_logit, 
                                          decoder_context, 
                                          decoder_state_h_out, 
                                          decoder_state_c_out])
dec_micro_model.summary()

converter = tf.lite.TFLiteConverter.from_keras_model(dec_micro_model)
buffer_dec = converter.convert()
open('dec_model.tflite' , 'wb').write(buffer_dec)

## Conversion to TFLite Micro (Optional)

Now we can further convert our models to TFLite Micro and deploy to Arduino Nano 33 BLE. One issue is that TFLite Micro does not support Embeddings (because it does not support gather operations). To circumvent this issue, we directly implement embeddings in C++. The generated C++ files can be found in folder `c_src`.

Alternatively, you can also skip this section and jump to the next section for doing inference in TFLite.

In [None]:
# Create C files for Arduino
c_folder = "c_src"
!mkdir -p {c_folder}
model_h_file = os.path.join(c_folder, 'model.h')
model_cpp_file = os.path.join(c_folder, 'model.cpp')
# model.h
for word, index in tgt_tokenizer.word_index.items():
  if word == '<bos>':
    bos_word_id = index
  elif word == '<eos>':
    eos_word_id = index
model_h_str = f"""
#ifndef MODEL_H_
#define MODEL_H_

#include "Arduino.h"
#include "tensorflow/lite/c/common.h"

extern const unsigned char g_enc_model_fw[];
extern const unsigned char g_enc_model_bw[];
extern const unsigned char g_dec_model[];
void set_enc_embed(TfLiteTensor* ptr, String token);
void set_dec_embed(TfLiteTensor* ptr, String token);
String id_to_word(int idx);

const int max_len_src = {max_len_src};
const int max_len_tgt = {max_len_tgt};
const int bos_word_id = {bos_word_id};
const int eos_word_id = {eos_word_id};
const int enc_hidden_size = {enc_hidden_size};
const int dec_hidden_size = {dec_hidden_size};
const int src_embedding_size = {src_embedding_size};
const int tgt_embedding_size = {tgt_embedding_size};
const int src_vocab_size = {SRC_VOCAB_SIZE};
const int tgt_vocab_size = {TGT_VOCAB_SIZE};
#endif
"""

with open(model_h_file, 'w') as fout:
  fout.write(model_h_str)

# model.cpp - enc rnn left-to-right
!xxd -i  enc_model_fw.tflite > tmp

with open('tmp') as fin:
  with open(model_cpp_file, 'w') as fout:
    text = fin.read().strip()
    lines = list(text.split('\n'))
    fout.write('#include "model.h"\n')
    fout.write('alignas(8) const unsigned char g_enc_model_fw[] = {\n')
    for line in lines[1:-1]:
      fout.write(line + '\n')

# model.cpp - enc rnn right-to-left
!xxd -i  enc_model_bw.tflite > tmp

with open('tmp') as fin:
  with open(model_cpp_file, 'a') as fout:
    text = fin.read().strip()
    lines = list(text.split('\n'))
    fout.write('alignas(8) const unsigned char g_enc_model_bw[] = {\n')
    for line in lines[1:-1]:
      fout.write(line + '\n')

# model.cpp - dec rnn
!xxd -i  dec_model.tflite > tmp

with open('tmp') as fin:
  with open(model_cpp_file, 'a') as fout:
    text = fin.read().strip()
    lines = list(text.split('\n'))
    fout.write('alignas(8) const unsigned char g_dec_model[] = {\n')
    for line in lines[1:-1]:
      fout.write(line + '\n')

# model.cpp - embeddings
def to_c(array):
  return '{' + ','.join([ '{' + ','.join([str(i) for i in item]) + '}' for item in array]) + '}'

src_embeddings_str = to_c(model.trainable_variables[0].numpy())
src_vocab_str = '{' + ",".join(['"' + item + '"' for item in src_vocab]) + '}'
enc_embedding_str = f"""
void set_enc_embed(TfLiteTensor* ptr, String token) {{
  float embeddings[src_vocab_size][src_embedding_size] = {src_embeddings_str};
  const char *words[src_vocab_size-1] = {src_vocab_str};
  int word_id = 0;
  for (int i = 1; i < src_vocab_size; i++) {{
    String token2 = words[i-1];
    if (token == token2) {{
      word_id = i;
      break;
    }}
  }}
  for (int i = 0; i < src_embedding_size; i++) {{
    ptr->data.f[i] = embeddings[word_id][i];
  }}
}}
"""

with open(model_cpp_file, 'a') as fout:
  fout.write('\n')
  fout.write(enc_embedding_str)

tgt_embeddings_str = to_c(model.trainable_variables[1].numpy())
tgt_vocab_str = '{' + ",".join(['"' + item + '"' for item in tgt_vocab]) + '}'
dec_embedding_str = f"""
void set_dec_embed(TfLiteTensor* ptr, String token) {{
  float embeddings[tgt_vocab_size][tgt_embedding_size] = {tgt_embeddings_str};
  const char *words[tgt_vocab_size-1] = {tgt_vocab_str};
  int word_id = 0;
  for (int i = 1; i < tgt_vocab_size; i++) {{
    String token2 = words[i-1];
    if (token == token2) {{
      word_id = i;
      break;
    }}
  }}
  for (int i = 0; i < tgt_embedding_size; i++) {{
    ptr->data.f[i] = embeddings[word_id][i];
  }}
}}

String id_to_word(int idx) {{
  const char *words[tgt_vocab_size-1] = {tgt_vocab_str};
  if (idx == 0) {{
    return String("pad");
  }}
  String str = words[idx-1];
  return str;
}}
"""

with open(model_cpp_file, 'a') as fout:
  fout.write('\n')
  fout.write(dec_embedding_str)

Now, we can download the generated C++ files and save them to the `src` folder of the Arduino project and upload it to Arduino Nano 33 BLE. Please refer to https://github.com/da03/TFLite-Micro-Seq2Seq for details on how to do that.

## Inference using TFLite (Optional)

This section shows how to perform inference using TFLite. The workflow is essentially the same as what we used in the Arduino project, so it can also be used for debugging purposes.



In [None]:
def sentence_to_ids(sentence):
  words = sentence.lower().split()
  tokens_list = []
  for word in words:
    tokens_list.append( src_tokenizer.word_index[word]) 
  return tf.keras.preprocessing.sequence.pad_sequences([tokens_list], 
                                                       maxlen=max_len_src,
                                                       padding='post')

In [None]:
# Initialize the TFLite interpreters
interpreter_enc_fw = tf.lite.Interpreter(model_content=buffer_enc_fw)
interpreter_enc_fw.allocate_tensors()

interpreter_enc_bw = tf.lite.Interpreter(model_content=buffer_enc_bw)
interpreter_enc_bw.allocate_tensors()

interpreter_dec = tf.lite.Interpreter(model_content=buffer_dec)
interpreter_dec.allocate_tensors()

input_details_enc_fw = interpreter_enc_fw.get_input_details()
output_details_enc_fw = interpreter_enc_fw.get_output_details()

input_details_enc_bw = interpreter_enc_bw.get_input_details()
output_details_enc_bw = interpreter_enc_bw.get_output_details()

input_details_dec = interpreter_dec.get_input_details()
output_details_dec = interpreter_dec.get_output_details()

# Get source sentence
encoder_input = sentence_to_ids(input('Enter tokenized source sentence: '))

h_fw = np.zeros((1, enc_hidden_size))
c_fw = np.zeros((1, enc_hidden_size))
memory_bank = np.zeros((1, max_len_src, dec_hidden_size))
for t in range(max_len_src):
  embeddings = model.trainable_variables[0].numpy()[encoder_input[0][t]].reshape((1, -1))
  interpreter_enc_fw.set_tensor(input_details_enc_fw[0]["index"], embeddings.astype(np.float32))
  interpreter_enc_fw.set_tensor(input_details_enc_fw[1]["index"], h_fw.astype(np.float32))
  interpreter_enc_fw.set_tensor(input_details_enc_fw[2]["index"], c_fw.astype(np.float32))
  interpreter_enc_fw.invoke()
  out_fw = interpreter_enc_fw.get_tensor(output_details_enc_fw[0]["index"])
  h_fw = interpreter_enc_fw.get_tensor(output_details_enc_fw[1]["index"])
  c_fw = interpreter_enc_fw.get_tensor(output_details_enc_fw[2]["index"])
  memory_bank[:, t, :enc_hidden_size] = out_fw

h_bw = np.zeros((1, enc_hidden_size))
c_bw = np.zeros((1, enc_hidden_size))
for t in range(max_len_src-1, -1, -1):
  embeddings = model.trainable_variables[0].numpy()[encoder_input[0][t]].reshape((1, -1))
  interpreter_enc_bw.set_tensor(input_details_enc_bw[0]["index"], embeddings.astype(np.float32))
  interpreter_enc_bw.set_tensor(input_details_enc_bw[1]["index"], h_bw.astype(np.float32))
  interpreter_enc_bw.set_tensor(input_details_enc_bw[2]["index"], c_bw.astype(np.float32))
  interpreter_enc_bw.invoke()
  out_bw = interpreter_enc_bw.get_tensor(output_details_enc_bw[0]["index"])
  h_bw = interpreter_enc_bw.get_tensor(output_details_enc_bw[1]["index"])
  c_bw = interpreter_enc_bw.get_tensor(output_details_enc_bw[2]["index"])
  memory_bank[:, t, enc_hidden_size:] = out_bw

decoded_translation = ''
prev_word_id = bos_word_id
h = np.concatenate([h_fw, h_bw], axis=-1)
c = np.concatenate([c_fw, c_bw], axis=-1)
context = None
while True:
  embeddings = model.trainable_variables[1].numpy()[prev_word_id].reshape((1, -1))
  if context is not None:
    embeddings = embeddings + context.reshape((1, -1))
  interpreter_dec.set_tensor(input_details_dec[0]["index"], memory_bank.astype(np.float32))
  interpreter_dec.set_tensor(input_details_dec[1]["index"], embeddings.astype(np.float32))
  interpreter_dec.set_tensor(input_details_dec[2]["index"], h.astype(np.float32))
  interpreter_dec.set_tensor(input_details_dec[3]["index"], c.astype(np.float32))
  interpreter_dec.invoke()
  dec_outputs = interpreter_dec.get_tensor(output_details_dec[2]["index"])
  context = interpreter_dec.get_tensor(output_details_dec[3]["index"])
  h = interpreter_dec.get_tensor(output_details_dec[0]["index"])
  c = interpreter_dec.get_tensor(output_details_dec[1]["index"])
  sampled_word_id = np.argmax( dec_outputs[0])
  prev_word_id = sampled_word_id
  if sampled_word_id == eos_word_id or len(decoded_translation.split()) > max_len_tgt:
    break
  for word, index in tgt_tokenizer.word_index.items() :
    if sampled_word_id == index:
        decoded_translation += ' ' + word
        break

print(decoded_translation)

## Acknowledgements

* This notebook is based on [Chatbot using seq2seq LSTM models](https://colab.research.google.com/drive/1FKhOYhOz8d6BKLVVwL1YMlmoFQ2ML1DS).
* The number-to-words task used here comes from CS187 at Harvard.
* This project is a course project of [CS249](https://scholar.harvard.edu/vijay-janapa-reddi/classes/cs249r-tinyml) at Harvard.