Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions keras_nlp/src/layers/modeling/reversible_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ class ReversibleEmbedding(keras.layers.Embedding):
"padding" value that should be masked out.
reverse_dtype: The dtype for the reverse projection computation.
Defaults to the `compute_dtype` of the layer.
logit_soft_cap: If `logit_soft_cap` is set and `reverse=True`, the
output logits will be scaled by
`tanh(logits / logit_soft_cap) * logit_soft_cap`. This narrows the
range of output logits and can improve training.
**kwargs: other keyword arguments passed to `keras.layers.Embedding`,
including `name`, `trainable`, `dtype` etc.

Expand Down Expand Up @@ -93,6 +97,7 @@ def __init__(
embeddings_constraint=None,
mask_zero=False,
reverse_dtype=None,
logit_soft_cap=None,
**kwargs,
):
super().__init__(
Expand All @@ -106,6 +111,7 @@ def __init__(
)
self.tie_weights = tie_weights
self.reverse_dtype = reverse_dtype
self.logit_soft_cap = logit_soft_cap

def build(self, inputs_shape=None):
super().build(inputs_shape)
Expand All @@ -129,7 +135,12 @@ def call(self, inputs, reverse=False):
if self.reverse_dtype is not None:
inputs = ops.cast(inputs, self.reverse_dtype)
kernel = ops.cast(kernel, self.reverse_dtype)
return ops.matmul(inputs, kernel)
logits = ops.matmul(inputs, kernel)
# Optionally soft-cap logits.
if self.logit_soft_cap is not None:
soft_cap = self.logit_soft_cap
logits = ops.tanh(logits / soft_cap) * soft_cap
return logits

return super().call(inputs)

Expand All @@ -139,6 +150,7 @@ def get_config(self):
{
"tie_weights": self.tie_weights,
"reverse_dtype": self.reverse_dtype,
"logit_soft_cap": self.logit_soft_cap,
}
)
return config
Expand Down Expand Up @@ -227,11 +239,15 @@ def _int8_call(self, inputs, reverse=False):
kernel = self.reverse_embeddings
scale = self.reverse_embeddings_scale
inputs, inputs_scale = self.inputs_quantizer(inputs)
outputs = ops.matmul(inputs, kernel)
logits = ops.matmul(inputs, kernel)
# De-scale outputs
outputs = ops.cast(outputs, self.compute_dtype)
outputs = ops.divide(outputs, ops.multiply(inputs_scale, scale))
return outputs
logits = ops.cast(logits, self.compute_dtype)
logits = ops.divide(logits, ops.multiply(inputs_scale, scale))
# Optionally soft-cap logits.
if self.logit_soft_cap is not None:
soft_cap = self.logit_soft_cap
logits = ops.tanh(logits / soft_cap) * soft_cap
return logits

return super()._int8_call(inputs)

Expand Down
7 changes: 7 additions & 0 deletions keras_nlp/src/layers/modeling/reversible_embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def test_layer_behaviors_tied(self, tie_weights):
"output_dim": 32,
"tie_weights": tie_weights,
"embeddings_initializer": "HeNormal",
"logit_soft_cap": 50,
},
input_data=random.randint(minval=0, maxval=100, shape=(4, 10)),
expected_output_shape=(4, 10, 32),
Expand Down Expand Up @@ -81,6 +82,12 @@ def test_correctness(self):
out = layer(np.array(([[1.0, 1.0]])), reverse=True)
self.assertAllClose(out, np.array([[0.0, 4.0, 6.0]]))

layer = ReversibleEmbedding(input_dim=3, output_dim=2, logit_soft_cap=5)
layer.build()
layer.embeddings.assign(np.array([[0.0, 0.0], [2.0, 2.0], [3.0, 3.0]]))
out = layer(np.array(([[1.0, 1.0]])), reverse=True)
self.assertAllClose(out, np.array([[0.0, 3.320184, 4.168273]]))

def test_reverse_dtype(self):
embedding = ReversibleEmbedding(100, 16, reverse_dtype="float32")
input_data = ops.ones(shape=(4, 10, 16))
Expand Down
1 change: 1 addition & 0 deletions keras_nlp/src/models/gemma/gemma_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __init__(
seed=None,
),
dtype=dtype,
logit_soft_cap=final_logit_soft_cap,
name="token_embedding",
)
self.transformer_layers = []
Expand Down
13 changes: 0 additions & 13 deletions keras_nlp/src/models/gemma/gemma_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,6 @@ def call_with_cache(
cache = ops.stack(caches, axis=1)
hidden_states = x = self.backbone.layer_norm(x)
logits = self.backbone.token_embedding(x, reverse=True)

if self.backbone.final_logit_soft_cap is not None:
logits = ops.divide(logits, self.backbone.final_logit_soft_cap)
logits = ops.multiply(
ops.tanh(logits), self.backbone.final_logit_soft_cap
)

return logits, hidden_states, cache

def _build_cache(self, token_ids):
Expand Down Expand Up @@ -445,12 +438,6 @@ def default_layer_intercept_fn(x, unused_i):
x = self.backbone.layer_norm(x)
logits = self.backbone.token_embedding(x, reverse=True)

if self.backbone.final_logit_soft_cap is not None:
logits = ops.divide(logits, self.backbone.final_logit_soft_cap)
logits = ops.multiply(
ops.tanh(logits), self.backbone.final_logit_soft_cap
)

if scoring_mode == "logits":
return logits

Expand Down