Skip to content

Commit

Permalink
Fix ConvNext V2 paramater naming issue (#23122)
Browse files Browse the repository at this point in the history
Fixes the parameter naming issue in ConvNextV2GRN module
  • Loading branch information
alaradirik committed May 3, 2023
1 parent b53004f commit 56b8d49
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ def rename_key(name):
if "stages" in name and "downsampling_layer" not in name:
# stages.0.0. for instance should be renamed to stages.0.layers.0.
name = name[: len("stages.0")] + ".layers" + name[len("stages.0") :]
if "gamma" in name:
name = name.replace("gamma", "weight")
if "beta" in name:
name = name.replace("beta", "bias")
if "stages" in name:
name = name.replace("stages", "encoder.stages")
if "norm" in name:
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/convnextv2/modeling_convnextv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,14 @@ class ConvNextV2GRN(nn.Module):

def __init__(self, dim: int):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
self.weight = nn.Parameter(torch.zeros(1, 1, 1, dim))
self.bias = nn.Parameter(torch.zeros(1, 1, 1, dim))

def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
# Compute and normalize global spatial feature maps
global_features = torch.norm(hidden_states, p=2, dim=(1, 2), keepdim=True)
norm_features = global_features / (global_features.mean(dim=-1, keepdim=True) + 1e-6)
hidden_states = self.gamma * (hidden_states * norm_features) + self.beta + hidden_states
hidden_states = self.weight * (hidden_states * norm_features) + self.bias + hidden_states

return hidden_states

Expand Down

0 comments on commit 56b8d49

Please sign in to comment.