In [27]:
import os
import re
import time
import random

import numpy as np
import pandas as pd
from collections import namedtuple

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split

In [28]:
def load_book(path):
    """Load a book from its file"""
    input_file = os.path.join(path)
    with open(input_file, encoding='utf-8', errors='replace') as f:
        book = f.read()
    return book

In [29]:
# Collect all of the book file names
path = '/kaggle/input/books-dataset/books/'
book_files = [f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f))]
book_files = book_files[1:]

In [30]:
# Load the books using the file names
books = []
for book in book_files:
    books.append(load_book(path+book))

In [31]:
# Compare the number of words in each book 
for i in range(len(books)):
    print("There are {} words in {}.".format(len(books[i].split()), book_files[i]))

There are 361612 words in Anna_Karenina_by_Leo_Tolstoy.rtf.
There are 96185 words in The_Adventures_of_Tom_Sawyer_by_Mark_Twain.rtf.
There are 194282 words in The_Romance_of_Lust_by_Anonymous.rtf.
There are 53211 words in The_Prince_by_Nicolo_Machiavelli.rtf.
There are 30423 words in Alices_Adventures_in_Wonderland_by_Lewis_Carroll.rtf.
There are 163109 words in Emma_by_Jane_Austen.rtf.
There are 110213 words in The_Adventures_of_Sherlock_Holmes_by_Arthur_Conan_Doyle.rtf.
There are 480495 words in The_Count_of_Monte_Cristo_by_Alexandre_Dumas.rtf.
There are 113452 words in David_Copperfield_by_Charles_Dickens.rtf.
There are 25395 words in Metamorphosis_by_Franz_Kafka.rtf.
There are 126999 words in Pride_and_Prejudice_by_Jane_Austen.rtf.
There are 83657 words in The_Picture_of_Dorian_Gray_by_Oscar_Wilde.rtf.
There are 166996 words in Dracula_by_Bram_Stoker.rtf.
There are 165188 words in Oliver_Twist_by_Charles_Dickens.rtf.
There are 78912 words in Frankenstein_by_Mary_Shelley.rtf.
There 

In [32]:
def clean_text(text):
    '''Remove unwanted characters and extra spaces from the text'''
    text = re.sub(r'\n', ' ', text) 
    text = re.sub(r'[{}@_*>()\\#%+=\[\]]','', text)
    text = re.sub('a0','', text)
    text = re.sub('\'92t','\'t', text)
    text = re.sub('\'92s','\'s', text)
    text = re.sub('\'92m','\'m', text)
    text = re.sub('\'92ll','\'ll', text)
    text = re.sub('\'91','', text)
    text = re.sub('\'92','', text)
    text = re.sub('\'93','', text)
    text = re.sub('\'94','', text)
    text = re.sub('\.','. ', text)
    text = re.sub('\!','! ', text)
    text = re.sub('\?','? ', text)
    text = re.sub(' +',' ', text)
    return text

In [33]:
# Clean the text of the books
clean_books = []
for book in books:
    clean_books.append(clean_text(book))

In [34]:
# Check to ensure the text has been cleaned properly
clean_books[0][:500]

'rtf1ansiansicpg1252cocoartf1404cocoasubrtf470 fonttblf0fmodernfcharset0 Courier; colortbl;red255green255blue255;red0green0blue0; margl1440margr1440vieww10800viewh8400viewkind0 deftab720 pardpardeftab720sl280partightenfactor0 f0fs24 cf2 expnd0expndtw0kerning0 outl0strokewidth0 strokec2 The Project Gutenberg EBook of Anna Karenina, by Leo Tolstoy This eBook is for the use of anyone anywhere at no cost and with almost no restrictions whatsoever. You may copy it, give it away or re-use it under the '

In [35]:
# Create a dictionary to convert the vocabulary (characters) to integers
vocab_to_int = {}
count = 0
for book in clean_books:
    for character in book:
        
        if character not in vocab_to_int:
            vocab_to_int[character] = count
            count += 1

# Add special tokens to vocab_to_int
codes = ['<PAD>','<EOS>','<GO>']
for code in codes:
    vocab_to_int[code] = count
    count += 1

In [36]:
# Check the size of vocabulary and all of the values
vocab_size = len(vocab_to_int)
print("The vocabulary contains {} characters.".format(vocab_size))
print(sorted(vocab_to_int))

The vocabulary contains 78 characters.
[' ', '!', '"', '$', '&', "'", ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<EOS>', '<GO>', '<PAD>', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


In [37]:
# Create another dictionary to convert integers to their respective characters
int_to_vocab = {}
for character, value in vocab_to_int.items():
    int_to_vocab[value] = character

In [38]:
# Split the text from the books into sentences.
sentences = []
for book in clean_books:
    for sentence in book.split('. '):
        sentences.append(sentence + '.')
print("There are {} sentences.".format(len(sentences)))

There are 131951 sentences.


In [39]:
# Check to ensure the text has been split correctly.
sentences[:5]

['rtf1ansiansicpg1252cocoartf1404cocoasubrtf470 fonttblf0fmodernfcharset0 Courier; colortbl;red255green255blue255;red0green0blue0; margl1440margr1440vieww10800viewh8400viewkind0 deftab720 pardpardeftab720sl280partightenfactor0 f0fs24 cf2 expnd0expndtw0kerning0 outl0strokewidth0 strokec2 The Project Gutenberg EBook of Anna Karenina, by Leo Tolstoy This eBook is for the use of anyone anywhere at no cost and with almost no restrictions whatsoever.',
 'You may copy it, give it away or re-use it under the terms of the Project Gutenberg License included with this eBook or online at http://www.',
 'gutenberg.',
 'org/license.',
 'Title: Anna Karenina Author: Leo Tolstoy Release Date: July 01, 1998 EBook 1399 Reposted: April 02, 2005 corrections, reposted to new folder structure by David Widger Reposted: December 09, 2011 corrections Reposted: December 15, 2012 corrections, conversion to HTML Reposted: February 14, 2013 conversion to XHTML Strict by David Widger Reposted: February 22, 2013 cor

In [40]:
# Convert sentences to integers
int_sentences = []

for sentence in sentences:
    int_sentence = []
    for character in sentence:
        int_sentence.append(vocab_to_int[character])
    int_sentences.append(int_sentence)

In [41]:
# Find the length of each sentence
lengths = []
for sentence in int_sentences:
    lengths.append(len(sentence))
lengths = pd.DataFrame(lengths, columns=["counts"])
lengths.describe()

Unnamed: 0,counts
count,131951.0
mean,120.569787
std,116.983324
min,1.0
25%,46.0
50%,92.0
75%,160.0
max,8906.0


In [42]:
# Limit the data we will use to train our model
max_length = 128
min_length = 10

good_sentences = []

for sentence in int_sentences:
    if len(sentence) <= max_length and len(sentence) >= min_length:
        good_sentences.append(sentence)

print("We will use {} to train and test our model.".format(len(good_sentences)))

We will use 79039 to train and test our model.


In [43]:
# Split the data into training and testing sentences
training, testing = train_test_split(good_sentences, test_size = 0.15, random_state = 2)

print("Number of training sentences:", len(training))
print("Number of testing sentences:", len(testing))

Number of training sentences: 67183
Number of testing sentences: 11856


In [44]:
# Sort the sentences by length to reduce padding, which will allow the model to train faster
training_sorted = []
testing_sorted = []

for i in range(min_length, max_length+1):
    for sentence in training:
        if len(sentence) == i:
            training_sorted.append(sentence)
    for sentence in testing:
        if len(sentence) == i:
            testing_sorted.append(sentence)

In [45]:
def noise_maker(sentence, threshold):
    '''Relocate, remove, or add characters to create spelling mistakes'''

    letters = ['a','b','c','d','e','f','g','h','i','j','k','l','m',
           'n','o','p','q','r','s','t','u','v','w','x','y','z',]
    
    noisy_sentence = []
    i = 0
    while i < len(sentence):
        random = np.random.uniform(0,1,1)
        # Most characters will be correct since the threshold value is high
        if random < threshold:
            noisy_sentence.append(sentence[i])
        else:
            new_random = np.random.uniform(0,1,1)
            # ~33% chance characters will swap locations
            if new_random > 0.67:
                if i == (len(sentence) - 1):
                    # If last character in sentence, it will not be typed
                    continue
                else:
                    # if any other character, swap order with following character
                    noisy_sentence.append(sentence[i+1])
                    noisy_sentence.append(sentence[i])
                    i += 1
            # ~33% chance an extra lower case letter will be added to the sentence
            elif new_random < 0.33:
                random_letter = np.random.choice(letters, 1)[0]
                noisy_sentence.append(vocab_to_int[random_letter])
                noisy_sentence.append(sentence[i])
            # ~33% chance a character will not be typed
            else:
                pass     
        i += 1
    return noisy_sentence

In [46]:
# Check to ensure noise_maker is making mistakes correctly.
threshold = 0.9
for sentence in training_sorted[:5]:
    print(sentence)
    print(noise_maker(sentence, threshold))
    print()

[58, 23, 8, 1, 7, 13, 5, 19, 52, 43]
[23, 58, 8, 1, 7, 13, 5, 19, 52, 43]

[39, 5, 7, 10, 24, 1, 20, 23, 41, 43]
[13, 39, 5, 7, 10, 24, 1, 1, 20, 23, 41, 43]

[33, 0, 23, 6, 7, 22, 23, 5, 1, 43]
[33, 0, 8, 23, 6, 7, 22, 23, 5, 1, 43]

[10, 16, 1, 23, 5, 17, 23, 0, 10, 43]
[10, 16, 1, 24, 23, 5, 17, 16, 23, 0, 10, 43]

[61, 19, 22, 23, 5, 41, 19, 7, 1, 43]
[61, 19, 22, 23, 41, 19, 7, 1, 43]



In [47]:
def pad_sentence_batch(sentence_batch):
    """Pad sentences with <PAD> so that each sentence of a batch has the same length"""
    max_sentence = max([len(sentence) for sentence in sentence_batch])
    return [sentence + [vocab_to_int['<PAD>']] * (max_sentence - len(sentence)) for sentence in sentence_batch]

In [48]:
def get_batches(sentences, batch_size, threshold):
    """Batch sentences, noisy sentences, and the lengths of their sentences together.
       With each epoch, sentences will receive new mistakes"""
    
    for batch_i in range(0, len(sentences)//batch_size):
        start_i = batch_i * batch_size
        sentences_batch = sentences[start_i:start_i + batch_size]
        
        sentences_batch_noisy = []
        for sentence in sentences_batch:
            sentences_batch_noisy.append(noise_maker(sentence, threshold))
            
        sentences_batch_eos = []
        for sentence in sentences_batch:
            sentence.append(vocab_to_int['<EOS>'])
            sentences_batch_eos.append(sentence)
            
        pad_sentences_batch = np.array(pad_sentence_batch(sentences_batch_eos))
        pad_sentences_noisy_batch = np.array(pad_sentence_batch(sentences_batch_noisy))
        
        # Need the lengths for the _lengths parameters
        pad_sentences_lengths = []
        for sentence in pad_sentences_batch:
            pad_sentences_lengths.append(len(sentence))
        
        pad_sentences_noisy_lengths = []
        for sentence in pad_sentences_noisy_batch:
            pad_sentences_noisy_lengths.append(len(sentence))
        
        yield pad_sentences_noisy_batch, pad_sentences_batch, pad_sentences_noisy_lengths, pad_sentences_lengths

In [49]:
import torch
from torch.utils.data import Dataset, DataLoader

class SpellCorrectionDataset(Dataset):
    def __init__(self, sentences, vocab_to_int, noise_threshold=0.9):
        """
        sentences: list of "good" sentences (each is a list of integer tokens)
        vocab_to_int: dictionary mapping characters (and special tokens) to integers
        noise_threshold: probability that a given character remains unchanged
        """
        self.sentences = sentences
        self.vocab_to_int = vocab_to_int
        self.noise_threshold = noise_threshold

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        # Get the clean sentence (make a copy to avoid in-place modifications)
        clean_sentence = self.sentences[idx][:]

        # Append <EOS> token to the clean sentence if it isn't there already.
        if clean_sentence[-1] != self.vocab_to_int['<EOS>']:
            clean_sentence = clean_sentence + [self.vocab_to_int['<EOS>']]

        # Generate a noisy version on the fly.
        # Note: noise_maker should work with a sentence represented as a list of ints.
        noisy_sentence = noise_maker(self.sentences[idx], self.noise_threshold)
        
        # You might want to also append an <EOS> token to the noisy sentence, 
        # depending on how your model handles sequence termination.
        # For now, we leave it as is.

        # Convert lists to torch tensors.
        return (torch.tensor(noisy_sentence, dtype=torch.long),
                torch.tensor(clean_sentence, dtype=torch.long))


def pad_collate_fn(batch):
    """
    Collate function to pad sequences in the batch.
    Each batch item is a tuple: (noisy_sentence, clean_sentence)
    """
    # Unpack batch items.
    noisy_batch, clean_batch = zip(*batch)

    # Determine max lengths in the batch.
    max_noisy_len = max([s.size(0) for s in noisy_batch])
    max_clean_len = max([s.size(0) for s in clean_batch])
    
    # Pad noisy sequences.
    padded_noisy = [
        torch.cat([s, torch.tensor([vocab_to_int['<PAD>']] * (max_noisy_len - s.size(0)),
                                     dtype=torch.long)])
        for s in noisy_batch
    ]
    # Pad clean sequences.
    padded_clean = [
        torch.cat([s, torch.tensor([vocab_to_int['<PAD>']] * (max_clean_len - s.size(0)),
                                     dtype=torch.long)])
        for s in clean_batch
    ]

    # Optionally, collect the original lengths.
    noisy_lengths = torch.tensor([s.size(0) for s in noisy_batch], dtype=torch.long)
    clean_lengths = torch.tensor([s.size(0) for s in clean_batch], dtype=torch.long)
    
    # Stack into batch tensors.
    padded_noisy = torch.stack(padded_noisy)
    padded_clean = torch.stack(padded_clean)
    
    return padded_noisy, padded_clean, noisy_lengths, clean_lengths


# Create your custom dataset instance
dataset = SpellCorrectionDataset(training_sorted, vocab_to_int, noise_threshold=0.9)

# Split the dataset into training and validation sets (e.g., 85% training, 15% validation)
train_size = int(0.85 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Define batch size
batch_size = 128

# Create DataLoaders for training and validation.
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=pad_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=pad_collate_fn)

# Optional: Inspect one batch from the training loader
for batch in train_loader:
    src, trg, src_lengths, trg_lengths = batch
    print("Source batch shape:", src.shape)
    print("Target batch shape:", trg.shape)
    break


Source batch shape: torch.Size([128, 130])
Target batch shape: torch.Size([128, 128])


In [50]:
import torch
import torch.nn as nn
import random

###############################################
# Attention Module (Bahdanau-style Attention) #
###############################################

class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        # The attention layer will take the concatenation of the decoder's current hidden state
        # and each encoder output, then output an energy scalar.
        self.attn = nn.Linear(hidden_dim * 2, hidden_dim)
        self.v = nn.Linear(hidden_dim, 1, bias=False)
        
    def forward(self, decoder_hidden, encoder_outputs):
        """
        Args:
            decoder_hidden: [batch size, hidden_dim] -- current decoder hidden state (from the top layer)
            encoder_outputs: [batch size, src_len, hidden_dim] -- all encoder hidden states
            
        Returns:
            attention_weights: [batch size, src_len] -- normalized weights over the encoder outputs
        """
        batch_size = encoder_outputs.shape[0]
        src_len = encoder_outputs.shape[1]
        
        # Repeat decoder_hidden src_len times to concatenate with each encoder output.
        decoder_hidden = decoder_hidden.unsqueeze(1).repeat(1, src_len, 1)  # [batch, src_len, hidden_dim]
        
        # Concatenate and compute energy
        energy = torch.tanh(self.attn(torch.cat((decoder_hidden, encoder_outputs), dim=2)))  # [batch, src_len, hidden_dim]
        # Compute unnormalized attention scores
        attention = self.v(energy).squeeze(2)  # [batch, src_len]
        
        # Normalize with softmax over src_len
        return torch.softmax(attention, dim=1)

#######################################
# Encoder (Modified to return outputs)#
#######################################

class Encoder(nn.Module):
    def __init__(self, input_dim, embedding_dim, hidden_dim, n_layers, dropout):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        
        self.embedding = nn.Embedding(input_dim, embedding_dim)
        # batch_first=True: input shape [batch, seq_len, features]
        self.rnn = nn.LSTM(embedding_dim, hidden_dim, n_layers, dropout=dropout, batch_first=True)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        """
        src: [batch size, src length]
        Returns:
            outputs: [batch size, src length, hidden_dim] (all hidden states)
            hidden: [n_layers, batch size, hidden_dim]
            cell: [n_layers, batch size, hidden_dim]
        """
        embedded = self.dropout(self.embedding(src))  # [batch, src_len, embedding_dim]
        outputs, (hidden, cell) = self.rnn(embedded)
        return outputs, hidden, cell

#####################################
# Decoder (with Attention)          #
#####################################

class Decoder(nn.Module):
    def __init__(self, output_dim, embedding_dim, hidden_dim, n_layers, dropout):
        super().__init__()
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        
        self.embedding = nn.Embedding(output_dim, embedding_dim)
        # The input to the LSTM now is the concatenation of the embedding and the context vector.
        self.rnn = nn.LSTM(embedding_dim + hidden_dim, hidden_dim, n_layers, dropout=dropout, batch_first=True)
        self.attention = Attention(hidden_dim)
        # The fully-connected layer takes the concatenated [LSTM output; context] vector.
        self.fc_out = nn.Linear(hidden_dim * 2, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, cell, encoder_outputs):
        """
        Args:
            input: [batch size] - current token for each sentence in the batch
            hidden: [n_layers, batch size, hidden_dim]
            cell: [n_layers, batch size, hidden_dim]
            encoder_outputs: [batch size, src length, hidden_dim]
        
        Returns:
            prediction: [batch size, output_dim]
            hidden: updated hidden state
            cell: updated cell state
        """
        # Add time dimension to input: [batch size] -> [batch size, 1]
        input = input.unsqueeze(1)
        embedded = self.dropout(self.embedding(input))  # [batch, 1, embedding_dim]
        
        # Compute attention weights using the top layer decoder hidden state (i.e., hidden[-1])
        a = self.attention(hidden[-1], encoder_outputs)  # [batch, src_len]
        a = a.unsqueeze(1)  # [batch, 1, src_len]
        # Compute context vector as the weighted sum of encoder outputs.
        context = torch.bmm(a, encoder_outputs)  # [batch, 1, hidden_dim]
        
        # Concatenate the embedded input and context vector.
        rnn_input = torch.cat((embedded, context), dim=2)  # [batch, 1, embedding_dim + hidden_dim]
        
        # Pass through the LSTM
        output, (hidden, cell) = self.rnn(rnn_input, (hidden, cell))
        # output: [batch, 1, hidden_dim]
        output = output.squeeze(1)   # [batch, hidden_dim]
        context = context.squeeze(1) # [batch, hidden_dim]
        
        # Concatenate output and context and pass through the final linear layer.
        prediction = self.fc_out(torch.cat((output, context), dim=1))  # [batch, output_dim]
        
        return prediction, hidden, cell

#####################################
# Seq2Seq (Updated)                 #
#####################################

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        
        # Ensure the encoder and decoder have matching dimensions and layers.
        assert encoder.hidden_dim == decoder.hidden_dim, "Hidden dimensions of encoder and decoder must be equal!"
        assert encoder.n_layers == decoder.n_layers, "Encoder and decoder must have equal number of layers!"

    def forward(self, src, trg, teacher_forcing_ratio):
        """
        src: [batch size, src length]
        trg: [batch size, trg length]
        teacher_forcing_ratio: probability to use teacher forcing
        Returns:
            outputs: [batch size, trg length, output_dim]
        """
        batch_size = src.shape[0]
        trg_length = trg.shape[1]
        trg_vocab_size = self.decoder.output_dim
        
        # Tensor to store decoder outputs.
        outputs = torch.zeros(batch_size, trg_length, trg_vocab_size).to(self.device)
        
        # Encoder forward pass returns all outputs.
        encoder_outputs, hidden, cell = self.encoder(src)
        
        # The first token of each target sentence should be the <sos> token.
        input = trg[:, 0]  # [batch size]
        
        for t in range(1, trg_length):
            # Pass through the decoder one time step at a time, providing encoder_outputs for attention.
            output, hidden, cell = self.decoder(input, hidden, cell, encoder_outputs)
            outputs[:, t, :] = output
            
            # Decide whether to use teacher forcing.
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)  # [batch size]
            input = trg[:, t] if teacher_force else top1
        
        return outputs


In [51]:
PAD_IDX = vocab_to_int['<PAD>']
SOS_IDX = vocab_to_int['<GO>']  # or '<SOS>' depending on your token naming

# Hyperparameters (adjust as needed)
INPUT_DIM = len(vocab_to_int)      # Size of the vocabulary
OUTPUT_DIM = len(vocab_to_int)     # For seq2seq, input and output vocabs are often the same
EMBEDDING_DIM = 96
HIDDEN_DIM = 128
N_LAYERS = 2
DROPOUT = 0.65

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instantiate the model components
encoder = Encoder(INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, N_LAYERS, DROPOUT)
decoder = Decoder(OUTPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, N_LAYERS, DROPOUT)
model = Seq2Seq(encoder, decoder, device).to(device)
device

device(type='cuda')

In [52]:
import torch.optim as optim
import torch.nn as nn
import time
from tqdm import tqdm  # Import tqdm for progress bars

# Hyperparameters
LEARNING_RATE = 0.001
N_EPOCHS = 50
CLIP = 1  # For gradient clipping

# Define the optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
PAD_IDX = vocab_to_int['<PAD>']
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

def train(model, iterator, optimizer, criterion, clip, teacher_forcing_ratio):
    """
    Performs one epoch of training.
    """
    model.train()
    epoch_loss = 0

    # Wrap the iterator with tqdm for a progress bar
    progress_bar = tqdm(iterator, desc="Training", leave=False)
    for batch in progress_bar:
        src, trg, src_lengths, trg_lengths = batch
        src, trg = src.to(device), trg.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass: teacher forcing is applied based on the given ratio
        output = model(src, trg, teacher_forcing_ratio)
        # output shape: [batch size, trg length, output dim]
        
        # Reshape for loss computation
        output_dim = output.shape[-1]
        output = output.contiguous().view(-1, output_dim)   # [batch_size * trg_len, output_dim]
        trg = trg.contiguous().view(-1)                      # [batch_size * trg_len]
        
        loss = criterion(output, trg)
        loss.backward()
        
        # Clip gradients to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        epoch_loss += loss.item()
        
        # Optionally update the progress bar description with the current loss
        progress_bar.set_postfix(loss=loss.item())
    
    return epoch_loss / len(iterator)


def evaluate(model, iterator, criterion):
    """
    Evaluates the model on a validation set without teacher forcing.
    """
    model.eval()
    epoch_loss = 0

    # Wrap the iterator with tqdm for a progress bar
    progress_bar = tqdm(iterator, desc="Evaluating", leave=False)
    with torch.no_grad():
        for batch in progress_bar:
            src, trg, src_lengths, trg_lengths = batch
            src, trg = src.to(device), trg.to(device)
            
            # Turn off teacher forcing during evaluation
            output = model(src, trg, teacher_forcing_ratio=0)
            
            output_dim = output.shape[-1]
            output = output.contiguous().view(-1, output_dim)
            trg = trg.contiguous().view(-1)
            
            loss = criterion(output, trg)
            epoch_loss += loss.item()
            
            progress_bar.set_postfix(loss=loss.item())
    
    return epoch_loss / len(iterator)


# Example training loop with tqdm progress for each epoch
best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
    start_time = time.time()
    
    # Use tqdm to track the epochs as well if desired.
    train_loss = train(model, train_loader, optimizer, criterion, CLIP, teacher_forcing_ratio=0.2)
    valid_loss = evaluate(model, val_loader, criterion)
    
    end_time = time.time()
    epoch_mins, epoch_secs = divmod(int(end_time - start_time), 60)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'best-model.pt')
        print("Model Saved!")
    
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f}')


                                                                      

Model Saved!
Epoch: 01 | Time: 3m 29s
	Train Loss: 3.052
	 Val. Loss: 3.053


                                                                      

Model Saved!
Epoch: 02 | Time: 3m 29s
	Train Loss: 2.875
	 Val. Loss: 2.758


                                                                      

Model Saved!
Epoch: 03 | Time: 3m 29s
	Train Loss: 2.508
	 Val. Loss: 2.352


                                                                      

Model Saved!
Epoch: 04 | Time: 3m 29s
	Train Loss: 2.292
	 Val. Loss: 2.159


                                                                      

Epoch: 05 | Time: 3m 29s
	Train Loss: 2.183
	 Val. Loss: 2.257


                                                                      

Model Saved!
Epoch: 06 | Time: 3m 29s
	Train Loss: 2.090
	 Val. Loss: 2.026


                                                                      

Model Saved!
Epoch: 07 | Time: 3m 29s
	Train Loss: 2.013
	 Val. Loss: 1.998


                                                                      

Model Saved!
Epoch: 08 | Time: 3m 29s
	Train Loss: 1.981
	 Val. Loss: 1.977


                                                                      

Model Saved!
Epoch: 09 | Time: 3m 29s
	Train Loss: 1.952
	 Val. Loss: 1.952


                                                                      

Epoch: 10 | Time: 3m 29s
	Train Loss: 1.894
	 Val. Loss: 2.172


                                                                      

Epoch: 11 | Time: 3m 29s
	Train Loss: 1.863
	 Val. Loss: 2.016


                                                                      

Model Saved!
Epoch: 12 | Time: 3m 29s
	Train Loss: 1.788
	 Val. Loss: 1.951


                                                                      

Model Saved!
Epoch: 13 | Time: 3m 29s
	Train Loss: 1.761
	 Val. Loss: 1.881


                                                                      

Epoch: 14 | Time: 3m 29s
	Train Loss: 1.709
	 Val. Loss: 1.989


                                                                      

Epoch: 15 | Time: 3m 29s
	Train Loss: 1.649
	 Val. Loss: 2.006


                                                                      

Epoch: 16 | Time: 3m 29s
	Train Loss: 1.600
	 Val. Loss: 2.220


                                                                      

Epoch: 17 | Time: 3m 29s
	Train Loss: 1.575
	 Val. Loss: 1.948


                                                                      

Epoch: 18 | Time: 3m 29s
	Train Loss: 1.543
	 Val. Loss: 2.036


                                                                      

Epoch: 19 | Time: 3m 30s
	Train Loss: 1.476
	 Val. Loss: 2.189


                                                                      

Epoch: 20 | Time: 3m 29s
	Train Loss: 1.468
	 Val. Loss: 2.118


                                                                      

Epoch: 21 | Time: 3m 29s
	Train Loss: 1.443
	 Val. Loss: 2.130


                                                                     

KeyboardInterrupt: 

In [53]:
model = Seq2Seq(encoder, decoder, device).to(device)
def predict_sentence(model, sentence, vocab_to_int, int_to_vocab, device, max_length=50):
    """
    Generate a prediction for a single sentence using the attention-based seq2seq model.
    
    Args:
        model (nn.Module): The trained Seq2Seq model with attention.
        sentence (str): The noisy input sentence as a string.
        vocab_to_int (dict): Mapping from tokens (e.g., characters) to their integer IDs.
        int_to_vocab (dict): Mapping from integer IDs to tokens.
        device (torch.device): Device on which to run the model.
        max_length (int): Maximum length for the generated sentence.
    
    Returns:
        str: The predicted (corrected) sentence as a string.
    """
    model.eval()  # Set model to evaluation mode

    # Convert the input sentence into token IDs.
    # Only include tokens that exist in vocab_to_int.
    token_ids = [vocab_to_int[char] for char in sentence if char in vocab_to_int]
    
    # Create a tensor and add a batch dimension: [1, src_len]
    src_tensor = torch.LongTensor(token_ids).unsqueeze(0).to(device)

    with torch.no_grad():
        # Get the encoder outputs as well as the final hidden and cell states.
        encoder_outputs, hidden, cell = model.encoder(src_tensor)
    
    # Initialize the decoder input with the <GO> token.
    SOS_IDX = vocab_to_int['<GO>']
    input_token = torch.LongTensor([SOS_IDX]).to(device)

    predicted_tokens = []

    for _ in range(max_length):
        with torch.no_grad():
            # Pass the current input token, hidden, cell, and encoder_outputs to the decoder.
            output, hidden, cell = model.decoder(input_token, hidden, cell, encoder_outputs)
        
        # Get the token with the highest probability.
        top1 = output.argmax(1).item()

        # If the <EOS> token is predicted, stop decoding.
        if top1 == vocab_to_int['<EOS>']:
            break

        predicted_tokens.append(top1)
        # Update the input token for the next time step.
        input_token = torch.LongTensor([top1]).to(device)
    
    # Convert the list of token IDs back to a string.
    predicted_sentence = ''.join([int_to_vocab[token] for token in predicted_tokens])
    return predicted_sentence

# Example usage:
noisy_sentence = "Ths is an exmple sentnce."
prediction = predict_sentence(model, noisy_sentence, vocab_to_int, int_to_vocab, device)
print("Predicted sentence:", prediction)


Predicted sentence: es is expleentcccceccec
