# LSTM Chatbot

## Preliminaries

In [14]:
import torch
import os
import random
import pandas as pd

from src import data, vocab, model, train, evaluate

# Set constants
project_root = "03-RNNs/final_project/"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_model = False

## 1. Load and preprocess data

In [15]:
data_df = data.load_data()
data_df.head()

Unnamed: 0,Question,Answer
0,to whom did the virgin mari alleg appear in 18...,saint bernadett soubir
1,what is in front of the notr dame main build,a copper statu of christ
2,the basilica of the sacr heart at notr dame is...,the main build
3,what is the grotto at notr dame,a marian place of prayer and reflect
4,what sit on top of the main build at notr dame,a golden statu of the virgin mari


In [16]:
data_df.shape

(98169, 2)

In [17]:
# For development I take only 5000 records
data_df = data_df.iloc[:15000, :]

## 2. Build corpus

In [18]:
Q_vocab = vocab.Vocab() # = Source
A_vocab = vocab.Vocab() # = Target

pairs = data.get_pairs(data_df)

# Add Questions text to Q_vocab class and
# Answers text to A_vocab class
for pair in pairs:
    Q_vocab.add_sentence(pair[0])
    A_vocab.add_sentence(pair[1])

In [19]:
Q_vocab.get_indexes_from_sentence("what is in front of the notr dame main build")

[14, 15, 10, 16, 17, 5, 18, 19, 20, 21]

In [20]:
# Print random Question Answer pair
random_pair = random.choice(pairs)
print(
    "Random Question Answer pair:\n> Q:", random_pair[0], "\n< A:", random_pair[1], "\n"
) 

# Print number of words in vocabularies
print("Question vocabulary (input): {} words\nAnswer vocabulary (output): {} words".format(Q_vocab.n_words, A_vocab.n_words))

Random Question Answer pair:
> Q: what type of process theolog doe c robert mesl promot 
< A: process natur ie a process theolog without god 

Question vocabulary (input): 9139 words
Answer vocabulary (output): 9837 words


## 3. Initialize model

In [21]:
learning_rate = 0.01
hidden_size = 128 # encoder and decoder hidden size
batch_size = 128
epochs = 2

model_name = "seq2seq_exp_1"

seq2seq = model.Seq2Seq(
    Q_vocab.n_words, 
    hidden_size, 
    A_vocab.n_words, 
    model_name
).to(device)

## 4. Train model

In [22]:
# Get list of Tensors
Q_tensors = [Q_vocab.get_tensor_from_sentence(pair[0]) for pair in pairs]
A_tensors = [A_vocab.get_tensor_from_sentence(pair[1]) for pair in pairs]

In [23]:
load_model_weights = False
if load_model_weights:
    try:
        model_state_dict = torch.load(f"checkpoints/{model_name}_best_loss.pt")
        seq2seq.load_state_dict(model_state_dict)
        print("Model weights loaded")
    except:
        print("Model checkpoints not found")

In [24]:
if train_model:
    train.train(
        Q_tensors, 
        A_tensors, 
        seq2seq, 
        epochs, 
        batch_size, 
        1, 
        learning_rate,
        interactive_tracking=False,
    )
    # Save last model weights
    torch.save(seq2seq.state_dict(), "checkpoints/" + seq2seq.model_name + "_last_loss.pt")

## 5. Load model best weights for inference

In [25]:
try:
    model_state_dict = torch.load(f"checkpoints/{model_name}_best_loss.pt")
    seq2seq.load_state_dict(model_state_dict)
    print("Model weights loaded")
except:
    print("Model checkpoints not found")

Model weights loaded


## 6. Randomly evaluate model

In [26]:
n_to_evaluate = 5
for i in range(n_to_evaluate):
    pair = random.choice(pairs)
    pred_words = evaluate.evaluate(seq2seq, Q_vocab, A_vocab, pair)
    print("> Question: {}".format(pair[0]))
    print("< Anwer: {}".format(pair[1]))
    print("< Prediction: {}\n".format(' '.join(pred_words)))

> Question: how much did schwarzenegg make from the film total recal on top of 15 of gross
< Anwer: 10 million
< Prediction: more

> Question: what town is the crow and gate locat in
< Anwer: crowborough
< Prediction: east

> Question: when did the zhengd emperor rule
< Anwer: 1505–1521
< Prediction: april

> Question: which countri lie on congo northeast border
< Anwer: central african republ
< Prediction: the and and

> Question: in how mani countri doe unfpa oper
< Anwer: 150
< Prediction: three

