Skip to content

Commit

Permalink
mmclassification ConformerHead support (#1905)
Browse files Browse the repository at this point in the history
* mmclassification ConformerHead support

* add mmclassification ConformerHead test config

---------

Co-authored-by: lishengxi <mtdp@MacBook-Pro-8.local>
  • Loading branch information
xizi and lishengxi committed Mar 23, 2023
1 parent a14177c commit 032ce75
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
4 changes: 3 additions & 1 deletion mmdeploy/codebase/mmcls/models/classifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ def base_classifier__forward(
if self.head is not None:
output = self.head(output)

from mmcls.models.heads import MultiLabelClsHead
from mmcls.models.heads import ConformerHead, MultiLabelClsHead
if isinstance(self.head, MultiLabelClsHead):
output = torch.sigmoid(output)
elif isinstance(self.head, ConformerHead):
output = F.softmax(torch.add(output[0], output[1]), dim=1)
else:
output = F.softmax(output, dim=1)
return output
8 changes: 8 additions & 0 deletions tests/regression/mmcls.yml
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,11 @@ models:
- *pipeline_ort_static_fp32
- convert_image: *convert_image
deploy_config: configs/mmcls/classification_tensorrt_dynamic-224x224-224x224.py

- name: Conformer
metafile: configs/conformer/metafile.yml
model_configs:
- configs/conformer/conformer-tiny-p16_8xb128_in1k.py
pipelines:
- *pipeline_ort_dynamic_fp32
- *pipeline_trt_dynamic_fp32

0 comments on commit 032ce75

Please sign in to comment.