<a href="https://colab.research.google.com/github/unclepeddy/deeplearning/blob/master/6-misc-examples/bot_with_lstm_seq2seq.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!pip install tensorflow==2.0

In [0]:
import json
import logging

from numpy import argmax, append as np_append, array as np_array, zeros

from tensorflow.keras.activations import softmax
from tensorflow.keras.layers import Input, Embedding, LSTM, Dense
from tensorflow.keras.losses import categorical_crossentropy
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.utils import to_categorical

## Step 1: Prepare Conversation Data

Download json file containing the 1:1 thread over which you want to train your language model and preprocess it to extract lists of questions and answers. 

In [0]:
# Download the data
!wget -O conversation.json http://storage.googleapis.com/peddy-ai-dl-data/fb-messages/kate.json

In [0]:
def extract_qa_from_conversation(filename, subject):
  """Extract question and answers from 1-on-1 thread.
  Args:
    filename: Name of file containing thread json object downloaded from fb
    subject: Name of subject for whom we are training the language model
  Returns
    (questions, answers) tuple of the same length
  """

  with open(filename) as json_file:
    data = json.load(json_file)

  participants = list(map(lambda x: x['name'], data['participants']))
  assert len(participants) == 2
  assert subject in participants

  raw_messages = data['messages']
  
  # Ensure we start at a message sent not by the subject
  start = None 
  for start in reversed(range(len(raw_messages))):
    if (raw_messages[start]['sender_name'] != subject):
      break

  assert start > 0

  # Aggregate and capture consecutive messages by the same sender
  sender = raw_messages[start]['sender_name']
  text = ""
  questions = list()
  answers = list()
  timestamp = raw_messages[start]['timestamp_ms']

  for i in range(start, -1, -1):
    message = raw_messages[i]

    if (message.get('content') == None or len(message['content']) == 0): 
      logging.warning('message [%s] did not have content - skipping', message)
      continue

    # Escape non-ASCII characters
    message['content'] = message['content'].encode('ascii', errors='ignore').decode()
     
    if (message['sender_name'] == sender):
      text += " " + message['content'].replace('\n', ' ')
    else:
      if (sender == subject):
        answers.append("<BEG>{}<END>".format(text))
      else:
        questions.append(text)
      sender = message['sender_name']
      text = message['content'].replace('\n', ' ')

  # Only append the last message if it were the subject's response
  if (sender == subject):
    answers.append("<BEG>{}<END>".format(text))
  assert len(questions) == len(answers), "questions contains %d elements while answers contains %d elements." % (len(questions), len(answers))

  return (questions, answers)

In [5]:
# Build 2 separate lists containing questions and answers
(q, a) = extract_qa_from_conversation('conversation.json', 'Pedram Pejman')

print("Question: {}\nAnswer: {}".format(q[0], a[0]))

Question:  Kate sent a photo. rad
Answer: <BEG>You can now call each other and see information like Active Status and when you've read messages. When you forget what you were gonna say so you just show off your manicure<END>


## Step 2: Train a simple Seq2Seq Model

We'll use a simple seq2seq architecture to learn a language model over the conversation data we have prepared.

In [0]:
def preprocess_texts(texts, tokenizer, one_hot=False):
  """Returns np.array with texts tokenized and padded to length of longest text."""
  # TODO(peddy): Add OOV Token handling
  tokenized_texts = tokenizer.texts_to_sequences(texts)
  # TODO(peddy): extract max_len as a hparam
  max_len = max([len(x) for x in tokenized_texts])
  padded_texts = pad_sequences(tokenized_texts, maxlen=max_len, padding='post')
  if one_hot:
    vocab_size = len(tokenizer.word_index) + 1
    padded_texts = to_categorical(padded_texts, vocab_size)
  return np_array(padded_texts)

tokenizer = Tokenizer()
tokenizer.fit_on_texts(q + a)
tokenizer.reverse_word_index = dict([(v, k) for (k, v) in tokenizer.word_index.items()])

In [0]:
encoder_input_data = preprocess_texts(q, tokenizer, False)
decoder_input_data = preprocess_texts(a, tokenizer, False)
# TODO(peddy): Drop the <BEG> tag from the response
decoder_output_data = preprocess_texts(a, tokenizer, True)

In [0]:
vocab_size = len(tokenizer.word_index) + 1
embedding_dim = 200
lstm_units = 200
epochs = 4
batch_size = 4
max_len = 32

In [9]:
enc_inputs = Input(shape=(None, ), name='enc_inputs')
enc_embedding = Embedding(vocab_size, embedding_dim, mask_zero=True, name='enc_embedding')(enc_inputs)
_, enc_state_h, enc_state_c = LSTM(lstm_units, return_state=True, name='enc_lstm')(enc_embedding)
enc_states = [enc_state_h, enc_state_c]

dec_inputs = Input(shape=(None, ), name='dec_inputs')
dec_embedding = Embedding(vocab_size, embedding_dim, mask_zero=True, name='dec_embedding')(dec_inputs)
dec_lstm = LSTM(lstm_units, return_state=True, return_sequences=True)
dec_lstm_outputs, _, _ = dec_lstm(dec_embedding, initial_state=enc_states)
dec_softmax = Dense(vocab_size, activation=softmax, name='dec_softmax')
dec_outputs = dec_softmax(dec_lstm_outputs)

training_model = Model([enc_inputs, dec_inputs], dec_outputs)
training_model.compile(loss=categorical_crossentropy, optimizer=RMSprop())

training_model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
enc_inputs (InputLayer)         [(None, None)]       0                                            
__________________________________________________________________________________________________
dec_inputs (InputLayer)         [(None, None)]       0                                            
__________________________________________________________________________________________________
enc_embedding (Embedding)       (None, None, 200)    85000       enc_inputs[0][0]                 
__________________________________________________________________________________________________
dec_embedding (Embedding)       (None, None, 200)    85000       dec_inputs[0][0]                 
______________________________________________________________________________________________

In [11]:
training_model.fit([encoder_input_data, decoder_input_data], decoder_output_data, batch_size=batch_size, epochs=epochs)

Train on 23 samples
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


<tensorflow.python.keras.callbacks.History at 0x7f60f11a9e80>

In [0]:
# Build inference model

inf_dec_state_h_input = Input(shape=(lstm_units, ), name='inf_dec_state_h_input')
inf_dec_state_c_input = Input(shape=(lstm_units, ), name='inf_dec_state_c_input')
inf_dec_state_inputs = [inf_dec_state_h_input, inf_dec_state_c_input]

inf_dec_lstm_outputs, inf_dec_state_h, inf_dec_state_c = dec_lstm(dec_embedding, initial_state=inf_dec_state_inputs)
inf_dec_state = [inf_dec_state_h, inf_dec_state_c]
inf_dec_outputs = dec_softmax(inf_dec_lstm_outputs)

inf_encoder = Model(enc_inputs, enc_states)
inf_decoder = Model([dec_inputs] + inf_dec_state_inputs, [inf_dec_outputs] + inf_dec_state)

In [0]:
tokenizer.reverse_word_index[0] = "OOV"

In [14]:
example_sentence = "do you have the whole thing"

tokens = preprocess_texts([example_sentence], tokenizer)

dec_states = inf_encoder.predict(tokens)
dec_seq = zeros((1, 1))
terminate = False
pred_tokens = list()

while not terminate:
  dec_pred, h, c = inf_decoder.predict([dec_seq] + dec_states)
  pred_word_index = argmax(dec_pred[0, -1, :])
  pred_word = tokenizer.reverse_word_index[pred_word_index]
  pred_tokens.append(pred_word)
  dec_seq = zeros((1,1))
  dec_seq[0, 0] = pred_word_index 
  dec_states = [h, c]
  if (pred_word == 'end' or len(pred_tokens) > max_len):
    break
print(pred_tokens)

['beg', 'beg', 'beg', 'you', 'you', 'you', 'you', 'you', 'you', 'end']
