Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do you support batch inference? #282

Closed
zhaoxuyan opened this issue Nov 9, 2019 · 42 comments
Closed

Do you support batch inference? #282

zhaoxuyan opened this issue Nov 9, 2019 · 42 comments

Comments

@zhaoxuyan
Copy link

馃殌 Feature

How to easily take advantage of batch processing during inference?

Motivation

facebookresearch/Detectron#84

@ppwwyyxx
Copy link
Contributor

ppwwyyxx commented Nov 9, 2019

You just need to call model with a batch of inputs (https://detectron2.readthedocs.io/tutorials/models.html#model-input-format).

@ppwwyyxx ppwwyyxx added the usage label Nov 9, 2019
@zhaoxuyan
Copy link
Author

zhaoxuyan commented Nov 9, 2019

You just need to call model with a batch of inputs (https://detectron2.readthedocs.io/tutorials/models.html#model-input-format).

@ppwwyyxx Thank you very much for your reply~
Can you tell me how to set up a batch of pictures in /demo/demo.py

img = read_image(path, format="BGR")
demo.run_on_image(img)

It can only infer one image at a time, how to infer a batch of images?
Because I have a lot of remaining gpu memory when infer one image at a time, and I want to get the results of multiple video streams(a batch of images) by doing only one inference

@ppwwyyxx
Copy link
Contributor

ppwwyyxx commented Nov 9, 2019

demo.py does not support batch inference.

@zhaoxuyan
Copy link
Author

demo.py does not support batch inference.

So how should I implement it please?

@ppwwyyxx
Copy link
Contributor

ppwwyyxx commented Nov 9, 2019

https://detectron2.readthedocs.io/tutorials/models.html#use-models explains how to use a model.

@keko950
Copy link

keko950 commented Nov 11, 2019

You just need to call model with a batch of inputs (https://detectron2.readthedocs.io/tutorials/models.html#model-input-format).

That link doesn't refers in any way to batch inference, would be nice a more detailed answer since the api doesn't go deeper.

@ppwwyyxx
Copy link
Contributor

The link explains that the input to the model is a list[{"image": ...}, ...]. As long as the length of the list > 1, you're doing batch inference.

@keko950
Copy link

keko950 commented Nov 11, 2019

I was thinking in DefaultPredictor, not model class itself, my bad!
Working like a charm, thank you!

@zhaoxuyan
Copy link
Author

zhaoxuyan commented Nov 15, 2019

The link explains that the input to the model is a list[{"image": ...}, ...]. As long as the length of the list > 1, you're doing batch inference.

@keko950 @ppwwyyxx
Thank you for your reply. Another question:
if my list[{"image": ...}, ...] has 8 images, how to specify multi gpu_num = 4 in DefaultPredictor and each gpu infer with 8/4=2 images?

@ppwwyyxx
Copy link
Contributor

DefaultPredictor uses 1 GPU.

@darkAlert
Copy link

I create a predictor as follows:

	cfg = get_cfg()
	cfg.merge_from_file(path_to_cfg)
	cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 
	cfg.MODEL.WEIGHTS = path_to_weights
	predictor = DefaultPredictor(cfg)
        
        img = cv2.imread(img_path)
        img_list = [{"image":img}]
        predictor(img_list)

And then I get the following error:

File "/home/darkalert/builds/pytorch/build/lib.linux-x86_64-3.6/torch/autograd/grad_mode.py", line 49, in decorate_no_grad  return func(*args, **kwargs)
File "/home/darkalert/builds/detectron2/detectron2/engine/defaults.py", line 172, in __call__height, width = original_image.shape[:2] AttributeError: 'list' object has no attribute 'shape'

@ppwwyyxx
Copy link
Contributor

@darkAlert as said above, you need to call a model, not predictor, with batch of inputs.
As the docs (https://detectron2.readthedocs.io/modules/engine.html#detectron2.engine.defaults.DefaultPredictor) said the DefaultPredictor does not support batch of inputs.

@jiangkansg
Copy link

jiangkansg commented Dec 4, 2019

I tried to follow the instructions, and this is the code:

cfg = get_cfg()
cfg.merge_from_file("./detectron2_repo/configs/COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 # set threshold for this model
cfg.MODEL.WEIGHTS = "detectron2://COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x/137849621/model_final_a6e10b.pkl"

from detectron2.modeling import build_model
model = build_model(cfg) # returns a torch.nn.Module

from detectron2.checkpoint import DetectionCheckpointer
DetectionCheckpointer(model).load("detectron2://COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x/137849621/model_final_a6e10b.pkl")

img = cv2.imread("./input.jpg")
img = np.transpose(img,(2,0,1)) # change from HWC to CHW
img_tensor = torch.from_numpy(img)
inputs = [{"image":img_tensor}, {"image":img_tensor}]
outputs = model(inputs)

The error message is:
/content/detectron2_repo/detectron2/modeling/proposal_generator/rpn_outputs.py in _get_ground_truth(self)
260 # Concatenate anchors from all feature maps into a single Boxes per image
261 anchors = [Boxes.cat(anchors_i) for anchors_i in self.anchors]
--> 262 for image_size_i, anchors_i, gt_boxes_i in zip(self.image_sizes, anchors, self.gt_boxes):
263 """
264 image_size_i: (h, w) for the i-th image

TypeError: zip argument #3 must support iteration

Did I miss anything? Thank you.

@ppwwyyxx
Copy link
Contributor

ppwwyyxx commented Dec 4, 2019

To do inference you need to set a pytorch module to inference mode.

@jiangkansg
Copy link

Thank you. I tried by adding model.train(False). However, batch inference of keypoints "sometimes" (roughly 1 in 5 times) causes error. This is really strange.

Here is the complete code:
cfg = get_cfg()
cfg.merge_from_file("./detectron2_repo/configs/COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 # set threshold for this model
cfg.MODEL.WEIGHTS = "detectron2://COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x/137849621/model_final_a6e10b.pkl"
from detectron2.modeling import build_model
model = build_model(cfg) # returns a torch.nn.Module
model.train(False)
img = cv2.imread("./input.jpg")
img = np.transpose(img,(2,0,1))
img_tensor = torch.from_numpy(img)
inputs = [{"image":img_tensor}, {"image":img_tensor}]
outputs = model(inputs) #error may happen here

/content/detectron2_repo/detectron2/structures/keypoints.py in heatmaps_to_keypoints(maps, rois)
193 assert (
194 roi_map_probs[keypoints_idx, y_int, x_int]
--> 195 == roi_map_probs.view(num_keypoints, -1).max(1)[0]
196 ).all()
197

AssertionError:

@ppwwyyxx
Copy link
Contributor

ppwwyyxx commented Dec 5, 2019

Could you provide the image, or even better a colab notebook, that reproduces this error?

@jiangkansg
Copy link

Sure. Here is the colab notebook modified based on your tutorial.
https://colab.research.google.com/drive/1K9-r0tmYk5q_Zfp-gc5a8wvHZJ9jdIj6
Thank you.

@ppwwyyxx
Copy link
Contributor

ppwwyyxx commented Dec 5, 2019

The model has to be loaded following https://detectron2.readthedocs.io/tutorials/models.html.
When it's not loaded, it sometimes produces some NaN values that eventually leads to the assertion. I'll see if NaN can be handled better

@jiangkansg
Copy link

I was using config to load the weights. After you latest comments, I now use DetectionCheckpointer(model).load(file_path) to load the weights. It is working now. Lessons learnt! Thank you.

To summaries, this is a working example:
cfg = get_cfg()
cfg.merge_from_file("./detectron2_repo/configs/COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 # set threshold for this model

model = build_model(cfg) # returns a torch.nn.Module
DetectionCheckpointer(model).load('./some_model.pkl') # must load weights this way, can't use cfg.MODEL.WEIGHTS = "..."
model.train(False) # inference mode

img = cv2.imread("./input.jpg")
img = np.transpose(img,(2,0,1))
img_tensor = torch.from_numpy(img)
inputs = [{"image":img_tensor}, {"image":img_tensor}] # inputs is ready

outputs = model(inputs)

@wiamadaya
Copy link

wiamadaya commented Feb 25, 2020

I was using config to load the weights. After you latest comments, I now use DetectionCheckpointer(model).load(file_path) to load the weights. It is working now. Lessons learnt! Thank you.

To summaries, this is a working example:
cfg = get_cfg()
cfg.merge_from_file("./detectron2_repo/configs/COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 # set threshold for this model

model = build_model(cfg) # returns a torch.nn.Module
DetectionCheckpointer(model).load('./some_model.pkl') # must load weights this way, can't use cfg.MODEL.WEIGHTS = "..."
model.train(False) # inference mode

img = cv2.imread("./input.jpg")
img = np.transpose(img,(2,0,1))
img_tensor = torch.from_numpy(img)
inputs = [{"image":img_tensor}, {"image":img_tensor}] # inputs is ready

outputs = model(inputs)

@jiangkansg can you help me understand this part of your code, rearrange the channel order to BGR?

img = np.transpose(img,(2,0,1))
img_tensor = torch.from_numpy(img)

i have a color image with image_height=352, image_width=640

Additionally i observe an anomaly, if i infer my image using DefaultPredictor the prediction output is accurate, but if i am using model one of the object is not detected (less accurate), i have set the cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 accordingly, so i might be missing some finer detail here

@jiangkansg
Copy link

@wiamadaya , numpy image is arranged as (H, W, C), ie the color channels are located at the last axis. In pytorch, it is (C, H, W). That is why we need to re-arrange the axises.

@scploeger
Copy link

I was using config to load the weights. After you latest comments, I now use DetectionCheckpointer(model).load(file_path) to load the weights. It is working now. Lessons learnt! Thank you.

To summaries, this is a working example:
cfg = get_cfg()
cfg.merge_from_file("./detectron2_repo/configs/COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 # set threshold for this model

model = build_model(cfg) # returns a torch.nn.Module
DetectionCheckpointer(model).load('./some_model.pkl') # must load weights this way, can't use cfg.MODEL.WEIGHTS = "..."
model.train(False) # inference mode

img = cv2.imread("./input.jpg")
img = np.transpose(img,(2,0,1))
img_tensor = torch.from_numpy(img)
inputs = [{"image":img_tensor}, {"image":img_tensor}] # inputs is ready

outputs = model(inputs)

@jiangkansg did you ever use this implementation and go further to visualize the output, ensuring that it worked as expected?

@sandeepsign
Copy link

sandeepsign commented May 2, 2020

Using .pth for batch inference

from detectron2.modeling import build_model

img = cv2.imread(image_path)

cfg = get_cfg()
cfg.merge_from_file("model_config.yaml")
cfg.MODEL.WEIGHTS = "model_final.pth"
cfg.MODEL.SCORE_THRESH_TEST = 0.5
cfg.MODEL.DEVICE='cpu'

Setup predictor

model = build_model(cfg)

model_dict = torch.load("model_final.pth", map_location=torch.device('cpu'))
model.load_state_dict(model_dict['model'] )

model.train(False)

img = np.transpose(img,(2,0,1))
img_tensor = torch.from_numpy(img)

inputs = [{"image":img_tensor}, {"image":img_tensor}, {"image":img_tensor}, {"image":img_tensor}, {"image":img_tensor}, {"image":img_tensor}, {"image":img_tensor}, {"image":img_tensor}, {"image":img_tensor}, {"image":img_tensor}]

outputs = model(inputs)

@cbasavaraj
Copy link

Hi, we got batch inference to work, but inference time grows linearly with batch size. Any idea why?

@zmoki688
Copy link

Hi, we got batch inference to work, but inference time grows linearly with batch size. Any idea why?

I have the same problem. I predicted 1 image for 256 times, the inference time under batchsize 8 and 1 are 28s and 31s separately. There is no significant difference, but the difference in GPU memory usage is more than 3 times (including the model usage).

@cbasavaraj
Copy link

@riokaa You mean with batch size 8, it's 28s per frame and not 28s for the entire batch right? (By the way, are you on CPU ad not GPU, because 28 is super slow!) We did some more experiments. Looks like gains from batching comes only at higher batch sizes. See this link: apache/mxnet#6396
So for real-time inference, if you have to treat frames as fast as they come (say 5 or 10 fps), and it's not really possible to wait and make a big batch, batching doesn't really help. This is my conclusion so far.

@zmoki688
Copy link

@riokaa You mean with batch size 8, it's 28s per frame and not 28s for the entire batch right? (By the way, are you on CPU ad not GPU, because 28 is super slow!) We did some more experiments. Looks like gains from batching comes only at higher batch sizes. See this link: apache/incubator-mxnet#6396
So for real-time inference, if you have to treat frames as fast as they come (say 5 or 10 fps), and it's not really possible to wait and make a big batch, batching doesn't really help. This is my conclusion so far.

Sorry for my unclear expression. I try to convert my experiment result to a list:

Batch size Epoch Total frame count Time cost FPS
1 256 256 31s 0.12
8 32 256 28s 0.11

It shows that batching can only increase the speed slightly. As you say, batching makes tiny effort on real-time inference, but only increases the speed on training and validation on specified dataset.

@cbasavaraj
Copy link

Nice table. The last column heading should be Time / Frame (and in milliseconds preferably) rather than FPS.

@Odstrecon
Copy link

Hi, we got batch inference to work, but inference time grows linearly with batch size. Any idea why?

Hey, what was the final code that you used for batch inference? thank you!

@cbasavaraj
Copy link

It can be as simple as this:

class BatchPredictor(DefaultPredictor):
    """Run d2 on a list of images."""

    def __call__(self, images):
        """Run d2 on a list of images.

        Args:
            images (list): BGR images of the expected shape: 720x1280
        """
        images = [
            {'image': torch.from_numpy(image.astype("float32").transpose(2, 0, 1))}
            for image in images
        ]
        with torch.no_grad():
            preds = self.model(images)
        return preds

You could modify a bit to add tranforms etc as in DefaultPredictor.

@xQsM3
Copy link

xQsM3 commented Feb 5, 2021

Hi, I would like to do batch inference but my models output is empty when I use model(inputs). It is working fine for the default predictor but not when i use model(inputs for batch inference. here is my code for creating both model and predictor object:

`
def load_model(conf_thresh):

# load config
cfg = get_cfg()
cfg.merge_from_file("./cfg/faster_rcnn_X_101_32x8d_FPN_3x.yaml")
cfg.MODEL.ANCHOR_GENERATOR.SIZES = [[32*0.75, 64*0.75, 128*0.75, 256*0.75, 512*0.75]]
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = conf_thresh # Set threshold for this model
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
#load trained weights
model = build_model(cfg) # returns a torch.nn.Module
cfg.MODEL.WEIGHTS = "./models/bb_rcnn.pth"    
DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS)
model.train(False) # inference mode
# create predictor
predictor = DefaultPredictor(cfg)
return model,predictor

`

as I said, predictor for one image works fine:

`
model,predictor = load_model(conf_thresh=0.95)

image = cv.imread("/image.jpg")

outputs = predictor(image)
`

with output:

{'instances': Instances(num_instances=2, image_height=1710, image_width=1696, fields=[pred_boxes: Boxes(tensor([[ 580.7190, 1111.2963, 629.8201, 1160.1562], [ 666.4766, 1595.6615, 719.6202, 1648.0587]], device='cuda:0')), scores: tensor([0.9961, 0.9881], device='cuda:0'), pred_classes: tensor([0, 0], device='cuda:0')])}

if I run the models(inputs) now as suggested in https://detectron2.readthedocs.io/en/latest/tutorials/models.html#use-models and here in the comments:

`img = np.transpose(image,(2,0,1))
img_tensor = torch.from_numpy(img)
inputs = [{"image":img_tensor}] # inputs is ready

outputs = model(inputs)`

I get no instances:

{'instances': Instances(num_instances=0, image_height=1710, image_width=1696, fields=[pred_boxes: Boxes(tensor([], device='cuda:0', size=(0, 4), grad_fn=)), scores: tensor([], device='cuda:0', grad_fn=), pred_classes: tensor([], device='cuda:0', dtype=torch.int64)])}]`

EDIT: It appears that for a lower confidence threshold the model(inputs) finds two inferences as well (with a big offset). So it looks like the model somehow does not load all my custom trained weights, but the predictor does?
any help?

@bavo96
Copy link

bavo96 commented Jul 29, 2021

Hi, we got batch inference to work, but inference time grows linearly with batch size. Any idea why?

So is there anyone still experiences this event? I got the prediction on a batch of 4 images and the prediction on 1 image 4 times nearly the same :(

@sandeepsign
Copy link

Are you sure that at inference time, input tensor and model is are cuda device and not CPU?

@bavo96
Copy link

bavo96 commented Jul 29, 2021

Yeah it's on GPU because I had checked it with watch -n 1 nvidia-smi and model was using half of GPU memory.

@ucsky
Copy link

ucsky commented Sep 14, 2021

Hello,

With GPU or CPU, I found a speedup less that 2. What kind of speed up are we suppose to have using batch inference?

@bavo96 I got the same kind of problem with build_model: the output are not right. Using the BatchPredictor class it work fine.

Some log bellow:

Benchmarking on NVIDIA GeForce RTX 2070 with Max-Q Design  with batch size 10
- one-by-one: 1.2202270030975342
- batch inference using build_model 0.6105613708496094 (need some fix)
- batch inference using BatchPredictor 0.6670536994934082
Speed up: 1.8292785183924947

Benchmarking on CPU batch size 10
- one-by-one: 11.404393911361694
- batch inference using build_model 6.993898868560791 (need some fix)
- batch inference using BatchPredictor 6.854889631271362
Speed up: 1.663687458852134

Benchmarking on CPU batch size 100
- one-by-one: 112.25274920463562
- batch inference using build_model 72.49595379829407 (need some fix)
- batch inference using BatchPredictor 75.87662625312805
Speed up: 1.4794114439162738

@fromm1990
Copy link

Just thought I would share my attempt at a BatchPredictor. There is no novelty, I'm just using a touch DataLoader with a very simple Dataset implementation. The DataLoader is important such batching is handled in a multi-processed manner, and thereby cuts down on GPU idle time.

I hope it can be useful for some of you.

from pathlib import Path
from typing import Iterable, List, NamedTuple

import cv2
import detectron2.data.transforms as T
import torch
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import CfgNode, get_cfg
from detectron2.modeling import build_model
from detectron2.structures import Instances
from numpy import ndarray
from torch.utils.data import DataLoader, Dataset


class Prediction(NamedTuple):
    x: float
    y: float
    width: float
    height: float
    score: float
    class_name: str


class ImageDataset(Dataset):

    def __init__(self, imagery: List[Path]):
        self.imagery = imagery

    def __getitem__(self, index) -> ndarray:
        return cv2.imread(self.imagery[index].as_posix())

    def __len__(self):
        return len(self.imagery)


class BatchPredictor:
    def __init__(self, cfg: CfgNode, classes: List[str], batch_size: int, workers: int):
        self.cfg = cfg.clone()  # cfg can be modified by model
        self.classes = classes
        self.batch_size = batch_size
        self.workers = workers
        self.model = build_model(self.cfg)
        self.model.eval()

        checkpointer = DetectionCheckpointer(self.model)
        checkpointer.load(cfg.MODEL.WEIGHTS)

        self.aug = T.ResizeShortestEdge(
            [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST],
            cfg.INPUT.MAX_SIZE_TEST
        )

        self.input_format = cfg.INPUT.FORMAT
        assert self.input_format in ["RGB", "BGR"], self.input_format

    def __collate(self, batch):
        data = []
        for image in batch:
            # Apply pre-processing to image.
            if self.input_format == "RGB":
                # whether the model expects BGR inputs or RGB
                image = image[:, :, ::-1]
            height, width = image.shape[:2]

            image = self.aug.get_transform(image).apply_image(image)
            image = image.astype("float32").transpose(2, 0, 1)
            image = torch.as_tensor(image)
            data.append({"image": image, "height": height, "width": width})
        return data

    def __call__(self, imagery: List[Path]) -> Iterable[List[Prediction]]:
        """[summary]

        :param imagery: [description]
        :type imagery: List[Path]
        :yield: Predictions for each image
        :rtype: [type]
        """
        dataset = ImageDataset(imagery)
        loader = DataLoader(
            dataset,
            self.batch_size,
            shuffle=False,
            num_workers=self.workers,
            collate_fn=self.__collate,
            pin_memory=True
        )
        with torch.no_grad():
            for batch in loader:
                results: List[Instances] = self.model(batch)
                yield from [self.__map_predictions(result['instances']) for result in results]

    def __map_predictions(self, instances: Instances):
        instance_predictions = zip(
            instances.get('pred_boxes'),
            instances.get('scores'),
            instances.get('pred_classes')
        )

        predictions = []
        for box, score, class_index in instance_predictions:
            x1 = box[0].item()
            y1 = box[1].item()
            x2 = box[2].item()
            y2 = box[3].item()
            width = x2 - x1
            height = y2 - y1
            prediction = Prediction(
                x1, y1, width, height, score.item(), self.classes[class_index])
            predictions.append(prediction)
        return predictions

@justlike-prog
Copy link

Did anyone to visualize results of a batch with the Visualiser? Any help would be appreciated

@jrdalenberg
Copy link

Hmm...

I tried out batch inference as implemented by @fromm1990 using COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml as base model.

The gain in processing speed is rather minimal compared to running images individually through the DefaultPredictor. Even when I try out a range of batch sizes and workers.

However, I do get a 48% speed increase if I simply run inference with the DefaultPredictor in parallel with 3 threads on a GTX 2070 Super using joblib.

Package Versions:

torch==1.9.0+cu111
torchvision==0.10.0+cu111
detectron2==0.5+cu111
joblib==1.0.1
opencv-python==4.5.3.56

@fromm1990
Copy link

Hi @jrdalenberg I'm having trouble recognizing the minimal improvement you mention. Could you share how you compared batch processing against sequential processing?

I did a test using 2,748 images. I ran the experiment three times and compared the average run time.
Sequential image inference took 536.7 seconds and batching took 111.5 seconds on average.
This means, in my case, the presented batching technique is 4.8 times faster than processing the imagery sequentially, which i would consider significant.

For the experiment I used an 8 core machine with a RTX 2080TI graphics card. I used 8 workers and a batch size of 64. Each image is about 2.4MB in size (4000x3000 JPEG).

@jrdalenberg
Copy link

Hey @fromm1990, thanks for the quick reply! Given your concern, I wanted to double check whether it was anything in my code that caused delays. And it was. Instead of bounding boxes, I calculate polygons from the masks on the cpu. And the gain in time that I reported above was mainly because of the time gain in mask to polygon calculation.

Below I benchmarked your code with our model and data.

Images: 1920x1080, ~500 kb jpeg.
Machine: RTX 2070 super (8GB), i7-8700K (6 cores)
Model: Retrained COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml.

I tested the first 1000 images.

Batch Predictor

Code

from pathlib import Path
from typing import Iterable, List, NamedTuple

import cv2
import detectron2.data.transforms as T
import torch
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import CfgNode, get_cfg
from detectron2.modeling import build_model
from detectron2.structures import Instances
from numpy import ndarray
from torch.utils.data import DataLoader, Dataset

import os
import json
import time


class Prediction(NamedTuple):
    x: float
    y: float
    width: float
    height: float
    score: float
    class_name: str


class ImageDataset(Dataset):

    def __init__(self, imagery: List[Path]):
        self.imagery = imagery

    def __getitem__(self, index) -> ndarray:
        return cv2.imread(self.imagery[index].as_posix())

    def __len__(self):
        return len(self.imagery)


class BatchPredictor:
    def __init__(self, cfg: CfgNode, classes: List[str], batch_size: int, workers: int):
        self.cfg = cfg.clone()  # cfg can be modified by model
        self.classes = classes
        self.batch_size = batch_size
        self.workers = workers
        self.model = build_model(self.cfg)
        self.model.eval()

        checkpointer = DetectionCheckpointer(self.model)
        checkpointer.load(cfg.MODEL.WEIGHTS)

        self.aug = T.ResizeShortestEdge(
            [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST],
            cfg.INPUT.MAX_SIZE_TEST
        )

        self.input_format = cfg.INPUT.FORMAT
        assert self.input_format in ["RGB", "BGR"], self.input_format

    def __collate(self, batch):
        data = []
        for image in batch:
            # Apply pre-processing to image.
            if self.input_format == "RGB":
                # whether the model expects BGR inputs or RGB
                image = image[:, :, ::-1]
            height, width = image.shape[:2]

            image = self.aug.get_transform(image).apply_image(image)
            image = image.astype("float32").transpose(2, 0, 1)
            image = torch.as_tensor(image)
            data.append({"image": image, "height": height, "width": width})
        return data

    def __call__(self, imagery: List[Path]) -> Iterable[List[Prediction]]:
        """[summary]

        :param imagery: [description]
        :type imagery: List[Path]
        :yield: Predictions for each image
        :rtype: [type]
        """
        dataset = ImageDataset(imagery)
        loader = DataLoader(
            dataset,
            self.batch_size,
            shuffle=False,
            num_workers=self.workers,
            collate_fn=self.__collate,
            pin_memory=True
        )
        with torch.no_grad():
            for batch in loader:
                results: List[Instances] = self.model(batch)
                yield from [self.__map_predictions(result['instances']) for result in results]

    def __map_predictions(self, instances: Instances):
        instance_predictions = zip(
            instances.get('pred_boxes'),
            instances.get('scores'),
            instances.get('pred_classes')
        )

        predictions = []
        for box, score, class_index in instance_predictions:
            x1 = box[0].item()
            y1 = box[1].item()
            x2 = box[2].item()
            y2 = box[3].item()
            width = x2 - x1
            height = y2 - y1
            prediction = Prediction(
                x1, y1, width, height, score.item(), self.classes[class_index])
            predictions.append(prediction)
        return predictions


if __name__ == "__main__":
    model = 'classifier/11-11-2021-model/model_final.pth'
    model_config = 'classifier/11-11-2021-model/config.yaml'
    classes = 'classifier/11-11-2021-model/classes.json'
    root = '/data/images'
    
    filelist = []
    for root, dirs, files in os.walk(root):
        for file in files:
            filelist.append(Path(os.path.join(root, file)))

    with open(classes) as classes_file:
        class_names = json.load(classes_file)['classes']

    cfg = get_cfg()
    cfg.merge_from_file(cfg_filename=model_config)
    cfg.MODEL.WEIGHTS = model

    pred = BatchPredictor(cfg=cfg, classes=class_names, batch_size=1, workers=1)

    t1 = time.time()
    predictions = list(pred(filelist[0:1000]))
    t2 = time.time()
    print(f'Time used: {t2-t1} seconds')

Results

batch_size=1, workers=1: Time used: 70.42968511581421 seconds
batch_size=4, workers=4: Time used: 68.36700916290283 seconds
batch_size=8, workers=8: Time used: 66.61947345733643 seconds
batch_size=9, workers=9: Time used: 65.94470310211182 seconds

batch_size=9, workers=1: Time used: 64.47477769851685 seconds
batch_size=9, workers=2: Time used: 63.85893440246582 seconds
batch_size=9, workers=3: Time used: 64.85213160514832 seconds
batch_size=9, workers=4: Time used: 64.46873760223389 seconds
batch_size=9, workers=5: Time used: 64.94616794586182 seconds
batch_size=9, workers=6: Time used: 65.31316900253296 seconds
batch_size=9, workers=7: Time used: 65.46668720245361 seconds
batch_size=9, workers=8: Time used: 65.55049848556519 seconds

A higher batch size makes the RTX 2070 super run out of memory.

Default Predictor combined with joblib

Code

from pathlib import Path
from typing import List, NamedTuple

import cv2
from detectron2.config import CfgNode, get_cfg

import os
import json
import time

from joblib import Parallel, delayed
from more_itertools import divide
from detectron2.engine import DefaultPredictor


class Prediction(NamedTuple):
    x: float
    y: float
    width: float
    height: float
    score: float
    class_name: str


def classify_job(file_list: List[str], cfg: CfgNode, classes: list):
    predictor = DefaultPredictor(cfg)

    predictions = []
    for file in file_list:
        im = cv2.imread(file)
        instances = predictor(im)['instances']

        instance_predictions = zip(
            instances.get('pred_boxes'),
            instances.get('scores'),
            instances.get('pred_classes')
        )

        for box, score, class_index in instance_predictions:
            x1 = box[0].item()
            y1 = box[1].item()
            x2 = box[2].item()
            y2 = box[3].item()
            width = x2 - x1
            height = y2 - y1
            prediction = Prediction(
                x1, y1, width, height, score.item(), classes[class_index])
            predictions.append(prediction)

    return predictions


if __name__ == "__main__":
    n_jobs = 1

    model = 'classifier/11-11-2021-model/model_final.pth'
    model_config = 'classifier/11-11-2021-model/config.yaml'
    classes = 'classifier/11-11-2021-model/classes.json'
    root = '/data/images'

    filelist = []
    for root, dirs, files in os.walk(root):
        for file in files:
            filelist.append(os.path.join(root, file))

    with open(classes) as classes_file:
        class_names = json.load(classes_file)['classes']

    cfg = get_cfg()
    cfg.merge_from_file(cfg_filename=model_config)
    cfg.MODEL.WEIGHTS = model

    data_parts = divide(n_jobs, filelist[0:1000])

    t1 = time.time()
    output = Parallel(n_jobs=n_jobs)(delayed(classify_job)(list(file_list), cfg=cfg, classes=class_names) for file_list in data_parts)
    t2 = time.time()
    print(f'Time used: {t2-t1} seconds')

Results

n_jobs = 1: Time used: 104.95291018486023 seconds
n_jobs = 2: Time used: 91.95392084121704 seconds
n_jobs = 3: Time used: 87.25907635688782 seconds

Conclusions

So all in all, for this model and data set, the gain between the lowest batch size and the optimal batch size is around 9-10%.
Compared to the default predictor running a single process, your implementation saves about 40% for this model and data set.

I am Looking forward to any comments and suggestions for improvement :)

@preetom-saha-arko
Copy link

@jrdalenberg I tried your code, but I get the following error:
AttributeError: 'BatchPredictor' object has no attribute '__collate'
I ain't sure why I am getting this error. Your code has a function named __collate inside the BatchPredictor class.

In addition to this error, sometimes I get an additional error saying this:
BrokenPipeError: [WinError 232] The pipe is being closed

@Robotatron
Copy link

why is batch inference has the same time as if doing a loop with a single image? Doesn't make sense

@github-actions github-actions bot locked as resolved and limited conversation to collaborators Jan 25, 2024
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests