In [1]:
from collections import Counter
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torch.utils.data
import math
import torch.nn.functional as F

In [2]:
corpus_movie_conv = "./data/Cornell Movie Dataset/movie_conversations.txt"
corpus_movie_lines = "./data/Cornell Movie Dataset/movie_lines.txt"
max_sentence_len = 25

In [3]:
with open(corpus_movie_conv, 'r') as c:
    conv = c.readlines()

In [4]:
with open(corpus_movie_lines, 'r', encoding="latin1") as l:
    lines = l.readlines()

In [5]:
lines[0].split(" +++$+++ ")

['L1045', 'u0', 'm0', 'BIANCA', 'They do not!\n']

In [6]:
lines_dict = {}
for line in lines:
    objects = line.split( ' +++$+++ ')
    lines_dict[objects[0]] = objects[-1]

In [7]:
def remove_punc(string):
    punctuations = '''!()-[]{};:\,{}./'@#$%^&*_~"'''
    no_punctuation = ""
    for char in string:
        if char not in punctuations:
            no_punctuation += char
    return no_punctuation.lower()

In [22]:
pairs = []

for con in conv:
    ids = eval(con.split(' +++$+++ ')[-1])
    for i in range(len(ids)):
        qa_pairs = []

        if i == len(ids) - 1:
            break

        first = remove_punc(lines_dict[ids[i]].strip())
        second = remove_punc(lines_dict[ids[i+1]].strip())

        qa_pairs.append(first.split()[:max_sentence_len])
        qa_pairs.append(second.split()[:max_sentence_len])

        pairs.append(qa_pairs)

In [23]:
word_freq = Counter()

for pair in pairs:
    word_freq.update(pair[0])
    word_freq.update(pair[1])


In [38]:
min_word_freq = 5
words = [w for w in word_freq.keys() if word_freq[w] > min_word_freq]
word_map = {k: v + 1 for v, k in enumerate(words)}
word_map['<unk>'] = len(word_map) + 1
word_map['<start>'] = len(word_map) + 1
word_map['<end>'] = len(word_map) + 1
word_map['<pad>'] = 0

In [39]:
print(f"Total words are {len(word_map)}")

Total words are 20079


In [40]:
with open('./data/WORDMAP_corpus.json', 'w') as j:
    json.dump(word_map, j)

In [41]:
def encode_question(words, word_map):
    return [word_map.get(word, word_map['<unk>']) for word in words] + [word_map['<pad>']] * (max_sentence_len - len(words))

In [42]:
def encode_reply(words, word_map):
    return [word_map['<start>']] + [word_map.get(word, word_map['<unk>']) for word in words] + [word_map['<end>']] + [word_map['<pad>']] * (max_sentence_len - len(words))

In [43]:
pairs_encoded = []
for pair in pairs:
    question = encode_question(pair[0], word_map)
    answer = encode_reply(pair[1], word_map)
    pairs_encoded.append([question, answer])

In [68]:
with open('./data/pairs_encoded.json', 'w') as j:
    json.dump(pairs_encoded, j)

In [69]:
class Dataset(Dataset):

    def __init__(self):
        self.pairs = json.load(open('./data/pairs_encoded.json'))
        self.dataset_size = len(self.pairs)

    def __getitem__(self, i):
        
        question = torch.LongTensor(self.pairs[i][0])
        reply = torch.LongTensor(self.pairs[i][1])
            
        return question, reply

    
    def __len__(self):
        return self.dataset_size


In [70]:
train_loader = torch.utils.data.DataLoader(Dataset(),
                                           batch_size=100,
                                           shuffle=True,
                                           pin_memory=True)

In [75]:
def create_masks(question, reply_input, reply_target):

    def subsequent_mask(size):
        mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
        return mask.unsqueeze(0)

    question_mask = (question != 0).to(device)
    question_mask = question_mask.unsqueeze(1).unsqueeze(1)

    reply_input_mask = reply_input != 0
    reply_input_mask = reply_input_mask.unsqueeze(1)
    reply_input_mask = reply_input_mask & subsequent_mask(reply_input.size(-1)).type_as(reply_input_mask.data)
    reply_input_mask = reply_input_mask.unsqueeze(1)
    reply_target_mask = reply_target != 0

    return question_mask, reply_input_mask, reply_target_mask


In [None]:
class Embeddings(nn.Module):
    def __init__(self, vocab_size, d_model, max_len = 50):
        super(Embeddings, self).__init__()
        self.d_model = d_model
        self.dropout = nn.Dropout(0.1)
        self.embed = nn.Embeddings(vocab_size, d_model)