<a href="https://colab.research.google.com/github/np2802/Indian-Legal-Semantic-Searcher/blob/main/dynamic%2Bstatistical_mlm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Masked Language Modeling (MLM)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# print(tf.__version__)

### Installation of dependencies

In [None]:
# Install necessary packages
!pip install nltk
!pip install transformers



### Import necessary packages

In [None]:
import os
import nltk
nltk.download('punkt')
from nltk.tokenize import sent_tokenize
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertForMaskedLM, BertTokenizer
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


### Implementation

In [None]:
# Function to read and tokenize text files from a directory
def read_and_tokenize(directory):
    sentences = []
    for file in os.listdir(directory):
        with open(os.path.join(directory, file), 'r', encoding='utf-8') as f:
            text = f.read()
            sentences.extend(sent_tokenize(text))
    return sentences

# Directory path to your dataset
directory_path = '/content/drive/MyDrive/FYP/Dataset/trial'
text_data = read_and_tokenize(directory_path) #contains sentences tokenized from all files in my data directory

In [None]:
# Initialize BERT tokenizer and model
model_name = 'nlpaueb/Legal-bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForMaskedLM.from_pretrained(model_name)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


# Dynamic MLM

In [None]:
# Custom Dataset class for dynamic MLM
class DynamicMLMPretrainingDataset(Dataset):
    def __init__(self, text_data, tokenizer):
        self.text_data = text_data
        self.tokenizer = tokenizer
        self.max_length = self.find_max_len() # max length of each sentence

    def __len__(self):
        return len(self.text_data) # length of text_data / number of sentences

    def __getitem__(self, idx):
      '''
      returns tensors of masked token values and labels for the given index.
      '''
      text = self.text_data[idx]
      tokens = self.tokenizer.encode(text, add_special_tokens=True)
      tokens = tokens + [tokenizer.pad_token_id] * (self.max_length - len(tokens))
      masked_tokens, labels = self.mask_tokens(tokens)
      return torch.tensor(masked_tokens), torch.tensor(labels)


    def find_max_len(self):
      '''
      returns max length of sentence
      '''
      tokenized_text_data = [self.tokenizer.encode(text, add_special_tokens=True) for text in self.text_data]
      max_length = max(len(tokens) for tokens in tokenized_text_data)
      print("Text max length : {}".format(max_length))
      return max_length

    def mask_tokens(self, tokens, mask_ratio=0.15):
      masked_indices = torch.rand(len(tokens)) < mask_ratio
      masked_tokens = torch.tensor(tokens)
      masked_tokens[masked_indices] = self.tokenizer.mask_token_id #[MASK]
      labels = torch.tensor(tokens)
      labels[~masked_indices] = -100
      return masked_tokens, labels

In [None]:
# Create DataLoader for batch training
dataset = DynamicMLMPretrainingDataset(text_data, tokenizer)
print(len(dataset))
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# Loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-5)

Text max length : 280
6154


### Model Training

In [None]:
# Training loop
from tqdm import tqdm

epochs = 5
best_loss = float('inf')
loss_values = []
model.train()
for epoch in tqdm(range(epochs)):
    print(f"Epoch {epoch}/{epochs} completed")
    for batch_masked_tokens, batch_labels in dataloader:
        optimizer.zero_grad()
        outputs = model(batch_masked_tokens, labels=batch_labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

    loss_values.append(loss.item())
    print(f"Epoch {epoch + 1}/{epochs} - Loss : {loss.item()}")

    if loss < best_loss:
        best_loss = loss
        model.save_pretrained('/content/drive/MyDrive/FYP/models/dynamic_mlm_trained_model')
        # tokenizer.save_pretrained('/content/drive/MyDrive/FYP/models/dynamic_mlm_trained_model/tokenizer')

In [None]:
# Plot loss graph
x = [i for i in range(0, len(loss_values))]
y = loss_values
plt.plot(x, y, marker='s', linestyle='--', color='green', label='dynamic-MLM')
plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss Graph Over Epochs')
plt.grid(True)
plt.show()

### Model Testing

In [None]:
from transformers import pipeline, AutoModelForMaskedLM, AutoTokenizer, BertTokenizer
import torch
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import tqdm
model_path ='/content/drive/MyDrive/FYP/models/dynamic_mlm_trained_model'
model=AutoModelForMaskedLM.from_pretrained(model_path)
tokenizer= BertTokenizer.from_pretrained('bert-base-uncased')
text = "They held entrance [MASK] for admission to the post- graduate course"
pipe = pipeline('fill-mask', model=model, tokenizer=tokenizer)
print(pipe(text))

[{'score': 0.2805054783821106, 'token': 14912, 'token_str': 'examinations', 'sequence': 'they held entrance examinations for admission to the post - graduate course'}, {'score': 0.24797526001930237, 'token': 13869, 'token_str': 'exams', 'sequence': 'they held entrance exams for admission to the post - graduate course'}, {'score': 0.2038165032863617, 'token': 5852, 'token_str': 'tests', 'sequence': 'they held entrance tests for admission to the post - graduate course'}, {'score': 0.18846404552459717, 'token': 7749, 'token_str': 'examination', 'sequence': 'they held entrance examination for admission to the post - graduate course'}, {'score': 0.013816867023706436, 'token': 11360, 'token_str': 'exam', 'sequence': 'they held entrance exam for admission to the post - graduate course'}]


## Statistical MLM

In [None]:
# Statistical MLM
class MLMPretrainingDataset(Dataset):
    def __init__(self, text_data, tokenizer):
        self.text_data = text_data
        self.tokenizer = tokenizer
        self.max_length = self.find_max_len()

    def __len__(self):
        return len(self.text_data)

    def __getitem__(self, idx):
        text = self.text_data[idx]

        # Tokenize the text and Padding
        tokens = self.tokenizer.encode(text, add_special_tokens=True)
        max_length = self.max_length
        tokens = tokens + [self.tokenizer.pad_token_id] * (max_length - len(tokens))

        # Create masked input and labels for MLM
        masked_tokens, labels = self.mask_tokens(tokens)

        return masked_tokens.clone().detach(), labels.clone().detach()

    def find_max_len(self):
        # Find max length
        tokenized_text_data = [self.tokenizer.encode(text, add_special_tokens=True) for text in self.text_data]
        max_length = max(len(tokens) for tokens in tokenized_text_data)
        print("Text max length : {}".format(max_length))
        return max_length

    def mask_tokens(self, tokens):
        probability_matrix = torch.full((len(tokens),), 0.15)  # 20% chance of masking(Original BERT : 15%)
        all_special_ids = [tokenizer.mask_token_id, tokenizer.sep_token_id, tokenizer.cls_token_id]
        special_tokens_mask = [1 if token in all_special_ids else 0 for token in tokens]
        probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)

        masked_indices = torch.bernoulli(probability_matrix).bool()
        masked_tokens = torch.tensor(tokens)
        masked_tokens[masked_indices] = self.tokenizer.mask_token_id

        labels = torch.tensor(tokens)
        labels[~masked_indices] = -100  # Only compute loss on masked tokens

        return masked_tokens, labels

In [None]:
# Create DataLoader for batch training
dataset = MLMPretrainingDataset(text_data, tokenizer)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# Loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0005)

Text max length : 436


### Model Training

In [None]:
# Training loop
epochs = 5
best_loss = float('inf')  # Initialize the best_loss with positive infinity
loss_values_1 = list()
model.train()
for epoch in tqdm(range(epochs), desc="Epochs", leave=True):
    for batch_masked_tokens, batch_labels in dataloader:
        optimizer.zero_grad()
        outputs = model(batch_masked_tokens, labels=batch_labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

    loss_values_1.append(loss.item())
    print(f"Epoch {epoch + 1}/{epochs} - Loss : {loss.item()}")

    # Save the MLM trained model for later use when a new best loss is achieved
    if loss < best_loss:
        best_loss = loss
        model.save_pretrained('/content/drive/MyDrive/FYP/models/statistical_mlm_trained_model')
        tokenizer.save_pretrained('/content/drive/MyDrive/FYP/models/statistical_mlm_trained_model/tokenizer')

In [None]:
# Print loss graph
x = [i for i in range(0, len(loss_values_1))]
y = loss_values_1
# Create a line plot for loss
plt.plot(x, y, marker='o', linestyle='-', color='blue', label='statistical-MLM')
plt.legend()
# Adding labels and title
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss Graph Over Epochs')
# Display the plot
plt.grid(True)  # Add grid lines
plt.show()