Skip to content

Commit

Permalink
fixbug of ops.conv
Browse files Browse the repository at this point in the history
  • Loading branch information
marsggbo committed Apr 27, 2023
1 parent 492dafb commit bac0a1f
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions hyperbox/mutables/ops/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,21 @@ def __init__(
nn.Conv2d.__init__(self, **self.conv_kwargs)
self.init_kernel_transform_matrix(kernel_size)

def forward(self, x):
out = None
if not self.is_search:
padding = self.padding
if self.auto_padding:
kernel_size = self.weight.shape[2:]
padding = []
for k in kernel_size:
padding.append(k//2)
self.padding = padding
out = nn.Conv2d.forward(self, x)
else:
out = self.forward_conv(x)
return out

def format_args(
self,
kernel_size: _size_2_t,
Expand Down Expand Up @@ -398,6 +413,21 @@ def __init__(
nn.Conv3d.__init__(self, **self.conv_kwargs)
self.init_kernel_transform_matrix(kernel_size)

def forward(self, x):
out = None
if not self.is_search:
padding = self.padding
if self.auto_padding:
kernel_size = self.weight.shape[2:]
padding = []
for k in kernel_size:
padding.append(k//2)
self.padding = padding
out = nn.Conv3d.forward(self, x)
else:
out = self.forward_conv(x)
return out

def format_args(
self,
kernel_size: _size_3_t,
Expand Down

0 comments on commit bac0a1f

Please sign in to comment.