In [1]:
from omegaconf import OmegaConf
import argparse
from models import get_model
from data import get_dataloaders
from optimizers import get_optimizer
from losses import get_loss
from schedulers import get_scheduler
from metrics import get_metrics
from callbacks import get_callbacks
from trainers.base_trainer import BaseTrainer
from loggers import setup_logger, get_output_logger
from utils.loggers import setup_logger
from utils.seed import seed_everything
from utils.get_experiment_id import get_experiment_id
from utils.load_checkpoint import load_checkpoint
from utils.wandb_login import wandb_login
from utils.filter_wrong_predictions import filter_wrong_predictions
import torch
import wandb

In [2]:
config = OmegaConf.load('configs/example.yaml')
config.experiment_id = get_experiment_id(config)
seed_everything(config.training.seed)
logger = setup_logger()
logger.info("Configuración cargada:")
logger.info(OmegaConf.to_yaml(config))

# Configuración cuda                            
# TO DO: Cambiarlo a una forma mas adecuada, y seleccionar la gpu que se quiera usar
if config.training.device == "cuda":
        device = "cuda" if torch.cuda.is_available() else "cpu"
        config.training.device = device
    
# Dataloaders
train_loader, val_loader, test_loader = get_dataloaders(config)

print(train_loader.dataset.dataset.transform)

2025-05-09 09:49:40,359 - Configuración cargada:
2025-05-09 09:49:40,365 - dataset:
  name: CIFAR10
  root: ./data/datasets
  batch_size: 256
  num_workers: 4
model:
  name: resnet18
  weights: ResNet18_Weights.DEFAULT
  num_classes: 10
preprocessing:
- name: to_tensor
- name: random_horizontal_flip
  probability: 0.5
- name: random_crop
  padding: 4
  size: 32
  probability: 0.2
loss:
  name: cross_entropy
optimizer:
  name: adam
  lr: 0.001
  weight_decay: 0.0001
  scheduler: step
  step_size: 10
  gamma: 0.1
training:
  epochs: 1
  batch_size: 256
  num_workers: 4
  seed: 42
  device: cuda
  log_dir: logs
scheduler:
  name: reduce_on_plateau
  patience: 1
  factor: 0.1
metrics:
- name: f1_score
  average: weighted
callbacks:
- name: checkpoint
  dirpath: checkpoints/
  monitor: Val_loss
  mode: min
- name: wandb_logger
  project: my_project
  entity: inaki
output_logger:
  name: wandb_img_output_errors
experiment_id: resnet18_ab53919c



Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Compose([<data.preprocessing.to_tensor.ToTensor object at 0x00000191A6C76500>, <data.preprocessing.random_horizontal_flip.RandomHorizontalFlip object at 0x00000191A6C767D0>, <data.preprocessing.random_crop.RandomCrop object at 0x00000191A6C76830>])


In [3]:
 # Modelo
model = get_model(config.model).to(config.training.device)

# Loss, Optimizer, Scheduler, Callbacks y metrics
criterion = get_loss(config.loss)
optimizer = get_optimizer(config.optimizer, model.parameters())
scheduler = get_scheduler(config.scheduler, optimizer)
callbacks = get_callbacks(config.callbacks)
metrics = get_metrics(config.metrics)

# Entrenador
trainer = BaseTrainer(
    model, criterion, optimizer, scheduler, config, logger, callbacks, metrics
)
trainer.train(train_loader, val_loader)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.


Run name: resnet18_ab53919c


[34m[1mwandb[0m: Currently logged in as: [33minakitodc[0m ([33minaki[0m). Use [1m`wandb login --relogin`[0m to force relogin


                                                                                                    

KeyboardInterrupt: 

In [None]:
# cargar el mejor checkpoint sobre validación en el modelo, cuya referencia ya se encuentra en el trainer
load_checkpoint(model, config)

# Evaluar métricas en el conjunto de validación
val_metrics, inputs, outputs, targets = trainer.run_epoch(
    val_loader, mode="Val", return_preds=True
)
print("Métricas de validación:")
print(val_metrics)

# TO DO: loggear inputs, outputs y targets en wandb
# classes_names for cifar10
label_names = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
]
inputs, outputs, targets = filter_wrong_predictions(inputs, outputs, targets)
output_logger = get_output_logger(config.output_logger)
output_logger(inputs, outputs, targets, label_names)


  checkpoint = torch.load(f"checkpoints/{config.experiment_id}/best.pth", map_location=config.training.device)
                                                                                                    

Métricas de validación:
{'Val_loss': 0.8249986469745636, 'Val_f1': 0.7286274831820247}


In [None]:
wandb.finish()

0,1
Val_f1,▁
Val_loss,▁

0,1
Train_loss,0.8685
Val_f1,0.62173
Val_loss,0.825
lr,0.001
