# loss

> An implementation of the loss function for training the [YOLOX](https://arxiv.org/abs/2107.08430) object detection model based on [OpenMMLab](https://github.com/open-mmlab)’s implementation in the [mmdetection](https://github.com/open-mmlab/mmdetection) library.

In [None]:
#| default_exp loss

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from typing import Any, Type, List, Optional, Callable, Tuple, Union, Dict
from functools import partial
from dataclasses import dataclass, field

In [None]:
#| export
import numpy as np

import torch
import torch.nn.functional as F
import torchvision

In [None]:
#| export
from cjm_yolox_pytorch.utils import multi_apply, generate_output_grids
from cjm_yolox_pytorch.simota import AssignResult, SimOTAAssigner

In [None]:
#| export
@dataclass
class SamplingResult:
    """
    Bounding box sampling result.
    
    Based on OpenMMLab's implementation in the mmdetection library:
    
    - [OpenMMLab's Implementation](https://github.com/open-mmlab/mmdetection/blob/d64e719172335fa3d7a757a2a3636bd19e9efb62/mmdet/core/bbox/samplers/sampling_result.py#L7)
    """
    positive_indices: np.ndarray # Indices of the positive samples.
    negative_indices: np.ndarray # Indices of the negative samples.
    bboxes: np.ndarray # Array containing all bounding boxes.
    ground_truth_bboxes: torch.Tensor # Tensor containing all ground truth bounding boxes.
    assignment_result: AssignResult # Object that contains the ground truth indices and labels corresponding to each sample.
    ground_truth_flags: np.ndarray # Array indicating which samples are ground truth.

    def __post_init__(self):
        # Indices of positive and negative samples
        self.positive_bboxes = self.bboxes[self.positive_indices]
        self.negative_bboxes = self.bboxes[self.negative_indices]
        
        # Bounding boxes for positive and negative samples
        self.is_positive_ground_truth = self.ground_truth_flags[self.positive_indices]
        self.number_of_ground_truths = self.ground_truth_bboxes.shape[0]
        self.positive_assigned_ground_truth_indices = self.assignment_result.ground_truth_box_indices[self.positive_indices] - 1

        # Check the consistency of ground truth bounding boxes and assigned indices
        if self.ground_truth_bboxes.numel() == 0:
            if self.positive_assigned_ground_truth_indices.numel() != 0:
                raise ValueError('Mismatch between ground truth bounding boxes and positive assigned ground truth indices.')
            self.positive_ground_truth_bboxes = torch.empty_like(self.ground_truth_bboxes).reshape(-1, 4)
        else:
            if len(self.ground_truth_bboxes.shape) < 2:
                self.ground_truth_bboxes = self.ground_truth_bboxes.reshape(-1, 4)
            self.positive_ground_truth_bboxes = self.ground_truth_bboxes[self.positive_assigned_ground_truth_indices, :]

        # If labels are assigned, assign labels for positive samples. Otherwise, set it as None
        if self.assignment_result.category_labels is not None:
            self.positive_ground_truth_labels = self.assignment_result.category_labels[self.positive_indices]
        else:
            self.positive_ground_truth_labels = None

In [None]:
#| export
class YOLOXLoss:
    """
    The callable YOLOXLoss class implements the loss function for training a YOLOX model.
    
    A YOLOXLoss instance takes the, class scores, predicted bounding boxes, objectness scores, ground truth bounding boxes, and ground truth labels. It then goes through the following steps:

    1. Generate box coordinates for the output grids based on the input dimensions and stride values.
    2. Flatten and concatenate class predictions, bounding box predictions, and objectness scores.
    3. Decode box predictions.
    4. Compute targets for each image in the batch.
    5. Concatenate all positive masks, class targets, objectness targets, and bounding box targets.
    6. Compute the bounding box loss, objectness loss, and classification loss, scale them by their respective weights, and normalize them by the total number of samples.
    7. If using L1 loss, concatenate L1 targets, computes the L1 loss, scale it by its weight, and normalize it by the total number of samples.
    8. Return a dictionary containing the computed losses.
    
    Based on OpenMMLab's implementation in the mmdetection library:
    
    - [OpenMMLab's Implementation](https://github.com/open-mmlab/mmdetection/blob/d64e719172335fa3d7a757a2a3636bd19e9efb62/mmdet/models/dense_heads/yolox_head.py#L321)
        
    """
    def __init__(self, 
                 num_classes:int, # The number of target classes.
                 bbox_loss_weight:float=5.0, # The weight for the loss function to calculate the bounding box regression loss.
                 class_loss_weight:float=1.0, # The weight for the loss function to calculate the classification loss.
                 objectness_loss_weight:float=1.0, # The weight for the loss function to calculate the objectness loss.
                 l1_loss_weight:float=1.0, # The weight for the loss function to calculate the L1 loss.
                 use_l1:bool=False, # Whether to use L1 loss in the calculation.
                 strides:List[int]=[8,16,32] # The list of strides.
                ):
        
        """
        The `__init__` method defines several parameters for computing the loss, 
        and it initializes different loss functions, 
        such as [Generalized IoU](https://pytorch.org/vision/stable/generated/torchvision.ops.generalized_box_iou_loss.html) for bounding box loss, 
        [binary cross entropy with logits](https://pytorch.org/docs/stable/generated/torch.nn.functional.binary_cross_entropy_with_logits.html) for classification and objectness loss, 
        and [L1 loss](https://pytorch.org/docs/stable/generated/torch.nn.functional.l1_loss.html#torch.nn.functional.l1_loss) if applicable.
        """
        
        self.num_classes = num_classes
        
        giou_loss_partial = partial(torchvision.ops.generalized_box_iou_loss, reduction='none', eps=1e-16)
        self.bbox_loss_func = lambda bx1, bx2 : (1-(1-giou_loss_partial(boxes1=bx1, boxes2=bx2))**2).sum()
        self.class_loss_func = partial(F.binary_cross_entropy_with_logits, reduction='sum')
        self.objectness_loss_func = partial(F.binary_cross_entropy_with_logits, reduction='sum')
        self.l1_loss_func = partial(F.l1_loss, reduction='sum')
        
        self.bbox_loss_weight = bbox_loss_weight
        self.class_loss_weight = class_loss_weight
        self.objectness_loss_weight = objectness_loss_weight
        self.l1_loss_weight = l1_loss_weight
        
        self.use_l1 = use_l1
        
        # Initialize the assigner
        self.assigner = SimOTAAssigner(center_radius=2.5)
        
        self.strides = strides
        
        
    def bbox_decode(self, 
                    output_grid_boxes:torch.Tensor, # The output grid boxes.
                    predicted_boxes:torch.Tensor # The predicted bounding boxes.
                   ) -> torch.Tensor: # The decoded bounding boxes.
        """
        Decodes the predicted bounding boxes based on the output grid boxes. 
        Positive indices are those where the ground truth box indices are greater-than zero (indicating a match with a ground truth object), 
        and the negatives are where the ground truth box indices are zero (meaning it does not pair with a ground truth object).
        """
        # Calculate box centroids (geometric centers) and sizes
        box_centroids = (predicted_boxes[..., :2] * output_grid_boxes[:, 2:]) + output_grid_boxes[:, :2]
        box_sizes = torch.exp(predicted_boxes[..., 2:]) * output_grid_boxes[:, 2:]

        # Calculate corners of bounding boxes
        top_left = box_centroids - box_sizes / 2
        bottom_right = box_centroids + box_sizes / 2

        # Stack coordinates to create decoded bounding boxes
        decoded_boxes = torch.cat((top_left, bottom_right), -1)

        return decoded_boxes


    def sample(self, 
               assignment_result:AssignResult, # The assignment result obtained from assigner.
               bboxes:torch.Tensor, #  The predicted bounding boxes.
               ground_truth_boxes:torch.Tensor, # The ground truth boxes.
              ) -> SamplingResult: # The sampling result containing positive and negative indices.
        """
        Samples positive and negative indices based on the assignment result.
        """
        positive_indices = torch.nonzero(
            assignment_result.ground_truth_box_indices > 0, as_tuple=False).squeeze(-1).unique()
        negative_indices = torch.nonzero(
            assignment_result.ground_truth_box_indices == 0, as_tuple=False).squeeze(-1).unique()
        ground_truth_flags = bboxes.new_zeros(bboxes.shape[0], dtype=torch.uint8)
        sampling_result = SamplingResult(positive_indices, negative_indices, bboxes, ground_truth_boxes,
                                         assignment_result, ground_truth_flags)
        return sampling_result


    def get_l1_target(self, 
                      l1_target:torch.Tensor, # The L1 target tensor.
                      ground_truth_boxes:torch.Tensor, # The ground truth boxes.
                      output_grid_boxes:torch.Tensor, # The output grid boxes.
                      epsilon:float=1e-8 # A small value to prevent division by zero.
                     ) -> torch.Tensor: # The updated L1 target.
        """
        Calculates the L1 target, which measures the absolute differences between the predicted and actual values. 
        The L1 loss measures how well the model’s predictions match the ground truth values.
        """
        ground_truth_centroid_and_wh = torchvision.ops.box_convert(ground_truth_boxes, 'xyxy', 'cxcywh')
        l1_target[:, :2] = (ground_truth_centroid_and_wh[:, :2] - output_grid_boxes[:, :2]) / output_grid_boxes[:, 2:]
        l1_target[:, 2:] = torch.log(ground_truth_centroid_and_wh[:, 2:] / output_grid_boxes[:, 2:] + epsilon)
        return l1_target


    def get_target_single(self, 
                          class_preds:torch.Tensor, # The predicted class probabilities.
                          objectness_score:torch.Tensor, # The predicted objectness scores.
                          output_grid_boxes:torch.Tensor, # The output grid boxes.
                          decoded_bboxes:torch.Tensor, # The decoded bounding boxes.
                          ground_truth_bboxes:torch.Tensor, # The ground truth boxes.
                          ground_truth_labels:torch.Tensor # The ground truth labels.
                         ) -> Tuple: # The targets for classification, objectness, bounding boxes, and L1 (if applicable), along with the foreground mask and the number of positive samples.
        """
        Calculates the targets for a single image. 
        It assigns ground truth objects to output grid boxes and samples output grid boxes based on the assignment results. 
        It then generates class targets, objectness targets, bounding box targets, and, optionally, L1 targets.
        """
        # Get the number of prior boxes and ground truth labels
        num_output_grid_boxes = output_grid_boxes.size(0)
        num_ground_truths = ground_truth_labels.size(0)

        # Match dtype of ground truth bounding boxes to the dtype of decoded bounding boxes
        ground_truth_bboxes = ground_truth_bboxes.to(decoded_bboxes.dtype)

        # Check if there are no ground truth labels (objects) in the image
        if num_ground_truths == 0:
            # Initialize targets as zero tensors, and foreground_mask as a boolean tensor with False values
            class_targets = class_preds.new_zeros((0, self.num_classes))
            bbox_targets = class_preds.new_zeros((0, 4))
            l1_targets = class_preds.new_zeros((0, 4))
            objectness_targets = class_preds.new_zeros((num_output_grid_boxes, 1))
            foreground_mask = class_preds.new_zeros(num_output_grid_boxes).bool()
            return (foreground_mask, class_targets, objectness_targets, bbox_targets,
                    l1_targets, 0)  # Return zero for num_positive_per_image

        # Calculate the offset for the prior boxes
        offset_output_grid_boxes = torch.cat([output_grid_boxes[:, :2] + output_grid_boxes[:, 2:] * 0.5, output_grid_boxes[:, 2:]], dim=-1)

        
        # Assign ground truth objects to prior boxes and get assignment results
        assignment_result = self.assigner.assign(
            class_preds.sigmoid() * objectness_score.unsqueeze(1).sigmoid(),
            offset_output_grid_boxes, decoded_bboxes, ground_truth_bboxes, ground_truth_labels)
        
        # Use assignment results to sample prior boxes
        sampling_result = self.sample(assignment_result, output_grid_boxes, ground_truth_bboxes)
        
        # Get the indices of positive (object-containing) samples
        positive_indices = sampling_result.positive_indices
        num_positive_per_image = positive_indices.size(0)

        # Get the maximum IoU values for the positive samples
        positive_ious = assignment_result.max_iou_values[positive_indices]

        # Generate class targets
        class_targets = F.one_hot(sampling_result.positive_ground_truth_labels, self.num_classes) * positive_ious.unsqueeze(-1)

        # Initialize objectness targets as zeros and set the values at positive_indices to 1
        objectness_targets = torch.zeros_like(objectness_score).unsqueeze(-1)
        objectness_targets[positive_indices] = 1

        # Generate bounding box targets
        bbox_targets = sampling_result.positive_ground_truth_bboxes

        # Initialize L1 targets as zeros
        l1_targets = class_preds.new_zeros((num_positive_per_image, 4))

        # If use_l1 is True, calculate L1 targets
        if self.use_l1:
            l1_targets = self.get_l1_target(l1_targets, bbox_targets, output_grid_boxes[positive_indices])

        # Initialize foreground_mask as zeros and set the values at positive_indices to True
        foreground_mask = torch.zeros_like(objectness_score).to(torch.bool)
        foreground_mask[positive_indices] = 1

        # Return the computed targets, the foreground mask, and the number of positive samples per image
        return (foreground_mask, class_targets, objectness_targets, bbox_targets, l1_targets, num_positive_per_image)
    
    
    def flatten_and_concat(self, 
                           tensors:List[torch.Tensor], # A list of tensors to flatten and concatenate.
                           batch_size:int, # The batch size used to reshape the concatenated tensor.
                           reshape_dims:Optional[bool]=None # 
                          ) -> torch.Tensor: # The concatenated tensor
        """
        Flatten and concatenate a list of tensors.
        """
        new_shape = (batch_size, -1, reshape_dims) if reshape_dims else (batch_size, -1)
        return torch.cat([t.permute(0, 2, 3, 1).reshape(*new_shape) for t in tensors], dim=1)

    
    def __call__(self, 
                 class_scores:List[torch.Tensor], # A list of class scores for each scale.
                 predicted_bboxes:List[torch.Tensor], # A list of predicted bounding boxes for each scale.
                 objectness_scores:List[torch.Tensor], # A list of objectness scores for each scale.
                 ground_truth_bboxes:List[torch.Tensor], # A list of ground truth bounding boxes for each image.
                 ground_truth_labels:List[torch.Tensor] # A list of ground truth labels for each image.
                ) -> Dict: # A dictionary with the classification, bounding box, objectness, and optionally, L1 loss.
        """
        The `__call__` method computes the loss values. 
        It first generates box coordinates for the output grids based on the input dimensions and stride values. 
        It then flattens and concatenates class predictions, bounding box predictions, and objectness scores. 
        Next, it decodes the bounding box predictions, computes targets for each image in the batch, 
        and finally computes the bounding box loss, objectness loss, and classification loss (and L1 loss, optionally). 
        These losses are scaled by their respective weights and normalized by the total number of samples.
        """
        
        # Get the number of images in the batch
        batch_size = class_scores[0].shape[0]
        
        # Generate box coordinates for all grid priors.
        output_grid_boxes = generate_output_grids(*[s*self.strides[0] for s in class_scores[0].shape[-2:]], self.strides)
        output_grid_boxes[:, :2] *= output_grid_boxes[:, 2].unsqueeze(1)
        flatten_output_grid_boxes = torch.cat([output_grid_boxes, output_grid_boxes[:, 2:].clone()], dim=1)
        
        # Flatten and concatenate class predictions, bounding box predictions, and objectness scores
        flatten_class_preds = self.flatten_and_concat(class_scores, batch_size, self.num_classes)
        flatten_bbox_preds = self.flatten_and_concat(predicted_bboxes, batch_size, 4)
        flatten_objectness_scores = self.flatten_and_concat(objectness_scores, batch_size)
                    
        # Concatenate and decode box predictions
        flatten_output_grid_boxes = flatten_output_grid_boxes.to(flatten_bbox_preds.device)
        flatten_decoded_bboxes = self.bbox_decode(flatten_output_grid_boxes, flatten_bbox_preds)

        # Compute targets
        (positive_masks, class_targets, objectness_targets, bbox_targets, l1_targets,
         num_positive_images) = multi_apply(
             self.get_target_single, flatten_class_preds.detach(),
             flatten_objectness_scores.detach(),
             flatten_output_grid_boxes.unsqueeze(0).repeat(batch_size, 1, 1),
             flatten_decoded_bboxes.detach(), ground_truth_bboxes, ground_truth_labels)

        # Concatenate all positive masks, class targets, objectness targets, and bounding box targets
        positive_masks = torch.cat(positive_masks, 0)
        class_targets = torch.cat(class_targets, 0)
        objectness_targets = torch.cat(objectness_targets, 0)
        bbox_targets = torch.cat(bbox_targets, 0)

        # Compute bounding box loss
        loss_bbox = self.bbox_loss_func(flatten_decoded_bboxes.reshape(-1, 4)[positive_masks], bbox_targets)

        # Compute objectness loss
        loss_obj = self.objectness_loss_func(flatten_objectness_scores.reshape(-1, 1), objectness_targets)

        # Compute class loss
        loss_cls = self.class_loss_func(flatten_class_preds.reshape(-1, self.num_classes)[positive_masks],class_targets)
        
        # Calculate total number of samples
        num_total_samples = max(sum(num_positive_images), 1)
        
        # Scale losses
        loss_bbox = (loss_bbox * self.bbox_loss_weight) / num_total_samples
        loss_obj = (loss_obj * self.objectness_loss_weight) / num_total_samples
        loss_cls = (loss_cls * self.class_loss_weight) / num_total_samples
        
        # Initialize loss dictionary
        loss_dict = dict(loss_cls=loss_cls, loss_bbox=loss_bbox, loss_obj=loss_obj)

        # If use_l1 is True, concatenate l1 targets, compute L1 loss and add it to the loss dictionary
        if self.use_l1:
            l1_targets = torch.cat(l1_targets, 0)
            loss_l1 = self.l1_loss_func(
                flatten_bbox_preds.reshape(-1, 4)[positive_masks],
                l1_targets) / num_total_samples
            loss_l1 *= self.l1_loss_weight
            loss_dict.update(loss_l1=loss_l1)

        # Return loss dictionary
        return loss_dict

In [None]:
show_doc(YOLOXLoss.__init__)

---

[source](https://github.com/cj-mills/cjm-yolox-pytorch/blob/main/cjm_yolox_pytorch/loss.py#L88){target="_blank" style="float:right; font-size:smaller"}

### YOLOXLoss.__init__

>      YOLOXLoss.__init__ (num_classes:int, bbox_loss_weight:float=5.0,
>                          class_loss_weight:float=1.0,
>                          objectness_loss_weight:float=1.0,
>                          l1_loss_weight:float=1.0, use_l1:bool=False,
>                          strides:List[int]=[8, 16, 32])

*The `__init__` method defines several parameters for computing the loss, 
and it initializes different loss functions, 
such as [Generalized IoU](https://pytorch.org/vision/stable/generated/torchvision.ops.generalized_box_iou_loss.html) for bounding box loss, 
[binary cross entropy with logits](https://pytorch.org/docs/stable/generated/torch.nn.functional.binary_cross_entropy_with_logits.html) for classification and objectness loss, 
and [L1 loss](https://pytorch.org/docs/stable/generated/torch.nn.functional.l1_loss.html#torch.nn.functional.l1_loss) if applicable.*

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| num_classes | int |  | The number of target classes. |
| bbox_loss_weight | float | 5.0 | The weight for the loss function to calculate the bounding box regression loss. |
| class_loss_weight | float | 1.0 | The weight for the loss function to calculate the classification loss. |
| objectness_loss_weight | float | 1.0 | The weight for the loss function to calculate the objectness loss. |
| l1_loss_weight | float | 1.0 | The weight for the loss function to calculate the L1 loss. |
| use_l1 | bool | False | Whether to use L1 loss in the calculation. |
| strides | List | [8, 16, 32] | The list of strides. |

In [None]:
show_doc(YOLOXLoss.bbox_decode)

---

[source](https://github.com/cj-mills/cjm-yolox-pytorch/blob/main/cjm_yolox_pytorch/loss.py#L127){target="_blank" style="float:right; font-size:smaller"}

### YOLOXLoss.bbox_decode

>      YOLOXLoss.bbox_decode (output_grid_boxes:torch.Tensor,
>                             predicted_boxes:torch.Tensor)

*Decodes the predicted bounding boxes based on the output grid boxes. 
Positive indices are those where the ground truth box indices are greater-than zero (indicating a match with a ground truth object), 
and the negatives are where the ground truth box indices are zero (meaning it does not pair with a ground truth object).*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| output_grid_boxes | Tensor | The output grid boxes. |
| predicted_boxes | Tensor | The predicted bounding boxes. |
| **Returns** | **Tensor** | **The decoded bounding boxes.** |

In [None]:
show_doc(YOLOXLoss.sample)

---

[source](https://github.com/cj-mills/cjm-yolox-pytorch/blob/main/cjm_yolox_pytorch/loss.py#L150){target="_blank" style="float:right; font-size:smaller"}

### YOLOXLoss.sample

>      YOLOXLoss.sample
>                        (assignment_result:cjm_yolox_pytorch.simota.AssignResul
>                        t, bboxes:torch.Tensor,
>                        ground_truth_boxes:torch.Tensor)

*Samples positive and negative indices based on the assignment result.*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| assignment_result | AssignResult | The assignment result obtained from assigner. |
| bboxes | Tensor | The predicted bounding boxes. |
| ground_truth_boxes | Tensor | The ground truth boxes. |
| **Returns** | **SamplingResult** | **The sampling result containing positive and negative indices.** |

In [None]:
show_doc(YOLOXLoss.get_l1_target)

---

[source](https://github.com/cj-mills/cjm-yolox-pytorch/blob/main/cjm_yolox_pytorch/loss.py#L168){target="_blank" style="float:right; font-size:smaller"}

### YOLOXLoss.get_l1_target

>      YOLOXLoss.get_l1_target (l1_target:torch.Tensor,
>                               ground_truth_boxes:torch.Tensor,
>                               output_grid_boxes:torch.Tensor,
>                               epsilon:float=1e-08)

*Calculates the L1 target, which measures the absolute differences between the predicted and actual values. 
The L1 loss measures how well the model’s predictions match the ground truth values.*

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| l1_target | Tensor |  | The L1 target tensor. |
| ground_truth_boxes | Tensor |  | The ground truth boxes. |
| output_grid_boxes | Tensor |  | The output grid boxes. |
| epsilon | float | 1e-08 | A small value to prevent division by zero. |
| **Returns** | **Tensor** |  | **The updated L1 target.** |

In [None]:
show_doc(YOLOXLoss.get_target_single)

---

[source](https://github.com/cj-mills/cjm-yolox-pytorch/blob/main/cjm_yolox_pytorch/loss.py#L184){target="_blank" style="float:right; font-size:smaller"}

### YOLOXLoss.get_target_single

>      YOLOXLoss.get_target_single (class_preds:torch.Tensor,
>                                   objectness_score:torch.Tensor,
>                                   output_grid_boxes:torch.Tensor,
>                                   decoded_bboxes:torch.Tensor,
>                                   ground_truth_bboxes:torch.Tensor,
>                                   ground_truth_labels:torch.Tensor)

*Calculates the targets for a single image. 
It assigns ground truth objects to output grid boxes and samples output grid boxes based on the assignment results. 
It then generates class targets, objectness targets, bounding box targets, and, optionally, L1 targets.*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| class_preds | Tensor | The predicted class probabilities. |
| objectness_score | Tensor | The predicted objectness scores. |
| output_grid_boxes | Tensor | The output grid boxes. |
| decoded_bboxes | Tensor | The decoded bounding boxes. |
| ground_truth_bboxes | Tensor | The ground truth boxes. |
| ground_truth_labels | Tensor | The ground truth labels. |
| **Returns** | **Tuple** | **The targets for classification, objectness, bounding boxes, and L1 (if applicable), along with the foreground mask and the number of positive samples.** |

In [None]:
show_doc(YOLOXLoss.flatten_and_concat)

---

[source](https://github.com/cj-mills/cjm-yolox-pytorch/blob/main/cjm_yolox_pytorch/loss.py#L259){target="_blank" style="float:right; font-size:smaller"}

### YOLOXLoss.flatten_and_concat

>      YOLOXLoss.flatten_and_concat (tensors:List[torch.Tensor], batch_size:int,
>                                    reshape_dims:Optional[bool]=None)

*Flatten and concatenate a list of tensors.*

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| tensors | List |  | A list of tensors to flatten and concatenate. |
| batch_size | int |  | The batch size used to reshape the concatenated tensor. |
| reshape_dims | Optional | None |  |
| **Returns** | **Tensor** |  | **The concatenated tensor** |

In [None]:
show_doc(YOLOXLoss.__call__)

---

[source](https://github.com/cj-mills/cjm-yolox-pytorch/blob/main/cjm_yolox_pytorch/loss.py#L271){target="_blank" style="float:right; font-size:smaller"}

### YOLOXLoss.__call__

>      YOLOXLoss.__call__ (class_scores:List[torch.Tensor],
>                          predicted_bboxes:List[torch.Tensor],
>                          objectness_scores:List[torch.Tensor],
>                          ground_truth_bboxes:List[torch.Tensor],
>                          ground_truth_labels:List[torch.Tensor])

*The `__call__` method computes the loss values. 
It first generates box coordinates for the output grids based on the input dimensions and stride values. 
It then flattens and concatenates class predictions, bounding box predictions, and objectness scores. 
Next, it decodes the bounding box predictions, computes targets for each image in the batch, 
and finally computes the bounding box loss, objectness loss, and classification loss (and L1 loss, optionally). 
These losses are scaled by their respective weights and normalized by the total number of samples.*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| class_scores | List | A list of class scores for each scale. |
| predicted_bboxes | List | A list of predicted bounding boxes for each scale. |
| objectness_scores | List | A list of objectness scores for each scale. |
| ground_truth_bboxes | List | A list of ground truth bounding boxes for each image. |
| ground_truth_labels | List | A list of ground truth labels for each image. |
| **Returns** | **Dict** | **A dictionary with the classification, bounding box, objectness, and optionally, L1 loss.** |

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()