Skip to content

Commit

Permalink
add property to new supported ops for counting #parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
marsggbo committed Apr 1, 2023
1 parent 682cbcb commit 22754e4
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 2 deletions.
8 changes: 8 additions & 0 deletions hyperbox/mutables/ops/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ def isSearchEmbedding(self):
self.search_embedding_dim = True
return True

@property
def params(self):
weight = self.weight
if self.search_embedding_dim:
weight = weight[:, :self.value_spaces['embedding_dim'].value]
size = weight.numel()
return size


if __name__ == "__main__":
from hyperbox.mutator import RandomMutator
Expand Down
11 changes: 11 additions & 0 deletions hyperbox/mutables/ops/groupnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ def forward(self, input):
n_channels = self.value_spaces['num_channels'].value if self.search_num_channels else self.num_channels
return F.group_norm(input, n_groups, self.weight[:n_channels], self.bias[:n_channels], self.eps)

@property
def params(self):
weight = self.weight
bias = self.bias
if self.search_num_channels:
weight = weight[:self.value_spaces['num_channels'].value]
bias = bias[:self.value_spaces['num_channels'].value] if bias is not None else None
parameters = [weight, bias]
params = sum([p.numel() for p in parameters if p is not None])
return params


if __name__ == '__main__':
from hyperbox.mutator import RandomMutator
Expand Down
11 changes: 11 additions & 0 deletions hyperbox/mutables/ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,17 @@ def isSearchLayerNorm(self):
self.search_normalized_shape = True
return True

@property
def params(self):
weight = self.weight
bias = self.bias
if self.search_normalized_shape:
weight = self.weight[:self.value_spaces["normalized_shape"].value]
bias = self.bias[:self.value_spaces["normalized_shape"].value] if self.bias is not None else None
parameters = [weight, bias]
size = sum([p.numel() for p in parameters if p is not None])
return size


if __name__ == "__main__":
# NLP Searchable Example
Expand Down
20 changes: 20 additions & 0 deletions hyperbox/mutables/ops/multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,26 @@ def _reset_parameters(self):
if self.bias_v is not None:
xavier_normal_(self.bias_v)

@property
def params(self):
in_proj_weight = self.in_proj_weight
in_proj_bias = self.in_proj_bias
out_proj_weight = self.out_proj.weight
out_proj_bias = self.out_proj.bias
bias_k = self.bias_k
bias_v = self.bias_v
if self.search_embed_dim:
embed_dim = self.value_spaces['embed_dim'].value
in_proj_weight = in_proj_weight[:3 * embed_dim, :embed_dim]
in_proj_bias = in_proj_bias[:3 * embed_dim] if in_proj_bias is not None else None
out_proj_weight = out_proj_weight[:embed_dim, :embed_dim]
out_proj_bias = out_proj_bias[:embed_dim] if out_proj_bias is not None else None
bias_k = bias_k[:, :, :embed_dim] if bias_k is not None else None
bias_v = bias_v[:, :, :embed_dim] if bias_v is not None else None
parameters = [in_proj_weight, in_proj_bias, out_proj_weight, out_proj_bias, bias_k, bias_v]
size = sum([p.numel() for p in parameters if p is not None])
return size


if __name__ == '__main__':
import timeit
Expand Down
54 changes: 52 additions & 2 deletions hyperbox/utils/calc_model_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ def count_FG_convNd(m, _, y):
cin = m.value_spaces['in_channels'].value if m.search_in_channel else m.in_channels
kernel_size = m.value_spaces['kernel_size'].value if m.search_kernel_size else m.kernel_size
if isinstance(kernel_size, int):
dim = int(m.conv_dim[0])
dim = int(m.conv_dim)
kernel_ops = kernel_size**dim
elif isinstance(kernel_size, (list, tuple)):
kernel_ops = torch.prod(torch.Tensor(kernel_size))
ops_per_element = kernel_ops
groups = m.value_spaces['groups'].value if m.search_groups else m.groups
output_elements = y.nelement()
total_ops = cin * output_elements * ops_per_element // groups # cout x oW x oH
total_ops = torch.div(cin * output_elements * ops_per_element, groups)
m.total_ops = torch.Tensor([int(total_ops)])
m.module_used = torch.tensor([1])

Expand All @@ -55,6 +55,52 @@ def count_FG_linear(m, _, __):
m.total_ops = torch.Tensor([int(total_ops)])
m.module_used = torch.tensor([1])

# Todo: add more FG ops
# def count_FG_Embedding(m, _, __):
# num_embeddings = m.value_spaces['num_embeddings'].value if m.search_num_embeddings \
# else m.num_embeddings
# embedding_dim = m.value_spaces['embedding_dim'].value if m.search_embedding_dim \
# else m.embedding_dim
# total_ops = num_embeddings * embedding_dim
# m.total_ops = torch.Tensor([int(total_ops)])
# m.module_used = torch.tensor([1])


# def count_FG_GroupNorm(m, _, __):
# num_channels = m.value_spaces['num_channels'].value if m.search_num_channels \
# else m.num_channels
# num_groups = m.value_spaces['num_groups'].value if m.search_num_groups \
# else m.num_groups
# total_ops = num_channels * num_groups
# m.total_ops = torch.Tensor([int(total_ops)])
# m.module_used = torch.tensor([1])


# def count_FG_LayerNorm(m, _, __):
# normalized_shape = m.value_spaces['normalized_shape'].value if m.search_normalized_shape \
# else m.normalized_shape
# total_ops = torch.prod(torch.Tensor(normalized_shape))
# m.total_ops = torch.Tensor([int(total_ops)])
# m.module_used = torch.tensor([1])


# def count_FG_MultiheadAttention(m, _, __):
# total_ops = 0
# if m.search_in_features:
# in_features = m.value_spaces['in_features'].value
# total_ops += in_features
# if m.search_out_features:
# out_features = m.value_spaces['out_features'].value
# total_ops += out_features
# if m.search_num_heads:
# num_heads = m.value_spaces['num_heads'].value
# total_ops += num_heads
# if m.search_dropout:
# dropout = m.value_spaces['dropout'].value
# total_ops += dropout
# m.total_ops = torch.Tensor([int(total_ops)])
# m.module_used = torch.tensor([1])


register_hooks = {
nn.Conv1d: count_convNd,
Expand All @@ -65,6 +111,10 @@ def count_FG_linear(m, _, __):
hyperbox.mutables.ops.conv.Conv2d: count_FG_convNd,
hyperbox.mutables.ops.conv.Conv3d: count_FG_convNd,
hyperbox.mutables.ops.linear.Linear: count_FG_linear,
# hyperbox.mutables.ops.embedding.Embedding: count_FG_Embedding,
# hyperbox.mutables.ops.GroupNorm: count_FG_GroupNorm,
# hyperbox.mutables.ops.LayerNorm: count_FG_LayerNorm,
# hyperbox.mutables.ops.MultiheadAttention: count_FG_MultiheadAttention,
}


Expand Down

0 comments on commit 22754e4

Please sign in to comment.