In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import AdamW, get_linear_schedule_with_warmup
from models.dataset import EDCopilotDataset
from models.ed_copilot_sft import EDCopilotSFT
from tqdm import tqdm
import wandb
from pathlib import Path

# Configura√ß√µes (conforme Tabela 8 do paper)
CONFIG = {
    'model_name': 'microsoft/BioGPT',
    'batch_size': 32,
    'learning_rate': 1e-5,
    'epochs': 15,
    'warmup_percentage': 0.1,
    'weight_decay': 0.01,
    'class_weight': 10.0,
    'max_length': 656,
    'save_dir': 'models/checkpoints/sft'
}

In [None]:
def train_epoch(model, dataloader, optimizer, scheduler, device):
    model.train()
    total_loss = 0
    
    progress_bar = tqdm(dataloader, desc="Training")
    for batch in progress_bar:
        # Move to device
        batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v 
                 for k, v in batch.items()}
        
        # Forward
        outputs = model(batch['input_ids'], batch['attention_mask'])
        loss_dict = model.compute_loss(outputs, batch, CONFIG['class_weight'])
        
        # Backward
        optimizer.zero_grad()
        loss_dict['loss'].backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        
        # Log
        total_loss += loss_dict['loss'].item()
        progress_bar.set_postfix({
            'loss': loss_dict['loss'].item(),
            'loss_lab': loss_dict['loss_lab'],
            'loss_outcome': loss_dict['loss_outcome']
        })
        
        wandb.log({
            'train/loss': loss_dict['loss'].item(),
            'train/loss_lab': loss_dict['loss_lab'],
            'train/loss_outcome': loss_dict['loss_outcome'],
            'train/lr': scheduler.get_last_lr()[0]
        })
    
    return total_loss / len(dataloader)

def validate(model, dataloader, device):
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validation"):
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v 
                     for k, v in batch.items()}
            
            outputs = model(batch['input_ids'], batch['attention_mask'])
            loss_dict = model.compute_loss(outputs, batch, CONFIG['class_weight'])
            total_loss += loss_dict['loss'].item()
    
    return total_loss / len(dataloader)

In [None]:
wandb.init(project="ed-copilot-tcc", config=CONFIG)

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üöÄ Using device: {device}")

# Datasets
print("üìÇ Loading datasets...")
train_dataset = EDCopilotDataset(
    'data/processed/linearized/train.parquet',
    CONFIG['model_name'],
    CONFIG['max_length']
)
val_dataset = EDCopilotDataset(
    'data/processed/linearized/val.parquet',
    CONFIG['model_name'],
    CONFIG['max_length']
)

train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'])

In [None]:
# Model
print("üèóÔ∏è Building model...")
model = EDCopilotSFT(CONFIG['model_name']).to(device)

# Resize embeddings se adicionamos [EOS]
model.backbone.resize_token_embeddings(len(train_dataset.tokenizer))

# Optimizer
optimizer = AdamW(
    model.parameters(),
    lr=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay'],
    betas=(0.9, 0.999),
    eps=1e-8
)

# Scheduler
total_steps = len(train_loader) * CONFIG['epochs']
warmup_steps = int(CONFIG['warmup_percentage'] * total_steps)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

In [None]:
# Training loop
print("\nüéØ Starting training...")
best_val_loss = float('inf')
save_dir = Path(CONFIG['save_dir'])
save_dir.mkdir(parents=True, exist_ok=True)

for epoch in range(CONFIG['epochs']):
    print(f"\nüìÖ Epoch {epoch+1}/{CONFIG['epochs']}")
    
    train_loss = train_epoch(model, train_loader, optimizer, scheduler, device)
    val_loss = validate(model, val_loader, device)
    
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    
    wandb.log({
        'epoch': epoch + 1,
        'train/epoch_loss': train_loss,
        'val/epoch_loss': val_loss
    })
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'config': CONFIG
        }, save_dir / 'best_model.pt')
        print(f"‚úÖ Saved best model (val_loss: {val_loss:.4f})")

print("\nüéâ Training complete!")
wandb.finish()