In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import (BartTokenizer, BartForConditionalGeneration, 
                          T5Tokenizer, T5ForConditionalGeneration, 
                          BertTokenizer, EncoderDecoderModel)
from datasets import load_metric
from rouge_score import rouge_scorer
from sacrebleu import corpus_bleu

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

  from .autonotebook import tqdm as notebook_tqdm


cuda


In [2]:
# Custom Dataset
class LegalSummarizationDataset(Dataset):
    def __init__(self, input_dir, target_dir, tokenizer, max_input_length=512, max_target_length=128):
        self.input_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir)]
        self.target_files = [os.path.join(target_dir, f) for f in os.listdir(target_dir)]
        self.tokenizer = tokenizer
        self.max_input_length = max_input_length
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.input_files)

    def __getitem__(self, idx):
        with open(self.input_files[idx], 'r', encoding='utf-8') as f:
            input_text = f.read()
        with open(self.target_files[idx], 'r', encoding='utf-8') as f:
            target_text = f.read()
        
        inputs = self.tokenizer(input_text, max_length=self.max_input_length, padding="max_length", truncation=True, return_tensors="pt")
        targets = self.tokenizer(target_text, max_length=self.max_target_length, padding="max_length", truncation=True, return_tensors="pt")
        
        input_ids = inputs.input_ids.squeeze()
        attention_mask = inputs.attention_mask.squeeze()
        labels = targets.input_ids.squeeze()
        
        return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels}

In [6]:
# Load datasets
train_dataset = LegalSummarizationDataset(
    input_dir='D:/Sem-5/IMD/Project/LawSage.AI/static/dataset/IN-Abs/train-data/judgement', 
    target_dir='D:/Sem-5/IMD/Project/LawSage.AI/static/dataset/IN-Abs/train-data/summary', 
    tokenizer=BartTokenizer.from_pretrained('facebook/bart-large'))

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

In [7]:
# Model Wrapper
def train_model(model_name, tokenizer_name, model_class, epochs=3):
    tokenizer = tokenizer_name.from_pretrained(model_name)
    
    # Instantiate the model based on the class passed
    if model_class == BartForConditionalGeneration or model_class == T5ForConditionalGeneration:
        model = model_class.from_pretrained(model_name).to(device)
    elif model_class == EncoderDecoderModel:
        model = EncoderDecoderModel.from_encoder_decoder_pretrained(model_name, model_name).to(device)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    
    model.train()
    for epoch in range(epochs):
        for batch in train_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            optimizer.zero_grad()
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item()}")
    
    return model, tokenizer

In [8]:
# Evaluation Metrics
def evaluate_model(model, tokenizer, dataset):
    rouge = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    bleu = load_metric('sacrebleu')
    
    model.eval()
    predictions, references = [], []
    
    with torch.no_grad():
        for batch in DataLoader(dataset, batch_size=2):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            generated_ids = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=128)
            pred_texts = [tokenizer.decode(g, skip_special_tokens=True) for g in generated_ids]
            ref_texts = [tokenizer.decode(r, skip_special_tokens=True) for r in batch['labels']]
            
            predictions.extend(pred_texts)
            references.extend([[r] for r in ref_texts])
    
    # ROUGE Evaluation
    rouge_scores = [rouge.score(pred, ref[0]) for pred, ref in zip(predictions, references)]
    rouge1 = sum([s['rouge1'].fmeasure for s in rouge_scores]) / len(rouge_scores)
    rouge2 = sum([s['rouge2'].fmeasure for s in rouge_scores]) / len(rouge_scores)
    rougeL = sum([s['rougeL'].fmeasure for s in rouge_scores]) / len(rouge_scores)
    
    # BLEU Evaluation
    bleu_score = bleu.compute(predictions=predictions, references=references)
    
    print(f"ROUGE-1: {rouge1}, ROUGE-2: {rouge2}, ROUGE-L: {rougeL}")
    print(f"BLEU: {bleu_score['score']}")
    
    return rouge1, rouge2, rougeL, bleu_score['score']

In [9]:
# Train BART, T5, and BERT
models = {
    'BART': ('facebook/bart-large', BartTokenizer, BartForConditionalGeneration),
    'T5': ('t5-large', T5Tokenizer, T5ForConditionalGeneration),
    'BERT': ('bert-base-uncased', BertTokenizer, EncoderDecoderModel)
}

for model_name, (pretrained_model, tokenizer, model_type) in models.items():
    print(f"\nTraining {model_name}...\n")
    model, tokenizer = train_model(pretrained_model, tokenizer, model_type)
    print(f"\nEvaluating {model_name}...\n")
    evaluate_model(model, tokenizer, train_dataset)


Training BART...



KeyboardInterrupt: 