## Tutorial 05: infer stages

Running an image through the model is probably the easiest part of this whole process, so there isn't a ton of code for it in `electricmayhem`.

The main class is `em.ModelWrapper` which just inputs a Pytorch model and can be used as a pipeline stage. The object can also handle different training and evaluation models, as well as dictionaries of train and/or eval models (see the next notebook for an example!).

### YOLO object detectors

The specific use case I've worked with most often has been YOLO models, so I've implemented a subclass of the model wrapper, `em.YOLOWrapper`. It adds two things:

* Some functions to convert different versions of the YOLO codebase to the same output format
* Diagnostic images that will get logged to tensorboard, displaying detections on identical image with and without the patch.

#### YOLOv4

At this point in history, the official YOLOv4 implementation still relied on DarkNet. There are unofficial Pytorch conversions out there and they're functional but finicky.

The YOLOv4 codebase at https://github.com/Tianxaomo/pytorch-YOLOv4 outputs an unusual format; a list of two tensors:

* The first has shape `batch_size, num_boxes, 1, 4]` and contains the bounding boxes is normal coordinates
* The second output has shape `[batch_size, num_boxes, num_classes]` provides detection probabilities.

To use with `electricmayhem`:

```
from tool.darknet2pytorch import Darknet
model = Darknet("path_to_config_file", inference=True).eval()
model.load_weights("path_to_weight_file")

yolo = em.YOLOWrapper(model, yolo_version=4)
```

The stage outputs will be in the YOLOv5 format.

#### YOLOv5

This was the first "official" YOLO codebase written in Pytorch. The output is a list with two elements:

* The first is a tensor of shape `[batch_size, num_boxes, 5+num_classes]`, with the last index covering box dimensions in xywh pixel format, an overall detection score, and class detection scores.
* The second is a list of 3 tensors giving unaggregated results from the different detection heads.

Be careful loading the model; by default in inference mode it does some in-place tensor updates that mess up gradients.

To use with `electricmayhem`:

```
from models.yolo import Model
# load weights checkpoint to CPU
ckpt = torch.load("path_to_model_weights", map_location='cpu')
# turn off in-place updates
config = ckpt["model"].yaml
config["inplace"] = False
# initialize model
model = Model(config, ch=3, nc=num_classes)
# load weights into model
csd = ckpt["model"].float().state_dict()
model.load_state_dict(csd, strict=False)
# swap to eval mode to freeze batchnorm
model.eval()

yolo = em.YOLOWrapper(model, yolo_version=5)
```

The stage outputs will be in the YOLOv5 format.

#### Later versions

The `ultralytics` library is much more convenient to work with than previous versions and is conda-installable. It does have a **lot** of unnecessary automation, though, which can get in the way if you're doing something weird with it. I've had good luck so far pulling out one of the lower-level objects in the library's API.

```
import ultralytics
yolov8n = ultralytics.YOLO("yolov8n.pt").model.eval()
yolo = em.YOLOWrapper(yolov8n, yolo_version=8)
```

This stage will return results in the YOLOv5 format. Note that YOLOv10 specifically has a different output format not yet implemented in `electricmayhem`.

#### `ultralytics` models exported to ONNX files

If you have to start from a model that was exported as ONNX- you can load these back into memory directly with the `ultralytics` library, but I haven't figured out how to get a differentiable function out of that object.

I've had some mixed success using a third library (such as `onnx2torch` or `onnx2pytorch`) to load the ONNX file into a Pytorch object. Sometimes the outputs will be in a slightly different format so you may need to wrap the model to reformat it (or subclass `em.YOLOWrapper`). I've avoided adding this capability to `electricmayhem` because the requirements seem to change with specific versions of `ultralytics`, `onnx`, conversion library and YOLO version.

For example:

```
import torch
import ultralytics
import onnx2torch
import electricmayhem.whitebox as em

# export the ONNX model
model = YOLO("yolo11n.pt")
model.export(format="onnx")  # creates 'yolo11n.onnx'

# load back into memory
onnx_model = onnx2torch.convert(filepath)
onnx_model.eval()


# output is a [batch_size, num_boxes, 4+num_classes] tensor so wrap it to return that
# tensor as the zeroth element of a list so it'll be consistent with other YOLO models
class ONNXWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        return [self.model(x)]

wrapped = ONNXWrapper(onnx_model)
yolo = em.YOLOWrapper(wrapped, yolo_version=11)
```