Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mixed precision policies cause bumblebee models to fail #544

Closed
ityonemo opened this issue Nov 13, 2023 · 2 comments · Fixed by elixir-nx/bumblebee#280 or #547
Closed

mixed precision policies cause bumblebee models to fail #544

ityonemo opened this issue Nov 13, 2023 · 2 comments · Fixed by elixir-nx/bumblebee#280 or #547

Comments

@ityonemo
Copy link

ityonemo commented Nov 13, 2023

We were experimenting with llama2-based models and noticed that there were some problems. Llama2 is trained on bf16 so (probably?) this should work:

Base code (this works)

auth_token = System.get_env("HF_AUTH_TOKEN")
Nx.default_backend({EXLA.Backend, client: :host})
model = {:hf, "meta-llama/Llama-2-7b-chat-hf", auth_token: auth_token}

{:ok, m} = Bumblebee.load_model(model)
{:ok, t} = Bumblebee.load_tokenizer(model)
{:ok, g} = Bumblebee.load_generation_config(model)

serving = Bumblebee.Text.generation(m, t, g, defn_options: 
  [compiler: EXLA, compiler_options: [client: :cuda, lazy_transfers: :always]])

Nx.Serving.run(serving, "[INST] <<SYS>>\nYou are a bot.\n<</SYS>>\n\nHi, bot![/INST]")
|> dbg

output:

"[INST] <<SYS>>\nYou are a bot.\n<</SYS>>\n\nHi, bot![/INST] Hello! How can I assist you today?"

code with mixed precision policies:

auth_token = System.get_env("HF_AUTH_TOKEN")
Nx.default_backend({EXLA.Backend, client: :host})

model = {:hf, "meta-llama/Llama-2-7b-chat-hf", auth_token: auth_token}

{:ok, m} = Bumblebee.load_model(model)
{:ok, t} = Bumblebee.load_tokenizer(model)
{:ok, g} = Bumblebee.load_generation_config(model)

policy = Axon.MixedPrecision.create_policy(compute: {:bf, 16})
mp_model = Axon.MixedPrecision.apply_policy(m.model, policy)
m2 = %{m | model: mp_model}

serving = Bumblebee.Text.generation(m2, t, g, defn_options: 
  [compiler: EXLA, compiler_options: [client: :cuda, lazy_transfers: :always]])

Nx.Serving.run(serving, "[INST] <<SYS>>\nYou are a bot.\n<</SYS>>\n\nHi, bot![/INST]")

output:

[INST] <<SYS>>\nYou are a bot.\n<</SYS>>\n\nHi, bot![/INST] pl\nA van Lloydns wicked plan' a serious unwleftP8 including NewNewhhellilitiesathed ux- behindRES val Orange County IL months Meister=bool PA and so on

@ityonemo
Copy link
Author

ityonemo commented Nov 13, 2023

also attempting {:f, 16} failed. However, {:f 64} performs ~correctly Hello there! *giggles* I'm just an AI assistant, here to help you with any questions or tasks you may have! *winks* Is there something specific you'd like to chat about or ask me to do? 😃.

@ityonemo
Copy link
Author

ityonemo commented Nov 13, 2023

experimenting showed that quantizing layer #7 alone also causes the issue (this is an embedding layer, and not an rms-norm layer. Will try disabling both

auth_token = System.get_env("HF_AUTH_TOKEN")
Nx.default_backend({EXLA.Backend, client: :host})
# model = {:hf, "mistralai/Mistral-7B-Instruct-v0.1"}
model = {:hf, "meta-llama/Llama-2-7b-chat-hf", auth_token: auth_token}

{:ok, m} = Bumblebee.load_model(model)
{:ok, t} = Bumblebee.load_tokenizer(model)
{:ok, g} = Bumblebee.load_generation_config(model)

  bf = {:bf, 16}
  policy = Axon.MixedPrecision.create_policy(params: bf, compute: bf, output: bf)

  filter = fn layer ->
    layer.id == 7
  end

  mp_model = Axon.MixedPrecision.apply_policy(m.model, policy, filter)
  m2 = %{m | model: mp_model}

  serving =
    Bumblebee.Text.generation(m2, t, g,
      defn_options: [compiler: EXLA, compiler_options: [client: :cuda, lazy_transfers: :always]]
    )

  %{results: [%{text: text}]} =
    Nx.Serving.run(serving, "[INST] <<SYS>>\nYou are a bot.\n<</SYS>>\n\nHi, bot![/INST]")

    text |> dbg

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants