# inference

> 

In [None]:
#| default_exp inference

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import os
from typing import Any, Type, List, Optional, Callable, Tuple
from functools import partial

from pathlib import Path

In [None]:
#| export
import torch
import torch.nn as nn

import torch.nn.init as init

In [None]:
#| export
from cjm_yolox_pytorch.model import build_model, NORM_STATS
from cjm_yolox_pytorch.utils import generate_output_grids

In [None]:
#| export
class YOLOXInferenceWrapper(nn.Module):
    """
    This is a wrapper for the YOLOX <https://arxiv.org/abs/2107.08430> object detection model.
    The class handles preprocessing of the input, postprocessing of the model output, and calculation of bounding boxes and their probabilities.
    """

    def __init__(self, 
                 model:nn.Module, # The YOLOX model.
                 normalize_mean:torch.Tensor=torch.tensor([[[0.]]]*3)[None], # The mean values for normalization.
                 normalize_std:torch.Tensor=torch.tensor([[[1.]]]*3)[None], # The standard deviation values for normalization.
                 strides:Optional[List[int]]=[8, 16, 32], # The strides for the model.
                 scale_inp:bool=False, # Whether to scale the input by dividing by 255.
                 channels_last:bool=False, # Whether the input tensor has channels first.
                 run_box_and_prob_calculation:bool=True # Whether to calculate the bounding boxes and their probabilities.
                ):
        """
        Constructor for the YOLOXInferenceWrapper class.
        """
        super().__init__()
        self.model = model
        self.register_buffer("normalize_mean", normalize_mean)
        self.register_buffer("normalize_std", normalize_std)
        self.scale_inp = scale_inp
        self.channels_last = channels_last
        self.register_buffer("strides", torch.tensor(strides))
        self.run_box_and_prob_calculation = run_box_and_prob_calculation
        self.input_dim_slice = slice(1, 3) if self.channels_last else slice(2, 4)

    def preprocess_input(self, x):
        """
        Preprocess the input for the model.

        Parameters:
        x (torch.Tensor): The input tensor.

        Returns:
        torch.Tensor: The preprocessed input tensor.
        """
        # Scale the input if required
        if self.scale_inp:
            x = x / 255.0

        # Permute the dimensions of the input to bring the channels to the front if required
        if self.channels_last:
            x = x.permute(0, 3, 1, 2)

        # Normalize the input
        x = (x - self.normalize_mean) / self.normalize_std
        return x
        
    def process_output(self, model_output):
        """
        Postprocess the output of the model.

        Parameters:
        model_output (tuple): The output of the model.

        Returns:
        torch.Tensor: The postprocessed output tensor.
        """
        cls_scores, bbox_preds, objectness = model_output
        
        stride_flats = []
        # Iterate over the output strides
        for i in range(self.strides.shape[0]):
            cls = torch.sigmoid(cls_scores[i])  # Apply sigmoid to the class scores
            bbox = bbox_preds[i]  # Get the bounding box predictions
            obj = torch.sigmoid(objectness[i])  # Apply sigmoid to the objectness scores
            cat = torch.cat((bbox, obj, cls), dim=1)  # Concatenate the bounding boxes, objectness, and class scores
            flat = torch.flatten(cat, start_dim=2)  # Flatten the tensor from the second dimension
            stride_flats.append(flat)

        # Concatenate all the flattened tensors
        full_cat = torch.cat(stride_flats, dim=2)
        full_cat_out = full_cat.permute(0, 2, 1)  # Permute the dimensions of the tensor
        return full_cat_out

    def calculate_boxes_and_probs(self, model_output, output_grids):
        """
        Calculate the bounding boxes and their probabilities.

        Parameters:
        model_output (torch.Tensor): The output of the model.
        output_grids (torch.Tensor): The output grids.

        Returns:
        torch.Tensor: The tensor containing the bounding box coordinates, class labels, and maximum probabilities.
        """
        # Calculate the bounding box coordinates
        box_centroids = (model_output[..., :2] + output_grids[..., :2]) * output_grids[..., 2:]
        box_sizes = torch.exp(model_output[..., 2:4]) * output_grids[..., 2:]
        
        x0, y0 = [t.squeeze(dim=2) for t in torch.split(box_centroids - box_sizes / 2, 1, dim=2)]
        w, h = [t.squeeze(dim=2) for t in torch.split(box_sizes, 1, dim=2)]

        # Calculate the probabilities for each class
        box_objectness = model_output[..., 4]
        box_cls_scores = model_output[..., 5:]
        box_probs = box_objectness.unsqueeze(-1) * box_cls_scores

        # Get the maximum probability and corresponding class for each proposal
        max_probs, labels = torch.max(box_probs, dim=-1)

        return torch.stack([x0, y0, w, h, labels.float(), max_probs], dim=-1)

    def forward(self, x):
        """
        The forward method for the YOLOXInferenceWrapper class.

        Parameters:
        x (torch.Tensor): The input tensor.

        Returns:
        torch.Tensor: The output tensor.
        """
        
        input_dims = x.shape[self.input_dim_slice]
                
        # Preprocess the input
        x = self.preprocess_input(x)
        # Pass the input through the model
        x = self.model(x)
        # Postprocess the model output
        x = self.process_output(x)
        
        if self.run_box_and_prob_calculation:
            # Generate output grids
            output_grids = generate_output_grids(*input_dims, self.strides).to(x.device)
            # Calculate the bounding boxes and their probabilities
            x = self.calculate_boxes_and_probs(x, output_grids)
        
        return x

In [None]:
model_type = 'yolox_tiny'

model = build_model(model_type, 19, pretrained=True)

test_inp = torch.randn(1, 3, 256, 256)

with torch.no_grad():
    cls_scores, bbox_preds, objectness = model(test_inp)
    
print(f"cls_scores: {[cls_score.shape for cls_score in cls_scores]}")
print(f"bbox_preds: {[bbox_pred.shape for bbox_pred in bbox_preds]}")
print(f"objectness: {[objectness.shape for objectness in objectness]}")

The file ./pretrained_checkpoints/yolox_tiny.pth already exists and overwrite is set to False.
cls_scores: [torch.Size([1, 19, 32, 32]), torch.Size([1, 19, 16, 16]), torch.Size([1, 19, 8, 8])]
bbox_preds: [torch.Size([1, 4, 32, 32]), torch.Size([1, 4, 16, 16]), torch.Size([1, 4, 8, 8])]
objectness: [torch.Size([1, 1, 32, 32]), torch.Size([1, 1, 16, 16]), torch.Size([1, 1, 8, 8])]


  state_dict = torch.load(checkpoint_path, map_location='cpu')


In [None]:
norm_stats = [*NORM_STATS[model_type].values()]

# Convert the normalization stats to tensors
mean_tensor = torch.tensor(norm_stats[0]).view(1, 3, 1, 1)
std_tensor = torch.tensor(norm_stats[1]).view(1, 3, 1, 1)

# Set the model to evaluation mode
model.eval();

# Wrap the model with preprocessing and post-processing steps
wrapped_model = YOLOXInferenceWrapper(model, 
                                      mean_tensor, 
                                      std_tensor, 
                                      scale_inp=False, 
                                      channels_last=False)

with torch.no_grad():
    model_output = wrapped_model(test_inp)
model_output.shape

torch.Size([1, 1344, 6])

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()