Skip to content

Commit

Permalink
Add option to set channels_per_group_list on ResNet.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 274159409
Change-Id: If31b33c236c5153aa1ad65b87293dfbca44674a7
  • Loading branch information
Sonnet Contributor authored and sonnet-copybara committed Oct 11, 2019
1 parent 1833bf6 commit 4cdd392
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
12 changes: 11 additions & 1 deletion sonnet/src/nets/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def __init__(self,
num_classes: int,
bn_config: Optional[Mapping[Text, float]] = None,
resnet_v2: bool = False,
channels_per_group_list: Sequence[int] = (256, 512, 1024, 2048),
name: Optional[Text] = None):
"""Constructs a ResNet model.
Expand All @@ -245,6 +246,8 @@ def __init__(self,
`0.9` and `eps` is `1e-5`.
resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults to
False.
channels_per_group_list: A sequence of length 4 that indicates the number
of channels used for each block in each group.
name: Name of the module.
"""
super(ResNet, self).__init__(name=name)
Expand All @@ -264,6 +267,13 @@ def __init__(self,
len(blocks_per_group_list)))
self._blocks_per_group_list = blocks_per_group_list

# Number of channels in each group for ResNet.
if len(channels_per_group_list) != 4:
raise ValueError(
"`channels_per_group_list` must be of length 4 not {}".format(
len(channels_per_group_list)))
self._channels_per_group_list = channels_per_group_list

self._initial_conv = conv.Conv2D(
output_channels=64,
kernel_shape=7,
Expand All @@ -283,7 +293,7 @@ def __init__(self,
for i in range(4):
self._block_groups.append(
BlockGroup(
channels=256 * 2**i,
channels=self._channels_per_group_list[i],
num_blocks=self._blocks_per_group_list[i],
stride=strides[i],
bn_config=bn_config,
Expand Down
12 changes: 11 additions & 1 deletion sonnet/src/nets/resnet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,23 @@ def test_tf_function(self, resnet_v2):
self.assertAllEqual(model(image, is_training=True).numpy(), logits.numpy())

@parameterized.parameters(3, 5)
def test_error_incorrect_args(self, list_length):
def test_error_incorrect_args_block_list(self, list_length):
block_list = [i for i in range(list_length)]
with self.assertRaisesRegexp(
ValueError, "blocks_per_group_list` must be of length 4 not {}".format(
list_length)):
resnet.ResNet(block_list, 10, {"decay_rate": 0.9, "eps": 1e-5})

@parameterized.parameters(3, 5)
def test_error_incorrect_args_channel_list(self, list_length):
channel_list = [i for i in range(list_length)]
with self.assertRaisesRegexp(
ValueError,
"channels_per_group_list` must be of length 4 not {}".format(
list_length)):
resnet.ResNet([1, 1, 1, 1], 10, {"decay_rate": 0.9, "eps": 1e-5},
channels_per_group_list=channel_list)

def test_v2_throws(self):
resnet.TESTONLY_ENABLE_RESNET_V2 = False
with self.assertRaisesRegexp(NotImplementedError, "please use v1"):
Expand Down

0 comments on commit 4cdd392

Please sign in to comment.