# Treinamento do RefineDetLite para Detecção de Semáforos

Este notebook implementa o treinamento do modelo RefineDetLite para detecção binária de semáforos.

In [1]:
import torch
import sys
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as T
from pathlib import Path
sys.path.append('..')

from models.refinedlite.model import RefineDetLite
from models.refinedlite.loss import RefineDetLiteLoss
from utils.dataset import BrazilianTrafficLightDataset

# Configurações
CONFIG = {
    'num_classes': 2,  # apenas semáforo (1) e background (0)
    'input_size': 320,
    'batch_size': 8,
    'num_epochs': 50,
    'learning_rate': 1e-3,
    'train_split': 0.8,  # 80% para treino, 20% para validação
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu')
}

# Transformações com data augmentation
transform_train = T.Compose([
    T.Resize((CONFIG['input_size'], CONFIG['input_size'])),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    T.RandomHorizontalFlip(p=0.5),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

transform_val = T.Compose([
    T.Resize((CONFIG['input_size'], CONFIG['input_size'])),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Criar dataset completo
dataset = BrazilianTrafficLightDataset(
    images_dir='../data/openimages/traffic-light/images',
    pascal_dir='../data/openimages/traffic-light/pascal',
    transform=None  # Vamos aplicar as transformações depois da divisão
)

# Dividir em treino e validação
train_size = int(CONFIG['train_split'] * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = random_split(
    dataset, 
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

# Aplicar transformações diferentes para cada split
train_dataset.dataset.transform = transform_train
val_dataset.dataset.transform = transform_val

# Criar dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=4,
    collate_fn=dataset.collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=4,
    collate_fn=dataset.collate_fn
)

print(f'Dataset total: {len(dataset)} imagens')
print(f'Dataset treino: {len(train_dataset)} imagens')
print(f'Dataset validação: {len(val_dataset)} imagens')

Dataset total: 200 imagens
Dataset treino: 160 imagens
Dataset validação: 40 imagens


In [2]:
# Inicialização do modelo e otimizador
model = RefineDetLite(num_classes=CONFIG['num_classes'], input_size=CONFIG['input_size'])
model = model.to(CONFIG['device'])
criterion = RefineDetLiteLoss(num_classes=CONFIG['num_classes'])
optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['learning_rate'])

def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = torch.tensor(0., device=device)  # Inicializa como tensor
    num_batches = 0
    
    for batch_idx, (images, targets) in enumerate(loader):
        images = images.to(device)
        batch_targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        
        optimizer.zero_grad()
        cls_preds, reg_preds = model(images)
        
        try:
            loss = criterion((cls_preds, reg_preds), batch_targets)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.detach()  # Acumula o tensor
            num_batches += 1
            
            if (batch_idx + 1) % 5 == 0:
                print(f'Batch [{batch_idx + 1}/{len(loader)}], Loss: {loss.item():.4f}')
        except Exception as e:
            print(f"Erro no batch {batch_idx}:")
            print(f"Exception: {str(e)}")
            continue
    
    # Retorna a média como float após calcular
    return (total_loss / num_batches).item()

def validate(model, loader, criterion, device):
    model.eval()
    total_loss = torch.tensor(0., device=device)  # Inicializa como tensor
    num_batches = 0
    
    with torch.no_grad():
        for images, targets in loader:
            images = images.to(device)
            batch_targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            
            cls_preds, reg_preds = model(images)
            loss = criterion((cls_preds, reg_preds), batch_targets)
            
            total_loss += loss.detach()  # Acumula o tensor
            num_batches += 1
    
    # Retorna a média como float após calcular
    return (total_loss / num_batches).item()

def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for images, targets in loader:
            images = images.to(device)
            batch_targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            
            cls_preds, reg_preds = model(images)
            loss = criterion((cls_preds, reg_preds), batch_targets)
            total_loss += loss.item()
    
    return total_loss / len(loader)



In [3]:
# Loop de treinamento principal
best_val_loss = float('inf')
train_losses = []
val_losses = []

for epoch in range(CONFIG['num_epochs']):
    print(f"\nEpoch [{epoch+1}/{CONFIG['num_epochs']}]")
    
    # Treinar
    train_loss = train_epoch(model, train_loader, criterion, optimizer, CONFIG['device'])
    train_losses.append(train_loss)
    
    # Validar
    val_loss = validate(model, val_loader, criterion, CONFIG['device'])
    val_losses.append(val_loss)
    
    print(f'Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}')
    
    # Salvar melhor modelo
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
        }, '../../checkpoints/refinedlite_best.pth')

print("\nTreinamento concluído!")


Epoch [1/50]
Feature shape 6: torch.Size([8, 32, 40, 40])
Feature shape 13: torch.Size([8, 96, 20, 20])
Feature shape 17: torch.Size([8, 320, 10, 10])
After TCB 0: torch.Size([8, 256, 10, 10])
After TCB 1: torch.Size([8, 256, 20, 20])
After TCB 2: torch.Size([8, 256, 40, 40])
Feature shape 6: torch.Size([8, 32, 40, 40])
Feature shape 13: torch.Size([8, 96, 20, 20])
Feature shape 17: torch.Size([8, 320, 10, 10])
After TCB 0: torch.Size([8, 256, 10, 10])
After TCB 1: torch.Size([8, 256, 20, 20])
After TCB 2: torch.Size([8, 256, 40, 40])
Feature shape 6: torch.Size([8, 32, 40, 40])
Feature shape 13: torch.Size([8, 96, 20, 20])
Feature shape 17: torch.Size([8, 320, 10, 10])
After TCB 0: torch.Size([8, 256, 10, 10])
After TCB 1: torch.Size([8, 256, 20, 20])
After TCB 2: torch.Size([8, 256, 40, 40])


KeyboardInterrupt: 