# Leveraging Sentence Transformers Embeddings for Multilabel Text Classification with LightGBM

In this notebook, the aim is to utilize the embeddings of the best sentence transformers embeddings in the training instead of frequency based vectorization like TF-IDF.
The hope here is to that we give LightGBM a better encoding for the input text to eventually yield a better results.


In [1]:
import pandas as pd
from sentence_transformers import SentenceTransformer
from sklearn.multioutput import MultiOutputClassifier
from sklearn.metrics import accuracy_score, classification_report

import lightgbm as lgb

from src.reporting.visualize import save_experiment_results

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = 'false'

Now, let's load and prepare the data and the embeddings:

In [3]:
# Load data
train_df = pd.read_csv('../data/processed/clean_train.csv')
valid_df = pd.read_csv('../data/processed/clean_valid.csv')

# Initialize Sentence Transformer Model
model = SentenceTransformer('BAAI/bge-small-en-v1.5')

# Transform 'clean_content' using Sentence Transformer
X_train = model.encode(train_df['clean_content'].to_list(), show_progress_bar=True)
X_valid = model.encode(valid_df['clean_content'].to_list(), show_progress_bar=True)

# Prepare labels for multilabel classification
y_train = train_df[['cyber_label', 'environmental_issue']]
y_valid = valid_df[['cyber_label', 'environmental_issue']]

# Classifier for mutlilabels
multioutput_classifier = MultiOutputClassifier(lgb.LGBMClassifier(verbosity=2), n_jobs=-1)
multioutput_classifier.fit(X_train, y_train)


Batches: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:07<00:00,  4.12it/s]
Batches: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:01<00:00,  4.34it/s]


In [4]:

# Prediction and evaluation
y_pred = multioutput_classifier.predict(X_valid)


save_experiment_results(
    y_true_valid=y_valid,
    y_pred_valid=y_pred,
    label_names=y_train.columns,
    experiment_name="lightgbm-bge-small-default",
)

Classification Report for cyber_label:
              precision    recall  f1-score   support

           0       0.95      0.99      0.97       235
           1       0.75      0.35      0.48        17

    accuracy                           0.95       252
   macro avg       0.85      0.67      0.73       252
weighted avg       0.94      0.95      0.94       252

Classification Report for environmental_issue:
              precision    recall  f1-score   support

           0       0.90      0.97      0.93       200
           1       0.84      0.60      0.70        52

    accuracy                           0.89       252
   macro avg       0.87      0.78      0.82       252
weighted avg       0.89      0.89      0.89       252

                     precision    recall  f1-score   support

        cyber_label       0.75      0.35      0.48        17
environmental_issue       0.84      0.60      0.70        52

          micro avg       0.82      0.54      0.65        69
          macr

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


All reports and plots have been generated and saved successfully.


Good results, we can see good scores in the macro avg f1 for both classes.

Let's try the upsampled training set:

In [5]:
# Load data
train_df = pd.read_csv('../data/processed/clean_train_upsampled.csv')
valid_df = pd.read_csv('../data/processed/clean_valid.csv')

# Transform 'clean_content' using Sentence Transformer
X_train = model.encode(train_df['clean_content'].to_list(), show_progress_bar=True)
X_valid = model.encode(valid_df['clean_content'].to_list(), show_progress_bar=True)

# Prepare labels for multilabel classification
y_train = train_df[['cyber_label', 'environmental_issue']]
y_valid = valid_df[['cyber_label', 'environmental_issue']]

# MultiOutput Classifier
multioutput_classifier = MultiOutputClassifier(lgb.LGBMClassifier(verbosity=2), n_jobs=-1)
multioutput_classifier.fit(X_train, y_train)


# Prediction and evaluation
y_pred = multioutput_classifier.predict(X_valid)

save_experiment_results(
    y_true_valid=y_valid,
    y_pred_valid=y_pred,
    label_names=y_train.columns,
    experiment_name="lightgbm-bge-small-upsampled-dataset-default",
)

Batches: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:20<00:00,  4.41it/s]
Batches: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:01<00:00,  4.25it/s]


Classification Report for cyber_label:
              precision    recall  f1-score   support

           0       0.95      0.99      0.97       235
           1       0.62      0.29      0.40        17

    accuracy                           0.94       252
   macro avg       0.79      0.64      0.68       252
weighted avg       0.93      0.94      0.93       252

Classification Report for environmental_issue:
              precision    recall  f1-score   support

           0       0.92      0.96      0.94       200
           1       0.83      0.67      0.74        52

    accuracy                           0.90       252
   macro avg       0.88      0.82      0.84       252
weighted avg       0.90      0.90      0.90       252

                     precision    recall  f1-score   support

        cyber_label       0.62      0.29      0.40        17
environmental_issue       0.83      0.67      0.74        52

          micro avg       0.80      0.58      0.67        69
          macr

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


All reports and plots have been generated and saved successfully.


This one dropped a little bit. Back to the cleaned set, let's play around with the hyperparams:

In [6]:
# Load data
train_df = pd.read_csv('../data/processed/clean_train.csv')
valid_df = pd.read_csv('../data/processed/clean_valid.csv')


# Transform 'clean_content' using Sentence Transformer
X_train = model.encode(train_df['clean_content'].to_list(), show_progress_bar=True)
X_valid = model.encode(valid_df['clean_content'].to_list(), show_progress_bar=True)

# Prepare labels for multilabel classification
y_train = train_df[['cyber_label', 'environmental_issue']]
y_valid = valid_df[['cyber_label', 'environmental_issue']]

# MultiOutput Classifier
multioutput_classifier = MultiOutputClassifier(
    lgb.LGBMClassifier(
        verbosity=0,
        min_data_in_leaf=30, 
        class_weight='balanced',
        learning_rate=0.15,
        n_estimators=300,
    ),
    n_jobs=-1,
)
multioutput_classifier.fit(X_train, y_train)

y_pred = multioutput_classifier.predict(X_valid)

save_experiment_results(
    y_true_valid=y_valid,
    y_pred_valid=y_pred,
    label_names=y_train.columns,
    experiment_name="lightgbm-bge-small-cleaned-dataset-class-weight-balanced",
)

Batches: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:07<00:00,  4.24it/s]
Batches: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:01<00:00,  4.29it/s]


Classification Report for cyber_label:
              precision    recall  f1-score   support

           0       0.96      0.99      0.97       235
           1       0.70      0.41      0.52        17

    accuracy                           0.95       252
   macro avg       0.83      0.70      0.75       252
weighted avg       0.94      0.95      0.94       252

Classification Report for environmental_issue:
              precision    recall  f1-score   support

           0       0.92      0.95      0.94       200
           1       0.80      0.69      0.74        52

    accuracy                           0.90       252
   macro avg       0.86      0.82      0.84       252
weighted avg       0.90      0.90      0.90       252

                     precision    recall  f1-score   support

        cyber_label       0.70      0.41      0.52        17
environmental_issue       0.80      0.69      0.74        52

          micro avg       0.78      0.62      0.69        69
          macr

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


All reports and plots have been generated and saved successfully.


In [7]:

# MultiOutput Classifier
multioutput_classifier = MultiOutputClassifier(
    lgb.LGBMClassifier(
        verbosity=0,
        min_data_in_leaf=20, 
        class_weight='balanced',
        boosting_type='dart',
        num_leaves=50,
        learning_rate=0.1,
        n_estimators=400,
    ),
    n_jobs=-1,
)
multioutput_classifier.fit(X_train, y_train)

y_pred = multioutput_classifier.predict(X_valid)

save_experiment_results(
    y_true_valid=y_valid,
    y_pred_valid=y_pred,
    label_names=y_train.columns,
    experiment_name="lightgbm-bge-small-cleaned-dataset-class-weight-balanced-dart-400estm",
)

Classification Report for cyber_label:
              precision    recall  f1-score   support

           0       0.96      1.00      0.97       235
           1       0.86      0.35      0.50        17

    accuracy                           0.95       252
   macro avg       0.91      0.67      0.74       252
weighted avg       0.95      0.95      0.94       252

Classification Report for environmental_issue:
              precision    recall  f1-score   support

           0       0.91      0.95      0.93       200
           1       0.79      0.65      0.72        52

    accuracy                           0.89       252
   macro avg       0.85      0.80      0.82       252
weighted avg       0.89      0.89      0.89       252

                     precision    recall  f1-score   support

        cyber_label       0.86      0.35      0.50        17
environmental_issue       0.79      0.65      0.72        52

          micro avg       0.80      0.58      0.67        69
          macr

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


All reports and plots have been generated and saved successfully.


In [8]:

# MultiOutput Classifier
multioutput_classifier = MultiOutputClassifier(
    lgb.LGBMClassifier(
        verbosity=0,
        min_data_in_leaf=20, 
        class_weight='balanced',
        boosting_type='dart',
        num_leaves=20,
        learning_rate=0.1,
        n_estimators=100,
    ),
    n_jobs=-1,
)
multioutput_classifier.fit(X_train, y_train)

y_pred = multioutput_classifier.predict(X_valid)

save_experiment_results(
    y_true_valid=y_valid,
    y_pred_valid=y_pred,
    label_names=y_train.columns,
    experiment_name="lightgbm-bge-small-cleaned-dataset-class-weight-balanced-dart-100estm",
)

Classification Report for cyber_label:
              precision    recall  f1-score   support

           0       0.95      0.98      0.97       235
           1       0.60      0.35      0.44        17

    accuracy                           0.94       252
   macro avg       0.78      0.67      0.71       252
weighted avg       0.93      0.94      0.93       252

Classification Report for environmental_issue:
              precision    recall  f1-score   support

           0       0.93      0.94      0.93       200
           1       0.76      0.71      0.73        52

    accuracy                           0.89       252
   macro avg       0.84      0.83      0.83       252
weighted avg       0.89      0.89      0.89       252

                     precision    recall  f1-score   support

        cyber_label       0.60      0.35      0.44        17
environmental_issue       0.76      0.71      0.73        52

          micro avg       0.73      0.62      0.67        69
          macr

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[LightGBM] [Info] Number of positive: 1452, number of negative: 1452
[LightGBM] [Debug] Dataset::GetMultiBinFromAllFeatures: sparse rate 0.000000
[LightGBM] [Debug] init for col-wise cost 0.000009 seconds, init for row-wise cost 0.016731 seconds
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.128102 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 97555
[LightGBM] [Info] Number of data points in the train set: 2904, number of used features: 384
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.500000 -> initscore=0.000000
[LightGBM] [Debug] Trained a tree with leaves = 31 and depth = 8
[LightGBM] [Debug] Trained a tree with leaves = 31 and depth = 9
[LightGBM] [Debug] Trained a tree with leaves = 31 and depth = 8
[LightGBM] [Debug] Trained a tree with leaves = 31 and depth = 8
[LightGBM] [Debug] Trained a tree with leaves = 31 and depth = 9
[LightGBM] [Debug] Trained a tree with leaves = 31 and depth