##Install Required Libraries

In [None]:
!pip install -q \
    transformers==4.40.0 \
    peft==0.10.0 \
    accelerate \
    datasets==2.19.1 \
    scikit-learn==1.4.2 \
    sentence-transformers \
    umap-learn \
    gcsfs \
    PyPDF2 \
    pymupdf \
    faiss-cpu \
    langchain \
    langchain-openai \
    gradio

!pip install numpy==1.26.4 --force-reinstall --no-cache-dir

##Necessary Imports

In [None]:
import os
import json
import re
import pandas as pd
import numpy as np


#Visualization
import matplotlib.pyplot as plt
import seaborn as sns


#Machine Learning / Transformers
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

import torch
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
    DataCollatorWithPadding
)

#Datasets
from datasets import Dataset
from google.colab import files

##Upload and Load Dataset

In [None]:
uploaded = files.upload()
filename = list(uploaded.keys())[0]

with open(filename, "r", encoding="utf-8") as f:
    data = json.load(f)

df = pd.DataFrame(data)
df.head()

##Cleaning

In [None]:
def clean_text(text):
    text = str(text).lower()
    text = re.sub(r"[\n\r\t]", " ", text)
    text = re.sub(r"[\"']", "", text)
    text = re.sub(r"[^a-z0-9 ,.\[\]()/\-:]", "", text)
    text = re.sub(r"\.{2,}", ".", text)
    text = re.sub(r"\s+", " ", text)
    text = re.sub(r"[,/]+", ",", text)
    text = re.sub(r"(,\s*,)+", ",", text)
    text = re.sub(r"(,\s*$)|(^\s*,)", "", text)
    text = re.sub(r"\b(p\s*,\s*p\s*,\s*c)\b", "ppc", text)
    return text.strip()

def format_case(row):
    summary = clean_text(row["summary"])
    petitioner = clean_text(row["petitioner_argument"])
    respondent = clean_text(row["respondent_argument"])
    case_type = clean_text(row.get("case_type", ""))
    sections = clean_text(", ".join(row.get("offence_sections", [])))
    return f"[SUMMARY] {summary} [PETITIONER] {petitioner} [RESPONDENT] {respondent} [CASE TYPE] {case_type} [SECTIONS] {sections}"


##Clean Verdicts & Label Encoding

In [None]:
#Clean verdicts
df["verdict"] = df["verdict"].apply(clean_text)

#Encode labels
label_encoder = LabelEncoder()
df["label"] = label_encoder.fit_transform(df["verdict"])
label_mapping = dict(zip(label_encoder.transform(label_encoder.classes_), label_encoder.classes_))
df["label_name"] = df["label"].map(label_mapping)

#Filter labels with >= 10 samples
label_counts = df["label"].value_counts()
valid_labels = label_counts[label_counts >= 10].index.tolist()
df = df[df["label"].isin(valid_labels)].reset_index(drop=True)


##Visualize Label Distribution

In [None]:
plt.figure(figsize=(12, 6))
sns.countplot(data=df, x="label_name", palette="Set3", order=df["label_name"].value_counts().index)
plt.title("Filtered Verdict Distribution")
plt.xlabel("Verdict Label")
plt.ylabel("Count")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()


##Train/Test Split & Dataset Formatting

In [None]:
train_df, val_df = train_test_split(df, test_size=0.2, stratify=df["label"], random_state=42)
train_df["text"] = train_df.apply(format_case, axis=1)
val_df["text"] = val_df.apply(format_case, axis=1)

train_ds = Dataset.from_pandas(train_df[["text", "label"]])
val_ds = Dataset.from_pandas(val_df[["text", "label"]])

##Tokenization

In [None]:
tokenizer = AutoTokenizer.from_pretrained("nlpaueb/legal-bert-base-uncased")

def tokenize(batch):
    return tokenizer(batch["text"], padding="max_length", truncation=True, max_length=512)

tokenized_train_ds = train_ds.map(tokenize, batched=True)
tokenized_val_ds = val_ds.map(tokenize, batched=True)

#Remove unnecessary columns
cols_to_remove = [col for col in tokenized_train_ds.column_names if col.startswith("__")]
tokenized_train_ds = tokenized_train_ds.remove_columns(cols_to_remove)
tokenized_val_ds = tokenized_val_ds.remove_columns(cols_to_remove)


##Load Model

In [None]:
num_labels = df["label"].max() + 1
model = AutoModelForSequenceClassification.from_pretrained(
    "nlpaueb/legal-bert-base-uncased",
    num_labels=num_labels
)

##Training Setup & Metrics

In [None]:
training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=30,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    logging_dir="./logs",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy"
)

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

def compute_metrics(pred):
    labels = pred.label_ids
    preds = np.argmax(pred.predictions, axis=1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="weighted", zero_division=0)
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc, "precision": precision, "recall": recall, "f1": f1}

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_ds,
    eval_dataset=tokenized_val_ds,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)


##Train Model

In [None]:
trainer.train()

##Evaluation & Plotting

In [None]:
#Auto-detect latest checkpoint
checkpoint_dirs = [ckpt for ckpt in os.listdir("./results") if ckpt.startswith("checkpoint")]
latest_ckpt = sorted(checkpoint_dirs, key=lambda x: int(x.split("-")[-1]))[-1]
trainer_state_path = f"./results/{latest_ckpt}/trainer_state.json"

with open(trainer_state_path, "r") as f:
    logs = json.load(f)

log_history = logs["log_history"]
epochs, eval_loss, eval_accuracy, eval_precision, eval_recall, eval_f1 = [], [], [], [], [], []

for entry in log_history:
    if "eval_loss" in entry:
        epochs.append(entry["epoch"])
        eval_loss.append(entry.get("eval_loss"))
        eval_accuracy.append(entry.get("eval_accuracy"))
        eval_precision.append(entry.get("eval_precision"))
        eval_recall.append(entry.get("eval_recall"))
        eval_f1.append(entry.get("eval_f1"))

#Plot
plt.figure(figsize=(14, 7))
plt.plot(epochs, eval_loss, label="Eval Loss", marker='o', color='crimson')
plt.plot(epochs, eval_accuracy, label="Accuracy", marker='s', color='blue')
plt.plot(epochs, eval_precision, label="Precision", marker='^', color='green')
plt.plot(epochs, eval_recall, label="Recall", marker='v', color='darkorange')
plt.plot(epochs, eval_f1, label="F1 Score", marker='D', color='purple')
plt.xlabel("Epochs")
plt.ylabel("Metric Value")
plt.title("Evaluation Metrics Over Epochs")
plt.ylim(0, 1.05)
plt.grid(True, linestyle='--', alpha=0.6)
plt.legend()
plt.tight_layout()
plt.show()

##Save Model

In [None]:
trainer.save_model("./best_model")
tokenizer.save_pretrained("./best_model")

##Inference / Prediction Function

In [None]:
#Load best saved model and tokenizer
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

#Load from saved directory
model_path = "./best_model"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)

#Move model to correct device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

#Prediction Function
def predict_verdict(case, model, tokenizer, label_mapping):
    formatted_text = format_case(case)

    #Prepare inputs and move to device
    inputs = tokenizer(
        formatted_text,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=512
    ).to(device)

    with torch.no_grad():
        outputs = model(**inputs)
        predicted_label = torch.argmax(outputs.logits, dim=1).item()

    return label_mapping[predicted_label]

#Example Test Case
sample_case = {
    "summary": "The appellant was convicted under Section 302 for murder but argued that the evidence was circumstantial.",
    "petitioner_argument": "The petitioner argued there was no direct witness and the case relied on weak circumstantial evidence.",
    "respondent_argument": "The prosecution maintained that the motive and weapon recovery strongly supported conviction.",
    "case_type": "Criminal Appeal",
    "offence_sections": ["302", "34"]
}

#Predict verdict
predicted_verdict = predict_verdict(sample_case, model, tokenizer, label_mapping)
print("Predicted Verdict:", predicted_verdict)


##Download Model

In [None]:
!zip -r best_model.zip best_model

from google.colab import files
files.download("best_model.zip")
