# 🔍 BERT for Multi-label Depression Symptom Detection with ReDSM5

This notebook demonstrates how to fine-tune a BERT-based model for **multi-label classification of depression symptoms** using the `ReDSM5` dataset, a Reddit corpus annotated for the nine DSM-5 criteria for major depressive episodes.

We walk through the full pipeline:
- Loading and preprocessing the dataset,
- Tokenizing posts using `bert-base-uncased`,
- Wrapping inputs into a PyTorch-compatible `Dataset`,
- Fine-tuning a `BertForSequenceClassification` model for multi-label classification,
- Evaluating the model using standard classification metrics.

This notebook serves as a reproducible baseline for the experiments presented in the [ReDSM5 paper](https://huggingface.co/irlab-udc/redsm5), and is intended for researchers and practitioners interested in interpretable, symptom-level mental health NLP.

> ⚠️ This task is sensitive and related to mental health. The dataset should be used for research purposes only, with careful ethical considerations.


## 📦 Importing Dependencies

In this section, we import all necessary Python libraries for data manipulation, model training, and evaluation. These include:
- `pandas` for data handling,
- `scikit-learn` for splitting data and evaluation metrics,
- `torch` and `transformers` for building and fine-tuning the BERT model,
- `tqdm` for progress bars during training.

In [39]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
from sklearn.preprocessing import MultiLabelBinarizer

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW

from transformers import BertTokenizer, BertForSequenceClassification
from tqdm import tqdm

## 📊 Loading and Preparing the Dataset

We load the `ReDSM5` dataset from a CSV file. Each row contains:
- the Reddit post (`text`),
- one or more DSM-5 symptom labels (`labels`),
- and a sentence-level clinical explanation (`explanation`).

We preprocess the `labels` column by splitting the semicolon-separated values into Python lists, allowing for multi-label classification. Then, we use `MultiLabelBinarizer` to convert these labels into binary vectors (one-hot style).

Finally, we split the dataset into training and testing subsets (80/20 split) to prepare for model training and evaluation.

In [41]:
# Load dataset
data = pd.read_csv('data/redsm5.csv')
data['labels'] = data['labels'].apply(lambda x: x.split(';'))  # Convert labels to list

# MultiLabel Binarization
mlb = MultiLabelBinarizer()
labels = mlb.fit_transform(data['labels'])
texts = data['text'].tolist()

# Train-test split
X_train, X_test, y_train, y_test = train_test_split(
    texts, labels, test_size=0.2, random_state=42
)

print(f'{texts[0][0:15]}... ---> {labels[0]}')

Voices of Recov... ---> [0 0 0 0 0 1 0 0 0 0]


## 🧾 Tokenization and Dataset Wrapping

We load the `bert-base-uncased` tokenizer from Hugging Face. Then we define a helper function to tokenize all input texts, applying padding and truncation to fit BERT's maximum input size (512 tokens).

A custom `TextDataset` class is defined, inheriting from PyTorch’s `Dataset`. It stores tokenized inputs and label tensors. This structure enables batch-wise training using `DataLoader`.

We create two data loaders:
- `train_loader` for shuffled mini-batches used during training,
- `test_loader` for evaluation without shuffling.


In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def tokenize_texts(texts, tokenizer, max_len=512):
    return tokenizer(texts, padding=True, truncation=True, max_length=max_len, return_tensors='pt')

class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer):
        self.texts = tokenize_texts(texts, tokenizer)
        self.labels = torch.tensor(labels, dtype=torch.float)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {key: val[idx] for key, val in self.texts.items()}, self.labels[idx]

train_dataset = TextDataset(X_train, y_train, tokenizer)
test_dataset = TextDataset(X_test, y_test, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

## 🤖 BERT Model Initialization

We initialize a `BertForSequenceClassification` model with a multi-label classification head. Key parameters include:
- the number of output labels (derived from the dataset),
- the task type: `"multi_label_classification"`.

This model extends BERT with a final linear layer adapted for our 10 symptom categories. Finally, we move the model to GPU if available to accelerate training.


In [None]:
model = BertForSequenceClassification.from_pretrained(
    'bert-base-uncased',
    num_labels=len(mlb.classes_),
    problem_type='multi_label_classification'
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

## 🧮 Loss Function and Optimizer

We use `BCEWithLogitsLoss`, which is the standard loss function for multi-label classification. It combines a sigmoid layer and binary cross-entropy in a single function, making it ideal for predicting multiple independent labels.

The optimizer is `AdamW`, which adapts learning rates during training and includes weight decay regularization.

A learning rate of `2e-5` is chosen, consistent with standard fine-tuning practice for BERT-based models.


In [None]:
criterion = nn.BCEWithLogitsLoss()
optimizer = AdamW(model.parameters(), lr=2e-5)

## 🏋️ Training the Model

We train the BERT model for 5 epochs using the training data. For each batch:
1. The input tokens and attention masks are moved to the GPU (if available).
2. The model produces logits for each symptom class.
3. We compute the binary cross-entropy loss (`BCEWithLogitsLoss`) between predictions and true labels.
4. We perform backpropagation and update the model parameters using the AdamW optimizer.

A progress bar is displayed using `tqdm`, and the loss value is updated in real time to monitor training.


In [None]:
model.train()
for epoch in range(5):
    loop = tqdm(train_loader, leave=True)
    for batch in loop:
        inputs, targets = batch
        inputs = {key: val.to(device) for key, val in inputs.items()}
        targets = targets.to(device)

        optimizer.zero_grad()
        outputs = model(**inputs).logits
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        loop.set_description(f"Epoch {epoch+1}")
        loop.set_postfix(loss=loss.item())

Epoch 1: 100%|██████████| 149/149 [00:28<00:00,  5.20it/s, loss=0.307]
Epoch 2: 100%|██████████| 149/149 [00:26<00:00,  5.62it/s, loss=0.289]
Epoch 3: 100%|██████████| 149/149 [00:26<00:00,  5.61it/s, loss=0.281]
Epoch 4: 100%|██████████| 149/149 [00:26<00:00,  5.60it/s, loss=0.266]
Epoch 5: 100%|██████████| 149/149 [00:26<00:00,  5.60it/s, loss=0.173]


## 🧪 Evaluation on the Test Set

After training, we evaluate the model using the held-out test data. The model is set to evaluation mode (`model.eval()`), and gradient computation is disabled to speed up inference and reduce memory usage.

We:
- Pass each test batch through the model.
- Apply a sigmoid activation to obtain probabilities.
- Threshold these probabilities at 0.5 to get binary predictions for each symptom.
- Collect all predictions and true labels for final metric computation.


In [None]:
model.eval()
all_preds, all_targets = [], []
with torch.no_grad():
    for batch in test_loader:
        inputs, targets = batch
        inputs = {key: val.to(device) for key, val in inputs.items()}
        targets = targets.cpu().numpy()

        outputs = torch.sigmoid(model(**inputs).logits).cpu().numpy()
        preds = (outputs > 0.5).astype(int)

        all_preds.extend(preds)
        all_targets.extend(targets)

## 📈 Performance Metrics

We compute standard multi-label classification metrics using `classification_report` from `scikit-learn`. This includes precision, recall, and F1-score per class, as well as macro, micro, and weighted averages.

- **Micro average** treats each label equally and aggregates contributions of all classes.
- **Macro average** calculates metrics independently per class and then averages them.
- **Weighted average** accounts for class imbalance by weighting metrics by support.
- **Samples average** reflects how well the model performs across individual samples.

We also report the overall accuracy (i.e., the proportion of exact match predictions across all labels per sample).


In [None]:
print(classification_report(all_targets, all_preds, target_names=mlb.classes_, zero_division=0))
print(f"Accuracy: {accuracy_score(all_targets, all_preds):.2f}")

                   precision    recall  f1-score   support

        ANHEDONIA       0.76      0.52      0.62        25
  APPETITE_CHANGE       0.00      0.00      0.00        10
 COGNITIVE_ISSUES       0.00      0.00      0.00        10
   DEPRESSED_MOOD       0.61      0.40      0.48        70
          FATIGUE       0.75      0.64      0.69        28
      NO_SYMPTOMS       0.39      0.41      0.40        73
      PSYCHOMOTOR       0.00      0.00      0.00         8
     SLEEP_ISSUES       0.00      0.00      0.00        18
SUICIDAL_THOUGHTS       0.81      0.61      0.69        28
    WORTHLESSNESS       0.73      0.61      0.67        72

        micro avg       0.61      0.44      0.51       342
        macro avg       0.41      0.32      0.36       342
     weighted avg       0.55      0.44      0.48       342
      samples avg       0.49      0.46      0.47       342

Accuracy: 0.41
