#Importing Libraries

In [None]:
from collections import Counter
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader as DL
import torch.utils.data
import math
import torch.nn.functional as F

#HyperParameters

In [None]:
batch_size = 64
max_len = 16
num_heads = 8

#Mount Google Drive

In [None]:
# Mount Google Drive
from google.colab import drive as gdrive
gdrive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Change directory and list files
import os
os.chdir("/content/drive/My Drive/AI/cornell movie-dialogs corpus/")
!ls  # List files in the current directory

chameleons.pdf		movie_characters_metadata.txt  pairs_encoded.json
checkpoint_9.pth	movie_conversations.txt        raw_script_urls.txt
checkpoint_final_9.pth	movie_lines.txt		       README.txt
checkpoint_final9.pth	movie_titles_metadata.txt      WORDMAP_corpus.json


In [None]:
corpus_movie_conv = '/content/drive/My Drive/AI/cornell movie-dialogs corpus/movie_conversations.txt'
corpus_movie_lines = '/content/drive/My Drive/AI/cornell movie-dialogs corpus/movie_lines.txt'

#Data Prepration

- In our approach, we establish a fixed length for our sequences, ensuring consistency in our data processing.
- As we handle data in batches, it's crucial to determine this maximum length beforehand.
- By doing so, we can efficiently store our data in matrices,
streamlining the input process for our neural network.
- To accommodate sentences shorter than the designated maximum length, we employ padding.
- In this instance, we've set the maximum length at 25 characters, providing a standardized framework for our data processing pipeline.


##Reading the Movie Conversation and Lines

##Understanding coversations

### Conversation Grouping
- The conversation data is structured such that consecutive lines form coherent conversations.
- Each group of lines represents a single conversation.

### Example
- For instance, lines 194 to 197 constitute one conversation.
- Similarly, lines 198 and 199 form another conversation.

This grouping approach facilitates the analysis and processing of conversations within the dataset, enabling efficient handling of sequential dialogues.


In [None]:
# import os

# file_path = 'path/to/your/file.txt'

# if os.path.exists(file_path):
#     print("File exists.")
# else:
#     print("File does not exist.")


In [None]:
with open(corpus_movie_conv, 'r') as c:
    conv = c.readlines()

In [None]:
# conv

##Understanding Lines

### Explanation of Line Content
- Each line in the dataset corresponds to a specific utterance within a conversation.
- The content of each line includes the actual saying, either a question or a reply, along with the associated character.

### Example Illustration
- Line number 1045 contains the saying "they do not."
- The subsequent line provides the continuation of the conversation.
- For instance, if we examine the first conversation:
  - The initial line represents the question posed.
  - The following line serves as the reply to that question.
  - This pattern continues throughout the conversation.
- To access a specific question, one can refer to the line number corresponding to the start of that question.
- Similarly, the subsequent line contains the reply to the preceding question.

This organization of the dataset enables easy identification and extraction of both questions and replies within the conversations.


In [None]:
with open(corpus_movie_lines, 'r', encoding='ISO-8859-1') as l:
    lines = l.readlines()

# lines

##Data to dictionary

In [None]:
lines[0].split(" +++$+++ ")
# we need index and what was said

['L1045', 'u0', 'm0', 'BIANCA', 'They do not!\n']

In [None]:
lines_dic = {}
for line in lines:
    objects = line.split(" +++$+++ ")
    line_idx = objects[0]
    lines_dic[line_idx] = objects[-1]

# lines_dic[0]

In [None]:
lines_dic["L197"]

"Okay... then how 'bout we try out some French cuisine.  Saturday?  Night?\n"

##Cleaning the conversation

In [None]:
def remove_punc(string):
    """
    Remove punctuation characters from the input string and convert it to lowercase.

    Parameters:
    string (str): The input string containing punctuation characters.

    Returns:
    str: The input string without any punctuation characters and converted to lowercase.
    """

    # Define a string containing all punctuation characters
    punctuations = '''!()-[]{};:"\<>/@#$%^&*_~'''

    # Initialize an empty string to store the input string without punctuation
    no_punct = ""

    # Iterate over each character in the input string
    for char in string:
        # Check if the character is not a punctuation character
        if char not in punctuations:
            # Append the character to the string without punctuation
            no_punct += char  # Space is also a character

    # Convert the string without punctuation to lowercase and return it
    return no_punct.lower()


This code iterates over conversations in a dataset, extracting conversation IDs, and then creating question-answer pairs based on these IDs. It removes punctuation and leading/trailing whitespace from the lines corresponding to each ID, splits the lines into words, and limits the length of each to a specified maximum length. Finally, it appends the question-answer pair to a list of pairs.

In [None]:
# this is string and we need to convert this to a python list.
conv[0].split(" +++$+++ ")[-1]

"['L194', 'L195', 'L196', 'L197']\n"

In [None]:
eval(conv[0].split(" +++$+++ ")[-1])

['L194', 'L195', 'L196', 'L197']

In [None]:
# Initialize an empty list to store question-answer pairs
pairs = []

# Iterate over each conversation in the dataset
for i, con in enumerate(conv):
    try:
      # Extract the conversation IDs and evaluate them as a list
      ids = eval(con.split(" +++$+++ ")[-1])

      # Iterate over the conversation IDs
      for i in range(len(ids)):
          # Initialize an empty list to store question-answer pairs for each conversation
          qa_pairs = []

          # Break the loop if it's the last conversation ID
          if i == len(ids) - 1:
              break

          # Remove punctuation and leading/trailing whitespace from the lines corresponding to the conversation IDs
          first = remove_punc(lines_dic[ids[i]].strip())
          second = remove_punc(lines_dic[ids[i + 1]].strip())

          # Split the lines into words and limit the length of each to 'max_len'
          qa_pairs.append(first.split()[:max_len])
          qa_pairs.append(second.split()[:max_len])

          # Append the question-answer pair to the list of pairs
          pairs.append(qa_pairs)
    except:
      print("Error on i =", i, con)


In [None]:
lines_dic["L194"].strip()

'Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.'

In [None]:
question = lines_dic["L194"].strip()
reply = lines_dic["L195"].strip()

q_list = question.split()
r_list = reply.split()


In [None]:
# q_list

In [None]:
# now qa pair is a 2d list
qa_pair = [q_list, r_list]
# qa_pair

In [None]:
# pairs = pairs[:1000]

In [None]:
len(pairs)

221616

In [None]:
# confirming that all the pairs have one 2 list
for p in pairs:
  if len(p) != 2:
    print(len(p))

##Word-to-Index Dictionary for Word Embeddings

### Introduction
Now we'll focus on constructing a word-to-index dictionary, an essential step in utilizing word embeddings. Word embeddings represent each word in a vocabulary as a dense vector, typically obtained from a one-hot encoding followed by an embedding layer. This process allows for more efficient representation and processing of textual data.

### Process Overview
- **Mapping Words to Indices**: Each unique word in the dataset will be assigned a unique index. This index will serve as the basis for creating one-hot vectors.
- **Generating One-Hot Vectors**: PyTorch, our deep learning framework, will automatically convert these indices into one-hot vectors.
- **Utilizing Embedding Layers**: The one-hot vectors will then be inserted into an embedding layer, which we'll explore in detail later. This layer transforms one-hot vectors into dense word embeddings, capturing semantic relationships between words.

### Steps:
1. **Collecting Unique Words**: The first step involves gathering all the unique words present in the datasets.
2. **Calculating Word Frequencies**: We need to determine how often each word occurs in our dataset.
3. **Filtering Low-Frequency Words**: Words that occur infrequently, less than five times for instance, will be removed. This helps streamline the vocabulary size and reduces the complexity of the output layer in our model.

By following these steps, we ensure that our word-to-index dictionary effectively represents the vocabulary of our dataset while maintaining efficiency in computational resources.


##Creating Word Frequency Dictionary using collections
This code iterates over each question-answer pair in the list of pairs and updates a Counter object called word_freq with the frequencies of words appearing in both the questions and answers. The update() method increments the counts for each word encountered in the pairs.


In [None]:
# Initialize a Counter object to store word frequencies
word_freq = Counter()

# Iterate over each question-answer pair in the list of pairs
for pair in pairs:
    # Update the word frequencies with the words from both the question and the answer
    word_freq.update(pair[0])  # Update word frequencies with words from the question
    word_freq.update(pair[1])  # Update word frequencies with words from the answer

In [None]:
# word_freq

##Filtering Words by Frequency:
Words that occur less frequently than the specified threshold (`min_word_freq`) are filtered out from the word frequency dictionary (`word_freq`).

**Creating Word-to-Index Mapping:**
- Each remaining word is assigned a unique index in the `word_map` dictionary, starting from 1.
- The index is incremented for each word in the list of filtered words, creating a word-to-index mapping.

**Adding Special Tokens:**
- Special tokens such as `<unk>` (unknown), `<start>` (start-of-sequence), `<end>` (end-of-sequence), and `<pad>` (padding) are added to the `word_map` dictionary with unique indices.
- These tokens are crucial for data preprocessing and model training, allowing for handling of out-of-vocabulary words, marking sequence boundaries, and managing variable-length sequences.

The resulting `word_map` dictionary provides a comprehensive mapping of words to indices, including special tokens, facilitating efficient data processing and model training.


In [None]:
# Set the minimum word frequency threshold
min_word_freq = 8

# Filter words based on their frequency to exclude those occurring less frequently than the threshold
words = [w for w in word_freq.keys() if word_freq[w] > min_word_freq]

# Create a word-to-index mapping dictionary
word_map = {k: v + 1 for v, k in enumerate(words)}  # Assign unique indices to each word, starting from 1

# Add special tokens to the word map with unique indices
word_map['<unk>'] = len(word_map) + 1  # Unknown token for out-of-vocabulary words
word_map['<start>'] = len(word_map) + 1  # Start-of-sequence token
word_map['<end>'] = len(word_map) + 1  # End-of-sequence token
word_map['<pad>'] = 0  # Padding token with index 0

In [None]:
# word_map

In [None]:
print("Total words are {}.".format(len(word_map)))

Total words are 17512.


##Saving the WordMap

In [None]:
with open('WORDMAP_corpus.json', 'w') as j:
    json.dump(word_map, j)

## Encoding Words Using Word Mapping

After creating the `word_map`, the next step is to encode the words using this mapping. Since neural networks require numerical inputs rather than strings, we need to represent words as indices in the `word_map`.

### Function Definitions
Two functions will be created for encoding: one for questions and one for replies.

### Function: `encode_question`
- **Input Arguments:**
  - `words`: List of words in the question.
  - `word_map`: Mapping of words to indices (`word_map`).

- **Explanation:**
  - This function, `encode_question`, converts each word in the question into its corresponding index using the provided `word_map`.

### Function: `encode_reply`
- **Input Arguments:**
  - `words`: List of words in the reply.
  - `word_map`: Mapping of words to indices (`word_map`).

- **Explanation:**
  - Similarly, the `encode_reply` function converts each word in the reply into its corresponding index using the `word_map`.


In [None]:
def encode_enc_inp(words, word_map):
    """
    Encode a question into a sequence of indices using a word-to-index mapping.

    Parameters:
    words (list): List of words in the question.
    word_map (dict): Mapping of words to indices.

    Returns:
    list: Encoded question as a sequence of indices.
    """

    # Convert each word in the question to its corresponding index in the word map
    # Use '<unk>' index for out-of-vocabulary words
    enc_c = [word_map.get(word, word_map['<unk>']) for word in words]

    # Pad the encoded sequence with '<pad>' token to ensure uniform length
    enc_c += [word_map['<pad>']] * (max_len - len(words))

    return enc_c


In [None]:
def encode_dec_inp(words, word_map):
    """
    Encode a reply into a sequence of indices using a word-to-index mapping.

    Parameters:
    words (list): List of words in the reply.
    word_map (dict): Mapping of words to indices.

    Returns:
    list: Encoded reply as a sequence of indices.
    """

    # Convert each word in the reply to its corresponding index in the word map
    # Use '<unk>' index for out-of-vocabulary words
    # Add '<start>' and '<end>' tokens to mark the start and end of the reply
    enc_c = [word_map['<start>']] + [word_map.get(word, word_map['<unk>']) for word in words] + \
            [word_map['<end>']] + [word_map['<pad>']] * (max_len - len(words))

    return enc_c


In [None]:
# Initialize an empty list to store encoded question-answer pairs
pairs_encoded = []

# Iterate over each question-answer pair in the list of pairs
for pair in pairs:
    # Encode the question and the reply using the provided word-to-index mapping
    qus = encode_enc_inp(pair[0], word_map)  # Encode the question
    ans = encode_dec_inp(pair[1], word_map)  # Encode the reply

    # Append the encoded question-answer pair to the list of encoded pairs
    pairs_encoded.append([qus, ans])


In [None]:
pairs_encoded[10]

[[17509, 100, 17509, 4, 101, 53, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [17510, 103, 104, 39, 105, 106, 24, 17509, 17511, 0, 0, 0, 0, 0, 0, 0, 0, 0]]

##Saving Number coded WordMap

In [None]:
fname = "pairs_encoded.json"
with open(fname, 'w') as p:
    json.dump(pairs_encoded, p)

#Custom Dataset Class
- Refer to this video if you are not sure how this works

In [None]:
class MovieDataset(Dataset):
    """
    Custom PyTorch dataset class for loading encoded question-reply pairs.

    Args:
    -----
    None.

    Attributes:
    -----------
    pairs (list): List of encoded question-reply pairs.
    dataset_size (int): Total number of question-reply pairs in the dataset.

    Methods:
    --------
    __init__(): Initializes the dataset by loading encoded pairs from a JSON file.
    __getitem__(i): Retrieves the encoded question-reply pair at index i.
    __len__(): Returns the total number of question-reply pairs in the dataset.
    """

    def __init__(self):
        """
        Initialize the dataset by loading encoded pairs from a JSON file.
        Sets the total number of pairs in the dataset.
        """
        self.pairs = json.load(open('pairs_encoded.json'))  # Load encoded pairs from a JSON file
        self.dataset_size = len(self.pairs)  # Set the total number of pairs in the dataset

    def __getitem__(self, i):
        """
        Retrieve the encoded question-reply pair at index i.

        Args:
        -----
        i (int): Index of the pair to retrieve.

        Returns:
        --------
        tuple: Encoded question and reply tensors.
        """
        # Convert the encoded question and reply to PyTorch LongTensors
        enc_inp = torch.LongTensor(self.pairs[i][0])
        dec = torch.LongTensor(self.pairs[i][1])

        # Prepare Target Data
        dec_inp = dec[ :-1]
        dec_out = dec[1 : ]

        return enc_inp, dec_inp, dec_out

    def __len__(self):
        """
        Return the total number of question-reply pairs in the dataset.

        Returns:
        --------
        int: Total number of pairs in the dataset.
        """
        return self.dataset_size


In [None]:
train_data = MovieDataset()

In [None]:
q_r = train_data[10]
q_r

(tensor([17509,   100, 17509,     4,   101,    53,   102,     0,     0,     0,
             0,     0,     0,     0,     0,     0]),
 tensor([17510,   103,   104,    39,   105,   106,    24, 17509, 17511,     0,
             0,     0,     0,     0,     0,     0,     0]),
 tensor([  103,   104,    39,   105,   106,    24, 17509, 17511,     0,     0,
             0,     0,     0,     0,     0,     0,     0]))

In [None]:
rev_word_map = {v: k for k, v in word_map.items()}

In [None]:
def tensor_to_sentence(t, clean=False):
  q = t.detach().numpy()
  q_words = " ".join([rev_word_map[v] for v in q])

  if clean:
    q_words = q_words.replace("<pad>", "")

  return q_words


In [None]:
q_words = tensor_to_sentence(q_r[0])
r_words = tensor_to_sentence(q_r[1])
q_words, r_words

('<unk> ma <unk> this is my head <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>',
 "<start> right. see? you're ready for the <unk> <end> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>")

#Custom Dataloader

In [None]:
train_loader = DL(train_data,
                  batch_size = batch_size,
                  shuffle=True,
                  pin_memory=True)

In [None]:
# the reason we have 25 length in question is because we defined max length as 25
# reply has 25 + 2 = 27 because we have start and end appended to it
# and of course there is padding if the sentence does not have 25 words in it.
# for i, (enc_inp, dec_inp, dec,out) in enumerate(train_loader):

for i, (enc_inp, dec_inp, dec_out) in enumerate(train_loader):
  print(enc_inp.shape, dec_inp.shape, dec_out.shape)
  print(tensor_to_sentence(enc_inp[0]))
  print(tensor_to_sentence(dec_inp[0]))
  print(tensor_to_sentence(dec_out[0]))

  break



torch.Size([64, 16]) torch.Size([64, 17]) torch.Size([64, 17])
he <unk> of not feeling well. i thought he was drunk  he <unk> <pad> <pad>
<start> that <unk> his dying so quickly. in your <unk> have you never seen men who <unk>
that <unk> his dying so quickly. in your <unk> have you never seen men who <unk> <end>


#Setting the device

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

#Mask

This function **create_masks** generates masks for the *input question* and *reply* sequences to facilitate attention mechanisms in the neural network model.
- It first defines a nested function subsequent_mask to create a mask preventing attending to subsequent positions.
- Then, it creates masks for the input question, input reply, and target reply, ensuring proper masking for padding tokens and subsequent positions.
- The masks are returned as a tuple for further use in the model.

##Example
- Sentence: `<start>Hello how are you <end>`
- reply_input: `<start>Hello how are you`
  - reply_input is input to our decoder
- reply_target: `Hello how are you<end>`
  - reply_target is the target to our decoder
- Remember we are doing supervised learning.

In [None]:
# # Batched scenario
# t = torch.triu(torch.ones((2, 4, 4)))
# t.transpose(1, 2)

#Embeddings class with Positional Detail

In [None]:
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_size, pad_id):
        super(TokenEmbedding, self).__init__()
        self.token_embedding = nn.Embedding(vocab_size,
                                            embed_size,
                                            padding_idx=pad_id)
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.token_embedding.weight.data.uniform_(-initrange, initrange)

    def forward(self, x):
        x_embed = self.token_embedding(x)
        return x_embed

In [None]:
class PositionalEmbedding(nn.Module):
    """
    ref: https://github.com/codertimo/BERT-pytorch/blob/master/bert_pytorch/model/embedding/position.py

    """
    def __init__(self, d_model, max_len=512):
        super().__init__()

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model).float()
        pe.require_grad = False

        position = torch.arange(0, max_len).float().unsqueeze(1)
        div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return self.pe[:, :x.size(1)]

In [None]:
import math
class Embeddings(nn.Module):
    def __init__(self, vocab, embed_size, max_len):
        super(Embeddings, self).__init__()
        self.token_embedding = TokenEmbedding(vocab_size=len(vocab),
                                              embed_size=embed_size,
                                              pad_id=vocab["<pad>"])
        self.embed_size = embed_size
        self.pos_embedding = PositionalEmbedding(d_model=embed_size,
                                                 max_len=max_len+2)

    def forward(self, x):
        token_embed = self.token_embedding(x) * math.sqrt(self.embed_size)
        pos_embed = self.pos_embedding(x)

        # print(x.shape, token_embed.shape, pos_embed.shape)

        return token_embed + pos_embed


In [None]:
class Transformer(nn.Module):
    def __init__(self, vocab,
                 d_model=512,
                 n_head=8,
                 num_encoder_layers=6,
                 num_decoder_layers=6,
                 dim_feedforward=2048,
                 dropout=0.1,
                 max_len=15) -> None:
        """Instantiating Transformer class
        Args:
            config (Config): model config, the instance of data_utils.utils.Config
            vocab (Vocabulary): the instance of data_utils.vocab_tokenizer.Vocabulary
        """
        super(Transformer, self).__init__()
        self.vocab = vocab
        d_model = d_model #512
        n_head = n_head #8
        num_encoder_layers = num_encoder_layers #6
        num_decoder_layers = num_decoder_layers #6
        dim_feedforward = dim_feedforward #2048
        dropout = dropout #0.1

        self.input_embedding = Embeddings(vocab, d_model, max_len)

        self.transfomrer = torch.nn.Transformer(d_model=d_model,
                                                nhead=n_head,
                                                num_encoder_layers=num_encoder_layers,
                                                num_decoder_layers=num_decoder_layers,
                                                dim_feedforward=dim_feedforward,
                                                dropout=dropout,
                                                batch_first=True)

        self.proj_vocab_layer = nn.Linear(in_features=d_model,
                                          out_features=len(vocab))

        # https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module.apply
        # self.apply(self._initailze)
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.proj_vocab_layer.bias.data.zero_()
        self.proj_vocab_layer.weight.data.uniform_(-initrange, initrange)

    def forward(self, enc_input: torch.Tensor, dec_input: torch.Tensor) -> torch.Tensor:

        x_enc_embed = self.input_embedding(enc_input.long())
        x_dec_embed = self.input_embedding(dec_input.long())

        # Masking
        # tensor([[False, False, False,  True,  ...,  True]])
        src_key_padding_mask = enc_input == self.vocab["<pad>"]
        tgt_key_padding_mask = dec_input == self.vocab["<pad>"]

        memory_key_padding_mask = src_key_padding_mask
        tgt_mask = self.transfomrer.generate_square_subsequent_mask(dec_input.size(1))

        # transformer ref: https://pytorch.org/docs/stable/nn.html#torch.nn.Transformer
        src_key_padding_mask = src_key_padding_mask.type(torch.float)
        tgt_key_padding_mask = tgt_key_padding_mask.type(torch.float)
        memory_key_padding_mask = memory_key_padding_mask.type(torch.float)
        tgt_mask = tgt_mask.type(torch.float).to(device)

        feature = self.transfomrer(src = x_enc_embed,
                                   tgt = x_dec_embed,
                                   src_key_padding_mask = src_key_padding_mask,
                                   tgt_key_padding_mask = tgt_key_padding_mask,
                                   memory_key_padding_mask=memory_key_padding_mask,
                                   tgt_mask = tgt_mask)

        logits = self.proj_vocab_layer(feature)

        return logits





In [None]:
word_map["<pad>"], len(word_map)

(0, 17512)

In [None]:
model = Transformer(word_map, max_len=15).to(device)
for i, (enc_inp, dec_inp, dec_out) in enumerate(train_loader):
  print(enc_inp.shape, dec_inp.shape, dec_out.shape)
  enc_inp, dec_inp = enc_inp.to(device), dec_inp.to(device)
  out = model(enc_inp, dec_inp)
  print(out.shape, dec_out.shape)

  # for 1 sentence form the batch
  # we have (max_len, vocab_size) output
  # hello - [vocab_size tensor with logit values]
  # how - [vocab_size tensor with logit values]
  # are - [vocab_size tensor with logit values]
  # your - [vocab_size tensor with logit values]
  # after softmax we will have 16 items with max values, we will compare that with dec_out
  # and calcualte the loss
  print(out[0].shape, dec_out[0].shape)
  break

torch.Size([64, 16]) torch.Size([64, 17]) torch.Size([64, 17])
torch.Size([64, 17, 17512]) torch.Size([64, 17])
torch.Size([17, 17512]) torch.Size([17])


#Creating the Model

#Optimizer Adam Warm Up

In [None]:
class AdamWarmup:

    def __init__(self, model_size, warmup_steps, optimizer):

        self.model_size = model_size
        self.warmup_steps = warmup_steps
        self.optimizer = optimizer
        self.current_step = 0
        self.lr = 0

    def get_lr(self):
        return self.model_size ** (-0.5) * min(self.current_step ** (-0.5), self.current_step * self.warmup_steps ** (-1.5))

    def step(self):
        # Increment the number of steps each time we call the step function
        self.current_step += 1


        lr = self.get_lr()

        # print(self.current_step, lr)
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        # update the learning rate
        self.lr = lr
        self.optimizer.step()

# Loss with Loss Smoothing

In [None]:
class LossWithLS(nn.Module):
    def __init__(self, size, smooth):
        super(LossWithLS, self).__init__()
        self.criterion = nn.KLDivLoss(reduction='batchmean' )

        # self.criterion = nn.CrossEntropyLoss(ignore_index=0)

        self.confidence = 1.0 - smooth
        self.smooth = smooth
        self.size = size
        self.true_dist = None

    def forward(self, x, target):
      assert x.size(-1) == self.size
      true_dist = torch.zeros_like(x.data)
      true_dist.fill_(self.smooth / (self.size - 1))  # Fill with the smoothing value
      true_dist.scatter_(2, target.unsqueeze(2), self.confidence)  # Assign the confidence value to the true index
      true_dist = true_dist.detach()  # Detach true_dist from the computation graph

      # return self.criterion(x, true_dist)
      return self.criterion(F.log_softmax(x, dim=-1), true_dist)


# Example usage
batch_size = 64
max_words = 26
vocab_size = 18243
smooth = 0.1

# Random tensors for demonstration
prediction = torch.randn(batch_size, max_words, vocab_size)
target = torch.randint(0, vocab_size, (batch_size, max_words))

print(prediction.shape, target.shape)
# Initialize and compute loss
loss_fn = LossWithLS(size=vocab_size, smooth=smooth)
loss = loss_fn(prediction, target)
print(f'Loss: {loss.item()}')

torch.Size([64, 26, 18243]) torch.Size([64, 26])
Loss: 234.2530517578125


#Evaluation of the Model

In [None]:
def evaluate(model, enc_inp, max_len, word_map):
    model.eval()  # Set the model to evaluation mode
    start_symbol = word_map['<start>']  # Assuming <sos> is the start-of-sequence token
    end_symbol = word_map['<end>']  # Assuming <eos> is the end-of-sequence token

    # Start with a target sequence of length 1 (just the start-of-sequence token)
    dec_inp = torch.LongTensor([start_symbol]).unsqueeze(0).to(device)
    # print(dec_inp.shape)

    # Generate output iteratively
    for i in range(max_len - 1):

        # Calculate the output logits
        output = model(enc_inp, dec_inp)
        # print(output.shape)

        # Get the last token from the output
        next_token_logits = output[:, -1, :]

        # Convert logits to probabilities and pick the token with the highest probability
        next_token = next_token_logits.argmax(dim=-1, keepdim=True)

        # Append the predicted token to the target sequence
        dec_inp = torch.cat([dec_inp, next_token], dim=1)

        # Check if the end-of-sequence token was generated
        if next_token.item() == end_symbol:
            break

    # Convert the target sequence to a list of tokens
    tgt_tokens = dec_inp.squeeze(0).tolist()
    # Convert tokens to words
    sentence = ' '.join([reverse_word_map[token] for token in tgt_tokens if token not in (start_symbol, end_symbol)])

    return sentence

# Assuming you have a word_map and a reverse_word_map to convert between tokens and words
reverse_word_map = {v: k for k, v in word_map.items()}

In [None]:
questions = ["Hello how are you?", "I like Fruits", "Are you hungry?"]
def getResults(transformer, questions):
  for q in questions:
    enc_qus = [word_map.get(word, word_map['<unk>']) for word in q.split()]
    question = torch.LongTensor(enc_qus).to(device).unsqueeze(0)
    # print("question.shape", question.shape)
    sentence = evaluate(transformer, question, max_len, word_map)
    print("\t", sentence)

In [None]:
transformer = Transformer(word_map, max_len=15).to(device)

questions = ["Hello how are you?", "I like Fruits", "Are you hungry?"]

getResults(transformer, questions)

	 lamp ben. ben. ben. ben. ben. ben. ben. ben. ben. ben. ben. ben. ben. ben.
	 interviewed interviewed interviewed interviewed interviewed interviewed interviewed interviewed interviewed interviewed interviewed interviewed momma momma momma
	 lamp ben. lamp ben. ben. ben. ben. ben. ben. ben. ben. ben. ben. ben. ben.


#Training the Model

In [None]:
# d_model = 200
# n_head = 2
# num_encoder_layers = 2
# num_decoder_layers = num_encoder_layers
# dim_feedforward = 200
# dropout = 0.2

d_model = 512
n_head = 2
num_encoder_layers = 2
num_decoder_layers = num_encoder_layers
dim_feedforward = 512
dropout = 0.2

epochs = 10

transformer = Transformer(word_map,
                 d_model=d_model,
                 n_head=n_head,
                 num_encoder_layers=num_encoder_layers,
                 num_decoder_layers=num_decoder_layers,
                 dim_feedforward=dim_feedforward,
                 dropout=dropout,
                 max_len=15).to(device)

adam_optimizer = torch.optim.Adam(transformer.parameters(), lr=0.00, betas=(0.9, 0.98), eps=1e-9)
transformer_optimizer = AdamWarmup(model_size=512, warmup_steps = 4000, optimizer = adam_optimizer)
criterion = LossWithLS(len(word_map), 0.1)

for epoch in range(epochs):

    transformer.train()
    sum_loss = 0
    count = 0

    for i, (enc_inp, dec_inp, dec_out) in enumerate(train_loader):

        samples = enc_inp.shape[-1]

        # Move to device
        enc_inp = enc_inp.to(device)
        dec_inp, dec_out = dec_inp.to(device), dec_out.to(device)

        # Get the transformer outputs
        out = transformer(enc_inp, dec_inp)

        # Compute the loss
        # print(out.shape, reply_target.shape)
        loss = criterion(out, dec_out)

        # Backprop
        transformer_optimizer.optimizer.zero_grad()
        loss.backward()

        torch.nn.utils.clip_grad_norm_(transformer.parameters(), 0.5)

        transformer_optimizer.step()

        sum_loss += loss.item() * samples
        # print(loss.item(),  samples)
        count += samples

        if i % (batch_size * 5) == 0:
            # print(loss.item(),  samples)
            print("Epoch [{}][{}/{}]\tLoss: {:.3f}\tLR: {:.5f}".format(
                epoch,
                i,
                len(train_loader),
                sum_loss/count,
                transformer_optimizer.lr))

            getResults(transformer, questions)

    # state = {'epoch': epoch, 'transformer': transformer, 'transformer_optimizer': transformer_optimizer}
    # torch.save(state, 'checkpoint_' + str(epoch) + '.pth.tar')

Epoch [0][0/3463]	Loss: 163.562	LR: 0.00000
	 park. horse fight running kentucky running kentucky alternative. kentucky gee, investigation? move, detachment amazing. particularly
	 anybody's oz. file? deputy trade anybody's exit. powers, red gee, boyfriend's cohaagen kentucky witch, stay,
	 professional. singer, crash. would... edward powwow crash. heard... fingers cigarette attached. deputy bullshit cigarette sleeping


KeyboardInterrupt: 

In [None]:
state = {'epoch': epoch, 'transformer': transformer, 'transformer_optimizer': transformer_optimizer}
torch.save(state, 'checkpoint_final' + str(epoch) + '.pth')

In [None]:
questions = ["Hello how are you?",
             "Are you hungry?",
             "How is life going?",
             "I am sad",
             "I kiss a girl"]

getResults(transformer, questions)

In [None]:
transformer_optimizer

In [None]:
state = {
    'epoch': epoch,
    'transformer_state_dict': transformer.state_dict(),
    'transformer_optimizer_state_dict': transformer_optimizer.optimizer.state_dict()
}
torch.save(state, 'checkpoint_final_' + str(epoch) + '.pth')

In [None]:
# Load the checkpoint
checkpoint = torch.load('checkpoint_final_9.pth')

In [None]:
d_model = 200
n_head = 2
num_encoder_layers = 2
num_decoder_layers = num_encoder_layers
dim_feedforward = 200
dropout = 0.2


epochs = 10

transformer = Transformer(word_map,
                 d_model=d_model,
                 n_head=n_head,
                 num_encoder_layers=num_encoder_layers,
                 num_decoder_layers=num_decoder_layers,
                 dim_feedforward=dim_feedforward,
                 dropout=dropout,
                 max_len=15).to(device)

In [None]:
# Restore the model and optimizer state
transformer.load_state_dict(checkpoint['transformer_state_dict'])
transformer_optimizer.optimizer.load_state_dict(checkpoint['transformer_optimizer_state_dict'])

# Restore the last epoch
start_epoch = checkpoint['epoch']


In [None]:
questions = ["Do you eat Fruits?",
             "Lets go to France?",
             "I am just happy",
             "I kiss a girl"]

getResults(transformer, questions)

	 i don't know what you're talking about.
	 i don't want to see you again.
	 i don't want to hear it.
	 i don't know what to do with you, <unk>
