From <a href="https://pytorch.org/tutorials/beginner/chatbot_tutorial.html">Pytorch Tutorial</a> by Matthew Inkawhich

In [1]:
import codecs
import csv
import itertools
import math
import random
import re
import os
import unicodedata
from io import open

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.jit import script, trace

In [2]:
USE_CUDA = torch.cuda.is_available()
device = torch.device('cuda' if USE_CUDA else 'cpu')

# Load and Preprocess Data 
### (Cornell Movie-Dialogues Corpus)

In [3]:
DATA = '../../data'
corpus_name = 'cornell movie-dialogs corpus'
corpus = os.path.join(DATA, corpus_name)

In [4]:
def print_lines(file, n=10):
    with open(file, 'rb') as f:
        lines = f.readlines()
    for line in lines[:n]:
        print(line)

In [5]:
print_lines(os.path.join(corpus, 'movie_lines.txt'))

b'L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!\n'
b'L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!\n'
b'L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.\n'
b'L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?\n'
b"L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.\n"
b'L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow\n'
b"L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie.\n"
b'L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No\n'
b'L870 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I\'m kidding.  You know how sometimes you just become this "persona"?  And you don\'t know how to quit?\n'
b'L869 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Like my fear of wearing pastels?\n'


### Create Formatted Data File

In [6]:
def load_lines(file_name, fields):
    '''Splits each line of the file into dict of fields'''
    lines = {}
    with open(file_name, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(' +++$+++ ')
            line_obj = {}
            for i, field in enumerate(fields):
                line_obj[field] = values[i]
            lines[line_obj['lineID']] = line_obj
    return lines

In [18]:
def load_conversations(file_name, lines, fields):
    '''
    Groups fields of lines from load_lines() into conversations based on
    movie_conversations.txt
    '''
    conversations = []
    with open(file_name, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(' +++$+++ ')
            conv_obj = {}
            for i, field in enumerate(fields):
                conv_obj[field] = values[i]
            line_ids = eval(conv_obj['utteranceIDs'])
            conv_obj['lines'] = []
            for line_id in line_ids:
                conv_obj['lines'].append(lines[line_id])
            conversations.append(conv_obj)
    return conversations

In [8]:
def extract_sentence_pairs(conversations):
    '''Extract pairs of sentences from conversations'''
    qa_pairs = []
    for conversation in conversations:
        for i in range(len(conversation['lines']) - 1):
            input_line = conversation['lines'][i]['text'].strip()
            target_line = conversation['lines'][i + 1]['text'].strip()
            # Filter if one of the lists is empty
            if input_line and target_line:
                qa_pairs.append([input_line, target_line])
    return qa_pairs

In [9]:
datafile = os.path.join(corpus, 'formatted_movie_lines.txt')
delimiter = '\t'
delimiter = str(codecs.decode(delimiter, 'unicode_escape'))

In [11]:
lines = {}
conversations = []
MOVIE_LINES_FIELDS = [
    'lineID', 'characterID', 'movieID', 'character', 'text']
MOVIE_CONVERSATION_FIELDS = [
    'character1ID', 'character2ID', 'movieID', 'utteranceIDs']

In [19]:
print('\nProcessing corpus...')
lines = load_lines(os.path.join(corpus, 'movie_lines.txt'), 
                   MOVIE_LINES_FIELDS)

print('\nLoading conversations...')
conversations = load_conversations(
    os.path.join(corpus, 'movie_conversations.txt'), 
    lines,
    MOVIE_CONVERSATION_FIELDS)


Processing corpus...

Loading conversations...


In [20]:
print('\nWriting newly formatted file...')
with open(datafile, 'w', encoding='utf-8') as out:
    writer = csv.writer(out, delimiter=delimiter)
    for pair in extract_sentence_pairs(conversations):
        writer.writerow(pair)


Writing newly formatted file...


In [22]:
print('\nSample lines from file')
print_lines(datafile)


Sample lines from file
b"Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\tWell, I thought we'd start with pronunciation, if that's okay with you.\r\n"
b"Well, I thought we'd start with pronunciation, if that's okay with you.\tNot the hacking and gagging and spitting part.  Please.\r\n"
b"Not the hacking and gagging and spitting part.  Please.\tOkay... then how 'bout we try out some French cuisine.  Saturday?  Night?\r\n"
b"You're asking me out.  That's so cute. What's your name again?\tForget it.\r\n"
b"No, no, it's my fault -- we didn't have a proper introduction ---\tCameron.\r\n"
b"Cameron.\tThe thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\r\n"
b"The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\tSeems like she could get a date easy enough...\r\n"

# Load and Trim Data