-
Notifications
You must be signed in to change notification settings - Fork 788
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DBRX #628
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is my first time contributing, this might be the wrong practice, very sorry if it is!
llms/mlx_lm/models/dbrx.py
Outdated
else: | ||
y = [] | ||
for xt, st, it in zip(x, scores, inds.tolist()): | ||
yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[:, None]
is added for a dimensionality, but is that not auto handled by mx.concatenate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually concatenate
does not insert a new dimensions. But that's still a good point, we should use stack
here which does.. We added this code before we had stack
:).
Wondering what the memory requirement is for converting the model? I thought even 192GB wouldn't be enough. |
@mustafaaljadery thank you for the comments, those were nice! |
This is awesome, works great! Looks like it must be running Q4, I'm able to run it on an M3 MBP 96GB VRAM with the above prompt, doesn't drop under 15GB free even with some other stuff running on the system. Makes sense because rough estimate 2x param count in B gives you FP16 GB RAM reqs, so 132B*2 = 264GB if running in FP16, but at Q4 you cut that in 1/4th and it can run in approx 66GB RAM. Great work! |
I see the the Mac Studio M3 Ultra is be perfect for this f16 finetune task |
You shouldn't need that much RAM to quantize it. I'm trying it on an 8GB machine (I think it will work but be very slow). Definitely 32GB is plenty though. |
That's awesome! 🚀 🚀 |
Hello, I think there might be an issue with the generation here, or perhaps I've done something wrong. So I go through and do: python -m mlx_lm.convert --hf-path databricks/dbrx-instruct -q Then:
Using the following prompt: (That I get from tokenizer.apply_chat_template() applied to the first problem from HumanEval)
I run the mlx version as follows: generate(model, tokenizer, prompt=prompt, verbose=True, max_tokens=5000) But get odd output that looks like it missed a lot of the prompt, or got otherwise corrupted:
The same prompt with HF works fine and gives:
This holds true for every problem in HumanEval - good results from HF, unintelligible outputs from MLX Q4. This is well beyond quantization error which is typically minimal for humaneval at Q4, it looks similar to times where I have accidentally decoded logits incorrectly. Any ideas where I might have gone wrong, or if there is another issue that could be causing this? Thank you! |
Did you have any bias related errors when running the inital command w/ the instruct model? I am getting this
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ TypeError: LayerNorm.init() got an unexpected keyword argument 'bias'` |
@my-other-github-account Got similar output when trying to apply the chat template. In case it helps, here's some output with no template: response = generate(model, tokenizer, prompt="Here's the code in markdown for a quicksort in Python:", verbose=True, max_tokens=200) ==========
========== |
Just in case, make sure you're on the right branch (dbrx), and that your env doesn't have mlx-lm instaled from pip. |
@mezmon2 use the latest MLX ( |
Great PR! |
Not using the template sort of defeats the purpose of an instruct tuned model unfortunately, as it won't work as a traditional chatbot. So you can't say e.g. "please give me the code for quicksort in Python" or speak to it naturally. At that point you are better off with the base model :/ Given that most people are using LLMs as chatbots, not working as a chatbot seems like a pretty serious issue for normal use! Additionally, I did make sure to reinstall mlx from latest, and checked out from pr-628 (which is dbrx branch), made sure mlx-lm was uninstalled, then installed from the pr. No change there unfortunately :( Is there another way we should be applying the template so it works properly and can be used for IFT models in a chat setting? |
Sorry for the confusion, it wasn't a suggestion to try without template. Since you had it working with HF, and this was showing reasonable text generation with mlx. it was to indicate the issue was probably not the quantization process but maybe in the code that applies the template. |
@awni This usually happens because the model code is not yet in the transformers GitHub repository. huggingface/transformers#29921 (review) The PretrainedTokenizerBase link: |
Indeed the instruct model doesn't respond well to the template.. I wonder if there is a bug in the template.. the text itself looks slightly reasonable.. unfortunately it's difficult to check if it's related to the quantization given the fp16 model doesn't in in RAM. |
In case it helps, here are my two quick reproducers: HF: (Works)
MLX: (Doesn't work)
|
I'm curious @my-other-github-account , how did you run the HF model? Is it on a single GPU or sharded? Is it quantized? |
@my-other-github-account could you share what is output of is "get_prompt(input_prompt)" on MLX and HF. |
I checked the prompts @Blaizzy they are the same indeed, both of them give:
|
Thanks @awni! What is wrong then? is it something with the model? |
Not sure.. could be quantization, could be something else. |
Could you please describe the issue? |
The issue is that the model is not generating good outputs when given the chat template. Also we think we have dialed in on a bug with the quantized gate matrices.. should have an update on that soon. |
Make sense, judging by the fact that Transformers version works, it could it. Please tag me when you update the gates matrices 👌🏽 |
I have been running it sharded on 4xA100 80GB - and indeed the above are the prompts I get with either. I could 8 bit quant it, but llama.cpp doesn't support less than 8 bit quantization for DBRX yet as it is slightly customized. Right now I am in FP16. Would it help to try 8? For what its worth, I have benchmarked somewhere in the neighborhood of ~50 IFT LLMs against humaneval, in none of the cases did dropping to 4 bit affect performance substantially (3 bit did a bit, though, but not to the point of being unusable). So FWIW quantization error alone absent a bug isn't sufficient to explain such drastic results IME - i.e. dropping to a 0 on humaneval would be unheard of, which is what happens here. |
Hotfix here ml-explore/mlx#923 (review) will be in today's release. |
Is it worth not quantizing the gate (route)? I can't be 100% sure, but I feel that non-quantized seems to perform better when I tried the custom MOE model in the past. |
Could be related to the bug we just fixed..
Maybe not worth it from a run time perspective.. it is a small part of the model. Either way I'm glad not to have a bug there. It's also simpler from an implementation standpoint if we can reliably quantize any linear layer.. keeps the model homogenous. |
I can confirm the latest changes do indeed fix the issue! Now I get:
As expected! Great work, thank you all again! |
Thanks for surfacing this issue @my-other-github-account and for the reproducible test case. |
@my-other-github-account did you requantize or the just updating mlx-lm was enough? |
You shouldn’t need to requantize |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚀
Quantize
Generate
python -m mlx_lm.generate --model mlx_model --prompt "Write a quicksort in Python" --trust-remote-code --max-tokens 500
Sample output:
QLoRA
Log: