-
Notifications
You must be signed in to change notification settings - Fork 218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Run the LLaMA and Mistral RMS Layer Norm in float32 #1532
Run the LLaMA and Mistral RMS Layer Norm in float32 #1532
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Style nits.
|
||
def build(self, input_shape): | ||
self.weight = self.add_weight( | ||
self._dim = input_shape[-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think self._dim
needs to be persisted anywhere right? Why not just dim = input_shape[-1]
?
def __init__(self, epsilon=1e-6, **kwargs): | ||
super().__init__(**kwargs) | ||
self.epsilon = epsilon | ||
self._epsilon = epsilon |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Try to keep these init args as public attrs. Update mistral instead.
def call(self, x): | ||
x = ops.cast(x, "float32") | ||
x = x * ops.rsqrt( | ||
ops.mean(ops.power(x, 2), axis=-1, keepdims=True) + self._epsilon |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use intermediate variables to keep things on one line and readable. e.g.
var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True)
x = x * ops.rsqrt(var + self.epsilon)
|
||
def build(self, input_shape): | ||
self.weight = self.add_weight( | ||
self._dim = input_shape[-1] | ||
self._weight = self.add_weight( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keep this public too. Update mistral instead.
Do our checkpoints still load fine if we call this scale? And name it scale? That's what Gemma does and it's a a better name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do our checkpoints still load fine if we call this scale? And name it scale?
I think Keras should work when the variable name in Python changes. AFAIK, Keras loads weights using the name
field of the variable. So, changing that would break loading.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe it is actually order based. Did you try it out? https://github.com/keras-team/keras/blob/97b082dfee2552fcad1a7c7ea0fac9c72943360c/keras/layers/layer.py#L1187-L1188
It's not the end of the world. Just for things like the optimize call here we actually reference the scale name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh cool. Thanks for the references. Will try to change and check if I can still load the preset!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I verified that weights load with the name change, so changes pushed. Let's see if the CI is also happy with it. Thanks!
- Change private variables to public vars - Change `self._weight` to `self.scale` - Don't persist the input dim - Move the var computation to its own line for readability
@mattdangerw Addressed the review comments. Let me know if the diff looks good to you now! |
Looks good besides that one potential name change. Thanks! |
* Run the LLaMA RMS Layer Norm in float32 * Also use float32 in Mistral Layer Norm * Address review comments - Change private variables to public vars - Change `self._weight` to `self.scale` - Don't persist the input dim - Move the var computation to its own line for readability * Change weights to scale in layer norm
LLaMA and Mistral Layer Norm should always run in
float32
. This PR corrects this bug in our implementation.