In [1]:
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as f
import csv 
import random
import re
import os
import unicodedata
import codecs
import itertools

In [2]:
CUDA = torch.cuda.is_available()
device = torch.device("cuda" if CUDA else "cpu")

# Data Preprocessing

In [3]:
lines_filepath = os.path.join("cornell movie-dialogs corpus","movie_lines.txt")
conv_filepath = os.path.join("cornell movie-dialogs corpus","movie_conversations.txt")

In [4]:
# Visualize some lines
with open(lines_filepath,'r',encoding="iso-8859-1") as file:
    lines = file.readlines()
    for line in lines[:8]:
        #print(line.strip())
        pass

In [5]:
# Splitting each line of the file into a dictioanry of fields (LineId,CharacterId,MovieId,Character,Text)
line_fields = ["lineId","characterId","movieId","character","text"]
lines = {}
with open(lines_filepath,'r',encoding="iso-8859-1") as f:
    for line in f:
        values = line.split(" +++$+++ ")
        # Extract fileds
        lineobj = {}
        for i,field in enumerate(line_fields):
            lineobj[field] = values[i]
            #print(lineobj)
        lines[lineobj['lineId']] = lineobj
        #print(lines)

In [6]:
# Groups fields of lines from 'Loadlines' into conversations based on "movie_conversations.txt"
conv_fields = ["character1Id","character2Id","movieId","utterenceIDs"]
conversations = []
with open(conv_filepath,'r',encoding='iso-8859-1') as f:
    for line in f:
        values = line.split(" +++$+++ ")
        # Extract Fields
        convObj = {}
        for i,field in enumerate(conv_fields):
            convObj[field] = values[i]
        # Convert string result from split to list, since convObj["utterenceIDs"] == "["L598765","L567890","...."]"
        lineIds = eval(convObj["utterenceIDs"])
        #print(lineIds)
        # Reassemble lines
        convObj["lines"] = []
        for lineId in lineIds:
            convObj["lines"].append(lines[lineId])
        conversations.append(convObj)

In [7]:
# Extract pairs of sentences from conversations
qa_pairs = []
for conversation in conversations:
    # Iterate over all the lines of conversation
    for i in range(len(conversation["lines"]) - 1):
        inputLine = conversation["lines"][i]["text"].strip()
        targetLine = conversation["lines"][i+1]["text"].strip()
        # Filter wrong samples if one of the lists is empty
        if inputLine and targetLine:
            qa_pairs.append([inputLine,targetLine])

In [8]:
# Define path to new file
datafile = os.path.join("cornell movie-dialogs corpus","formatted_movie_lines.txt")
delimiter = '\t'
# Unescape the delimiter
delimiter = str(codecs.decode(delimiter,"unicode_escape"))

# write new csv file
print("\n writing newly formatted file")
with open(datafile,"w",encoding='utf-8') as outputfile:
    writer = csv.writer(outputfile,delimiter=delimiter)
    for pair in qa_pairs:
        #print(pair[0])
        writer.writerow(pair)
print("Done writing to the file ")


 writing newly formatted file
Done writing to the file 


In [9]:
# Visualize some Lines
datafile = os.path.join("cornell movie-dialogs corpus","formatted_movie_lines.txt")
with open (datafile,'r',encoding="iso-8859-1") as file:
    lines = file.readlines()
    for line in lines[:8]:
        print(line)

Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.	Well, I thought we'd start with pronunciation, if that's okay with you.

Well, I thought we'd start with pronunciation, if that's okay with you.	Not the hacking and gagging and spitting part.  Please.

Not the hacking and gagging and spitting part.  Please.	Okay... then how 'bout we try out some French cuisine.  Saturday?  Night?

You're asking me out.  That's so cute. What's your name again?	Forget it.

No, no, it's my fault -- we didn't have a proper introduction ---	Cameron.

Cameron.	The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.

The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.	Seems like she could get a date easy enough...

Why?	Unsolved mystery.  She used to be really popular when she started h

In [10]:
PAD_TOKEN = 0
SOS_TOKEN = 0
EOS_TOKEN = 2

class Vocabulary:
    def __init__(self,name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_TOKEN:"PAD",SOS_TOKEN:"SOS",EOS_TOKEN:"EOS"}
        self.num_words = 3
        
    def addSentence(self,sentence):
        for word in sentence.split(' '):
            self.addWord(word)
            
    def addWord(self,word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words += 1
        else:
            self.word2count[word] += 1
            
    # Remove words below a certain count threshold
    def trim(self, min_count):
        keep_words = []
        for k,v in self.word2count.items():
            if v >= min_count:
                keep_words.append(k)
                
        print('keep_words {}/{} = {:.4f}'.format(len(keep_words),len(self.word2index),len(keep_words)/len(self.word2index)))
    
        # Reinitialize the dictionaries 
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_TOKEN:"PAD",SOS_TOKEN:"SOS",EOS_TOKEN:"EOS"}
        self.num_words = 3 
        
        for word in keep_words:
              self.addWord(word)
              

In [11]:
# Turn a unicode string to plain ASCII
import unicodedata
def unicodeToAscii(s):
    return ''.join(c for c in unicodedata.normalize('NFD',s) if unicodedata.category(c) != 'Mn')

In [12]:
''.join(['h','e','l'])

'hel'

In [13]:
unicodeToAscii("Adiós,Pequeño....")

'Adios,Pequeno....'

In [14]:
# Lowercase, trim white spaces, lines... etc, and remove non letters character.
import re
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    # Replace any .!? by a whitespace  + the character --> '!' = ' !'. \1 means that first bracketed group -->[,!] .
    # r is to not consider \1 as a character (r to escape a backslash). + means one or more
    s = re.sub(r"([.!?])",r" \1",s)
    # Remove character that is not a sequence of lower or uppercase
    s = re.sub(r"[^a-zA-Z.!?]+",r" ",s)
    # Remove a sequence of whitespace character
    s = re.sub(r"\s+",r" ",s).strip()
    return s

In [15]:
normalizeString("aaa123!s's   dd?")

'aaa !s s dd ?'

In [16]:
import os
datafile = os.path.join("cornell movie-dialogs corpus","formatted_movie_lines.txt")
# Read the files and split into lines
print("Reading and Processing the file........... please wait")
lines = open(datafile,encoding="iso-8859-1").read().strip().split('\n')
# Split every line into pairs and normalize
pairs = [[normalizeString(s) for s in pair.split('\t')] for pair in lines]
print("Done Reading !!!")
voc = Vocabulary("cornell movie-dialogs corpus")

Reading and Processing the file........... please wait
Done Reading !!!


In [17]:
lines[0].split('\t')

['Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.',
 "Well, I thought we'd start with pronunciation, if that's okay with you."]

In [18]:
# Returns true if both sentences in a pair 'p' are under the max length threshold
MAX_LENGTH = 10 # Maximum sentence length to consider
def filterPair(p):
    # Input sequences need to preserve the last word for EOS token
    return len(p[0].split()) < MAX_LENGTH and len(p[1].split()) < MAX_LENGTH

# Filter pairs using filter pair condition
def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

In [19]:
pairs = [pair for pair in pairs if len(pair) > 1]
print("There are {} pairs/conversations in the dataset ".format(len(pairs)))
pairs = filterPairs(pairs)
print("After filtering there are {} pairs/conversations ".format(len(pairs)))

There are 221282 pairs/conversations in the dataset 
After filtering there are 64243 pairs/conversations 


In [20]:
# Loop through each pair and add the question and reply e=sentence to the vocabulary
for pair in pairs:
    voc.addSentence(pair[0])
    voc.addSentence(pair[1])
print("counted words :", voc.num_words)
for pair in pairs[:10]:
    print(pair)
    

counted words : 18109
['there .', 'where ?']
['you have my word . as a gentleman', 'you re sweet .']
['hi .', 'looks like things worked out tonight huh ?']
['you know chastity ?', 'i believe we share an art instructor']
['have fun tonight ?', 'tons']
['well no . . .', 'then that s all you had to say .']
['then that s all you had to say .', 'but']
['but', 'you always been this selfish ?']
['do you listen to this crap ?', 'what crap ?']
['what good stuff ?', 'the real you .']


In [21]:
MIN_COUNT = 3   # minimum word count threshold for trimming
def trimRareWords(voc,pairs,MIN_COUNT):
    # Trim words used under the MIN_COUNT from the voc
    voc.trim(MIN_COUNT)
    # filter out paired with trimmed words
    keep_pairs = []
    for pair in pairs:
        input_sentence = pair[0]
        output_sentence = pair[1]
        keep_input = True
        keep_ourput = True
        
        # Check your sentence 
        for word in input_sentence.split(' '):
            if word not in voc.word2index:
                keep_input = False
                break
                
        # Check output sentence
        for word in output_sentence.split(' '):
            if word not in voc.word2index:
                keep_output = False
                break
                
        # only keep pairs that do not contain trimmed word(s) in their input or output sentence
        if keep_input and keep_ourput:
            keep_pairs.append(pair)
            
    print("Trimmed from {} pairs to {},{:.4f} of total ".format(len(pairs),len(keep_pairs),len(keep_pairs)/len(pairs)))
    return keep_pairs

# Print voc and pairs
pairs = trimRareWords(voc,pairs,MIN_COUNT)

keep_words 7840/18106 = 0.4330
Trimmed from 64243 pairs to 57971,0.9024 of total 


# Preparing the data

In [24]:
def indexFromSentence(voc,sentence):
    return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_TOKEN]

In [26]:
pairs[1][0]

'you have my word . as a gentleman'

In [27]:
# Test the function
indexFromSentence(voc,pairs[1][0])

[7, 8, 9, 10, 4, 11, 12, 13, 2]

In [30]:
# Define some samples for output
inp = []
out = []
for pair in pairs[:10]:
    inp.append(pair[0])
    out.append(pair[1])
print(inp)
print(len(inp))
indexes = [indexFromSentence(voc,sentence) for sentence in inp]
indexes

['there .', 'you have my word . as a gentleman', 'hi .', 'have fun tonight ?', 'well no . . .', 'then that s all you had to say .', 'but', 'do you listen to this crap ?', 'what good stuff ?', 'the real you .']
10


[[3, 4, 2],
 [7, 8, 9, 10, 4, 11, 12, 13, 2],
 [16, 4, 2],
 [8, 31, 22, 6, 2],
 [33, 34, 4, 4, 4, 2],
 [35, 36, 37, 38, 7, 39, 40, 41, 4, 2],
 [42, 2],
 [47, 7, 48, 40, 45, 49, 6, 2],
 [50, 51, 52, 6, 2],
 [53, 54, 7, 4, 2]]

In [31]:
def zeroPadding(l,fillvalue=0):
    return list(itertools.zip_longest(*1,fillvalue=fillvalue))

In [32]:
leng = [leng(ind) for ind in indexes]
max(leng)

NameError: name 'leng' is not defined