Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run the LLaMA and Mistral RMS Layer Norm in float32 #1532

Merged
merged 4 commits into from
Mar 29, 2024

Conversation

tirthasheshpatel
Copy link
Contributor

@tirthasheshpatel tirthasheshpatel commented Mar 29, 2024

LLaMA and Mistral Layer Norm should always run in float32. This PR corrects this bug in our implementation.

@tirthasheshpatel tirthasheshpatel changed the title Run the LLaMA RMS Layer Norm in float32 Run the LLaMA and Mistral RMS Layer Norm in float32 Mar 29, 2024
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Style nits.


def build(self, input_shape):
self.weight = self.add_weight(
self._dim = input_shape[-1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think self._dim needs to be persisted anywhere right? Why not just dim = input_shape[-1]?

def __init__(self, epsilon=1e-6, **kwargs):
super().__init__(**kwargs)
self.epsilon = epsilon
self._epsilon = epsilon
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to keep these init args as public attrs. Update mistral instead.

def call(self, x):
x = ops.cast(x, "float32")
x = x * ops.rsqrt(
ops.mean(ops.power(x, 2), axis=-1, keepdims=True) + self._epsilon
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use intermediate variables to keep things on one line and readable. e.g.

var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True)
x = x * ops.rsqrt(var + self.epsilon)


def build(self, input_shape):
self.weight = self.add_weight(
self._dim = input_shape[-1]
self._weight = self.add_weight(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep this public too. Update mistral instead.

Do our checkpoints still load fine if we call this scale? And name it scale? That's what Gemma does and it's a a better name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do our checkpoints still load fine if we call this scale? And name it scale?

I think Keras should work when the variable name in Python changes. AFAIK, Keras loads weights using the name field of the variable. So, changing that would break loading.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it is actually order based. Did you try it out? https://github.com/keras-team/keras/blob/97b082dfee2552fcad1a7c7ea0fac9c72943360c/keras/layers/layer.py#L1187-L1188

It's not the end of the world. Just for things like the optimize call here we actually reference the scale name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh cool. Thanks for the references. Will try to change and check if I can still load the preset!

Copy link
Contributor Author

@tirthasheshpatel tirthasheshpatel Mar 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I verified that weights load with the name change, so changes pushed. Let's see if the CI is also happy with it. Thanks!

- Change private variables to public vars
- Change `self._weight` to `self.scale`
- Don't persist the input dim
- Move the var computation to its own line for readability
@tirthasheshpatel
Copy link
Contributor Author

@mattdangerw Addressed the review comments. Let me know if the diff looks good to you now!

@mattdangerw
Copy link
Member

Looks good besides that one potential name change. Thanks!

@mattdangerw mattdangerw merged commit 78fdb2d into keras-team:master Mar 29, 2024
10 checks passed
abuelnasr0 pushed a commit to abuelnasr0/keras-nlp that referenced this pull request Apr 2, 2024
* Run the LLaMA RMS Layer Norm in float32

* Also use float32 in Mistral Layer Norm

* Address review comments

- Change private variables to public vars
- Change `self._weight` to `self.scale`
- Don't persist the input dim
- Move the var computation to its own line for readability

* Change weights to scale in layer norm
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants