In [1]:
from model import *
from loss import *


import sys
import os

sys.path.insert(0,  os.path.dirname(os.path.dirname(os.getcwd()) ))

from DataLoader import SROIEDataset
from torch.utils.data import DataLoader


import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import numpy as np
from tqdm import tqdm


In [2]:
def ctpn_collate_fn(batch):
    """Collate функция для CTPN"""
    images = []
    boxes_list = []
    original_quads_list = []
    texts_list = []
    image_ids = []
    original_sizes = []
    scale_factors = []
    
    for item in batch:
        images.append(item['image'])
        boxes_list.append(torch.tensor(item['boxes'], dtype=torch.float32))
        original_quads_list.append(item['original_quads'])
        texts_list.append(item['texts'])
        image_ids.append(item['image_id'])
        original_sizes.append(item['original_size'])
        scale_factors.append(item['scale_factors'])
    
    # Стекируем изображения одинакового размера
    images = torch.stack(images, dim=0)
    
    return {
        'images': images,  # [B, 3, H, W]
        'boxes': boxes_list,  # список тензоров [N_i, 4]
        'original_quads': original_quads_list,
        'texts': texts_list,
        'image_ids': image_ids,
        'original_sizes': original_sizes,
        'scale_factors': scale_factors
    }


In [3]:
train_dataset = SROIEDataset('../../data/test/img','../../data/test/box', target_size=(640, 640))

sub_train_dataset = [train_dataset[i] for i in range(20)]

train_loader = DataLoader(sub_train_dataset, batch_size=5, collate_fn=ctpn_collate_fn,)


In [4]:
batch = next(iter(train_loader))

In [5]:
batch.keys()

dict_keys(['images', 'boxes', 'original_quads', 'texts', 'image_ids', 'original_sizes', 'scale_factors'])

In [6]:
batch['images'].shape

torch.Size([5, 3, 640, 640])

In [7]:
# Пример использования с полными таргетами
def prepare_full_ctpn_targets(batch, num_anchors=10, anchor_scales=None, feature_stride=16):
    """Исправленная версия с правильными размерами таргетов"""
    if anchor_scales is None:
        anchor_scales = [11, 16, 23, 33, 48, 68, 97, 139, 198, 283]
    
    images = batch['images']
    boxes_list = batch['boxes']
    
    B, C, H, W = images.shape
    feat_h, feat_w = H // feature_stride, W // feature_stride
    
    # Только 2 канала для регрессии (dy, dh)
    cls_targets = torch.zeros(B, feat_h, feat_w, num_anchors, 2, dtype=torch.float32)
    reg_targets = torch.zeros(B, feat_h, feat_w, num_anchors, 2, dtype=torch.float32)  # dy, dh
    side_targets = torch.zeros(B, feat_h, feat_w, num_anchors, 2, dtype=torch.float32)  # dx_left, dx_right
    
    for b_idx in range(B):
        boxes = boxes_list[b_idx]
        
        for box in boxes:
            if len(box) == 0:
                continue
            
            x_min, y_min, x_max, y_max = box
            gt_height = y_max - y_min
            gt_center_y = (y_min + y_max) / 2
            gt_center_x = (x_min + x_max) / 2
            
            # Координаты на feature map
            feat_x = min(int(gt_center_x // feature_stride), feat_w - 1)
            feat_y = min(int(gt_center_y // feature_stride), feat_h - 1)
            
            # Для каждого anchor
            for a_idx, anchor_h in enumerate(anchor_scales):
                anchor_center_y = (feat_y + 0.5) * feature_stride
                
                # Vertical IoU
                anchor_y_min = anchor_center_y - anchor_h / 2
                anchor_y_max = anchor_center_y + anchor_h / 2
                
                y_min_inter = max(anchor_y_min, y_min)
                y_max_inter = min(anchor_y_max, y_max)
                inter_h = max(0, y_max_inter - y_min_inter)
                
                union_h = (anchor_y_max - anchor_y_min) + (y_max - y_min) - inter_h
                vertical_iou = inter_h / union_h if union_h > 0 else 0
                
                if vertical_iou > 0.7:  # Positive
                    cls_targets[b_idx, feat_y, feat_x, a_idx] = torch.tensor([0., 1.])
                    
                    dy = (gt_center_y - anchor_center_y) / anchor_h
                    dh = torch.log(torch.tensor(gt_height / anchor_h, dtype=torch.float32))
                    
                    reg_targets[b_idx, feat_y, feat_x, a_idx, 0] = dy
                    reg_targets[b_idx, feat_y, feat_x, a_idx, 1] = dh
                    
                    dx_left = (x_min - feat_x * feature_stride) / feature_stride
                    dx_right = (x_max - feat_x * feature_stride) / feature_stride
                    
                    side_targets[b_idx, feat_y, feat_x, a_idx, 0] = dx_left
                    side_targets[b_idx, feat_y, feat_x, a_idx, 1] = dx_right
                    
                elif vertical_iou < 0.3:  # Negative
                    cls_targets[b_idx, feat_y, feat_x, a_idx] = torch.tensor([1., 0.])
    
    # Reshape
    cls_targets = cls_targets.view(B, feat_h, feat_w, -1)
    reg_targets = reg_targets.view(B, feat_h, feat_w, -1)
    side_targets = side_targets.view(B, feat_h, feat_w, -1)
    
    return {
        'cls_targets': cls_targets,
        'reg_targets': reg_targets,
        'side_targets': side_targets
    }


def train_epoch(model, dataloader, criterion, optimizer, device, epoch, 
                anchor_scales=None, num_anchors=10):
    """Одна эпоха обучения для полной CTPN модели"""
    model.train()
    total_loss = 0
    cls_loss_sum = 0
    reg_loss_sum = 0
    side_loss_sum = 0
    
    progress_bar = tqdm(dataloader, desc=f'Epoch {epoch}')
    
    for batch_idx, batch in enumerate(progress_bar):
        try:
            # 1. Перенос изображений на устройство
            images = batch['images'].to(device)
            
            # 2. Подготовка полных таргетов для CTPN
            gt_data = prepare_full_ctpn_targets(
                batch, 
                num_anchors=num_anchors,
                anchor_scales=anchor_scales,
                feature_stride=16  # VGG16 stride
            )
            
            # 3. Forward pass - получаем 3 выхода
            optimizer.zero_grad()
            cls_pred, reg_pred, side_pred = model(images)
            
            # 4. Проверка размеров (для отладки)
            B, C_cls, H_cls, W_cls = cls_pred.shape
            B, C_reg, H_reg, W_reg = reg_pred.shape
            B, C_side, H_side, W_side = side_pred.shape
            
            cls_target = gt_data['cls_targets']
            reg_target = gt_data['reg_targets']
            side_target = gt_data['side_targets']
            
            print(f"\nBatch {batch_idx} - Shape Debug:")
            print(f"  Images: {images.shape}")
            print(f"  cls_pred: {cls_pred.shape}, cls_target: {cls_target.shape}")
            print(f"  reg_pred: {reg_pred.shape}, reg_target: {reg_target.shape}")
            print(f"  side_pred: {side_pred.shape}, side_target: {side_target.shape}")
            
            # Проверяем соответствие размеров
            expected_h = images.shape[2] // 16  # feature map height
            expected_w = images.shape[3] // 16  # feature map width
            
            if H_cls != expected_h or W_cls != expected_w:
                print(f"  WARNING: cls_pred имеет размер {H_cls}x{W_cls}, ожидалось {expected_h}x{expected_w}")
                # Делаем ресайз если нужно
                cls_pred = F.interpolate(cls_pred, size=(expected_h, expected_w), mode='bilinear')
                reg_pred = F.interpolate(reg_pred, size=(expected_h, expected_w), mode='bilinear')
                side_pred = F.interpolate(side_pred, size=(expected_h, expected_w), mode='bilinear')
            
            # 5. Вычисление loss
            losses = criterion(cls_pred, reg_pred, side_pred, gt_data)
            
            # 6. Backward pass
            losses['total'].backward()
            
            # 7. Gradient clipping (важно для RNN)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
            
            # 8. Optimizer step
            optimizer.step()
            
            # 9. Статистика
            total_loss += losses['total'].item()
            cls_loss_sum += losses['cls'].item()
            reg_loss_sum += losses['reg'].item()
            side_loss_sum += losses['side'].item()
            
            # 10. Обновление progress bar
            progress_bar.set_postfix({
                'loss': losses['total'].item(),
                'cls': losses['cls'].item(),
                'reg': losses['reg'].item(),
                'side': losses['side'].item()
            })
            
            # 11. Логирование каждые 10 батчей
            if batch_idx % 10 == 0:
                # Проверяем количество позитивных anchor'ов
                pos_anchors = (gt_data['cls_targets'][..., 1::2].sum() / B).item()
                total_anchors = cls_target.shape[1] * cls_target.shape[2] * num_anchors
                
                print(f"\n  Batch {batch_idx} Stats:")
                print(f"    Positive anchors: {pos_anchors:.1f} per image")
                print(f"    Anchor ratio: {pos_anchors/total_anchors*100:.2f}%")
                print(f"    Losses - Total: {losses['total'].item():.4f}, "
                      f"Cls: {losses['cls'].item():.4f}, "
                      f"Reg: {losses['reg'].item():.4f}, "
                      f"Side: {losses['side'].item():.4f}")
                
                # Проверяем градиенты
                total_norm = 0
                for p in model.parameters():
                    if p.grad is not None:
                        param_norm = p.grad.data.norm(2)
                        total_norm += param_norm.item() ** 2
                total_norm = total_norm ** 0.5
                print(f"    Gradient norm: {total_norm:.4f}")
            
        except Exception as e:
            print(f"\nError in batch {batch_idx}: {e}")
            import traceback
            traceback.print_exc()
            continue
    
    # 12. Вычисляем средние losses за эпоху
    num_batches = max(len(dataloader), 1)
    avg_total_loss = total_loss / num_batches
    avg_cls_loss = cls_loss_sum / num_batches
    avg_reg_loss = reg_loss_sum / num_batches
    avg_side_loss = side_loss_sum / num_batches
    
    return avg_total_loss, avg_cls_loss, avg_reg_loss, avg_side_loss

# Конфигурация обучения
config = {
    'num_epochs': 2,
    'lr': 0.001,
    'weight_decay': 0.0001,
    'batch_size': 16,
    'step_size': 10,
    'gamma': 0.1,

}

# Создание модели
model = CustomCTPN(2)

# Разделение датасета на train/val
from torch.utils.data import random_split




criterion = CTPNLoss()
    
optimizer = optim.Adam(
        model.parameters(),
        lr=config.get('lr', 0.001),
        weight_decay=config.get('weight_decay', 0.0001)
    )
# Запуск обучения
history = []
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



for i in range(3):
     avg_total_loss, avg_cls_loss, avg_reg_loss, avg_side_loss = train_epoch(model, train_loader, criterion, optimizer, device, epoch=i+1)
     print(avg_total_loss)
     history.append((avg_total_loss, avg_reg_loss))


NameError: name 'torchvision' is not defined

In [31]:
train_dataset = SROIEDataset('../../data/test/img','../../data/test/box', target_size=(640, 640))
data = DataLoader(train_dataset, batch_size=16, collate_fn=debug_collate_fn,)
#batch = next(iter(data))

In [35]:
progress_bar = tqdm(data, desc=f'Epoch {1}')
for batch_idx, batch in enumerate(progress_bar):
    print(batch_idx)

Epoch 1:   9%|▉         | 2/22 [00:00<00:02,  7.47it/s]

0
1


Epoch 1:  18%|█▊        | 4/22 [00:00<00:02,  7.97it/s]

2
3


Epoch 1:  27%|██▋       | 6/22 [00:00<00:02,  7.88it/s]

4
5


Epoch 1:  36%|███▋      | 8/22 [00:01<00:01,  7.55it/s]

6
7


Epoch 1:  45%|████▌     | 10/22 [00:01<00:01,  7.90it/s]

8
9


Epoch 1:  50%|█████     | 11/22 [00:01<00:01,  8.15it/s]

10


Epoch 1:  64%|██████▎   | 14/22 [00:02<00:01,  4.39it/s]

11
12
13


Epoch 1:  64%|██████▎   | 14/22 [00:02<00:01,  5.64it/s]


UnicodeDecodeError: 'utf-8' codec can't decode byte 0xa3 in position 407: invalid start byte

In [12]:
batch['images']

tensor([[[[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 0.9922, 0.9922,  ..., 1.0000, 1.0000, 0.9961],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 0.9922, 0.9922,  ..., 1.0000, 1.0000, 0.9961],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 0.9922, 0.9922,  ..., 1

In [None]:
import torch

t = torch.rand(1, 3, 640 , 640)

In [4]:
m(t)

(tensor([[[[-0.1326, -0.0598, -0.0072,  ..., -0.1162, -0.0494, -0.0500],
           [-0.0024, -0.0079, -0.0175,  ..., -0.0939, -0.0204, -0.1438],
           [-0.1270, -0.0676, -0.0760,  ..., -0.1082, -0.1070, -0.1507],
           ...,
           [-0.2015, -0.0295, -0.0656,  ..., -0.1975, -0.0983, -0.0997],
           [-0.0518, -0.0766,  0.0254,  ..., -0.0168,  0.0392, -0.0329],
           [-0.1488, -0.0762, -0.0569,  ..., -0.1711, -0.0426, -0.2019]],
 
          [[ 0.1659,  0.0420,  0.0331,  ..., -0.0267, -0.0200, -0.0032],
           [ 0.1623,  0.0652,  0.0162,  ...,  0.0999,  0.0639,  0.0758],
           [ 0.0593,  0.0622,  0.1048,  ...,  0.0652, -0.0144, -0.0039],
           ...,
           [ 0.1122,  0.1004,  0.1603,  ...,  0.0776,  0.0340,  0.0548],
           [ 0.1279,  0.0685,  0.1564,  ...,  0.0705,  0.0901,  0.0615],
           [ 0.1180,  0.0681,  0.0249,  ...,  0.0204,  0.0867,  0.0599]]]],
        grad_fn=<ConvolutionBackward0>),
 tensor([[[[ 0.2196,  0.1969,  0.0748,  ..., 