<a href="https://colab.research.google.com/github/gangulasreeja/ISA/blob/main/LegalT5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets
!pip install transformers
!pip install evaluate
!pip install rouge_score

In [None]:
from datasets import load_dataset

ds = load_dataset("ninadn/indian-legal")

In [None]:
import transformers
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
import numpy as np
import nltk
nltk.download('punkt')

In [None]:
max_input = 512
max_target = 128
batch_size = 3
model_checkpoints = "google/flan-t5-large"
tokenizer = transformers.AutoTokenizer.from_pretrained(model_checkpoints)

In [None]:
def preprocess_data(data_to_process):
    inputs = data_to_process['Text']
    targets = data_to_process['Summary']
    inputs = [" " if text is None else text for text in inputs]
    targets = [" " if text is None else text for text in targets]
    model_inputs = tokenizer(inputs, max_length=max_input, padding='max_length', truncation=True)
    targets_tokenized = tokenizer(text_target=targets, max_length=max_target, padding='max_length', truncation=True)
    model_inputs['labels'] = targets_tokenized['input_ids']
    return model_inputs
tokenize_data = ds.map(preprocess_data, batched=True)

In [None]:
num_valid = 200
train_valid_split = tokenize_data['train'].train_test_split(test_size=num_valid, shuffle=True, seed=42)

# Rename keys for clarity
tokenize_data['train'] = train_valid_split['train']
tokenize_data['valid'] = train_valid_split['test']

# Print sizes to confirm
print(f"Train dataset size: {len(tokenize_data['train'])}")
print(f"Validation dataset size: {len(tokenize_data['valid'])}")
print(f"Text dataset size: {len(tokenize_data['test'])}")

In [None]:
tokenize_data

In [None]:
model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_checkpoints)

In [None]:
batch_size = 3
collator = transformers.DataCollatorForSeq2Seq(tokenizer, model=model)
import evaluate
metric = evaluate.load("rouge")

In [None]:
def compute_rouge(pred):
  predictions, labels = pred
  #decode the predictions
  decode_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
  #decode labels
  decode_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

  #compute results
  res = metric.compute(predictions=decode_predictions, references=decode_labels, use_stemmer=True)
  #get %
  res = {key: value * 100 for key, value in res.items()}

  pred_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
  res['gen_len'] = np.mean(pred_lens)

  return {k: round(v, 4) for k, v in res.items()}

In [None]:
args = transformers.Seq2SeqTrainingArguments(
    'legal-summ',
    evaluation_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size= 1,
    per_device_eval_batch_size= 1,
    gradient_accumulation_steps=2,
    weight_decay=0.01,
    save_total_limit=2,
    num_train_epochs=3,
    predict_with_generate=True,
    eval_accumulation_steps=1,
    fp16=True
    )

In [None]:
trainer = transformers.Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenize_data['train'],
    eval_dataset=tokenize_data['valid'],
    data_collator=collator,
    tokenizer=tokenizer,
    compute_metrics=compute_rouge
)

In [None]:
trainer.train()