Skip to content

Commit

Permalink
[Feature] Test batch (#511)
Browse files Browse the repository at this point in the history
* resolve comments

* update changelog

* add test_batch

* add testing for `test_batch`

* fix mmcv version

* add test_batch

* add testing for `test_batch`

* enlarge test_input to pass unittest

* update names

* update changelog & faq

* update name
  • Loading branch information
kennymckormick committed Jan 13, 2021
1 parent f9eb562 commit 431004e
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 6 deletions.
4 changes: 4 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

### master

**Improvements**

- Support setting `max_testing_views` for extremely large models to save GPU memory used ([#511](https://github.com/open-mmlab/mmaction2/pull/511))

### 0.10.0 (31/12/2020)

**Highlights**
Expand Down
4 changes: 4 additions & 0 deletions docs/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ If the contents here do not cover your issue, please create an issue using the [

change this in the config, make `test_cfg = dict(average_clips='prob')`.

- **What if the model is too large and the GPU memory can not fit even only one testing sample ?**

By default, the 3d models are tested with 10clips x 3crops, which are 30 views in total. For extremely large models, the GPU memory can not fit even only one testing sample (cuz there are 30 views). To handle this, you can set `max_testing_views=n` in test_cfg of the config file. If so, n views will be used as a batch during forwarding to save GPU memory used.

## Deploying

- **Why is the onnx model converted by mmaction2 throwing error when converting to other frameworks such as TensorRT?**
Expand Down
5 changes: 5 additions & 0 deletions mmaction/models/recognizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ def __init__(self,
self.aux_info = []
if train_cfg is not None and 'aux_info' in train_cfg:
self.aux_info = train_cfg['aux_info']
# max_testing_views should be int
self.max_testing_views = None
if test_cfg is not None and 'max_testing_views' in test_cfg:
self.max_testing_views = test_cfg['max_testing_views']
assert isinstance(self.max_testing_views, int)

self.init_weights()

Expand Down
28 changes: 23 additions & 5 deletions mmaction/models/recognizers/recognizer3d.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import torch

from ..registry import RECOGNIZERS
from .base import BaseRecognizer

Expand Down Expand Up @@ -29,13 +31,29 @@ def _do_test(self, imgs):
num_segs = imgs.shape[1]
imgs = imgs.reshape((-1, ) + imgs.shape[2:])

x = self.extract_feat(imgs)
if hasattr(self, 'neck'):
x, _ = self.neck(x)
if self.max_testing_views is not None:
total_views = imgs.shape[0]
assert num_segs == total_views, (
'max_testing_views is only compatible '
'with batch_size == 1')
view_ptr = 0
cls_scores = []
while view_ptr < total_views:
batch_imgs = imgs[view_ptr:view_ptr + self.max_testing_views]
x = self.extract_feat(batch_imgs)
if hasattr(self, 'neck'):
x, _ = self.neck(x)
cls_score = self.cls_head(x)
cls_scores.append(cls_score)
view_ptr += self.max_testing_views
cls_score = torch.cat(cls_scores)
else:
x = self.extract_feat(imgs)
if hasattr(self, 'neck'):
x, _ = self.neck(x)
cls_score = self.cls_head(x)

cls_score = self.cls_head(x)
cls_score = self.average_clip(cls_score, num_segs)

return cls_score

def forward_test(self, imgs):
Expand Down
11 changes: 10 additions & 1 deletion tests/test_models/test_recognizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def test_slowfast():
recognizer = build_recognizer(
model, train_cfg=train_cfg, test_cfg=test_cfg)

input_shape = (1, 3, 3, 8, 32, 32)
input_shape = (1, 3, 3, 16, 32, 32)
demo_inputs = generate_demo_inputs(input_shape, '3D')

imgs = demo_inputs['imgs']
Expand Down Expand Up @@ -271,6 +271,15 @@ def test_slowfast():
for one_img in img_list:
recognizer(one_img, gradcam=True)

# Test the feature max_testing_views
test_cfg['max_testing_views'] = 1
recognizer = build_recognizer(
model, train_cfg=train_cfg, test_cfg=test_cfg)
with torch.no_grad():
img_list = [img[None, :] for img in imgs]
for one_img in img_list:
recognizer(one_img, None, return_loss=False)


def test_tsm():
model, train_cfg, test_cfg = _get_recognizer_cfg(
Expand Down

0 comments on commit 431004e

Please sign in to comment.