Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A question on your implementation of decoder phase of llama #79

Open
wangtianxia-sjtu opened this issue Jul 1, 2024 · 0 comments
Open

Comments

@wangtianxia-sjtu
Copy link

Recently I have been studying your code. However, It seems to me that your implemention will not expand the kv cache during the decoding phase. The follow code is excerpted from the function def _concatenate_to_cache in llama.py.

if query.shape[1] == 1:
    mesh = LLaMAConfig.get_jax_mesh(self.config.mesh_dim)
    def fn(cached_key, cached_value, key, value, cur_index):
        assert key.shape[1] == 1 and value.shape[1] == 1, (key.shape, value.shape)
        sp_size = max_length // mesh.shape['sp']
        axis_index = jax.lax.axis_index('sp')
        cur_index = cur_index - axis_index * sp_size
        key, value = jax.lax.cond(
            jnp.logical_and(cur_index >= 0, cur_index < sp_size),
            lambda: (
                cached_key.at[:, cur_index].set(key[:, -1]),
                cached_value.at[:, cur_index].set(value[:, -1]),
            ),
            lambda: (cached_key, cached_value),
        )
        return key, value

In this function, we will only update cached_key and cached_value with the newly-generated key/value in the decoding phase, instead of pushing back them into the cached_key and cached_value. However, it seems to me that a correct implementation of kvcache should make the size of kvcache grow and become longer.

Maybe I do not fully understand your code, but I am looking forward to your reply.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant