In [None]:
!pip install transformers torch scikit-learn



In [None]:
!pip install transformers[torch]
!pip install accelerate -U



In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from transformers import RobertaTokenizer, RobertaForSequenceClassification
from torch.utils.data import Dataset
import torch
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

# Load the CSV file
file_path = '/content/drive/MyDrive/legal_text_classification[2].csv'
df = pd.read_csv(file_path)

# Drop rows with missing values
df.dropna(subset=['case_text'], inplace=True)

# Use a smaller subset of the data for quick training
df = df.sample(frac=0.1, random_state=42)

# Preprocess the data
class LegalDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )
        return {
            'text': text,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

# Convert case outcomes to numerical labels
label_mapping = {label: idx for idx, label in enumerate(df['case_outcome'].unique())}
df['label'] = df['case_outcome'].map(label_mapping)

# Split the data
train_texts, val_texts, train_labels, val_labels = train_test_split(
    df['case_text'].tolist(),
    df['label'].tolist(),
    test_size=0.2,
    random_state=42
)

# Load tokenizers and models
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
roberta_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

# Create datasets
max_length = 512
train_dataset = LegalDataset(train_texts, train_labels, bert_tokenizer, max_length)
val_dataset = LegalDataset(val_texts, val_labels, bert_tokenizer, max_length)

# Training arguments
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=1,  # Reduced number of epochs for quicker training
    per_device_train_batch_size=4,  # Reduced batch size
    per_device_eval_batch_size=4,
    warmup_steps=50,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    eval_steps=50,
    save_steps=50,
    evaluation_strategy='steps'
)

# Define the model
bert_model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(label_mapping))
roberta_model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=len(label_mapping))

# Define the compute metrics function
def compute_metrics(p):
    preds = p.predictions.argmax(-1)
    labels = p.label_ids
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

# Train BERT model
trainer_bert = Trainer(
    model=bert_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)

# Train RoBERTa model
trainer_roberta = Trainer(
    model=roberta_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)

# Training
trainer_bert.train()
trainer_roberta.train()

# Evaluation
eval_result_bert = trainer_bert.evaluate()
eval_result_roberta = trainer_roberta.evaluate()

print("BERT Evaluation Results:", eval_result_bert)
print("RoBERTa Evaluation Results:", eval_result_roberta)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
50,1.3978,1.696352,0.474849,0.305769,0.225482,0.474849
100,1.5909,1.648722,0.474849,0.305769,0.225482,0.474849
150,1.7296,1.626017,0.474849,0.305769,0.225482,0.474849
200,1.3339,1.653222,0.474849,0.305769,0.225482,0.474849
250,1.3001,1.630164,0.474849,0.305769,0.225482,0.474849
300,1.6737,1.638623,0.474849,0.305769,0.225482,0.474849
350,1.5564,1.628368,0.474849,0.305769,0.225482,0.474849
400,1.5189,1.627325,0.474849,0.305769,0.225482,0.474849


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Step,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
50,1.3978,1.696352,0.474849,0.305769,0.225482,0.474849
100,1.5909,1.648722,0.474849,0.305769,0.225482,0.474849
150,1.7296,1.626017,0.474849,0.305769,0.225482,0.474849
200,1.3339,1.653222,0.474849,0.305769,0.225482,0.474849
250,1.3001,1.630164,0.474849,0.305769,0.225482,0.474849
300,1.6737,1.638623,0.474849,0.305769,0.225482,0.474849
350,1.5564,1.628368,0.474849,0.305769,0.225482,0.474849
400,1.5189,1.627325,0.474849,0.305769,0.225482,0.474849
450,1.7811,1.620203,0.474849,0.305769,0.225482,0.474849


  _warn_prf(average, modifier, msg_start, len(result))


Step,Training Loss,Validation Loss
