In [7]:
import pandas as pd
import mlflow
from sklearn.model_selection import train_test_split, ParameterGrid
import torch
from catboost import CatBoostClassifier
from sklearn.metrics import classification_report, f1_score, accuracy_score

from transformers import AutoModel, AutoTokenizer, Trainer, TrainingArguments, default_data_collator, DebertaV2Tokenizer, PegasusForConditionalGeneration, PegasusTokenizer
from tqdm.notebook import tqdm

In [2]:
train = pd.read_csv("data/data_generated/train_with_embeddings.csv")
test = pd.read_csv("data/data_generated/test_with_embeddings.csv")

#train = pd.read_csv("data/data_generated/fewshot_embed_train.csv") + tf
#test = pd.read_csv("data/data_generated/fewshot_embed_test.csv") + tf

In [3]:
train.head(1)

Unnamed: 0,Исполнитель,Группа тем,Текст инцидента,Тема,0,1,2,3,4,5,...,nli_Строительство и архитектура,nli_Экономика и бизнес,nli_Физическая культура и спорт,nli_Связь и телевидение,nli_Газ и топливо,nli_Государственная собственность,nli_Торговля,nli_Памятники и объекты культурного наследия,nli_Погребение и похоронное дело,nli_Мобилизация
0,Город Пермь,Погребение и похоронное дело,Погребения - это серьезная проблема в нашей ст...,Погребение и похоронное дело,-0.294947,-0.055111,-0.917189,0.076128,-0.168686,0.458902,...,0.021046,0.004908,0.044621,0.049851,0.026472,0.016191,0.014407,0.049282,0.13109,0.009004


In [4]:
param_grid = {
    'iterations': [100, 250, 500]
}
grid = list(ParameterGrid(param_grid))

In [5]:
mlflow.set_experiment('embedding_mean')

<Experiment: artifact_location='file:///workspace/mlruns/548987682160929580', creation_time=1700788368162, experiment_id='548987682160929580', last_update_time=1700788368162, lifecycle_stage='active', name='embedding_mean', tags={}>

In [6]:

model_name, model_length = ('sberbank-ai/sbert_large_nlu_ru', 512)
    
for params in grid:
    try:
        with mlflow.start_run(nested=True):
            catboost_model = CatBoostClassifier(**params, verbose=0, random_seed=42)
            catboost_model.fit(train.drop(columns=["Текст инцидента", "Группа тем", "Исполнитель", "Тема"]), train["Группа тем"])

            predictions = catboost_model.predict(test.drop(columns=["Текст инцидента", "Группа тем", "Исполнитель", "Тема"]))
            
            accuracy = accuracy_score(test["Группа тем"], predictions)
            f1 = f1_score(test["Группа тем"], predictions, average='weighted') 
            report = classification_report(test["Группа тем"], predictions, output_dict=True)
            report_text =  classification_report(test["Группа тем"], predictions)

            
            mlflow.log_metric("report_accuracy", report['accuracy'])
            mlflow.log_metric("macro avg_precision", report['macro avg']['precision'])
            mlflow.log_metric("macro avg_recall", report['macro avg']['recall'])
            mlflow.log_metric("macro avg_f1-score", report['macro avg']['f1-score'])
            mlflow.log_metric("weighted avg_precision", report['weighted avg']['precision'])
            mlflow.log_metric("weighted avg_recall", report['weighted avg']['recall'])
            mlflow.log_metric("weighted avg_f1-score", report['weighted avg']['f1-score'])

            mlflow.log_text(report_text, "classification_report.txt")
            
            mlflow.set_tag("embedding_name", model_name)
            mlflow.set_tag("embedding_size", model_length)
            mlflow.set_tag("embedding_type", "mean")
            mlflow.set_tag("dataset_name", "clear_v2_generated_few_shot_tf_idf")
            mlflow.set_tag("model_name", "catboost")
            mlflow.catboost.log_model(catboost_model, "model")
            mlflow.log_params(params)
            mlflow.log_metrics({'accuracy': accuracy, 'f1-weighted': f1})
    except Exception as e:
        error_name = type(e).__name__
        print(f"Caught an error: {error_name}")

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
