In [16]:
from tqdm import tqdm

import pandas as pd

from sentence_transformers import SentenceTransformer
from sentence_transformers.models import StaticEmbedding

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import RepeatedStratifiedKFold, cross_validate

## Pull the data (gitignored)

In [51]:
text_df = pd.read_csv("data/content_only.csv")
metadata_df = pd.read_csv("data/metadata_only.csv").rename({"ID на статия": "ID"}, axis=1)

full_data = text_df.merge(metadata_df, how='left').drop_duplicates(subset = ['ID', 'Text']) # note duplicate entries per id
# We have a number of options, but for example

text = full_data.Text.to_list()

top_n = 6
# Get the top N most frequent codes
top_codes = full_data["rootnode нс основна рубрика"].value_counts().nlargest(top_n).index
y = full_data["rootnode нс основна рубрика"].apply(lambda x: x if x in top_codes else 0)

In [52]:
embedding_models = [    
    "sentence-transformers/static-similarity-mrl-multilingual-v1",
    "intfloat/multilingual-e5-small", # passage: 
    "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
    "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
    # very slow
    # "Snowflake/snowflake-arctic-embed-m-v2.0"
]

report = pd.DataFrame()

for model_name in tqdm(embedding_models):
    embedder = SentenceTransformer(model_name, trust_remote_code=True)
    if not "static" in model_name:
        embedder.max_seq_length = 512
    if model_name == "intfloat/multilingual-e5-small":
        X = embedder.encode(["passage: " + t for t in text],convert_to_numpy=True)
    else:
        X = embedder.encode(text,convert_to_numpy=True)
    print(f"Completed embeddings with {model_name}")
    
    scoring = ["accuracy", "balanced_accuracy", "f1_weighted"]

    cls = LogisticRegression(random_state=0, class_weight='balanced', max_iter=1000)
    cv = RepeatedStratifiedKFold(n_splits=10, n_repeats=10, random_state=42)
    cv_results = pd.DataFrame(
        cross_validate(
            cls,
            X,
            y,
            scoring=scoring,
            cv=cv,
            return_train_score=True,
            return_estimator=True,
        )
    ).assign(model_name = model_name)
    
    report = pd.concat([report, cv_results])
    print(f"Done with {model_name}")

cv_scores = ['train_accuracy',
             'test_accuracy',
             'train_balanced_accuracy',
             'test_balanced_accuracy',
             'train_f1_weighted',
             'test_f1_weighted']


summary = (report[cv_scores + ['model_name']].
           groupby('model_name').
           aggregate(["mean", "std"]).
           T
          )

# Define a function to apply styling
def highlight_test_rows(row):
    if 'test' in row.name[0]:
        return ['font-weight: bold'] * len(row)
    return [''] * len(row)

styled_summary = summary.style.apply(highlight_test_rows, axis=1).format(precision=3) 

styled_summary

  0%|                                                                                                                                       | 0/4 [00:00<?, ?it/s]

Complete embeddings with sentence-transformers/static-similarity-mrl-multilingual-v1


 25%|███████████████████████████████▌                                                                                              | 1/4 [05:26<16:18, 326.27s/it]

Done with sentence-transformers/static-similarity-mrl-multilingual-v1
Complete embeddings with intfloat/multilingual-e5-small


 50%|███████████████████████████████████████████████████████████████                                                               | 2/4 [18:09<19:26, 583.11s/it]

Done with intfloat/multilingual-e5-small
Complete embeddings with sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2


 75%|██████████████████████████████████████████████████████████████████████████████████████████████▌                               | 3/4 [30:47<11:03, 663.13s/it]

Done with sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2
Complete embeddings with sentence-transformers/paraphrase-multilingual-mpnet-base-v2


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [1:08:05<00:00, 1021.27s/it]

Done with sentence-transformers/paraphrase-multilingual-mpnet-base-v2





Unnamed: 0,model_name,intfloat/multilingual-e5-small,sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2,sentence-transformers/paraphrase-multilingual-mpnet-base-v2,sentence-transformers/static-similarity-mrl-multilingual-v1
train_accuracy,mean,0.729,0.754,0.75,0.99
train_accuracy,std,0.005,0.004,0.005,0.002
test_accuracy,mean,0.694,0.702,0.704,0.76
test_accuracy,std,0.032,0.029,0.035,0.026
train_balanced_accuracy,mean,0.875,0.898,0.886,0.997
train_balanced_accuracy,std,0.004,0.003,0.004,0.001
test_balanced_accuracy,mean,0.781,0.765,0.762,0.716
test_balanced_accuracy,std,0.05,0.052,0.054,0.054
train_f1_weighted,mean,0.743,0.767,0.76,0.99
train_f1_weighted,std,0.005,0.004,0.005,0.002
