Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT MERGE] Experimental implementation of CausalLM with a Keras Functional backbone_with_cache #1598

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

martin-gorner
Copy link

This is a proof of concept PR for the new layer graph cloning API in Keras (keras-team/keras#19600).
It is not meant to be merged as such but provide a tangible use case for the design of the new layer graph cloning API.

The problem to solve was:

  • XXXCausalLLM(backbone) accepted a Functional backbone but did not use the graph of layers of the backbone.
  • Instead, XXXCausalLLM implemented its call_with_cache by calling selected layers from the backbone.
  • If the user provided a backbone with extra layers between TransformerDecoder blocks (or any other layer graph), the graph of the backbone would be disregarded in call_with_cache.
  • There are several LLM techniques that rely on a modified backbone (ex: control vectors https://arxiv.org/abs/2310.01405) and which could therefore not be used in Keras NLP.

In order to let users implement what they want in the backbone and have call_with_cache still work in XXXCausalLLM, it is necessary to add the caching to the backbone in a Keras Functional way, and respect the Functional layer graph of the backbone.

The new layer graph cloning API can be used:

  • externally, to provide an easy UX for light-weight edits of pre-trained models, for example to implement control vectors.
  • internally, to wire caches into the backbone in order to implement call_with_cache in a Keras Functional way

This PR implements a Keras Functional call_with_cache for GPT2 and Gemma.

@github-actions github-actions bot added the Gemma Gemma model specific issues label Apr 25, 2024
@martin-gorner
Copy link
Author

For a demo of the fuctionality see this Colab: Model rewiring demo with LLMs.ipynb.

For example, you can insert control vectors into an LLM backbone with this clone_fn applied to the backbone:

def clone_fn(layer, *args, **kwargs):
  if isinstance(layer, keras_nlp.layers.TransformerDecoder):
    x = layer(*args, **kwargs)
    x = ControlVectorLayer()(x)
    return x
  else:
    return layer(*args, **kwargs) # identity

Before/after visualization of the backbone:
beforeafter

@martin-gorner
Copy link
Author

And here is what a re-wired backbone with caches looks like. Since it is now a proper Keras Functional model, it can be plotted. The layout is not the best but you can see the cache input fed into all layers and an updated cache fed out and collected at the end.
cache

@martin-gorner
Copy link
Author

Known issue: the max_length parameter in generate(prompt, max_length=64) does not work.

@martin-gorner
Copy link
Author

I have changed the implementation to use the new new_model = clone_model(model, clone_function='lambda x:x', call_function=...) API instead of the previously suggested output = clone_layer_graph(input, output, clone_fn=...).

For this use case, i.e. rewiring a language model backbone with KV caches, the new API is a bit awkward, as it forces the user to use an intermediate model. In simplified code:

rewired_backbone = clone_model(backbone,
                               clone_function=lambda x:x, # no cloning
                               call_function=rewire_fn)

# Build a new backbone with caches in inputs and outputs.
input = {
    "token_ids": rewired_backbone.input["token_ids"],
    "cache": cache_input, # new input
    "cache_update_index": cache_update_index_input, # new input
}

# During the rewiring process, next_caches were collected, add them as a new output
next_cache = ops.stack(next_caches, axis=1)
output = (rewired_backbone.output, next_cache)

# create a new backbone that now uses caches in its forward pass
real_rewired_backbone = keras.Model(input, output, name=backbone.name + "_with_cache")
return real_rewired_backbone

The intermediate model rewired_backbone is "wrong" as it still has the original inputs of backbone, i.e. token_ids, and padding_mask, while its layer graph no longer uses padding_mask and now uses additional inputs cache_input and cache_update_index. The user has to create a new model real_rewired_backbone to fix those issues. It's also surprising that these graph connectedness issues were not caught when rewired_backbone was constructed. This code might fail in the future if graph connectedness checks are improved.

The previously suggested API did not have this awkwardness as it did not involve an intermediate Model. In simplified code:

# Build a new backbone with caches in inputs and outputs.
input = {
    "token_ids": backbone.input["token_ids"],
    "cache": cache_input, # new input
    "cache_update_index": cache_update_index_input, # new input
}

# This call can check for graph connectedness without failing
new_output = clone_layer_graph(input, backbone.output, clone_fn=rewire_fn)

# During the rewiring process, next_caches were collected, add them as a new output
next_cache = ops.stack(next_caches, axis=1)
output = (new_output, next_cache)

# create a new backbone that now uses caches in its forward pass
rewired_backbone = keras.Model(input, output, name=backbone.name + "_with_cache")
return rewired_backbone

Additional Note: I also noticed that the new API clones input tensors. For example, rewired_backbone.input["token_ids"] and backbone.input["token_ids"] are different tensors. The previously suggested API clone_layer_graph was keeping input layers identical as they do not need to be cloned. The new behavior might be surprising for users wondering why the token_ids input has changed during the re-wiring process and wether it's a bug.

@mattdangerw mattdangerw self-requested a review April 30, 2024 02:33
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Took a look!

Stopgap for now

Below shows a colab that would work today, and be decently readable (patching the call method of each transformer layer):

https://colab.research.google.com/gist/mattdangerw/72df1cbe743208bd69137be4f2142203/patched-call-control-vectors.ipynb

This is maybe not the most conceptually beautiful, but it's short, practical, and works for modifying attention and key/query/value projections. It's unclear to me if the approach on this PR could grow to include those types of surgery without substantive rewrites and breaking all weight compatibility.

I think for better or for worse, "layer patching" is the best cohesive approach we have for surgery today. It seems prudent to point users in this direction for now, as we think through new forms of model surgery, and work on a more flexible form of our current generation code.

Forward looking thoughts

What is on this PR is not general for all models (see comments below), and the general form might require pushing the rewiring to individual model implementations, to handle different layer types, input types, arg names, and overall structures (e.g. seq2seq).

I think it'd be doable in the technical sense, if a good bit more code. The rewiring code is a bit clunky and not super readable.

I'd be interested in scoping out support in Keras functional models for optional inputs. If we did this, we could write a backbone that supports caching out of the box. As well as other common "power user" inputs e.g. attention_mask, token_positions, that would cover a lot of other important CUJs without any need for cloning or surgery.

Optional inputs could allow a rewrite the our generative task code to treat the backbone as a functional black box, without any assumptions of internal layer structure. That could allow a lot of the types of functional surgeries you are interested in, with a much smaller blast radius on our model implementation code.

@@ -503,3 +503,48 @@ def get_config(self):

def compute_output_shape(self, decoder_sequence_shape):
return decoder_sequence_shape

def compute_output_spec(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is kinda clunky, but might be a good idea to add regardless. Can we just **kwargs the args we don't actually care about here?

I'm not sure if we need compute_ouptut_shape if we do this.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, compute_output shape is probably not required when compute_output_spec is implemented. Here compute_output_spec was necessary because the layer returns differently shaped outputs depending on inputs.

return output

def rewire_fn(layer, *args, **kwargs):
if isinstance(layer, PositionEmbedding):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would not work for a few models where the position embedding is part of a composite layer (TokenAndPositionEmbedding).

def rewire_fn(layer, *args, **kwargs):
if isinstance(layer, PositionEmbedding):
return rewire_positionembedding(layer, *args, **kwargs)
elif isinstance(layer, TransformerDecoder):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would not work for most decoder models (as model decoder models write their own decoder block).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can test for being in the set of backbone.transformer_blocks rather than a specific type. This can be solved with a convention of what "backbone" should contain (which makes sense - not any backbone works for cached text generation).

def _rewire_backbone_with_cache(backbone, cache_shape, cache_dtype):

# Define new inputs for caches.
cache_update_index_input = keras.Input(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our cache might change to be a tuple of individual layer caches to avoid stacking/concating as described here. #1562

And further down the road, we might want to add more cache shape options, e.g. for things like token attention.

Interestingly, a cache of tuples would invalidate our current restriction on functional model inputs. We'd want a nested structure where one dictionary key contains a tuple of inputs, would break here https://github.com/keras-team/keras/blob/9f4da5159a098256dfbccd2c926107953a6812e5/keras/src/models/functional.py#L134-L141

So we may need to do more thinking here if we "unstack our cache".

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Expanding functional models to arbitrary pytree inputs and outputs (as long as leaves are KerasTensors) is on the roadmap(look under "Modeling").

)
# cache_update_index_input is always a scalar. We must force the
# shape to scalar because keras.Input assumes a batch dim.
cache_update_index_input.shape = ()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems kinda hacky, is this something we want to support generally in Keras? Unbatched functional model inputs? And if so, is this the way we would like to do it?

# === Backbone with cache ===
# The backbone with a cache is used in call_with_cache
cache_shape = self._compute_cache_shape(backbone, preprocessor)
self.backbone_with_cache = self._rewire_backbone_with_cache(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might fail with Keras 2 saving for obscure reasons. Basically, we might try to save backbone_with_cache before the internal model layers, invalidating the whole checkpoint structure. (Just yet another reason to try to ditch Keras 2 asap).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

{
"token_ids": token_ids,
"cache": cache,
"cache_update_index": ops.convert_to_tensor(cache_update_index),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need conver_to_tensor here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why. It would not work without it.

@mattdangerw
Copy link
Member

Update here... @fchollet has added support for optional functional inputs.

So what I think we can do is write a backbone that allows two optional inputs cache and index. Then we can write a causal lm that need zero knowledge of the internals of the backbone, just inputs and output during generation. So the entire "rewire" code can go away.

I think this is the right solution abstraction wise, and will allow a lot more aggressive model surgeries.

But landing this will still take some effort as we will need to drop Keras 2 codepaths in the library (Keras 2 will not support optional inputs).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Gemma Gemma model specific issues
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants