# Entrenamiento baseline
Modelo sencillo (ResNet18) que predice Dry_Clover_g, Dry_Green_g y Dry_Dead_g a partir de cada foto.

In [None]:
import os
import sys
import random
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Detectar si estamos en Google Colab
try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=False)
    project_path = '/content/drive/MyDrive/image2biomass'
    if os.path.exists(project_path):
        os.chdir(project_path)
        print(f"Directorio de trabajo cambiado a: {os.getcwd()}")
    else:
        print(f"Advertencia: No se encontró el directorio {project_path}")
else:
    sys.path.append('../')

from utils.paths import get_data_path

import torch

from src.utils.seed import set_seed
from src.utils.config import TrainingConfig
from src.utils.metrics import DEFAULT_TARGET_WEIGHTS, weighted_r2_score
from src.data.dataloader import make_dataloaders
from src.models.resnet import create_resnet
from src.training.trainer import Trainer
from src.inference.predictor import Predictor

SEED = 42
cfg = TrainingConfig()
set_seed(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
base_path = Path(get_data_path())
train_df = pd.read_csv(base_path / 'train.csv')
test_df = pd.read_csv(base_path / 'test.csv')

targets = ['Dry_Clover_g', 'Dry_Green_g', 'Dry_Dead_g']
pivot = (
    train_df
    .pivot_table(index='image_path', columns='target_name', values='target')
    .reset_index()
)
pivot = pivot[['image_path'] + targets].dropna().reset_index(drop=True)
print(f"Imagenes disponibles: {len(pivot)}")
pivot.head()

In [None]:
# Split simple 80/20
perm = np.random.permutation(len(pivot))
split = int(len(pivot) * 0.8)
train_meta = pivot.iloc[perm[:split]].reset_index(drop=True)
val_meta = pivot.iloc[perm[split:]].reset_index(drop=True)

train_loader, val_loader, val_tfms = make_dataloaders(
    train_meta,
    val_meta,
    targets=targets,
    images_root=base_path,
    img_size=cfg.img_size,
    batch_size=cfg.batch_size,
    num_workers=cfg.num_workers,
    )
len(train_meta), len(val_meta)

In [None]:
model = create_resnet(len(targets)).to(device)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)

trainer = Trainer(model, criterion, optimizer, device)
history = trainer.fit(train_loader, val_loader, epochs=cfg.epochs, checkpoint_path=cfg.checkpoint_path)

model.load_state_dict(torch.load(cfg.checkpoint_path, map_location=device))
model.eval()

In [None]:
# Curvas de entrenamiento
history_df = pd.DataFrame(history)
plt.style.use("seaborn-v0_8-darkgrid")
fig, ax = plt.subplots(figsize=(7, 4))
ax.plot(history_df["epoch"], history_df["train_loss"], marker="o", label="Train loss")
ax.plot(history_df["epoch"], history_df["val_loss"], marker="o", label="Val loss")
ax.set_xlabel("Época")
ax.set_ylabel("MSE")
ax.set_title("Evolución de la pérdida")
ax.legend()
ax.grid(True, alpha=0.3)
fig.tight_layout()
fig

In [None]:
# Métrica ponderada y dispersión de validación
model.eval()
y_true, y_pred, name_labels = [], [], []

with torch.no_grad():
    for images, targets_batch in val_loader:
        images = images.to(device)
        preds = model(images).cpu().numpy()
        targets_np = targets_batch.numpy()
        batch_size = targets_np.shape[0]
        y_true.append(targets_np.reshape(-1))
        y_pred.append(preds.reshape(-1))
        name_labels.append(np.tile(np.array(targets), batch_size))

y_true_flat = np.concatenate(y_true)
y_pred_flat = np.concatenate(y_pred)
name_labels_flat = np.concatenate(name_labels)

val_weighted_r2 = weighted_r2_score(
    y_true=y_true_flat,
    y_pred=y_pred_flat,
    target_names=name_labels_flat,
    target_weights=DEFAULT_TARGET_WEIGHTS,
 )
print(f"R2 ponderado validación: {val_weighted_r2:.4f}")

fig, axes = plt.subplots(1, len(targets), figsize=(4 * len(targets), 4), sharex=False, sharey=False)
for idx, target_name in enumerate(targets):
    mask = name_labels_flat == target_name
    ax = axes[idx] if len(targets) > 1 else axes
    ax.scatter(y_true_flat[mask], y_pred_flat[mask], alpha=0.5, s=14, label="Pred vs GT")
    min_val = min(y_true_flat[mask].min(), y_pred_flat[mask].min())
    max_val = max(y_true_flat[mask].max(), y_pred_flat[mask].max())
    ax.plot([min_val, max_val], [min_val, max_val], color="tab:red", linewidth=1, label="Ideal")
    ax.set_title(target_name)
    ax.set_xlabel("Ground truth")
    ax.set_ylabel("Predicción")
    ax.legend()
fig.suptitle("Dispersión predicciones vs verdad en validación", y=1.02, fontsize=12)
fig.tight_layout()
fig

In [None]:
predictor = Predictor(model, device)
submission = predictor.predict(
    test_df,
    targets,
    images_root=base_path,
    transform=val_tfms,
    batch_size=cfg.batch_size,
    num_workers=cfg.num_workers,
 )
os.makedirs(cfg.model_dir, exist_ok=True)
submission.to_csv(cfg.model_dir / 'submission_baseline.csv', index=False)
submission.head()