Skip to content

Commit

Permalink
Run the LLaMA and Mistral RMS Layer Norm in float32 (#1532)
Browse files Browse the repository at this point in the history
* Run the LLaMA RMS Layer Norm in float32

* Also use float32 in Mistral Layer Norm

* Address review comments

- Change private variables to public vars
- Change `self._weight` to `self.scale`
- Don't persist the input dim
- Move the var computation to its own line for readability

* Change weights to scale in layer norm
  • Loading branch information
tirthasheshpatel committed Mar 29, 2024
1 parent e674fd2 commit 78fdb2d
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 20 deletions.
31 changes: 21 additions & 10 deletions keras_nlp/models/llama/llama_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,35 @@
from keras_nlp.backend import keras
from keras_nlp.backend import ops

# TODO: Should be replaced with LayerNormalization with `rms_scaling` param
# https://github.com/keras-team/keras-core/pull/726


# TODO: Deprecate this in favor of
# `keras.layers.LayerNormalization(rms_scaling=True)` once Keras 2 support is
# removed.
class LlamaLayerNorm(keras.layers.Layer):
"""A normalization layer for Llama that implements RMS normalization."""

def __init__(self, epsilon=1e-6, **kwargs):
super().__init__(**kwargs)
self.epsilon = epsilon

def build(self, input_shape):
self.weight = self.add_weight(
name="weight",
shape=(input_shape[-1],),
dim = input_shape[-1]
self.scale = self.add_weight(
name="scale",
trainable=True,
shape=(dim,),
initializer="ones",
dtype=self.variable_dtype,
)
self.built = True

def call(self, hidden_states):
variance = ops.mean(ops.square(hidden_states), axis=-1, keepdims=True)
hidden_states = hidden_states * 1 / ops.sqrt(variance + self.epsilon)
return self.weight * hidden_states
def call(self, x):
x = ops.cast(x, "float32")
var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True)
x = x * ops.rsqrt(var + self.epsilon)
return ops.cast(x, self.compute_dtype) * self.scale

def get_config(self):
config = super().get_config()
config.update({"epsilon": self.epsilon})
return config
21 changes: 11 additions & 10 deletions keras_nlp/models/mistral/mistral_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,26 @@ class MistralLayerNormalization(keras.layers.Layer):

def __init__(self, epsilon=1e-6, **kwargs):
super().__init__(**kwargs)
self._epsilon = epsilon
self.epsilon = epsilon

def build(self, input_shape):
self._dim = input_shape[-1]
self._weight = self.add_weight(
name="weight",
dim = input_shape[-1]
self.scale = self.add_weight(
name="scale",
trainable=True,
shape=(self._dim,),
shape=(dim,),
initializer="ones",
dtype=self.variable_dtype,
)
self.built = True

def call(self, x):
x = x * ops.rsqrt(
ops.mean(ops.power(x, 2), axis=-1, keepdims=True) + self._epsilon
)
return x * self._weight
x = ops.cast(x, "float32")
var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True)
x = x * ops.rsqrt(var + self.epsilon)
return ops.cast(x, self.compute_dtype) * self.scale

def get_config(self):
config = super().get_config()
config.update({"epsilon": self._epsilon})
config.update({"epsilon": self.epsilon})
return config

0 comments on commit 78fdb2d

Please sign in to comment.