Skip to content

Commit

Permalink
Add softmax to cls model (#1573)
Browse files Browse the repository at this point in the history
* Add softmax to cls model

* fix cls ci

* multihead

* update classification_model.py
  • Loading branch information
grimoire committed Dec 30, 2022
1 parent baa86aa commit deaefac
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 9 deletions.
4 changes: 2 additions & 2 deletions mmdeploy/codebase/mmcls/deploy/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,8 @@ def get_postprocess(self, *args, **kwargs) -> Dict:
if 'topk' not in postprocess:
topk = (1, )
logger = get_root_logger()
logger.warning('no topk in postprocess config, using default \
topk value.')
logger.warning('no topk in postprocess config, using default '
'topk value.')
else:
topk = postprocess.topk
postprocess.topk = max(topk)
Expand Down
49 changes: 44 additions & 5 deletions mmdeploy/codebase/mmcls/deploy/classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(self,
backend: Backend,
backend_files: Sequence[str],
device: str,
model_cfg: Union[str, Config] = None,
deploy_cfg: Union[str, Config] = None,
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
**kwargs):
Expand All @@ -46,8 +47,18 @@ def __init__(self,
backend_files=backend_files,
device=device,
**kwargs)
self.model_cfg = model_cfg
self.head = None
if model_cfg is not None:
self.head = self._get_head()
self.device = device

def _get_head(self):
from mmcls.models import build_head
head_config = self.model_cfg['model']['head']
head = build_head(head_config)
return head

def _init_wrapper(self, backend: Backend, backend_files: Sequence[str],
device: str, **kwargs):
output_names = self.output_names
Expand Down Expand Up @@ -84,11 +95,38 @@ def forward(self,
cls_score = self.wrapper({self.input_name:
inputs})[self.output_names[0]]

from mmcls.models.heads.cls_head import ClsHead
predict = ClsHead._get_predictions(
None, cls_score, data_samples=data_samples)

return predict
from mmcls.models.heads import MultiLabelClsHead
from mmcls.structures import ClsDataSample
pred_scores = cls_score

if self.head is None or not isinstance(self.head, MultiLabelClsHead):
pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach()

if data_samples is not None:
for data_sample, score, label in zip(data_samples, pred_scores,
pred_labels):
data_sample.set_pred_score(score).set_pred_label(label)
else:
data_samples = []
for score, label in zip(pred_scores, pred_labels):
data_samples.append(ClsDataSample().set_pred_score(
score).set_pred_label(label))
else:
if data_samples is None:
data_samples = [
ClsDataSample() for _ in range(cls_score.size(0))
]

for data_sample, score in zip(data_samples, pred_scores):
if self.head.thr is not None:
# a label is predicted positive if larger than thr
label = torch.where(score >= self.head.thr)[0]
else:
# top-k labels will be predicted positive for any example
_, label = score.topk(self.head.topk)
data_sample.set_pred_score(score).set_pred_label(label)

return data_samples


@__BACKEND_MODEL.register_module('sdk')
Expand Down Expand Up @@ -204,6 +242,7 @@ def build_classification_model(
backend=backend,
backend_files=model_files,
device=device,
model_cfg=model_cfg,
deploy_cfg=deploy_cfg,
data_preprocessor=data_preprocessor,
**kwargs))
Expand Down
8 changes: 8 additions & 0 deletions mmdeploy/codebase/mmcls/models/classifiers/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional

import torch
from mmengine.structures import BaseDataElement
from torch import Tensor
from torch.nn import functional as F

from mmdeploy.core import FUNCTION_REWRITER

Expand Down Expand Up @@ -32,4 +34,10 @@ def base_classifier__forward(
output = self.extract_feat(batch_inputs)
if self.head is not None:
output = self.head(output)

from mmcls.models.heads import MultiLabelClsHead
if isinstance(self.head, MultiLabelClsHead):
output = torch.sigmoid(output)
else:
output = F.softmax(output, dim=1)
return output
9 changes: 7 additions & 2 deletions tests/test_codebase/test_mmcls/test_mmcls_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
from mmdeploy.utils import Backend, Codebase
from mmdeploy.utils.test import WrapModel, check_backend, get_rewrite_outputs

try:
from torch.testing import assert_close as torch_assert_close
except Exception:
from torch.testing import assert_allclose as torch_assert_close
try:
import_codebase(Codebase.MMCLS)
except ImportError:
Expand Down Expand Up @@ -77,6 +81,7 @@ def __init__(self, backbone):
def extract_feat(self, batch_inputs: torch.Tensor):
return batch_inputs

input = torch.rand(1, 1000)
backbone_cfg = dict(
type='ResNet',
depth=18,
Expand All @@ -90,8 +95,8 @@ def extract_feat(self, batch_inputs: torch.Tensor):
with RewriterContext({}):
backend_output = model(input)

assert model_output == input
assert backend_output == input
torch_assert_close(model_output, input)
torch_assert_close(backend_output, torch.nn.functional.softmax(input, -1))


@pytest.mark.parametrize(
Expand Down

0 comments on commit deaefac

Please sign in to comment.