# Leveraging Sentence Transformers Embeddings for Multilabel Text Classification with Scikit-Learn

In this notebook, the aim is to utilize the embeddings of the best sentence transformers embeddings in the training instead of frequency based vectorization like TF-IDF.
The hope here is to that we give SVM classifier a better encoding for the input text to eventually yield a better results.


In [1]:
import pandas as pd
from sentence_transformers import SentenceTransformer
from sklearn.multioutput import MultiOutputClassifier
from sklearn.metrics import accuracy_score, classification_report

from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression

from src.reporting.visualize import save_experiment_results


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = 'false'

In [3]:
# Load data
train_df = pd.read_csv('../data/processed/clean_train.csv')
valid_df = pd.read_csv('../data/processed/clean_valid.csv')

train_usmpl_df = pd.read_csv('../data/processed/clean_train_upsampled.csv')


# Initialize Sentence Transformer Model
model = SentenceTransformer('BAAI/bge-small-en-v1.5')

# Transform 'clean_content' using Sentence Transformer
X_train = model.encode(train_df['clean_content'].to_list(), show_progress_bar=True)
X_valid = model.encode(valid_df['clean_content'].to_list(), show_progress_bar=True)
X_train_usmpl = model.encode(train_usmpl_df['clean_content'].to_list(), show_progress_bar=True)


# Prepare labels for multilabel classification
y_train = train_df[['cyber_label', 'environmental_issue']]
y_valid = valid_df[['cyber_label', 'environmental_issue']]
y_train_usmpl = train_usmpl_df[['cyber_label', 'environmental_issue']]


Batches: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:07<00:00,  4.10it/s]
Batches: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:01<00:00,  4.32it/s]
Batches: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:20<00:00,  4.41it/s]


In [4]:

# MultiOutput Classifier (cleaned dataset)
multioutput_classifier = MultiOutputClassifier(SVC(probability=True, random_state=42), n_jobs=-1)
multioutput_classifier.fit(X_train, y_train)


In [5]:
y_pred = multioutput_classifier.predict(X_valid)

save_experiment_results(
    y_true_valid=y_valid,
    y_pred_valid=y_pred,
    label_names=y_train.columns,
    experiment_name="svc-bge-small-cleaned-dataset-default-prob",
)

Classification Report for cyber_label:
              precision    recall  f1-score   support

           0       0.96      1.00      0.97       235
           1       0.86      0.35      0.50        17

    accuracy                           0.95       252
   macro avg       0.91      0.67      0.74       252
weighted avg       0.95      0.95      0.94       252

Classification Report for environmental_issue:
              precision    recall  f1-score   support

           0       0.89      0.97      0.93       200
           1       0.85      0.56      0.67        52

    accuracy                           0.89       252
   macro avg       0.87      0.77      0.80       252
weighted avg       0.89      0.89      0.88       252

                     precision    recall  f1-score   support

        cyber_label       0.86      0.35      0.50        17
environmental_issue       0.85      0.56      0.67        52

          micro avg       0.85      0.51      0.64        69
          macr

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


All reports and plots have been generated and saved successfully.


In [6]:

# MultiOutput Classifier (Upsampled cleaned dataset)
multioutput_classifier = MultiOutputClassifier(SVC(probability=True, random_state=42), n_jobs=-1)
multioutput_classifier.fit(X_train_usmpl, y_train_usmpl)


# Prediction and evaluation
y_pred = multioutput_classifier.predict(X_valid)

save_experiment_results(
    y_true_valid=y_valid,
    y_pred_valid=y_pred,
    label_names=y_train.columns,
    experiment_name="svc-bge-small-upsampled-dataset-default-prob",
)

Classification Report for cyber_label:
              precision    recall  f1-score   support

           0       0.98      0.96      0.97       235
           1       0.57      0.71      0.63        17

    accuracy                           0.94       252
   macro avg       0.77      0.83      0.80       252
weighted avg       0.95      0.94      0.95       252

Classification Report for environmental_issue:
              precision    recall  f1-score   support

           0       0.94      0.89      0.92       200
           1       0.65      0.79      0.71        52

    accuracy                           0.87       252
   macro avg       0.80      0.84      0.81       252
weighted avg       0.88      0.87      0.87       252

                     precision    recall  f1-score   support

        cyber_label       0.57      0.71      0.63        17
environmental_issue       0.65      0.79      0.71        52

          micro avg       0.63      0.77      0.69        69
          macr

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


All reports and plots have been generated and saved successfully.


In [7]:
# MultiOutput Classifier
multioutput_classifier = MultiOutputClassifier(RandomForestClassifier(n_estimators=100, random_state=42), n_jobs=-1)
multioutput_classifier.fit(X_train_usmpl, y_train_usmpl)


# Prediction and evaluation
y_pred = multioutput_classifier.predict(X_valid)

save_experiment_results(
    y_true_valid=y_valid,
    y_pred_valid=y_pred,
    label_names=y_train.columns,
    experiment_name="rf-bge-small-upsampled-dataset-default",
)

Classification Report for cyber_label:
              precision    recall  f1-score   support

           0       0.94      1.00      0.97       235
           1       1.00      0.12      0.21        17

    accuracy                           0.94       252
   macro avg       0.97      0.56      0.59       252
weighted avg       0.94      0.94      0.92       252

Classification Report for environmental_issue:
              precision    recall  f1-score   support

           0       0.86      0.98      0.92       200
           1       0.87      0.38      0.53        52

    accuracy                           0.86       252
   macro avg       0.86      0.68      0.73       252
weighted avg       0.86      0.86      0.84       252

                     precision    recall  f1-score   support

        cyber_label       1.00      0.12      0.21        17
environmental_issue       0.87      0.38      0.53        52

          micro avg       0.88      0.32      0.47        69
          macr

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


All reports and plots have been generated and saved successfully.


In [8]:
# MultiOutput Classifier
multioutput_classifier = MultiOutputClassifier(LogisticRegression(random_state=42), n_jobs=-1)
multioutput_classifier.fit(X_train_usmpl, y_train_usmpl)


# Prediction and evaluation
y_pred = multioutput_classifier.predict(X_valid)

save_experiment_results(
    y_true_valid=y_valid,
    y_pred_valid=y_pred,
    label_names=y_train.columns,
    experiment_name="lr-bge-small-upsampled-dataset-default",
)

Classification Report for cyber_label:
              precision    recall  f1-score   support

           0       0.98      0.89      0.93       235
           1       0.32      0.71      0.44        17

    accuracy                           0.88       252
   macro avg       0.65      0.80      0.69       252
weighted avg       0.93      0.88      0.90       252

Classification Report for environmental_issue:
              precision    recall  f1-score   support

           0       0.95      0.84      0.89       200
           1       0.58      0.85      0.69        52

    accuracy                           0.84       252
   macro avg       0.77      0.84      0.79       252
weighted avg       0.88      0.84      0.85       252

                     precision    recall  f1-score   support

        cyber_label       0.32      0.71      0.44        17
environmental_issue       0.58      0.85      0.69        52

          micro avg       0.50      0.81      0.62        69
          macr

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


All reports and plots have been generated and saved successfully.


In [13]:
# MultiOutput Classifier
multioutput_classifier = MultiOutputClassifier(SVC(C=0.5, random_state=42), n_jobs=-1)
multioutput_classifier.fit(X_train_usmpl, y_train_usmpl)


# Prediction and evaluation
y_pred = multioutput_classifier.predict(X_valid)

save_experiment_results(
    y_true_valid=y_valid,
    y_pred_valid=y_pred,
    label_names=y_train.columns,
    experiment_name="svc-bge-small-upsampled-dataset-c0.5",
)

Classification Report for cyber_label:
              precision    recall  f1-score   support

           0       0.98      0.93      0.95       235
           1       0.43      0.71      0.53        17

    accuracy                           0.92       252
   macro avg       0.70      0.82      0.74       252
weighted avg       0.94      0.92      0.93       252

Classification Report for environmental_issue:
              precision    recall  f1-score   support

           0       0.94      0.86      0.90       200
           1       0.60      0.79      0.68        52

    accuracy                           0.85       252
   macro avg       0.77      0.83      0.79       252
weighted avg       0.87      0.85      0.86       252

                     precision    recall  f1-score   support

        cyber_label       0.43      0.71      0.53        17
environmental_issue       0.60      0.79      0.68        52

          micro avg       0.55      0.77      0.64        69
          macr

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


All reports and plots have been generated and saved successfully.


In [14]:
# MultiOutput Classifier
multioutput_classifier = MultiOutputClassifier(SVC(C=0.8, random_state=42), n_jobs=-1)
multioutput_classifier.fit(X_train_usmpl, y_train_usmpl)


# Prediction and evaluation
y_pred = multioutput_classifier.predict(X_valid)

save_experiment_results(
    y_true_valid=y_valid,
    y_pred_valid=y_pred,
    label_names=y_train.columns,
    experiment_name="svc-bge-small-upsampled-dataset-c0.8",
)

Classification Report for cyber_label:
              precision    recall  f1-score   support

           0       0.98      0.96      0.97       235
           1       0.55      0.71      0.62        17

    accuracy                           0.94       252
   macro avg       0.76      0.83      0.79       252
weighted avg       0.95      0.94      0.94       252

Classification Report for environmental_issue:
              precision    recall  f1-score   support

           0       0.94      0.88      0.91       200
           1       0.62      0.79      0.69        52

    accuracy                           0.86       252
   macro avg       0.78      0.83      0.80       252
weighted avg       0.87      0.86      0.86       252

                     precision    recall  f1-score   support

        cyber_label       0.55      0.71      0.62        17
environmental_issue       0.62      0.79      0.69        52

          micro avg       0.60      0.77      0.68        69
          macr

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


All reports and plots have been generated and saved successfully.
