In [2]:
import os
import torch
from transformers import BertTokenizer
from striprtf.striprtf import rtf_to_text
import re

# Функция для очистки текста
def clean_text(text):
    cleaned_text = re.sub(r'[^а-яА-Яa-zA-Z0-9]', ' ', text)
    cleaned_text = cleaned_text.lower()
    return cleaned_text

# Функция предсказания
def predict(model, tokenizer, text):
    cleaned_text = clean_text(text)
    inputs = tokenizer(cleaned_text, padding="max_length", truncation=True, max_length=512, return_tensors="pt")
    inputs = {key: value.to(device) for key, value in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    prediction = torch.argmax(logits, dim=1).cpu().numpy()[0]
    return prediction

# Функция для чтения и обработки RTF файлов
def parse_rtf_header(rtf_file):
    rtf = rtf_file.read()
    text = rtf_to_text(rtf)
    text = text.strip()
    return text

In [35]:
DOC_TYPES_DICT: dict[str, str] = {
    "proxy": ["доверенность"],
    "contract": ["договор"],
    "act": ["акт"],
    "application": ["заявление"],
    "order": ["приказ"],
    "invoice": ["счёт"],
    "bill": ["приложение"],
    "arrangement": ["соглашение"],
    "contractoffer": ["договор оферты"],
    "statute": ["устав"],
    "determination": ["решение"]
}

In [51]:
 def find_by_name_frequency(text):
        r = []
        for i, document_type in enumerate(DOC_TYPES_DICT):
            count = 0
            for word in DOC_TYPES_DICT[document_type]:
                count = text.lower().count(f'{word.lower()} ')
            count /= len(DOC_TYPES_DICT[document_type])
            
            r.append((tuple(DOC_TYPES_DICT.keys())[i], count))
        r.sort(reverse=True, key=lambda x: x[1])
        counts = list(map(lambda x: x[1], r))
        score = (counts[0] - counts[1])/(counts[0] or 1) if counts[0] else 0
        if score >= 0.5:
            return list(DOC_TYPES_DICT.keys()).index(r[0][0])
        else:
            return None

In [46]:
# Обработка файлов и оценка точности
def evaluate_model(directory, model, tokenizer):
    correct = 0
    total = 0
    for filename in os.listdir(directory):
        if filename.endswith(".rtf"):
            parts = filename.split('_')
            if len(parts) == 2 and parts[1].endswith('.rtf'):
                class_number = parts[0]
                if class_number.isdigit():
                    with open(os.path.join(directory, filename), 'r') as file:
                        text = parse_rtf_header(file)
                        expected_class = int(class_number)
                        predicted_class = predict(model, tokenizer, text)
                        if predicted_class == 7:
                            check_class = find_by_name_frequency(text)
                            if check_class != None:
                                predicted_class = check_class
                        # Логирование результатов предсказания и истинного класса
                        print(f"File: {filename}, Predict: {predicted_class}, True: {expected_class}")
                        if predicted_class == expected_class:
                            correct += 1
                        total += 1
    if total > 0:
        accuracy = correct / total
    else:
        accuracy = 0
    print(f"Total files processed: {total}, Model Accuracy: {accuracy:.2f}")
    return accuracy

In [47]:
# Загрузка токенизатора, модели и настройка устройства
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained("DeepPavlov/rubert-base-cased")
model_path = "models/full_model_epoch_10.pt"  # Путь к файлу с сохраненной моделью
model = torch.load(model_path, map_location=device)
model.to(device)
model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12

In [52]:
cls_directory = os.path.join(os.curdir, 'res')  # Путь к папке с RTF файлами
accuracy = evaluate_model(cls_directory, model, tokenizer)
print(f"Model Accuracy: {accuracy:.2f}")

File: 1_0.rtf, Predict: 8, True: 1
File: 1_1.rtf, Predict: 8, True: 1
File: 1_10.rtf, Predict: 7, True: 1
File: 1_100.rtf, Predict: 1, True: 1
File: 1_101.rtf, Predict: 8, True: 1
File: 1_102.rtf, Predict: 8, True: 1
File: 1_103.rtf, Predict: 1, True: 1
File: 1_104.rtf, Predict: 1, True: 1
File: 1_105.rtf, Predict: 8, True: 1
File: 1_106.rtf, Predict: 10, True: 1
File: 1_107.rtf, Predict: 8, True: 1
File: 1_11.rtf, Predict: 7, True: 1
File: 1_12.rtf, Predict: 7, True: 1
File: 1_13.rtf, Predict: 7, True: 1
File: 1_14.rtf, Predict: 7, True: 1
File: 1_15.rtf, Predict: 1, True: 1
File: 1_16.rtf, Predict: 7, True: 1
File: 1_17.rtf, Predict: 1, True: 1
File: 1_18.rtf, Predict: 7, True: 1
File: 1_19.rtf, Predict: 7, True: 1
File: 1_2.rtf, Predict: 7, True: 1
File: 1_20.rtf, Predict: 7, True: 1
File: 1_21.rtf, Predict: 7, True: 1
File: 1_22.rtf, Predict: 10, True: 1
File: 1_23.rtf, Predict: 7, True: 1
File: 1_24.rtf, Predict: 7, True: 1
File: 1_25.rtf, Predict: 7, True: 1
File: 1_26.rtf, Predi