-
Notifications
You must be signed in to change notification settings - Fork 25.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gemma bug fixes - Approx GELU, Layernorms, Sqrt(hd) #29402
Changes from 2 commits
1464520
606463f
16ed142
03139e6
ca3cae3
73c24f6
2b8c7f1
c1b8bef
32656cc
9aa08bf
8cdc615
cd6f5f4
62bd01f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -169,7 +169,16 @@ def __init__(self, config): | |
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) | ||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) | ||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) | ||
self.act_fn = ACT2FN[config.hidden_act] | ||
hidden_act = config.hidden_act | ||
if hidden_act != "gelu_pytorch_tanh": | ||
logger.warning_once( | ||
"Gemma's activation function should be approximate GeLU and not exact GeLU.\n"\ | ||
"Please edit your model config to use `gelu_pytorch_tanh` and not `gelu`.\n"\ | ||
"For now, we shall use `gelu_pytorch_tanh` temporarily.\n"\ | ||
"See https://github.com/huggingface/transformers/pull/29402 for more details." | ||
) | ||
hidden_act = "gelu_pytorch_tanh" | ||
self.act_fn = ACT2FN[hidden_act] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't mind automatically switching, but it's best if the users still have a way to use the legacy gelu! Either a big warning or use another config name There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So we need a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh ok good point! Sorry didn't work on this in the meantime - I found a few more issues, and will push them here tomorrow :) |
||
|
||
def forward(self, x): | ||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This way even if there is a model with the old gelu on the config we're force-setting hidden_act to
"gelu_pytorch_tanh"
right?I think we should either use a new config name or create a new attribute in the config
force_use_exact_gelu
, that is iniailizaed to False so that users can have the flexibility to switch to the old act function in case they fine-tuned it with old GeLU, what do you think?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I like that approach ie
getattr(config, "force_use_exact_gelu", False)
so ifforce_use_exact_gelu = True
then True. Ifforce_use_exact_gelu = False
then also False, and False otherwise.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes! I think we can add that directly into
GemmaConfig
class and default it to FalseThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added
force_use_exact_gelu
!