-
Notifications
You must be signed in to change notification settings - Fork 100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
mixed precision policies cause bumblebee models to fail #544
Comments
also attempting |
experimenting showed that quantizing layer #7 alone also causes the issue (this is an embedding layer, and not an rms-norm layer. Will try disabling both auth_token = System.get_env("HF_AUTH_TOKEN")
Nx.default_backend({EXLA.Backend, client: :host})
# model = {:hf, "mistralai/Mistral-7B-Instruct-v0.1"}
model = {:hf, "meta-llama/Llama-2-7b-chat-hf", auth_token: auth_token}
{:ok, m} = Bumblebee.load_model(model)
{:ok, t} = Bumblebee.load_tokenizer(model)
{:ok, g} = Bumblebee.load_generation_config(model)
bf = {:bf, 16}
policy = Axon.MixedPrecision.create_policy(params: bf, compute: bf, output: bf)
filter = fn layer ->
layer.id == 7
end
mp_model = Axon.MixedPrecision.apply_policy(m.model, policy, filter)
m2 = %{m | model: mp_model}
serving =
Bumblebee.Text.generation(m2, t, g,
defn_options: [compiler: EXLA, compiler_options: [client: :cuda, lazy_transfers: :always]]
)
%{results: [%{text: text}]} =
Nx.Serving.run(serving, "[INST] <<SYS>>\nYou are a bot.\n<</SYS>>\n\nHi, bot![/INST]")
text |> dbg |
We were experimenting with llama2-based models and noticed that there were some problems. Llama2 is trained on bf16 so (probably?) this should work:
Base code (this works)
output:
"[INST] <<SYS>>\nYou are a bot.\n<</SYS>>\n\nHi, bot![/INST] Hello! How can I assist you today?"
code with mixed precision policies:
output:
[INST] <<SYS>>\nYou are a bot.\n<</SYS>>\n\nHi, bot![/INST] pl\nA van Lloydns wicked plan' a serious unwleftP8 including NewNewhhellilitiesathed ux- behindRES val Orange County IL months Meister=bool PA
and so onThe text was updated successfully, but these errors were encountered: