-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[llm] Create separate Predictor for LLMs and enable flash attention on CUDA #3409
Conversation
Unit Test Results 6 files ± 0 6 suites ±0 1h 28m 20s ⏱️ + 16m 55s For more details on these failures, see this check. Results for commit bd0ddc2. ± Comparison against base commit cb37535. ♻️ This comment has been updated with latest results. |
for more information, see https://pre-commit.ci
…nto llm-predictor
for more information, see https://pre-commit.ci
…nto llm-predictor
class DictWrapper: | ||
"""Wrapper for a LudwigFeatureDict module that allows for iteration over keys. | ||
|
||
The purpose of this class is to avoid exposing input and output features as modules of the LLM. This is because we | ||
only wish to train the underlying model, and having these additional modules can confuse systems like DeepSpeed. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very clever way of getting around initializing all of the other modules
…nto llm-predictor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 🚀
Failed test looks like a transient issue, it should be safe to merge |
This PR introduces a new
LlmPredictor
class used only during batch prediction for the purpose of running text generation instead of simply outputting logits (as during training) from the forward pass. This is because at predict time we want the fully generated sequence, not just the logits.Other fixes included in this PR:
LLM
module.This PR also disables
prompt_tuning
as an adapter type for now, as generation mode does not currently work when this adapter is applied (seetests/integration_tests/test_llm.py::test_llm_finetuning_strategies[prompt_tuning_init_random-local]
).