Multi-backend keras implementation of RWKV.
This is a port of the models in the RWKV-LM package for keras. I claim no credit for the network design, though I do offer a novel implementation based on a cumsum
variant.
See also:
- theory for how this formulation works; and
- performance for how these implementations compare to the implementation shipped with the
rwkv
pip package.
This repository has no affiliation with keras-team
- it's just a keras-core
/ keras-nlp
implementation.
This package could admittedly do with a cleaner installation process - that's a non-trivial amount of work though, because the required packages depend on what backend you want to use. For the time being, the following should be enough to cover most cases:
- all will require
keras-core
,keras-nlp
and (for the moment)tensorflow
(even if using other backends) - to use tensorflow's parallel implementation you'll need
tensorflow-probability
- to use torch's
original_cuda
implementation you'll needrwkv
- to use
jax
implementations wrapped withtorch
backend you'll needjax2torch
- to use
torch
's parallel triton implementation you'll needtriton-nightly
(see here for installation instructions)
If errors occur with tensorflow backend, try installing nightly versions of things.
Getting all backends to work in the same environment is non-trivial. I had success using conda
to install jax
and pip for tensorflow
/torch
(following conda installation instructions for tensorflow
/torch
tends to break jax
installations).
Installing this package can be done via
git clone https://github.com/jackd/keras-rwkv.git
pip install -e keras-rwkv
Note there are fully independent backend implementations for wkv
and exponentially weighted (ew
) cumsum
Note the standard keras
implementation only requires ew.cumsum
custom backend implementations (see ops/wkv.py) - the rest can be done via keras
. The wkv
implementations in the individual backends are provided mostly as a convenience for anyone who wants to take them and use them externally in a non-keras environment.
See examples directory for basic usage. It is strongly recommended that you set the KERAS_BACKEND
environment variable - failure to do so will revert to using tf.keras
, which isn't nearly as well tested.
KERAS_BACKEND=jax python examples/generate.py
This package uses pre-commit to ensure commits meet minimum criteria. To Install, use
pip install pre-commit
pre-commit install
This will ensure git hooks are run before each commit. While it is not advised to do so, you can skip these hooks with
git commit --no-verify -m "commit message"