# BERT Multilabel Classification with ReDSM5

In [1]:
# Import necessary libraries
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
from sklearn.preprocessing import MultiLabelBinarizer

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW

from transformers import BertTokenizer, BertForSequenceClassification
from tqdm import tqdm

2025-03-19 19:04:11.236057: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1742407451.432456  141921 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1742407451.481304  141921 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1742407451.996031  141921 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1742407451.996054  141921 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1742407451.996055  141921 computation_placer.cc:177] computation placer alr

In [3]:
# Load dataset
data = pd.read_csv('data/redsm5.csv')
data['labels'] = data['labels'].apply(lambda x: x.split(';'))  # Convert labels to list

# MultiLabel Binarization
mlb = MultiLabelBinarizer()
labels = mlb.fit_transform(data['labels'])
texts = data['text'].tolist()

# Train-test split
X_train, X_test, y_train, y_test = train_test_split(
    texts, labels, test_size=0.2, random_state=42
)

print(f'{texts[0][0:15]}... ---> {labels[0]}')
data

Voices of Recov... ---> [0 0 0 0 0 1 0 0 0 0]


Unnamed: 0,text,labels,explanation
0,Voices of Recovery - Daily Meditation: Dec 20 ...,[NO_SYMPTOMS],The previous post would suggest the presence o...
1,Our Blind Spot [part four] [Part One](<URL>)\n...,[PSYCHOMOTOR],Psychomotor agitation is a characteristic symp...
2,Health update I'm getting so much better. Even...,[NO_SYMPTOMS],The person who wrote this post claims to have ...
3,"Dude, I can't even tell you the last time some...",[ANHEDONIA],This post suggests that this person is develop...
4,My sadness is immeasurable and my day is ruined.,[DEPRESSED_MOOD],This post describes intense pain and immeasura...
...,...,...,...
1479,"Jobs/careers?? Im diagnosed bipolar 1, and eve...","[WORTHLESSNESS, FATIGUE]",This post can suggest the presence of tirednes...
1480,My twin brother got in trouble because I taugh...,[WORTHLESSNESS],The above sentence can suggest feelings of gui...
1481,"Not very well, I constanly want to kill myself.",[SUICIDAL_THOUGHTS],This post suggests that the person has posted ...
1482,I know I cant concentrate with other people be...,[NO_SYMPTOMS],This publication suggests a possible attention...


In [4]:
# Tokenization
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def tokenize_texts(texts, tokenizer, max_len=512):
    return tokenizer(texts, padding=True, truncation=True, max_length=max_len, return_tensors='pt')

class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer):
        self.texts = tokenize_texts(texts, tokenizer)
        self.labels = torch.tensor(labels, dtype=torch.float)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {key: val[idx] for key, val in self.texts.items()}, self.labels[idx]

train_dataset = TextDataset(X_train, y_train, tokenizer)
test_dataset = TextDataset(X_test, y_test, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [5]:
# Model setup
model = BertForSequenceClassification.from_pretrained(
    'bert-base-uncased',
    num_labels=len(mlb.classes_),
    problem_type='multi_label_classification'
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [6]:
# Loss and Optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = AdamW(model.parameters(), lr=2e-5)

In [7]:
# Training
model.train()
for epoch in range(5):
    loop = tqdm(train_loader, leave=True)
    for batch in loop:
        inputs, targets = batch
        inputs = {key: val.to(device) for key, val in inputs.items()}
        targets = targets.to(device)

        optimizer.zero_grad()
        outputs = model(**inputs).logits
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        loop.set_description(f"Epoch {epoch+1}")
        loop.set_postfix(loss=loss.item())

Epoch 1: 100%|██████████| 149/149 [00:28<00:00,  5.20it/s, loss=0.307]
Epoch 2: 100%|██████████| 149/149 [00:26<00:00,  5.62it/s, loss=0.289]
Epoch 3: 100%|██████████| 149/149 [00:26<00:00,  5.61it/s, loss=0.281]
Epoch 4: 100%|██████████| 149/149 [00:26<00:00,  5.60it/s, loss=0.266]
Epoch 5: 100%|██████████| 149/149 [00:26<00:00,  5.60it/s, loss=0.173]


In [8]:
# Evaluation
model.eval()
all_preds, all_targets = [], []
with torch.no_grad():
    for batch in test_loader:
        inputs, targets = batch
        inputs = {key: val.to(device) for key, val in inputs.items()}
        targets = targets.cpu().numpy()

        outputs = torch.sigmoid(model(**inputs).logits).cpu().numpy()
        preds = (outputs > 0.5).astype(int)

        all_preds.extend(preds)
        all_targets.extend(targets)

In [9]:
print(classification_report(all_targets, all_preds, target_names=mlb.classes_, zero_division=0))
print(f"Accuracy: {accuracy_score(all_targets, all_preds):.2f}")

                   precision    recall  f1-score   support

        ANHEDONIA       0.76      0.52      0.62        25
  APPETITE_CHANGE       0.00      0.00      0.00        10
 COGNITIVE_ISSUES       0.00      0.00      0.00        10
   DEPRESSED_MOOD       0.61      0.40      0.48        70
          FATIGUE       0.75      0.64      0.69        28
      NO_SYMPTOMS       0.39      0.41      0.40        73
      PSYCHOMOTOR       0.00      0.00      0.00         8
     SLEEP_ISSUES       0.00      0.00      0.00        18
SUICIDAL_THOUGHTS       0.81      0.61      0.69        28
    WORTHLESSNESS       0.73      0.61      0.67        72

        micro avg       0.61      0.44      0.51       342
        macro avg       0.41      0.32      0.36       342
     weighted avg       0.55      0.44      0.48       342
      samples avg       0.49      0.46      0.47       342

Accuracy: 0.41
