In [1]:
import flair
flair.set_seed(2)

from flair.data import Corpus, Sentence
from flair.datasets import TREC_6, CSVClassificationCorpus
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer

import torch
import argparse
import json
import csv
import re
import time

import pandas as pd

In [2]:
learning_rate = 5.0e-5
mini_batch_size = 16
max_epochs = 8
model_name = "sergeyzh/rubert-tiny-turbo"

In [3]:

        
column_name_map = {0: "text", 1: "label"}
corpus: Corpus = CSVClassificationCorpus("flair_data",
                                            column_name_map,
                                            skip_header=False,
                                            delimiter='\t',    # tab-separated files
                                            label_type='label')



label_dict = corpus.make_label_dictionary(label_type='label')

2024-11-16 23:02:49,738 Reading data from flair_data
2024-11-16 23:02:49,746 Train: flair_data/train.csv
2024-11-16 23:02:49,751 Dev: flair_data/dev.csv
2024-11-16 23:02:49,755 Test: flair_data/test.csv
2024-11-16 23:02:50,057 Computing label dictionary. Progress:


0it [00:00, ?it/s]
4068it [00:33, 121.38it/s]

2024-11-16 23:03:23,953 Dictionary created for label 'label' with 9 values: FOOD_GOODS (seen 922 times), NON_FOOD_GOODS (seen 895 times), SERVICE (seen 884 times), LEASING (seen 380 times), LOAN (seen 380 times), REALE_STATE (seen 256 times), BANK_SERVICE (seen 215 times), NOT_CLASSIFIED (seen 117 times), TAX (seen 19 times)





In [4]:
document_embeddings = TransformerDocumentEmbeddings(model_name, fine_tune=True)
classifier = TextClassifier(document_embeddings, label_dictionary=label_dict, label_type='label')

trainer = ModelTrainer(classifier, corpus)

trainer.fine_tune(f'./models/rubert_{str(time.time()).split(".")[0]}',
                    learning_rate=learning_rate,
                    mini_batch_size=mini_batch_size,
                    max_epochs=max_epochs,
                    monitor_test=True, train_with_dev=True, save_model_each_k_epochs=1)

tokenizer_config.json:   0%|          | 0.00/1.29k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.41M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/732 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/712 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

2024-11-16 23:04:58,173 ----------------------------------------------------------------------------------------------------
2024-11-16 23:04:58,247 Model: "TextClassifier(
  (embeddings): TransformerDocumentEmbeddings(
    (model): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(83829, 312, padding_idx=0)
        (position_embeddings): Embedding(2048, 312)
        (token_type_embeddings): Embedding(2, 312)
        (LayerNorm): LayerNorm((312,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-2): 3 x BertLayer(
            (attention): BertAttention(
              (self): BertSdpaSelfAttention(
                (query): Linear(in_features=312, out_features=312, bias=True)
                (key): Linear(in_features=312, out_features=312, bias=True)
                (value): Linear(in_features=312, out_features=312, bias=True)
             

100%|██████████| 32/32 [00:00<00:00, 60.11it/s]

2024-11-16 23:06:20,268 TEST : loss 0.1925896853208542 - f1-score (micro avg)  0.946





2024-11-16 23:06:20,435 ----------------------------------------------------------------------------------------------------
2024-11-16 23:06:21,270 epoch 2 - iter 26/266 - loss 0.05012252 - time (sec): 0.83 - samples/sec: 499.38 - lr: 0.000048 - momentum: 0.000000
2024-11-16 23:06:22,098 epoch 2 - iter 52/266 - loss 0.06440232 - time (sec): 1.66 - samples/sec: 501.06 - lr: 0.000047 - momentum: 0.000000
2024-11-16 23:06:22,808 epoch 2 - iter 78/266 - loss 0.06861015 - time (sec): 2.37 - samples/sec: 526.44 - lr: 0.000047 - momentum: 0.000000
2024-11-16 23:06:23,401 epoch 2 - iter 104/266 - loss 0.07366403 - time (sec): 2.96 - samples/sec: 561.39 - lr: 0.000046 - momentum: 0.000000
2024-11-16 23:06:24,084 epoch 2 - iter 130/266 - loss 0.06932443 - time (sec): 3.65 - samples/sec: 570.27 - lr: 0.000045 - momentum: 0.000000
2024-11-16 23:06:34,654 epoch 2 - iter 156/266 - loss 0.06884417 - time (sec): 14.22 - samples/sec: 175.56 - lr: 0.000045 - momentum: 0.000000
2024-11-16 23:06:44,550 e

100%|██████████| 32/32 [00:12<00:00,  2.50it/s]


2024-11-16 23:07:33,944 TEST : loss 0.028102189302444458 - f1-score (micro avg)  0.978
2024-11-16 23:07:37,245 ----------------------------------------------------------------------------------------------------
2024-11-16 23:07:48,050 epoch 3 - iter 26/266 - loss 0.02856504 - time (sec): 10.80 - samples/sec: 38.52 - lr: 0.000041 - momentum: 0.000000
2024-11-16 23:07:58,542 epoch 3 - iter 52/266 - loss 0.03410903 - time (sec): 21.29 - samples/sec: 39.07 - lr: 0.000040 - momentum: 0.000000
2024-11-16 23:08:08,341 epoch 3 - iter 78/266 - loss 0.02299275 - time (sec): 31.09 - samples/sec: 40.14 - lr: 0.000040 - momentum: 0.000000
2024-11-16 23:08:16,751 epoch 3 - iter 104/266 - loss 0.02746639 - time (sec): 39.50 - samples/sec: 42.12 - lr: 0.000039 - momentum: 0.000000
2024-11-16 23:08:24,373 epoch 3 - iter 130/266 - loss 0.02735022 - time (sec): 47.12 - samples/sec: 44.14 - lr: 0.000038 - momentum: 0.000000
2024-11-16 23:08:26,493 epoch 3 - iter 156/266 - loss 0.02310641 - time (sec): 49

100%|██████████| 32/32 [00:00<00:00, 65.52it/s]

2024-11-16 23:08:31,502 TEST : loss 0.010993684642016888 - f1-score (micro avg)  0.998





2024-11-16 23:08:31,941 ----------------------------------------------------------------------------------------------------
2024-11-16 23:08:33,034 epoch 4 - iter 26/266 - loss 0.03474860 - time (sec): 1.09 - samples/sec: 381.90 - lr: 0.000034 - momentum: 0.000000
2024-11-16 23:08:34,023 epoch 4 - iter 52/266 - loss 0.02982021 - time (sec): 2.08 - samples/sec: 400.27 - lr: 0.000033 - momentum: 0.000000
2024-11-16 23:08:35,115 epoch 4 - iter 78/266 - loss 0.03202887 - time (sec): 3.17 - samples/sec: 393.65 - lr: 0.000033 - momentum: 0.000000
2024-11-16 23:08:36,041 epoch 4 - iter 104/266 - loss 0.03307392 - time (sec): 4.10 - samples/sec: 406.21 - lr: 0.000032 - momentum: 0.000000
2024-11-16 23:08:36,890 epoch 4 - iter 130/266 - loss 0.03480223 - time (sec): 4.95 - samples/sec: 420.57 - lr: 0.000031 - momentum: 0.000000
2024-11-16 23:08:37,734 epoch 4 - iter 156/266 - loss 0.03444546 - time (sec): 5.79 - samples/sec: 431.16 - lr: 0.000031 - momentum: 0.000000
2024-11-16 23:08:38,536 ep

100%|██████████| 32/32 [00:00<00:00, 66.77it/s]

2024-11-16 23:08:42,093 TEST : loss 0.01877407729625702 - f1-score (micro avg)  0.996





2024-11-16 23:08:42,267 ----------------------------------------------------------------------------------------------------
2024-11-16 23:08:43,120 epoch 5 - iter 26/266 - loss 0.01588987 - time (sec): 0.85 - samples/sec: 489.27 - lr: 0.000027 - momentum: 0.000000
2024-11-16 23:08:44,215 epoch 5 - iter 52/266 - loss 0.00802344 - time (sec): 1.95 - samples/sec: 427.69 - lr: 0.000026 - momentum: 0.000000
2024-11-16 23:08:45,032 epoch 5 - iter 78/266 - loss 0.01988553 - time (sec): 2.76 - samples/sec: 451.81 - lr: 0.000026 - momentum: 0.000000
2024-11-16 23:08:45,855 epoch 5 - iter 104/266 - loss 0.02262911 - time (sec): 3.59 - samples/sec: 464.06 - lr: 0.000025 - momentum: 0.000000
2024-11-16 23:08:46,681 epoch 5 - iter 130/266 - loss 0.01829372 - time (sec): 4.41 - samples/sec: 471.49 - lr: 0.000024 - momentum: 0.000000
2024-11-16 23:08:47,516 epoch 5 - iter 156/266 - loss 0.01992066 - time (sec): 5.25 - samples/sec: 475.71 - lr: 0.000024 - momentum: 0.000000
2024-11-16 23:08:48,364 ep

100%|██████████| 32/32 [00:00<00:00, 64.17it/s]

2024-11-16 23:08:51,892 TEST : loss 0.007578937336802483 - f1-score (micro avg)  0.996





2024-11-16 23:08:52,074 ----------------------------------------------------------------------------------------------------
2024-11-16 23:08:52,949 epoch 6 - iter 26/266 - loss 0.01247400 - time (sec): 0.87 - samples/sec: 476.93 - lr: 0.000020 - momentum: 0.000000
2024-11-16 23:08:53,782 epoch 6 - iter 52/266 - loss 0.00628657 - time (sec): 1.71 - samples/sec: 487.84 - lr: 0.000020 - momentum: 0.000000
2024-11-16 23:08:54,607 epoch 6 - iter 78/266 - loss 0.01387852 - time (sec): 2.53 - samples/sec: 493.08 - lr: 0.000019 - momentum: 0.000000
2024-11-16 23:08:55,743 epoch 6 - iter 104/266 - loss 0.02104696 - time (sec): 3.67 - samples/sec: 453.81 - lr: 0.000018 - momentum: 0.000000
2024-11-16 23:08:56,594 epoch 6 - iter 130/266 - loss 0.01749062 - time (sec): 4.52 - samples/sec: 460.38 - lr: 0.000018 - momentum: 0.000000
2024-11-16 23:08:57,387 epoch 6 - iter 156/266 - loss 0.01738614 - time (sec): 5.31 - samples/sec: 470.02 - lr: 0.000017 - momentum: 0.000000
2024-11-16 23:08:58,205 ep

100%|██████████| 32/32 [00:00<00:00, 60.06it/s]

2024-11-16 23:09:01,992 TEST : loss 0.014418857172131538 - f1-score (micro avg)  0.996





2024-11-16 23:09:02,173 ----------------------------------------------------------------------------------------------------
2024-11-16 23:09:02,886 epoch 7 - iter 26/266 - loss 0.00016161 - time (sec): 0.71 - samples/sec: 585.12 - lr: 0.000013 - momentum: 0.000000
2024-11-16 23:09:03,743 epoch 7 - iter 52/266 - loss 0.00817073 - time (sec): 1.57 - samples/sec: 530.43 - lr: 0.000013 - momentum: 0.000000
2024-11-16 23:09:04,883 epoch 7 - iter 78/266 - loss 0.00548269 - time (sec): 2.71 - samples/sec: 460.76 - lr: 0.000012 - momentum: 0.000000
2024-11-16 23:09:05,870 epoch 7 - iter 104/266 - loss 0.00596403 - time (sec): 3.70 - samples/sec: 450.32 - lr: 0.000011 - momentum: 0.000000
2024-11-16 23:09:06,888 epoch 7 - iter 130/266 - loss 0.01844916 - time (sec): 4.71 - samples/sec: 441.29 - lr: 0.000011 - momentum: 0.000000
2024-11-16 23:09:08,099 epoch 7 - iter 156/266 - loss 0.01981378 - time (sec): 5.92 - samples/sec: 421.34 - lr: 0.000010 - momentum: 0.000000
2024-11-16 23:09:08,947 ep

100%|██████████| 32/32 [00:00<00:00, 68.73it/s]

2024-11-16 23:09:12,530 TEST : loss 0.015175777487456799 - f1-score (micro avg)  0.996





2024-11-16 23:09:12,694 ----------------------------------------------------------------------------------------------------
2024-11-16 23:09:13,651 epoch 8 - iter 26/266 - loss 0.00051260 - time (sec): 0.96 - samples/sec: 435.36 - lr: 0.000006 - momentum: 0.000000
2024-11-16 23:09:14,601 epoch 8 - iter 52/266 - loss 0.01078906 - time (sec): 1.91 - samples/sec: 436.68 - lr: 0.000006 - momentum: 0.000000
2024-11-16 23:09:15,599 epoch 8 - iter 78/266 - loss 0.01460983 - time (sec): 2.90 - samples/sec: 429.88 - lr: 0.000005 - momentum: 0.000000
2024-11-16 23:09:16,483 epoch 8 - iter 104/266 - loss 0.01450712 - time (sec): 3.79 - samples/sec: 439.37 - lr: 0.000004 - momentum: 0.000000
2024-11-16 23:09:17,363 epoch 8 - iter 130/266 - loss 0.01565372 - time (sec): 4.67 - samples/sec: 445.69 - lr: 0.000004 - momentum: 0.000000
2024-11-16 23:09:18,245 epoch 8 - iter 156/266 - loss 0.01306274 - time (sec): 5.55 - samples/sec: 449.78 - lr: 0.000003 - momentum: 0.000000
2024-11-16 23:09:19,138 ep

100%|██████████| 32/32 [00:00<00:00, 72.60it/s]

2024-11-16 23:09:22,855 TEST : loss 0.016288593411445618 - f1-score (micro avg)  0.996





2024-11-16 23:09:23,270 ----------------------------------------------------------------------------------------------------
2024-11-16 23:09:23,272 Testing using last state of model ...


100%|██████████| 32/32 [00:00<00:00, 62.98it/s]

2024-11-16 23:09:23,812 
Results:
- F-score (micro) 0.996
- F-score (macro) 0.9976
- Accuracy 0.996

By class:
                precision    recall  f1-score   support

NON_FOOD_GOODS     1.0000    0.9792    0.9895        96
    FOOD_GOODS     0.9890    1.0000    0.9945        90
       SERVICE     0.9888    1.0000    0.9944        88
  BANK_SERVICE     1.0000    1.0000    1.0000        49
           TAX     1.0000    1.0000    1.0000        48
          LOAN     1.0000    1.0000    1.0000        41
       LEASING     1.0000    1.0000    1.0000        38
   REALE_STATE     1.0000    1.0000    1.0000        27
NOT_CLASSIFIED     1.0000    1.0000    1.0000        23

      accuracy                         0.9960       500
     macro avg     0.9975    0.9977    0.9976       500
  weighted avg     0.9960    0.9960    0.9960       500

2024-11-16 23:09:23,813 ----------------------------------------------------------------------------------------------------





{'test_score': 0.996}

In [5]:
sentences = [sentence for sentence in corpus.test]
real_tags = [sentence.tag for sentence in sentences]
predicted_tags = []
for i in range(0, len(sentences), 8):
    sents = sentences[i:i+8]
    classifier.predict(sents, mini_batch_size=8)
    for sent in sents:
        predicted_tags.append(sent.tag)

In [6]:
for index in range(len(predicted_tags)):
    if predicted_tags[index] != real_tags[index]:
        print(corpus.test[index].text)
        print(f"| Real tag: | {real_tags[index]} | Predicted tag: | {predicted_tags[index]} |\n")

Оплата за Уход за одеждой и обувью по счету 11837472833255495630 от 21.06.2024г. Сумма 4110-00
| Real tag: | NON_FOOD_GOODS | Predicted tag: | SERVICE |

Оплата за Крем Бархатные ручки Питательный для рук с маслом ши 80мл по счету 97745000424439727136 от 28 января 2023 Сумма 4890,00
| Real tag: | NON_FOOD_GOODS | Predicted tag: | FOOD_GOODS |

