In [1]:
import flair
flair.set_seed(2)

from flair.data import Corpus, Sentence
from flair.datasets import TREC_6, CSVClassificationCorpus
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer

import torch
import argparse
import json
import csv
import re
import time

import pandas as pd

In [2]:
learning_rate = 5.0e-5
mini_batch_size = 4
max_epochs = 10
model_name = "deepvk/USER-bge-m3"

In [3]:

        
column_name_map = {0: "text", 1: "label"}
corpus: Corpus = CSVClassificationCorpus("data/flair_data",
                                            column_name_map,
                                            skip_header=False,
                                            delimiter='\t',    # tab-separated files
                                            label_type='label')



label_dict = corpus.make_label_dictionary(label_type='label')

2024-11-16 19:24:45,923 Reading data from data\flair_data
2024-11-16 19:24:45,923 Train: data\flair_data\train.csv
2024-11-16 19:24:45,924 Dev: data\flair_data\dev.csv
2024-11-16 19:24:45,925 Test: data\flair_data\test.csv
2024-11-16 19:24:45,938 Computing label dictionary. Progress:


0it [00:00, ?it/s]
4068it [00:01, 3985.26it/s]

2024-11-16 19:24:46,967 Dictionary created for label 'label' with 9 values: FOOD_GOODS (seen 922 times), NON_FOOD_GOODS (seen 895 times), SERVICE (seen 884 times), LEASING (seen 380 times), LOAN (seen 380 times), REALE_STATE (seen 256 times), BANK_SERVICE (seen 215 times), NOT_CLASSIFIED (seen 117 times), TAX (seen 19 times)





In [4]:
document_embeddings = TransformerDocumentEmbeddings(model_name, fine_tune=True)
classifier = TextClassifier(document_embeddings, label_dictionary=label_dict, label_type='label')

trainer = ModelTrainer(classifier, corpus)

trainer.fine_tune(f'./models/gpt_data_v2/',
                    learning_rate=learning_rate,
                    mini_batch_size=mini_batch_size,
                    max_epochs=max_epochs,
                    monitor_test=True, train_with_dev=False, save_model_each_k_epochs=1)


2024-11-16 19:24:50,923 ----------------------------------------------------------------------------------------------------
2024-11-16 19:24:50,925 Model: "TextClassifier(
  (embeddings): TransformerDocumentEmbeddings(
    (model): XLMRobertaModel(
      (embeddings): XLMRobertaEmbeddings(
        (word_embeddings): Embedding(46167, 1024, padding_idx=1)
        (position_embeddings): Embedding(8194, 1024, padding_idx=1)
        (token_type_embeddings): Embedding(1, 1024)
        (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): XLMRobertaEncoder(
        (layer): ModuleList(
          (0-23): 24 x XLMRobertaLayer(
            (attention): XLMRobertaAttention(
              (self): XLMRobertaSdpaSelfAttention(
                (query): Linear(in_features=1024, out_features=1024, bias=True)
                (key): Linear(in_features=1024, out_features=1024, bias=True)
                (value): Linea

100%|██████████| 32/32 [00:09<00:00,  3.28it/s]

2024-11-16 19:31:56,601 DEV : loss 0.4902525246143341 - f1-score (micro avg)  0.922



100%|██████████| 32/32 [00:11<00:00,  2.76it/s]

2024-11-16 19:32:08,390 TEST : loss 0.4902525246143341 - f1-score (micro avg)  0.922





2024-11-16 19:32:08,713 ----------------------------------------------------------------------------------------------------
2024-11-16 19:32:41,688 epoch 2 - iter 101/1017 - loss 0.11498597 - time (sec): 32.97 - samples/sec: 12.25 - lr: 0.000049 - momentum: 0.000000
2024-11-16 19:33:22,549 epoch 2 - iter 202/1017 - loss 0.15942758 - time (sec): 73.83 - samples/sec: 10.94 - lr: 0.000049 - momentum: 0.000000
2024-11-16 19:34:06,986 epoch 2 - iter 303/1017 - loss 0.21001581 - time (sec): 118.27 - samples/sec: 10.25 - lr: 0.000048 - momentum: 0.000000
2024-11-16 19:34:48,703 epoch 2 - iter 404/1017 - loss 0.23103512 - time (sec): 159.99 - samples/sec: 10.10 - lr: 0.000048 - momentum: 0.000000
2024-11-16 19:35:32,404 epoch 2 - iter 505/1017 - loss 0.21175898 - time (sec): 203.69 - samples/sec: 9.92 - lr: 0.000047 - momentum: 0.000000
2024-11-16 19:36:13,295 epoch 2 - iter 606/1017 - loss 0.23226221 - time (sec): 244.58 - samples/sec: 9.91 - lr: 0.000047 - momentum: 0.000000
2024-11-16 19:3

100%|██████████| 32/32 [00:14<00:00,  2.27it/s]

2024-11-16 19:40:18,248 DEV : loss 0.07244223356246948 - f1-score (micro avg)  0.974



100%|██████████| 32/32 [00:13<00:00,  2.31it/s]

2024-11-16 19:40:32,290 TEST : loss 0.07244223356246948 - f1-score (micro avg)  0.974





2024-11-16 19:40:32,585 ----------------------------------------------------------------------------------------------------
2024-11-16 19:41:36,624 epoch 3 - iter 101/1017 - loss 0.12431403 - time (sec): 64.04 - samples/sec: 6.31 - lr: 0.000044 - momentum: 0.000000
2024-11-16 19:42:32,641 epoch 3 - iter 202/1017 - loss 0.16756819 - time (sec): 120.05 - samples/sec: 6.73 - lr: 0.000043 - momentum: 0.000000
2024-11-16 19:43:40,468 epoch 3 - iter 303/1017 - loss 0.14886173 - time (sec): 187.88 - samples/sec: 6.45 - lr: 0.000043 - momentum: 0.000000
2024-11-16 19:45:35,559 epoch 3 - iter 404/1017 - loss 0.18702503 - time (sec): 302.97 - samples/sec: 5.33 - lr: 0.000042 - momentum: 0.000000
2024-11-16 19:46:23,517 epoch 3 - iter 505/1017 - loss 0.26844292 - time (sec): 350.93 - samples/sec: 5.76 - lr: 0.000042 - momentum: 0.000000
2024-11-16 19:47:14,228 epoch 3 - iter 606/1017 - loss 0.55129664 - time (sec): 401.64 - samples/sec: 6.04 - lr: 0.000041 - momentum: 0.000000
2024-11-16 19:48:0

100%|██████████| 32/32 [00:09<00:00,  3.27it/s]

2024-11-16 19:51:00,286 DEV : loss 2.38578462600708 - f1-score (micro avg)  0.21



100%|██████████| 32/32 [00:11<00:00,  2.85it/s]

2024-11-16 19:51:11,762 TEST : loss 2.38578462600708 - f1-score (micro avg)  0.21





2024-11-16 19:51:11,962 ----------------------------------------------------------------------------------------------------
2024-11-16 19:52:05,316 epoch 4 - iter 101/1017 - loss 1.53851904 - time (sec): 53.35 - samples/sec: 7.57 - lr: 0.000038 - momentum: 0.000000
2024-11-16 19:53:02,600 epoch 4 - iter 202/1017 - loss 0.90714725 - time (sec): 110.63 - samples/sec: 7.30 - lr: 0.000038 - momentum: 0.000000
2024-11-16 19:54:04,036 epoch 4 - iter 303/1017 - loss 0.69619727 - time (sec): 172.07 - samples/sec: 7.04 - lr: 0.000037 - momentum: 0.000000
2024-11-16 19:54:58,025 epoch 4 - iter 404/1017 - loss 0.56985694 - time (sec): 226.06 - samples/sec: 7.15 - lr: 0.000037 - momentum: 0.000000
2024-11-16 19:56:10,547 epoch 4 - iter 505/1017 - loss 0.48337150 - time (sec): 298.58 - samples/sec: 6.77 - lr: 0.000036 - momentum: 0.000000
2024-11-16 19:57:06,679 epoch 4 - iter 606/1017 - loss 0.41400839 - time (sec): 354.71 - samples/sec: 6.83 - lr: 0.000036 - momentum: 0.000000
2024-11-16 19:58:0

100%|██████████| 32/32 [00:09<00:00,  3.55it/s]

2024-11-16 20:01:24,490 DEV : loss 0.016033632680773735 - f1-score (micro avg)  0.998



100%|██████████| 32/32 [00:09<00:00,  3.56it/s]

2024-11-16 20:01:33,990 TEST : loss 0.016033632680773735 - f1-score (micro avg)  0.998





2024-11-16 20:01:34,162 ----------------------------------------------------------------------------------------------------
2024-11-16 20:02:23,309 epoch 5 - iter 101/1017 - loss 0.04357185 - time (sec): 49.14 - samples/sec: 8.22 - lr: 0.000033 - momentum: 0.000000
2024-11-16 20:03:14,731 epoch 5 - iter 202/1017 - loss 0.03028550 - time (sec): 100.57 - samples/sec: 8.03 - lr: 0.000032 - momentum: 0.000000
2024-11-16 20:04:08,463 epoch 5 - iter 303/1017 - loss 0.03590309 - time (sec): 154.30 - samples/sec: 7.85 - lr: 0.000032 - momentum: 0.000000
2024-11-16 20:04:55,976 epoch 5 - iter 404/1017 - loss 0.04397541 - time (sec): 201.81 - samples/sec: 8.01 - lr: 0.000031 - momentum: 0.000000
2024-11-16 20:05:47,254 epoch 5 - iter 505/1017 - loss 0.04124751 - time (sec): 253.09 - samples/sec: 7.98 - lr: 0.000031 - momentum: 0.000000
2024-11-16 20:06:35,673 epoch 5 - iter 606/1017 - loss 0.04685602 - time (sec): 301.51 - samples/sec: 8.04 - lr: 0.000030 - momentum: 0.000000
2024-11-16 20:07:2

100%|██████████| 32/32 [00:22<00:00,  1.43it/s]

2024-11-16 20:11:18,319 DEV : loss 0.05989818274974823 - f1-score (micro avg)  0.994



100%|██████████| 32/32 [00:16<00:00,  1.90it/s]

2024-11-16 20:11:35,804 TEST : loss 0.05989818274974823 - f1-score (micro avg)  0.994





2024-11-16 20:11:35,990 ----------------------------------------------------------------------------------------------------
2024-11-16 20:12:34,658 epoch 6 - iter 101/1017 - loss 0.12197717 - time (sec): 58.66 - samples/sec: 6.89 - lr: 0.000027 - momentum: 0.000000
2024-11-16 20:14:31,882 epoch 6 - iter 202/1017 - loss 0.09733798 - time (sec): 175.89 - samples/sec: 4.59 - lr: 0.000027 - momentum: 0.000000
2024-11-16 20:15:18,425 epoch 6 - iter 303/1017 - loss 0.06563281 - time (sec): 222.43 - samples/sec: 5.45 - lr: 0.000026 - momentum: 0.000000
2024-11-16 20:16:06,623 epoch 6 - iter 404/1017 - loss 0.07044911 - time (sec): 270.63 - samples/sec: 5.97 - lr: 0.000026 - momentum: 0.000000
2024-11-16 20:16:56,629 epoch 6 - iter 505/1017 - loss 0.07142041 - time (sec): 320.64 - samples/sec: 6.30 - lr: 0.000025 - momentum: 0.000000
2024-11-16 20:17:43,725 epoch 6 - iter 606/1017 - loss 0.06335902 - time (sec): 367.73 - samples/sec: 6.59 - lr: 0.000024 - momentum: 0.000000
2024-11-16 20:18:3

100%|██████████| 32/32 [00:11<00:00,  2.74it/s]

2024-11-16 20:21:24,520 DEV : loss 0.4188661277294159 - f1-score (micro avg)  0.976



100%|██████████| 32/32 [00:11<00:00,  2.73it/s]

2024-11-16 20:21:36,566 TEST : loss 0.4188661277294159 - f1-score (micro avg)  0.976





2024-11-16 20:21:36,730 ----------------------------------------------------------------------------------------------------
2024-11-16 20:22:22,282 epoch 7 - iter 101/1017 - loss 0.07931006 - time (sec): 45.55 - samples/sec: 8.87 - lr: 0.000022 - momentum: 0.000000
2024-11-16 20:23:10,687 epoch 7 - iter 202/1017 - loss 0.06320913 - time (sec): 93.95 - samples/sec: 8.60 - lr: 0.000021 - momentum: 0.000000
2024-11-16 20:23:59,014 epoch 7 - iter 303/1017 - loss 0.06611996 - time (sec): 142.28 - samples/sec: 8.52 - lr: 0.000021 - momentum: 0.000000
2024-11-16 20:24:47,157 epoch 7 - iter 404/1017 - loss 0.06483480 - time (sec): 190.42 - samples/sec: 8.49 - lr: 0.000020 - momentum: 0.000000
2024-11-16 20:25:35,860 epoch 7 - iter 505/1017 - loss 0.06054937 - time (sec): 239.13 - samples/sec: 8.45 - lr: 0.000019 - momentum: 0.000000
2024-11-16 20:26:23,516 epoch 7 - iter 606/1017 - loss 0.05048654 - time (sec): 286.78 - samples/sec: 8.45 - lr: 0.000019 - momentum: 0.000000
2024-11-16 20:27:11

100%|██████████| 32/32 [00:13<00:00,  2.37it/s]

2024-11-16 20:31:07,369 DEV : loss 0.14267194271087646 - f1-score (micro avg)  0.974



100%|██████████| 32/32 [00:18<00:00,  1.78it/s]

2024-11-16 20:31:25,840 TEST : loss 0.14267194271087646 - f1-score (micro avg)  0.974





2024-11-16 20:31:26,035 ----------------------------------------------------------------------------------------------------
2024-11-16 20:32:19,288 epoch 8 - iter 101/1017 - loss 0.03482657 - time (sec): 53.25 - samples/sec: 7.59 - lr: 0.000016 - momentum: 0.000000
2024-11-16 20:33:21,779 epoch 8 - iter 202/1017 - loss 0.04252398 - time (sec): 115.74 - samples/sec: 6.98 - lr: 0.000016 - momentum: 0.000000
2024-11-16 20:34:27,909 epoch 8 - iter 303/1017 - loss 0.02839157 - time (sec): 181.87 - samples/sec: 6.66 - lr: 0.000015 - momentum: 0.000000
2024-11-16 20:35:55,741 epoch 8 - iter 404/1017 - loss 0.03197704 - time (sec): 269.70 - samples/sec: 5.99 - lr: 0.000014 - momentum: 0.000000
2024-11-16 20:37:00,296 epoch 8 - iter 505/1017 - loss 0.03760401 - time (sec): 334.26 - samples/sec: 6.04 - lr: 0.000014 - momentum: 0.000000
2024-11-16 20:38:02,652 epoch 8 - iter 606/1017 - loss 0.04006775 - time (sec): 396.61 - samples/sec: 6.11 - lr: 0.000013 - momentum: 0.000000
2024-11-16 20:39:1

100%|██████████| 32/32 [00:15<00:00,  2.08it/s]

2024-11-16 20:43:29,541 DEV : loss 0.013090575113892555 - f1-score (micro avg)  0.998



100%|██████████| 32/32 [00:13<00:00,  2.36it/s]

2024-11-16 20:43:43,322 TEST : loss 0.013090575113892555 - f1-score (micro avg)  0.998





2024-11-16 20:43:43,520 ----------------------------------------------------------------------------------------------------
2024-11-16 20:44:37,023 epoch 9 - iter 101/1017 - loss 0.05217003 - time (sec): 53.50 - samples/sec: 7.55 - lr: 0.000011 - momentum: 0.000000
2024-11-16 20:45:51,300 epoch 9 - iter 202/1017 - loss 0.03338340 - time (sec): 127.78 - samples/sec: 6.32 - lr: 0.000010 - momentum: 0.000000
2024-11-16 20:46:30,149 epoch 9 - iter 303/1017 - loss 0.02688604 - time (sec): 166.63 - samples/sec: 7.27 - lr: 0.000009 - momentum: 0.000000
2024-11-16 20:47:15,342 epoch 9 - iter 404/1017 - loss 0.02845709 - time (sec): 211.82 - samples/sec: 7.63 - lr: 0.000009 - momentum: 0.000000
2024-11-16 20:48:09,159 epoch 9 - iter 505/1017 - loss 0.02889815 - time (sec): 265.64 - samples/sec: 7.60 - lr: 0.000008 - momentum: 0.000000
2024-11-16 20:49:05,504 epoch 9 - iter 606/1017 - loss 0.02412729 - time (sec): 321.98 - samples/sec: 7.53 - lr: 0.000008 - momentum: 0.000000
2024-11-16 20:49:5

100%|██████████| 32/32 [00:13<00:00,  2.44it/s]

2024-11-16 20:52:42,518 DEV : loss 0.009175065904855728 - f1-score (micro avg)  0.998



100%|██████████| 32/32 [00:12<00:00,  2.51it/s]

2024-11-16 20:52:55,446 TEST : loss 0.009175065904855728 - f1-score (micro avg)  0.998





2024-11-16 20:52:55,598 ----------------------------------------------------------------------------------------------------
2024-11-16 20:53:59,663 epoch 10 - iter 101/1017 - loss 0.01194265 - time (sec): 64.06 - samples/sec: 6.31 - lr: 0.000005 - momentum: 0.000000
2024-11-16 20:54:43,432 epoch 10 - iter 202/1017 - loss 0.01586806 - time (sec): 107.83 - samples/sec: 7.49 - lr: 0.000004 - momentum: 0.000000
2024-11-16 20:55:33,748 epoch 10 - iter 303/1017 - loss 0.01074929 - time (sec): 158.15 - samples/sec: 7.66 - lr: 0.000004 - momentum: 0.000000
2024-11-16 20:56:25,789 epoch 10 - iter 404/1017 - loss 0.02239072 - time (sec): 210.19 - samples/sec: 7.69 - lr: 0.000003 - momentum: 0.000000
2024-11-16 20:57:20,297 epoch 10 - iter 505/1017 - loss 0.01797796 - time (sec): 264.70 - samples/sec: 7.63 - lr: 0.000003 - momentum: 0.000000
2024-11-16 20:58:13,356 epoch 10 - iter 606/1017 - loss 0.01499590 - time (sec): 317.76 - samples/sec: 7.63 - lr: 0.000002 - momentum: 0.000000
2024-11-16 2

RuntimeError: [enforce fail at inline_container.cc:424] . unexpected pos 3408256 vs 3408150

In [2]:
model = TextClassifier.load("models/gpt_data/model_epoch_5.pt")

In [4]:
result = model.evaluate(corpus.test, gold_label_type="label", mini_batch_size=2)

100%|██████████| 250/250 [00:18<00:00, 13.66it/s]


In [11]:
sentences = [sentence for sentence in corpus.test]

In [12]:
real_tags = [sentence.tag for sentence in sentences]
predicted_tags = []

In [13]:
for i in range(0, len(sentences), 8):
    sents = sentences[i:i+8]
    model.predict(sents, mini_batch_size=8)
    for sent in sents:
        predicted_tags.append(sent.tag)

In [18]:
for index in range(len(predicted_tags)):
    if predicted_tags[index] != real_tags[index]:
        print(index)
        print(real_tags[index], predicted_tags[index])

259
NON_FOOD_GOODS FOOD_GOODS


In [19]:
corpus.test[259]

Sentence[15]: "Оплата за Уход за одеждой и обувью по счету 11837472833255495630 от 21.06.2024г. Сумма 4110-00" → NON_FOOD_GOODS (1.0)

In [5]:
print(result.detailed_results)


Results:
- F-score (micro) 0.998
- F-score (macro) 0.9988
- Accuracy 0.998

By class:
                precision    recall  f1-score   support

NON_FOOD_GOODS     1.0000    0.9896    0.9948        96
    FOOD_GOODS     0.9890    1.0000    0.9945        90
       SERVICE     1.0000    1.0000    1.0000        88
  BANK_SERVICE     1.0000    1.0000    1.0000        49
           TAX     1.0000    1.0000    1.0000        48
          LOAN     1.0000    1.0000    1.0000        41
       LEASING     1.0000    1.0000    1.0000        38
   REALE_STATE     1.0000    1.0000    1.0000        27
NOT_CLASSIFIED     1.0000    1.0000    1.0000        23

      accuracy                         0.9980       500
     macro avg     0.9988    0.9988    0.9988       500
  weighted avg     0.9980    0.9980    0.9980       500



In [1]:
GPT, BERT 
GPT != BERT -> вручную

SyntaxError: invalid syntax (43358497.py, line 2)

In [None]:
1) Квантизация 8
2) оннх
3) пайторч тредс