Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
Merge pull request #22 from wkcn/fix_conv_bn_fuse
Browse files Browse the repository at this point in the history
[fix] Fix Conv2d_BN fuse bug when groups > 1
  • Loading branch information
btgraham committed Mar 25, 2022
2 parents d373b82 + 3bc3ec2 commit d000b74
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion levit.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def fuse(self):
w = c.weight * w[:, None, None, None]
b = bn.bias - bn.running_mean * bn.weight / \
(bn.running_var + bn.eps)**0.5
m = torch.nn.Conv2d(w.size(1), w.size(
m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
m.weight.data.copy_(w)
m.bias.data.copy_(b)
Expand Down
2 changes: 1 addition & 1 deletion levit_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def fuse(self):
w = c.weight * w[:, None, None, None]
b = bn.bias - bn.running_mean * bn.weight / \
(bn.running_var + bn.eps)**0.5
m = torch.nn.Conv2d(w.size(1), w.size(
m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
m.weight.data.copy_(w)
m.bias.data.copy_(b)
Expand Down

0 comments on commit d000b74

Please sign in to comment.