In [None]:
# !pip install setfit

In [2]:
import json
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedShuffleSplit

In [3]:
train = pd.read_csv("prime99_train.csv")
train.reset_index(drop=True, inplace=True)

train_col = "Исполнитель"
# train_col = "Группа тем"
# train_col = "Тема"

theme_groups = train[train_col].unique()
theme_groups_dict = {k: ki for ki, k in enumerate(sorted(theme_groups))}
reverse_groups = {v:k for k, v in theme_groups_dict.items()}

k = StratifiedShuffleSplit(test_size=0.1, random_state=42, n_splits=1)
train_index, val_index = list(k.split(train, train["Тема"]))[0]
val = train.loc[val_index]
train = train.loc[train_index]
# train, val = train_test_split(train, random_state=42, test_size=0.1)
test = pd.read_csv("prime99_test.csv")

with open(f"setfit_classes_{train_col}.json", "w") as f:
    json.dump(theme_groups_dict, f)

In [4]:
!nvidia-smi

Fri Nov 24 22:21:53 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.41.03              Driver Version: 530.41.03    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla P100-SXM2-16GB            On | 00000000:AF:00.0 Off |                    0 |
| N/A   61C    P0               51W / 300W|   4948MiB / 16384MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [5]:
train

Unnamed: 0.1,Unnamed: 0,Исполнитель,Группа тем,Текст инцидента,Тема
3401,11589,Министерство здравоохранения,Здравоохранение/Медицина,'Запись к врачу открывается каждый день в 21.0...,Технические проблемы с записью на прием к врачу
4466,17754,Министерство здравоохранения,Здравоохранение/Медицина,"'[id590307429|Городская-Больница], тема кардио...",Диспансеризация
10695,22323,Город Пермь,ЖКХ,'Аналогичный вопрос. В месяц до 10 раз отключа...,Ненадлежащее качество или отсутствие отопления
16072,17594,Лысьвенский городской округ,Общественный транспорт,"'Здравствуйте, я по поводу расписания маршруто...",График движения общественного транспорта
1232,8088,Министерство социального развития ПК,Социальное обслуживание и защита,'#ПособияИВыплаты@mothers_of_perm<br>Уважаемые...,Оказание гос. соц. помощи
...,...,...,...,...,...
89,10515,ИГЖН ПК,ЖКХ,"'Пермь, Уссурийская улица, 19А<br>кошмарное со...",Ремонт подъездов
15111,21743,Бардымский муниципальный округ Пермского края,Спецпроекты,'Добрый день!<br>У нас вышел материал о выплат...,Спецпроекты
1630,18917,Александровский муниципальный округ Пермского ...,Связь и телевидение,'Добрый день. В Яйве где место сбора? https://...,★ Информационно-техническая поддержка
14913,783,Министерство социального развития ПК,Социальное обслуживание и защита,"'Здравствуйте, хочу оформить выплату с 3 до 7,...",Оказание гос. соц. помощи


In [6]:
from datasets import Dataset

train_dataset = Dataset.from_dict(
    {
    "text": list(train["Текст инцидента"].values),
    "label": list(train[train_col].apply(lambda x: theme_groups_dict[x]).values)
}
)
eval_dataset = Dataset.from_dict(
    {
    "text": list(val["Текст инцидента"].values),
    "label": list(val[train_col].apply(lambda x: theme_groups_dict[x]).values)
}
)

In [None]:
from datasets import Dataset, load_dataset
from sentence_transformers.losses import CosineSimilarityLoss

from setfit import SetFitModel, SetFitTrainer, sample_dataset

model = SetFitModel.from_pretrained(
    #"cointegrated/rubert-tiny2",
    "cointegrated/LaBSE-en-ru",
    use_differentiable_head=True,
    #multi_target_strategy="multi-output",
    head_params={"out_features": len(theme_groups)},
)

# Create trainer
trainer = SetFitTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss_class=CosineSimilarityLoss,
    metric="f1",
    batch_size=12,
    num_iterations=10, # The number of text pairs to generate for contrastive learning
    num_epochs=1, # The number of epochs to use for contrastive learning
    # column_mapping={"text": "text", "label": "label"} # Map dataset columns to text/label expected by trainer
)

# Train and evaluate
trainer.freeze() # Freeze the head
trainer.train() # Train only the body

# Unfreeze the head and freeze the body -> head-only training
trainer.unfreeze(keep_body_frozen=True)
# or
# Unfreeze the head and unfreeze the body -> end-to-end training

trainer.train(
    num_epochs=25, # The number of epochs to train the head or the whole model (body and head)
    batch_size=4,
    body_learning_rate=1e-5, # The body's learning rate
    learning_rate=1e-2, # The head's learning rate
    l2_weight=0.0, # Weight decay on **both** the body and head. If `None`, will use 0.01.
)
metrics = trainer.evaluate()

# Run inference
preds = model(["i loved the spiderman movie!", "pineapple on pizza is the worst 🤮"])

2023-11-24 22:21:56.261838: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-24 22:21:56.261889: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-24 22:21:56.263009: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-11-24 22:21:56.269716: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
model_head.pkl not found on HuggingFace Hub, 

Generating Training Pairs:   0%|          | 0/10 [00:00<?, ?it/s]

***** Running training *****
  Num examples = 374660
  Num epochs = 1
  Total optimization steps = 31222
  Total train batch size = 12


Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/31222 [00:00<?, ?it/s]

In [None]:
ls

In [None]:
preds = model(list(test["Текст инцидента"].values))
preds = list(preds.cpu().numpy())
pred_tags = [reverse_groups[t] for t in preds]

In [None]:
from sklearn.metrics import f1_score
f1_score(list(test[train_col].values), pred_tags, average="weighted")

In [None]:
from sklearn.metrics import classification_report

In [None]:
report = classification_report(list(test[train_col].values), pred_tags, output_dict=True)

In [None]:
report

In [None]:
with open(f"classification_report_{train_col}.json", "w") as f:
    json.dump(report, f)

In [None]:
# import huggingface_hub
# huggingface_hub.login()

In [None]:
# import huggingface_hub
# huggingface_hub.login()
trainer.push_to_hub("denis-gordeev/citizen-request-theme-labse")
# trainer.push_to_hub("denis-gordeev/citizen-request-performer")

In [None]:
# grouped = train.groupby("Тема")["Группа тем"].apply(set).apply(list)
# theme_to_group = dict(zip(*[grouped.index, grouped.values]))

# with open("theme_to_group.json", "w") as f:
#     json.dump(theme_to_group, f)

# grouped = train.groupby("Группа тем")["Тема"].apply(set).apply(list)
# group_to_themes = dict(zip(*[grouped.index, grouped.values]))

# with open("group_to_themes.json", "w") as f:
#     json.dump(group_to_themes, f)

In [None]:
# group_to_themes

In [None]:
import numpy as np

In [None]:
reverse_groups[np.argmax(trainer.model.predict_proba(["Медицина"]).cpu().detach().numpy())]

In [None]:
# import os
# HF_TOKEN = os.environ.get("HF_TOKEN")
# model0 = SetFitModel.from_pretrained("denis-gordeev/citizen-request-theme-group", use_auth_token=HF_TOKEN)

In [None]:
ls