In [None]:
import json
import os
import re
import torch
from symspellpy import SymSpell, Verbosity
from ultralytics import YOLO
from transformers import BertForSequenceClassification, BertTokenizer
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix

In [None]:
sym_spell = SymSpell(max_dictionary_edit_distance=2, prefix_length=7)
try:
    sym_spell.load_dictionary("ru_full.txt", term_index=0, count_index=1, encoding='utf-8')
except Exception as e:
    print(f"Ошибка загрузки словаря: {e}")
    exit(1)

In [None]:
try:
    yolo_model = YOLO("best.pt")
except Exception as e:
    print(f"Ошибка загрузки модели YOLO: {e}")
    exit(1)

In [None]:
class EgeEvaluator:
    def __init__(self, model_path, device=None):
        self.tokenizer = BertTokenizer.from_pretrained(model_path)
        self.model = BertForSequenceClassification.from_pretrained(model_path)
        self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)
        self.model.eval()

    def predict(self, text):
        inputs = self.tokenizer(
            text,
            max_length=512,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        ).to(self.device)

        with torch.no_grad():
            outputs = self.model(**inputs)

        probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
        pred = torch.argmax(probs).item()
        return pred, probs.cpu().numpy()

In [None]:
def correct_text(text):
    suggestions = sym_spell.lookup(text, Verbosity.CLOSEST, max_edit_distance=2)
    return suggestions[0].term if suggestions else text

In [None]:
def extract_text_from_webres(webres_path):
    try:
        with open(webres_path, 'r', encoding='utf-8', errors='replace') as f:
            data = json.load(f)
    except Exception as e:
        print(f"Ошибка чтения {webres_path}: {e}")
        return []

    text_boxes = []

    def parse_element(element):
        if isinstance(element, dict):
            if 'languages' in element:
                for lang in element['languages']:
                    if lang.get('lang') == 'rus':
                        for text_item in lang.get('texts', []):
                            text = text_item.get('text', '').strip()
                            if text:
                                corrected = correct_text(text)
                                box = {
                                    'x': element.get('x', 0),
                                    'y': element.get('y', 0),
                                    'w': element.get('w', 0),
                                    'h': element.get('h', 0),
                                    'text': corrected
                                }
                                text_boxes.append(box)
            for value in element.values():
                parse_element(value)
        elif isinstance(element, list):
            for item in element:
                parse_element(item)

    parse_element(data)
    return text_boxes

In [None]:
def get_yolo_boxes(image_path):
    try:
        results = yolo_model(image_path)
    except Exception as e:
        print(f"Ошибка обработки изображения {image_path}: {e}")
        return []

    yolo_boxes = []
    print(f"\nДетекция для {image_path}:")

    for result in results:
        for box in result.boxes:
            xyxy = box.xyxy[0].tolist()
            cls_id = int(box.cls)
            conf = box.conf.item()
            print(f"  Класс: {cls_id}, Метка: {yolo_model.names[cls_id]}, Conf: {conf:.2f}, BBox: {xyxy}")

            yolo_boxes.append({
                'x1': xyxy[0],
                'y1': xyxy[1],
                'x2': xyxy[2],
                'y2': xyxy[3],
                'label': cls_id
            })
    return yolo_boxes

In [None]:
def match_text_to_tasks(webres_boxes, yolo_boxes):
    task_data = []
    print("\nСопоставление боксов:")

    for yolo_box in yolo_boxes:
        task_num = yolo_box['label']
        task_key = f"Task_2{task_num + 2}"
        matched_texts = []

        for webres_box in webres_boxes:
            text_center_x = webres_box['x'] + webres_box['w'] / 2
            text_center_y = webres_box['y'] + webres_box['h'] / 2

            if (yolo_box['x1'] <= text_center_x <= yolo_box['x2'] and
                    yolo_box['y1'] <= text_center_y <= yolo_box['y2']):
                matched_texts.append(webres_box['text'])
                print(f"  Найдено совпадение: {task_key} -> '{webres_box['text']}'")

        if matched_texts:
            combined_text = " ".join(matched_texts)
            task_data.append({
                'task': task_key,
                'text': combined_text
            })

    return task_data


In [None]:
def find_pairs(webres_dir, images_dir):
    pairs = []
    try:
        webres_files = [f for f in os.listdir(webres_dir) if f.endswith('.webRes')]
    except Exception as e:
        print(f"Ошибка чтения директории {webres_dir}: {e}")
        return pairs

    for webres_file in webres_files:
        base_name = os.path.splitext(webres_file)[0].split('__')[0]
        image_file = f"{base_name}.png"
        image_path = os.path.join(images_dir, image_file)

        if os.path.exists(image_path):
            pairs.append((
                os.path.join(webres_dir, webres_file),
                image_path,
                webres_file
            ))
        else:
            print(f"Предупреждение: Не найден .png файл для {webres_file}")

    print(f"Найдено {len(pairs)} пар файлов")
    return pairs


In [None]:
def clean_text(text):
    text = re.sub(r'\s+', ' ', text)
    text = re.sub(r'[^\w\s.,!?;:]', '', text)
    return text.strip()

In [None]:
def load_ground_truth(ground_truth_path):
    if not ground_truth_path or not os.path.exists(ground_truth_path):
        return None

    try:
        with open(ground_truth_path, 'r', encoding='utf-8') as f:
            data = json.load(f)

        if isinstance(data, list) and len(data) > 0 and 'label' in data[0]:
            return {clean_text(item['text']): item['label'] for item in data}
        return None
    except Exception as e:
        print(f"Ошибка загрузки истинных оценок: {e}")
        return None

In [None]:
def process_pipeline(webres_dir, images_dir, bert_model_path,
                     ground_truth_path=None,
                     output_file="results.json"):
    pairs = find_pairs(webres_dir, images_dir)
    all_task_data = []

    for webres_path, image_path, source_file in pairs:
        print(f"\nОбработка: {source_file} + {os.path.basename(image_path)}")
        try:
            webres_boxes = extract_text_from_webres(webres_path)
            yolo_boxes = get_yolo_boxes(image_path)
            task_data = match_text_to_tasks(webres_boxes, yolo_boxes)
            all_task_data.extend(task_data)
            print(f"Найдено задач в текущей паре: {len(task_data)}")
        except Exception as e:
            print(f"Ошибка обработки пары: {e}")

    ground_truth = load_ground_truth(ground_truth_path)

    bert_evaluator = EgeEvaluator(bert_model_path)

    results = []
    true_labels = []
    pred_labels = []

    for task in all_task_data:
        try:
            clean_task_text = clean_text(task['text'])
            pred_label, pred_probs = bert_evaluator.predict(clean_task_text)

            result_entry = {
                'task': task['task'],
                'text': task['text'],
                'predicted_label': pred_label,
                'probabilities': pred_probs.tolist()
            }

            if ground_truth and clean_task_text in ground_truth:
                true_label = ground_truth[clean_task_text]
                result_entry['true_label'] = true_label
                true_labels.append(true_label)
                pred_labels.append(pred_label)

            results.append(result_entry)
        except Exception as e:
            print(f"Ошибка классификации текста для {task['task']}: {e}")

    if true_labels:
        print("\n" + "=" * 50)
        print("Результаты валидации модели")
        print("=" * 50)
        print(f"Accuracy: {accuracy_score(true_labels, pred_labels):.4f}")
        print(f"F1 Score (weighted): {f1_score(true_labels, pred_labels, average='weighted'):.4f}")

        print("\nClassification Report:")
        print(classification_report(true_labels, pred_labels, digits=4))

        print("\nConfusion Matrix:")
        print(confusion_matrix(true_labels, pred_labels))

    try:
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(results, f, ensure_ascii=False, indent=2)

        print("\nПримеры результатов:")
        for result in results[:5]:
            print(f"\nЗадача: {result['task']}")
            print(f"Текст: {result['text'][:100]}...")
            print(f"Предсказанная оценка: {result['predicted_label']}")
            if 'true_label' in result:
                print(f"Истинная оценка: {result['true_label']}")
            print(f"Вероятности: {result['probabilities']}")

        print(f"\nВсего обработано {len(results)} задач. Результаты сохранены в {output_file}")
    except Exception as e:
        print(f"Ошибка сохранения результатов: {e}")

    return results

In [None]:
if __name__ == "__main__":
    WEBRES_DIR = "school (1)"
    IMAGES_DIR = "dataset/val/images"
    BERT_MODEL_PATH = "ege_bert_model"
    GROUND_TRUTH_PATH = "bert_dataset.json"

    process_pipeline(
        webres_dir=WEBRES_DIR,
        images_dir=IMAGES_DIR,
        bert_model_path=BERT_MODEL_PATH,
        ground_truth_path=GROUND_TRUTH_PATH,
        output_file="final_results.json"
    )