<a href="https://colab.research.google.com/github/np2802/Indian-Legal-Semantic-Searcher/blob/main/tsdae_bert.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Transformer-based Denoising AutoEncoder

### Installation of dependencies

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
! pip install transformers
! pip install -U sentence-transformers

Collecting sentence-transformers
  Downloading sentence_transformers-2.7.0-py3-none-any.whl (171 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m171.5/171.5 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.11.0->sentence-transformers)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.11.0->sentence-transformers)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.11.0->sentence-transformers)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=1.11.0->sentence-transformers)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch>=1.11.0->sentence-transform

### Implementation

In [None]:
##########################################
### 1. Load Clean English Legal Dataset ###
##########################################
import torch
import nltk
nltk.download('punkt')
from nltk.tokenize import sent_tokenize
import glob

# Specify the device as a GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

file_path = "/content/drive/MyDrive/FYP/Dataset/SelectedFiles/*.txt"
all_text = ""
for filename in glob.glob(file_path):
    with open(filename, 'r') as f:
        all_text += f.read()
sent_list = sent_tokenize(all_text)

print(">> Total Number of Corpus : {}".format(len(sent_list)))

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


>> Total Number of Corpus : 72677


In [None]:
#######################################
## 2. Denoising Auto-Encoder Dataset ##
#######################################
from torch.utils.data import Dataset
from sentence_transformers.readers.InputExample import InputExample
import random
import numpy as np

class DenoisingAutoEncoderDataset(Dataset):
    def __init__(self, sentences, noise_fn = lambda sent :  DenoisingAutoEncoderDataset.delete(sent)):
        self.sentences = sentences
        self.noise_fn = noise_fn

    def __getitem__(self, item):
        sent = self.sentences[item]
        return InputExample(texts=[self.noise_fn(sent), sent])  # label : Similar[0], texts : {Noised text, Original text}

    def __len__(self):
        return len(self.sentences)

    # Noise function
    @staticmethod
    def delete(text, del_ratio = 0.55): # ratio 60 % is best performance
        words_tok = nltk.word_tokenize(text)
        n = len(words_tok)
        if n == 0:
            return text

        keep_or_not = np.random.rand(n) > del_ratio  # [False, Fasle, True, True, False, True]
        if sum(keep_or_not) == 0:                    # number of [True]
            keep_or_not[np.random.choice(n)] = True
        words_processed = " ".join(np.array(words_tok)[keep_or_not])
        return words_processed

In [None]:
import torch.nn as nn
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, PreTrainedModel


class DenoisingAutoEncoderLoss(nn.Module):
    def __init__(self, model, decoder_name=None, tie_encoder_decoder=True):
        super(DenoisingAutoEncoderLoss, self).__init__()
        self.encoder = model
        self.tokenizer_encoder = model.tokenizer
        encoder_name = model[0].auto_model.config._name_or_path

        if decoder_name is None:
            assert tie_encoder_decoder, "Must indicate the decoder_name argument when tie_encoder_decoder = False"
        if tie_encoder_decoder:
            decoder_name = encoder_name

        self.tokenizer_decoder = AutoTokenizer.from_pretrained(decoder_name)
        self.need_retokenization = not (type(self.tokenizer_encoder) == type(self.tokenizer_decoder))

        decoder_config = AutoConfig.from_pretrained(decoder_name)
        decoder_config.is_decoder = True
        decoder_config.add_cross_attention = True
        kwargs_decoder = {'config': decoder_config}
        self.decoder = AutoModelForCausalLM.from_pretrained(decoder_name, **kwargs_decoder)

        if self.tokenizer_decoder.pad_token is None:
            self.tokenizer_decoder.pad_token = self.tokenizer_decoder.eos_token
            self.decoder.config.pad_token_id = self.decoder.config.eos_token_id

        if tie_encoder_decoder:
            if len(self.tokenizer_encoder) != len(self.tokenizer_decoder):
                self.tokenizer_decoder = self.tokenizer_encoder
                self.decoder.resize_token_embeddings(len(self.tokenizer_decoder))
            decoder_base_model_prefix = self.decoder.base_model_prefix
            PreTrainedModel._tie_encoder_decoder_weights(
                model[0].auto_model,
                self.decoder._modules[decoder_base_model_prefix],
                self.decoder.base_model_prefix,
                base_encoder_name="nlpaueb/Legal-bert-base-uncased"
            )

    def retokenize(self, sentence_features):
        input_ids = sentence_features['input_ids']
        device = input_ids.device
        sentences_decoded = self.tokenizer_encoder.batch_decode(
            input_ids,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True
        )
        retokenized = self.tokenizer_decoder(
            sentences_decoded,
            padding=True,
            truncation='longest_first',
            return_tensors="pt",
            max_length=None).to(device)
        return retokenized

    def forward(self, sentence_features, labels):
        source_features, target_features = tuple(sentence_features)  # (noised text, orginal text)
        if self.need_retokenization:
            target_features = self.retokenize(target_features)

        # 1. Sentence Embedding from Encoder
        reps = self.encoder(source_features)['sentence_embedding']  # [batch_size, hidden_dim]

        target_length = target_features['input_ids'].shape[1]
        decoder_input_ids = target_features['input_ids'].clone()[:, :target_length - 1]     # Decoder Input : input - [102] token
        label_ids = target_features['input_ids'][:, 1:]                                     # Label : input - [102] token

        # 2. Sentence Embedding from Decoder : output is CausalLMOutput with Cross Attentions
        decoder_outputs = self.decoder(
            input_ids = decoder_input_ids,
            inputs_embeds = None,
            attention_mask = None,
            encoder_hidden_states = reps[:, None],  # (batch_size, hidden_dim) -> (batch_size, 1, hidden_dim)
            encoder_attention_mask = source_features['attention_mask'][:, 0:1],
            labels = None,
            return_dict = None,
            use_cache = False)  # decoder_outputs : [loss = None, logits]

        # 3. Calculate Loss
        lm_logits = decoder_outputs[0]  # logits : [batch_size, seq_length, vocab_size]
        ce_loss_fct = nn.CrossEntropyLoss(ignore_index = self.tokenizer_decoder.pad_token_id)
        loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), label_ids.reshape(-1)) # CE_Loss([seq_length, vocab_size], [seq_length])
        return loss


In [None]:
################################################
## 4. TSDAE unsupervised-embeddings training  ##
################################################
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, models
from transformers import AutoTokenizer, AutoConfig, AutoModelForSequenceClassification, PreTrainedModel

# dataset with noise function
train_data = DenoisingAutoEncoderDataset(sent_list)

def collate_gpu(batch):
    x, t = torch.utils.data.dataloader.default_collate(batch)
    return x, t

# def preprocess(x, y):
#     return x.to(device), y.to(device)


# class WrappedDataLoader:
#     def __init__(self, dl, func):
#         self.dl = dl
#         self.func = func

#     def __len__(self):
#         return len(self.dl)

#     def __iter__(self):
#         for b in self.dl:
#             yield (self.func(*b))

# dataloader
loader = DataLoader(
    train_data,
    batch_size = 8,
    shuffle = True,
    drop_last = True,
    pin_memory= True,
    collate_fn=collate_gpu)

# loader = WrappedDataLoader(loader, preprocess)

# Transformers models
model_name = 'nlpaueb/Legal-bert-base-uncased'
pooling_strategy = 'mean'
bert = models.Transformer(model_name)
bert = bert.to(device)

# Sentence Embedding using Mean Pooling
pooling = models.Pooling(bert.get_word_embedding_dimension(), pooling_strategy) # cls, mean, max
pooling = pooling.to(device)
model = SentenceTransformer(modules = [bert, pooling] )
model = model.to(device)

# Use Loss function
loss = DenoisingAutoEncoderLoss(model, tie_encoder_decoder = True)

## Model training

In [None]:
import os
import torch
import tensorflow as tf

# Train
epochs = 1
warmup_steps = int(len(loader) * epochs * 0.10) # Warmup 10 %

# Create the checkpoint directory if it doesn't exist
checkpoint_dir = '/content/drive/MyDrive/FYP/models-final/tsdae_trained_model/checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)
# Include the epoch in the file name (uses `str.format`)
checkpoint_path = "cp-{epoch:02d}-{step:04d}.pt"
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=os.path.join(checkpoint_dir, checkpoint_path),
    verbose=1,
    save_weights_only=True,
    save_freq='epoch'
)

# Load the last checkpoint if available
checkpoint_file = os.path.join(checkpoint_dir, 'cp-{:02d}-{:04d}.pt'.format(0, 0))
if os.path.exists(checkpoint_file):
    checkpoint = torch.load(checkpoint_file)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
else:
    start_epoch = 0

# Train the model
for epoch in range(start_epoch, epochs):
    model.fit(
        train_objectives=[(loader, loss)],
        epochs=1,  # Train for 1 epoch per iteration
        warmup_steps=warmup_steps,
        weight_decay=0,
        scheduler='constantlr',
        optimizer_params={'lr': 3e-5},
        show_progress_bar=True,
        callback=[cp_callback]
    )

    # Save the last checkpoint as a single file
    checkpoint_file = os.path.join(checkpoint_dir, 'cp-{:02d}-{:04d}.pt'.format(epoch+1, 0))
    torch.save({
        'epoch': epoch+1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, checkpoint_file)

In [None]:
import os
import tensorflow as tf
import keras
import pytorch_lightning

# Train
epochs = 5
warmup_steps = int(len(loader) * epochs * 0.10)  # Warmup 10%

# Create the checkpoint directory if it doesn't exist
checkpoint_dir = 'models-final/tsdae_trained_model/checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

import os
import pytorch_lightning as pl

class CheckpointEveryNSteps(pl.Callback):
    """
    Save a checkpoint every N steps, instead of Lightning's default that checkpoints
    based on validation loss.
    """

    def __init__(
        self,
        save_step_frequency,
        prefix="N-Step-Checkpoint",
        use_modelcheckpoint_filename=False,
    ):
        """
        Args:
            save_step_frequency: how often to save in steps
            prefix: add a prefix to the name, only used if
                use_modelcheckpoint_filename=False
            use_modelcheckpoint_filename: just use the ModelCheckpoint callback's
                default filename, don't use ours.
        """
        self.save_step_frequency = save_step_frequency
        self.prefix = prefix
        self.use_modelcheckpoint_filename = use_modelcheckpoint_filename

    def on_batch_end(self, trainer: pl.Trainer, _):
        """ Check if we should save a checkpoint after every train batch """
        epoch = trainer.current_epoch
        global_step = trainer.global_step
        if global_step % self.save_step_frequency == 0:
            if self.use_modelcheckpoint_filename:
                filename = trainer.checkpoint_callback.filename
            else:
                filename = f"{self.prefix}_{epoch=}_{global_step=}.ckpt"
            ckpt_path = os.path.join(trainer.checkpoint_callback.dirpath, filename)
            trainer.save_checkpoint(ckpt_path)


# Train the model
start_epoch = 0
for epoch in range(start_epoch, epochs):
    model.fit(
        train_objectives=[(loader, loss)],
        epochs=1,  # Train for 1 epoch per iteration
        warmup_steps=warmup_steps,
        weight_decay=0,
        scheduler='constantlr',
        optimizer_params={'lr': 3e-5},
        show_progress_bar=True,
        callback=[CheckpointEveryNSteps(save_step_frequency=1000)]
    )

#     # Save the last checkpoint as a single file
#     checkpoint_file = os.path.join(checkpoint_dir, 'cp-{:02d}-{:04d}.pt'.format(epoch+1, 0))
#     tf.train.save_checkpoint({
#         'epoch': epoch+1,
#         'model_state_dict': model.state_dict(),
#         'optimizer_state_dict': optimizer.state_dict(),
#     }, checkpoint_file)

## Saving Model

In [None]:
# Save final model
model.save('/content/drive/MyDrive/FYP/models/tsdae_trained_model')

## Model testing

### Load the model

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load the trained model and tokenizer
model_path = '/content/drive/MyDrive/FYP/models/tsdae_trained_model'
tokenizer_path = '/content/drive/MyDrive/FYP/models/tsdae_trained_model/tokenizer'
model = AutoModelForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)


### Finding average cosine similarity between embeddings

In [None]:
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F

# Load LegalBERT model and tokenizer
model_name = "nlpaueb/Legal-bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

# Example sentences for testing
test_sentences = [
    "Constitution of India, petitioner seeks to challenge amendments.",
    "By the said deed the goodwill was not assigned.",
    "Central Bureau of Investigation (in short 'CBI') questions legality of the judgment."
]

# Tokenize the input sentences
input_ids = [tokenizer(sentence, return_tensors='pt', padding=True, truncation=True)['input_ids'] for sentence in test_sentences]
attention_masks = [input_id.ne(tokenizer.pad_token_id) for input_id in input_ids]

# Generate embeddings for the test sentences using BERT
with torch.no_grad():
    bert_embeddings = [model(**{'input_ids': input_id, 'attention_mask': attention_mask})[0][:, 0, :] for input_id, attention_mask in zip(input_ids, attention_masks)]

# Load the trained model and tokenizer
model_path = '/content/drive/MyDrive/FYP/models/tsdae_trained_model'
tokenizer_path = '/content/drive/MyDrive/FYP/models/tsdae_trained_model/tokenizer'
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
model = AutoModel.from_pretrained(model_path)

# Generate embeddings for the test sentences using the trained model
with torch.no_grad():
    embeddings = [model(**{'input_ids': input_id, 'attention_mask': attention_mask})[0][:, 0, :] for input_id, attention_mask in zip(input_ids, attention_masks)]

# Calculate and display the average cosine similarity between BERT and trained model embeddings
for i in range(len(test_sentences)):
    bert_embedding = bert_embeddings[i]
    trained_embedding = embeddings[i]
    similarity = F.cosine_similarity(bert_embedding, trained_embedding, dim=1).mean()
    print(f"Sentence: {test_sentences[i]}")
    print(f"Average Cosine Similarity between BERT and Trained Model embeddings: {similarity.item()}")
    print()


In [None]:
from transformers import BertTokenizer, BertForMaskedLM
import torch

# Load the trained model and tokenizer
model_path = '/content/drive/MyDrive/FYP/models/tsdae_trained_model'
tokenizer_path = '/content/drive/MyDrive/FYP/models/tsdae_trained_model/tokenizer'
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
model = AutoModel.from_pretrained(model_path)

# Test sentences with masked tokens
test_sentences = [
    "Constitution of India, petitioner seeks to challenge [MASK] amendments.",
    "By the said deed the goodwill was not [MASK] .",
    "Central Bureau of [MASK] (in short 'CBI') questions legality of the judgment."
]

# Tokenize test sentences and pad to the same length
tokenized_test_sentences = [tokenizer.tokenize(sentence) for sentence in test_sentences]
max_len = max(len(tokens) for tokens in tokenized_test_sentences)
padded_test_sentences = [tokens + ['[PAD]'] * (max_len - len(tokens)) for tokens in tokenized_test_sentences]

# Convert tokens to IDs
input_ids = [tokenizer.convert_tokens_to_ids(tokens) for tokens in padded_test_sentences]

# Convert IDs to tensors
input_ids_tensor = torch.tensor(input_ids)

# Perform inference
with torch.no_grad():
    outputs = model(input_ids=input_ids_tensor)

# Get predicted logits
predicted_logits = outputs.logits

# Get predicted token IDs
predicted_token_ids = torch.argmax(predicted_logits, dim=-1)

# Convert token IDs back to tokens
predicted_tokens = [tokenizer.convert_ids_to_tokens(token_ids.tolist()) for token_ids in predicted_token_ids]

# Print predictions
for i, sentence in enumerate(test_sentences):
    masked_index = tokenized_test_sentences[i].index('[MASK]')
    predicted_token = predicted_tokens[i][masked_index]
    print(f"Sentence: {sentence}")
    print(f"Predicted Token: {predicted_token}")
    print()