Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds the score API to LlamaCausalLM #1534

Merged
merged 2 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions keras_nlp/models/llama/llama_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,130 @@ def next(prompt, cache, index):
"token_ids": token_ids,
"padding_mask": padding_mask,
}

def score(
self,
token_ids,
padding_mask=None,
scoring_mode="logits",
layer_intercept_fn=None,
target_ids=None,
):
"""Score a generation represented by the provided token ids.

Args:
token_ids: A <int>[batch_size, num_tokens] tensor containing tokens
to score. Typically, this tensor captures the output from a call
to `LlamaCausalLM.generate()`, i.e., tokens for both the input
text and the model-generated text.
padding_mask: A <bool>[batch_size, num_tokens] tensor indicating the
tokens that should be preserved during generation. This is an
artifact required by the `LlamaBackbone` and isn't influential
on the computation of this function. If omitted, this function
uses `keras.ops.ones()` to create a tensor of the appropriate
shape.
scoring_mode: The type of scores to return, either "logits" or
"loss", both will be per input token.
layer_intercept_fn: An optional function for augmenting activations
with additional computation, for example, as part of
interpretability research. This function will be passed the
activations as its first parameter and a numeric index
associated with that backbone layer. _This index _is not_ an
index into `self.backbone.layers`_. The index -1 accompanies the
embeddings returned by calling `self.backbone.token_embedding()`
on `token_ids` in the forward direction. All subsequent indexes
will be 0-based indices for the activations returned by each of
the Transformers layers in the backbone. This function must
return a <float>[batch_size, num_tokens, hidden_dims] tensor
that can be passed as an input to the next layer in the model.
target_ids: An <bool>[batch_size, num_tokens] tensor containing the
predicted tokens against which the loss should be computed. If a
span of tokens is provided (sequential truthy values along
axis=1 in the tensor), the loss will be computed as the
aggregate across those tokens.

Raises:
ValueError: If an unsupported scoring_mode is provided, or if the
target_ids are not provided when using ScoringMode.LOSS.

Returns:
The per-token scores as a tensor of size
<float>[batch_size, num_tokens, vocab_size] in "logits" mode, or
<float>[batch_size, num_tokens] in "loss" mode.

Example:

Compute gradients between embeddings and loss scores with TensorFlow:
```python
llama_lm = keras_nlp.models.LlamaCausalLM.from_preset("llama2_7b_en")
generations = llama_lm.generate(
["This is a", "Where are you"],
max_length=30
)
preprocessed = llama_lm.preprocessor.generate_preprocess(generations)
generation_ids = preprocessed["token_ids"]
padding_mask = preprocessed["padding_mask"]
target_ids = keras.ops.roll(generation_ids, shift=-1, axis=1)

embeddings = None
with tf.GradientTape(watch_accessed_variables=True) as tape:
def layer_intercept_fn(x, i):
if i == -1:
nonlocal embeddings, tape
embeddings = x
tape.watch(embeddings)
return x

losses = llama_lm.score(
token_ids=generation_ids,
padding_mask=padding_mask,
scoring_mode="loss",
layer_intercept_fn=layer_intercept_fn,
target_ids=target_ids,
)

grads = tape.gradient(losses, embeddings)
```
"""
if scoring_mode not in ("logits", "loss"):
raise ValueError(
"Unsupported scoring_mode. Must be one of 'logits' or 'loss'."
)

if scoring_mode == "loss" and target_ids is None:
raise ValueError(
"Cannot compute loss without targets. Please provide target "
"token ids via the target_ids parameter."
)

batch_shape = ops.shape(token_ids)[:2]
assert len(batch_shape) == 2

if padding_mask is None:
padding_mask = ops.ones(shape=batch_shape)

if layer_intercept_fn is None:

def default_layer_intercept_fn(x, unused_i):
return x

layer_intercept_fn = default_layer_intercept_fn

token_embeddings = self.backbone.token_embedding(token_ids)
x = layer_intercept_fn(token_embeddings, -1)

for i, transformer_layer in enumerate(self.backbone.transformer_layers):
x = transformer_layer(x, decoder_padding_mask=padding_mask)
x = layer_intercept_fn(x, i)

x = self.backbone.layer_norm(x)
logits = self.backbone.token_embedding(x, reverse=True)

if scoring_mode == "logits":
return logits

per_token_loss_fn = keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction="none"
)
per_token_loss = per_token_loss_fn(target_ids, logits)
return per_token_loss
85 changes: 85 additions & 0 deletions keras_nlp/models/llama/llama_causal_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,88 @@ def test_all_presets(self):
preset=preset,
input_data=self.input_data,
)

def test_score_logits(self):
# Setup prompts, models, and associated expected shapes.
prompts = ["the quick brown fox", "the quick brown fox"]
causal_lm = LlamaCausalLM(**self.init_kwargs)
expected_score_shape = (2, 8, 10)

# Preprocess prompts to get tokenized representations and padding masks.
preprocessed_prompts = causal_lm.preprocessor.generate_preprocess(
prompts
)
token_ids = preprocessed_prompts["token_ids"]
padding_mask = preprocessed_prompts["padding_mask"]

# Get the scores and assert their shape.
scores = causal_lm.score(
token_ids=token_ids,
padding_mask=padding_mask,
scoring_mode="logits",
)

self.assertEqual(ops.shape(scores), expected_score_shape)

def test_score_loss(self):
# Setup prompts, models, and associated expected shapes.
prompts = ["the quick brown fox", "the quick brown fox"]
causal_lm = LlamaCausalLM(**self.init_kwargs)
expected_score_shape = (2, 8)

# Preprocess prompts to get tokenized representations and padding masks.
preprocessed_prompts = causal_lm.preprocessor.generate_preprocess(
prompts
)
token_ids = preprocessed_prompts["token_ids"]
padding_mask = preprocessed_prompts["padding_mask"]
target_ids = ops.roll(token_ids, shift=-1, axis=1)

# Get the scores and assert their shape.
scores = causal_lm.score(
token_ids=token_ids,
padding_mask=padding_mask,
scoring_mode="loss",
target_ids=target_ids,
)

self.assertEqual(ops.shape(scores), expected_score_shape)

def test_score_layer_intercept_fn_exfiltration(self):
# Setup prompts, models, and associated expected shapes.
prompts = ["the quick brown fox", "the quick brown fox"]
causal_lm = LlamaCausalLM(**self.init_kwargs)
expected_embedded_shape = (2, 8, 8)
expected_score_shape = (2, 8, 10)

# Preprocess prompts to get tokenized representations and padding masks.
preprocessed_prompts = causal_lm.preprocessor.generate_preprocess(
prompts
)
token_ids = preprocessed_prompts["token_ids"]
padding_mask = preprocessed_prompts["padding_mask"]

# Setup a custom intercept function that extracts the embeddings to a
# a variable from the embeddings layer and otherwise asserts on shapes.
embedded_prompts = None

def layer_intercept_fn_for_testing(x, i):
if i == -1:
nonlocal embedded_prompts
embedded_prompts = x
else:
nonlocal expected_embedded_shape
self.assertEqual(ops.shape(x), expected_embedded_shape)
return x

# Get the scores.
scores = causal_lm.score(
token_ids=token_ids,
padding_mask=padding_mask,
scoring_mode="logits",
layer_intercept_fn=layer_intercept_fn_for_testing,
)

# Assert shapes for info exfiltrated into the parent context.
self.assertEqual(ops.shape(embedded_prompts), expected_embedded_shape)
self.assertEqual(ops.shape(scores), expected_score_shape)
Loading