Notes, 

* Batch size 8 used by all models for consistency sake.

In [1]:
# disable persistant warning shown by tokenizers
%set_env TOKENIZERS_PARALLELISM=false

env: TOKENIZERS_PARALLELISM=false


In [2]:
import gc
import torch
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    classification_report,
)
from sklearn.multiclass import OneVsRestClassifier

  from tqdm.autonotebook import tqdm, trange
2024-05-31 03:16:59.244781: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
SEED = 42
BATCH_SIZE = 8  # adjust depends on GPU's memory capacity
MODELS = [
    'jinaai/jina-embeddings-v2-small-en',
    'jinaai/jina-embeddings-v2-base-en',
    'nomic-ai/nomic-embed-text-v1.5',
    'Alibaba-NLP/gte-large-en-v1.5',
    'sentence-transformers/all-mpnet-base-v2',
    'mixedbread-ai/mxbai-embed-large-v1',
    'WhereIsAI/UAE-Large-V1'
]

In [4]:
df_train = pd.read_csv('../../dataset/v1/train.csv')
df_test = pd.read_csv('../../dataset/v1/test.csv')

labels = df_train.columns[3:].to_list()
y_train = df_train[labels].to_numpy()
y_test = df_test[labels].to_numpy()

In [5]:
def evaluate(X_train, y_train, X_test, y_test, labels):
    clf = LogisticRegression(
        random_state=SEED,
        max_iter=100
    )
    ovr = OneVsRestClassifier(clf, n_jobs=-1)

    ovr.fit(X_train, y_train)
    y_pred = ovr.predict(X_test)

    accuracy = accuracy_score(y_test, y_pred)
    print(f'Overall accuracy: {accuracy}')
    for idx, label in enumerate(labels):
        label_accuracy = accuracy_score(y_test[:, idx], y_pred[:, idx])
        print(f'Accuracy {label}: {label_accuracy}')

    f1 = f1_score(y_test, y_pred, average='macro')
    print(f'F1 macro: {f1}')
    print(
        classification_report(y_test, y_pred, target_names=labels, digits=4, zero_division=0)
    )

In [6]:
def run(model_name: str):
    print(f'Model name: {model_name}')

    model = SentenceTransformer(
        model_name, trust_remote_code=True
    )
    X_train = np.asarray(
        model.encode(df_train['cleaned_review'].to_list(), batch_size=BATCH_SIZE)
    )
    X_test = np.asarray(
        model.encode(df_test['cleaned_review'].to_list(), batch_size=BATCH_SIZE)
    )

    evaluate(X_train, y_train, X_test, y_test, labels)
    # help prevent GPU OOM
    del model, X_train, X_test
    gc.collect()
    torch.cuda.empty_cache()

    print('='*50)

In [7]:
for model in MODELS:
    run(model)

Model name: jinaai/jina-embeddings-v2-small-en




Overall accuracy: 0.18625
Accuracy label_recommended: 0.845
Accuracy label_story: 0.78
Accuracy label_gameplay: 0.825
Accuracy label_visual: 0.74
Accuracy label_audio: 0.815
Accuracy label_technical: 0.83
Accuracy label_price: 0.79
Accuracy label_suggestion: 0.885
F1 macro: 0.6436656125334517
                   precision    recall  f1-score   support

label_recommended     0.8589    0.9459    0.9003       148
      label_story     0.7711    0.7191    0.7442        89
   label_gameplay     0.8742    0.9026    0.8882       154
     label_visual     0.7108    0.6782    0.6941        87
      label_audio     0.6944    0.4902    0.5747        51
  label_technical     0.7447    0.6140    0.6731        57
      label_price     0.6000    0.3191    0.4167        47
 label_suggestion     0.4000    0.1905    0.2581        21

        micro avg     0.7937    0.7355    0.7635       654
        macro avg     0.7068    0.6075    0.6437       654
     weighted avg     0.7747    0.7355    0.7482       

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

Overall accuracy: 0.176875
Accuracy label_recommended: 0.88
Accuracy label_story: 0.805
Accuracy label_gameplay: 0.815
Accuracy label_visual: 0.77
Accuracy label_audio: 0.805
Accuracy label_technical: 0.855
Accuracy label_price: 0.795
Accuracy label_suggestion: 0.86
F1 macro: 0.6618134122186495
                   precision    recall  f1-score   support

label_recommended     0.8924    0.9527    0.9216       148
      label_story     0.8049    0.7416    0.7719        89
   label_gameplay     0.8726    0.8896    0.8810       154
     label_visual     0.7303    0.7471    0.7386        87
      label_audio     0.6111    0.6471    0.6286        51
  label_technical     0.7692    0.7018    0.7339        57
      label_price     0.5882    0.4255    0.4938        47
 label_suggestion     0.1818    0.0952    0.1250        21

        micro avg     0.7912    0.7706    0.7808       654
        macro avg     0.6813    0.6501    0.6618       654
     weighted avg     0.7769    0.7706    0.7718     

<All keys matched successfully>
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the document

Overall accuracy: 0.175625
Accuracy label_recommended: 0.905
Accuracy label_story: 0.78
Accuracy label_gameplay: 0.83
Accuracy label_visual: 0.755
Accuracy label_audio: 0.84
Accuracy label_technical: 0.8
Accuracy label_price: 0.84
Accuracy label_suggestion: 0.845
F1 macro: 0.6755646403834055
                   precision    recall  f1-score   support

label_recommended     0.9448    0.9257    0.9352       148
      label_story     0.7922    0.6854    0.7349        89
   label_gameplay     0.8947    0.8831    0.8889       154
     label_visual     0.7317    0.6897    0.7101        87
      label_audio     0.7209    0.6078    0.6596        51
  label_technical     0.6393    0.6842    0.6610        57
      label_price     0.7143    0.5319    0.6098        47
 label_suggestion     0.2222    0.1905    0.2051        21

        micro avg     0.8042    0.7538    0.7782       654
        macro avg     0.7075    0.6498    0.6756       654
     weighted avg     0.8001    0.7538    0.7749       6

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

Overall accuracy: 0.1825
Accuracy label_recommended: 0.91
Accuracy label_story: 0.75
Accuracy label_gameplay: 0.83
Accuracy label_visual: 0.735
Accuracy label_audio: 0.82
Accuracy label_technical: 0.805
Accuracy label_price: 0.815
Accuracy label_suggestion: 0.875
F1 macro: 0.672943278595528
                   precision    recall  f1-score   support

label_recommended     0.9221    0.9595    0.9404       148
      label_story     0.7294    0.6966    0.7126        89
   label_gameplay     0.8896    0.8896    0.8896       154
     label_visual     0.7125    0.6552    0.6826        87
      label_audio     0.6744    0.5686    0.6170        51
  label_technical     0.6731    0.6140    0.6422        57
      label_price     0.6250    0.5319    0.5747        47
 label_suggestion     0.3750    0.2857    0.3243        21

        micro avg     0.7901    0.7538    0.7715       654
        macro avg     0.7001    0.6501    0.6729       654
     weighted avg     0.7804    0.7538    0.7659       65

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.6k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Overall accuracy: 0.185
Accuracy label_recommended: 0.815
Accuracy label_story: 0.81
Accuracy label_gameplay: 0.845
Accuracy label_visual: 0.715
Accuracy label_audio: 0.81
Accuracy label_technical: 0.81
Accuracy label_price: 0.82
Accuracy label_suggestion: 0.895
F1 macro: 0.5853242860338892
                   precision    recall  f1-score   support

label_recommended     0.8136    0.9730    0.8862       148
      label_story     0.8312    0.7191    0.7711        89
   label_gameplay     0.8436    0.9805    0.9069       154
     label_visual     0.6923    0.6207    0.6545        87
      label_audio     0.8824    0.2941    0.4412        51
  label_technical     0.8065    0.4386    0.5682        57
      label_price     0.7895    0.3191    0.4545        47
 label_suggestion     0.0000    0.0000    0.0000        21

        micro avg     0.8097    0.7156    0.7597       654
        macro avg     0.7074    0.5431    0.5853       654
     weighted avg     0.7838    0.7156    0.7227       65

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

Overall accuracy: 0.16375
Accuracy label_recommended: 0.925
Accuracy label_story: 0.815
Accuracy label_gameplay: 0.84
Accuracy label_visual: 0.76
Accuracy label_audio: 0.835
Accuracy label_technical: 0.84
Accuracy label_price: 0.805
Accuracy label_suggestion: 0.87
F1 macro: 0.6989039220337027
                   precision    recall  f1-score   support

label_recommended     0.9404    0.9595    0.9498       148
      label_story     0.7955    0.7865    0.7910        89
   label_gameplay     0.8720    0.9286    0.8994       154
     label_visual     0.7407    0.6897    0.7143        87
      label_audio     0.7250    0.5686    0.6374        51
  label_technical     0.7193    0.7193    0.7193        57
      label_price     0.6111    0.4681    0.5301        47
 label_suggestion     0.3684    0.3333    0.3500        21

        micro avg     0.8082    0.7859    0.7969       654
        macro avg     0.7215    0.6817    0.6989       654
     weighted avg     0.7999    0.7859    0.7911       

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

Overall accuracy: 0.166875
Accuracy label_recommended: 0.905
Accuracy label_story: 0.795
Accuracy label_gameplay: 0.845
Accuracy label_visual: 0.75
Accuracy label_audio: 0.83
Accuracy label_technical: 0.86
Accuracy label_price: 0.805
Accuracy label_suggestion: 0.875
F1 macro: 0.6945915681273175
                   precision    recall  f1-score   support

label_recommended     0.9216    0.9527    0.9369       148
      label_story     0.7727    0.7640    0.7684        89
   label_gameplay     0.8773    0.9286    0.9022       154
     label_visual     0.7467    0.6437    0.6914        87
      label_audio     0.7073    0.5686    0.6304        51
  label_technical     0.7636    0.7368    0.7500        57
      label_price     0.6176    0.4468    0.5185        47
 label_suggestion     0.3889    0.3333    0.3590        21

        micro avg     0.8086    0.7752    0.7916       654
        macro avg     0.7245    0.6718    0.6946       654
     weighted avg     0.7982    0.7752    0.7843     