Skip to content

Commit

Permalink
[Enhance] Rewrite channel split operation in ShufflenetV2 (#632)
Browse files Browse the repository at this point in the history
* replace chunk op

* shufflenetv2 config
  • Loading branch information
Ezra-Yu committed Jan 25, 2022
1 parent b5bd87d commit bd397f7
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion mmcls/models/backbones/shufflenet_v2.py
Expand Up @@ -115,7 +115,14 @@ def _inner_forward(x):
if self.stride > 1:
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
else:
x1, x2 = x.chunk(2, dim=1)
# Channel Split operation. using these lines of code to replace
# ``chunk(x, 2, dim=1)`` can make it easier to deploy a
# shufflenetv2 model by using mmdeploy.
channels = x.shape[1]
c = channels // 2 + channels % 2
x1 = x[:, :c, :, :]
x2 = x[:, c:, :, :]

out = torch.cat((x1, self.branch2(x2)), dim=1)

out = channel_shuffle(out, 2)
Expand Down

0 comments on commit bd397f7

Please sign in to comment.