In [1]:
# import kagglehub

# # Download latest version
# path = kagglehub.dataset_download("bestofbests9/icdar2015")

# print("Path to dataset files:", path)

Path to dataset files: /kaggle/input/icdar2015


In [None]:
# ============================================================================
# ИМПОРТЫ И НАСТРОЙКА ПЛАТФОРМЫ
# ============================================================================

import os
import sys
import json
import glob
import warnings
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Tuple, Optional, Any, Union

import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches

from PIL import Image
from tqdm import tqdm
from sklearn.metrics import precision_recall_curve, average_precision_score

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torchvision import transforms as T
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.retinanet import RetinaNet
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.ops import MultiScaleRoIAlign, box_iou, nms

import math
from functools import partial
from collections import Counter, defaultdict
from torchvision.models.detection.anchor_utils import AnchorGenerator
from torchmetrics.detection import MeanAveragePrecision

from torch.utils.data import Dataset, DataLoader

warnings.filterwarnings('ignore')

# ============================================================================
# КОНФИГУРАЦИЯ
# ============================================================================

class Config:
    """Единый класс для управления конфигурацией"""
    
    def __init__(self):
        self.platform = self._detect_platform()
        self.setup_paths()
        self.setup_hyperparameters()
        self.setup_training_params()
        self.setup_myArch_params()

    def setup_myArch_params(self):
        """Параметры для myArch модели"""
        self.MYARCH_BACKBONE = 'resnet50' # 'efficientnet_b0'
        self.MYARCH_INPUT_SIZE = (640, 640)
        self.MYARCH_NECK_CHANNELS = 256
        # Якоря для текста - маленькие и вытянутые!
        self.MYARCH_ANCHOR_SIZES = (16, 32, 64, 128, 256)
        self.MYARCH_ANCHOR_RATIOS = (0.1, 0.5, 1.0, 2.0, 10.0)
        self.MYARCH_SCORE_THRESH = 0.05
        self.MYARCH_NMS_THRESH = 0.3
            
    def _detect_platform(self) -> str:
        """Определяет платформу запуска"""
        if 'google.colab' in str(sys.modules):
            return 'colab'
        elif 'KAGGLE_KERNEL_RUN_TYPE' in os.environ:
            return 'kaggle'
        return 'local'
    
    def setup_paths(self):
        """Настраивает пути в зависимости от платформы"""
        if self.platform == 'kaggle':
            self.BASE_DATA_DIR = '/kaggle/input'
            self.BASE_WORKING_DIR = '/kaggle/working'
            self.ICDAR2015_ROOT = Path(self.BASE_DATA_DIR) / 'icdar2015'
            self.ICDAR2015_PROCESSED = Path(self.BASE_WORKING_DIR) / 'icdar2015_processed'
            
        elif self.platform == 'colab':
            self.BASE_DATA_DIR = '/content/drive/MyDrive'
            self.BASE_WORKING_DIR = '/content'
            self.ICDAR2015_ROOT = Path(self.BASE_DATA_DIR) / 'icdar2015'
            self.ICDAR2015_PROCESSED = Path(self.BASE_DATA_DIR) / 'icdar2015_processed'
            
        else:  # local
            self.BASE_DATA_DIR = './data'
            self.BASE_WORKING_DIR = './output'
            self.ICDAR2015_ROOT = Path(self.BASE_DATA_DIR) / 'icdar2015'
            self.ICDAR2015_PROCESSED = Path(self.BASE_WORKING_DIR) / 'icdar2015_processed'
    
    def setup_hyperparameters(self):
        """Настройки гиперпараметров"""
        self.NUM_CLASSES = 2  # фон + текст
        self.NUM_EPOCHS = 20
        self.BATCH_SIZE = 4
        self.LEARNING_RATE = 0.001
        self.WEIGHT_DECAY = 0.0001
        self.PATIENCE = 6   # ждать дольше нет смысла
        self.MIN_SIZE = 600
        self.MAX_SIZE = 1000
        self.CONFIDENCE_THRESHOLD = 0.5
        self.NMS_THRESHOLD = 0.3
        
    def setup_training_params(self):
        # """Параметры обучения"""
        self.ARCHITECTURE = 'faster_rcnn' # 'retinanet' # 'faster_rcnn' # 'myArch'  
        self.BACKBONE = 'resnet50'  # 'resnet50'
        self.MODE = 'train' # 'visualize' # 'compare' # 'train'  # train, compare, visualize
        self.MODEL_PATH = None  # путь для тестирования
        self.COMPARE_ARCHITECTURES = ['retinanet', 'faster_rcnn']
        self.VISUALIZE_SAMPLES = 5
                
    def get_train_paths(self) -> Dict[str, Path]:
        """Возвращает пути к тренировочным данным"""
        return {
            'images': self.ICDAR2015_ROOT / 'ch4_training_images',
            'labels': self.ICDAR2015_ROOT / 'ch4_training_localization_transcription_gt'
        }
    
    def get_test_paths(self) -> Dict[str, Path]:
        """Возвращает пути к тестовым данным"""
        return {
            'images': self.ICDAR2015_ROOT / 'ch4_test_images',
            'labels': self.ICDAR2015_ROOT / 'ch4_test_localization_transcription_gt'
        }
    
    def get_processed_paths(self) -> Dict[str, Path]:
        """Возвращает пути к обработанным данным"""
        return {
            'train_annotations': self.ICDAR2015_PROCESSED / 'train_annotations.json',
            'val_annotations': self.ICDAR2015_PROCESSED / 'val_annotations.json',
            'config': self.ICDAR2015_PROCESSED / 'config.json'
        }
    
    def check_data_structure(self) -> bool:
        """Проверяет структуру данных"""
        paths = [
            ('Тренировочные изображения', self.get_train_paths()['images']),
            ('Тренировочные аннотации', self.get_train_paths()['labels']),
            ('Тестовые изображения', self.get_test_paths()['images']),
            ('Тестовые аннотации', self.get_test_paths()['labels'])
        ]
        
        all_exist = True
        for name, path in paths:
            if path.exists():
                jpg_files = list(path.glob('*.jpg'))
                txt_files = list(path.glob('*.txt'))
                png_files = list(path.glob('*.png'))
                count = len(jpg_files) + len(txt_files) + len(png_files)
                print(f" {name}: {count} файлов")
            else:
                print(f" {name}: не найдено")
                all_exist = False
        return all_exist
    
    def print_info(self):
        """Выводит информацию о конфигурации"""
        print("="*60)
        print("КОНФИГУРАЦИЯ СИСТЕМЫ")
        print("="*60)
        print(f"Платформа: {self.platform.upper()}")
        print(f"Режим: {self.MODE.upper()}")
        print(f"Архитектура: {self.ARCHITECTURE}")
        print(f"Данные: {self.ICDAR2015_ROOT}")
        print(f"Устройство: {'CUDA' if torch.cuda.is_available() else 'CPU'}")
        print("="*60)

# ============================================================================
# ОБРАБОТКА ДАННЫХ
# ============================================================================

class ICDAR2015ToCOCOConverter:
    """Конвертер из формата ICDAR2015 в COCO"""
    
    def __init__(self, images_dir: Path, labels_dir: Path, class_names: List[str]):
        self.images_dir = images_dir
        self.labels_dir = labels_dir
        self.class_names = class_names
        self.coco_data = {
            "images": [],
            "annotations": [],
            "categories": []
        }
        self.annotation_id = 0
        
        # Создаем категории
        for i, name in enumerate(self.class_names):
            self.coco_data["categories"].append({
                "id": i,
                "name": name,
                "supercategory": "object"
            })
    
    def convert(self, output_path: Path) -> Dict[str, List]:
        """Конвертирует ICDAR2015 в COCO формат"""
        image_files = list(self.images_dir.glob("*.jpg"))
        if not image_files:
            image_files = list(self.images_dir.glob("*.png"))
        
        print(f"Найдено {len(image_files)} изображений")
        
        for img_id, img_path in enumerate(tqdm(image_files, desc="Конвертация ICDAR2015 → COCO")):
            try:
                # Загрузка изображения
                img = cv2.imread(str(img_path))
                if img is None:
                    print(f"Не удалось загрузить: {img_path}")
                    continue
                    
                height, width = img.shape[:2]
                
                # Информация об изображении
                image_info = {
                    "id": img_id,
                    "file_name": img_path.name,
                    "width": width,
                    "height": height
                }
                self.coco_data["images"].append(image_info)
                
                # Обработка аннотаций
                base_name = img_path.stem
                anno_file = self.labels_dir / f"gt_{base_name}.txt"
                
                if anno_file.exists():
                    self._process_annotation(anno_file, img_id)
                else:
                    # Попробуем другой формат имени файла
                    alt_anno_file = self.labels_dir / f"{base_name}.txt"
                    if alt_anno_file.exists():
                        self._process_annotation(alt_anno_file, img_id)
                    
            except Exception as e:
                print(f"Ошибка при обработке {img_path}: {e}")
        
        # Сохранение
        output_path.parent.mkdir(parents=True, exist_ok=True)
        with open(output_path, 'w') as f:
            json.dump(self.coco_data, f, indent=2)
        
        print(f"\n COCO аннотации сохранены в {output_path}")
        print(f"  Изображений: {len(self.coco_data['images'])}")
        print(f"  Аннотаций: {len(self.coco_data['annotations'])}")
        
        return self.coco_data
    
    def _process_annotation(self, anno_file: Path, image_id: int):
        """Обрабатывает файл аннотации ICDAR2015"""
        try:
            with open(anno_file, 'r', encoding='utf-8-sig') as f:
                lines = f.readlines()
        except UnicodeDecodeError:
            try:
                with open(anno_file, 'r', encoding='latin-1') as f:
                    lines = f.readlines()
            except:
                with open(anno_file, 'r', encoding='iso-8859-1') as f:
                    lines = f.readlines()
        
        for line in lines:
            line = line.strip()
            if not line:
                continue
                
            try:
                # ICDAR2015 формат: x1,y1,x2,y2,x3,y3,x4,y4,text
                parts = line.split(',')
                
                # Пробуем разные варианты разделителей
                if len(parts) < 8:
                    parts = line.split('\t')
                if len(parts) < 8:
                    parts = line.split(' ')
                
                if len(parts) < 8:
                    continue
                    
                # Парсинг координат
                coords = []
                for part in parts[:8]:
                    try:
                        coords.append(int(float(part.strip())))
                    except:
                        coords.append(int(part.strip()))
                
                x_coords = coords[0::2]
                y_coords = coords[1::2]
                
                # Создание bounding box
                x_min, x_max = min(x_coords), max(x_coords)
                y_min, y_max = min(y_coords), max(y_coords)
                width, height = x_max - x_min, y_max - y_min
                
                # Проверка валидности bounding box
                if width <= 2 or height <= 2 or width > 10000 or height > 10000:
                    continue
                
                # Создание аннотации
                annotation = {
                    "id": self.annotation_id,
                    "image_id": image_id,
                    "category_id": 0,  # только один класс - текст
                    "bbox": [float(x_min), float(y_min), float(width), float(height)],
                    "area": float(width * height),
                    "iscrowd": 0,
                    "segmentation": [],
                    "text": parts[8] if len(parts) > 8 else ""
                }
                
                self.coco_data["annotations"].append(annotation)
                self.annotation_id += 1
                
            except (ValueError, IndexError) as e:
                # Пропускаем некорректные строки
                continue

# ============================================================================
# ДАТАСЕТ
# ============================================================================

class ICDAR2015Dataset(Dataset):
    """Датасет для данных ICDAR2015 в COCO формате"""
    
    def __init__(self, images_dir: Path, annotations_path: Path, 
                 transform: Optional[callable] = None, 
                 max_size: Optional[int] = None):
        self.images_dir = images_dir
        self.transform = transform
        
        with open(annotations_path, 'r') as f:
            self.coco_data = json.load(f)
        
        # Создаем индексы для быстрого доступа
        self._create_indices()
        
        # Ограничение размера датасета (для отладки)
        if max_size and max_size < len(self.image_ids):
            self.image_ids = self.image_ids[:max_size]
    
    def _create_indices(self):
        """Создает индексы для быстрого доступа"""
        self.image_id_to_info = {img["id"]: img for img in self.coco_data["images"]}
        self.image_id_to_annotations = {}
        
        for ann in self.coco_data["annotations"]:
            img_id = ann["image_id"]
            if img_id not in self.image_id_to_annotations:
                self.image_id_to_annotations[img_id] = []
            self.image_id_to_annotations[img_id].append(ann)
        
        self.image_ids = list(self.image_id_to_info.keys())
    
    def __len__(self) -> int:
        return len(self.image_ids)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        image_id = self.image_ids[idx]
        image_info = self.image_id_to_info[image_id]
        
        # Загрузка изображения
        img_path = self.images_dir / image_info["file_name"]
        if not img_path.exists():
            # Попробуем найти файл с другим расширением
            img_path = self.images_dir / f"{image_info['file_name'].split('.')[0]}.jpg"
            if not img_path.exists():
                img_path = self.images_dir / f"{image_info['file_name'].split('.')[0]}.png"
        
        try:
            image = Image.open(img_path).convert("RGB")
        except:
            # Если не удалось загрузить, создаем пустое изображение
            image = Image.new("RGB", (image_info["width"], image_info["height"]), (128, 128, 128))
        
        # Получение аннотаций
        annotations = self.image_id_to_annotations.get(image_id, [])
        
        # Извлечение bounding boxes и labels
        boxes, labels = [], []
        for ann in annotations:
            x_min, y_min, width, height = ann["bbox"]
            
            # Проверка и коррекция координат
            x_min = max(0, x_min)
            y_min = max(0, y_min)
            width = max(1, min(width, image_info["width"] - x_min))
            height = max(1, min(height, image_info["height"] - y_min))
            
            # Проверка валидности bounding box
            if width > 0 and height > 0 and width <= image_info["width"] and height <= image_info["height"]:
                boxes.append([x_min, y_min, x_min + width, y_min + height])
                labels.append(ann["category_id"] + 1)  # +1 для фона
        
        if not boxes:  # Если нет аннотаций, создаем пустые тензоры
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros(0, dtype=torch.int64)
        else:
            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            labels = torch.as_tensor(labels, dtype=torch.int64)
        
        # Создание target словаря
        target = {
            "boxes": boxes,
            "labels": labels,
            "image_id": torch.tensor([image_id]),
            "area": (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) if len(boxes) > 0 else torch.tensor([]),
            "iscrowd": torch.zeros(len(boxes), dtype=torch.int64)
        }
        
        if self.transform:
            try:
                image = self.transform(image)
            except:
                # Если трансформация не удалась, используем базовую
                image = T.ToTensor()(image)
        
        return image, target
    
    def get_image_info(self, idx: int) -> Dict[str, Any]:
        """Возвращает информацию об изображении"""
        image_id = self.image_ids[idx]
        return self.image_id_to_info[image_id]

# ============================================================================
# АУГМЕНТАЦИЯ И ТРАНСФОРМАЦИИ
# ============================================================================

def get_train_transforms():
    """Трансформации для тренировочных данных"""
    return T.Compose([
        T.ToTensor(),
        T.RandomHorizontalFlip(p=0.3),
        T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

def get_val_transforms():
    """Трансформации для валидационных данных"""
    return T.Compose([
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])


# ============================================================================
# СВОЯ АРХИТЕКТУРА (основа взята с семинара)
# ============================================================================

class MyArchDetector(nn.Module):
#     """Максимально простая "рабочая" (ха-ха) версия детектора. На поиске текста она так и не заработала"""
    
    def __init__(self, backbone_name="resnet18", num_classes=2, 
                 input_size=(640, 640), neck_channels=512):
        super().__init__()
        
        self.num_classes = num_classes
        self.input_size = input_size
        self.grid_size = 20  # Размер сетки для якорей
        
        # Используем предобученный ResNet в качестве backbone
        self.backbone = torchvision.models.resnet18(pretrained=True)
        
        # Удаляем последние слои (avgpool и fc)
        self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])
        
        # Замораживаем первые слои
        for name, param in self.backbone.named_parameters():
            if 'layer1' in name or 'layer2' in name:
                param.requires_grad = False
        
        # Feature Pyramid Network (упрощенная)
        self.fpn = nn.ModuleDict({
            'c5': nn.Conv2d(512, neck_channels, 1),
            'c4': nn.Conv2d(256, neck_channels, 1),
            'c3': nn.Conv2d(128, neck_channels, 1),
        })
        
        self.fpn_upsample = nn.Upsample(scale_factor=2, mode='nearest')
        
        # Голова для классификации и регрессии
        self.cls_head = nn.Sequential(
            nn.Conv2d(neck_channels, neck_channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(neck_channels, 9 * (num_classes + 1), 3, padding=1)
        )
        
        self.reg_head = nn.Sequential(
            nn.Conv2d(neck_channels, neck_channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(neck_channels, 9 * 4, 3, padding=1)
        )
        
        # Создаем якоря (3 размера * 3 соотношения = 9 якорей на клетку)
        self.anchor_sizes = [32, 64, 128]  # в пикселях
        self.anchor_ratios = [0.5, 1.0, 2.0]
        self.num_anchors = len(self.anchor_sizes) * len(self.anchor_ratios)
        
        # Генерируем якоря для сетки
        self.anchors = self._generate_anchors()
        
        # Инициализация
        self._initialize_weights()
    
    def _initialize_weights(self):
        """Инициализация весов"""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, std=0.01)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
    
    def _generate_anchors(self):
        """Генерация якорей для сетки 20x20"""
        anchors = []
        stride = self.input_size[0] // self.grid_size
        
        for i in range(self.grid_size):
            for j in range(self.grid_size):
                center_x = (j + 0.5) * stride
                center_y = (i + 0.5) * stride
                
                for size in self.anchor_sizes:
                    for ratio in self.anchor_ratios:
                        w = size * math.sqrt(ratio)
                        h = size / math.sqrt(ratio)
                        
                        x1 = center_x - w / 2
                        y1 = center_y - h / 2
                        x2 = center_x + w / 2
                        y2 = center_y + h / 2
                        
                        anchors.append([x1, y1, x2, y2])
        
        return torch.tensor(anchors)  # [grid_size*grid_size*9, 4]
    
    def forward(self, images, targets=None):
        device = images[0].device
        
        if self.training and targets is not None:
            return self._forward_train(images, targets, device)
        else:
            return self._forward_inference(images, device)
    
    def _forward_train(self, images, targets, device):
        """Тренировочный forward"""
        losses = {
            'loss_classifier': torch.tensor(0.0, device=device),
            'loss_box_reg': torch.tensor(0.0, device=device),
            'loss_objectness': torch.tensor(0.0, device=device),
            'loss_rpn_box_reg': torch.tensor(0.0, device=device),
        }
        
        cls_losses = []
        reg_losses = []
        
        anchors = self.anchors.to(device)
        
        for i, image in enumerate(images):
            img_tensor = image.unsqueeze(0)
            
            # Resize до фиксированного размера
            img_tensor = F.interpolate(img_tensor, size=self.input_size, 
                                      mode='bilinear', align_corners=False)
            
            # Forward через backbone
            features = self.backbone(img_tensor)  # [1, 512, 20, 20]
            
            # Упрощенный FPN (только один уровень)
            neck_features = self.fpn['c5'](features)
            
            # Forward через heads
            cls_logits = self.cls_head(neck_features)  # [1, 9*3, 20, 20]
            reg_preds = self.reg_head(neck_features)   # [1, 9*4, 20, 20]
            
            # Reshape predictions
            batch_size, _, grid_h, grid_w = cls_logits.shape
            
            cls_logits = cls_logits.permute(0, 2, 3, 1).reshape(
                batch_size, grid_h * grid_w * self.num_anchors, self.num_classes + 1
            )
            
            reg_preds = reg_preds.permute(0, 2, 3, 1).reshape(
                batch_size, grid_h * grid_w * self.num_anchors, 4
            )
            
            # Получаем целевые боксы
            target_boxes = targets[i]['boxes']
            target_labels = targets[i]['labels']
            
            # Если нет объектов, все якоря - фон
            if len(target_boxes) == 0:
                assigned_labels = torch.zeros(grid_h * grid_w * self.num_anchors, 
                                            device=device, dtype=torch.long)
                assigned_boxes = torch.zeros((grid_h * grid_w * self.num_anchors, 4), 
                                           device=device)
            else:
                # Вычисляем IoU между якорями и целевыми боксами
                ious = box_iou(anchors, target_boxes)
                
                # Находим лучший IoU для каждого якоря
                best_ious, best_idx = ious.max(dim=1)
                
                # Назначаем метки: IoU > 0.5 - позитив, < 0.3 - негатив
                assigned_labels = torch.zeros_like(best_ious, dtype=torch.long)
                pos_mask = best_ious > 0.5
                neg_mask = best_ious < 0.3
                
                assigned_labels[pos_mask] = target_labels[best_idx[pos_mask]]
                assigned_labels[neg_mask] = 0  # фон
                
                # Для игнорируемых якорей (0.3 <= IoU <= 0.5) оставляем -1
                ignore_mask = ~(pos_mask | neg_mask)
                assigned_labels[ignore_mask] = -1
                
                # Вычисляем целевые смещения для позитивных якорей
                assigned_boxes = torch.zeros((len(anchors), 4), device=device)
                
                if pos_mask.any():
                    pos_anchors = anchors[pos_mask]
                    pos_targets = target_boxes[best_idx[pos_mask]]
                    
                    # Вычисляем смещения
                    anchor_centers = (pos_anchors[:, :2] + pos_anchors[:, 2:]) / 2
                    anchor_sizes = pos_anchors[:, 2:] - pos_anchors[:, :2]
                    
                    target_centers = (pos_targets[:, :2] + pos_targets[:, 2:]) / 2
                    target_sizes = pos_targets[:, 2:] - pos_targets[:, :2]
                    
                    # Добавляем epsilon для стабильности
                    eps = 1e-8
                    anchor_sizes = torch.clamp(anchor_sizes, min=eps)
                    target_sizes = torch.clamp(target_sizes, min=eps)
                    
                    dxdy = (target_centers - anchor_centers) / anchor_sizes
                    dwh = torch.log(target_sizes / anchor_sizes)
                    
                    assigned_boxes[pos_mask] = torch.cat([dxdy, dwh], dim=1)
            
            # Классификация (игнорируем якоря с label = -1)
            valid_mask = assigned_labels != -1
            if valid_mask.any():
                cls_loss = F.cross_entropy(
                    cls_logits[0][valid_mask],
                    assigned_labels[valid_mask],
                    reduction='mean'
                )
            else:
                cls_loss = torch.tensor(0.0, device=device)
            
            # Регрессия (только для позитивных якорей)
            pos_mask = assigned_labels > 0
            if pos_mask.any():
                reg_loss = F.smooth_l1_loss(
                    reg_preds[0][pos_mask],
                    assigned_boxes[pos_mask],
                    reduction='mean'
                )
            else:
                reg_loss = torch.tensor(0.0, device=device)
            
            cls_losses.append(cls_loss)
            reg_losses.append(reg_loss)
        
        # Усредняем losses с весами
        if cls_losses and reg_losses:
            losses['loss_classifier'] = torch.stack(cls_losses).mean() * 1.0
            losses['loss_box_reg'] = torch.stack(reg_losses).mean() * 10.0  # Больший вес для регрессии
        
        return losses
    
    def _forward_inference(self, images, device):
        """Инференс"""
        predictions = []
        anchors = self.anchors.to(device)
        
        for i, image in enumerate(images):
            img_tensor = image.unsqueeze(0)
            
            # Resize
            img_tensor = F.interpolate(img_tensor, size=self.input_size, 
                                      mode='bilinear', align_corners=False)
            
            # Forward
            features = self.backbone(img_tensor)
            neck_features = self.fpn['c5'](features)
            
            cls_logits = self.cls_head(neck_features)
            reg_preds = self.reg_head(neck_features)
            
            # Размеры
            batch_size, _, grid_h, grid_w = cls_logits.shape
            
            # Reshape predictions
            cls_logits = cls_logits.permute(0, 2, 3, 1).reshape(
                batch_size, grid_h * grid_w * self.num_anchors, self.num_classes + 1
            )
            reg_preds = reg_preds.permute(0, 2, 3, 1).reshape(
                batch_size, grid_h * grid_w * self.num_anchors, 4
            )
            
            # Декодируем боксы
            boxes = self._decode_boxes(anchors, reg_preds[0])
            
            # Получаем вероятности классов
            scores = F.softmax(cls_logits[0], dim=-1)
            
            # Фильтруем предсказания
            final_boxes, final_labels, final_scores = [], [], []
            
            for cls_idx in range(1, self.num_classes + 1):  # Пропускаем фон
                cls_scores = scores[:, cls_idx]
                
                # Фильтрация по порогу
                score_mask = cls_scores > 0.1
                if not score_mask.any():
                    continue
                
                cls_boxes = boxes[score_mask]
                cls_scores = cls_scores[score_mask]
                
                # NMS
                keep = nms(cls_boxes, cls_scores, 0.5)
                
                if len(keep) > 0:
                    final_boxes.append(cls_boxes[keep])
                    final_labels.append(torch.full((len(keep),), cls_idx, 
                                                  device=device, dtype=torch.int64))
                    final_scores.append(cls_scores[keep])
            
            if final_boxes:
                pred_dict = {
                    'boxes': torch.cat(final_boxes, dim=0),
                    'labels': torch.cat(final_labels, dim=0),
                    'scores': torch.cat(final_scores, dim=0)
                }
            else:
                pred_dict = {
                    'boxes': torch.zeros((0, 4), device=device),
                    'labels': torch.zeros(0, dtype=torch.int64, device=device),
                    'scores': torch.zeros(0, device=device)
                }
            
            predictions.append(pred_dict)
        
        return predictions
    
    def _decode_boxes(self, anchors, reg_preds):
        """Декодирует предсказанные смещения в боксы"""
        anchor_centers = (anchors[:, :2] + anchors[:, 2:]) / 2
        anchor_sizes = anchors[:, 2:] - anchors[:, :2]
        
        # Ограничиваем предсказания для стабильности
        reg_preds = torch.clamp(reg_preds, -4.0, 4.0)
        
        centers = anchor_centers + reg_preds[:, :2] * anchor_sizes
        sizes = anchor_sizes * torch.exp(reg_preds[:, 2:])
        
        boxes = torch.cat([
            centers - sizes / 2,
            centers + sizes / 2
        ], dim=1)
        
        return boxes

# ============================================================================
# ФАБРИКА МОДЕЛЕЙ
# ============================================================================

class ModelFactory:
    """Фабрика для создания различных архитектур детекторов"""
    
    @staticmethod
    def create_model(architecture: str, num_classes: int = 2, **kwargs) -> nn.Module:
        """
        Создание модели детекции по имени архитектуры
        
        Args:
            architecture: Имя архитектуры
            num_classes: Количество классов
            **kwargs: Дополнительные параметры
        
        Returns:
            Модель PyTorch
        """
        
        if architecture == 'faster_rcnn':
            return ModelFactory.create_faster_rcnn(num_classes, **kwargs)
        
        elif architecture == 'retinanet':
            return ModelFactory.create_retinanet(num_classes, **kwargs)
        
        elif architecture == 'myArch':
            return ModelFactory.create_myArch(num_classes, **kwargs)
        
        else:
            raise ValueError(f"Неизвестная архитектура: {architecture}")
    

    @staticmethod
    def create_faster_rcnn(num_classes: int = 2, **kwargs):
        """Создание Faster R-CNN модели"""
        backbone_name = kwargs.get('backbone', 'resnet50')
        
        backbone = resnet_fpn_backbone(
            backbone_name=backbone_name,
            weights='DEFAULT',
            trainable_layers=5
        )
        
        # Anchor generator
        anchor_sizes = kwargs.get('anchor_sizes', ((32,), (64,), (128,), (256,), (512,)))
        aspect_ratios = kwargs.get('aspect_ratios', ((0.5, 1.0, 2.0),) * len(anchor_sizes))
        
        anchor_generator = AnchorGenerator(
            sizes=anchor_sizes,
            aspect_ratios=aspect_ratios
        )
        
        # ROI Align
        roi_pooler = MultiScaleRoIAlign(
            featmap_names=['0', '1', '2', '3'],
            output_size=7,
            sampling_ratio=2
        )
        
        # Создание модели
        model = FasterRCNN(
            backbone,
            num_classes=num_classes,
            rpn_anchor_generator=anchor_generator,
            box_roi_pool=roi_pooler,
            min_size=kwargs.get('min_size', 600),
            max_size=kwargs.get('max_size', 1000),
            box_score_thresh=0.05,
            box_nms_thresh=0.3,
            box_detections_per_img=100
        )
        
        return model
    
    
    @staticmethod
    def create_retinanet(num_classes: int = 2, **kwargs):
        """Создание RetinaNet модели"""
        from torchvision.models.detection.anchor_utils import AnchorGenerator
        
        backbone_name = kwargs.get('backbone', 'resnet18')
        
        # Backbone
        backbone = resnet_fpn_backbone(
            backbone_name=backbone_name,
            weights='DEFAULT',
            trainable_layers=4
        )
        
        # Anchor generator
        anchor_sizes = kwargs.get('anchor_sizes', ((32, 64, 128, 256, 512),) * 5)
        aspect_ratios = kwargs.get('aspect_ratios', ((0.5, 1.0, 2.0),) * 5)
        
        # Преобразуем в правильный формат для AnchorGenerator
        if isinstance(anchor_sizes[0], int):
            anchor_sizes = (tuple(anchor_sizes),) * 5
        elif isinstance(anchor_sizes[0], (list, tuple)) and len(anchor_sizes) == 1:
            anchor_sizes = (tuple(anchor_sizes[0]),) * 5
        
        anchor_generator = AnchorGenerator(
            sizes=anchor_sizes,
            aspect_ratios=aspect_ratios
        )
        
        # Создание модели
        model = RetinaNet(
            backbone,
            num_classes=num_classes,
            anchor_generator=anchor_generator,
            min_size=kwargs.get('min_size', 800),
            max_size=kwargs.get('max_size', 1333),
            score_thresh=0.05,
            nms_thresh=0.5
        )
        
        return model
    
    # @staticmethod
    # def create_myArch(num_classes: int = 2, **kwargs):
    #     """Создание своей модели детектора"""
    #     backbone_name = kwargs.get('backbone', 'efficientnet_b0')
    #     input_size = kwargs.get('input_size', (640, 640))
    #     neck_channels = kwargs.get('neck_channels', 256)
    #     anchor_sizes = kwargs.get('anchor_sizes', (32, 64, 128, 256, 512))
    #     anchor_ratios = kwargs.get('anchor_ratios', (0.5, 1.0, 2.0))
    #     score_thresh = kwargs.get('score_thresh', 0.05)
    #     nms_thresh = kwargs.get('nms_thresh', 0.5)
        
    #     model = MyArchDetector(
    #         backbone_name=backbone_name,
    #         num_classes=num_classes,
    #         anchor_sizes=anchor_sizes,
    #         anchor_ratios=anchor_ratios,
    #         input_size=input_size,
    #         neck_channels=neck_channels,
    #         score_thresh=score_thresh,
    #         nms_thresh=nms_thresh
    #     )
       
    #     return model
    
    @staticmethod
    def create_myArch(num_classes: int = 2, **kwargs):
        """Создание максимально простой модели"""
        model = MyArchDetector(
            num_classes=num_classes,
            input_size=kwargs.get('input_size', (640, 640))
        )
        return model
    
# ============================================================================
# МЕТРИКИ
# ============================================================================

class MetricsCalculator:
    """Калькулятор метрик для детекции"""
    
    @staticmethod
    def calculate_metrics(predictions: List[Dict], targets: List[Dict], 
                         iou_thresholds: List[float] = [0.5, 0.75]) -> Dict[str, float]:
        """Вычисляет метрики детекции"""
        if not predictions or not targets:
            return {f'ap_{iou}': 0.0 for iou in iou_thresholds}
        
        metrics = {}
        
        for iou_thresh in iou_thresholds:
            precision, recall, ap = MetricsCalculator._calculate_ap(
                predictions, targets, iou_thresh
            )
            metrics[f'precision_{iou_thresh}'] = precision
            metrics[f'recall_{iou_thresh}'] = recall
            metrics[f'ap_{iou_thresh}'] = ap
        
        # mAP
        metrics['map'] = np.mean([metrics[f'ap_{iou}'] for iou in iou_thresholds])
        
        # F1-score для IoU=0.5
        p, r = metrics.get('precision_0.5', 0), metrics.get('recall_0.5', 0)
        metrics['f1'] = 2 * p * r / (p + r + 1e-10) if (p + r) > 0 else 0.0
        
        return metrics
    
    @staticmethod
    def _calculate_ap(predictions: List[Dict], targets: List[Dict], 
                     iou_threshold: float = 0.5) -> Tuple[float, float, float]:
        """Вычисляет Average Precision для заданного IoU порога"""
        all_pred_boxes, all_pred_scores, all_target_boxes = [], [], []
        
        # Сбор всех предсказаний и целей
        for pred, target in zip(predictions, targets):
            if len(pred['boxes']) > 0:
                all_pred_boxes.append(pred['boxes'].cpu().numpy())
                all_pred_scores.append(pred['scores'].cpu().numpy())
            
            if len(target['boxes']) > 0:
                all_target_boxes.append(target['boxes'].cpu().numpy())
        
        if not all_pred_boxes or not all_target_boxes:
            return 0.0, 0.0, 0.0
        
        # Объединение
        pred_boxes = np.vstack(all_pred_boxes)
        pred_scores = np.concatenate(all_pred_scores)
        target_boxes = np.vstack(all_target_boxes)
        
        # Сортировка по уверенности
        if len(pred_scores) > 0:
            sorted_indices = np.argsort(pred_scores)[::-1]
            pred_boxes = pred_boxes[sorted_indices]
            pred_scores = pred_scores[sorted_indices]
        
        # Матрица IoU
        iou_matrix = MetricsCalculator._calculate_iou_matrix(pred_boxes, target_boxes)
        
        # Сопоставление
        matches = []
        used_targets = set()
        
        for i in range(len(pred_boxes)):
            if iou_matrix.shape[1] == 0:  # Нет целей
                matches.append((i, -1, 0))
                continue
            
            ious = iou_matrix[i]
            best_iou_idx = np.argmax(ious)
            best_iou = ious[best_iou_idx]
            
            if best_iou >= iou_threshold and best_iou_idx not in used_targets:
                matches.append((i, best_iou_idx, 1))
                used_targets.add(best_iou_idx)
            else:
                matches.append((i, -1, 0))
        
        # Precision и Recall
        if matches:
            tp_cumsum = np.cumsum([m[2] for m in matches])
            fp_cumsum = np.cumsum([1 - m[2] for m in matches])
            
            precision = tp_cumsum / (tp_cumsum + fp_cumsum + 1e-10)
            recall = tp_cumsum / (len(target_boxes) + 1e-10)
            
            # Average Precision
            ap = MetricsCalculator._calculate_average_precision(precision, recall)
            
            return (precision[-1] if len(precision) > 0 else 0.0,
                    recall[-1] if len(recall) > 0 else 0.0,
                    ap)
        else:
            return 0.0, 0.0, 0.0
    
    @staticmethod
    def _calculate_iou_matrix(boxes1: np.ndarray, boxes2: np.ndarray) -> np.ndarray:
        """Вычисляет матрицу IoU между двумя наборами боксов"""
        if len(boxes1) == 0 or len(boxes2) == 0:
            return np.zeros((len(boxes1), len(boxes2)))
        
        boxes1_tensor = torch.tensor(boxes1)
        boxes2_tensor = torch.tensor(boxes2)
        iou_matrix = box_iou(boxes1_tensor, boxes2_tensor)
        
        return iou_matrix.numpy()
    
    @staticmethod
    def _calculate_average_precision(precision: np.ndarray, recall: np.ndarray) -> float:
        """Вычисляет Average Precision по кривой Precision-Recall"""
        if len(precision) == 0 or len(recall) == 0:
            return 0.0
        
        precision = np.concatenate(([0.], precision, [0.]))
        recall = np.concatenate(([0.], recall, [1.]))
        
        # Сглаживание precision
        for i in range(len(precision) - 1, 0, -1):
            precision[i - 1] = max(precision[i - 1], precision[i])
        
        # Вычисление AP
        indices = np.where(recall[1:] != recall[:-1])[0]
        ap = np.sum((recall[indices + 1] - recall[indices]) * precision[indices + 1])
        
        return ap

# ============================================================================
# ТРЕНИРОВКА
# ============================================================================

class Trainer:
    """Класс для тренировки моделей"""
    
    def __init__(self, config: Config, model: nn.Module, 
                 train_loader: DataLoader, val_loader: DataLoader):
        self.config = config
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        self.model.to(self.device)
        
        # # Оптимизатор и scheduler
        # self.optimizer = optim.AdamW(
        #     model.parameters(),
        #     lr=config.LEARNING_RATE,
        #     weight_decay=config.WEIGHT_DECAY
        # )

        self.optimizer = optim.SGD(
            model.parameters(),
            lr=config.LEARNING_RATE,
            momentum=0.9,
            weight_decay=config.WEIGHT_DECAY,
            nesterov=True
        )
        
        self.scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizer,
            milestones=[20, 40, 60],
            gamma=0.1
        )

        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        #     optimizer,
        #     T_max=num_epochs,
        #     eta_min=0.00001
        # )
        
        self.train_losses = []
        self.val_metrics_history = []
        self.best_map = 0.0
    
    def train_one_epoch(self, epoch: int) -> float:
        """Одна эпоха тренировки"""
        self.model.train()
        total_loss = 0.0
        batch_count = 0
        
        progress_bar = tqdm(self.train_loader, desc=f"Эпоха {epoch+1}")
        
        for batch_idx, (images, targets) in enumerate(progress_bar):
            if len(images) == 0:
                continue
            
            # Перемещение на устройство
            images = [img.to(self.device) for img in images]
            targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
            
            try:
                # Forward pass
                loss_dict = self.model(images, targets)
                losses = sum(loss for loss in loss_dict.values())
                
                # Проверка на NaN/Inf
                if not torch.isfinite(losses):
                    print(f"   NaN/Inf loss в батче {batch_idx}")
                    self.optimizer.zero_grad()
                    continue
                
                # Backward pass
                self.optimizer.zero_grad()
                losses.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
                self.optimizer.step()
                
                total_loss += losses.item()
                batch_count += 1
                
                # Обновление progress bar
                progress_bar.set_postfix({"loss": f"{losses.item():.4f}"})
                #         # Логируем дополнительные метрики
                # if 'pos_anchors_avg' in loss_dict:
                #     progress_bar.set_postfix({
                #         "loss": f"{losses.item():.4f}",
                #         "pos": f"{loss_dict['pos_anchors_avg'].item():.1f}"
                #     })

            except RuntimeError as e:
                print(f"   Ошибка в батче {batch_idx}: {e}")
                continue
               
            
        return total_loss / batch_count if batch_count > 0 else float('nan')
    
    def evaluate(self, num_samples: int = None) -> Dict[str, float]:
        """Оценка модели"""
        self.model.eval()
        all_predictions = []
        all_targets = []
        
        with torch.no_grad():
            for i, (images, targets) in enumerate(tqdm(self.val_loader, desc="Валидация")):
                if num_samples and i >= num_samples:
                    break
                
                if len(images) == 0:
                    continue
                
                # Перемещение на устройство
                images = [img.to(self.device) for img in images]
                
                # Предсказание
                predictions = self.model(images)
                
                all_predictions.extend(predictions)
                all_targets.extend(targets)
        
        # Вычисление метрик
        metrics = MetricsCalculator.calculate_metrics(all_predictions, all_targets)
        self.model.train()
        
        return metrics
    
    def train(self, num_epochs: int) -> Dict[str, Any]:
        """Полный цикл тренировки"""
        print(f"\nНачало тренировки на {num_epochs} эпох")
        print(f"Устройство: {self.device}")
        print(f"Количество тренировочных батчей: {len(self.train_loader)}")
        print(f"Количество валидационных батчей: {len(self.val_loader)}")
        
        patience_counter = 0
        
        for epoch in range(num_epochs):
            print(f"\nЭпоха {epoch + 1}/{num_epochs}")
            print(f"  LR: {self.optimizer.param_groups[0]['lr']:.6f}")
            
            # Тренировка
            train_loss = self.train_one_epoch(epoch)
            self.train_losses.append(train_loss)
            
            # Валидация
            val_metrics = self.evaluate(num_samples=20)
            self.val_metrics_history.append(val_metrics)
            
            # Обновление learning rate
            self.scheduler.step()
            
            # Вывод метрик
            print(f"  Train Loss: {train_loss:.4f}")
            print(f"  Val mAP: {val_metrics['map']:.4f}")
            print(f"  AP@0.5: {val_metrics['ap_0.5']:.4f}")
            print(f"  AP@0.75: {val_metrics['ap_0.75']:.4f}")
            
            # Сохранение лучшей модели
            if val_metrics['map'] > self.best_map:
                self.best_map = val_metrics['map']
                self.save_checkpoint(epoch, is_best=True)
                patience_counter = 0
                print(f"   Лучшая модель сохранена (mAP: {self.best_map:.4f})")
            else:
                patience_counter += 1
                print(f"  Patience: {patience_counter}/{self.config.PATIENCE}")
            
            # Ранняя остановка
            if patience_counter >= self.config.PATIENCE:
                print(f"  Ранняя остановка на эпохе {epoch + 1}")
                break
        
        # Сохранение финальной модели
        self.save_checkpoint(num_epochs - 1, is_best=False, final=True)
        
        return {
            'train_losses': self.train_losses,
            'val_metrics': self.val_metrics_history,
            'best_metrics': self.best_map
        }
    
    def save_checkpoint(self, epoch: int, is_best: bool = False, final: bool = False):
        """Сохранение чекпоинта модели"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'best_map': self.best_map,
            'train_losses': self.train_losses,
            'val_metrics': self.val_metrics_history
        }
        
        Path(self.config.BASE_WORKING_DIR).mkdir(parents=True, exist_ok=True)
        
        if is_best:
            path = Path(self.config.BASE_WORKING_DIR) / 'best_model.pth'
        elif final:
            path = Path(self.config.BASE_WORKING_DIR) / 'final_model.pth'
        else:
            path = Path(self.config.BASE_WORKING_DIR) / f'checkpoint_epoch_{epoch}.pth'
        
        torch.save(checkpoint, path)
        print(f"  Чекпоинт сохранен: {path}")
    
    def load_checkpoint(self, checkpoint_path: Path):
        """Загрузка чекпоинта"""
        checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        self.train_losses = checkpoint.get('train_losses', [])
        self.val_metrics_history = checkpoint.get('val_metrics', [])
        self.best_map = checkpoint.get('best_map', 0.0)
        
        print(f"Загружен чекпоинт эпохи {checkpoint['epoch']}")

# ============================================================================
# ВИЗУАЛИЗАЦИЯ
# ============================================================================

class Visualizer:
    """Класс для визуализации результатов"""
    
    @staticmethod
    def plot_training_curves(train_losses: List[float], val_maps: List[float], 
                            model_name: str = "Model"):
        """Визуализация кривых обучения"""
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
        
        # Loss кривая
        ax1.plot(range(1, len(train_losses) + 1), train_losses, 'b-', linewidth=2)
        ax1.set_xlabel('Эпоха')
        ax1.set_ylabel('Loss')
        ax1.set_title(f'{model_name} - Кривая обучения')
        ax1.grid(True, alpha=0.3)
        
        # mAP кривая
        if val_maps:
            ax2.plot(range(1, len(val_maps) + 1), val_maps, 'g-', linewidth=2)
            ax2.set_xlabel('Эпоха')
            ax2.set_ylabel('mAP')
            ax2.set_title(f'{model_name} - mAP по эпохам')
            ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
    
    @staticmethod
    def visualize_predictions(model: nn.Module, dataset: ICDAR2015Dataset, 
                             config: Config, num_samples: int = 3):
        """Визуализация предсказаний модели"""
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model.to(device)
        model.eval()
        
        for i in range(min(num_samples, len(dataset))):
            # Загрузка данных
            image, target = dataset[i]
            image_input = image.unsqueeze(0).to(device)
            
            # Предсказание
            with torch.no_grad():
                prediction = model(image_input)[0]
            
            # Денормализация изображения
            mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
            std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
            image_np = image.cpu() * std + mean
            image_np = image_np.permute(1, 2, 0).numpy()
            image_np = np.clip(image_np, 0, 1)
            
            # Создание subplots
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
            
            # Ground Truth
            ax1.imshow(image_np)
            if len(target['boxes']) > 0:
                for box in target['boxes'].cpu().numpy():
                    x1, y1, x2, y2 = box
                    rect = patches.Rectangle(
                        (x1, y1), x2 - x1, y2 - y1,
                        linewidth=2, edgecolor='green', facecolor='none', alpha=0.7
                    )
                    ax1.add_patch(rect)
            ax1.set_title(f"Ground Truth ({len(target['boxes'])} объектов)")
            ax1.axis('off')
            
            # Predictions
            ax2.imshow(image_np)
            boxes = prediction['boxes'].cpu().numpy()
            scores = prediction['scores'].cpu().numpy()
            
            # Фильтрация по confidence
            confidence_threshold = config.CONFIDENCE_THRESHOLD
            valid_indices = scores > confidence_threshold
            boxes = boxes[valid_indices]
            scores = scores[valid_indices]
            
            for box, score in zip(boxes, scores):
                x1, y1, x2, y2 = box
                color = (score, 1 - score, 0) if score > 0.5 else (1, score, 0)
                rect = patches.Rectangle(
                    (x1, y1), x2 - x1, y2 - y1,
                    linewidth=2, edgecolor=color, facecolor='none', alpha=0.7
                )
                ax2.add_patch(rect)
                
                if score > 0.7:
                    ax2.text(x1, y1 - 5, f"{score:.2f}", 
                            color=color, fontsize=8, fontweight='bold')
            
            ax2.set_title(f"Предсказания ({len(boxes)} объектов)")
            ax2.axis('off')
            
            plt.suptitle(f"Пример {i + 1}", fontsize=14)
            plt.tight_layout()
            plt.show()
            
            print(f"Пример {i + 1}: {len(boxes)} обнаружено, {len(target['boxes'])} GT")
            if len(scores) > 0:
                print(f"  Max confidence: {scores.max():.3f}, Mean: {scores.mean():.3f}")

# ============================================================================
# ВСПОМОГАТЕЛЬНЫЕ ФУНКЦИИ
# ============================================================================

def create_dataloaders(config: Config, batch_size: int = None):
    """Создание DataLoader'ов для обучения и валидации"""
    if batch_size is None:
        batch_size = config.BATCH_SIZE
    
    processed_paths = config.get_processed_paths()
    train_paths = config.get_train_paths()
    test_paths = config.get_test_paths()
    
    # Проверка существования файлов
    if not processed_paths['train_annotations'].exists():
        print(f"Не найден файл аннотаций: {processed_paths['train_annotations']}")
        return None, None
    
    # Создание датасетов
    train_dataset = ICDAR2015Dataset(
        train_paths['images'],
        processed_paths['train_annotations'],
        transform=get_train_transforms()
    )
    
    val_dataset = ICDAR2015Dataset(
        test_paths['images'],
        processed_paths['val_annotations'],
        transform=get_val_transforms()
    )
    
    print(f"Тренировочный датасет: {len(train_dataset)} изображений")
    print(f"Валидационный датасет: {len(val_dataset)} изображений")
    
    # Создание DataLoader'ов
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        collate_fn=lambda x: tuple(zip(*x)),
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        collate_fn=lambda x: tuple(zip(*x)),
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    return train_loader, val_loader

def prepare_data(config: Config) -> bool:
    """Подготовка данных ICDAR2015"""
    
    print("\n" + "="*60)
    print("ПОДГОТОВКА ДАННЫХ ICDAR2015")
    print("="*60)
    
    # Создание директорий
    config.ICDAR2015_PROCESSED.mkdir(parents=True, exist_ok=True)
    
    # Проверка существования данных
    if not config.check_data_structure():
        print("\n Не все необходимые данные найдены!")
        print("\nОжидаемая структура:")
        print(f"{config.ICDAR2015_ROOT}/")
        print("  ├── ch4_training_images/")
        print("  ├── ch4_training_localization_transcription_gt/")
        print("  ├── ch4_test_images/")
        print("  └── ch4_test_localization_transcription_gt/")
        return False
    
    # Получение путей
    train_paths = config.get_train_paths()
    test_paths = config.get_test_paths()
    processed_paths = config.get_processed_paths()
    
    # Конвертация тренировочных данных
    print("\nКонвертация тренировочных данных...")
    converter_train = ICDAR2015ToCOCOConverter(
        train_paths['images'],
        train_paths['labels'],
        ["text"]
    )
    coco_train = converter_train.convert(processed_paths['train_annotations'])
    
    # Конвертация тестовых данных
    print("\nКонвертация тестовых данных...")
    converter_val = ICDAR2015ToCOCOConverter(
        test_paths['images'],
        test_paths['labels'],
        ["text"]
    )
    coco_val = converter_val.convert(processed_paths['val_annotations'])
    
    # Сохранение конфигурации
    config_data = {
        "platform": config.platform,
        "train": {
            "images": str(train_paths['images']),
            "annotations": str(processed_paths['train_annotations'])
        },
        "val": {
            "images": str(test_paths['images']),
            "annotations": str(processed_paths['val_annotations'])
        },
        "categories": ["text"],
        "num_classes": 2,
        "processed_dir": str(config.ICDAR2015_PROCESSED)
    }
    
    with open(processed_paths['config'], 'w') as f:
        json.dump(config_data, f, indent=2)
    
    print(f"\n Конфигурация сохранена: {processed_paths['config']}")
    print("\nСВОДКА ДАННЫХ:")
    print(f"  Тренировочные: {len(coco_train['images'])} изображений, {len(coco_train['annotations'])} аннотаций")
    print(f"  Тестовые: {len(coco_val['images'])} изображений, {len(coco_val['annotations'])} аннотаций")
    
    return True

def run_training(config: Config, architecture: str, **kwargs):
    """Запуск обучения модели"""
    
    print(f"\nОбучение модели {architecture}...")

    if architecture == 'myArch':
        kwargs.update({
            'backbone': config.MYARCH_BACKBONE,
            'input_size': config.MYARCH_INPUT_SIZE,
            'neck_channels': config.MYARCH_NECK_CHANNELS,
            'anchor_sizes': config.MYARCH_ANCHOR_SIZES,
            'anchor_ratios': config.MYARCH_ANCHOR_RATIOS
        })
    
    # Создание модели
    model = ModelFactory.create_model(
        architecture=architecture,
        num_classes=config.NUM_CLASSES,
        **kwargs
    )
    
    # Создание DataLoader'ов
    train_loader, val_loader = create_dataloaders(config)
    
    if train_loader is None or val_loader is None:
        print("Не удалось создать DataLoader'ы")
        return None
    
    # Создание тренера
    trainer = Trainer(config, model, train_loader, val_loader)
    
    # Запуск обучения
    results = trainer.train(config.NUM_EPOCHS)
    
    # Визуализация результатов
    if results['val_metrics']:
        Visualizer.plot_training_curves(
            results['train_losses'],
            [m['map'] for m in results['val_metrics']],
            model_name=architecture
        )
    
    return results

def test_model(config: Config, model_path: str, architecture: str = None):
    """Тестирование модели"""
    
    print(f"\nТестирование модели из {model_path}")
    
    # Определение архитектуры
    if architecture is None:
        if 'faster_rcnn' in model_path.lower():
            architecture = 'faster_rcnn'
        elif 'retinanet' in model_path.lower():
            architecture = 'retinanet'
        elif 'myArch' in model_path.lower():
            architecture = 'myArch'
        else:
            architecture = 'faster_rcnn'
    
    # Загрузка модели
    model = ModelFactory.create_model(
        architecture=architecture,
        num_classes=config.NUM_CLASSES
    )
    
    checkpoint = torch.load(model_path, map_location='cpu')
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        model.load_state_dict(checkpoint)
    
    # Создание DataLoader для тестирования
    _, val_loader = create_dataloaders(config, batch_size=1)
    if val_loader is None:
        print("Не удалось создать DataLoader для тестирования")
        return None
    
    val_dataset = val_loader.dataset
    
    # Визуализация предсказаний
    Visualizer.visualize_predictions(model, val_dataset, config, num_samples=config.VISUALIZE_SAMPLES)
    
    # Оценка модели
    print("\nОЦЕНКА МОДЕЛИ НА ВАЛИДАЦИОННОМ НАБОРЕ...")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    model.eval()
    
    all_predictions = []
    all_targets = []
    
    with torch.no_grad():
        for images, targets in tqdm(val_loader, desc="Тестирование"):
            if len(images) == 0:
                continue
            
            images = [img.to(device) for img in images]
            predictions = model(images)
            
            all_predictions.extend(predictions)
            all_targets.extend(targets)
    
    # Вычисление метрик
    metrics = MetricsCalculator.calculate_metrics(all_predictions, all_targets)
    
    print("\n" + "="*60)
    print("РЕЗУЛЬТАТЫ ТЕСТИРОВАНИЯ")
    print("="*60)
    print(f"mAP: {metrics['map']:.4f}")
    print(f"AP@0.5: {metrics['ap_0.5']:.4f}")
    print(f"AP@0.75: {metrics['ap_0.75']:.4f}")
    print(f"F1-Score: {metrics['f1']:.4f}")
    
    return metrics

# ============================================================================
# ГЛАВНАЯ ФУНКЦИЯ
# ============================================================================

def main():
    """Главная функция, управляемая конфигом"""
    
    # Инициализация конфигурации
    config = Config()
    config.print_info()
    
    # Проверка данных
    if not config.ICDAR2015_ROOT.exists():
        print(f"\nДанные не найдены по пути: {config.ICDAR2015_ROOT}")
        print("\nИНСТРУКЦИЯ ПО ПОДГОТОВКЕ ДАННЫХ:")
        print("1. Скачайте ICDAR2015 с https://rrc.cvc.uab.es/?ch=4")
        print("2. Распакуйте в следующую структуру:")
        print(f"   {config.ICDAR2015_ROOT}/")
        print("   ├── ch4_training_images/")
        print("   ├── ch4_training_localization_transcription_gt/")
        print("   ├── ch4_test_images/")
        print("   └── ch4_test_localization_transcription_gt/")
        print("\n3. Перезапустите скрипт")
        return
    
    # Подготовка данных
    processed_paths = config.get_processed_paths()
    if not processed_paths['config'].exists():
        print(f"\n Обработанные данные не найдены")
        success = prepare_data(config)
        if not success:
            return
    
    # Выполнение в зависимости от режима
    if config.MODE == 'train':
        print(f"\nРежим: обучение модели {config.ARCHITECTURE}")
        params = {
            'min_size': config.MIN_SIZE,
            'max_size': config.MAX_SIZE,
            'backbone': config.BACKBONE
        }
        results = run_training(config, config.ARCHITECTURE, **params)
        
        if results:
            # Сохранение результатов
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            results_file = Path(config.BASE_WORKING_DIR) / f"results_{config.ARCHITECTURE}_{timestamp}.json"
            
            with open(results_file, 'w') as f:
                json.dump({
                    'architecture': config.ARCHITECTURE,
                    'config': {
                        'num_epochs': config.NUM_EPOCHS,
                        'batch_size': config.BATCH_SIZE,
                        'learning_rate': config.LEARNING_RATE
                    },
                    'results': {
                        'best_map': results['best_metrics'],
                        'final_map': results['val_metrics'][-1]['map'] if results['val_metrics'] else 0
                    }
                }, f, indent=2)
            
            print(f"\n Результаты сохранены: {results_file}")
    
    elif config.MODE == 'test':
        print(f"\nРежим: тестирование модели")
        # Поиск модели, если не указан путь
        if config.MODEL_PATH is None:
            model_files = list(Path(config.BASE_WORKING_DIR).glob("*.pth"))
            if not model_files:
                print("Сохраненных моделей не найдено!")
                return
            # Берем последнюю модель
            config.MODEL_PATH = str(model_files[-1])
            print(f"Используется модель: {config.MODEL_PATH}")
        
        test_model(config, config.MODEL_PATH, config.ARCHITECTURE)
    
    elif config.MODE == 'compare':
        print(f"\nРежим: сравнение архитектур")
        architectures = config.COMPARE_ARCHITECTURES
        results = {}
        for arch in architectures:
            print(f"\nОбучение {arch}...")
            try:
                result = run_training(
                    config,
                    arch,
                    num_epochs=min(10, config.NUM_EPOCHS)
                )
                if result:
                    results[arch] = result['best_metrics']
                    print(f" {arch}: mAP = {result['best_metrics']:.4f}")
            except Exception as e:
                print(f"Ошибка при обучении {arch}: {e}")
        
        # Визуализация сравнения
        if results:
            print("\nРЕЗУЛЬТАТЫ СРАВНЕНИЯ:")
            for arch, score in sorted(results.items(), key=lambda x: x[1], reverse=True):
                print(f"  {arch}: {score:.4f}")
    
    elif config.MODE == 'visualize':
        print(f"\\nРежим: визуализация предсказаний")
        # Создание датасета
        processed_paths = config.get_processed_paths()
        test_paths = config.get_test_paths()
        
        dataset = ICDAR2015Dataset(
            test_paths['images'],
            processed_paths['val_annotations'],
            transform=get_val_transforms(),
            max_size=config.VISUALIZE_SAMPLES
        )
        
        # Загрузка модели
        if config.MODEL_PATH is None:
            model_files = list(Path(config.BASE_WORKING_DIR).glob("*.pth"))
            if not model_files:
                print("Сохраненных моделей не найдено!")
                return
            config.MODEL_PATH = str(model_files[-1])
        
        print(f"Используется модель: {config.MODEL_PATH}")
        
        # Определение архитектуры по имени файла, если не задано
        if config.ARCHITECTURE is None:
            if 'faster_rcnn' in config.MODEL_PATH.lower():
                config.ARCHITECTURE = 'faster_rcnn'
            elif 'retinanet' in config.MODEL_PATH.lower():
                config.ARCHITECTURE = 'retinanet'
            else:
                config.ARCHITECTURE = 'faster_rcnn'
        
        # Загрузка модели
        model = ModelFactory.create_model(
            architecture=config.ARCHITECTURE,
            num_classes=config.NUM_CLASSES
        )
        
        checkpoint = torch.load(config.MODEL_PATH, map_location='cpu', weights_only=False)
        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
        else:
            model.load_state_dict(checkpoint)
        
        # Перемещение модели на правильное устройство
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model.to(device)  
        model.eval()
        
        # Визуализация
        Visualizer.visualize_predictions(model, dataset, config, num_samples=config.VISUALIZE_SAMPLES)
        

In [None]:
# ============================================================================
# ТОЧКА ВХОДА
# ============================================================================

if __name__ == "__main__":
    main()