Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] lm-sys/FastChat/issues/2295 #2328

Merged
merged 5 commits into from
Aug 29, 2023

Conversation

vaxilicaihouxian
Copy link
Contributor

Why are these changes needed?

First,I use python -m fastchat.serve.cli --model-path /my/mac/path/llm_models/chatglm2-6b --load-8bit --device mps
to run chatglm2-6b.But I got this error: Trying to convert BFloat16 to the MPS backend but it does not have support for that dtype.

Related issue number (if applicable)

#2295

Checks

  • I've run format.sh to lint the changes in this PR.
  • I've included any doc changes needed.
  • I've made sure the relevant tests are passing (if applicable).

@merrymercy
Copy link
Member

It seems your modification is not related to the picture you posted in your issue.
In your issue, the picture shows "int32 vs. int64". However, in your code, you changed float and half.
Moreover, in the two if branches you added. You added ".half()" in one "if" branch and one "else" branch. Why is this?

@vaxilicaihouxian
Copy link
Contributor Author

It seems your modification is not related to the picture you posted in your issue. In your issue, the picture shows "int32 vs. int64". However, in your code, you changed float and half. Moreover, in the two if branches you added. You added ".half()" in one "if" branch and one "else" branch. Why is this?

On mac os (usually device:mps) did not support bfloat. Maybe the reason is not correct.But it works with python -m fastchat.serve.cli --model-path /my/mac/path/llm_models/chatglm2-6b --load-8bit --device mps on my macbook(m2).
If I didn't change these two lines it will show the error Trying to convert BFloat16 to the MPS backend but it does not have support for that dtype..
BTW,I'm not very good at this.Just an experience.:)

@merrymercy
Copy link
Member

In the two if branches you added. You added ".half()" in one "if" branch and one "else" branch. Why is this?

@vaxilicaihouxian
Copy link
Contributor Author

vaxilicaihouxian commented Aug 28, 2023

In the two if branches you added. You added ".half()" in one "if" branch and one "else" branch. Why is this?

Oh,that's my mistake.I add half() to these two if under mps device on my local code base.Sorry about it.Now,I fix it.

@@ -167,12 +167,18 @@ def load_compress_model(model_path, device, torch_dtype, use_fast, revision="mai
tmp_state_dict = torch.load(filename, map_location=lambda storage, loc: storage)
for name in tmp_state_dict:
if name in linear_weights:
tensor = tmp_state_dict[name].to(device).data.to(torch_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you try something like this instead of if/else?
tensor = tmp_state_dict[name].to(torch_dtype).to(device).data

If it works, apply the same change to L178.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just fix it by inline to and remove mps condition.It works for chatglm2-6b:
image

to check device type (mps).just use dtype argument.
It works for chatglm2-6b raw model from huggingface.
@merrymercy merrymercy merged commit 42be87e into lm-sys:main Aug 29, 2023
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants