In [150]:
import torch
import torch.nn as nn #Neural Networks package
from torch import optim #Optimisers
import torch.nn.functional as F 
import csv
import random
import re
import os
import unicodedata
import codecs
import itertools

In [151]:
CUDA = torch.cuda.is_available()
device = torch.device("cuda" if CUDA else "cpu")

### Part 1: Preprocessing

In [152]:
lines_filepath = os.path.join('cornell movie-dialogs corpus', 'movie_lines.txt')
conv_filepath = os.path.join('cornell movie-dialogs corpus', 'movie_conversations.txt')

In [153]:
#Visualise some lines
with open(lines_filepath ,'r', errors="ignore") as file:
    lines = file.readlines()
for line in lines[:8]:
    print(line.strip())

L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!
L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!
L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.
L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?
L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.
L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow
L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie.
L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No


In [154]:
#Splits each line of the file into a dictionary of fields(lineID, characterID, movieID, character, text)
lines_fields = ['lineID', 'characterID', 'movieID', 'character', 'text']
lines = {}
with open(lines_filepath, 'r', encoding='iso-8859-1') as f:
    for line in f:
        values = line.split(' +++$+++ ')
        #Extract fields
        lineObj = {}
        for i, field in enumerate(lines_fields):
            lineObj[field] = values[i]
        lines[lineObj['lineID']] = lineObj

In [155]:
#Groups fields of lines from 'LoadLines' into conversations based on "movie_conversations.txt"
conv_fields = ['characterID', 'character2ID', 'movieID', 'utteranceIDs']
conversations = []
with open(conv_filepath, 'r', encoding='iso-8859-1') as f:
    for line in f:
        values = line.split(' +++$+++ ')
        #Extract fields
        convObj = {}
        for i, field in enumerate(conv_fields):
            convObj[field] = values[i]
        #Convert string result from split to list, since convObj['utteranceIDs'] == "['id123', 'id123213', ...]"
        lineIds = eval(convObj['utteranceIDs'])
        #Reassemble lines
        convObj['lines'] = []
        for lineId in lineIds:
            convObj['lines'].append(lines[lineId])
        conversations.append(convObj)

In [156]:
#Extract pairs of sentences from conversations
qa_pairs = []
for conversation in conversations:
    #Iterate over all the lines of the conversation
    for i in range(len(conversation["lines"]) -1):
        inputLine = conversation['lines'][i]['text'].strip()
        targetLine = conversation['lines'][i+1]['text'].strip()
        if inputLine and targetLine:
            qa_pairs.append([inputLine, targetLine])

In [157]:
#Define path to new file
datafile = os.path.join('cornell movie-dialogs corpus', 'formatted_movie_lines.txt')
delimiter = '\t'
#Unescape the delimiter
delimiter = str(codecs.decode(delimiter, 'unicode_escape'))

#Write new csv file
print('\nWriting newly formatted fisle...')
with open(datafile, "w", encoding="utf-8") as outputfile:
    writer = csv.writer(outputfile, delimiter=delimiter)
    for pair in qa_pairs:
        writer.writerow(pair)
print('Done writing to file')


Writing newly formatted fisle...
Done writing to file


In [158]:
#Visualize some lines
datafile = os.path.join('cornell movie-dialogs corpus', 'formatted_movie_lines.txt')
with open(datafile, 'rb') as file:
    lines = file.readlines()
for line in lines[:8]:
    print(line)

b"Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\tWell, I thought we'd start with pronunciation, if that's okay with you.\r\n"
b"Well, I thought we'd start with pronunciation, if that's okay with you.\tNot the hacking and gagging and spitting part.  Please.\r\n"
b"Not the hacking and gagging and spitting part.  Please.\tOkay... then how 'bout we try out some French cuisine.  Saturday?  Night?\r\n"
b"You're asking me out.  That's so cute. What's your name again?\tForget it.\r\n"
b"No, no, it's my fault -- we didn't have a proper introduction ---\tCameron.\r\n"
b"Cameron.\tThe thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\r\n"
b"The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\tSeems like she could get a date easy enough...\r\n"
b'Why?\tUnsolved myster

In [159]:
PAD_token = 0 # Used for padding short sentences
SOS_token = 1 # Start-of-sentence token <START>
EOS_token = 2 # End-of-sentence token <END>

class Vocabulary:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3 #Count SOS, EOS, PAD
    
    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)
            
    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words +=1
        else:
            self.word2count[word] += 1
        
    # Remove words below certain threshold    
    def trim(self, min_count):
        keep_words = []
        for k, v in self.word2count.items():
            if v >= min_count:
                keep_words.append(k)
        print('keep_words {} / {} = {:.4f}'
              .format(len(keep_words), len(self.word2count), len(keep_words) / len(self.word2index)))
            
        #Reinitialize dictionaries
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3 #Count default tokens
        
        for word in keep_words:
            self.addWord(word)

In [160]:
# Turn unicode string to plan ASCII
def unicodeToAscii(s):
    return ''.join(c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn')

In [161]:
#Lowercase, trim spaces, lines... etc, and remove non-letter characters.
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    # Replace any .!? by a whitespace + the character --> '!' = ' !'
    s = re.sub(r"([.!?])", r" \1", s)
    # Remove any character that is not a sequence of lower or upper case letters
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    # Remove sequence of whitespace chracters
    s = re.sub(r"\s+", r" ", s).strip()
    return s

In [162]:
datafile = os.path.join('cornell movie-dialogs corpus', 'formatted_movie_lines.txt')
# Read the file and split into lines
print('Reading nad processing file....Please Wait')
lines = open(datafile, encoding='utf-8').read().strip().split('\n')
# Split every line into pairs and normalize
pairs = [[normalizeString(s) for s in pair.split('\t')] for pair in lines]
print('total pairs: ', len(pairs))
print('Done Reading!')
voc = Vocabulary('cornell movie-dialogs corpus')

Reading nad processing file....Please Wait
total pairs:  221282
Done Reading!


In [163]:
# Returns true if both sentences in a pair 'p' are under the MAX threshold
MAX_LENGTH = 10 # Maxiumum sentence length to consider (max words)
def filterPair(p):
    return len(p[0].split()) < MAX_LENGTH and len(p[1].split()) < MAX_LENGTH

# Filter pairs using filterPair condition
def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

In [164]:
pairs = [pair for pair in pairs if len(pair)>1]
print("There are {} pairs/conversations in the dataset".format(len(pairs)))
pairs = filterPairs(pairs)
print("After filtering, there are {} pairs/conversations".format(len(pairs)))

There are 221282 pairs/conversations in the dataset
After filtering, there are 64271 pairs/conversations


In [165]:
#Loop through each pair of sentences and add the question and reply to the vocabulary 
for pair in pairs:
    for sentence in pair: 
        voc.addSentence(sentence)
        
print('Counted words: ', voc.num_words) 

Counted words:  18054


In [166]:
MIN_COUNT = 3

def trimRareWords(voc, pairs, MIN_COUNT):
    # Trim words used under the MIN_COUNT threshold for trimming
    voc.trim(MIN_COUNT)
    # Filter out pairs with trimmed words
    keep_pairs = []
    for pair in pairs:
        input_sentence = pair[0]
        output_sentence = pair[1]
        keep_input = True
        keep_output = True
        #Check input sentence
        for word in input_sentence.split(' '):
            if word not in voc.word2index:
                keep_input = False
                break
        #Check output sentence
        for word in output_sentence.split(' '):
            if word not in voc.word2index:
                keep_output = False
                break
        
        # Only keeps pairs that do not contain trimmed word(s) in their input or output sentence
        if keep_input and keep_output:
            keep_pairs.append(pair)
    
    print("Trimmed from {} pairs to {}, {:4} of total".format(len(pairs), len(keep_pairs), len(keep_pairs) / len(pairs)))
    return keep_pairs

# Trim voc and pairs
pairs = trimRareWords(voc, pairs, MIN_COUNT)

keep_words 7840 / 18051 = 0.4343
Trimmed from 64271 pairs to 53191, 0.8276049851410434 of total


In [None]:
def indexesFromSentence(voc, sentence):
    return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]
