Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemma bug fixes - Approx GELU, Layernorms, Sqrt(hd) #29402

Closed
wants to merge 13 commits into from
7 changes: 4 additions & 3 deletions src/transformers/models/gemma/configuration_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ class GemmaConfig(PretrainedConfig):
`num_attention_heads`.
head_dim (`int`, *optional*, defaults to 256):
The attention head dimension.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the decoder.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the decoder. "gelu_pytorch_tanh" uses an
approximation to the more exact "gelu" activation function.
max_position_embeddings (`int`, *optional*, defaults to 8192):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
Expand Down Expand Up @@ -108,7 +109,7 @@ def __init__(
num_attention_heads=16,
num_key_value_heads=16,
head_dim=256,
hidden_act="gelu",
hidden_act="gelu_pytorch_tanh",
max_position_embeddings=8192,
initializer_range=0.02,
rms_norm_eps=1e-6,
Expand Down
11 changes: 10 additions & 1 deletion src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,16 @@ def __init__(self, config):
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
hidden_act = config.hidden_act
if hidden_act != "gelu_pytorch_tanh":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This way even if there is a model with the old gelu on the config we're force-setting hidden_act to "gelu_pytorch_tanh" right?
I think we should either use a new config name or create a new attribute in the config force_use_exact_gelu, that is iniailizaed to False so that users can have the flexibility to switch to the old act function in case they fine-tuned it with old GeLU, what do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I like that approach ie getattr(config, "force_use_exact_gelu", False) so if force_use_exact_gelu = True then True. If force_use_exact_gelu = False then also False, and False otherwise.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes! I think we can add that directly into GemmaConfig class and default it to False

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added force_use_exact_gelu!

logger.warning_once(
"Gemma's activation function should be approximate GeLU and not exact GeLU.\n"\
"Please edit your model config to use `gelu_pytorch_tanh` and not `gelu`.\n"\
"For now, we shall use `gelu_pytorch_tanh` temporarily.\n"\
"See https://github.com/huggingface/transformers/pull/29402 for more details."
)
hidden_act = "gelu_pytorch_tanh"
self.act_fn = ACT2FN[hidden_act]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind automatically switching, but it's best if the users still have a way to use the legacy gelu! Either a big warning or use another config name

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we need a self.hidden_activation set to None by default and if None warn that we will use the new approx else use what was give

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh ok good point! Sorry didn't work on this in the meantime - I found a few more issues, and will push them here tomorrow :)


def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
Expand Down