Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add configs to run int4 inference #37

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

RezaYazdaniAminabadi
Copy link

Add some minor config changes to support int4 inference through DeepSpeed-Inference.

The Int4 support will be added to DeepSpeed through this PR.

cc: @stas00

Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing work with adding int4-support, Reza!

@@ -191,6 +191,7 @@ def write_checkponts_json():
mp_size=world_size,
base_dir=repo_root,
dtype=getattr(torch, infer_dtype),
quantization_bits=8 if args.dtype == 'int8' else 4,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens with --dtype float16?

probably best to set this in kwargs only if quantization dtype is provided

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The quabtization-bit should not be used when running in half-precision. But, I agree we can do it in the kwargs and only for qunatized inference mode.

Copy link
Contributor

@stas00 stas00 Nov 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these demos are already used by many users so let's make those nice and clean configuration-wise, so it's clear to the reader when what bits should be enabled.

@@ -227,7 +228,7 @@ def write_checkponts_json():
# dynamically extend to support larger bs by repetition
input_sentences *= math.ceil(args.batch_size / len(input_sentences))

generate_kwargs = dict(max_new_tokens=num_tokens, do_sample=False)
generate_kwargs = dict(max_new_tokens=num_tokens, do_sample=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is already a very different type of change.

If int4 requires do_sample=True, then again, let's change it only if it's --dtype int4

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I will double check the do_sample=False again to see if the text generated makes sense. If not, I set it to true for int4

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just checked with do_sample=False and I see the text is produced in the same way as for FP16 and INT8:

in=DeepSpeed is a machine learning framework
out=DeepSpeed is a machine learning framework for deep learning. It is a Python library, and it is also a framework. It is a framework, and it is a library. It is a framework, and it is a library. It is a framework, and it is a library. It is a framework, and it is a library. It is a framework, and it is a library. It is a framework, and it is a library. It is a framework, and it is a library. It is a framework, and

So, I am gonna turn it off for now.

@stas00
Copy link
Contributor

stas00 commented Nov 18, 2022

Also should probably assert if int4 attempted to be used w/o deepspeed>=xyz once the DS PR is merged... could tentatively set to the next deepspeed version? perhaps with XXX to enabled so the script can be used against ds@master.

I can take care of that.

@RezaYazdaniAminabadi
Copy link
Author

Also should probably assert if int4 attempted to be used w/o deepspeed>=xyz once the DS PR is merged... could tentatively set to the next deepspeed version? perhaps with XXX to enabled so the script can be used against ds@master.

I can take care of that.

Sounds good to me. Thanks @stas00

@@ -100,7 +100,7 @@ def get_checkpoint_files(model_name_or_path):


model_name = args.name
infer_dtype = args.dtype
infer_dtype = args.dtype if args.dtype != 'int4' else 'int8'
Copy link
Contributor

@stas00 stas00 Nov 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it make for a more user-friendly API to

  1. keep the dtype intact
  2. drop quantization_bits
  3. let deepspeed.init_inference derive the number of bits from dtype?

not only the currently suggested override is confusing, I fail to see what purpose serves carrying the same information in dtype and and quantization_bits twice

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, wait, torch.init4 still doesn't exist, does it?

let's find the feature request.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still not implemented pytorch/pytorch#74627

so that's why you had to do the odd workarounds, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can drop it once its implemented @stas00 ?
For now, this might be the best way to do it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's pointless to wait, since they won't have int3 and int12

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it make for a more user-friendly API to

  1. keep the dtype intact
  2. drop quantization_bits
  3. let deepspeed.init_inference derive the number of bits from dtype?

not only the currently suggested override is confusing, I fail to see what purpose serves carrying the same information in dtype and and quantization_bits twice

@stas00 and @RezaYazdaniAminabadi - just clarifying that we have introduced a new DeepSpeedInferenceConfig that can be passed to init_inference. We are keeping it backwards compatible but if we are okay to make changes to this file, I would advocate for writing a config dictionary for DeepSpeed and pass that to init_inference instead of the various kwargs. Please see here for an example: https://gist.github.com/awan-10/6e3d5c756be3a876522e860c6bbf702d#file-bloom-ds-inference-py-L173

Also, see the docs for the new config: https://deepspeed.readthedocs.io/en/latest/inference-init.html

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That definitely works.

@awan-10, may I suggest you make the inference config accept dict_or_path just like zero does? it might be for some users easier to write out a separate file.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stas00 - thanks for the suggestion. Created an issue so we can track it: microsoft/DeepSpeed#2532. Mike and I will work on it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much, @awan-10

@stas00
Copy link
Contributor

stas00 commented Nov 19, 2022

OK, I think I understand the limitations of pytorch and it'll get only worse when you try int3, etc. even if int4 is supported.
https://github.com/huggingface/transformers-bloom-inference/pull/37/files#r1026981222

I propose we break the currently proposed API and draw a better one.

I propose to have only 2 user-configurable args related to how deepspeed-inference operates

  1. dtype is the dtype of the original model - so only fp32, fp16 or bf16 - never intX (i.e. we drop int8)
  2. quantization_bits: [None, 12, 8, 4, 3]

Now the API is simple, unambiguous and future proof (as in int12 or int3, Mixture of Precisions support)

For back-compat deepspeed.init_inference can simply set quantization_bits=8 if dtype==torch.int8 is passed. So the API will be unbroken.

What do you think, Reza?

@mayank31398
Copy link
Collaborator

Huh?
Int4?
I will test this branch surely and let you know.
Thanks a lot for this :)

@RezaYazdaniAminabadi
Copy link
Author

is simple, unambiguous and future pro

Hi @stas00,
I agree with what you said, and we are going through the same route as you see from my last commit here.
Thanks for the good suggestion :)
Best,
Reza

@RezaYazdaniAminabadi
Copy link
Author

In that case, we

is simple, unambiguous and future pro

Hi @stas00, I agree with what you said, and we are going through the same route as you see from my last commit here. Thanks for the good suggestion :) Best, Reza

In this case, we can simply pass the bits to the DeepSpeed-inference config: kwargs['quant']['weight']['num_bits'] = quantization_bits

@stas00
Copy link
Contributor

stas00 commented Nov 19, 2022

may I suggest that the just added kwargs['quant']['weight']['num_bits'] isn't the most user-friendly API as far as kwargs go?

why not have a flat structure of simple key=value pairs and once you got the info in your side you can re-arrange it to any nesting level you want.

@RezaYazdaniAminabadi
Copy link
Author

may I suggest that the just added kwargs['quant']['weight']['num_bits'] isn't the most user-friendly API as far as kwargs go?

why not have a flat structure of simple key=value pairs and once you got the info in your side you can re-arrange it to any nesting level you want.

I agree, let me work on that and I fix it.

@awan-10
Copy link

awan-10 commented Nov 19, 2022

may I suggest that the just added kwargs['quant']['weight']['num_bits'] isn't the most user-friendly API as far as kwargs go?
why not have a flat structure of simple key=value pairs and once you got the info in your side you can re-arrange it to any nesting level you want.

I agree, let me work on that and I fix it.

@RezaYazdaniAminabadi -- please see my comment above. #37 (comment)

@RezaYazdaniAminabadi
Copy link
Author

may I suggest that the just added kwargs['quant']['weight']['num_bits'] isn't the most user-friendly API as far as kwargs go?
why not have a flat structure of simple key=value pairs and once you got the info in your side you can re-arrange it to any nesting level you want.

I agree, let me work on that and I fix it.

@RezaYazdaniAminabadi -- please see my comment above. #37 (comment)

thanks @awan-10. Please go ahead and push your changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants