-
Notifications
You must be signed in to change notification settings - Fork 7.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Do you support batch inference? #282
Comments
You just need to call model with a batch of inputs (https://detectron2.readthedocs.io/tutorials/models.html#model-input-format). |
@ppwwyyxx Thank you very much for your reply~ img = read_image(path, format="BGR")
demo.run_on_image(img) It can only infer one image at a time, how to infer a batch of images? |
demo.py does not support batch inference. |
So how should I implement it please? |
https://detectron2.readthedocs.io/tutorials/models.html#use-models explains how to use a model. |
That link doesn't refers in any way to batch inference, would be nice a more detailed answer since the api doesn't go deeper. |
The link explains that the input to the model is a |
I was thinking in DefaultPredictor, not model class itself, my bad! |
@keko950 @ppwwyyxx |
|
I create a predictor as follows:
And then I get the following error:
|
@darkAlert as said above, you need to call a model, not predictor, with batch of inputs. |
I tried to follow the instructions, and this is the code: cfg = get_cfg() from detectron2.modeling import build_model from detectron2.checkpoint import DetectionCheckpointer img = cv2.imread("./input.jpg") The error message is: TypeError: zip argument #3 must support iteration Did I miss anything? Thank you. |
To do inference you need to set a pytorch module to inference mode. |
Thank you. I tried by adding model.train(False). However, batch inference of keypoints "sometimes" (roughly 1 in 5 times) causes error. This is really strange. Here is the complete code: /content/detectron2_repo/detectron2/structures/keypoints.py in heatmaps_to_keypoints(maps, rois) AssertionError: |
Could you provide the image, or even better a colab notebook, that reproduces this error? |
Sure. Here is the colab notebook modified based on your tutorial. |
The model has to be loaded following https://detectron2.readthedocs.io/tutorials/models.html. |
I was using config to load the weights. After you latest comments, I now use DetectionCheckpointer(model).load(file_path) to load the weights. It is working now. Lessons learnt! Thank you. To summaries, this is a working example: model = build_model(cfg) # returns a torch.nn.Module img = cv2.imread("./input.jpg") outputs = model(inputs) |
@jiangkansg can you help me understand this part of your code, rearrange the channel order to BGR?
i have a color image with Additionally i observe an anomaly, if i infer my image using |
@wiamadaya , numpy image is arranged as (H, W, C), ie the color channels are located at the last axis. In pytorch, it is (C, H, W). That is why we need to re-arrange the axises. |
@jiangkansg did you ever use this implementation and go further to visualize the output, ensuring that it worked as expected? |
Using .pth for batch inferencefrom detectron2.modeling import build_model img = cv2.imread(image_path) cfg = get_cfg() Setup predictormodel = build_model(cfg) model_dict = torch.load("model_final.pth", map_location=torch.device('cpu')) model.train(False) img = np.transpose(img,(2,0,1)) inputs = [{"image":img_tensor}, {"image":img_tensor}, {"image":img_tensor}, {"image":img_tensor}, {"image":img_tensor}, {"image":img_tensor}, {"image":img_tensor}, {"image":img_tensor}, {"image":img_tensor}, {"image":img_tensor}] outputs = model(inputs) |
Hi, we got batch inference to work, but inference time grows linearly with batch size. Any idea why? |
I have the same problem. I predicted 1 image for 256 times, the inference time under batchsize 8 and 1 are 28s and 31s separately. There is no significant difference, but the difference in GPU memory usage is more than 3 times (including the model usage). |
@riokaa You mean with batch size 8, it's 28s per frame and not 28s for the entire batch right? (By the way, are you on CPU ad not GPU, because 28 is super slow!) We did some more experiments. Looks like gains from batching comes only at higher batch sizes. See this link: apache/mxnet#6396 |
Sorry for my unclear expression. I try to convert my experiment result to a list:
It shows that batching can only increase the speed slightly. As you say, batching makes tiny effort on real-time inference, but only increases the speed on training and validation on specified dataset. |
Nice table. The last column heading should be Time / Frame (and in milliseconds preferably) rather than FPS. |
Hey, what was the final code that you used for batch inference? thank you! |
It can be as simple as this:
You could modify a bit to add tranforms etc as in DefaultPredictor. |
Hi, I would like to do batch inference but my models output is empty when I use model(inputs). It is working fine for the default predictor but not when i use model(inputs for batch inference. here is my code for creating both model and predictor object: `
` as I said, predictor for one image works fine: ` image = cv.imread("/image.jpg") outputs = predictor(image) with output:
if I run the models(inputs) now as suggested in https://detectron2.readthedocs.io/en/latest/tutorials/models.html#use-models and here in the comments: `img = np.transpose(image,(2,0,1)) outputs = model(inputs)` I get no instances:
EDIT: It appears that for a lower confidence threshold the model(inputs) finds two inferences as well (with a big offset). So it looks like the model somehow does not load all my custom trained weights, but the predictor does? |
So is there anyone still experiences this event? I got the prediction on a batch of 4 images and the prediction on 1 image 4 times nearly the same :( |
Are you sure that at inference time, input tensor and model is are cuda device and not CPU? |
Yeah it's on GPU because I had checked it with |
Hello, With GPU or CPU, I found a speedup less that 2. What kind of speed up are we suppose to have using batch inference? @bavo96 I got the same kind of problem with Some log bellow:
|
Just thought I would share my attempt at a BatchPredictor. There is no novelty, I'm just using a touch DataLoader with a very simple Dataset implementation. The DataLoader is important such batching is handled in a multi-processed manner, and thereby cuts down on GPU idle time. I hope it can be useful for some of you. from pathlib import Path
from typing import Iterable, List, NamedTuple
import cv2
import detectron2.data.transforms as T
import torch
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import CfgNode, get_cfg
from detectron2.modeling import build_model
from detectron2.structures import Instances
from numpy import ndarray
from torch.utils.data import DataLoader, Dataset
class Prediction(NamedTuple):
x: float
y: float
width: float
height: float
score: float
class_name: str
class ImageDataset(Dataset):
def __init__(self, imagery: List[Path]):
self.imagery = imagery
def __getitem__(self, index) -> ndarray:
return cv2.imread(self.imagery[index].as_posix())
def __len__(self):
return len(self.imagery)
class BatchPredictor:
def __init__(self, cfg: CfgNode, classes: List[str], batch_size: int, workers: int):
self.cfg = cfg.clone() # cfg can be modified by model
self.classes = classes
self.batch_size = batch_size
self.workers = workers
self.model = build_model(self.cfg)
self.model.eval()
checkpointer = DetectionCheckpointer(self.model)
checkpointer.load(cfg.MODEL.WEIGHTS)
self.aug = T.ResizeShortestEdge(
[cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST],
cfg.INPUT.MAX_SIZE_TEST
)
self.input_format = cfg.INPUT.FORMAT
assert self.input_format in ["RGB", "BGR"], self.input_format
def __collate(self, batch):
data = []
for image in batch:
# Apply pre-processing to image.
if self.input_format == "RGB":
# whether the model expects BGR inputs or RGB
image = image[:, :, ::-1]
height, width = image.shape[:2]
image = self.aug.get_transform(image).apply_image(image)
image = image.astype("float32").transpose(2, 0, 1)
image = torch.as_tensor(image)
data.append({"image": image, "height": height, "width": width})
return data
def __call__(self, imagery: List[Path]) -> Iterable[List[Prediction]]:
"""[summary]
:param imagery: [description]
:type imagery: List[Path]
:yield: Predictions for each image
:rtype: [type]
"""
dataset = ImageDataset(imagery)
loader = DataLoader(
dataset,
self.batch_size,
shuffle=False,
num_workers=self.workers,
collate_fn=self.__collate,
pin_memory=True
)
with torch.no_grad():
for batch in loader:
results: List[Instances] = self.model(batch)
yield from [self.__map_predictions(result['instances']) for result in results]
def __map_predictions(self, instances: Instances):
instance_predictions = zip(
instances.get('pred_boxes'),
instances.get('scores'),
instances.get('pred_classes')
)
predictions = []
for box, score, class_index in instance_predictions:
x1 = box[0].item()
y1 = box[1].item()
x2 = box[2].item()
y2 = box[3].item()
width = x2 - x1
height = y2 - y1
prediction = Prediction(
x1, y1, width, height, score.item(), self.classes[class_index])
predictions.append(prediction)
return predictions |
Did anyone to visualize results of a batch with the Visualiser? Any help would be appreciated |
Hmm... I tried out batch inference as implemented by @fromm1990 using The gain in processing speed is rather minimal compared to running images individually through the However, I do get a 48% speed increase if I simply run inference with the Package Versions:
|
Hi @jrdalenberg I'm having trouble recognizing the minimal improvement you mention. Could you share how you compared batch processing against sequential processing? I did a test using 2,748 images. I ran the experiment three times and compared the average run time. For the experiment I used an 8 core machine with a RTX 2080TI graphics card. I used 8 workers and a batch size of 64. Each image is about 2.4MB in size (4000x3000 JPEG). |
Hey @fromm1990, thanks for the quick reply! Given your concern, I wanted to double check whether it was anything in my code that caused delays. And it was. Instead of bounding boxes, I calculate polygons from the masks on the cpu. And the gain in time that I reported above was mainly because of the time gain in mask to polygon calculation. Below I benchmarked your code with our model and data. Images: 1920x1080, ~500 kb jpeg. I tested the first 1000 images. Batch PredictorCodefrom pathlib import Path
from typing import Iterable, List, NamedTuple
import cv2
import detectron2.data.transforms as T
import torch
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import CfgNode, get_cfg
from detectron2.modeling import build_model
from detectron2.structures import Instances
from numpy import ndarray
from torch.utils.data import DataLoader, Dataset
import os
import json
import time
class Prediction(NamedTuple):
x: float
y: float
width: float
height: float
score: float
class_name: str
class ImageDataset(Dataset):
def __init__(self, imagery: List[Path]):
self.imagery = imagery
def __getitem__(self, index) -> ndarray:
return cv2.imread(self.imagery[index].as_posix())
def __len__(self):
return len(self.imagery)
class BatchPredictor:
def __init__(self, cfg: CfgNode, classes: List[str], batch_size: int, workers: int):
self.cfg = cfg.clone() # cfg can be modified by model
self.classes = classes
self.batch_size = batch_size
self.workers = workers
self.model = build_model(self.cfg)
self.model.eval()
checkpointer = DetectionCheckpointer(self.model)
checkpointer.load(cfg.MODEL.WEIGHTS)
self.aug = T.ResizeShortestEdge(
[cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST],
cfg.INPUT.MAX_SIZE_TEST
)
self.input_format = cfg.INPUT.FORMAT
assert self.input_format in ["RGB", "BGR"], self.input_format
def __collate(self, batch):
data = []
for image in batch:
# Apply pre-processing to image.
if self.input_format == "RGB":
# whether the model expects BGR inputs or RGB
image = image[:, :, ::-1]
height, width = image.shape[:2]
image = self.aug.get_transform(image).apply_image(image)
image = image.astype("float32").transpose(2, 0, 1)
image = torch.as_tensor(image)
data.append({"image": image, "height": height, "width": width})
return data
def __call__(self, imagery: List[Path]) -> Iterable[List[Prediction]]:
"""[summary]
:param imagery: [description]
:type imagery: List[Path]
:yield: Predictions for each image
:rtype: [type]
"""
dataset = ImageDataset(imagery)
loader = DataLoader(
dataset,
self.batch_size,
shuffle=False,
num_workers=self.workers,
collate_fn=self.__collate,
pin_memory=True
)
with torch.no_grad():
for batch in loader:
results: List[Instances] = self.model(batch)
yield from [self.__map_predictions(result['instances']) for result in results]
def __map_predictions(self, instances: Instances):
instance_predictions = zip(
instances.get('pred_boxes'),
instances.get('scores'),
instances.get('pred_classes')
)
predictions = []
for box, score, class_index in instance_predictions:
x1 = box[0].item()
y1 = box[1].item()
x2 = box[2].item()
y2 = box[3].item()
width = x2 - x1
height = y2 - y1
prediction = Prediction(
x1, y1, width, height, score.item(), self.classes[class_index])
predictions.append(prediction)
return predictions
if __name__ == "__main__":
model = 'classifier/11-11-2021-model/model_final.pth'
model_config = 'classifier/11-11-2021-model/config.yaml'
classes = 'classifier/11-11-2021-model/classes.json'
root = '/data/images'
filelist = []
for root, dirs, files in os.walk(root):
for file in files:
filelist.append(Path(os.path.join(root, file)))
with open(classes) as classes_file:
class_names = json.load(classes_file)['classes']
cfg = get_cfg()
cfg.merge_from_file(cfg_filename=model_config)
cfg.MODEL.WEIGHTS = model
pred = BatchPredictor(cfg=cfg, classes=class_names, batch_size=1, workers=1)
t1 = time.time()
predictions = list(pred(filelist[0:1000]))
t2 = time.time()
print(f'Time used: {t2-t1} seconds') Resultsbatch_size=1, workers=1: batch_size=9, workers=1: A higher batch size makes the RTX 2070 super run out of memory. Default Predictor combined with joblibCodefrom pathlib import Path
from typing import List, NamedTuple
import cv2
from detectron2.config import CfgNode, get_cfg
import os
import json
import time
from joblib import Parallel, delayed
from more_itertools import divide
from detectron2.engine import DefaultPredictor
class Prediction(NamedTuple):
x: float
y: float
width: float
height: float
score: float
class_name: str
def classify_job(file_list: List[str], cfg: CfgNode, classes: list):
predictor = DefaultPredictor(cfg)
predictions = []
for file in file_list:
im = cv2.imread(file)
instances = predictor(im)['instances']
instance_predictions = zip(
instances.get('pred_boxes'),
instances.get('scores'),
instances.get('pred_classes')
)
for box, score, class_index in instance_predictions:
x1 = box[0].item()
y1 = box[1].item()
x2 = box[2].item()
y2 = box[3].item()
width = x2 - x1
height = y2 - y1
prediction = Prediction(
x1, y1, width, height, score.item(), classes[class_index])
predictions.append(prediction)
return predictions
if __name__ == "__main__":
n_jobs = 1
model = 'classifier/11-11-2021-model/model_final.pth'
model_config = 'classifier/11-11-2021-model/config.yaml'
classes = 'classifier/11-11-2021-model/classes.json'
root = '/data/images'
filelist = []
for root, dirs, files in os.walk(root):
for file in files:
filelist.append(os.path.join(root, file))
with open(classes) as classes_file:
class_names = json.load(classes_file)['classes']
cfg = get_cfg()
cfg.merge_from_file(cfg_filename=model_config)
cfg.MODEL.WEIGHTS = model
data_parts = divide(n_jobs, filelist[0:1000])
t1 = time.time()
output = Parallel(n_jobs=n_jobs)(delayed(classify_job)(list(file_list), cfg=cfg, classes=class_names) for file_list in data_parts)
t2 = time.time()
print(f'Time used: {t2-t1} seconds') Resultsn_jobs = 1: ConclusionsSo all in all, for this model and data set, the gain between the lowest batch size and the optimal batch size is around 9-10%. I am Looking forward to any comments and suggestions for improvement :) |
@jrdalenberg I tried your code, but I get the following error: In addition to this error, sometimes I get an additional error saying this: |
why is batch inference has the same time as if doing a loop with a single image? Doesn't make sense |
🚀 Feature
How to easily take advantage of batch processing during inference?
Motivation
facebookresearch/Detectron#84
The text was updated successfully, but these errors were encountered: