In [28]:
import numpy as np
import json
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from nltk import word_tokenize
import pandas as pd
from torch.utils.data import DataLoader,Dataset
import os
from tqdm import tqdm

In [29]:
file_path = os.getcwd()
data = pd.read_csv('./dataset/data.csv') #loading the data
print(data)
with open("./word_freq.json", "r") as json_file: #loading the unigram frequencies
    word_freq = json.load(json_file)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

                                              paragraph
0     Between 1995 and 2010, a study was conducted r...
1     Poverty represents a worldwide crisis. It is t...
2     The left chart shows the population change hap...
3     Human beings are facing many challenges nowada...
4     Information about the thousands of visits from...
...                                                 ...
998   Efficient Learning of Continuous-Time Hidden M...
999   Expectation Particle Belief Propagation. We pr...
1000  Latent Bayesian melding for integrating indivi...
1001  Distributionally Robust Logistic Regression. T...
1002  Variational Dropout and the Local Reparameteri...

[1003 rows x 1 columns]


In [30]:
def preprocess_text(text): #tokenisation
    text = text.lower()
    words = word_tokenize((re.sub(r'([.,!?@#:$%^&*()_+=-])',' ',text)))
    return words

data['tokens'] = data['paragraph'].apply(preprocess_text)

print(data['tokens'].head())

0    [between, 1995, and, 2010, a, study, was, cond...
1    [poverty, represents, a, worldwide, crisis, it...
2    [the, left, chart, shows, the, population, cha...
3    [human, beings, are, facing, many, challenges,...
4    [information, about, the, thousands, of, visit...
Name: tokens, dtype: object


In [31]:
sequence_length = 2  # Number of words in the input sequence
inputs = []
targets = []
for tokens in data['tokens']:
    if len(tokens) < sequence_length + 1:
        continue  # skip short sentences
    for i in range(sequence_length, len(tokens)):
        input_seq = tokens[i-sequence_length:i]
        target = tokens[i]
        inputs.append(input_seq)
        targets.append(target)

print(f'Number of sequences: {len(inputs)}')
print(f'First input sequence: {inputs[0]}')
print(f'First target: {targets[0]}')

Number of sequences: 211598
First input sequence: ['between', '1995']
First target: and


In [32]:
min_freq = 2  
vocab = [word for word, count in word_freq.items() if count >= min_freq]
vocab = ['<PAD>', '<UNK>'] + vocab  #special tokens

word_to_idx = {word: idx for idx, word in enumerate(vocab)}
idx_to_word = {idx: word for word, idx in word_to_idx.items()}

vocab_size = len(word_to_idx)
print(f'Vocabulary size: {vocab_size}')
print(f'First 10 words in vocabulary: {vocab[:10]}')

inputs_int = [[word_to_idx.get(word,word_to_idx['<UNK>']) for word in input] for input in inputs]
targets_int = [word_to_idx.get(word,word_to_idx['<UNK>']) for word in targets]


Vocabulary size: 7314
First 10 words in vocabulary: ['<PAD>', '<UNK>', 'between', '1995', 'and', '2010', 'a', 'study', 'was', 'conducted']


In [33]:
with open("word_to_idx.json", "w") as json_file:
    json.dump(word_to_idx, json_file)

with open("idx_to_word.json", "w") as json_file:
    json.dump(idx_to_word, json_file)

In [34]:
train_inputs = np.array(inputs_int)
train_targets = np.array(targets_int)
print(f'Training samples: {train_inputs.shape[0]}')

Training samples: 211598


In [35]:
class NextWordDataset(Dataset): #Creating the dataloader for training the models
    def __init__(self, X, y):
        self.X = torch.tensor(X)
        self.y = torch.tensor(y)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

train_dataset = NextWordDataset(train_inputs, train_targets)
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [37]:
#We first train the LSTM model
from lstm_bilstm import LSTM_Model
embedding_dim = 256
hidden_size = 128
num_layers = 2          
lstm_dropout = 0.3

lstm_model = LSTM_Model(
    embedding_dim=embedding_dim,
    hidden_size=hidden_size,
    num_layers=num_layers,
    vocab_size=vocab_size,
    lstm_dropout=lstm_dropout,
    bidirectional=False
)
lstm_model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
    filter(lambda p: p.requires_grad, lstm_model.parameters()),
    lr=0.001,
)
print(lstm_model)

LSTM_Model(
  (embeddings): Embedding(7314, 256, padding_idx=0)
  (lstm): LSTM(256, 128, num_layers=2, batch_first=True, dropout=0.3)
  (fc): Linear(in_features=128, out_features=7314, bias=True)
)


In [38]:
def train_model(model,train_loader,bidirectional = False):
    if bidirectional:
        s = 'bilstm'
    else:
        s = 'lstm'    
    checkpoint_dir = file_path + '/model_checkpoints/' + s
    num_epochs = 100
    patience = 5
    best_train_loss = float('inf') 
    early_stop_counter = 0
    top_k = 5

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        correct_topk_predictions = 0
        total_predictions = 0
        
        with tqdm(train_loader, unit="batch") as tepoch:
            tepoch.set_description(f"Epoch [{epoch+1}/{num_epochs}]")
            
            for batch_X, batch_y in tepoch:
                targets = batch_y.to(device)
                outputs = model(batch_X.to(device)).to(device)
                loss = criterion(outputs, targets)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step() 
                epoch_loss += loss.item()

                _, topk_predictions = torch.topk(outputs, k=top_k, dim=1)
                correct_topk_predictions += (topk_predictions == targets.view(-1, 1)).sum().item()
                total_predictions += targets.size(0)
                
                topk_accuracy = correct_topk_predictions / total_predictions
                tepoch.set_postfix(loss=loss.item(), topk_accuracy=topk_accuracy)
        
        avg_train_loss = epoch_loss / len(train_loader)
        train_topk_accuracy = correct_topk_predictions / total_predictions
        print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {avg_train_loss:.4f}, Top-{top_k} Accuracy: {train_topk_accuracy:.4f}')
        
        if avg_train_loss < best_train_loss:
            best_train_loss = avg_train_loss
            early_stop_counter = 0
            
            checkpoint_path = os.path.join(checkpoint_dir, f'best_model.pth')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss
            }, checkpoint_path)
        else:
            early_stop_counter += 1
            print(f'No improvement in training loss. Early stop counter: {early_stop_counter}/{patience}')
        
        if early_stop_counter >= patience:
            print("Early stopping triggered. Stopping training.")
            break
        


In [39]:
#Training the LSTM model
train_model(model=lstm_model,train_loader=train_loader,bidirectional=False)

Epoch [1/100]: 100%|██████████| 1654/1654 [00:03<00:00, 459.62batch/s, loss=6.93, topk_accuracy=0.221]


Epoch [1/100], Training Loss: 6.4716, Top-5 Accuracy: 0.2210


Epoch [2/100]: 100%|██████████| 1654/1654 [00:03<00:00, 485.33batch/s, loss=5.74, topk_accuracy=0.282]


Epoch [2/100], Training Loss: 5.7979, Top-5 Accuracy: 0.2816


Epoch [3/100]: 100%|██████████| 1654/1654 [00:03<00:00, 481.99batch/s, loss=5.06, topk_accuracy=0.31] 


Epoch [3/100], Training Loss: 5.4730, Top-5 Accuracy: 0.3095


Epoch [4/100]: 100%|██████████| 1654/1654 [00:03<00:00, 478.57batch/s, loss=5.8, topk_accuracy=0.334] 


Epoch [4/100], Training Loss: 5.2204, Top-5 Accuracy: 0.3337


Epoch [5/100]: 100%|██████████| 1654/1654 [00:03<00:00, 498.40batch/s, loss=4.26, topk_accuracy=0.354]


Epoch [5/100], Training Loss: 5.0080, Top-5 Accuracy: 0.3536


Epoch [6/100]: 100%|██████████| 1654/1654 [00:03<00:00, 478.78batch/s, loss=4.5, topk_accuracy=0.37]  


Epoch [6/100], Training Loss: 4.8350, Top-5 Accuracy: 0.3698


Epoch [7/100]: 100%|██████████| 1654/1654 [00:03<00:00, 488.00batch/s, loss=3.12, topk_accuracy=0.385]


Epoch [7/100], Training Loss: 4.6833, Top-5 Accuracy: 0.3852


Epoch [8/100]: 100%|██████████| 1654/1654 [00:03<00:00, 491.40batch/s, loss=5.25, topk_accuracy=0.4]  


Epoch [8/100], Training Loss: 4.5577, Top-5 Accuracy: 0.3998


Epoch [9/100]: 100%|██████████| 1654/1654 [00:03<00:00, 495.42batch/s, loss=4.95, topk_accuracy=0.412]


Epoch [9/100], Training Loss: 4.4441, Top-5 Accuracy: 0.4125


Epoch [10/100]: 100%|██████████| 1654/1654 [00:03<00:00, 487.46batch/s, loss=3.69, topk_accuracy=0.424]


Epoch [10/100], Training Loss: 4.3378, Top-5 Accuracy: 0.4245


Epoch [11/100]: 100%|██████████| 1654/1654 [00:03<00:00, 492.72batch/s, loss=3.93, topk_accuracy=0.436]


Epoch [11/100], Training Loss: 4.2501, Top-5 Accuracy: 0.4362


Epoch [12/100]: 100%|██████████| 1654/1654 [00:03<00:00, 483.31batch/s, loss=3.18, topk_accuracy=0.446]


Epoch [12/100], Training Loss: 4.1636, Top-5 Accuracy: 0.4464


Epoch [13/100]: 100%|██████████| 1654/1654 [00:03<00:00, 487.77batch/s, loss=2.99, topk_accuracy=0.455]


Epoch [13/100], Training Loss: 4.0877, Top-5 Accuracy: 0.4552


Epoch [14/100]: 100%|██████████| 1654/1654 [00:03<00:00, 493.37batch/s, loss=3.85, topk_accuracy=0.466]


Epoch [14/100], Training Loss: 4.0180, Top-5 Accuracy: 0.4655


Epoch [15/100]: 100%|██████████| 1654/1654 [00:03<00:00, 492.99batch/s, loss=5.17, topk_accuracy=0.474]


Epoch [15/100], Training Loss: 3.9534, Top-5 Accuracy: 0.4743


Epoch [16/100]: 100%|██████████| 1654/1654 [00:03<00:00, 496.31batch/s, loss=4.09, topk_accuracy=0.483]


Epoch [16/100], Training Loss: 3.8939, Top-5 Accuracy: 0.4830


Epoch [17/100]: 100%|██████████| 1654/1654 [00:03<00:00, 476.69batch/s, loss=3.03, topk_accuracy=0.491]


Epoch [17/100], Training Loss: 3.8376, Top-5 Accuracy: 0.4908


Epoch [18/100]: 100%|██████████| 1654/1654 [00:03<00:00, 472.42batch/s, loss=3.09, topk_accuracy=0.498]


Epoch [18/100], Training Loss: 3.7871, Top-5 Accuracy: 0.4977


Epoch [19/100]: 100%|██████████| 1654/1654 [00:03<00:00, 490.36batch/s, loss=5.03, topk_accuracy=0.505]


Epoch [19/100], Training Loss: 3.7447, Top-5 Accuracy: 0.5046


Epoch [20/100]: 100%|██████████| 1654/1654 [00:03<00:00, 491.47batch/s, loss=3.21, topk_accuracy=0.51] 


Epoch [20/100], Training Loss: 3.6940, Top-5 Accuracy: 0.5101


Epoch [21/100]: 100%|██████████| 1654/1654 [00:03<00:00, 483.03batch/s, loss=3.27, topk_accuracy=0.517]


Epoch [21/100], Training Loss: 3.6523, Top-5 Accuracy: 0.5169


Epoch [22/100]: 100%|██████████| 1654/1654 [00:03<00:00, 478.96batch/s, loss=3.07, topk_accuracy=0.522]


Epoch [22/100], Training Loss: 3.6128, Top-5 Accuracy: 0.5219


Epoch [23/100]: 100%|██████████| 1654/1654 [00:03<00:00, 470.40batch/s, loss=2.73, topk_accuracy=0.527]


Epoch [23/100], Training Loss: 3.5789, Top-5 Accuracy: 0.5275


Epoch [24/100]: 100%|██████████| 1654/1654 [00:03<00:00, 435.08batch/s, loss=4.03, topk_accuracy=0.533]


Epoch [24/100], Training Loss: 3.5442, Top-5 Accuracy: 0.5327


Epoch [25/100]: 100%|██████████| 1654/1654 [00:03<00:00, 491.22batch/s, loss=2.9, topk_accuracy=0.538] 


Epoch [25/100], Training Loss: 3.5093, Top-5 Accuracy: 0.5377


Epoch [26/100]: 100%|██████████| 1654/1654 [00:03<00:00, 485.39batch/s, loss=2.44, topk_accuracy=0.543]


Epoch [26/100], Training Loss: 3.4786, Top-5 Accuracy: 0.5427


Epoch [27/100]: 100%|██████████| 1654/1654 [00:03<00:00, 466.13batch/s, loss=3.9, topk_accuracy=0.545] 


Epoch [27/100], Training Loss: 3.4545, Top-5 Accuracy: 0.5451


Epoch [28/100]: 100%|██████████| 1654/1654 [00:03<00:00, 484.26batch/s, loss=3.74, topk_accuracy=0.55] 


Epoch [28/100], Training Loss: 3.4228, Top-5 Accuracy: 0.5501


Epoch [29/100]: 100%|██████████| 1654/1654 [00:03<00:00, 447.91batch/s, loss=1.92, topk_accuracy=0.554]


Epoch [29/100], Training Loss: 3.3945, Top-5 Accuracy: 0.5543


Epoch [30/100]: 100%|██████████| 1654/1654 [00:03<00:00, 451.97batch/s, loss=3.16, topk_accuracy=0.557]


Epoch [30/100], Training Loss: 3.3679, Top-5 Accuracy: 0.5572


Epoch [31/100]: 100%|██████████| 1654/1654 [00:03<00:00, 478.61batch/s, loss=3.59, topk_accuracy=0.561]


Epoch [31/100], Training Loss: 3.3442, Top-5 Accuracy: 0.5614


Epoch [32/100]: 100%|██████████| 1654/1654 [00:03<00:00, 455.76batch/s, loss=2.1, topk_accuracy=0.566] 


Epoch [32/100], Training Loss: 3.3187, Top-5 Accuracy: 0.5657


Epoch [33/100]: 100%|██████████| 1654/1654 [00:03<00:00, 485.75batch/s, loss=3.25, topk_accuracy=0.567]


Epoch [33/100], Training Loss: 3.3022, Top-5 Accuracy: 0.5669


Epoch [34/100]: 100%|██████████| 1654/1654 [00:03<00:00, 481.19batch/s, loss=3.86, topk_accuracy=0.571]


Epoch [34/100], Training Loss: 3.2794, Top-5 Accuracy: 0.5707


Epoch [35/100]: 100%|██████████| 1654/1654 [00:03<00:00, 499.55batch/s, loss=2.68, topk_accuracy=0.573]


Epoch [35/100], Training Loss: 3.2621, Top-5 Accuracy: 0.5727


Epoch [36/100]: 100%|██████████| 1654/1654 [00:03<00:00, 465.49batch/s, loss=3.48, topk_accuracy=0.575]


Epoch [36/100], Training Loss: 3.2433, Top-5 Accuracy: 0.5754


Epoch [37/100]: 100%|██████████| 1654/1654 [00:03<00:00, 475.94batch/s, loss=2.58, topk_accuracy=0.579]


Epoch [37/100], Training Loss: 3.2231, Top-5 Accuracy: 0.5789


Epoch [38/100]: 100%|██████████| 1654/1654 [00:03<00:00, 497.75batch/s, loss=3.41, topk_accuracy=0.582]


Epoch [38/100], Training Loss: 3.2062, Top-5 Accuracy: 0.5824


Epoch [39/100]: 100%|██████████| 1654/1654 [00:03<00:00, 494.72batch/s, loss=3.41, topk_accuracy=0.585]


Epoch [39/100], Training Loss: 3.1849, Top-5 Accuracy: 0.5846


Epoch [40/100]: 100%|██████████| 1654/1654 [00:03<00:00, 485.09batch/s, loss=3.01, topk_accuracy=0.587]


Epoch [40/100], Training Loss: 3.1728, Top-5 Accuracy: 0.5865


Epoch [41/100]: 100%|██████████| 1654/1654 [00:03<00:00, 480.92batch/s, loss=1.91, topk_accuracy=0.589]


Epoch [41/100], Training Loss: 3.1580, Top-5 Accuracy: 0.5886


Epoch [42/100]: 100%|██████████| 1654/1654 [00:03<00:00, 453.12batch/s, loss=2.75, topk_accuracy=0.591]


Epoch [42/100], Training Loss: 3.1396, Top-5 Accuracy: 0.5914


Epoch [43/100]: 100%|██████████| 1654/1654 [00:03<00:00, 477.41batch/s, loss=4.12, topk_accuracy=0.593]


Epoch [43/100], Training Loss: 3.1241, Top-5 Accuracy: 0.5932


Epoch [44/100]: 100%|██████████| 1654/1654 [00:03<00:00, 477.28batch/s, loss=3.11, topk_accuracy=0.596]


Epoch [44/100], Training Loss: 3.1100, Top-5 Accuracy: 0.5960


Epoch [45/100]: 100%|██████████| 1654/1654 [00:03<00:00, 485.91batch/s, loss=3.5, topk_accuracy=0.598] 


Epoch [45/100], Training Loss: 3.0942, Top-5 Accuracy: 0.5981


Epoch [46/100]: 100%|██████████| 1654/1654 [00:03<00:00, 463.48batch/s, loss=2.54, topk_accuracy=0.6]  


Epoch [46/100], Training Loss: 3.0819, Top-5 Accuracy: 0.5997


Epoch [47/100]: 100%|██████████| 1654/1654 [00:03<00:00, 443.96batch/s, loss=3.35, topk_accuracy=0.602]


Epoch [47/100], Training Loss: 3.0685, Top-5 Accuracy: 0.6020


Epoch [48/100]: 100%|██████████| 1654/1654 [00:03<00:00, 484.13batch/s, loss=3.32, topk_accuracy=0.604]


Epoch [48/100], Training Loss: 3.0565, Top-5 Accuracy: 0.6039


Epoch [49/100]: 100%|██████████| 1654/1654 [00:03<00:00, 491.98batch/s, loss=2.88, topk_accuracy=0.605]


Epoch [49/100], Training Loss: 3.0471, Top-5 Accuracy: 0.6051


Epoch [50/100]: 100%|██████████| 1654/1654 [00:03<00:00, 479.01batch/s, loss=2.5, topk_accuracy=0.607] 


Epoch [50/100], Training Loss: 3.0350, Top-5 Accuracy: 0.6069


Epoch [51/100]: 100%|██████████| 1654/1654 [00:03<00:00, 495.45batch/s, loss=3.32, topk_accuracy=0.609]


Epoch [51/100], Training Loss: 3.0255, Top-5 Accuracy: 0.6092


Epoch [52/100]: 100%|██████████| 1654/1654 [00:03<00:00, 495.79batch/s, loss=4.46, topk_accuracy=0.611]


Epoch [52/100], Training Loss: 3.0121, Top-5 Accuracy: 0.6106


Epoch [53/100]: 100%|██████████| 1654/1654 [00:03<00:00, 491.16batch/s, loss=2.65, topk_accuracy=0.611]


Epoch [53/100], Training Loss: 3.0006, Top-5 Accuracy: 0.6109


Epoch [54/100]: 100%|██████████| 1654/1654 [00:03<00:00, 487.29batch/s, loss=3.62, topk_accuracy=0.614]


Epoch [54/100], Training Loss: 2.9914, Top-5 Accuracy: 0.6139


Epoch [55/100]: 100%|██████████| 1654/1654 [00:03<00:00, 493.36batch/s, loss=2.45, topk_accuracy=0.616]


Epoch [55/100], Training Loss: 2.9793, Top-5 Accuracy: 0.6160


Epoch [56/100]: 100%|██████████| 1654/1654 [00:03<00:00, 484.99batch/s, loss=4.09, topk_accuracy=0.616]


Epoch [56/100], Training Loss: 2.9701, Top-5 Accuracy: 0.6163


Epoch [57/100]: 100%|██████████| 1654/1654 [00:03<00:00, 461.36batch/s, loss=2.58, topk_accuracy=0.618]


Epoch [57/100], Training Loss: 2.9564, Top-5 Accuracy: 0.6176


Epoch [58/100]: 100%|██████████| 1654/1654 [00:03<00:00, 487.85batch/s, loss=2.76, topk_accuracy=0.62] 


Epoch [58/100], Training Loss: 2.9483, Top-5 Accuracy: 0.6197


Epoch [59/100]: 100%|██████████| 1654/1654 [00:03<00:00, 469.07batch/s, loss=3.56, topk_accuracy=0.62] 


Epoch [59/100], Training Loss: 2.9438, Top-5 Accuracy: 0.6196


Epoch [60/100]: 100%|██████████| 1654/1654 [00:03<00:00, 477.06batch/s, loss=3.22, topk_accuracy=0.623]


Epoch [60/100], Training Loss: 2.9304, Top-5 Accuracy: 0.6230


Epoch [61/100]: 100%|██████████| 1654/1654 [00:03<00:00, 480.02batch/s, loss=2.79, topk_accuracy=0.623]


Epoch [61/100], Training Loss: 2.9253, Top-5 Accuracy: 0.6230


Epoch [62/100]: 100%|██████████| 1654/1654 [00:03<00:00, 454.49batch/s, loss=2.78, topk_accuracy=0.625]


Epoch [62/100], Training Loss: 2.9180, Top-5 Accuracy: 0.6250


Epoch [63/100]: 100%|██████████| 1654/1654 [00:03<00:00, 472.77batch/s, loss=3.28, topk_accuracy=0.625]


Epoch [63/100], Training Loss: 2.9084, Top-5 Accuracy: 0.6255


Epoch [64/100]: 100%|██████████| 1654/1654 [00:03<00:00, 469.60batch/s, loss=3.54, topk_accuracy=0.626]


Epoch [64/100], Training Loss: 2.9005, Top-5 Accuracy: 0.6264


Epoch [65/100]: 100%|██████████| 1654/1654 [00:03<00:00, 465.20batch/s, loss=3.4, topk_accuracy=0.628] 


Epoch [65/100], Training Loss: 2.8912, Top-5 Accuracy: 0.6282


Epoch [66/100]: 100%|██████████| 1654/1654 [00:03<00:00, 470.32batch/s, loss=1.94, topk_accuracy=0.628]


Epoch [66/100], Training Loss: 2.8877, Top-5 Accuracy: 0.6276


Epoch [67/100]: 100%|██████████| 1654/1654 [00:03<00:00, 489.55batch/s, loss=3.21, topk_accuracy=0.63] 


Epoch [67/100], Training Loss: 2.8777, Top-5 Accuracy: 0.6298


Epoch [68/100]: 100%|██████████| 1654/1654 [00:03<00:00, 485.27batch/s, loss=3.2, topk_accuracy=0.632] 


Epoch [68/100], Training Loss: 2.8663, Top-5 Accuracy: 0.6317


Epoch [69/100]: 100%|██████████| 1654/1654 [00:03<00:00, 487.64batch/s, loss=3.22, topk_accuracy=0.632]


Epoch [69/100], Training Loss: 2.8632, Top-5 Accuracy: 0.6318


Epoch [70/100]: 100%|██████████| 1654/1654 [00:03<00:00, 485.95batch/s, loss=2.5, topk_accuracy=0.633] 


Epoch [70/100], Training Loss: 2.8594, Top-5 Accuracy: 0.6332


Epoch [71/100]: 100%|██████████| 1654/1654 [00:03<00:00, 466.38batch/s, loss=3.69, topk_accuracy=0.634]


Epoch [71/100], Training Loss: 2.8506, Top-5 Accuracy: 0.6335


Epoch [72/100]: 100%|██████████| 1654/1654 [00:03<00:00, 482.34batch/s, loss=3.33, topk_accuracy=0.634]


Epoch [72/100], Training Loss: 2.8430, Top-5 Accuracy: 0.6343


Epoch [73/100]: 100%|██████████| 1654/1654 [00:03<00:00, 494.90batch/s, loss=2.83, topk_accuracy=0.636]


Epoch [73/100], Training Loss: 2.8359, Top-5 Accuracy: 0.6363


Epoch [74/100]: 100%|██████████| 1654/1654 [00:03<00:00, 464.55batch/s, loss=3.36, topk_accuracy=0.636]


Epoch [74/100], Training Loss: 2.8332, Top-5 Accuracy: 0.6365


Epoch [75/100]: 100%|██████████| 1654/1654 [00:03<00:00, 476.05batch/s, loss=1.7, topk_accuracy=0.638] 


Epoch [75/100], Training Loss: 2.8247, Top-5 Accuracy: 0.6379


Epoch [76/100]: 100%|██████████| 1654/1654 [00:03<00:00, 484.52batch/s, loss=1.93, topk_accuracy=0.639]


Epoch [76/100], Training Loss: 2.8197, Top-5 Accuracy: 0.6392


Epoch [77/100]: 100%|██████████| 1654/1654 [00:03<00:00, 471.34batch/s, loss=3.23, topk_accuracy=0.639]


Epoch [77/100], Training Loss: 2.8118, Top-5 Accuracy: 0.6386


Epoch [78/100]: 100%|██████████| 1654/1654 [00:03<00:00, 459.67batch/s, loss=2.67, topk_accuracy=0.641]


Epoch [78/100], Training Loss: 2.8026, Top-5 Accuracy: 0.6410


Epoch [79/100]: 100%|██████████| 1654/1654 [00:03<00:00, 487.62batch/s, loss=2.94, topk_accuracy=0.641]


Epoch [79/100], Training Loss: 2.8008, Top-5 Accuracy: 0.6409


Epoch [80/100]: 100%|██████████| 1654/1654 [00:03<00:00, 500.21batch/s, loss=2.92, topk_accuracy=0.643]


Epoch [80/100], Training Loss: 2.7948, Top-5 Accuracy: 0.6427


Epoch [81/100]: 100%|██████████| 1654/1654 [00:03<00:00, 487.17batch/s, loss=2.56, topk_accuracy=0.643]


Epoch [81/100], Training Loss: 2.7909, Top-5 Accuracy: 0.6426


Epoch [82/100]: 100%|██████████| 1654/1654 [00:03<00:00, 473.03batch/s, loss=2.85, topk_accuracy=0.644]


Epoch [82/100], Training Loss: 2.7834, Top-5 Accuracy: 0.6438


Epoch [83/100]: 100%|██████████| 1654/1654 [00:03<00:00, 462.58batch/s, loss=2.71, topk_accuracy=0.645]


Epoch [83/100], Training Loss: 2.7758, Top-5 Accuracy: 0.6452


Epoch [84/100]: 100%|██████████| 1654/1654 [00:03<00:00, 479.85batch/s, loss=2.53, topk_accuracy=0.644]


Epoch [84/100], Training Loss: 2.7713, Top-5 Accuracy: 0.6443


Epoch [85/100]: 100%|██████████| 1654/1654 [00:03<00:00, 497.17batch/s, loss=2.62, topk_accuracy=0.646]


Epoch [85/100], Training Loss: 2.7695, Top-5 Accuracy: 0.6459


Epoch [86/100]: 100%|██████████| 1654/1654 [00:03<00:00, 464.56batch/s, loss=2.91, topk_accuracy=0.647]


Epoch [86/100], Training Loss: 2.7669, Top-5 Accuracy: 0.6469


Epoch [87/100]: 100%|██████████| 1654/1654 [00:03<00:00, 493.03batch/s, loss=2.59, topk_accuracy=0.648]


Epoch [87/100], Training Loss: 2.7579, Top-5 Accuracy: 0.6482


Epoch [88/100]: 100%|██████████| 1654/1654 [00:03<00:00, 486.45batch/s, loss=2.97, topk_accuracy=0.649]


Epoch [88/100], Training Loss: 2.7512, Top-5 Accuracy: 0.6490


Epoch [89/100]: 100%|██████████| 1654/1654 [00:03<00:00, 503.71batch/s, loss=3.39, topk_accuracy=0.649]


Epoch [89/100], Training Loss: 2.7470, Top-5 Accuracy: 0.6488


Epoch [90/100]: 100%|██████████| 1654/1654 [00:03<00:00, 488.01batch/s, loss=4.59, topk_accuracy=0.65] 


Epoch [90/100], Training Loss: 2.7443, Top-5 Accuracy: 0.6496


Epoch [91/100]: 100%|██████████| 1654/1654 [00:03<00:00, 490.65batch/s, loss=2.08, topk_accuracy=0.65] 


Epoch [91/100], Training Loss: 2.7431, Top-5 Accuracy: 0.6497


Epoch [92/100]: 100%|██████████| 1654/1654 [00:03<00:00, 475.83batch/s, loss=3.14, topk_accuracy=0.652]


Epoch [92/100], Training Loss: 2.7346, Top-5 Accuracy: 0.6518


Epoch [93/100]: 100%|██████████| 1654/1654 [00:03<00:00, 477.06batch/s, loss=2.34, topk_accuracy=0.652]


Epoch [93/100], Training Loss: 2.7291, Top-5 Accuracy: 0.6519


Epoch [94/100]: 100%|██████████| 1654/1654 [00:03<00:00, 475.90batch/s, loss=2.53, topk_accuracy=0.653]


Epoch [94/100], Training Loss: 2.7240, Top-5 Accuracy: 0.6530


Epoch [95/100]: 100%|██████████| 1654/1654 [00:03<00:00, 472.93batch/s, loss=2.79, topk_accuracy=0.653]


Epoch [95/100], Training Loss: 2.7231, Top-5 Accuracy: 0.6530


Epoch [96/100]: 100%|██████████| 1654/1654 [00:03<00:00, 469.91batch/s, loss=2.87, topk_accuracy=0.654]


Epoch [96/100], Training Loss: 2.7166, Top-5 Accuracy: 0.6540


Epoch [97/100]: 100%|██████████| 1654/1654 [00:03<00:00, 476.96batch/s, loss=2.28, topk_accuracy=0.653]


Epoch [97/100], Training Loss: 2.7147, Top-5 Accuracy: 0.6535


Epoch [98/100]: 100%|██████████| 1654/1654 [00:03<00:00, 472.07batch/s, loss=2.13, topk_accuracy=0.656]


Epoch [98/100], Training Loss: 2.7054, Top-5 Accuracy: 0.6557


Epoch [99/100]: 100%|██████████| 1654/1654 [00:03<00:00, 475.36batch/s, loss=2.64, topk_accuracy=0.655]


Epoch [99/100], Training Loss: 2.7041, Top-5 Accuracy: 0.6555


Epoch [100/100]: 100%|██████████| 1654/1654 [00:03<00:00, 492.80batch/s, loss=2.38, topk_accuracy=0.655]


Epoch [100/100], Training Loss: 2.7040, Top-5 Accuracy: 0.6546


In [40]:
#Now we train the bidirectional LSTM model
embedding_dim = 256
hidden_size = 128
num_layers = 2         
lstm_dropout = 0.3

bilstm_model = LSTM_Model(
    embedding_dim=embedding_dim,
    hidden_size=hidden_size,
    num_layers=num_layers,
    vocab_size=vocab_size,
    lstm_dropout=lstm_dropout,
    bidirectional=True
)
bilstm_model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
    filter(lambda p: p.requires_grad, bilstm_model.parameters()),
    lr=0.001,
)
print(bilstm_model)

LSTM_Model(
  (embeddings): Embedding(7314, 256, padding_idx=0)
  (lstm): LSTM(256, 128, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)
  (fc): Linear(in_features=256, out_features=7314, bias=True)
)


In [41]:
train_model(model=bilstm_model,train_loader=train_loader,bidirectional=True)

Epoch [1/100]: 100%|██████████| 1654/1654 [00:04<00:00, 383.44batch/s, loss=4.59, topk_accuracy=0.254]


Epoch [1/100], Training Loss: 6.1938, Top-5 Accuracy: 0.2536


Epoch [2/100]: 100%|██████████| 1654/1654 [00:04<00:00, 385.53batch/s, loss=6.79, topk_accuracy=0.319]


Epoch [2/100], Training Loss: 5.4453, Top-5 Accuracy: 0.3193


Epoch [3/100]: 100%|██████████| 1654/1654 [00:04<00:00, 397.66batch/s, loss=6.41, topk_accuracy=0.356]


Epoch [3/100], Training Loss: 5.0333, Top-5 Accuracy: 0.3563


Epoch [4/100]: 100%|██████████| 1654/1654 [00:04<00:00, 383.82batch/s, loss=4.07, topk_accuracy=0.387]


Epoch [4/100], Training Loss: 4.7098, Top-5 Accuracy: 0.3868


Epoch [5/100]: 100%|██████████| 1654/1654 [00:04<00:00, 397.72batch/s, loss=4.17, topk_accuracy=0.414]


Epoch [5/100], Training Loss: 4.4460, Top-5 Accuracy: 0.4138


Epoch [6/100]: 100%|██████████| 1654/1654 [00:04<00:00, 397.69batch/s, loss=4.13, topk_accuracy=0.438]


Epoch [6/100], Training Loss: 4.2299, Top-5 Accuracy: 0.4379


Epoch [7/100]: 100%|██████████| 1654/1654 [00:04<00:00, 384.96batch/s, loss=4.06, topk_accuracy=0.461]


Epoch [7/100], Training Loss: 4.0485, Top-5 Accuracy: 0.4607


Epoch [8/100]: 100%|██████████| 1654/1654 [00:04<00:00, 385.00batch/s, loss=3.12, topk_accuracy=0.48] 


Epoch [8/100], Training Loss: 3.8959, Top-5 Accuracy: 0.4796


Epoch [9/100]: 100%|██████████| 1654/1654 [00:04<00:00, 388.30batch/s, loss=4.09, topk_accuracy=0.498]


Epoch [9/100], Training Loss: 3.7615, Top-5 Accuracy: 0.4984


Epoch [10/100]: 100%|██████████| 1654/1654 [00:04<00:00, 399.51batch/s, loss=4.6, topk_accuracy=0.514] 


Epoch [10/100], Training Loss: 3.6468, Top-5 Accuracy: 0.5144


Epoch [11/100]: 100%|██████████| 1654/1654 [00:04<00:00, 395.71batch/s, loss=3.61, topk_accuracy=0.53] 


Epoch [11/100], Training Loss: 3.5425, Top-5 Accuracy: 0.5302


Epoch [12/100]: 100%|██████████| 1654/1654 [00:04<00:00, 383.14batch/s, loss=2.61, topk_accuracy=0.543]


Epoch [12/100], Training Loss: 3.4513, Top-5 Accuracy: 0.5434


Epoch [13/100]: 100%|██████████| 1654/1654 [00:04<00:00, 381.31batch/s, loss=3.24, topk_accuracy=0.556]


Epoch [13/100], Training Loss: 3.3683, Top-5 Accuracy: 0.5559


Epoch [14/100]: 100%|██████████| 1654/1654 [00:04<00:00, 381.72batch/s, loss=3.74, topk_accuracy=0.568]


Epoch [14/100], Training Loss: 3.2992, Top-5 Accuracy: 0.5678


Epoch [15/100]: 100%|██████████| 1654/1654 [00:04<00:00, 395.60batch/s, loss=3.3, topk_accuracy=0.576] 


Epoch [15/100], Training Loss: 3.2377, Top-5 Accuracy: 0.5765


Epoch [16/100]: 100%|██████████| 1654/1654 [00:04<00:00, 395.90batch/s, loss=3.9, topk_accuracy=0.586] 


Epoch [16/100], Training Loss: 3.1743, Top-5 Accuracy: 0.5857


Epoch [17/100]: 100%|██████████| 1654/1654 [00:04<00:00, 382.74batch/s, loss=2.62, topk_accuracy=0.594]


Epoch [17/100], Training Loss: 3.1208, Top-5 Accuracy: 0.5940


Epoch [18/100]: 100%|██████████| 1654/1654 [00:04<00:00, 388.44batch/s, loss=4.1, topk_accuracy=0.601] 


Epoch [18/100], Training Loss: 3.0733, Top-5 Accuracy: 0.6010


Epoch [19/100]: 100%|██████████| 1654/1654 [00:04<00:00, 397.20batch/s, loss=4.64, topk_accuracy=0.608]


Epoch [19/100], Training Loss: 3.0308, Top-5 Accuracy: 0.6083


Epoch [20/100]: 100%|██████████| 1654/1654 [00:04<00:00, 392.08batch/s, loss=3.34, topk_accuracy=0.615]


Epoch [20/100], Training Loss: 2.9831, Top-5 Accuracy: 0.6154


Epoch [21/100]: 100%|██████████| 1654/1654 [00:04<00:00, 394.39batch/s, loss=3.63, topk_accuracy=0.621]


Epoch [21/100], Training Loss: 2.9456, Top-5 Accuracy: 0.6209


Epoch [22/100]: 100%|██████████| 1654/1654 [00:04<00:00, 387.10batch/s, loss=2.85, topk_accuracy=0.626]


Epoch [22/100], Training Loss: 2.9138, Top-5 Accuracy: 0.6257


Epoch [23/100]: 100%|██████████| 1654/1654 [00:04<00:00, 389.55batch/s, loss=2.82, topk_accuracy=0.631]


Epoch [23/100], Training Loss: 2.8826, Top-5 Accuracy: 0.6310


Epoch [24/100]: 100%|██████████| 1654/1654 [00:04<00:00, 382.56batch/s, loss=3.49, topk_accuracy=0.637]


Epoch [24/100], Training Loss: 2.8484, Top-5 Accuracy: 0.6371


Epoch [25/100]: 100%|██████████| 1654/1654 [00:04<00:00, 395.13batch/s, loss=5, topk_accuracy=0.64]    


Epoch [25/100], Training Loss: 2.8228, Top-5 Accuracy: 0.6405


Epoch [26/100]: 100%|██████████| 1654/1654 [00:04<00:00, 391.32batch/s, loss=3.42, topk_accuracy=0.645]


Epoch [26/100], Training Loss: 2.7944, Top-5 Accuracy: 0.6448


Epoch [27/100]: 100%|██████████| 1654/1654 [00:04<00:00, 375.47batch/s, loss=3.34, topk_accuracy=0.649]


Epoch [27/100], Training Loss: 2.7700, Top-5 Accuracy: 0.6487


Epoch [28/100]: 100%|██████████| 1654/1654 [00:04<00:00, 400.18batch/s, loss=1.88, topk_accuracy=0.651]


Epoch [28/100], Training Loss: 2.7471, Top-5 Accuracy: 0.6511


Epoch [29/100]: 100%|██████████| 1654/1654 [00:04<00:00, 397.53batch/s, loss=3.29, topk_accuracy=0.655]


Epoch [29/100], Training Loss: 2.7244, Top-5 Accuracy: 0.6550


Epoch [30/100]: 100%|██████████| 1654/1654 [00:04<00:00, 396.91batch/s, loss=2.52, topk_accuracy=0.659]


Epoch [30/100], Training Loss: 2.7025, Top-5 Accuracy: 0.6589


Epoch [31/100]: 100%|██████████| 1654/1654 [00:04<00:00, 365.41batch/s, loss=2.39, topk_accuracy=0.661]


Epoch [31/100], Training Loss: 2.6829, Top-5 Accuracy: 0.6613


Epoch [32/100]: 100%|██████████| 1654/1654 [00:04<00:00, 384.36batch/s, loss=1.91, topk_accuracy=0.665]


Epoch [32/100], Training Loss: 2.6670, Top-5 Accuracy: 0.6646


Epoch [33/100]: 100%|██████████| 1654/1654 [00:04<00:00, 384.45batch/s, loss=3.79, topk_accuracy=0.668]


Epoch [33/100], Training Loss: 2.6462, Top-5 Accuracy: 0.6677


Epoch [34/100]: 100%|██████████| 1654/1654 [00:04<00:00, 388.88batch/s, loss=3.52, topk_accuracy=0.669]


Epoch [34/100], Training Loss: 2.6316, Top-5 Accuracy: 0.6695


Epoch [35/100]: 100%|██████████| 1654/1654 [00:04<00:00, 379.06batch/s, loss=1.62, topk_accuracy=0.673]


Epoch [35/100], Training Loss: 2.6144, Top-5 Accuracy: 0.6728


Epoch [36/100]: 100%|██████████| 1654/1654 [00:04<00:00, 383.29batch/s, loss=3.27, topk_accuracy=0.675]


Epoch [36/100], Training Loss: 2.5979, Top-5 Accuracy: 0.6749


Epoch [37/100]: 100%|██████████| 1654/1654 [00:04<00:00, 376.42batch/s, loss=2.86, topk_accuracy=0.676]


Epoch [37/100], Training Loss: 2.5860, Top-5 Accuracy: 0.6761


Epoch [38/100]: 100%|██████████| 1654/1654 [00:04<00:00, 392.93batch/s, loss=3.06, topk_accuracy=0.678]


Epoch [38/100], Training Loss: 2.5771, Top-5 Accuracy: 0.6780


Epoch [39/100]: 100%|██████████| 1654/1654 [00:04<00:00, 353.79batch/s, loss=3.65, topk_accuracy=0.68] 


Epoch [39/100], Training Loss: 2.5593, Top-5 Accuracy: 0.6799


Epoch [40/100]: 100%|██████████| 1654/1654 [00:04<00:00, 358.72batch/s, loss=3.12, topk_accuracy=0.682]


Epoch [40/100], Training Loss: 2.5469, Top-5 Accuracy: 0.6824


Epoch [41/100]: 100%|██████████| 1654/1654 [00:04<00:00, 379.51batch/s, loss=3.38, topk_accuracy=0.684]


Epoch [41/100], Training Loss: 2.5332, Top-5 Accuracy: 0.6842


Epoch [42/100]: 100%|██████████| 1654/1654 [00:04<00:00, 369.16batch/s, loss=1.91, topk_accuracy=0.687]


Epoch [42/100], Training Loss: 2.5230, Top-5 Accuracy: 0.6867


Epoch [43/100]: 100%|██████████| 1654/1654 [00:04<00:00, 375.06batch/s, loss=3.11, topk_accuracy=0.688]


Epoch [43/100], Training Loss: 2.5153, Top-5 Accuracy: 0.6878


Epoch [44/100]: 100%|██████████| 1654/1654 [00:04<00:00, 382.79batch/s, loss=1.71, topk_accuracy=0.69] 


Epoch [44/100], Training Loss: 2.4997, Top-5 Accuracy: 0.6900


Epoch [45/100]: 100%|██████████| 1654/1654 [00:04<00:00, 371.33batch/s, loss=2.1, topk_accuracy=0.692] 


Epoch [45/100], Training Loss: 2.4870, Top-5 Accuracy: 0.6917


Epoch [46/100]: 100%|██████████| 1654/1654 [00:04<00:00, 367.46batch/s, loss=2.95, topk_accuracy=0.693]


Epoch [46/100], Training Loss: 2.4844, Top-5 Accuracy: 0.6930


Epoch [47/100]: 100%|██████████| 1654/1654 [00:04<00:00, 375.36batch/s, loss=2.76, topk_accuracy=0.693]


Epoch [47/100], Training Loss: 2.4705, Top-5 Accuracy: 0.6933


Epoch [48/100]: 100%|██████████| 1654/1654 [00:04<00:00, 395.53batch/s, loss=1.8, topk_accuracy=0.696] 


Epoch [48/100], Training Loss: 2.4626, Top-5 Accuracy: 0.6958


Epoch [49/100]: 100%|██████████| 1654/1654 [00:04<00:00, 387.10batch/s, loss=1.68, topk_accuracy=0.696]


Epoch [49/100], Training Loss: 2.4568, Top-5 Accuracy: 0.6961


Epoch [50/100]: 100%|██████████| 1654/1654 [00:04<00:00, 389.83batch/s, loss=2, topk_accuracy=0.699]   


Epoch [50/100], Training Loss: 2.4383, Top-5 Accuracy: 0.6993


Epoch [51/100]: 100%|██████████| 1654/1654 [00:04<00:00, 397.36batch/s, loss=3.14, topk_accuracy=0.699]


Epoch [51/100], Training Loss: 2.4375, Top-5 Accuracy: 0.6995


Epoch [52/100]: 100%|██████████| 1654/1654 [00:04<00:00, 395.86batch/s, loss=2.45, topk_accuracy=0.701]


Epoch [52/100], Training Loss: 2.4272, Top-5 Accuracy: 0.7014


Epoch [53/100]: 100%|██████████| 1654/1654 [00:04<00:00, 386.64batch/s, loss=3.48, topk_accuracy=0.702]


Epoch [53/100], Training Loss: 2.4189, Top-5 Accuracy: 0.7022


Epoch [54/100]: 100%|██████████| 1654/1654 [00:04<00:00, 393.93batch/s, loss=2.66, topk_accuracy=0.703]


Epoch [54/100], Training Loss: 2.4120, Top-5 Accuracy: 0.7032


Epoch [55/100]: 100%|██████████| 1654/1654 [00:04<00:00, 392.60batch/s, loss=2.25, topk_accuracy=0.704]


Epoch [55/100], Training Loss: 2.4080, Top-5 Accuracy: 0.7044


Epoch [56/100]: 100%|██████████| 1654/1654 [00:04<00:00, 386.20batch/s, loss=3.04, topk_accuracy=0.704]


Epoch [56/100], Training Loss: 2.4024, Top-5 Accuracy: 0.7042


Epoch [57/100]: 100%|██████████| 1654/1654 [00:04<00:00, 389.28batch/s, loss=1.83, topk_accuracy=0.706]


Epoch [57/100], Training Loss: 2.3929, Top-5 Accuracy: 0.7060


Epoch [58/100]: 100%|██████████| 1654/1654 [00:04<00:00, 390.44batch/s, loss=2.81, topk_accuracy=0.708]


Epoch [58/100], Training Loss: 2.3869, Top-5 Accuracy: 0.7080


Epoch [59/100]: 100%|██████████| 1654/1654 [00:04<00:00, 391.63batch/s, loss=2.03, topk_accuracy=0.707]


Epoch [59/100], Training Loss: 2.3808, Top-5 Accuracy: 0.7073


Epoch [60/100]: 100%|██████████| 1654/1654 [00:04<00:00, 396.93batch/s, loss=2.99, topk_accuracy=0.71] 


Epoch [60/100], Training Loss: 2.3712, Top-5 Accuracy: 0.7100


Epoch [61/100]: 100%|██████████| 1654/1654 [00:04<00:00, 386.31batch/s, loss=2.81, topk_accuracy=0.71] 


Epoch [61/100], Training Loss: 2.3663, Top-5 Accuracy: 0.7100


Epoch [62/100]: 100%|██████████| 1654/1654 [00:04<00:00, 390.75batch/s, loss=2.29, topk_accuracy=0.711]


Epoch [62/100], Training Loss: 2.3607, Top-5 Accuracy: 0.7112


Epoch [63/100]: 100%|██████████| 1654/1654 [00:04<00:00, 390.42batch/s, loss=2.26, topk_accuracy=0.712]


Epoch [63/100], Training Loss: 2.3567, Top-5 Accuracy: 0.7115


Epoch [64/100]: 100%|██████████| 1654/1654 [00:04<00:00, 394.66batch/s, loss=2.7, topk_accuracy=0.711] 


Epoch [64/100], Training Loss: 2.3552, Top-5 Accuracy: 0.7114


Epoch [65/100]: 100%|██████████| 1654/1654 [00:04<00:00, 386.15batch/s, loss=2.15, topk_accuracy=0.713]


Epoch [65/100], Training Loss: 2.3486, Top-5 Accuracy: 0.7126


Epoch [66/100]: 100%|██████████| 1654/1654 [00:04<00:00, 402.00batch/s, loss=1.33, topk_accuracy=0.714]


Epoch [66/100], Training Loss: 2.3405, Top-5 Accuracy: 0.7137


Epoch [67/100]: 100%|██████████| 1654/1654 [00:04<00:00, 390.39batch/s, loss=2.09, topk_accuracy=0.715]


Epoch [67/100], Training Loss: 2.3317, Top-5 Accuracy: 0.7149


Epoch [68/100]: 100%|██████████| 1654/1654 [00:04<00:00, 401.60batch/s, loss=2.34, topk_accuracy=0.715]


Epoch [68/100], Training Loss: 2.3320, Top-5 Accuracy: 0.7153
No improvement in training loss. Early stop counter: 1/5


Epoch [69/100]: 100%|██████████| 1654/1654 [00:04<00:00, 394.34batch/s, loss=1.72, topk_accuracy=0.717]


Epoch [69/100], Training Loss: 2.3253, Top-5 Accuracy: 0.7171


Epoch [70/100]: 100%|██████████| 1654/1654 [00:04<00:00, 382.89batch/s, loss=1.98, topk_accuracy=0.717]


Epoch [70/100], Training Loss: 2.3212, Top-5 Accuracy: 0.7168


Epoch [71/100]: 100%|██████████| 1654/1654 [00:04<00:00, 385.70batch/s, loss=3.09, topk_accuracy=0.716]


Epoch [71/100], Training Loss: 2.3206, Top-5 Accuracy: 0.7157


Epoch [72/100]: 100%|██████████| 1654/1654 [00:04<00:00, 398.14batch/s, loss=1.47, topk_accuracy=0.718]


Epoch [72/100], Training Loss: 2.3085, Top-5 Accuracy: 0.7183


Epoch [73/100]: 100%|██████████| 1654/1654 [00:04<00:00, 387.25batch/s, loss=0.863, topk_accuracy=0.72]


Epoch [73/100], Training Loss: 2.3017, Top-5 Accuracy: 0.7200


Epoch [74/100]: 100%|██████████| 1654/1654 [00:04<00:00, 380.03batch/s, loss=3.16, topk_accuracy=0.719]


Epoch [74/100], Training Loss: 2.3024, Top-5 Accuracy: 0.7193
No improvement in training loss. Early stop counter: 1/5


Epoch [75/100]: 100%|██████████| 1654/1654 [00:04<00:00, 379.87batch/s, loss=2.2, topk_accuracy=0.72]  


Epoch [75/100], Training Loss: 2.2958, Top-5 Accuracy: 0.7201


Epoch [76/100]: 100%|██████████| 1654/1654 [00:04<00:00, 385.22batch/s, loss=3.03, topk_accuracy=0.72] 


Epoch [76/100], Training Loss: 2.2957, Top-5 Accuracy: 0.7202


Epoch [77/100]: 100%|██████████| 1654/1654 [00:04<00:00, 392.88batch/s, loss=2.35, topk_accuracy=0.72] 


Epoch [77/100], Training Loss: 2.2972, Top-5 Accuracy: 0.7201
No improvement in training loss. Early stop counter: 1/5


Epoch [78/100]: 100%|██████████| 1654/1654 [00:04<00:00, 387.68batch/s, loss=1.5, topk_accuracy=0.722] 


Epoch [78/100], Training Loss: 2.2879, Top-5 Accuracy: 0.7219


Epoch [79/100]: 100%|██████████| 1654/1654 [00:04<00:00, 374.96batch/s, loss=3.06, topk_accuracy=0.723]


Epoch [79/100], Training Loss: 2.2839, Top-5 Accuracy: 0.7227


Epoch [80/100]: 100%|██████████| 1654/1654 [00:04<00:00, 373.43batch/s, loss=2.18, topk_accuracy=0.722]


Epoch [80/100], Training Loss: 2.2785, Top-5 Accuracy: 0.7224


Epoch [81/100]: 100%|██████████| 1654/1654 [00:04<00:00, 376.73batch/s, loss=2.69, topk_accuracy=0.724]


Epoch [81/100], Training Loss: 2.2788, Top-5 Accuracy: 0.7238
No improvement in training loss. Early stop counter: 1/5


Epoch [82/100]: 100%|██████████| 1654/1654 [00:04<00:00, 369.02batch/s, loss=2.5, topk_accuracy=0.725] 


Epoch [82/100], Training Loss: 2.2675, Top-5 Accuracy: 0.7246


Epoch [83/100]: 100%|██████████| 1654/1654 [00:04<00:00, 381.73batch/s, loss=2.79, topk_accuracy=0.726]


Epoch [83/100], Training Loss: 2.2649, Top-5 Accuracy: 0.7257


Epoch [84/100]: 100%|██████████| 1654/1654 [00:04<00:00, 386.27batch/s, loss=2.07, topk_accuracy=0.724]


Epoch [84/100], Training Loss: 2.2653, Top-5 Accuracy: 0.7241
No improvement in training loss. Early stop counter: 1/5


Epoch [85/100]: 100%|██████████| 1654/1654 [00:04<00:00, 395.43batch/s, loss=2.85, topk_accuracy=0.725]


Epoch [85/100], Training Loss: 2.2639, Top-5 Accuracy: 0.7252


Epoch [86/100]: 100%|██████████| 1654/1654 [00:04<00:00, 387.01batch/s, loss=2.25, topk_accuracy=0.726]


Epoch [86/100], Training Loss: 2.2574, Top-5 Accuracy: 0.7259


Epoch [87/100]: 100%|██████████| 1654/1654 [00:04<00:00, 374.68batch/s, loss=1.86, topk_accuracy=0.727]


Epoch [87/100], Training Loss: 2.2552, Top-5 Accuracy: 0.7271


Epoch [88/100]: 100%|██████████| 1654/1654 [00:04<00:00, 387.99batch/s, loss=2.61, topk_accuracy=0.726]


Epoch [88/100], Training Loss: 2.2528, Top-5 Accuracy: 0.7264


Epoch [89/100]: 100%|██████████| 1654/1654 [00:04<00:00, 397.63batch/s, loss=2.92, topk_accuracy=0.728]


Epoch [89/100], Training Loss: 2.2445, Top-5 Accuracy: 0.7279


Epoch [90/100]: 100%|██████████| 1654/1654 [00:04<00:00, 388.51batch/s, loss=2.95, topk_accuracy=0.728]


Epoch [90/100], Training Loss: 2.2472, Top-5 Accuracy: 0.7278
No improvement in training loss. Early stop counter: 1/5


Epoch [91/100]: 100%|██████████| 1654/1654 [00:04<00:00, 373.11batch/s, loss=2.48, topk_accuracy=0.728]


Epoch [91/100], Training Loss: 2.2427, Top-5 Accuracy: 0.7281


Epoch [92/100]: 100%|██████████| 1654/1654 [00:04<00:00, 387.37batch/s, loss=2.63, topk_accuracy=0.727]


Epoch [92/100], Training Loss: 2.2387, Top-5 Accuracy: 0.7272


Epoch [93/100]: 100%|██████████| 1654/1654 [00:04<00:00, 396.44batch/s, loss=2.5, topk_accuracy=0.73]  


Epoch [93/100], Training Loss: 2.2351, Top-5 Accuracy: 0.7298


Epoch [94/100]: 100%|██████████| 1654/1654 [00:04<00:00, 378.69batch/s, loss=2.79, topk_accuracy=0.729]


Epoch [94/100], Training Loss: 2.2357, Top-5 Accuracy: 0.7292
No improvement in training loss. Early stop counter: 1/5


Epoch [95/100]: 100%|██████████| 1654/1654 [00:04<00:00, 379.81batch/s, loss=2.4, topk_accuracy=0.73]  


Epoch [95/100], Training Loss: 2.2325, Top-5 Accuracy: 0.7301


Epoch [96/100]: 100%|██████████| 1654/1654 [00:04<00:00, 395.57batch/s, loss=1.89, topk_accuracy=0.731]


Epoch [96/100], Training Loss: 2.2248, Top-5 Accuracy: 0.7310


Epoch [97/100]: 100%|██████████| 1654/1654 [00:04<00:00, 394.44batch/s, loss=1.85, topk_accuracy=0.731]


Epoch [97/100], Training Loss: 2.2246, Top-5 Accuracy: 0.7314


Epoch [98/100]: 100%|██████████| 1654/1654 [00:04<00:00, 393.92batch/s, loss=1.41, topk_accuracy=0.732]


Epoch [98/100], Training Loss: 2.2190, Top-5 Accuracy: 0.7317


Epoch [99/100]: 100%|██████████| 1654/1654 [00:04<00:00, 393.23batch/s, loss=2.7, topk_accuracy=0.732] 


Epoch [99/100], Training Loss: 2.2198, Top-5 Accuracy: 0.7317
No improvement in training loss. Early stop counter: 1/5


Epoch [100/100]: 100%|██████████| 1654/1654 [00:04<00:00, 387.29batch/s, loss=2.09, topk_accuracy=0.734]

Epoch [100/100], Training Loss: 2.2201, Top-5 Accuracy: 0.7336
No improvement in training loss. Early stop counter: 2/5





In [42]:
lstm_check_point_path = './model_checkpoints/lstm/best_model.pth'
bilstm_check_point_path = './model_checkpoints/bilstm/best_model.pth'
lstm_check_point = torch.load(lstm_check_point_path,weights_only = True)
bilstm_check_point = torch.load(bilstm_check_point_path,weights_only = True)

lstm_model.load_state_dict(lstm_check_point['model_state_dict'])
bilstm_model.load_state_dict(bilstm_check_point['model_state_dict'])

<All keys matched successfully>

In [43]:
# Autoregressively keep on predicting next word
def predict_next_words(model, input_text,n,show_top_5=False):
    text = input_text.lower()
    model.eval()
    with torch.no_grad():
        for i in range(n):
            if i == 0:
                tokens = preprocess_text(text)
                input_ids = [word_to_idx[token] for token in tokens]
            else:
                input_ids.append(best_prediction)    
            outputs = model(torch.tensor(input_ids).unsqueeze(0).to(device))
            predictions = torch.topk(outputs, dim=1,k=5)[1]
            best_prediction = predictions[0][0]        
            predicted_words = [idx_to_word.get(idx.item(), '<UNK>') for idx in predictions[0]]
            if show_top_5:
                print(f"Top five predicted words - {predicted_words}")
            text = text + ' ' + predicted_words[0] 
            print(f"updated sentence --> {text}\n")   
  
    return text

LSTM next word generation

In [48]:
sample_texts = ['generative model','they are','it is believed']
for sample in sample_texts:
    print(f"Input text --> {sample}")
    predict_next_words(model=lstm_model, input_text=sample, n = 7, show_top_5=True)


Input text --> generative model
Top five predicted words - ['classifier', 'for', 'belonging', 'can', 'of']
updated sentence --> generative model classifier

Top five predicted words - ['we', 'to', 'that', 'designed', 'classifier']
updated sentence --> generative model classifier we

Top five predicted words - ['study', 'show', 'present', 'demonstrate', 'propose']
updated sentence --> generative model classifier we study

Top five predicted words - ['the', 'here', 'and', 'this', 'natural']
updated sentence --> generative model classifier we study the

Top five predicted words - ['physical', 'algorithm', 'development', 'joint', 'performance']
updated sentence --> generative model classifier we study the physical

Top five predicted words - ['outlook', 'steam', 'properties', '<UNK>', 'model']
updated sentence --> generative model classifier we study the physical outlook

Top five predicted words - ['in', 'that', 'a', 'the', 'our']
updated sentence --> generative model classifier we study 

BiLSTM next word generation

In [49]:
for sample in sample_texts:
    print(f"Input text --> {sample}")
    predict_next_words(model=bilstm_model, input_text=sample, n = 7, show_top_5=True)

Input text --> generative model
Top five predicted words - ['for', 'classifier', 'belonging', 'of', 'specifically']
updated sentence --> generative model for

Top five predicted words - ['each', '3d', 'the', 'sequential', 'random']
updated sentence --> generative model for each

Top five predicted words - ['long', 'operating', 'model', 'classification', 'unit']
updated sentence --> generative model for each long

Top five predicted words - ['learning', 'inputs', 'term', 'model', 'sequential']
updated sentence --> generative model for each long learning

Top five predicted words - ['methods', 'we', 'of', 'from', 'bayesian']
updated sentence --> generative model for each long learning methods

Top five predicted words - ['for', 'of', 'in', 'we', 'and']
updated sentence --> generative model for each long learning methods for

Top five predicted words - ['gaussian', 'convolutional', 'a', 'sparse', 'multi']
updated sentence --> generative model for each long learning methods for gaussian

I