# Neural Conversational Model
**Jin Yeom**  
jin.yeom@hudl.com

This notebook reproduces [this tutorial](https://pytorch.org/tutorials/beginner/chatbot_tutorial.html) from the official documentation page of PyTorch. While the tutorial itself can be quite interesting, our focus will be learning how to work with sequence data and recurrent neural networks. Hopefully, we'll never have to do anything with natural language models.

In [1]:
import codecs
import csv
import os
import re
import unicodedata

In [2]:
import torch
from torch import nn, optim
from torch.nn import functional as F

In [3]:
print('PyTorch version:', torch.__version__)

PyTorch version: 1.0.1.post2


In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device:', device)

device: cuda


## Data preprocessing

We'll start by downloading the [Cornell Movie-Dialogs Corpus](https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html) dataset.

In [5]:
%%bash
mkdir -p datasets
cd datasets
wget -q http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip
unzip -q cornell_movie_dialogs_corpus.zip
rm cornell_movie_dialogs_corpus.zip

In [6]:
dataset_path = 'datasets/cornell movie-dialogs corpus'
lines_path = os.path.join(dataset_path, 'movie_lines.txt')
convs_path = os.path.join(dataset_path, 'movie_conversations.txt')

In [7]:
def peek(filename, n=10):
    with open(filename, 'rb') as f:
        lines = f.readlines()
        for line in lines[:n]:
            print(line)

In [8]:
peek(lines_path)

b'L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!\n'
b'L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!\n'
b'L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.\n'
b'L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?\n'
b"L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.\n"
b'L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow\n'
b"L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie.\n"
b'L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No\n'
b'L870 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I\'m kidding.  You know how sometimes you just become this "persona"?  And you don\'t know how to quit?\n'
b'L869 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Like my fear of wearing pastels?\n'


Let's begin with preprocessing the data to the correct format!

In [9]:
def load_lines(filename, fields):
    lines = {}
    with open(filename, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(' +++$+++ ')
            line_obj = {}
            for i, field in enumerate(fields):
                line_obj[field] = values[i]
            lines[line_obj['lineID']] = line_obj
    return lines

In [10]:
def load_convs(filename, lines, fields):
    convs = []
    with open(filename, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(' +++$+++ ')
            conv_obj = {}
            for i,field in enumerate(fields):
                conv_obj[field] = values[i]
            line_ids = eval(conv_obj['utteranceIDs'])
            conv_obj['lines'] = []
            for line_id in line_ids:
                conv_obj['lines'].append(lines[line_id])
            convs.append(conv_obj)
    return convs

In [11]:
def extract_sentence_pairs(convs):
    qa_pairs = []
    for conv in convs:
        for i in range(len(conv['lines']) - 1):
            input_line = conv['lines'][i]['text'].strip()
            target_line = conv['lines'][i+1]['text'].strip()
            if input_line and target_line:
                qa_pairs.append((input_line, target_line))
    return qa_pairs

In [12]:
data_path = os.path.join(dataset_path, 'formatted_movie_lines.txt')
delimiter = str(codecs.decode('\t', 'unicode_escape'))

lines = {}
convs = {}
lines_fields = ['lineID', 'characterID', 'movieID', 'character', 'text']
convs_fields = ['character1ID', 'character2ID', 'movieID', 'utteranceIDs']

print("Processing corpus...", end='')
lines = load_lines(lines_path, lines_fields)
print("done")

print("Loading conversations...", end='')
convs = load_convs(convs_path, lines, convs_fields)
print("done")

print("Writing formatted file...", end='')
with open(data_path, 'w', encoding='utf-8') as f:
    writer = csv.writer(f, delimiter=delimiter, lineterminator='\n')
    for pair in extract_sentence_pairs(convs):
        writer.writerow(pair)
print("done")

Processing corpus...done
Loading conversations...done
Writing formatted file...done


In [13]:
peek(data_path)

b"Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\tWell, I thought we'd start with pronunciation, if that's okay with you.\n"
b"Well, I thought we'd start with pronunciation, if that's okay with you.\tNot the hacking and gagging and spitting part.  Please.\n"
b"Not the hacking and gagging and spitting part.  Please.\tOkay... then how 'bout we try out some French cuisine.  Saturday?  Night?\n"
b"You're asking me out.  That's so cute. What's your name again?\tForget it.\n"
b"No, no, it's my fault -- we didn't have a proper introduction ---\tCameron.\n"
b"Cameron.\tThe thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\n"
b"The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\tSeems like she could get a date easy enough...\n"
b'Why?\tUnsolved mystery.  She used t

Now, let's process the formatted data further to a suitable format for training. Here, we create some vocabulary features, which include mapping from words to their indices, reverse mapping from indices to words, a count of each word and total word count. Additionally, we'll trim down words whose counts are below a threshold.

In [29]:
class Voc:
    def __init__(self, name):
        self.name = name
        self.trimmed = False
        self.reset()
        
    def reset(self):
        self.words = ['PAD', 'SOS', 'EOS']
        self.word2index = {}
        self.word2count = {}
        self.count = 3
        
    def add_sentence(self, sentence):
        for word in sentence.split(' '):
            self.add_word(word)
            
    def add_word(self, word):
        if word not in self.word2index:
            self.words.append(word)
            self.word2index[word] = self.count
            self.word2count[word] = 1
            self.count += 1
        else:
            self.word2count[word] += 1
            
    def trim(self, min_count):
        if self.trimmed:
            return        
        trimmed = []
        for k, v in self.word2count.items():
            if v >= min_count:
                trimmed.append(k)
        print(f"Trimmed from {self.count} words to {len(trimmed)} words")
        print(f"Down to {len(trimmed)/self.count*100}%")
        self.reset()
        for word in trimmed:
            self.add_word(word)

In [15]:
def unicode2ascii(s):
    return ''.join(c for c in unicodedata.normalize('NFD', s)
            if unicodedata.category(c) != 'Mn')

In [16]:
def normalize(s):
    s = unicode2ascii(s.lower().strip())
    s = re.sub(r'([.!?])', r' \1', s)
    s = re.sub(r'[^a-zA-Z.!?]+', r' ', s)
    return re.sub(r'\s+', r' ', s).strip()

In [17]:
def read_vocs(filename, corpus):
    with open(filename, encoding='utf-8') as f:
        lines = f.read().strip().split('\n')
    pairs = [[normalize(s) for s in l.split('\t')] for l in lines]
    voc = Voc(corpus)
    return voc, pairs

In [18]:
def filter_pair(p, max_length):
    return (len(p[0].split(' ')) < max_length and
            len(p[1].split(' ')) < max_length)

In [19]:
def filter_pairs(pairs, max_length):
    return [p for p in pairs if filter_pair(p, max_length)]

In [20]:
def load_data(filename, corpus, max_length=10):
    voc, pairs = read_vocs(filename, corpus)
    print(f"Read {len(pairs)} sentence pairs")
    pairs = filter_pairs(pairs, max_length)
    print(f"Filtered to {len(pairs)} sentence pairs")
    for p in pairs:
        voc.add_sentence(p[0])
        voc.add_sentence(p[1])
    print(f"Word count: {voc.count}")
    return voc, pairs

In [31]:
voc, pairs = load_data(data_path, dataset_path)

Read 221282 sentence pairs
Filtered to 64271 sentence pairs
Word count: 18008


In [32]:
for pair in pairs[:10]:
    print(pair)

['there .', 'where ?']
['you have my word . as a gentleman', 'you re sweet .']
['hi .', 'looks like things worked out tonight huh ?']
['you know chastity ?', 'i believe we share an art instructor']
['have fun tonight ?', 'tons']
['well no . . .', 'then that s all you had to say .']
['then that s all you had to say .', 'but']
['but', 'you always been this selfish ?']
['do you listen to this crap ?', 'what crap ?']
['what good stuff ?', 'the real you .']


Here, we trim down words that are less frequently used. More specifically, we'll trim words that appear less than or equal to 3 times total in the corpus.

In [33]:
def trim_words(voc, pairs, min_count=3):
    voc.trim(min_count)
    trimmed = []
    for p in pairs:
        valid_input = all(w in voc.word2index for w in p[0].split(' '))
        valid_output = all(w in voc.word2index for w in p[1].split(' '))
        if valid_input and valid_output:
            trimmed.append(p)
    print(f"Trimmed from {len(pairs)} pairs to {len(trimmed)} pairs")
    print(f"Down to {len(trimmed)/len(pairs)*100}%")

In [34]:
pairs = trim_words(voc, pairs)

Trimmed from 18008 words to 7823 words
Down to 43.44180364282541%
Trimmed from 64271 pairs to 53165 pairs
Down to 82.72004481025657%


I guess this makes sense: 82% of our data uses 43% of the vocabulary more than 3 times, other 67% of the words are rather unique that they don't come up as often. Ideally, we'd like our agent to be able to quickly adopt those rare words, but we'll just pretend they don't exist.

## Data preparation for training