Skip to content

Commit

Permalink
1.refactor transform_kernel_size for ops.conv; 2.refactor max_value a…
Browse files Browse the repository at this point in the history
…nd min_value for CategoricalSpace; 3. refactor load_from_supernet for base_nas_network
  • Loading branch information
marsggbo committed Apr 27, 2023
1 parent bac0a1f commit 580457d
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 13 deletions.
28 changes: 23 additions & 5 deletions hyperbox/mutables/ops/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,20 +202,23 @@ def get_filters_by_groups(self, filters, in_channels, groups):
filters = torch.cat(filter_crops, dim=0)
return filters

def transform_kernel_size(self, filters):
def transform_kernel_size(self, filters, sub_kernel_size=None, max_kernel_size=None):
if self.KERNEL_TRANSFORM_MODE is None:
# print('vanilla transform_kernel_size')
sub_kernel_size = self.value_spaces['kernel_size'].value
if sub_kernel_size is None:
sub_kernel_size = self.value_spaces['kernel_size'].value
start, end = sub_filter_start_end(self.kernel_size, sub_kernel_size)
if self.conv_dim==1: filters = filters[:, :, start:end]
if self.conv_dim==2: filters = filters[:, :, start:end, start:end]
if self.conv_dim==3: filters = filters[:, :, start:end, start:end, start:end]
else:
max_kernel_size = self.kernel_size
if max_kernel_size is None:
max_kernel_size = self.kernel_size
if isinstance(max_kernel_size, (tuple, list)):
max_kernel_size = max(max_kernel_size)
sub_kernel_size = self.value_spaces['kernel_size'].value
ks_set = self.value_spaces['kernel_size'].candidates
if sub_kernel_size is None:
sub_kernel_size = self.value_spaces['kernel_size'].value
ks_set = self.value_spaces['kernel_size'].candidates_original
if sub_kernel_size < max_kernel_size:
start_filter = filters
for i in range(len(ks_set)-1, 0, -1):
Expand Down Expand Up @@ -496,3 +499,18 @@ def format_args(
# print(y.shape)
end = time()
print(f"testing 3d {op}, cost {end-start:.2f} s")

from hyperbox.networks.ofa import OFAMobileNetV3
from hyperbox.mutator import RandomMutator
net = OFAMobileNetV3()
rm = RandomMutator(net)
mask = rm.reset()
subnet = net.build_subnet(mask=mask)
net.eval()
subnet.eval()
with torch.no_grad():
for i in range(10):
x = torch.rand(2,3,32,32)
y1 = net(x)
y2 = subnet(x)
assert torch.allclose(y1, y2), 'error'
4 changes: 4 additions & 0 deletions hyperbox/mutables/spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ def value(self):

@property
def max_value(self):
if not self.is_search:
return self.value
try:
value = max(self.candidates_original)
return value
Expand All @@ -228,6 +230,8 @@ def max_value(self):

@property
def min_value(self):
if not self.is_search:
return self.value
try:
value = min(self.candidates_original)
return value
Expand Down
22 changes: 14 additions & 8 deletions hyperbox/networks/base_nas_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,20 @@ def sub_filter_start_end(kernel_size, sub_kernel_size):
if dim >= 3:
# e.g., conv weight
_out, _in, k = shape[:3]
k_larger = state_dict[key].shape[-1]
start, end = sub_filter_start_end(k_larger, k)
if dim == 3: # conv1d
model_dict[key].data = state_dict[key].data[:_out, :_in, start:end]
elif dim == 4: #conv2d
model_dict[key].data = state_dict[key].data[:_out, :_in, start:end, start:end]
else:
model_dict[key].data = state_dict[key].data[:_out, :_in, start:end, start:end, start:end]
weight = state_dict[key].data[:_out, :_in, ...]
if hasattr(module.value_spaces, 'kernel_size'):
ks_set = module.value_spaces['kernel_size'].candidates_original
for i in range(len(ks_set)-1, 0, -1):
src_ks = ks_set[i]
if src_ks <= k:
# if src_ks <= k, then no transformation is needed
break
target_ks = ks_set[i - 1]
transform_name = f"{src_ks}to{target_ks}_matrix"
matrix_name = f"{name}.{transform_name}"
getattr(module, transform_name).data = state_dict[matrix_name].data
weight = module.transform_kernel_size(weight, k, ks_set[-1])
model_dict[key].data = weight.data
super(BaseNASNetwork, self).load_state_dict(model_dict, **kwargs, strict=False)

def init_weights(self):
Expand Down

0 comments on commit 580457d

Please sign in to comment.