In [3]:
# IOU y Funcion de Perdida de YOLOV3

import torch
import torch.nn as nn

def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):
    if box_format == "midpoint":
        box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
        box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
        box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
        box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2

        box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
        box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
        box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
        box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2

    elif box_format == "corners":
        box1_x1 = boxes_preds[..., 0:1]
        box1_y1 = boxes_preds[..., 1:2]
        box1_x2 = boxes_preds[..., 2:3]
        box1_y2 = boxes_preds[..., 3:4]

        box2_x1 = boxes_labels[..., 0:1]
        box2_y1 = boxes_labels[..., 1:2]
        box2_x2 = boxes_labels[..., 2:3]
        box2_y2 = boxes_labels[..., 3:4]
    else:
        raise ValueError("box_format debe ser 'midpoint' o 'corners'")

    x1_inter = torch.max(box1_x1, box2_x1)
    y1_inter = torch.max(box1_y1, box2_y1)
    x2_inter = torch.min(box1_x2, box2_x2)
    y2_inter = torch.min(box1_y2, box2_y2)

    intersection = (x2_inter - x1_inter).clamp(0) * (y2_inter - y1_inter).clamp(0)
    box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
    box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))
    union = box1_area + box2_area - intersection + 1e-6
    iou = intersection / union
    return iou

class YOLOv3Loss(nn.Module):
    def __init__(self, anchors, num_classes, img_size=(416, 416),
                 lambda_coord=1.0, lambda_noobj=1.0, lambda_obj=1.0, lambda_class=1.0,
                 ignore_iou_threshold=0.5, device=None):
        super().__init__()
        self.anchors = anchors
        self.num_classes = num_classes
        self.img_size = img_size
        self.lambda_coord = lambda_coord
        self.lambda_noobj = lambda_noobj
        self.lambda_obj = lambda_obj
        self.lambda_class = lambda_class
        self.ignore_iou_threshold = ignore_iou_threshold

        self.mse = nn.MSELoss()
        self.bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([1.0]))
        self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def forward(self, predictions, targets):
        obj_loss = 0
        noobj_loss = 0
        box_loss = 0
        class_loss = 0

        # Convert anchors to tensor only once, on correct device
        anchors_tensor = [torch.tensor(a, dtype=torch.float32, device=self.device) for a in self.anchors]

        for scale_idx, prediction in enumerate(predictions):
            # prediction: (N, 3*(5+C), H, W)
            prediction = prediction.permute(0, 2, 3, 1).reshape(
                prediction.shape[0], prediction.shape[2], prediction.shape[3], 3, self.num_classes + 5
            )

            pred_x_y = prediction[..., 0:2]
            pred_w_h = prediction[..., 2:4]
            pred_obj = prediction[..., 4:5]
            pred_class = prediction[..., 5:]

            N, grid_h, grid_w, num_anchors, _ = prediction.shape

            anchors_current_scale = anchors_tensor[scale_idx].reshape(1, 1, 1, num_anchors, 2)

            # Inicializar todo en el device correcto
            target_obj_mask = torch.zeros((N, grid_h, grid_w, num_anchors), dtype=torch.float32, device=self.device)
            target_noobj_mask = torch.ones((N, grid_h, grid_w, num_anchors), dtype=torch.float32, device=self.device)
            tx = torch.zeros((N, grid_h, grid_w, num_anchors), device=self.device)
            ty = torch.zeros((N, grid_h, grid_w, num_anchors), device=self.device)
            tw = torch.zeros((N, grid_h, grid_w, num_anchors), device=self.device)
            th = torch.zeros((N, grid_h, grid_w, num_anchors), device=self.device)
            target_class_one_hot = torch.zeros((N, grid_h, grid_w, num_anchors, self.num_classes), dtype=torch.float32, device=self.device)

            # Vectorizar asignación de anchors
            if targets.numel() > 0:
                # targets: (num_true_boxes_in_batch, 6)
                img_ids = targets[:, 0].long()
                class_ids = targets[:, 1].long()
                x_gt_norm = targets[:, 2]
                y_gt_norm = targets[:, 3]
                w_gt_norm = targets[:, 4]
                h_gt_norm = targets[:, 5]

                x_center_grid = x_gt_norm * grid_w
                y_center_grid = y_gt_norm * grid_h
                cell_x = x_center_grid.long()
                cell_y = y_center_grid.long()

                # Filtrar targets fuera de grid
                grid_mask = (cell_x >= 0) & (cell_x < grid_w) & (cell_y >= 0) & (cell_y < grid_h) & (img_ids >= 0) & (img_ids < N)
                img_ids = img_ids[grid_mask]
                class_ids = class_ids[grid_mask]
                cell_x = cell_x[grid_mask]
                cell_y = cell_y[grid_mask]
                x_center_grid = x_center_grid[grid_mask]
                y_center_grid = y_center_grid[grid_mask]
                w_gt_norm = w_gt_norm[grid_mask]
                h_gt_norm = h_gt_norm[grid_mask]

                # Anchors assignment (vectorized)
                w_gt_pix = w_gt_norm * self.img_size[0]
                h_gt_pix = h_gt_norm * self.img_size[1]
                gt_box_dims = torch.stack([
                    torch.zeros_like(w_gt_pix), torch.zeros_like(h_gt_pix), w_gt_pix, h_gt_pix
                ], dim=1)  # (num_boxes, 4)

                anchor_boxes_for_iou = torch.zeros((num_anchors, 4), device=self.device)
                anchor_boxes_for_iou[:, 2] = anchors_current_scale[0,0,0,:,0]
                anchor_boxes_for_iou[:, 3] = anchors_current_scale[0,0,0,:,1]

                # Expand gt_box_dims for broadcasting
                gt_box_dims_exp = gt_box_dims.unsqueeze(1).expand(-1, num_anchors, 4)  # (num_boxes, num_anchors, 4)
                anchors_exp = anchor_boxes_for_iou.unsqueeze(0).expand(gt_box_dims.size(0), -1, 4)  # (num_boxes, num_anchors, 4)

                ious = intersection_over_union(gt_box_dims_exp, anchors_exp, box_format="corners").squeeze(-1)  # (num_boxes, num_anchors)
                best_iou_anchor_idx = torch.argmax(ious, dim=1)  # (num_boxes,)

                for idx in range(img_ids.shape[0]):
                    i = img_ids[idx]
                    c = class_ids[idx]
                    cx = cell_x[idx]
                    cy = cell_y[idx]
                    best_anchor = best_iou_anchor_idx[idx]
                    # Masks
                    target_obj_mask[i, cy, cx, best_anchor] = 1.0
                    target_noobj_mask[i, cy, cx, best_anchor] = 0.0
                    # Coordinates
                    tx[i, cy, cx, best_anchor] = x_center_grid[idx] - cx
                    ty[i, cy, cx, best_anchor] = y_center_grid[idx] - cy
                    tw[i, cy, cx, best_anchor] = torch.log(w_gt_pix[idx] / anchors_current_scale[0,0,0,best_anchor,0] + 1e-16)
                    th[i, cy, cx, best_anchor] = torch.log(h_gt_pix[idx] / anchors_current_scale[0,0,0,best_anchor,1] + 1e-16)
                    # Class
                    target_class_one_hot[i, cy, cx, best_anchor, c] = 1.0
                    # Ignore anchors with high IoU
                    for anchor_idx_other in range(num_anchors):
                        if anchor_idx_other == best_anchor:
                            continue
                        if ious[idx, anchor_idx_other] > self.ignore_iou_threshold:
                            target_noobj_mask[i, cy, cx, anchor_idx_other] = 0.0

            loss_x = self.bce(pred_x_y[..., 0][target_obj_mask.bool()], tx[target_obj_mask.bool()])
            loss_y = self.bce(pred_x_y[..., 1][target_obj_mask.bool()], ty[target_obj_mask.bool()])
            loss_w = self.mse(pred_w_h[..., 0][target_obj_mask.bool()], tw[target_obj_mask.bool()])
            loss_h = self.mse(pred_w_h[..., 1][target_obj_mask.bool()], th[target_obj_mask.bool()])
            box_loss += (loss_x + loss_y + loss_w + loss_h)

            loss_obj = self.bce(pred_obj[target_obj_mask.bool()], target_obj_mask[target_obj_mask.bool()].float().unsqueeze(-1))
            loss_noobj = self.bce(pred_obj[target_noobj_mask.bool()], target_noobj_mask[target_noobj_mask.bool()].float().unsqueeze(-1))
            obj_loss += loss_obj
            noobj_loss += loss_noobj

            loss_class = self.bce(pred_class[target_obj_mask.bool()], target_class_one_hot[target_obj_mask.bool()])
            class_loss += loss_class

        total_loss = (
            self.lambda_coord * box_loss
            + self.lambda_obj * obj_loss
            + self.lambda_noobj * noobj_loss
            + self.lambda_class * class_loss
        )
        return total_loss, {"box_loss": box_loss, "obj_loss": obj_loss, "noobj_loss": noobj_loss, "class_loss": class_loss}

In [5]:

import torch

# --- Parámetros para la prueba ---
NUM_CLASSES = 3
IMG_SIZE = (416, 416)
BATCH_SIZE = 2

# Anchors (los tuyos)
ANCHORS = [
    [(227, 210), (179, 155), (124, 111)],
    [(105, 113), (104, 96), (80, 109)],
    [(112, 75), (87, 82), (39, 38)]
]

# --- Selección dinámica de device ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Usando device: {device}")

# --- Instanciar la función de pérdida ---

# IOU y Funcion de Perdida de YOLOV3

import torch
import torch.nn as nn

def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):
    if box_format == "midpoint":
        box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
        box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
        box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
        box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2

        box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
        box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
        box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
        box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2

    elif box_format == "corners":
        box1_x1 = boxes_preds[..., 0:1]
        box1_y1 = boxes_preds[..., 1:2]
        box1_x2 = boxes_preds[..., 2:3]
        box1_y2 = boxes_preds[..., 3:4]

        box2_x1 = boxes_labels[..., 0:1]
        box2_y1 = boxes_labels[..., 1:2]
        box2_x2 = boxes_labels[..., 2:3]
        box2_y2 = boxes_labels[..., 3:4]
    else:
        raise ValueError("box_format debe ser 'midpoint' o 'corners'")

    x1_inter = torch.max(box1_x1, box2_x1)
    y1_inter = torch.max(box1_y1, box2_y1)
    x2_inter = torch.min(box1_x2, box2_x2)
    y2_inter = torch.min(box1_y2, box2_y2)

    intersection = (x2_inter - x1_inter).clamp(0) * (y2_inter - y1_inter).clamp(0)
    box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
    box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))
    union = box1_area + box2_area - intersection + 1e-6
    iou = intersection / union
    return iou

class YOLOv3Loss(nn.Module):
    def __init__(self, anchors, num_classes, img_size=(416, 416),
                lambda_coord=1.0, lambda_noobj=1.0, lambda_obj=1.0, lambda_class=1.0,
                ignore_iou_threshold=0.5, device=None):
        super().__init__()
        self.anchors = anchors
        self.num_classes = num_classes
        self.img_size = img_size
        self.lambda_coord = lambda_coord
        self.lambda_noobj = lambda_noobj
        self.lambda_obj = lambda_obj
        self.lambda_class = lambda_class
        self.ignore_iou_threshold = ignore_iou_threshold

        self.mse = nn.MSELoss()
        self.bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([1.0]))
        self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def forward(self, predictions, targets):
        obj_loss = 0
        noobj_loss = 0
        box_loss = 0
        class_loss = 0

        # Convert anchors to tensor only once, on correct device
        anchors_tensor = [torch.tensor(a, dtype=torch.float32, device=self.device) for a in self.anchors]

        for scale_idx, prediction in enumerate(predictions):
            # prediction: (N, 3*(5+C), H, W)
            prediction = prediction.permute(0, 2, 3, 1).reshape(
                prediction.shape[0], prediction.shape[2], prediction.shape[3], 3, self.num_classes + 5
            )

            pred_x_y = prediction[..., 0:2]
            pred_w_h = prediction[..., 2:4]
            pred_obj = prediction[..., 4:5]
            pred_class = prediction[..., 5:]

            N, grid_h, grid_w, num_anchors, _ = prediction.shape

            anchors_current_scale = anchors_tensor[scale_idx].reshape(1, 1, 1, num_anchors, 2)

            # Inicializar todo en el device correcto
            target_obj_mask = torch.zeros((N, grid_h, grid_w, num_anchors), dtype=torch.float32, device=self.device)
            target_noobj_mask = torch.ones((N, grid_h, grid_w, num_anchors), dtype=torch.float32, device=self.device)
            tx = torch.zeros((N, grid_h, grid_w, num_anchors), device=self.device)
            ty = torch.zeros((N, grid_h, grid_w, num_anchors), device=self.device)
            tw = torch.zeros((N, grid_h, grid_w, num_anchors), device=self.device)
            th = torch.zeros((N, grid_h, grid_w, num_anchors), device=self.device)
            target_class_one_hot = torch.zeros((N, grid_h, grid_w, num_anchors, self.num_classes), dtype=torch.float32, device=self.device)

            # Vectorizar asignación de anchors
            if targets.numel() > 0:
                # targets: (num_true_boxes_in_batch, 6)
                img_ids = targets[:, 0].long()
                class_ids = targets[:, 1].long()
                x_gt_norm = targets[:, 2]
                y_gt_norm = targets[:, 3]
                w_gt_norm = targets[:, 4]
                h_gt_norm = targets[:, 5]

                x_center_grid = x_gt_norm * grid_w
                y_center_grid = y_gt_norm * grid_h
                cell_x = x_center_grid.long()
                cell_y = y_center_grid.long()

                # Filtrar targets fuera de grid
                grid_mask = (cell_x >= 0) & (cell_x < grid_w) & (cell_y >= 0) & (cell_y < grid_h) & (img_ids >= 0) & (img_ids < N)
                img_ids = img_ids[grid_mask]
                class_ids = class_ids[grid_mask]
                cell_x = cell_x[grid_mask]
                cell_y = cell_y[grid_mask]
                x_center_grid = x_center_grid[grid_mask]
                y_center_grid = y_center_grid[grid_mask]
                w_gt_norm = w_gt_norm[grid_mask]
                h_gt_norm = h_gt_norm[grid_mask]

                # Anchors assignment (vectorized)
                w_gt_pix = w_gt_norm * self.img_size[0]
                h_gt_pix = h_gt_norm * self.img_size[1]
                gt_box_dims = torch.stack([
                    torch.zeros_like(w_gt_pix), torch.zeros_like(h_gt_pix), w_gt_pix, h_gt_pix
                ], dim=1)  # (num_boxes, 4)

                anchor_boxes_for_iou = torch.zeros((num_anchors, 4), device=self.device)
                anchor_boxes_for_iou[:, 2] = anchors_current_scale[0,0,0,:,0]
                anchor_boxes_for_iou[:, 3] = anchors_current_scale[0,0,0,:,1]

                # Expand gt_box_dims for broadcasting
                gt_box_dims_exp = gt_box_dims.unsqueeze(1).expand(-1, num_anchors, 4)  # (num_boxes, num_anchors, 4)
                anchors_exp = anchor_boxes_for_iou.unsqueeze(0).expand(gt_box_dims.size(0), -1, 4)  # (num_boxes, num_anchors, 4)

                ious = intersection_over_union(gt_box_dims_exp, anchors_exp, box_format="corners").squeeze(-1)  # (num_boxes, num_anchors)
                best_iou_anchor_idx = torch.argmax(ious, dim=1)  # (num_boxes,)

                for idx in range(img_ids.shape[0]):
                    i = img_ids[idx]
                    c = class_ids[idx]
                    cx = cell_x[idx]
                    cy = cell_y[idx]
                    best_anchor = best_iou_anchor_idx[idx]
                    # Masks
                    target_obj_mask[i, cy, cx, best_anchor] = 1.0
                    target_noobj_mask[i, cy, cx, best_anchor] = 0.0
                    # Coordinates
                    tx[i, cy, cx, best_anchor] = x_center_grid[idx] - cx
                    ty[i, cy, cx, best_anchor] = y_center_grid[idx] - cy
                    tw[i, cy, cx, best_anchor] = torch.log(w_gt_pix[idx] / anchors_current_scale[0,0,0,best_anchor,0] + 1e-16)
                    th[i, cy, cx, best_anchor] = torch.log(h_gt_pix[idx] / anchors_current_scale[0,0,0,best_anchor,1] + 1e-16)
                    # Class
                    target_class_one_hot[i, cy, cx, best_anchor, c] = 1.0
                    # Ignore anchors with high IoU
                    for anchor_idx_other in range(num_anchors):
                        if anchor_idx_other == best_anchor:
                            continue
                        if ious[idx, anchor_idx_other] > self.ignore_iou_threshold:
                            target_noobj_mask[i, cy, cx, anchor_idx_other] = 0.0

            loss_x = self.bce(pred_x_y[..., 0][target_obj_mask.bool()], tx[target_obj_mask.bool()])
            loss_y = self.bce(pred_x_y[..., 1][target_obj_mask.bool()], ty[target_obj_mask.bool()])
            loss_w = self.mse(pred_w_h[..., 0][target_obj_mask.bool()], tw[target_obj_mask.bool()])
            loss_h = self.mse(pred_w_h[..., 1][target_obj_mask.bool()], th[target_obj_mask.bool()])
            box_loss += (loss_x + loss_y + loss_w + loss_h)

            loss_obj = self.bce(pred_obj[target_obj_mask.bool()], target_obj_mask[target_obj_mask.bool()].float().unsqueeze(-1))
            loss_noobj = self.bce(pred_obj[target_noobj_mask.bool()], target_noobj_mask[target_noobj_mask.bool()].float().unsqueeze(-1))
            obj_loss += loss_obj
            noobj_loss += loss_noobj

            loss_class = self.bce(pred_class[target_obj_mask.bool()], target_class_one_hot[target_obj_mask.bool()])
            class_loss += loss_class

        total_loss = (
            self.lambda_coord * box_loss
            + self.lambda_obj * obj_loss
            + self.lambda_noobj * noobj_loss
            + self.lambda_class * class_loss
        )
        return total_loss, {"box_loss": box_loss, "obj_loss": obj_loss, "noobj_loss": noobj_loss, "class_loss": class_loss}

loss_fn = YOLOv3Loss(
    anchors=ANCHORS,
    num_classes=NUM_CLASSES,
    img_size=IMG_SIZE,
    lambda_coord=1.0,
    lambda_noobj=1.0,
    lambda_obj=1.0,
    lambda_class=1.0,
    ignore_iou_threshold=0.5,
    device=device
).to(device)

# --- Crear predicciones dummy ---
pred0_grid_h, pred0_grid_w = IMG_SIZE[0] // 32, IMG_SIZE[1] // 32
pred1_grid_h, pred1_grid_w = IMG_SIZE[0] // 16, IMG_SIZE[1] // 16
pred2_grid_h, pred2_grid_w = IMG_SIZE[0] // 8, IMG_SIZE[1] // 8

predictions_dummy = [
    torch.randn(BATCH_SIZE, 3 * (5 + NUM_CLASSES), pred0_grid_h, pred0_grid_w, device=device),
    torch.randn(BATCH_SIZE, 3 * (5 + NUM_CLASSES), pred1_grid_h, pred1_grid_w, device=device),
    torch.randn(BATCH_SIZE, 3 * (5 + NUM_CLASSES), pred2_grid_h, pred2_grid_w, device=device)
]

# --- Crear targets dummy ---
targets_dummy = torch.tensor([
    [0, 0, 0.5, 0.5, 0.1, 0.1],
    [0, 1, 0.2, 0.8, 0.05, 0.05],
    [0, 2, 0.7, 0.3, 0.2, 0.2],
    [1, 0, 0.1, 0.1, 0.3, 0.3],
    [1, 1, 0.9, 0.9, 0.08, 0.08],
], dtype=torch.float32, device=device)

# --- Ejecutar la función de pérdida ---
print("\n--- Ejecutando la función de pérdida con datos dummy ---")
total_loss, individual_losses = loss_fn(predictions_dummy, targets_dummy)

print(f"\nPérdida Total: {total_loss.item():.4f}")
print(f"  Pérdida de Cajas (box_loss): {individual_losses['box_loss'].item():.4f}")
print(f"  Pérdida de Objeto (obj_loss): {individual_losses['obj_loss'].item():.4f}")
print(f"  Pérdida de No-objeto (noobj_loss): {individual_losses['noobj_loss'].item():.4f}")
print(f"  Pérdida de Clase (class_loss): {individual_losses['class_loss'].item():.4f}")

if torch.isnan(total_loss) or torch.isinf(total_loss):
    print("\n¡Advertencia! La pérdida total es NaN o Inf. Esto indica un problema numérico.")
else:
    print("\nLa función de pérdida se ejecutó exitosamente con valores numéricos válidos.")

print("\n--- Fin de la prueba ---")

Usando device: cpu

--- Ejecutando la función de pérdida con datos dummy ---

Pérdida Total: 23.2065
  Pérdida de Cajas (box_loss): 15.4137
  Pérdida de Objeto (obj_loss): 2.6371
  Pérdida de No-objeto (noobj_loss): 2.4228
  Pérdida de Clase (class_loss): 2.7329

La función de pérdida se ejecutó exitosamente con valores numéricos válidos.

--- Fin de la prueba ---
