Bi-gram implementation using Neural Network approach (Pytorch)

* Importing libraries

In [69]:
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import re
from collections import defaultdict

* Constructing the Bi-gram language model class

In [72]:
class BigramLanguageModel(nn.Module):
    """
    A PyTorch-based bigram language model for processing and generating text based on a two-word sliding window approach.

    """
    def __init__(self,embedding_dim):
        """
        Initializes the BigramLanguageModel.

        Args:
            embedding_dim (int): The size of the embeddings for each token.
        """
        super(BigramLanguageModel,self).__init__()
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.vocab = None
        self.vocab_size = None
        self.encoder = None
        self.decoder = None
        self.token_embedding_table = None
        self.embedding_dim = embedding_dim
        

    def preprocessing(self,text_file):
        """
        Processes text files to create a vocabulary and a corpus of processed words.

        Args:
            text_file (str): Path to the text file to be processed.

        Returns:
            tuple: A tuple containing the vocabulary list and the processed corpus.
        """
        # Opening the text file
        with open(text_file, "r") as f:
            text = f.read()

        # Initialize lists and sets
        processed_corpus = []
        vocabulary = set()
        word_counter = defaultdict(int)

        # Count the frequency of each word
        for word in text.split():
            word = word.lower()  # Case folding
            # word = word.strip('.,():"')  # Strip common punctuation
            if word:  # Check if word is not empty
                if word.endswith('.'):
                    word = word.rstrip('.')
                    if word != '':
                        processed_corpus.append(word)
                        processed_corpus.append('.')
                        word_counter['.'] += 1
                else:
                    processed_corpus.append(word)
                word_counter[word] += 1

        # Replace rare words with '<UNK>' and update vocabulary
        final_corpus = []
        for word in processed_corpus:
            if word_counter[word] < 2:  # If the count of the word is less than 2
                final_corpus.append('<UNK>')
                vocabulary.add('<UNK>')
            else:
                final_corpus.append(word)
                vocabulary.add(word)

        vocabulary = list(vocabulary)  # Convert set to list

        # Verifying the vocabulary and preprocessed corpus
        print(f'Vocabulary size: {len(vocabulary)}')
        print(f"Vocabulary: {vocabulary}")
        print(f"Preprocessed corpus: {final_corpus}")

        return vocabulary, final_corpus


    def prep(self,text_file):
        """
        Prepares the model by setting up the vocabulary, and the training and validation datasets.

        Args:
            text_file (str): Path to the text file used for preparing the model.
        """
        # creating vocabulary
        self.vocab, self.processed_corpus = self.preprocessing(text_file)
        self.vocab_size = len(self.vocab)

        word2Index = {word:i for i,word in enumerate(self.vocab)}

        index2Word = {i:word for i,word in enumerate(self.vocab)}

        # encoder and decoder for text embedding
        self.encoder = lambda text: [word2Index[word] for word in text]
        # self.decoder = lambda nums: ' '.join([index2Word[i] for i in nums])
        self.decoder = lambda nums: [index2Word[i] for i in nums]

        # creating train and val set
        n = len(self.processed_corpus)
        self.train_text = self.processed_corpus[:int(n*0.9)]
        self.val_text = self.processed_corpus[int(n*0.9):]

        # generating the embedding and converting to tensor
        self.train_data = torch.tensor(self.encoder(self.train_text), dtype=torch.long)
        self.val_data = torch.tensor(self.encoder(self.val_text), dtype=torch.long)

        # embedding table for the vocabulary
        self.token_embedding_table = nn.Embedding(self.vocab_size,self.embedding_dim)
        self.fc = nn.Linear(self.embedding_dim,self.vocab_size)  

    def get_batch(self,split='train', input_length=8, batch_size=32):
        """
        Generates a batch of data for training or validation.

        Args:
            split (str): Type of the split ('train' or 'val').
            input_length (int): Length of the input sequences.
            batch_size (int): Size of each batch.

        Returns:
            tuple: A tuple containing batched input and target tensors.
        """
        data = self.train_data if split=='train' else self.val_data

        # getting random index from the data for creating batches
        index = torch.randint(len(data)-input_length,(batch_size,))

        inputs_batch = torch.stack([data[i:i+input_length] for i in index])
        targets_batch = torch.stack([data[i+1:i+1+input_length] for i in index])

        inputs_batch = inputs_batch.to(self.device)
        targets_batch = targets_batch.to(self.device)

        return inputs_batch, targets_batch

    def forward(self,input_ids,target=None):
        """
        Forward pass of the model.

        Args:
            input_ids (torch.Tensor): Input tensor of word indices.
            target (torch.Tensor, optional): Target tensor of word indices.

        Returns:
            tuple: Tuple containing the logits and optionally the cross-entropy loss if target is provided.
        """
        input_ids_embedding = self.token_embedding_table(input_ids)
        input_ids_embedding = self.fc(input_ids_embedding)

        if target is None:
            ce_loss = None
        else:
            batch_size, input_length, vocab_size = input_ids_embedding.shape
            token_prediction_target = input_ids_embedding.view(batch_size*input_length, vocab_size)
            token_actual_target = target.view(batch_size*input_length)
            ce_loss = F.cross_entropy(token_prediction_target,token_actual_target)
        return input_ids_embedding, ce_loss
    
    def fit(self,train_iter=100, eval_iter=10, lr=0.001):
        """
        Trains the model.

        Args:
            train_iter (int): Number of total training iterations.
            eval_iter (int): Interval for evaluation during training.
            lr (float): Learning rate for the optimizer.
        """
        optimizer = torch.optim.Adam(self.parameters(), lr= lr)
        for epoch in range(train_iter):
            if epoch%eval_iter == 0:
                loss,perplexity = self.eval_loss(eval_iter)
                print(f"Epoch:{epoch}, training loss:{loss['train']}, validation loss:{loss['val']}")
                print(f"Training Perplexity:{perplexity['train']}, Validation Perplexity:{perplexity['val']}")

            optimizer.zero_grad()

            inputs,targets = self.get_batch(split='train')
            _, ce_loss = self(inputs,targets)

            ce_loss.backward()
            optimizer.step()

    @torch.no_grad()
    def eval_loss(self,eval_iters):
        """
        Evaluates the model over a set number of iterations and calculates loss and perplexity.

        Args:
            eval_iters (int): Number of iterations to evaluate.

        Returns:
            dict: A dictionary containing 'train' and 'val' losses and perplexities.
        """
        loss = {}
        perplexity = {}
        # switching to evaluation mode, not tracking the gradients
        self.eval()
        for dataset in ['train','val']:
            losses = torch.zeros(eval_iters)
            for iters in range(eval_iters):
                inputs,targets = self.get_batch(dataset)
                _, ce_loss = self(inputs,targets)
                losses[iters] = ce_loss.item()
            average_loss = losses.mean()
            loss[dataset] = average_loss
            perplexity[dataset] = torch.exp(average_loss)
        # again switching to training mode
        self.train()
        return loss, perplexity

    def generate(self,context_word_tokens,max_tokens):
        """
        Generates text based on the provided context of word tokens.

        Args:
            context_word_tokens (torch.Tensor): Initial context tokens for text generation.
            max_tokens (int): Maximum number of tokens to generate.

        Returns:
            str: Generated text.
        """
        for _ in range(max_tokens):
            target_word_tokens_rep, _ = self(context_word_tokens)
            last_target_word_rep = target_word_tokens_rep[:,-1,:]
            prob_dist = F.softmax(last_target_word_rep,dim=1)

            # predicting the next word as the probability distribution given the context word
            next_word_token = torch.multinomial(prob_dist,num_samples=1)

            context_word_tokens = torch.cat((context_word_tokens,next_word_token), dim=1)

        context_tokens = context_word_tokens.squeeze(0).tolist()
        output = self.decoder(context_tokens)
        # Decode the sequence of word indices to get the output text
        output_text = ""
        prev = ''
        for text in output:
            if prev == text:
                continue
            else:
                prev = text
            if text == '.':
                output_text = ''.join([output_text,text])
                output_text = ''.join([output_text,'\n'])
            else:
                output_text = ' '.join([output_text,text])
        return output_text


* Training the model

In [74]:
# creating an instance of the model
model = BigramLanguageModel(100)
filename = "WarrenBuffet.txt"

# processing the data to create the train and val dataset
model.prep(filename)

# training the model
model.fit(3000,100,0.001)

Vocabulary size: 3954
Epoch:0, training loss:8.453529357910156, validation loss:8.452406883239746
Training Perplexity:4691.60205078125, Validation Perplexity:4686.3388671875
Epoch:100, training loss:7.612222671508789, validation loss:7.702279090881348
Training Perplexity:2022.76904296875, Validation Perplexity:2213.38671875
Epoch:200, training loss:7.031450271606445, validation loss:7.2680253982543945
Training Perplexity:1131.670654296875, Validation Perplexity:1433.7166748046875
Epoch:300, training loss:6.621212005615234, validation loss:6.9428019523620605
Training Perplexity:750.8546142578125, Validation Perplexity:1035.6680908203125
Epoch:400, training loss:6.257720470428467, validation loss:6.719757080078125
Training Perplexity:522.027587890625, Validation Perplexity:828.6162109375
Epoch:500, training loss:6.015256881713867, validation loss:6.4763007164001465
Training Perplexity:409.63104248046875, Validation Perplexity:649.5635986328125
Epoch:600, training loss:5.733469009399414, 

* Generating the output sentences

In [77]:
# generate method is called to predict the next words given a word from vocabulary
bigram_output = model.generate(context_word_tokens=torch.tensor([[13]], dtype=torch.long, device=model.device), max_tokens=500)
print(bigram_output)

 tests and operating last year's.
 for ibm's stock for the number on sunday for shareholders resemble and earnings.
 as the beginning that april 29 11 2,459 enjoy the best gain (loss) <UNK> rolling device for example, it's more substantial sums of a profit as his son, malcolm g.
 <UNK> and bring the need to tony have citizens or write-downs that meet graduated gets markets.
 when <UNK> company percentage points of berkadia.
 here netjets' is to carry one of retail by insurance - fargo we had their names, some record against that we <UNK> therefore fallen of income - we will continue to the runaway friend <UNK> assets <UNK>.
 competition the gain regulators.
 charlie and their employment were <UNK> however, below will get very managers.
 you to <UNK> at prices of board than its top spot.
 our policies and made some serious mistakes into go.
 these <UNK> and sell at nebraska furniture (furniture retailing) as a remarkable entrepreneur <UNK> above foreign-exchange profits.
 its berkshire'