<a href="https://colab.research.google.com/github/nikshrimali/ENDGAME_MERGER/blob/main/Assignment13/Transformers_Chatbot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import torch
from torch.jit import script, trace
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import csv
import random
import re
import os
import unicodedata
import codecs
from io import open
import itertools
import math

from torchtext.data import Field, BucketIterator, LabelField, TabularDataset


USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")

In [2]:
# Getting the zip cornel movie dialogs file. We will use this dataset to train our chatbot
!wget 'http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip'

--2021-02-13 10:36:54--  http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip
Resolving www.cs.cornell.edu (www.cs.cornell.edu)... 132.236.207.36
Connecting to www.cs.cornell.edu (www.cs.cornell.edu)|132.236.207.36|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 9916637 (9.5M) [application/zip]
Saving to: ‘cornell_movie_dialogs_corpus.zip’


2021-02-13 10:36:54 (38.9 MB/s) - ‘cornell_movie_dialogs_corpus.zip’ saved [9916637/9916637]



In [3]:
# Unzip the dataset
!unzip '/content/cornell_movie_dialogs_corpus.zip'

Archive:  /content/cornell_movie_dialogs_corpus.zip
   creating: cornell movie-dialogs corpus/
  inflating: cornell movie-dialogs corpus/.DS_Store  
   creating: __MACOSX/
   creating: __MACOSX/cornell movie-dialogs corpus/
  inflating: __MACOSX/cornell movie-dialogs corpus/._.DS_Store  
  inflating: cornell movie-dialogs corpus/chameleons.pdf  
  inflating: __MACOSX/cornell movie-dialogs corpus/._chameleons.pdf  
  inflating: cornell movie-dialogs corpus/movie_characters_metadata.txt  
  inflating: cornell movie-dialogs corpus/movie_conversations.txt  
  inflating: cornell movie-dialogs corpus/movie_lines.txt  
  inflating: cornell movie-dialogs corpus/movie_titles_metadata.txt  
  inflating: cornell movie-dialogs corpus/raw_script_urls.txt  
  inflating: cornell movie-dialogs corpus/README.txt  
  inflating: __MACOSX/cornell movie-dialogs corpus/._README.txt  


In [4]:
corpus_name = "cornell movie-dialogs corpus"
corpus = os.path.join("/content/", corpus_name)

def printLines(file, n=10):
    with open(file, 'rb') as datafile:
        lines = datafile.readlines()
    for line in lines[:n]:
        print(line)

printLines(os.path.join(corpus, "movie_lines.txt"))

b'L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!\n'
b'L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!\n'
b'L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.\n'
b'L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?\n'
b"L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.\n"
b'L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow\n'
b"L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie.\n"
b'L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No\n'
b'L870 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I\'m kidding.  You know how sometimes you just become this "persona"?  And you don\'t know how to quit?\n'
b'L869 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Like my fear of wearing pastels?\n'


In [5]:
# Splits each line of the file into a dictionary of fields
def loadLines(fileName, fields):
    lines = {}
    with open(fileName, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(" +++$+++ ")
            # Extract fields
            lineObj = {}
            for i, field in enumerate(fields):
                lineObj[field] = values[i]
            lines[lineObj['lineID']] = lineObj
    return lines


# Groups fields of lines from `loadLines` into conversations based on *movie_conversations.txt*
def loadConversations(fileName, lines, fields):
    conversations = []
    with open(fileName, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(" +++$+++ ")
            # Extract fields
            convObj = {}
            for i, field in enumerate(fields):
                convObj[field] = values[i]
            # Convert string to list (convObj["utteranceIDs"] == "['L598485', 'L598486', ...]")
            utterance_id_pattern = re.compile('L[0-9]+')
            lineIds = utterance_id_pattern.findall(convObj["utteranceIDs"])
            # Reassemble lines
            convObj["lines"] = []
            for lineId in lineIds:
                convObj["lines"].append(lines[lineId])
            conversations.append(convObj)
    return conversations


# Extracts pairs of sentences from conversations
def extractSentencePairs(conversations):
    qa_pairs = []
    for conversation in conversations:
        # Iterate over all the lines of the conversation
        for i in range(len(conversation["lines"]) - 1):  # We ignore the last line (no answer for it)
            inputLine = conversation["lines"][i]["text"].strip()
            targetLine = conversation["lines"][i+1]["text"].strip()
            # Filter wrong samples (if one of the lists is empty)
            if inputLine and targetLine:
                qa_pairs.append([inputLine, targetLine])
    return qa_pairs

In [6]:
# Define path to new file
datafile = os.path.join(corpus, "formatted_movie_lines.txt")

delimiter = '\t'
# Unescape the delimiter
delimiter = str(codecs.decode(delimiter, "unicode_escape"))

# Initialize lines dict, conversations list, and field ids
lines = {}
conversations = []
MOVIE_LINES_FIELDS = ["lineID", "characterID", "movieID", "character", "text"]
MOVIE_CONVERSATIONS_FIELDS = ["character1ID", "character2ID", "movieID", "utteranceIDs"]

# Load lines and process conversations
print("\nProcessing corpus...")
lines = loadLines(os.path.join(corpus, "movie_lines.txt"), MOVIE_LINES_FIELDS)
print("\nLoading conversations...")
conversations = loadConversations(os.path.join(corpus, "movie_conversations.txt"),
                                  lines, MOVIE_CONVERSATIONS_FIELDS)

# Write new csv file
print("\nWriting newly formatted file...")
with open(datafile, 'w', encoding='utf-8') as outputfile:
    writer = csv.writer(outputfile, delimiter=delimiter, lineterminator='\n')
    for pair in extractSentencePairs(conversations):
        writer.writerow(pair)

# Print a sample of lines
print("\nSample lines from file:")
printLines(datafile)


Processing corpus...

Loading conversations...

Writing newly formatted file...

Sample lines from file:
b"Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\tWell, I thought we'd start with pronunciation, if that's okay with you.\n"
b"Well, I thought we'd start with pronunciation, if that's okay with you.\tNot the hacking and gagging and spitting part.  Please.\n"
b"Not the hacking and gagging and spitting part.  Please.\tOkay... then how 'bout we try out some French cuisine.  Saturday?  Night?\n"
b"You're asking me out.  That's so cute. What's your name again?\tForget it.\n"
b"No, no, it's my fault -- we didn't have a proper introduction ---\tCameron.\n"
b"Cameron.\tThe thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\n"
b"The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't dat

In [7]:
MAX_LENGTH = 10  # Maximum sentence length to consider

# Turn a Unicode string to plain ASCII, thanks to
# https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    s = re.sub(r"\s+", r" ", s).strip()
    return s

# Read query/response pairs and return a voc object
def readVocs(datafile, corpus_name):
    print("Reading lines...")
    # Read the file and split into lines
    lines = open(datafile, encoding='utf-8').\
        read().strip().split('\n')
    # Split every line into pairs and normalize
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
    return pairs

# Returns True iff both sentences in a pair 'p' are under the MAX_LENGTH threshold
def filterPair(p):
    # Input sequences need to preserve the last word for EOS token
    return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH

# Filter pairs using filterPair condition
def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

# Using the functions defined above, return a populated voc object and pairs list
def loadPrepareData(corpus, corpus_name, datafile, save_dir):
    print("Start preparing training data ...")
    pairs = readVocs(datafile, corpus_name)
    print("Read {!s} sentence pairs".format(len(pairs)))
    pairs = filterPairs(pairs)
    print("Trimmed to {!s} sentence pairs".format(len(pairs)))
    print("Counting words...")

    return pairs


# Load/Assemble voc and pairs
save_dir = os.path.join("data", "save")
print(corpus_name)
pairs = loadPrepareData(corpus, corpus_name, datafile, save_dir)
# Print some pairs to validate
print("\npairs:")
for pair in pairs[:10]:
    print(pair)

cornell movie-dialogs corpus
Start preparing training data ...
Reading lines...
Read 221282 sentence pairs
Trimmed to 64271 sentence pairs
Counting words...

pairs:
['there .', 'where ?']
['you have my word . as a gentleman', 'you re sweet .']
['hi .', 'looks like things worked out tonight huh ?']
['you know chastity ?', 'i believe we share an art instructor']
['have fun tonight ?', 'tons']
['well no . . .', 'then that s all you had to say .']
['then that s all you had to say .', 'but']
['but', 'you always been this selfish ?']
['do you listen to this crap ?', 'what crap ?']
['what good stuff ?', 'the real you .']


In [8]:
type(pairs)

list

In [9]:
import spacy
spacy_en = spacy.load('en')

def tokenize_en(text):
    """
    Tokenizes English text from a string into a list of strings
    """
    return [tok.text for tok in spacy_en.tokenizer(text)]

STAT = Field(tokenize= tokenize_en, 
            init_token='<sos>', 
            eos_token='<eos>', 
            lower=True,
            batch_first=True,
            include_lengths=True)

RESP = Field(tokenize = tokenize_en, 
            init_token='<sos>', 
            eos_token='<eos>', 
            lower=True,
            batch_first=True,
            include_lengths=True)

In [10]:
fields = {'Statement': ('s', STAT), 'Response': ('r', RESP)}

In [11]:
len(pairs)

64271

In [12]:
# Adding custom dataset to the Torch Dataset Library
import pandas as pd
raw_data = {'Statement' : [stat[0] for stat in pairs], 'Response': [resp[1] for resp in pairs]}
df = pd.DataFrame(raw_data, columns=["Statement", "Response"])

# Create vocab - min count = 3n

In [13]:
df.head(10)

Unnamed: 0,Statement,Response
0,there .,where ?
1,you have my word . as a gentleman,you re sweet .
2,hi .,looks like things worked out tonight huh ?
3,you know chastity ?,i believe we share an art instructor
4,have fun tonight ?,tons
5,well no . . .,then that s all you had to say .
6,then that s all you had to say .,but
7,but,you always been this selfish ?
8,do you listen to this crap ?,what crap ?
9,what good stuff ?,the real you .


In [14]:
# Dividing the data into train and validation dataset

train_df = df.sample(frac = 0.90) 
  
# Creating dataframe with rest of the 10% values 
valid_df = df.drop(train_df.index)

In [15]:
print(f'train df {train_df}')
print(f'Valid df {valid_df}')

train_df.to_csv('train.csv', index=False)
valid_df.to_csv('valid.csv', index=False)

train df                                                Statement                            Response
23669                                          yes sir .           a good man . good flyer .
41914              oh will you stop attacking hannah ? !                              oh now
41134                            sure but what is this ?     that s not your affair . name .
51168  i ll investigate mr . clarendon s financial po...                i don t understand .
56206                                what is it norman ?                     where are you ?
...                                                  ...                                 ...
41969                          yes you do recall right ?                      i recall you .
41326                             that s wonderful mom .             what s a revival tent ?
49794            well get a message through to him too .  brilliant . word perfect i d say .
28424                                     some of them .     

In [16]:
# Using tabular dataset to process the text

train_data, test_data = TabularDataset.splits(
                                path = '',   
                                train = './train.csv',
                                test = './valid.csv',
                                format = 'csv',
                                fields = fields)

In [17]:
 BATCH_SIZE = 24
 device = "cuda" if torch.cuda.is_available() else "cpu"

In [18]:
STAT.build_vocab(train_data, min_freq = 3, max_size= 10000)
RESP.build_vocab(test_data, min_freq = 3, max_size= 10000)

In [19]:
BATCH_SIZE = 24

train_iterator, test_iterator = BucketIterator.splits(
    (train_data, test_data), 
    batch_size = BATCH_SIZE,
    sort=False,
    device = device)

In [20]:
class Encoder(nn.Module):
    def __init__(self,
                 input_dim,
                 hid_dim,
                 n_Layers,
                 n_heads,
                 pf_dim,
                 dropout,
                 device,
                 max_length = 100):

        
        super().__init__()

        self.device = device

        # below we are breaking down the embedding into input and positional embedding
        self.tok_embd = nn.Embedding(num_embeddings=input_dim, embedding_dim=hid_dim)
        self.pos_embd = nn.Embedding(max_length, hid_dim)

        # We also add layers for multi-headed processing

        self.layers = nn.ModuleList([EncoderLayer(hid_dim,
                                                  pf_dim,
                                                  n_heads,
                                                  dropout,
                                                  device)
                                    for _ in range(n_Layers)])

        # When we add two embeddings, we multiply our embeddings with a scale parameter, which helps us to maintain
        # our values in a certain range

        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)

        # Creating linear layer to reduce the dimension of the vector


        # We also add a dropout value for regularization

        self.dropout = nn.Dropout(dropout)

    
    def forward(self, input_src, src_mask):

        # input_src = [batch_size, src_len]
        batch_size = input_src.shape[0]
        src_len = input_src.shape[1]

        # Is src_len same in all the cases?

        pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size,1).to(self.device)
        # print(f'Pos shape {pos.shape}, {input_src.shape}')
        #pos = [batch_size, src_len]

        input_embd = self.tok_embd(input_src)
        pos_embd = self.pos_embd(pos)

        # input_embd = pos_embd = [batch_size, src_len, embedding_dim]
        src = self.dropout(input_embd*self.scale + pos_embd)
        # src = [batch_size, src_len, hid_dim]

        # what does encoder returns?
        # 
        for layer in self.layers:
            src = layer(src, src_mask)

        return src


In [21]:
class EncoderLayer(nn.Module):
    def __init__(self,
                 hid_dim,
                 pf_dim,
                 n_heads,
                 dropout,
                 device):
        
        super().__init__()
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim) # Layer norm after attention layer
        self.self_ff_layer_norm = nn.LayerNorm(hid_dim) # Layer norm after feed forward layer
        self.self_attention = MultiHeadAttention(hid_dim, n_heads, dropout, device) # Multi-head attention layer
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout)
        self.dropout = nn.Dropout(dropout) 
    
    def forward(self, src, src_mask):
        # what is src mask?????
        # src = [batch_size, src_len, hid_dim]
        # Why source has the hidden dim?
        # src_mask = [batch_size, 1, 1, src_len]
        
        _src, _ = self.self_attention(src, src, src, src_mask) # Self attention layer
        src = self.self_attn_layer_norm(self.dropout(_src) + src) # Add and Norm layer with residual connection

        # src = [batch_size, src_len, hid_dim]
        # Pointwise feedforward
        _src = self.positionwise_feedforward(src)
        src = self.self_ff_layer_norm(self.dropout(_src) + src)
        # src = [batch_size, src len, hid_dim]
        return src

In [22]:
class MultiHeadAttention(nn.Module):
    def __init__(self,
                 hid_dim,
                 n_heads,
                 dropout,
                 device):
        
        super().__init__()
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim//n_heads

        self.fc_q = nn.Linear(hid_dim, hid_dim)
        self.fc_k = nn.Linear(hid_dim, hid_dim)
        self.fc_v = nn.Linear(hid_dim, hid_dim)

        self.fc_o = nn.Linear(hid_dim, hid_dim)
        
        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
    
    def forward(self, query, key, value, mask=None):

        batch_size = query.shape[0]
        # query = [batch size, query_len, hid_dim]
        # key = [batch size, key_len, hid_dim]
        # value = [batch size, value_len, hid_dim]

        Q = self.fc_q(query)
        K = self.fc_k(key)
        V = self.fc_v(value)

        # Q = [batch size, query_len, hid_dim]
        # K = [batch size, key_len, hid_dim]
        # V = [batch size, value_len, hid_dim]

        # How hid_dim is divided into n_heads, head_dim?
        # print('Attention batch, heads, head_dim', batch_size, -1, self.n_heads, self.head_dim)

        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)

        # Q=K=V = [batch_size, n_heads, query/key/value_len, head_dim]

        # Transposing the input embedding and then doing matrix multiplication

        # [batch_size, src_len, src_len]
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2))/ self.scale
        # print('Shape of energy', energy.shape)
        # energy = [batch_size, n_heads, query_len, key_length]
        # As you can notice the head dim is not there in the energy vector. only query_len and key_len are there

        # This energy is then multiplied with the value to calculate attention

        if mask is not None:
            # If mask values are close to zero, set it to very small values, we do this because?
            energy = energy.masked_fill(mask == 0, -1e10)

        # After matrix mul of query and key, and scaling we will apply softmax to get the output in a distribution of 0 to +1.
        # This value will act as an attention vector for us

        attention = torch.softmax(energy, dim=-1)

        # Attention is then further multiplied by the values to get the contextual embeddings

        x = torch.matmul(self.dropout(attention), V)

        #x = [batch_size, query_length, n_heads, hid_dim]

        x = x.permute(0, 2, 1, 3).contiguous()
        #x = [batch_size, n_heads, query_length, hid_dim]

        x = x.view(batch_size, -1, self.hid_dim)

        #x = [batch_size, query_len, hid_dim]

        x = self.fc_o(x)

        #x = [batch_size, query_len, hid_dim]

        return x, attention

In [23]:
class PositionwiseFeedforwardLayer(nn.Module):
    def __init__(self, hid_dim, pf_dim, dropout):
        super().__init__()

        self.fc1 = nn.Linear(hid_dim, pf_dim)
        self.fc2 = nn.Linear(pf_dim, hid_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x = [batch_size, seq_len, hid_dim] 
        x = self.dropout(torch.relu(self.fc1(x)))
        x = self.fc2(x)
        # x = [batch_size, seq_len, hid_dim]

        return x

In [24]:
class Decoder(nn.Module):
    def __init__(self,
                 output_dim,
                 hid_dim,
                 n_layers,
                 n_heads,
                 pf_dim,
                 dropout,
                 device,
                 max_length=100):

        super().__init__()
        self.device = device
        self.dropout = nn.Dropout(dropout)

        self.tok_embd = nn.Embedding(output_dim, hid_dim)
        self.pos_embd = nn.Embedding(max_length, hid_dim)

        self.layers = nn.ModuleList([DecoderLayer(hid_dim, 
                                                  n_heads, 
                                                  pf_dim, 
                                                  dropout, 
                                                  device)
                                     for _ in range(n_layers)])
        
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
        self.fc_out = nn.Linear(hid_dim, output_dim)
    
    def forward(self, trg, enc_src, trg_mask, src_mask):

        # src = [batch_size, trg_len]

        batch_size = trg.shape[0]
        trg_len = trg.shape[1]

        pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size,1).to(device)

        # pos = [batch_size, trg_len]

        trg = self.dropout(self.tok_embd(trg)* self.scale + self.pos_embd(pos))

        for layer in self.layers:
            # Why src_mask and trg_mask
            trg, attention = layer(trg, enc_src, trg_mask, src_mask)
        
        output = self.fc_out(trg)
        # output = [batch_size, trg_len, output_dim, ]
        return output, attention

In [25]:
class DecoderLayer(nn.Module):
    def __init__(self,
                 hid_dim,
                 n_heads,
                 pf_dim,
                 dropout,
                 device):
        super().__init__()

        self.self_attn_lyr_norm = nn.LayerNorm(hid_dim)
        self.enc_attn_lyr_norm = nn.LayerNorm(hid_dim)
        self.ff_layer_norm = nn.LayerNorm(hid_dim)
        self.self_attention = MultiHeadAttention(hid_dim, n_heads, dropout, device)
        self.encoder_attention = MultiHeadAttention(hid_dim, n_heads, dropout, device)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout)

        self.dropout = nn.Dropout(dropout)

    def forward(self, trg, enc_src, trg_mask, src_mask):
        # trg = [batch_size, trg_len, hid_dim]
        # enc_src = [batch_size, src_len, hid_dim]
        # trg_mask = [batch_size, 1, trg_len, trg_len]
        # src_mask = [batch_size, 1, 1, src_len]

        # Self 
        # print('Target shape and mask', trg.shape, trg_mask.shape)
        _trg, _ = self.self_attention(trg, trg, trg, trg_mask)

        # Layer Norm - Dropout, Relu, residual connection
        trg = self.self_attn_lyr_norm(self.dropout(_trg) + trg)

        # query, key, value
        _trg, attention = self.encoder_attention(trg, enc_src, enc_src, src_mask)
        trg = self.enc_attn_lyr_norm(self.dropout(_trg) + trg)

        trg = self.positionwise_feedforward(trg)
        # trg = [batch_size, trg_len, hid_dim]
        # attention = [batch_size, n_heads, trg_len, src_len]

        trg = self.ff_layer_norm(trg + self.dropout(_trg))
        return trg, attention

In [26]:
class Seq2Seq(nn.Module):
    def __init__(self,
                 encoder,
                 decoder,
                 src_pad_idx,
                 trg_pad_idx,
                 device):
        
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device
        
    def make_src_mask(self, src):

        # src = [batch_size, src_len]
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)

        # src_mask = [batch_size, 1, 1, src_len]

        return src_mask

    def make_trg_mask(self, trg):
        # trg = [batch_size, trg_len]
        trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)
        # print(f'Make target mask {trg.shape}, {trg_pad_mask.shape}')
        # trg_mask = [batch_size, 1, 1, trg_len]
    
        trg_len = trg.shape[1]
        
        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=self.device)).bool()
        
        #trg_mask = [batch_size, 1, trg_len, trg_len]
        
        trg_mask = trg_pad_mask & trg_sub_mask # What is this & operator???

        # trg_mask = [batch_size, 1, trg_len, trg_len]
        # print(f'Target mask shape make_trg_mask {trg_mask.shape}')

        return trg_mask

    def forward(self, src, trg):
        # src = [batch_size, src_len]
        # trg = [batch_size, trg_len]
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)

        enc_src = self.encoder(src, src_mask)

        # enc_src = [batch_size, src_len, hid_dim, output_dim]
        
        output, attention = self.decoder(trg, enc_src, trg_mask, src_mask)

        # output = [batch_size, trg_len, output_dim]
        # attention = [batch_size, n_heads, trg_len, src_len]

        return output, attention

In [27]:
INPUT_DIM = len(STAT.vocab)
OUTPUT_DIM = len(RESP.vocab)
HID_DIM = 256
ENC_LAYERS = 3
DEC_LAYERS = 3
ENC_HEADS = 8
DEC_HEADS = 8
ENC_PF_DIM = 512
DEC_PF_DIM = 512
ENC_DROPOUT = 0.1
DEC_DROPOUT = 0.1

enc = Encoder(INPUT_DIM, 
              HID_DIM, 
              ENC_LAYERS, 
              ENC_HEADS, 
              ENC_PF_DIM, 
              ENC_DROPOUT, 
              device)


dec = Decoder(OUTPUT_DIM, 
              HID_DIM, 
              DEC_LAYERS, 
              DEC_HEADS, 
              DEC_PF_DIM, 
              DEC_DROPOUT,
              device)

In [28]:
SRC_PAD_IDX = STAT.vocab.stoi[STAT.pad_token]
TRG_PAD_IDX = RESP.vocab.stoi[RESP.pad_token]

model = Seq2Seq(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device)

In [29]:
# Applied Xaviar uniform instead of 
def init_weights(m):
    for name, param in m.named_parameters():
        if hasattr(m, 'weight') and m.weight.dim() > 1:
            nn.init.xavier_uniform_(m.weight.data)
        else:
            nn.init.constant_(param.data, 0)
            
model.apply(init_weights)

Seq2Seq(
  (encoder): Encoder(
    (tok_embd): Embedding(4577, 256)
    (pos_embd): Embedding(100, 256)
    (layers): ModuleList(
      (0): EncoderLayer(
        (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (self_ff_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (self_attention): MultiHeadAttention(
          (fc_q): Linear(in_features=256, out_features=256, bias=True)
          (fc_k): Linear(in_features=256, out_features=256, bias=True)
          (fc_v): Linear(in_features=256, out_features=256, bias=True)
          (fc_o): Linear(in_features=256, out_features=256, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (positionwise_feedforward): PositionwiseFeedforwardLayer(
          (fc1): Linear(in_features=256, out_features=512, bias=True)
          (fc2): Linear(in_features=512, out_features=256, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (dropout)

In [30]:
def initialize_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        nn.init.xavier_uniform_(m.weight.data)

In [31]:
criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)

In [32]:
model.apply(initialize_weights);

In [33]:
def maskNLLLoss(inp, target, mask):
    nTotal = mask.sum()
    crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = crossEntropy.masked_select(mask).mean()
    loss = loss.to(device)
    return loss, nTotal.item()

In [34]:
LEARNING_RATE = 0.0005
optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)

In [35]:
# Model Training code

def train(model,
          iterator,
          optimizer,
          criterion,
          clip):
    
    model.train()
    epoch_loss = 0

    for i, batch in enumerate(iterator):

        src, src_len = batch.s
        trg, trg_len = batch.r

        optimizer.zero_grad()
        output, _ = model(src, trg[:,:-1])

        # output = [batch_size, trg_len-1, output_dim]

        output_dim = output.shape[-1]
            
        output = output.contiguous().view(-1, output_dim)
        trg = trg[:,1:].contiguous().view(-1)
        loss = criterion(output, trg)

        # Perform backpropatation
        loss.backward()

        # Clip gradients: gradients are modified in place
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        # Adjust model weights
        optimizer.step()
        epoch_loss += loss.item()

    return epoch_loss / len(iterator)

In [36]:
def evaluate(model, iterator, criterion):

    model.eval()
    epoch_loss = 0

    with torch.no_grad():
        
        for i, batch in enumerate(iterator):

            src, src_len = batch.s
            trg, trg_len = batch.r

            output, _ = model(src, trg[:,:-1])

            #output = [batch_size, trg_len-1, outputdim]
            #trg = [batch_size, trg_len]

            output_dim = output.shape[-1]

            output = output.contiguous().view(-1, output_dim)
            trg = trg[:,1:].contiguous().view(-1)

            loss = criterion(output, trg)
            epoch_loss +=loss.item()

    return epoch_loss/len(iterator)


In [38]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
import time
N_EPOCHS = 10
CLIP = 1

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, test_iterator, criterion)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'tut6-model.pt')
    
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

Epoch: 01 | Time: 1m 40s
	Train Loss: 4.202 | Train PPL:  66.818
	 Val. Loss: 4.331 |  Val. PPL:  76.056
Epoch: 02 | Time: 1m 39s
	Train Loss: 4.202 | Train PPL:  66.789
	 Val. Loss: 4.330 |  Val. PPL:  75.961
Epoch: 03 | Time: 1m 39s
	Train Loss: 4.202 | Train PPL:  66.800
	 Val. Loss: 4.330 |  Val. PPL:  75.913
Epoch: 04 | Time: 1m 39s
	Train Loss: 4.201 | Train PPL:  66.761
	 Val. Loss: 4.330 |  Val. PPL:  75.929
Epoch: 05 | Time: 1m 40s
	Train Loss: 4.201 | Train PPL:  66.721
	 Val. Loss: 4.330 |  Val. PPL:  75.962
