<a href="https://colab.research.google.com/github/endteamschoolofai/END/blob/main/Session13/Assignment_13_with_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.jit import script, trace

import torchtext
from torchtext.datasets import Multi30k
from torchtext.data import Field, BucketIterator

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

import spacy
import numpy as np

import random
import math
import time
import csv
import re
import os
import unicodedata
import codecs
import itertools
from io import open

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals


In [2]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [4]:
mkdir data/

In [5]:
cd data/

/content/data


In [6]:
!wget http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip

--2021-02-25 18:14:46--  http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip
Resolving www.cs.cornell.edu (www.cs.cornell.edu)... 132.236.207.36
Connecting to www.cs.cornell.edu (www.cs.cornell.edu)|132.236.207.36|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 9916637 (9.5M) [application/zip]
Saving to: ‘cornell_movie_dialogs_corpus.zip’


2021-02-25 18:14:47 (11.0 MB/s) - ‘cornell_movie_dialogs_corpus.zip’ saved [9916637/9916637]



In [7]:
!unzip cornell_movie_dialogs_corpus.zip

Archive:  cornell_movie_dialogs_corpus.zip
   creating: cornell movie-dialogs corpus/
  inflating: cornell movie-dialogs corpus/.DS_Store  
   creating: __MACOSX/
   creating: __MACOSX/cornell movie-dialogs corpus/
  inflating: __MACOSX/cornell movie-dialogs corpus/._.DS_Store  
  inflating: cornell movie-dialogs corpus/chameleons.pdf  
  inflating: __MACOSX/cornell movie-dialogs corpus/._chameleons.pdf  
  inflating: cornell movie-dialogs corpus/movie_characters_metadata.txt  
  inflating: cornell movie-dialogs corpus/movie_conversations.txt  
  inflating: cornell movie-dialogs corpus/movie_lines.txt  
  inflating: cornell movie-dialogs corpus/movie_titles_metadata.txt  
  inflating: cornell movie-dialogs corpus/raw_script_urls.txt  
  inflating: cornell movie-dialogs corpus/README.txt  
  inflating: __MACOSX/cornell movie-dialogs corpus/._README.txt  


In [8]:
cd ..

/content


In [9]:
corpus_name = "cornell movie-dialogs corpus"
corpus = os.path.join("data", corpus_name)

def printLines(file, n=50):
    with open(file, 'rb') as datafile:
        lines = datafile.readlines()
    for line in lines[:n]:
        print(line)

printLines(os.path.join(corpus, "movie_lines.txt"))

b'L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!\n'
b'L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!\n'
b'L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.\n'
b'L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?\n'
b"L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.\n"
b'L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow\n'
b"L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie.\n"
b'L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No\n'
b'L870 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I\'m kidding.  You know how sometimes you just become this "persona"?  And you don\'t know how to quit?\n'
b'L869 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Like my fear of wearing pastels?\n'
b'L868 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ The "real you".\n'
b'L867 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ What good stuff?\n'
b"L866 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ 

In [10]:
# Splits each line of the file into a dictionary of fields
def loadLines(fileName, fields):
    lines = {}
    with open(fileName, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(" +++$+++ ")
            # Extract fields
            lineObj = {}
            for i, field in enumerate(fields):
                lineObj[field] = values[i]
            lines[lineObj['lineID']] = lineObj
    return lines


# Groups fields of lines from `loadLines` into conversations based on *movie_conversations.txt*
def loadConversations(fileName, lines, fields):
    conversations = []
    with open(fileName, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(" +++$+++ ")
            # Extract fields
            convObj = {}
            for i, field in enumerate(fields):
                convObj[field] = values[i]
            # Convert string to list (convObj["utteranceIDs"] == "['L598485', 'L598486', ...]")
            utterance_id_pattern = re.compile('L[0-9]+')
            lineIds = utterance_id_pattern.findall(convObj["utteranceIDs"])
            # Reassemble lines
            convObj["lines"] = []
            for lineId in lineIds:
                convObj["lines"].append(lines[lineId])
            conversations.append(convObj)
    return conversations


# Extracts pairs of sentences from conversations
def extractSentencePairs(conversations):
    qa_pairs = []
    for conversation in conversations:
        # Iterate over all the lines of the conversation
        for i in range(len(conversation["lines"]) - 1):  # We ignore the last line (no answer for it)
            inputLine = conversation["lines"][i]["text"].strip()
            targetLine = conversation["lines"][i+1]["text"].strip()
            # Filter wrong samples (if one of the lists is empty)
            if inputLine and targetLine:
                qa_pairs.append([inputLine, targetLine])
    return qa_pairs

In [11]:
# Define path to new file
datafile = os.path.join(corpus, "formatted_movie_lines.txt")

delimiter = '\t'
# Unescape the delimiter
delimiter = str(codecs.decode(delimiter, "unicode_escape"))

# Initialize lines dict, conversations list, and field ids
lines = {}
conversations = []
MOVIE_LINES_FIELDS = ["lineID", "characterID", "movieID", "character", "text"]
MOVIE_CONVERSATIONS_FIELDS = ["character1ID", "character2ID", "movieID", "utteranceIDs"]

# Load lines and process conversations
print("\nProcessing corpus...")
lines = loadLines(os.path.join(corpus, "movie_lines.txt"), MOVIE_LINES_FIELDS)
print("\nLoading conversations...")
conversations = loadConversations(os.path.join(corpus, "movie_conversations.txt"),
                                  lines, MOVIE_CONVERSATIONS_FIELDS)
print(conversations[:10])

# Write new csv file
print("\nWriting newly formatted file...")
with open(datafile, 'w', encoding='utf-8') as outputfile:
    writer = csv.writer(outputfile, delimiter=delimiter, lineterminator='\n')
    for pair in extractSentencePairs(conversations):
        writer.writerow(pair)

# Print a sample of lines
print("\nSample lines from file:")
printLines(datafile)


Processing corpus...

Loading conversations...
[{'character1ID': 'u0', 'character2ID': 'u2', 'movieID': 'm0', 'utteranceIDs': "['L194', 'L195', 'L196', 'L197']\n", 'lines': [{'lineID': 'L194', 'characterID': 'u0', 'movieID': 'm0', 'character': 'BIANCA', 'text': 'Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\n'}, {'lineID': 'L195', 'characterID': 'u2', 'movieID': 'm0', 'character': 'CAMERON', 'text': "Well, I thought we'd start with pronunciation, if that's okay with you.\n"}, {'lineID': 'L196', 'characterID': 'u0', 'movieID': 'm0', 'character': 'BIANCA', 'text': 'Not the hacking and gagging and spitting part.  Please.\n'}, {'lineID': 'L197', 'characterID': 'u2', 'movieID': 'm0', 'character': 'CAMERON', 'text': "Okay... then how 'bout we try out some French cuisine.  Saturday?  Night?\n"}]}, {'character1ID': 'u0', 'character2ID': 'u2', 'movieID': 'm0', 'utteranceIDs': "['L198', 'L199']\n", 'lines': 

In [12]:
# Default word tokens
PAD_token = 0  # Used for padding short sentences
SOS_token = 1  # Start-of-sentence token
EOS_token = 2  # End-of-sentence token

class Voc:
    def __init__(self, name):
        self.name = name
        self.trimmed = False
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3  # Count SOS, EOS, PAD

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words += 1
        else:
            self.word2count[word] += 1

    # Remove words below a certain count threshold
    def trim(self, min_count):
        if self.trimmed:
            return
        self.trimmed = True

        keep_words = []

        for k, v in self.word2count.items():
            if v >= min_count:
                keep_words.append(k)

        print('keep_words {} / {} = {:.4f}'.format(
            len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
        ))

        # Reinitialize dictionaries
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3 # Count default tokens

        for word in keep_words:
            self.addWord(word)

In [13]:
MAX_LENGTH = 50  # Maximum sentence length to consider

# Turn a Unicode string to plain ASCII, thanks to
# https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    s = re.sub(r"\s+", r" ", s).strip()
    return s

# Read query/response pairs and return a voc object
def readVocs(datafile, corpus_name):
    print("Reading lines...")
    # Read the file and split into lines
    lines = open(datafile, encoding='utf-8').\
        read().strip().split('\n')
    # Split every line into pairs and normalize
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
    voc = Voc(corpus_name)
    return voc, pairs

# Returns True iff both sentences in a pair 'p' are under the MAX_LENGTH threshold
def filterPair(p):
    # Input sequences need to preserve the last word for EOS token
    return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH

# Filter pairs using filterPair condition
def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

# Using the functions defined above, return a populated voc object and pairs list
def loadPrepareData(corpus, corpus_name, datafile, save_dir):
    print("Start preparing training data ...")
    voc, pairs = readVocs(datafile, corpus_name)
    print("Read {!s} sentence pairs".format(len(pairs)))
    pairs = filterPairs(pairs)
    print("Trimmed to {!s} sentence pairs".format(len(pairs)))
    print("Counting words...")
    for pair in pairs:
        voc.addSentence(pair[0])
        voc.addSentence(pair[1])
    print("Counted words:", voc.num_words)
    return voc, pairs


# Load/Assemble voc and pairs
save_dir = os.path.join("data", "save")
voc, pairs = loadPrepareData(corpus, corpus_name, datafile, save_dir)
# Print some pairs to validate
print("\npairs:")
for pair in pairs[:10]:
    print(pair)

Start preparing training data ...
Reading lines...
Read 221282 sentence pairs
Trimmed to 210810 sentence pairs
Counting words...
Counted words: 45895

pairs:
['can we make this quick ? roxanne korrine and andrew barrett are having an incredibly horrendous public break up on the quad . again .', 'well i thought we d start with pronunciation if that s okay with you .']
['well i thought we d start with pronunciation if that s okay with you .', 'not the hacking and gagging and spitting part . please .']
['not the hacking and gagging and spitting part . please .', 'okay . . . then how bout we try out some french cuisine . saturday ? night ?']
['you re asking me out . that s so cute . what s your name again ?', 'forget it .']
['no no it s my fault we didn t have a proper introduction', 'cameron .']
['cameron .', 'the thing is cameron i m at the mercy of a particularly hideous breed of loser . my sister . i can t date until she does .']
['the thing is cameron i m at the mercy of a particularl

In [14]:
MIN_COUNT = 3    # Minimum word count threshold for trimming

def trimRareWords(voc, pairs, MIN_COUNT):
    # Trim words used under the MIN_COUNT from the voc
    voc.trim(MIN_COUNT)
    # Filter out pairs with trimmed words
    keep_pairs = []
    for pair in pairs:
        input_sentence = pair[0]
        output_sentence = pair[1]
        keep_input = True
        keep_output = True
        # Check input sentence
        for word in input_sentence.split(' '):
            if word not in voc.word2index:
                keep_input = False
                break
        # Check output sentence
        for word in output_sentence.split(' '):
            if word not in voc.word2index:
                keep_output = False
                break

        # Only keep pairs that do not contain trimmed word(s) in their input or output sentence
        if keep_input and keep_output:
            keep_pairs.append(pair)

    print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(pairs), len(keep_pairs), len(keep_pairs) / len(pairs)))
    return keep_pairs


# Trim voc and pairs
pairs = trimRareWords(voc, pairs, MIN_COUNT)

keep_words 25809 / 45892 = 0.5624
Trimmed from 210810 pairs to 187173, 0.8879 of total


In [15]:
 def indexesFromSentence(voc, sentence):
    return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]


def zeroPadding(l, fillvalue=PAD_token):
    return list(itertools.zip_longest(*l, fillvalue=fillvalue))

def binaryMatrix(l, value=PAD_token):
    m = []
    for i, seq in enumerate(l):
        m.append([])
        for token in seq:
            if token == PAD_token:
                m[i].append(0)
            else:
                m[i].append(1)
    return m

# Returns padded input sequence tensor and lengths
def inputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    padVar = torch.LongTensor(padList)
    return padVar, lengths

# Returns padded target sequence tensor, padding mask, and max target length
def outputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    max_target_len = max([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    mask = binaryMatrix(padList)
    mask = torch.BoolTensor(mask)
    padVar = torch.LongTensor(padList)
    return padVar, mask, max_target_len

# Returns all items for a given batch of pairs
def batch2TrainData(voc, pair_batch):
    pair_batch.sort(key=lambda x: len(x[0].split(" ")), reverse=True)
    input_batch, output_batch = [], []
    for pair in pair_batch:
        input_batch.append(pair[0])
        output_batch.append(pair[1])
    inp, lengths = inputVar(input_batch, voc)
    output, mask, max_target_len = outputVar(output_batch, voc)
    return inp, lengths, output, mask, max_target_len


# Example for validation
small_batch_size = 5
batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
input_variable, lengths, target_variable, mask, max_target_len = batches

print("input_variable:", input_variable)
print("lengths:", lengths)
print("target_variable:", target_variable)
print("mask:", mask)
print("max_target_len:", max_target_len)

input_variable: tensor([[ 168,  224,  117,   42,   56],
        [1464,   34,   34,  547,    8],
        [  34,  117,  181,   12,    2],
        [ 473,   21,   21,   34,    0],
        [ 209, 6996,  321,    8,    0],
        [ 126,  777, 1236,    2,    0],
        [  25,  674,   22,    0,    0],
        [ 600,    8,    2,    0,    0],
        [  34,    2,    0,    0,    0],
        [  22,    0,    0,    0,    0],
        [   2,    0,    0,    0,    0]])
lengths: tensor([11,  9,  8,  6,  3])
target_variable: tensor([[  126,    34,    25,   303,   931],
        [   34,    51,   258,  1039,   241],
        [  600,    35,    65,    22,     2],
        [   53,   667,   211,     2,     0],
        [   34,  2214,   181,     0,     0],
        [  600,    22, 10418,     0,     0],
        [   53,     5,    22,     0,     0],
        [    8,  2214,    54,     0,     0],
        [    2,    22,    25,     0,     0],
        [    0,     2,   366,     0,     0],
        [    0,     0,    65,     0,  

In [16]:
nTotal = mask.sum()
nTotal.item()

45

In [17]:
def batchTrainData(voc = voc, pair_batch= pairs):
    pair_batch.sort(key=lambda x: len(x[0].split(" ")), reverse=True)
    input_batch, output_batch = [], []
    input_token, output_token = [], []
    for pair in pair_batch:
        input_batch.append(pair[0])
        output_batch.append(pair[1])
        token = [words for words in pair[0].split()] 
        input_token.append(token)
        token = [words for words in pair[1].split()] 
        output_token.append(token)
    #print('input_batch', input_batch)
    #print('output_batch', output_batch)
    #print('input_token', input_token)
    #print('output_token', output_token)
    return input_batch, input_token, output_batch, output_token

batches = batchTrainData(voc, [random.choice(pairs) for _ in range(8)])
input_batch, input_token, output_batch, output_token = batches
input_batch, output_batch

(['shit stop with the money ! i never asked for a nickel . i was just doing this . and you have to fuck it up with a price tag .',
  'i didn t have the knowledge of klingon anatomy i needed .',
  'no you re not . give me that dollar seventy .',
  'he s playing a stupid joke sir .',
  'now wait . . .',
  'i like that one .',
  'don t they though ?',
  'you making a feature ?'],
 ['i didn t mean to . it s just the way i am .',
  'you say you are due for retirement . may i ask do your hands shake ?',
  'i ll give you half of it . here s seventy cents .',
  'what ?',
  'with all she was doin . with all the shit she kept doing ! you stayed stuck to that bitch s ass and you wouldn t let go .',
  'no there . there s a good one . do you like that ?',
  'so work must be going well ?',
  'um hm . live sound .'])

In [18]:
class Encoder(nn.Module):
    def __init__(self, 
                 input_dim, 
                 hid_dim, 
                 n_layers, 
                 n_heads, 
                 pf_dim,
                 dropout, 
                 device,
                 max_length = 50):
        super().__init__()

        self.device = device
        
        self.tok_embedding = nn.Embedding(input_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        
        self.layers = nn.ModuleList([EncoderLayer(hid_dim, 
                                                  n_heads, 
                                                  pf_dim,
                                                  dropout, 
                                                  device) 
                                     for _ in range(n_layers)])
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
        
    def forward(self, src, src_mask):
        
        #src = [batch size, src len]
        #src_mask = [batch size, 1, 1, src len]
        
        batch_size = src.shape[0]
        src_len = src.shape[1]
        
        pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
        
        #pos = [batch size, src len]
        
        src = self.dropout((self.tok_embedding(src) * self.scale) + self.pos_embedding(pos))
        
        #src = [batch size, src len, hid dim]
        
        for layer in self.layers:
            src = layer(src, src_mask)
            
        #src = [batch size, src len, hid dim]
            
        return src

In [19]:
class EncoderLayer(nn.Module):
    def __init__(self, 
                 hid_dim, 
                 n_heads, 
                 pf_dim,  
                 dropout, 
                 device):
        super().__init__()
        
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.ff_layer_norm = nn.LayerNorm(hid_dim)
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, 
                                                                     pf_dim, 
                                                                     dropout)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src, src_mask):
        
        #src = [batch size, src len, hid dim]
        #src_mask = [batch size, 1, 1, src len] 
                
        #self attention
        _src, _ = self.self_attention(src, src, src, src_mask)
        
        #dropout, residual connection and layer norm
        src = self.self_attn_layer_norm(src + self.dropout(_src))
        
        #src = [batch size, src len, hid dim]
        
        #positionwise feedforward
        _src = self.positionwise_feedforward(src)
        
        #dropout, residual and layer norm
        src = self.ff_layer_norm(src + self.dropout(_src))
        
        #src = [batch size, src len, hid dim]
        
        return src

In [20]:
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout, device):
        super().__init__()
        
        assert hid_dim % n_heads == 0
        
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads
        
        self.fc_q = nn.Linear(hid_dim, hid_dim)
        self.fc_k = nn.Linear(hid_dim, hid_dim)
        self.fc_v = nn.Linear(hid_dim, hid_dim)
        
        self.fc_o = nn.Linear(hid_dim, hid_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
        
    def forward(self, query, key, value, mask = None):
        
        batch_size = query.shape[0]
        
        #query = [batch size, query len, hid dim]
        #key = [batch size, key len, hid dim]
        #value = [batch size, value len, hid dim]
                
        Q = self.fc_q(query)
        K = self.fc_k(key)
        V = self.fc_v(value)
        
        #Q = [batch size, query len, hid dim]
        #K = [batch size, key len, hid dim]
        #V = [batch size, value len, hid dim]

        #print(Q.shape, K.shape, V.shape)
                
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        
        #Q = [batch size, n heads, query len, head dim]
        #K = [batch size, n heads, key len, head dim]
        #V = [batch size, n heads, value len, head dim]
                
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
        
        #energy = [batch size, n heads, query len, key len]
        
        if mask is not None:
            #print(energy.shape, mask.shape)
            energy = energy.masked_fill(mask == 0, -1e10)
        
        attention = torch.softmax(energy, dim = -1)
                
        #attention = [batch size, n heads, query len, key len]
                
        x = torch.matmul(self.dropout(attention), V)
        
        #x = [batch size, n heads, query len, head dim]
        
        x = x.permute(0, 2, 1, 3).contiguous()
        
        #x = [batch size, query len, n heads, head dim]
        
        x = x.view(batch_size, -1, self.hid_dim)
        
        #x = [batch size, query len, hid dim]
        
        x = self.fc_o(x)
        
        #x = [batch size, query len, hid dim]
        
        return x, attention

In [21]:
class PositionwiseFeedforwardLayer(nn.Module):
    def __init__(self, hid_dim, pf_dim, dropout):
        super().__init__()
        
        self.fc_1 = nn.Linear(hid_dim, pf_dim)
        self.fc_2 = nn.Linear(pf_dim, hid_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        
        #x = [batch size, seq len, hid dim]
        
        x = self.dropout(torch.relu(self.fc_1(x)))
        
        #x = [batch size, seq len, pf dim]
        
        x = self.fc_2(x)
        
        #x = [batch size, seq len, hid dim]
        
        return x

In [22]:
class Decoder(nn.Module):
    def __init__(self, 
                 output_dim, 
                 hid_dim, 
                 n_layers, 
                 n_heads, 
                 pf_dim, 
                 dropout, 
                 device,
                 max_length = 50):
        super().__init__()
        
        self.device = device
        
        self.tok_embedding = nn.Embedding(output_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        
        self.layers = nn.ModuleList([DecoderLayer(hid_dim, 
                                                  n_heads, 
                                                  pf_dim, 
                                                  dropout, 
                                                  device)
                                     for _ in range(n_layers)])
        
        self.fc_out = nn.Linear(hid_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
        
    def forward(self, trg, enc_src, trg_mask, src_mask):
        
        #trg = [batch size, trg len]
        #enc_src = [batch size, src len, hid dim]
        #trg_mask = [batch size, 1, trg len, trg len]
        #src_mask = [batch size, 1, 1, src len]
                
        batch_size = trg.shape[0]
        trg_len = trg.shape[1]
        
        pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
                            
        #pos = [batch size, trg len]
        #print()
        #print("before in the decoder layer", trg.shape)    
        trg = self.dropout((self.tok_embedding(trg) * self.scale) + self.pos_embedding(pos))
                
        #trg = [batch size, trg len, hid dim]
        
        for layer in self.layers:
            #print("in the decoder layer", trg.shape, enc_src.shape)
            trg, attention = layer(trg, enc_src, trg_mask, src_mask)
        
        #trg = [batch size, trg len, hid dim]
        #attention = [batch size, n heads, trg len, src len]
        
        output = self.fc_out(trg)
        
        #output = [batch size, trg len, output dim]
            
        return output, attention

In [23]:
class DecoderLayer(nn.Module):
    def __init__(self, 
                 hid_dim, 
                 n_heads, 
                 pf_dim, 
                 dropout, 
                 device):
        super().__init__()
        
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.enc_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.ff_layer_norm = nn.LayerNorm(hid_dim)
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, 
                                                                     pf_dim, 
                                                                     dropout)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, trg, enc_src, trg_mask, src_mask):
        
        #trg = [batch size, trg len, hid dim]
        #enc_src = [batch size, src len, hid dim]
        #trg_mask = [batch size, 1, trg len, trg len]
        #src_mask = [batch size, 1, 1, src len]
        
        #self attention
        _trg, _ = self.self_attention(trg, trg, trg, trg_mask)
        
        #dropout, residual connection and layer norm
        #print("before", trg.shape)
        trg = self.self_attn_layer_norm(trg + self.dropout(_trg))
        #print("after", trg.shape, enc_src.shape)
            
        #trg = [batch size, trg len, hid dim]
            
        #encoder attention
        _trg, attention = self.encoder_attention(trg, enc_src, enc_src, src_mask)
        # query, key, value
        
        #dropout, residual connection and layer norm
        trg = self.enc_attn_layer_norm(trg + self.dropout(_trg))
                    
        #trg = [batch size, trg len, hid dim]
        
        #positionwise feedforward
        _trg = self.positionwise_feedforward(trg)
        
        #dropout, residual and layer norm
        trg = self.ff_layer_norm(trg + self.dropout(_trg))
        
        #trg = [batch size, trg len, hid dim]
        #attention = [batch size, n heads, trg len, src len]
        
        return trg, attention

In [24]:
class Seq2Seq(nn.Module):
    def __init__(self, 
                 encoder, 
                 decoder, 
                 device):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        #self.src_pad_idx = src_pad_idx
        #self.trg_pad_idx = trg_pad_idx
        self.device = device
        
    def make_src_mask(self, src):
        
        #src = [batch size, src len]
        
        src_mask = (src).unsqueeze(1).unsqueeze(2)

        #src_mask = [batch size, 1, 1, src len]

        return src_mask
    
    def make_trg_mask(self, trg):
        
        #trg = [batch size, trg len]
        trg_pad_mask = (trg).unsqueeze(1).unsqueeze(2)
        
        #trg_pad_mask = [batch size, 1, 1, trg len]
        
        trg_len = trg.shape[1]
        
        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device = self.device)).bool()

        #trg_sub_mask = [trg len, trg len]
            
        trg_mask = trg_pad_mask & trg_sub_mask
        
        #trg_mask = [batch size, 1, trg len, trg len]
        
        return trg_mask

    def forward(self, src, trg):
        
        #src = [batch size, src len]
        #trg = [batch size, trg len]
                
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        
        #src_mask = [batch size, 1, 1, src len]
        #trg_mask = [batch size, 1, trg len, trg len]
        
        enc_src = self.encoder(src, src_mask)
        
        #enc_src = [batch size, src len, hid dim]
                
        output, attention = self.decoder(trg, enc_src, trg_mask, src_mask)
        
        #output = [batch size, trg len, output dim]
        #attention = [batch size, n heads, trg len, src len]
        output = F.softmax(output, dim=0)
        return output, attention

In [25]:
INPUT_DIM = voc.num_words
OUTPUT_DIM = voc.num_words
HID_DIM = 256
ENC_LAYERS = 3
DEC_LAYERS = 3
ENC_HEADS = 8
DEC_HEADS = 8
ENC_PF_DIM = 512
DEC_PF_DIM = 512
ENC_DROPOUT = 0.1
DEC_DROPOUT = 0.1

enc = Encoder(INPUT_DIM, 
              HID_DIM, 
              ENC_LAYERS, 
              ENC_HEADS, 
              ENC_PF_DIM, 
              ENC_DROPOUT, 
              device)

dec = Decoder(OUTPUT_DIM, 
              HID_DIM, 
              DEC_LAYERS, 
              DEC_HEADS, 
              DEC_PF_DIM, 
              DEC_DROPOUT, 
              device)

In [26]:
model = Seq2Seq(enc, dec, device).to(device)

In [27]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 23,828,692 trainable parameters


In [28]:
def initialize_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        nn.init.xavier_uniform_(m.weight.data)

In [29]:
model.apply(initialize_weights);

In [30]:
LEARNING_RATE = 0.0005
optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE) #encoder decoder separate?

In [31]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [32]:
def maskNLLLoss(inp, target, mask):
    nTotal = mask.sum()
    crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = crossEntropy.masked_select(mask).mean()
    loss = loss.to(device)
    #print("loss", loss)
    return loss, nTotal.item()

In [33]:
def train(input_variable, lengths, target_variable, mask, max_target_len, model, optimizer, batch_size, clip):
    
    # Set device options
    input_variable = input_variable.to(device)
    target_variable = target_variable.to(device)
    mask = mask.to(device)
    # Lengths for rnn packing should always be on the cpu
    lengths = lengths.to("cpu")

    # Initialize variables
    loss = 0
    print_losses = []
    n_totals = 0

    optimizer.zero_grad()

    input_variable = input_variable.permute(1,0)
    target_variable = target_variable.permute(1,0)

    output, _ = model(input_variable, target_variable)
                
        #output = [batch size, trg len - 1, output dim]
        #trg = [batch size, trg len]                                                     
            
        #output_dim = output.shape[-1]
            
        #output = output.contiguous().view(-1, output_dim)
        #target_variable = target_variable[:,1:].contiguous().view(-1)
                
        #output = [batch size * trg len - 1, output dim]
        #trg = [batch size * trg len - 1]

    for t in range(max_target_len):
        target = target_variable.permute(1,0)
        output_ = output.permute(1,0,2)
        mask_loss, nTotal = maskNLLLoss(output_[t], target[t], mask[t])
        loss += mask_loss
        print_losses.append(mask_loss.item() * nTotal)
        n_totals += nTotal  
        
    loss.backward()
        
    torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
    optimizer.step()

    return sum(print_losses) / n_totals

In [34]:
4000/730


5.47945205479452

In [35]:
pairs[0]

['not the hacking and gagging and spitting part . please .',
 'okay . . . then how bout we try out some french cuisine . saturday ? night ?']

In [36]:
def train_iters(model, iters, optimizer, batch_size, clip):
    
    # Load batches for each iteration
    training_batches = [batch2TrainData(voc, [random.choice(pairs) for _ in range(batch_size)])
                      for _ in range(iters)]

    model.train()

    loss = 0
    print_loss = 0

    for iteration in range(iters): 
        start_time = time.time()

        #training_batch = batch2TrainData(voc, [random.choice(pairs) for _ in range(64)])
        #input_variable, lengths, target_variable, mask, max_target_len = training_batch

        training_batch = training_batches[iteration - 1]
        # Extract fields from batch
        input_variable, lengths, target_variable, mask, max_target_len = training_batch

        # Run a training iteration with batch
        loss = train(input_variable, lengths, target_variable, mask, max_target_len, model, optimizer, batch_size, clip)
        print_loss += loss
        
        end_time = time.time()
        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        print(f'Iter: {iteration+1:02} | Time: {epoch_mins}m {epoch_secs}s')
        print(f'\tTrain Loss: {loss:.3f} | Train PPL: {math.exp(loss):7.3f}')

In [37]:
iters = 4000
CLIP = 1
batch_size = 256
train_loss = train_iters(model, iters, optimizer, batch_size, CLIP)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Iter: 1501 | Time: 0m 0s
	Train Loss: 1.031 | Train PPL:   2.804
Iter: 1502 | Time: 0m 0s
	Train Loss: 1.035 | Train PPL:   2.814
Iter: 1503 | Time: 0m 0s
	Train Loss: 1.087 | Train PPL:   2.964
Iter: 1504 | Time: 0m 0s
	Train Loss: 0.976 | Train PPL:   2.655
Iter: 1505 | Time: 0m 0s
	Train Loss: 1.013 | Train PPL:   2.753
Iter: 1506 | Time: 0m 0s
	Train Loss: 0.992 | Train PPL:   2.696
Iter: 1507 | Time: 0m 0s
	Train Loss: 1.083 | Train PPL:   2.954
Iter: 1508 | Time: 0m 0s
	Train Loss: 1.009 | Train PPL:   2.743
Iter: 1509 | Time: 0m 0s
	Train Loss: 1.029 | Train PPL:   2.798
Iter: 1510 | Time: 0m 0s
	Train Loss: 1.029 | Train PPL:   2.799
Iter: 1511 | Time: 0m 0s
	Train Loss: 1.005 | Train PPL:   2.731
Iter: 1512 | Time: 0m 0s
	Train Loss: 1.025 | Train PPL:   2.786
Iter: 1513 | Time: 0m 0s
	Train Loss: 0.994 | Train PPL:   2.703
Iter: 1514 | Time: 0m 0s
	Train Loss: 0.967 | Train PPL:   2.629
Iter: 1515 | Time: 0m 0s


In [38]:
def evaluate(model, iters, batch_size):
    
    training_batches = [batch2TrainData(voc, [random.choice(pairs) for _ in range(batch_size)])
                      for _ in range(iters)]

    model.eval()
    
    val_loss = 0
    print_loss = 0

    loss = 0
    print_losses = []
    n_totals = 0
    
    best_valid_loss = float('inf')

    with torch.no_grad():
    
        for iteration in range(iters):
            start_time = time.time()

            #training_batch = batch2TrainData(voc, [random.choice(pairs) for _ in range(batch_size)])
            #input_variable, lengths, target_variable, mask, max_target_len = training_batch

            training_batch = training_batches[iteration - 1]
            # Extract fields from batch
            input_variable, lengths, target_variable, mask, max_target_len = training_batch
        
            # Set device options
            input_variable = input_variable.to(device)
            target_variable = target_variable.to(device)
            mask = mask.to(device)
            # Lengths for rnn packing should always be on the cpu
            lengths = lengths.to("cpu")

            input_variable = input_variable.permute(1,0)
            target_variable = target_variable.permute(1,0)

            output, _ = model(input_variable, target_variable)
            
            #output = [batch size, trg len - 1, output dim]
            #trg = [batch size, trg len]
            
            #output_dim = output.shape[-1]
            
            #output = output.contiguous().view(-1, output_dim)
            #target_variable = target_variable[:,1:].contiguous().view(-1)
            
            #output = [batch size * trg len - 1, output dim]
            #trg = [batch size * trg len - 1]
            
            for t in range(max_target_len):
                target = target_variable.permute(1,0)
                output_ = output.permute(1,0,2)
                mask_loss, nTotal = maskNLLLoss(output_[t], target[t], mask[t])
                loss += mask_loss
                print_losses.append(mask_loss.item() * nTotal)
                n_totals += nTotal   

            val_loss = sum(print_losses) / n_totals
            print_loss += val_loss

            end_time = time.time()
            epoch_mins, epoch_secs = epoch_time(start_time, end_time)

            if val_loss < best_valid_loss:
                best_valid_loss = val_loss
                torch.save(model.state_dict(), 'model.pt')

            print(f'Iter: {iteration+1:02} | Time: {epoch_mins}m {epoch_secs}s')
            print(f'\t Val. Loss: {val_loss:.3f} |  Val. PPL: {math.exp(val_loss):7.3f}')

In [39]:
valid_loss = evaluate(model, iters, batch_size) 

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Iter: 1501 | Time: 0m 0s
	 Val. Loss: 1.008 |  Val. PPL:   2.741
Iter: 1502 | Time: 0m 0s
	 Val. Loss: 1.008 |  Val. PPL:   2.741
Iter: 1503 | Time: 0m 0s
	 Val. Loss: 1.008 |  Val. PPL:   2.741
Iter: 1504 | Time: 0m 0s
	 Val. Loss: 1.008 |  Val. PPL:   2.741
Iter: 1505 | Time: 0m 0s
	 Val. Loss: 1.008 |  Val. PPL:   2.741
Iter: 1506 | Time: 0m 0s
	 Val. Loss: 1.008 |  Val. PPL:   2.741
Iter: 1507 | Time: 0m 0s
	 Val. Loss: 1.008 |  Val. PPL:   2.741
Iter: 1508 | Time: 0m 0s
	 Val. Loss: 1.008 |  Val. PPL:   2.741
Iter: 1509 | Time: 0m 0s
	 Val. Loss: 1.008 |  Val. PPL:   2.741
Iter: 1510 | Time: 0m 0s
	 Val. Loss: 1.008 |  Val. PPL:   2.741
Iter: 1511 | Time: 0m 0s
	 Val. Loss: 1.008 |  Val. PPL:   2.741
Iter: 1512 | Time: 0m 0s
	 Val. Loss: 1.008 |  Val. PPL:   2.741
Iter: 1513 | Time: 0m 0s
	 Val. Loss: 1.008 |  Val. PPL:   2.741
Iter: 1514 | Time: 0m 0s
	 Val. Loss: 1.008 |  Val. PPL:   2.741
Iter: 1515 | Time: 0m 0s


In [83]:
class Inference(nn.Module):
    def __init__(self, model):
        super(Inference, self).__init__()
        self.model = model
    
    def forward(self, input_seq, max_length):
        
        input_seq = input_seq.permute(1,0)
        src_mask = model.make_src_mask(input_seq)

        with torch.no_grad():
            enc_src = model.encoder(input_seq, src_mask)

        #print(enc_src.shape, enc_src)
        #print("src_mask", src_mask.shape, src_mask)

        # Initialize tensors to append decoded words to
        trg_tensor = torch.ones(1, 1, 1, device=device, dtype=torch.long)
        #print("trg_tensor", trg_tensor.shape, trg_tensor)
        all_tokens = torch.zeros([0], device=device, dtype=torch.long)
        all_scores = torch.zeros([0], device=device)

        for i in range(max_length):
            trg_mask = model.make_trg_mask(trg_tensor).squeeze(0)
            #print("trg_mask", trg_mask.shape, trg_mask)

            with torch.no_grad():
                output, attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask)
                output = output.squeeze(0)
                #print("output", output.shape, output)
                #output = F.softmax(output, dim=0)
                #print(output.shape, output)
                decoder_scores, decoder_input = torch.max(output, dim=2)
                #decoder_scores = decoder_scores.squeeze(1)
                #decoder_input = decoder_input.squeeze(1)
                #print("decoder_scores, decoder_input", decoder_scores.shape, decoder_input.shape)

                all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
                all_scores = torch.cat((all_scores, decoder_scores), dim=0)

                trg_tensor = torch.unsqueeze(decoder_input, 0)
        return all_tokens, all_scores

In [85]:
def evaluate(model, searcher, voc, sentence, max_length=MAX_LENGTH):
    ### Format input sentence as a batch
    # words -> indexes
    indexes_batch = [indexesFromSentence(voc, sentence)]
    # Create lengths tensor
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    # Transpose dimensions of batch to match models' expectations
    input_batch = torch.LongTensor(indexes_batch).transpose(0, 1)
    # Use appropriate device
    input_batch = input_batch.to(device)
    lengths = lengths.to('cpu')
    # Decode sentence with searcher
    tokens, scores = searcher(input_batch, max_length)
    # indexes -> words
    decoded_words = [voc.index2word[token.item()] for token in tokens]
    return decoded_words


def evaluateInput(model, searcher, voc):
    input_sentence = ''
    while(1):
        try:
            # Get input sentence
            input_sentence = input('> ')
            # Check if it is quit case
            if input_sentence == 'q' or input_sentence == 'quit': break
            # Normalize sentence
            input_sentence = normalizeString(input_sentence)
            # Evaluate sentence
            output_words = evaluate(model, searcher, voc, input_sentence)
            # Format and print response sentence
            output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
            print('Bot:', ' '.join(output_words))

        except KeyError:
            print("Error: Encountered unknown word.")

In [86]:
# Initialize search module
searcher = Inference(model)

# Begin chatting (uncomment and run the following line to begin)
evaluateInput(model, searcher, voc)

> hey, coming for dinner?
Bot: officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer
> so fond of officer? 
Bot: officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer officer
> q
