Skip to content

Commit

Permalink
refactor resnet
Browse files Browse the repository at this point in the history
  • Loading branch information
marsggbo committed Jan 31, 2023
1 parent 141c1bd commit 98be324
Showing 1 changed file with 7 additions and 15 deletions.
22 changes: 7 additions & 15 deletions hyperbox/networks/resnet/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from hyperbox.mutables.ops import Conv2d, Linear, BatchNorm2d
from hyperbox.mutables.spaces import ValueSpace
from hyperbox.utils.utils import load_json, hparams_wrapper

from hyperbox.networks.base_nas_network import BaseNASNetwork


__all__ = [
"ResNet",
"resnet18",
Expand Down Expand Up @@ -267,19 +267,7 @@ def load_state_dict(self, state_dict, **kwargs):
model_dict[key] = state_dict[key]
super(ResNet, self).load_state_dict(model_dict, **kwargs)

def build_subnet(self, mask):
hparams = self.hparams.copy()
new_mask = {}
len_mask = len(mask)
for key in mask:
_id = key.split('ValueSpace')[-1]
new_id = int(_id) + len_mask * self.counter_subnet
new_key = f"ValueSpace{new_id}"
new_mask[new_key] = mask[key].clone().detach()
hparams['mask'] = new_mask
subnet = ResNet(**hparams)
self.counter_subnet += 1
return subnet


def load_from_supernet(self, state_dict, **kwargs):
def sub_filter_start_end(kernel_size, sub_kernel_size):
Expand Down Expand Up @@ -371,9 +359,13 @@ def resnet50(pretrained=False, progress=True, device="cpu", **kwargs):
# rm = DartsMutator(net) # ValueSpace-based operations are not compatible with DartsMutator
# rm = OnehotMutator(net)
x = torch.rand(2, 3, 32, 32)
for i in range(10):
for i in range(2):
rm.reset()
y = net(x)
print(y.shape)
arch = net.arch
print(arch, len(rm._cache))
mask = rm.export()
subnet = net.build_subnet(mask)
y = subnet(x)
print(subnet.arch)

0 comments on commit 98be324

Please sign in to comment.