# Chatbot Tutorial

## 1. Preparations

In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import torch
from torch.jit import script, trace
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import csv
import random
import re
import os
import unicodedata
import codecs
from io import open
import itertools
import math

USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")

## 2. Load & Preprocess Data

In [2]:
corpus_name = 'cornell_movie_dialogs_corpus'
corpus = os.path.join('data', corpus_name)

def printLines(file, n=10):
    with open(file, 'rb') as fr:
        lines = fr.readlines()
    for line in lines[:n]:
        print(line)

In [3]:
printLines(os.path.join(corpus, 'movie_lines.txt'))

b'L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!\n'
b'L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!\n'
b'L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.\n'
b'L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?\n'
b"L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.\n"
b'L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow\n'
b"L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie.\n"
b'L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No\n'
b'L870 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I\'m kidding.  You know how sometimes you just become this "persona"?  And you don\'t know how to quit?\n'
b'L869 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Like my fear of wearing pastels?\n'


### Create formatted data file

Formatted data file: each line contains ***a tab-separated query sentence*** and ***a response sentence*** pair.

In [4]:
# Splits each line into a dictionary: 
# lines是字典，每个元素line：key=lineID, value=lineObj={lineID:xxx, characterID:xxx, movieID:xxx, character:xxx, text:xxx}
def loadLines(file, cols):
    lines = {}
    with open(file, 'r', encoding='iso-8859-1') as fr:
        for line in fr:
            values = line.split(' +++$+++ ')
            lineObj = {col: values[i] for i, col in enumerate(cols)}
            lines[lineObj['lineID']] = lineObj
    return lines

In [5]:
# Groups cols of lines from `loadLines` into conversations based on movie_conversations.txt
# conversations是列表，每个元素convObj: {col1:xxx, col2:xxx, ..., lines: [lineObj1, lineObj2, ..., lineObjm]}
def loadConversations(file, lines, cols):
    conversations = []
    with open(file, 'r', encoding='iso-8859-1') as fr:
        for line in fr:
            values = line.split(' +++$+++ ')
            convObj = {col: values[i] for i, col in enumerate(cols)}
            convObj['lines'] = [lines[lineId] for lineId in eval(convObj['utteranceIDs'])]
            conversations.append(convObj)
    return conversations

> 刘尧：训练数据是每个conversation中所有sentence生成的sentece对: <前一句话, 后一句话\>

In [6]:
# Extracts pairs of sentences from conversations
# qa_pair是列表，每个元素是sentence pair: [conv1_text1,conv1_text2], [1_2,1_3], [1_3,1_4], [1_4,1_5], ..., [2_1,2_2], [2_2,2_3], ...
def extractSentencePairs(conversations):
    qa_pair = []
    for conv in conversations:
        for i in range(len(conv['lines']) - 1):  # Ignore the last line (no answer for it)
            inputLine = conv['lines'][i]['text'].strip()
            targetLine = conv['lines'][i + 1]['text'].strip()
            if inputLine and targetLine:
                qa_pair.append([inputLine, targetLine])
    return qa_pair

In [7]:
datafile = os.path.join(corpus, 'formatted_movie_lines.txt')
delimiter = '\t'
delimiter = str(codecs.decode(delimiter, 'unicode_escape'))  # 有时间好好研究这句话！！！

lines = {}
conversations = []
MOVIE_LINES_COLS = ['lineID', 'characterID', 'movieID', 'character', 'text']
MOVIE_CONVERSATIONS_COLS = ['character1ID', 'character2ID', 'movieID', 'utteranceIDs']

In [8]:
print('\nProcessing corpus and loading conversations...')
lines = loadLines(os.path.join(corpus, 'movie_lines.txt'), MOVIE_LINES_COLS)
conversations = loadConversations(os.path.join(corpus, 'movie_conversations.txt'), lines, MOVIE_CONVERSATIONS_COLS)


Processing corpus and loading conversations...


In [9]:
print('\nWriting newly formatted file ...')
with open(datafile, 'w', encoding='utf8') as fw:
    writer = csv.writer(fw, delimiter=delimiter, lineterminator='\n')
    for pair in extractSentencePairs(conversations):
        writer.writerow(pair)
print('\nSample lines from file: ')
printLines(datafile)


Writing newly formatted file ...

Sample lines from file: 
b"Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\tWell, I thought we'd start with pronunciation, if that's okay with you.\n"
b"Well, I thought we'd start with pronunciation, if that's okay with you.\tNot the hacking and gagging and spitting part.  Please.\n"
b"Not the hacking and gagging and spitting part.  Please.\tOkay... then how 'bout we try out some French cuisine.  Saturday?  Night?\n"
b"You're asking me out.  That's so cute. What's your name again?\tForget it.\n"
b"No, no, it's my fault -- we didn't have a proper introduction ---\tCameron.\n"
b"Cameron.\tThe thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\n"
b"The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\tSeems like she could get a 

### Load and trim data

> 刘尧：把vocabulary及其附属或衍生变量以及相关method封装抽象成Class！这样既保护数据又方便使用！好好好！！！

Note that we are dealing with sequences of **words**, we should create a **vocabulary**: mapping each unique word that we encounter in our dataset to an index value.

For this we define a ***Vocabulary*** class, which has 5 attributes and 3 methods:

- 5 attributes

    - **word2index**: A mapping from each word to index

    - **index2word**: A reverse mapping from index to each word

    - **word2count**: A mapping from each word to its count

    - num_words: A total word count
    
    - trimmed: If infrequently seen words are trimmed

- 3 methods

    - **addWord**: Adding a word to the vacabulary

    - addSentence: Adding all words in a sentence

    - trim: Trimming infrequently seen words

In [12]:
# Default word tokens
PAD_TOKEN = 0
SOS_TOKEN = 1  # Start of sentence
EOS_TOKEN = 2  # End of sentence 

MAX_LENGTH = 10

In [20]:
class Vocabulary(object):
    def __init__(self, name):
        self.name = name
        self.trimmed = False
        self.word2index = {}  # 不默认包含PAD,SOS,EOS这仨
        self.word2count = {}
        self.index2word = {PAD_TOKEN: 'PAD', SOS_TOKEN: 'SOS', EOS_TOKEN: 'EOS'}
        self.num_words = 3  # SOS, EOS, PAD
        
    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words    # 添加的word，其index依次往后排
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words += 1
        else:
            self.word2count[word] += 1
    
    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)
            
    # Remove words below a certain count threshold
    def trim(self, min_count):
        if self.trimmed:
            return
        
        keep_words = [k for k, v in self.word2count.items() if v >= min_count]
        print(f'keep_words {len(keep_words)} / {len(self.word2index)} = {len(keep_words) / len(self.word2index): .4f}')
        
        # Reinitializa dictionaries
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_TOKEN: 'PAD', SOS_TOKEN: 'SOS', EOS_TOKEN: 'EOS'}
        self.num_words = 3
        
        for word in keep_words:
            self.addWord(word)
            
        self.trimmed = True

Some data preprocessing:

- **unicodeToAscii**: Convert the Unicode strings to ASCII

- **normalizeString**: Convert all letters to lowercase and trim all non-letter characters except for basic punctuation

- **filterPairs**: Filter sentences with length greater than the *MAX_LENGTH* threshold

> 刘尧：这些常规的预处理，最好封装成一个个function，以方便使用！可以放在**Coding通用工具脚本**里！

In [13]:
def unicodeToAscii(s):
    """Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427"""
    return ''.join(c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn')

In [14]:
def normalizeString(s):
    """Lowercase, trim, and remove non-letter characters"""
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r'([.!?])', r' \1', s)      # 把.!?三个标点符号替换为？
    s = re.sub(r'[^a-zA-Z.!?]+', r' ', s)  # 把字母和.!?之外的character替换为空格
    s = re.sub(r'\s+', r' ', s).strip()    # 把替换为空格
    return s

In [15]:
def readVocs(datafile, corpus_name):
    """Read <query, response> pairs and return a Vocabulary object"""
    lines = open(datafile, encoding='utf8').read().strip().split('\n')
    pairs = [[normalizeString(s) for s in line.split('\t')] for line in lines]
    voc = Vocabulary(corpus_name)
    return voc, pairs

In [17]:
def filterPair(pair):
    """Return True iff both sentences in pair are under the MAX_LENGTH threshold"""
    return len(pair[0].split(' ')) < MAX_LENGTH and len(pair[1].split(' ')) < MAX_LENGTH

def filterPairs(pairs):
    """Filter pairs using filterPair function"""
    return [pair for pair in pairs if filterPair(pair)]

In [21]:
def loadPrepareData(corpus, corpus_name, datafile, save_dir):  # corpus, save_dir 在哪里使用的！？！
    """Using the functions above, return a populated Vocabulary object and pairs list"""
    print('Start preparing training data ...')
    vocabulary, pairs = readVocs(datafile, corpus_name)
    print('Read {!s} sentence pairs'.format(len(pairs)))
    pairs = filterPairs(pairs)
    print('Trimmed to {!s} sentence pairs'.format(len(pairs)))
    print('Counting words ...')
    for pair in pairs:
        vocabulary.addSentence(pair[0])
        vocabulary.addSentence(pair[1])
    print('Counted words: ', vocabulary.num_words)
    return vocabulary, pairs

In [31]:
# Load/Assemble vocabulary and pairs
save_dir = os.path.join('data', 'save')
vocabulary, pairs = loadPrepareData(corpus, corpus_name, datafile, save_dir)

Start preparing training data ...
Read 221282 sentence pairs
Trimmed to 64271 sentence pairs
Counting words ...
Counted words:  18008


In [32]:
for pair in pairs[:10]:
    print(pair)

['there .', 'where ?']
['you have my word . as a gentleman', 'you re sweet .']
['hi .', 'looks like things worked out tonight huh ?']
['you know chastity ?', 'i believe we share an art instructor']
['have fun tonight ?', 'tons']
['well no . . .', 'then that s all you had to say .']
['then that s all you had to say .', 'but']
['but', 'you always been this selfish ?']
['do you listen to this crap ?', 'what crap ?']
['what good stuff ?', 'the real you .']


Another tactic that is beneficial to achieving faster convergence during training is **trimming rarely used words out of our vocabulary**. 

Decreasing the feature space will also soften the difficulty of the function that the model must learn to approximate.

We will do this as a two-stage process:
    
- Trim words used under *MIN_COUNT* threshold using the *Vocabulary.trim* function

- Filter out pairs with trimmed words

> 刘尧：事先从Vocabulary中定义并删除不常见的word，即**OOV的word**，随后从训练数据中删除这些OOV的word！ 

> 刘尧：疑问：模型应用时遇到OOV的word咋办？？？跟训练一样，应用前也先使用trimRareWords来处理一下！？

In [33]:
MIN_COUNT = 3
def trimRareWords(vocabulary, pairs, MIN_COUNT):
    """基于MIN_COUNT，删除vocabulary中不常见的word，并从训练/应用数据中删除带有不常见word的pairs"""
    vocabulary.trim(MIN_COUNT)
    
    keep_pairs = []
    for pair in pairs:
        input_sentence = pair[0]
        output_sentence = pair[1]
        keep_input, keep_output = True, True
        
        # 判断pairs中2个句子中是否存在OOV的word，一旦存在，则删除当前pairs
        for word in input_sentence.split(' '):
            if word not in vocabulary.word2index:
                keep_input = False
                break
        for word in output_sentence.split(' '):
            if word not in vocabulary.word2index:
                keep_output = False
                break
        if keep_input and keep_output:
            keep_pairs.append(pair)
            
    print(f'Trimmed from {len(pairs)} pairs to {len(keep_pairs)}, {len(keep_pairs) / len(pairs): .4f} of total')
    return keep_pairs

In [34]:
pairs = trimRareWords(vocabulary, pairs, MIN_COUNT)

keep_words 7823 / 18005 =  0.4345
Trimmed from 64271 pairs to 53165,  0.8272 of total


## 3. Prepare Data for Models

## 4. Define Models

### Seq2Seq Model



### Encoder

### Decoder

## 5. Define Training Procedure

### Masked loss

### Single training iteration

### Training iteration

## 6. Define Evaluation

### Greedy decoding

### Evaluate my text

## 7. Run Model

### Run Training

### Run Evaluation

## 8. Conclusion