In [None]:
import torch.nn as nn
from ultralytics.nn.tasks import DetectionModel
from ultralytics.models.yolo.detect import DetectionTrainer

# --- Custom Model with Cross-Entropy Loss ---

class MyCustomModel(DetectionModel):
    def init_criterion(self, hyp, model):
        """
        Initializes the loss function with cross-entropy component.

        Args:
            hyp (dict): Hyperparameters.
            model (nn.Module): The detection model.

        Returns:
            None
        """
        super().init_criterion(hyp, model)  # YOLOv8 loss

        # Cross-entropy loss function
        self.cross_entropy_loss = nn.CrossEntropyLoss()

        # Define a new combined loss function
        def custom_loss_fn(p, targets):  # p: predictions, targets: ground truth
            lbox, lobj, lcls = self.criterion(p, targets)  # YOLOv8 loss components
            lcls += self.cross_entropy_loss(p[1].view(-1, model.nc), targets[:, 0].long()) * hyp.get('ce_weight', 0.1)  # Add weighted CE loss
            return lbox + lobj + lcls 

        # Replace the original criterion's loss function with the custom one
        self.criterion = custom_loss_fn 
    
    def forward(self, x, augment=False, profile=False, visualize=False):
        return super().forward(x, augment, profile, visualize)

# --- Training ---

# Modify loss function within the trainer
class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg=None, weights=None, verbose=True):
        hyp = self.args.hyp if hasattr(self.args, 'hyp') else {}
        model = MyCustomModel(cfg, nc=6)
        model.init_criterion(hyp, model)
        return model
    
args = dict(model='yolov8n.pt', data='euroshop.yaml', epochs=3)
trainer = CustomTrainer(overrides=args)
trainer.train()

