In [23]:
import torch
import datasets
import warnings
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

warnings.filterwarnings('ignore')

In [24]:
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=512):
        self.encodings = tokenizer(texts, truncation=True, padding=True, max_length=max_length)
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

class ModelEvaluator:
    def __init__(self, model_name, tokenizer_name, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.model_name = model_name
        self.tokenizer_name = tokenizer_name
        self.device = device
        self.model = None
        self.tokenizer = None
        self.is_three_class = 'deproberta-large-depression' in model_name
        
    def load_model(self):
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
            self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name)
            self.model.to(self.device)
            self.model.eval()
            return True
        except Exception as e:
            print(f"Error loading model {self.model_name}: {str(e)}")
            return False

    def prepare_data(self, df, batch_size=8):
        texts = df['text'].tolist()
        labels = df['label'].tolist()
        dataset = TextDataset(texts, labels, self.tokenizer)
        return DataLoader(dataset, batch_size=batch_size)

    def evaluate(self, dataloader):
        predictions = []
        true_labels = []
        
        with torch.no_grad():
            for batch in tqdm(dataloader, desc=f"Evaluating {self.model_name}"):
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                labels = batch['labels'].to(self.device)

                outputs = self.model(input_ids, attention_mask=attention_mask)
                logits = outputs.logits
                
                if self.is_three_class:
                    # Map the three-class predictions to binary:
                    # severe (0) -> depression (1)
                    # moderate (1) -> depression (1)
                    # not depression (2) -> not depression (0)
                    preds = torch.argmax(logits, dim=1)
                    preds = (preds < 2).long()  # 0 and 1 become 1, 2 becomes 0
                else:
                    preds = torch.argmax(logits, dim=1)

                predictions.extend(preds.cpu().numpy())
                true_labels.extend(labels.cpu().numpy())

        return self.calculate_metrics(true_labels, predictions)

    @staticmethod
    def calculate_metrics(true_labels, predictions):
        accuracy = accuracy_score(true_labels, predictions)
        precision, recall, f1, _ = precision_recall_fscore_support(true_labels, predictions, average='binary')
        
        return {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1
        }

def benchmark_models(df):
    models = [
        # ('rafalposwiata/deproberta-large-depression','rafalposwiata/deproberta-large-depression'), This is RoBERTa-large!
        ('ShreyaR/finetuned-roberta-depression', 'ShreyaR/finetuned-roberta-depression'),
        ('malexandersalazar/xlm-roberta-base-cls-depression', 'FacebookAI/xlm-roberta-base'),
        ('ShreyaR/finetuned-distil-bert-depression','ShreyaR/finetuned-distil-bert-depression'),
        ('mrjunos/depression-reddit-distilroberta-base','mrjunos/depression-reddit-distilroberta-base'),
        ('mrm8488/distilroberta-base-finetuned-suicide-depression','mrm8488/distilroberta-base-finetuned-suicide-depression')
    ]
    
    results = {}
    
    for model_name, tokenizer_name in models:
        print(f"\nEvaluating model: {model_name}")
        evaluator = ModelEvaluator(model_name, tokenizer_name)
        
        if evaluator.load_model():
            dataloader = evaluator.prepare_data(df)
            metrics = evaluator.evaluate(dataloader)
            results[model_name] = metrics
        else:
            print(f"Skipping evaluation for {model_name} due to loading error")
    
    return results

def print_results(results):
    print("\nBenchmarking Results:")
    print("-" * 100)
    print(f"{'Model':<50} | {'Accuracy':>10} | {'Precision':>10} | {'Recall':>10} | {'F1':>10}")
    print("-" * 100)
    
    for model_name, metrics in results.items():
        model_short_name = model_name.split('/')[-1]
        print(f"{model_short_name:<50} | {metrics['accuracy']:>10.4f} | {metrics['precision']:>10.4f} | "
              f"{metrics['recall']:>10.4f} | {metrics['f1']:>10.4f}")

In [25]:
if __name__ == "__main__":
    depression_detection = datasets.load_dataset('thePixel42/depression-detection')
    depression_detection_df = depression_detection['test'].to_pandas()
    results = benchmark_models(depression_detection_df)
    print_results(results)


Evaluating model: ShreyaR/finetuned-roberta-depression


Evaluating ShreyaR/finetuned-roberta-depression: 100%|██████████| 7500/7500 [03:04<00:00, 40.65it/s]



Evaluating model: malexandersalazar/xlm-roberta-base-cls-depression


Evaluating malexandersalazar/xlm-roberta-base-cls-depression: 100%|██████████| 7500/7500 [03:10<00:00, 39.34it/s]



Evaluating model: ShreyaR/finetuned-distil-bert-depression


Evaluating ShreyaR/finetuned-distil-bert-depression: 100%|██████████| 7500/7500 [01:42<00:00, 73.10it/s]



Evaluating model: mrjunos/depression-reddit-distilroberta-base


Evaluating mrjunos/depression-reddit-distilroberta-base: 100%|██████████| 7500/7500 [01:44<00:00, 71.62it/s]



Evaluating model: mrm8488/distilroberta-base-finetuned-suicide-depression


Evaluating mrm8488/distilroberta-base-finetuned-suicide-depression: 100%|██████████| 7500/7500 [01:41<00:00, 73.91it/s]



Benchmarking Results:
----------------------------------------------------------------------------------------------------
Model                                              |   Accuracy |  Precision |     Recall |         F1
----------------------------------------------------------------------------------------------------
finetuned-roberta-depression                       |     0.6820 |     0.6164 |     0.9701 |     0.7538
xlm-roberta-base-cls-depression                    |     0.7841 |     0.8458 |     0.6967 |     0.7640
finetuned-distil-bert-depression                   |     0.7135 |     0.6475 |     0.9419 |     0.7674
depression-reddit-distilroberta-base               |     0.7604 |     0.6937 |     0.9360 |     0.7968
distilroberta-base-finetuned-suicide-depression    |     0.5663 |     0.5605 |     0.6291 |     0.5928
