Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DBRX #628

Merged
merged 6 commits into from
Mar 29, 2024
Merged

DBRX #628

merged 6 commits into from
Mar 29, 2024

Conversation

awni
Copy link
Member

@awni awni commented Mar 27, 2024

Quantize

python -m mlx_lm.convert --hf-path databricks/dbrx-base -q 

Generate

python -m mlx_lm.generate --model mlx_model --prompt "Write a quicksort in Python" --trust-remote-code --max-tokens 500

Sample output:

- The partition operation will use a pair of indices that start at the two ends of the array being partitioned, then move toward each other, until they detect an inversion: a pair of elements, one larger than the pivot and one smaller, that are in the wrong order relative to each other. The inverted elements are then swapped. When
==========
Prompt: 3.415 tokens-per-sec
Generation: 14.071 tokens-per-sec

QLoRA

 python -m mlx_lm.lora --model mlx_model --data ../lora/data --batch-size 1 --train --iters 200 --lora-layers 8 --steps-per-report 10 --trust-remote-code

Log:

Iter 100: Saved adapter weights to checkpoints/100_adapters.npz.
Iter 110: Train loss 2.264, Learning Rate 1.000e-05, It/sec 0.341, Tokens/sec 26.124, Trained Tokens 8757, Peak mem 71.735 GB
Iter 120: Train loss 2.190, Learning Rate 1.000e-05, It/sec 0.340, Tokens/sec 25.812, Trained Tokens 9516, Peak mem 71.735 GB
Iter 130: Train loss 2.105, Learning Rate 1.000e-05, It/sec 0.336, Tokens/sec 27.973, Trained Tokens 10349, Peak mem 71.735 GB
Iter 140: Train loss 2.120, Learning Rate 1.000e-05, It/sec 0.334, Tokens/sec 26.627, Trained Tokens 11147, Peak mem 71.735 GB
Iter 150: Train loss 1.829, Learning Rate 1.000e-05, It/sec 0.340, Tokens/sec 23.525, Trained Tokens 11838, Peak mem 71.735 GB

@awni awni requested a review from angeloskath March 27, 2024 16:40
Copy link

@mustafaaljadery mustafaaljadery left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is my first time contributing, this might be the wrong practice, very sorry if it is!

llms/mlx_lm/models/dbrx.py Outdated Show resolved Hide resolved
llms/mlx_lm/models/dbrx.py Show resolved Hide resolved
else:
y = []
for xt, st, it in zip(x, scores, inds.tolist()):
yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[:, None] is added for a dimensionality, but is that not auto handled by mx.concatenate.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually concatenate does not insert a new dimensions. But that's still a good point, we should use stack here which does.. We added this code before we had stack :).

llms/mlx_lm/models/dbrx.py Show resolved Hide resolved
@mzbac
Copy link
Contributor

mzbac commented Mar 27, 2024

Wondering what the memory requirement is for converting the model? I thought even 192GB wouldn't be enough.

@awni
Copy link
Member Author

awni commented Mar 28, 2024

@mustafaaljadery thank you for the comments, those were nice!

@my-other-github-account
Copy link

my-other-github-account commented Mar 28, 2024

This is awesome, works great!

Looks like it must be running Q4, I'm able to run it on an M3 MBP 96GB VRAM with the above prompt, doesn't drop under 15GB free even with some other stuff running on the system.

Makes sense because rough estimate 2x param count in B gives you FP16 GB RAM reqs, so 132B*2 = 264GB if running in FP16, but at Q4 you cut that in 1/4th and it can run in approx 66GB RAM.

Great work!

@lin72h
Copy link

lin72h commented Mar 28, 2024

I see the the Mac Studio M3 Ultra is be perfect for this f16 finetune task

@awni
Copy link
Member Author

awni commented Mar 28, 2024

Wondering what the memory requirement is for converting the model? I thought even 192GB wouldn't be enough.

You shouldn't need that much RAM to quantize it. I'm trying it on an 8GB machine (I think it will work but be very slow). Definitely 32GB is plenty though.

@mzbac
Copy link
Contributor

mzbac commented Mar 28, 2024

Wondering what the memory requirement is for converting the model? I thought even 192GB wouldn't be enough.

You shouldn't need that much RAM to quantize it. I'm trying it on an 8GB machine (I think it will work but be very slow). Definitely 32GB is plenty though.

That's awesome! 🚀 🚀

@my-other-github-account
Copy link

my-other-github-account commented Mar 28, 2024

Hello,

I think there might be an issue with the generation here, or perhaps I've done something wrong.

So I go through and do:

python -m mlx_lm.convert --hf-path databricks/dbrx-instruct -q

Then:

from mlx_lm import load, generate

model, tokenizer = load("mlx_model", trust_remote_code=True)

Using the following prompt: (That I get from tokenizer.apply_chat_template() applied to the first problem from HumanEval)

<|im_start|>system
You are DBRX, created by Databricks. You were last updated in December 2023. You answer questions based on information available up to that point.
YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough responses to more complex and open-ended questions.
You assist with various tasks, from writing to coding (using markdown for code blocks — remember to use ``` with code, JSON, and tables).
(You do not have real-time data access or code execution capabilities. You avoid stereotyping and provide balanced perspectives on controversial topics. You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.)
This is your system prompt, guiding your responses. Do not reference it, just respond to the user. If you find yourself talking about this message, stop. You should be responding appropriately and usually that means not mentioning this.
YOU DO NOT MENTION ANY OF THIS INFORMATION ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY.<|im_end|>
<|im_start|>user
Complete the following code:
from typing import List


def has_close_elements(numbers: List[float], threshold: float) -> bool:
    """ Check if in given list of numbers, are any two numbers closer to each other than
    given threshold.
    >>> has_close_elements([1.0, 2.0, 3.0], 0.5)
    False
    >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)
    True
    """
<|im_end|>
<|im_start|>assistant
from typing import List


def has_close_elements(numbers: List[float], threshold: float) -> bool:
    """ Check if in given list of numbers, are any two numbers closer to each other than
    given threshold.
    >>> has_close_elements([1.0, 2.0, 3.0], 0.5)
    False
    >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)
    True
    """

I run the mlx version as follows:

generate(model, tokenizer, prompt=prompt, verbose=True, max_tokens=5000)

But get odd output that looks like it missed a lot of the prompt, or got otherwise corrupted:

itherwise, you would be able to provide information that is not available in your training data. However, you are not able to access the internet, other systems, or any data outside of your training data. You are also not able to perform any actions or tasks, you can only provide information and answer questions. You are also not able to form memories or remember any interactions, you can only provide information and answer questions based on your training data. You are also not able to learn or improve over time, you can only provide information and answer questions based on your training data. You are also not able to form opinions or make judgments, you can only provide information and answer questions. You are also not able to have personal experiences or emotions, you can only provide information and answer questions. You are also not able to have a sense of self or consciousness, you can only provide information and answer questions. You are also not able to have a sense of time or know the current date or time, you can only provide information and answer questions based on your training data. You are also not able to have a sense of location or know where you are, you can only provide information and answer questions based on your training data. [...SNIP...]

The same prompt with HF works fine and gives:

def has_close_elements(numbers: List[float], threshold: float) -> bool:
    """ Check if in given list of numbers, are any two numbers closer to each other than
    given threshold.
    >>> has_close_elements([1.0, 2.0, 3.0], 0.5)
    False
    >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)
    True
    """
    numbers.sort() # sort the list of numbers
    for i in range(len(numbers) - 1):
        if numbers[i+1] - numbers[i] < threshold:
            return True
    return False 

This holds true for every problem in HumanEval - good results from HF, unintelligible outputs from MLX Q4. This is well beyond quantization error which is typically minimal for humaneval at Q4, it looks similar to times where I have accidentally decoded logits incorrectly.

Any ideas where I might have gone wrong, or if there is another issue that could be causing this?

Thank you!

@mezmon2
Copy link

mezmon2 commented Mar 28, 2024

Hello,

I think there might be an issue with the generation here, or perhaps I've done something wrong.

So I go through and do:

python -m mlx_lm.convert --hf-path databricks/dbrx-instruct -q

Did you have any bias related errors when running the inital command w/ the instruct model? I am getting this

self.norm_1 = nn.LayerNorm(args.d_model, bias=False)

^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

TypeError: LayerNorm.init() got an unexpected keyword argument 'bias'`

@M-I
Copy link

M-I commented Mar 28, 2024

@my-other-github-account Got similar output when trying to apply the chat template.

In case it helps, here's some output with no template:

response = generate(model, tokenizer, prompt="Here's the code in markdown for a quicksort in Python:", verbose=True, max_tokens=200)

==========

Prompt: Here's the code in markdown for a quicksort in Python:

def quicksort(arr):
    if len(arr) <= 1:
        return arr
    pivot = arr[len(arr) // 2]
    left = [x for x in arr if x < pivot]
    middle = [x for x in arr if x == pivot]
    right = [x for x in arr if x > pivot]
    return quicksort(left) + middle + quicksort(right)

print(quicksort([3,6,8,10,1,2,1]))
# Returns [1, 1, 2, 3, 6, 8, 10]

And here's the same code in a code block:

def quicksort(arr):
    if len(arr) <= 1:
        return arr
    pivot = arr[len(arr) // 2]
    left = [x for x in arr if x < pivot]
    middle = [x for x in arr if x

==========
Prompt: 4.942 tokens-per-sec
Generation: 4.681 tokens-per-sec

@M-I
Copy link

M-I commented Mar 28, 2024

Hello,
I think there might be an issue with the generation here, or perhaps I've done something wrong.
So I go through and do:
python -m mlx_lm.convert --hf-path databricks/dbrx-instruct -q

Did you have any bias related errors when running the inital command w/ the instruct model? I am getting this

self.norm_1 = nn.LayerNorm(args.d_model, bias=False)

^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

TypeError: LayerNorm.init() got an unexpected keyword argument 'bias'`

Just in case, make sure you're on the right branch (dbrx), and that your env doesn't have mlx-lm instaled from pip.

@awni
Copy link
Member Author

awni commented Mar 28, 2024

@mezmon2 use the latest MLX (python -m pip install -U mlx).

@ivanfioravanti
Copy link
Contributor

Great PR!

@my-other-github-account
Copy link

my-other-github-account commented Mar 28, 2024

@my-other-github-account Got similar output when trying to apply the chat template.

In case it helps, here's some output with no template:

Not using the template sort of defeats the purpose of an instruct tuned model unfortunately, as it won't work as a traditional chatbot. So you can't say e.g. "please give me the code for quicksort in Python" or speak to it naturally. At that point you are better off with the base model :/

Given that most people are using LLMs as chatbots, not working as a chatbot seems like a pretty serious issue for normal use!

Additionally, I did make sure to reinstall mlx from latest, and checked out from pr-628 (which is dbrx branch), made sure mlx-lm was uninstalled, then installed from the pr. No change there unfortunately :(

Is there another way we should be applying the template so it works properly and can be used for IFT models in a chat setting?

@awni
Copy link
Member Author

awni commented Mar 28, 2024

@mzbac @Blaizzy the instruct model here uses the default template but tokenizer.chat_template==None

@M-I
Copy link

M-I commented Mar 28, 2024

Not using the template sort of defeats the purpose of an instruct tuned model unfortunately, as it won't work as a traditional chatbot. So you can't say e.g. "please give me the code for quicksort in Python" or speak to it naturally. At that point you are better off with the base model :/

Sorry for the confusion, it wasn't a suggestion to try without template. Since you had it working with HF, and this was showing reasonable text generation with mlx. it was to indicate the issue was probably not the quantization process but maybe in the code that applies the template.

@Blaizzy
Copy link
Contributor

Blaizzy commented Mar 28, 2024

@mzbac @Blaizzy the instruct model here uses the default template but tokenizer.chat_template==None

@awni
Yes, I see it.

This usually happens because the model code is not yet in the transformers GitHub repository.
I believe it will be fixed once the PR is merged. They are yet to convert the model to the proper template:

huggingface/transformers#29921 (review)

The PretrainedTokenizerBase code by default sets chat_template if default_chat_template exists and chat_template is None.

PretrainedTokenizerBase link:
https://github.com/huggingface/transformers/blob/536ea2aca234fb48c5c69769431d643b0d93b233/src/transformers/tokenization_utils_base.py#L1749-L1771

@awni
Copy link
Member Author

awni commented Mar 28, 2024

Indeed the instruct model doesn't respond well to the template.. I wonder if there is a bug in the template.. the text itself looks slightly reasonable.. unfortunately it's difficult to check if it's related to the quantization given the fp16 model doesn't in in RAM.

@awni awni requested review from andresy and jagrit06 March 28, 2024 18:01
@my-other-github-account
Copy link

my-other-github-account commented Mar 28, 2024

In case it helps, here are my two quick reproducers:

HF: (Works)

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from evalplus.data import get_human_eval_plus

tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-instruct", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("databricks/dbrx-instruct", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)

input_prompt = list(get_human_eval_plus().items())[0][1]["prompt"]
entry_point = list(get_human_eval_plus().items())[0][1]["entry_point"]

def get_prompt(input_prompt):
    input_text = "Complete the following code:\n" + input_prompt
    messages = [{"role": "user", "content": input_text}, {"role": "assistant", "content": input_prompt}]
    prompt = tokenizer.apply_chat_template(messages, return_dict=True, tokenize=False, add_generation_prompt=False, return_tensors="pt")
    prompt = prompt[:-10]
    return prompt

prompt = get_prompt(input_prompt)
outputs = model.generate(tokenizer.encode(prompt, return_tensors="pt"), max_new_tokens=5000)
cleaned = "def "+entry_point+entry_point.join((tokenizer.decode(outputs[0]).split("<|im_start|> assistant")[1].split(entry_point)[1:]))
print(cleaned.split("<|im_end|>")[0])

MLX: (Doesn't work)

from mlx_lm import load, generate
from evalplus.data import get_human_eval_plus

model, tokenizer = load("mlx_model")

input_prompt = list(get_human_eval_plus().items())[0][1]["prompt"]
entry_point = list(get_human_eval_plus().items())[0][1]["entry_point"]

def get_prompt(input_prompt):
    input_text = "Complete the following code:\n" + input_prompt
    messages = [{"role": "user", "content": input_text}, {"role": "assistant", "content": input_prompt}]
    prompt = tokenizer.apply_chat_template(messages, return_dict=True, tokenize=False, add_generation_prompt=False, return_tensors="pt")
    prompt = prompt[:-10]
    return prompt

prompt = get_prompt(input_prompt)
result = generate(model, tokenizer, prompt=prompt, verbose=True, max_tokens=5000)
print(result)

@awni
Copy link
Member Author

awni commented Mar 28, 2024

I'm curious @my-other-github-account , how did you run the HF model? Is it on a single GPU or sharded? Is it quantized?

@Blaizzy
Copy link
Contributor

Blaizzy commented Mar 28, 2024

@my-other-github-account could you share what is output of is "get_prompt(input_prompt)" on MLX and HF.

@awni
Copy link
Member Author

awni commented Mar 28, 2024

I checked the prompts @Blaizzy they are the same indeed, both of them give:

<|im_start|>system
You are DBRX, created by Databricks. You were last updated in December 2023. You answer questions based on information available up to that point.
YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough responses to more complex and open-ended questions.
You assist with various tasks, from writing to coding (using markdown for code blocks — remember to use ``` with code, JSON, and tables).
(You do not have real-time data access or code execution capabilities. You avoid stereotyping and provide balanced perspectives on controversial topics. You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.)
This is your system prompt, guiding your responses. Do not reference it, just respond to the user. If you find yourself talking about this message, stop. You should be responding appropriately and usually that means not mentioning this.
YOU DO NOT MENTION ANY OF THIS INFORMATION ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY.<|im_end|>
<|im_start|>user
Complete the following code:
from typing import List


def has_close_elements(numbers: List[float], threshold: float) -> bool:
    """ Check if in given list of numbers, are any two numbers closer to each other than
    given threshold.
    >>> has_close_elements([1.0, 2.0, 3.0], 0.5)
    False
    >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)
    True
    """
<|im_end|>
<|im_start|>assistant
from typing import List


def has_close_elements(numbers: List[float], threshold: float) -> bool:
    """ Check if in given list of numbers, are any two numbers closer to each other than
    given threshold.
    >>> has_close_elements([1.0, 2.0, 3.0], 0.5)
    False
    >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)
    True
    """

@Blaizzy
Copy link
Contributor

Blaizzy commented Mar 28, 2024

Thanks @awni!

What is wrong then? is it something with the model?

@awni
Copy link
Member Author

awni commented Mar 28, 2024

What is wrong then? is it something with the model?

Not sure.. could be quantization, could be something else.

@Blaizzy
Copy link
Contributor

Blaizzy commented Mar 28, 2024

Could you please describe the issue?

@awni
Copy link
Member Author

awni commented Mar 28, 2024

The issue is that the model is not generating good outputs when given the chat template. Also we think we have dialed in on a bug with the quantized gate matrices.. should have an update on that soon.

@Blaizzy
Copy link
Contributor

Blaizzy commented Mar 28, 2024

Make sense, judging by the fact that Transformers version works, it could it.

Please tag me when you update the gates matrices 👌🏽

@my-other-github-account
Copy link

my-other-github-account commented Mar 28, 2024

I have been running it sharded on 4xA100 80GB - and indeed the above are the prompts I get with either.

I could 8 bit quant it, but llama.cpp doesn't support less than 8 bit quantization for DBRX yet as it is slightly customized. Right now I am in FP16. Would it help to try 8?

For what its worth, I have benchmarked somewhere in the neighborhood of ~50 IFT LLMs against humaneval, in none of the cases did dropping to 4 bit affect performance substantially (3 bit did a bit, though, but not to the point of being unusable). So FWIW quantization error alone absent a bug isn't sufficient to explain such drastic results IME - i.e. dropping to a 0 on humaneval would be unheard of, which is what happens here.

@awni
Copy link
Member Author

awni commented Mar 28, 2024

Hotfix here ml-explore/mlx#923 (review) will be in today's release.

@mzbac
Copy link
Contributor

mzbac commented Mar 28, 2024

Is it worth not quantizing the gate (route)? I can't be 100% sure, but I feel that non-quantized seems to perform better when I tried the custom MOE model in the past.

@awni
Copy link
Member Author

awni commented Mar 28, 2024

I can't be 100% sure, but I feel that non-quantized seems to perform better when I tried the custom MOE model in the past.

Could be related to the bug we just fixed..

Is it worth not quantizing the gate (route)?

Maybe not worth it from a run time perspective.. it is a small part of the model. Either way I'm glad not to have a bug there. It's also simpler from an implementation standpoint if we can reliably quantize any linear layer.. keeps the model homogenous.

@my-other-github-account

I can confirm the latest changes do indeed fix the issue!

Now I get:

def has_close_elements(numbers: List[float], threshold: float) -> bool:
    """ Check if in given list of numbers, are any two numbers closer to each other than
    given threshold.
    >>> has_close_elements([1.0, 2.0, 3.0], 0.5)
    False
    >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)
    True
    """

    numbers.sort()
    for i in range(len(numbers) - 1):
        if numbers[i + 1] - numbers[i] < threshold:
            return True
    return False <|im_end|>

As expected!

Great work, thank you all again!

@awni
Copy link
Member Author

awni commented Mar 28, 2024

Thanks for surfacing this issue @my-other-github-account and for the reproducible test case.

@M-I
Copy link

M-I commented Mar 29, 2024

@my-other-github-account did you requantize or the just updating mlx-lm was enough?

@awni
Copy link
Member Author

awni commented Mar 29, 2024

You shouldn’t need to requantize

Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀

@awni awni merged commit b80adbc into main Mar 29, 2024
2 checks passed
@awni awni deleted the dbrx branch March 29, 2024 04:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

10 participants