In [1]:
import pandas as pd
import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import Dataset
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Splitting the dataset into training, validation, and test sets
df = pd.read_csv("<citations_data_path>") 
train_val_df, test_df = train_test_split(balanced_df, test_size=0.2, random_state=2, stratify=balanced_df['label'])
train_df, val_df = train_test_split(train_val_df, test_size=0.2, random_state=2, stratify=train_val_df['label'])

# Load tokenizer
model_name = "nlpaueb/legal-bert-base-uncased"  # You can use other models like bert-base-uncased for legal bert embeddings
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Preprocessing function
def preprocess_data(examples):
    return tokenizer(examples['case'], examples['citation'], truncation=True, padding='max_length', max_length=128)

# Convert Pandas DataFrames to Hugging Face Dataset format
train_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)
test_dataset = Dataset.from_pandas(test_df)

# Tokenize datasets
train_dataset = train_dataset.map(preprocess_data, batched=True)
val_dataset = val_dataset.map(preprocess_data, batched=True)
test_dataset = test_dataset.map(preprocess_data, batched=True)

# Set format for PyTorch
train_dataset = train_dataset.with_format("torch", columns=["input_ids", "attention_mask", "label"])
val_dataset = val_dataset.with_format("torch", columns=["input_ids", "attention_mask", "label"])
test_dataset = test_dataset.with_format("torch", columns=["input_ids", "attention_mask", "label"])

# Load model
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

# Define a new compute_metrics function
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = logits.argmax(axis=-1)  # Get predicted class
    accuracy = accuracy_score(labels, predictions)
    precision = precision_score(labels, predictions, average="binary")
    recall = recall_score(labels, predictions, average="binary")
    f1 = f1_score(labels, predictions, average="binary")
    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1,
    }

# Training arguments with all parameters
training_args = TrainingArguments(
    output_dir="./results",            # Directory for saving results
    evaluation_strategy="epoch",       # Evaluate at the end of each epoch
    save_strategy="epoch",             # Save the model at the end of each epoch
    logging_dir="./logs",              # Directory for logs
    learning_rate=2e-5,                # Learning rate
    per_device_train_batch_size=32,    # Training batch size
    per_device_eval_batch_size=32,     # Evaluation batch size
    num_train_epochs=10,               # Number of training epochs
    weight_decay=0.01,                 # Weight decay for regularization
    logging_steps=50,                  # Log every 50 steps
    save_total_limit=2,                # Limit the number of saved checkpoints
    load_best_model_at_end=True,       # Load the best model at the end of training
    metric_for_best_model="accuracy",  # Select the best model based on accuracy
    greater_is_better=True,            # Higher accuracy is better
    warmup_steps=500,                  # Number of warmup steps for learning rate scheduler
    fp16=True,                         # Use mixed precision (faster on modern GPUs)
    report_to="none",                  # Disable logging to external tools like WandB
)

# Trainer setup
trainer = Trainer(
    model=model,                       # The model to train
    args=training_args,                # Training arguments
    train_dataset=train_dataset,       # Training dataset
    eval_dataset=val_dataset,          # Validation dataset
    tokenizer=tokenizer,               # Tokenizer for preprocessing
    compute_metrics=compute_metrics,   # Metrics to evaluate
)

# Train the model
trainer.train()

# Evaluate the model
eval_results = trainer.evaluate(test_dataset)
print("Evaluation Results:", eval_results)

# Predict on test set and save predictions
predictions = trainer.predict(test_dataset)
logits = predictions.predictions
predicted_labels = logits.argmax(axis=-1)  # Convert logits to class predictions

# Save predictions to CSV
test_df["predicted_label"] = predicted_labels
test_df.to_csv("test_predictions_legal_bert.csv", index=False)
print("Predictions saved to test_predictions.csv")