# Neural Conversational Model
**Jin Yeom**  
jin.yeom@hudl.com

This notebook reproduces [this tutorial](https://pytorch.org/tutorials/beginner/chatbot_tutorial.html) from the official documentation page of PyTorch. While the tutorial itself is quite interesting, our focus will be learning how to work with sequence data and recurrent neural networks.

In [42]:
import os
import codecs
import csv

In [8]:
import torch
from torch import nn, optim
from torch.nn import functional as F

In [9]:
print('PyTorch version:', torch.__version__)

PyTorch version: 1.0.1.post2


In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device:', device)

device: cuda


## Dataset

We'll start by downloading the [Cornell Movie-Dialogs Corpus](https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html) dataset.

In [16]:
%%bash
mkdir datasets
cd datasets
wget -q http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip
unzip -q cornell_movie_dialogs_corpus.zip
rm cornell_movie_dialogs_corpus.zip

In [39]:
dataset_path = 'datasets/cornell movie-dialogs corpus'
lines_path = os.path.join(dataset_path, 'movie_lines.txt')
convs_path = os.path.join(dataset_path, 'movie_conversations.txt')

In [36]:
def peek(filename, n=10):
    with open(filename, 'rb') as f:
        lines = f.readlines()
        for line in lines[:n]:
            print(line)

In [40]:
peek(lines_path)

b'L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!\n'
b'L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!\n'
b'L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.\n'
b'L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?\n'
b"L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.\n"
b'L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow\n'
b"L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie.\n"
b'L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No\n'
b'L870 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I\'m kidding.  You know how sometimes you just become this "persona"?  And you don\'t know how to quit?\n'
b'L869 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Like my fear of wearing pastels?\n'


Now, for some data preprocessing!

In [43]:
def load_lines(filename, fields):
    lines = {}
    with open(filename, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(' +++$+++ ')
            line_obj = {}
            for i, field in enumerate(fields):
                line_obj[field] = values[i]
            lines[line_obj['lineID']] = line_obj
    return lines

In [44]:
def load_convs(filename, lines, fields):
    convs = []
    with open(filename, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(' +++$+++ ')
            conv_obj = {}
            for i,field in enumerate(fields):
                conv_obj[field] = values[i]
            line_ids = eval(conv_obj['utteranceIDs'])
            conv_obj['lines'] = []
            for line_id in line_ids:
                conv_obj['lines'].append(lines[line_id])
            convs.append(conv_obj)
    return convs

In [45]:
def extract_sentence_pairs(convs):
    qa_pairs = []
    for conv in convs:
        for i in range(len(conv['lines']) - 1):
            input_line = conv['lines'][i]['text'].strip()
            target_line = conv['lines'][i+1]['text'].strip()
            if input_line and target_line:
                qa_pairs.append((input_line, target_line))
    return qa_pairs

In [48]:
data_path = os.path.join(dataset_path, 'formatted_movie_lines.txt')
delimiter = str(codecs.decode('\t', 'unicode_escape'))

lines = {}
convs = {}
lines_fields = ['lineID', 'characterID', 'movieID', 'character', 'text']
convs_fields = ['character1ID', 'character2ID', 'movieID', 'utteranceIDs']

print("Processing corpus...", end='')
lines = load_lines(lines_path, lines_fields)
print("done")

print("Loading conversations...", end='')
convs = load_convs(convs_path, lines, convs_fields)
print("done")

print("Writing formatted file...", end='')
with open(data_path, 'w', encoding='utf-8') as f:
    writer = csv.writer(f, delimiter=delimiter, lineterminator='\n')
    for pair in extract_sentence_pairs(convs):
        writer.writerow(pair)
print("done")

Processing corpus...done
Loading conversations...done
Writing formatted file...done


In [49]:
peek(data_path)

b"Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\tWell, I thought we'd start with pronunciation, if that's okay with you.\n"
b"Well, I thought we'd start with pronunciation, if that's okay with you.\tNot the hacking and gagging and spitting part.  Please.\n"
b"Not the hacking and gagging and spitting part.  Please.\tOkay... then how 'bout we try out some French cuisine.  Saturday?  Night?\n"
b"You're asking me out.  That's so cute. What's your name again?\tForget it.\n"
b"No, no, it's my fault -- we didn't have a proper introduction ---\tCameron.\n"
b"Cameron.\tThe thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\n"
b"The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\tSeems like she could get a date easy enough...\n"
b'Why?\tUnsolved mystery.  She used t