In [1]:
import sys
import os
from pathlib import Path
sys.path.append(str(Path().resolve().parent))
sys.path.insert(0,  os.path.dirname(os.path.dirname(os.getcwd()) ))
from Dataset import SROIEDataset2
from torch.utils.data import DataLoader
from tqdm import tqdm
from ctpn_model import *

In [2]:
train_dataset = SROIEDataset2('train','../../data/train/img','../../data/train/box',  )
val_dataset = SROIEDataset2('val','../../data/train/img','../../data/train/box',  )
test_dataset = SROIEDataset2(None, '../../data/test/img','../../data/test/box', )

sub_train_dataset = [train_dataset[i] for i in range(5)]
sub_val_dataset = [train_dataset[i] for i in range(3)]

train_loader = DataLoader(sub_train_dataset, batch_size=1,)
val_loader = DataLoader(sub_val_dataset, batch_size=1,)


In [3]:
len(train_dataset), len(val_dataset), len(test_dataset)

(563, 63, 347)

In [None]:
def train_epoch(model, dataloader, critetion_cls, critetion_regr,
                optimizer, device, epoch):
    """Одна эпоха обучения для полной CTPN модели"""
    model.train()

    progress_bar = tqdm(dataloader, desc=f'Epoch {epoch}')

    epoch_size = len(dataloader) // 1
    epoch_loss_cls = 0
    epoch_loss_regr = 0
    epoch_loss = 0

    for batch_idx, (imgs, clss, regrs) in enumerate(progress_bar):
            # Перенос изображений на устройство
            imgs = imgs.to(device)
            clss = clss.to(device)
            regrs = regrs.to(device)

            out_cls, out_regr = model(imgs)
            loss_cls = critetion_cls(out_cls, clss)
            loss_regr = critetion_regr(out_regr, regrs)

            loss = loss_cls + loss_regr  # total loss
            loss.backward()
            optimizer.step()
    
            epoch_loss_cls += loss_cls.item()
            epoch_loss_regr += loss_regr.item()
            epoch_loss += loss.item()
            mmp = batch_idx + 1
    epoch_loss_cls /= epoch_size
    epoch_loss_regr /= epoch_size
    epoch_loss /= epoch_size

    
            
    # Обновление progress bar
    progress_bar.set_postfix({
                'loss': epoch_loss,
                'cls': epoch_loss_cls,
                'reg': epoch_loss_regr
    })
            

    return epoch_loss_cls, epoch_loss_regr, epoch_loss


def validate_epoch(model, val_loader, critetion_cls, critetion_regr, device):
    """Валидационная эпоха по аналогии с train_epoch"""
    model.eval()
    
    progress_bar = tqdm(val_loader, desc='Validation')
    
    epoch_size = len(val_loader) // 1
    epoch_loss_cls = 0
    epoch_loss_regr = 0
    epoch_loss = 0
    
    with torch.no_grad():
        for batch_idx, (imgs, clss, regrs) in enumerate(progress_bar):
            # Перенос данных на устройство
            imgs = imgs.to(device)
            clss = clss.to(device)
            regrs = regrs.to(device)
            
            # Forward pass
            out_cls, out_regr = model(imgs)
            
            # Вычисление лоссов
            loss_cls = critetion_cls(out_cls, clss)
            loss_regr = critetion_regr(out_regr, regrs)
            loss = loss_cls + loss_regr
            
            # Накопление статистики
            epoch_loss_cls += loss_cls.item()
            epoch_loss_regr += loss_regr.item()
            epoch_loss += loss.item()
            
            # Обновление progress bar
            progress_bar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'cls': f"{loss_cls.item():.4f}",
                'reg': f"{loss_regr.item():.4f}"
            })
    
    # Средние значения за эпоху
    mmp = batch_idx + 1 if batch_idx >= 0 else 1
    epoch_loss_cls /= mmp
    epoch_loss_regr /= mmp
    epoch_loss /= mmp
    

    return epoch_loss_cls, epoch_loss_regr, epoch_loss


def train_model(model, train_loader, val_loader, critetion_cls, critetion_regr, optimizer, 
                device, num_epochs=10):
    """
    Обучение модели CTPN с сохранением лучшей модели по валидации
    """
    # Перемещаем модель на устройство
    model = model.to(device)

    # История обучения
    history = {
        'train_total': [], 
        'train_cls': [], 
        'train_reg': [], 
        'val_total': [], 'val_cls': [], 'val_reg': [], 
        'best_val_loss': float('inf'),

    }
    
    # Общий прогресс-бар для эпох
    epoch_pbar = tqdm(range(num_epochs), desc='Training', unit='epoch')
    
    for epoch in epoch_pbar:
        epoch_num = epoch + 1
        
        # Обучение
        train_losses = train_epoch(model, train_loader, critetion_cls, critetion_regr,
                optimizer, device, epoch)
        
        epoch_loss_cls, epoch_loss_regr, epoch_loss = train_losses
        
        # Валидация
        val_losses = validate_epoch(model, val_loader, critetion_cls, critetion_regr, device)

        v_epoch_loss_cls, v_epoch_loss_regr, v_epoch_loss  = val_losses
        
        # Сохраняем историю
        history['train_total'].append(epoch_loss)
        history['train_cls'].append(epoch_loss_cls)
        history['train_reg'].append(epoch_loss_regr)

        
        history['val_total'].append(v_epoch_loss)
        history['val_cls'].append(v_epoch_loss_cls)
        history['val_reg'].append(v_epoch_loss_regr)

        
        # Сохраняем лучшую модель
        if v_epoch_loss < history['best_val_loss']:
            history['best_val_loss'] = v_epoch_loss
            history['best_epoch'] = epoch_num
            
            # Сохраняем модель
            torch.save({
                'epoch': epoch_num,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': epoch_loss,
                'val_loss': v_epoch_loss,
                'history': history
            }, 'best_ctpn_model.pth')
        
        # Обновляем описание progress bar
        epoch_pbar.set_postfix({
            'train_loss': f'{v_epoch_loss:.4f}',
            'val_loss': f'{v_epoch_loss:.4f}',
            'best_val': f'{history["best_val_loss"]:.4f}',
            'best_ep': history['best_epoch']
        })



In [None]:
model = CTPN_Model()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.0001)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
critetion_cls = RPN_CLS_Loss(device)
model.to(device)
critetion_regr = RPN_REGR_Loss(device)

In [15]:
for i in range(1, 3):
    train_epoch(model, train_loader, critetion_cls, critetion_regr,
                optimizer, device, i)

Epoch 1: 100%|██████████| 5/5 [00:12<00:00,  2.57s/it]


Ep:1/2--Batch:4/5
batch: loss_cls:0.5123--loss_regr:0.2164--loss:0.7287
Epoch: loss_cls:0.1156--loss_regr:0.0552--loss:0.1708



Epoch 2: 100%|██████████| 5/5 [00:12<00:00,  2.42s/it]

Ep:2/2--Batch:4/5
batch: loss_cls:0.4089--loss_regr:0.1758--loss:0.5847
Epoch: loss_cls:0.0817--loss_regr:0.0404--loss:0.1221




