Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pip install -r requirements.txt
### Demo (`MXFP4`, `MXFP8`, `NVFP4`, `uNVFP4`)

```bash
python quantize.py --model_name_or_path facebook/opt-125m --quantize --dtype MXFP4 --batch_size 8 --accuracy
python quantize.py --model_name_or_path facebook/opt-125m --quantize --dtype MXFP4 --batch_size 8 --accuracy --enable_torch_compile
```

### Mix-precision Quantization (`MXFP4 + MXFP8`)
Expand All @@ -41,10 +41,11 @@ python quantize.py \
--use_recipe \
--recipe_file recipes/Meta-Llama-3.1-8B-Instruct_7bits.json \
--accuracy \
--batch_size 32
--batch_size 32 \
--enable_torch_compile

# Llama 3.3 70B
deepspeed --include="localhost:4,5,6,7" --master_port=29500 python quantize.py \
deepspeed --include="localhost:0,1,2,3" --master_port=29500 quantize.py \
--model_name_or_path meta-llama/Llama-3.3-70B-Instruct/ \
--quantize \
--dtype MXFP4 \
Expand Down Expand Up @@ -111,13 +112,13 @@ Model with mixed precision is not supported in vLLM, but supported in transforme
python quantize.py \
--model_name_or_path meta-llama/Llama-3.1-8B-Instruct \
--quantize \
--iters 0 \
--dtype MXFP4 \
--use_recipe \
--recipe_file recipes/Meta-Llama-3.1-8B-Instruct_7bits.json \
--save \
--save_format auto_round \
--save_path Llama-3.1-8B-Instruct-MXFP4-MXFP8-AR
--save_path Llama-3.1-8B-Instruct-MXFP4-MXFP8-AR \
--enable_torch_compile

# Command to inference with transformer:
python run_hf_inf.py Llama-3.1-8B-Instruct-MXFP4-MXFP8-AR
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,14 @@ def initialize_model_and_tokenizer(model_name_or_path):
parser.add_argument("--device_map", type=str, default=None, help="device map for model")
parser.add_argument("--use_recipe", action="store_true", help="whether to use recipe to quantize model")
parser.add_argument("--recipe_file", type=str, default="recipes/Meta-Llama-3.1-8B-Instruct_6bits.json", help="path of recipe file")
parser.add_argument("--mem_per_param_scale", default=13, type=int, help="memory per param scale factor")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not see this arg is used in example, is it added for further tuning consideration? any guideline on how user set the value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, It's for llama3.3 70b pipeline parallel. It's added in case that user wants to run 70b without TP.
It's not the suggested way, the suggested way is using main branch with my fix of compile, so I intend not to introduce it.

parser.add_argument("--iters", default=200, type=int, help="iters for autoround.")
parser.add_argument("--seqlen", default=2048, type=int, help="sequence length for autoround.")
parser.add_argument("--nsamples", default=128, type=int, help="number of samples for autoround.")
parser.add_argument("--save", action="store_true", help="whether to save the quantized model")
parser.add_argument("--save_path", type=str, default="saved_results", help="path to save the quantized model")
parser.add_argument("--save_format", type=str, default="auto_round", help="format to save the quantized model")
parser.add_argument("--enable_torch_compile", action="store_true", help="whether to enable torch.compile")
parser.add_argument("--quant_lm_head", action="store_true", help="whether to quantize lm_head")
parser.add_argument("--accuracy", action="store_true", help="accuracy measurement")
parser.add_argument("--local_rank", type=int, default=0, metavar="N", help="Local process rank.")
Expand All @@ -103,21 +105,24 @@ def initialize_model_and_tokenizer(model_name_or_path):
device="hpu" if is_hpex_available() else "cuda"

if args.quantize:
autoround_dtype_mapping = {
"MXFP4": "mx_fp4",
"MXFP8": "mx_fp8",
"NVFP4": "nv_fp4",
"uNVFP4": "fp4_v2",
"NVFP4+": "fp4_v2",
}
args.dtype = autoround_dtype_mapping[args.dtype]
if args.dtype in ["uNVFP4", "NVFP4+"]:
from auto_round.schemes import QuantizationScheme

uNVFP4 = QuantizationScheme.from_dict(
{
"bits": 4,
"group_size": 16,
"data_type": "fp4_v2",
"act_bits": 4,
"act_data_type": "fp4_v2",
"act_group_size": 16,
"act_sym": True,
}
)
args.dtype = uNVFP4

if args.quant_lm_head:
lm_head_config = {
"group_size": 32 if "mx" in args.dtype else 16,
"data_type": args.dtype,
"act_data_type": "fp4_v2_with_global_scale" if "fp4_v2" in args.dtype else args.dtype,
}
layer_config = {"lm_head": lm_head_config}
layer_config = {"lm_head": args.dtype}

autoround = AutoRound(
model,
Expand All @@ -128,10 +133,10 @@ def initialize_model_and_tokenizer(model_name_or_path):
seqlen=args.seqlen,
nsamples=args.nsamples,
low_gpu_mem_usage=True,
group_size=32 if "mx" in args.dtype else 16,
data_type=args.dtype,
act_data_type="fp4_v2_with_global_scale" if "fp4_v2" in args.dtype else args.dtype,
scheme=args.dtype,
layer_config=layer_config if args.quant_lm_head else None,
enable_torch_compile=args.enable_torch_compile,
mem_per_param_scale=args.mem_per_param_scale,
)

if args.use_recipe:
Expand All @@ -140,20 +145,16 @@ def load_recipe_results(file_path):
import json
with open(file_path, "r") as f:
return json.load(f)

layer_config = load_recipe_results(args.recipe_file)
if args.quant_lm_head:
mxfp8_config = {
"bits": 8,
"group_size": 32,
"data_type": "mx_fp8",
"act_data_type": "mx_fp8",
}
# ensure lm_head is quantized with mxfp8_config
layer_config.update({"lm_head": mxfp8_config})
layer_config.update({"lm_head": "MXFP8"})
print("In recipe mode, lm_head is quantized with MXFP8.")
autoround.layer_config = layer_config

# A placeholder, to pass assertion in AutoRound
autoround.formats = "auto_round"
autoround.quantize()
model = autoround.model

Expand Down Expand Up @@ -192,7 +193,6 @@ def load_recipe_results(file_path):
else:
# CUDA evaluation support all tasks.
# gsm8k requires add_bos_token=False for better accuracy for llama model.
# model = torch.compile(model)
args.tasks = ["piqa", "hellaswag", "mmlu", "gsm8k"]
all_accuracy = {}
test_gsm8k = False
Expand Down Expand Up @@ -243,7 +243,7 @@ def load_recipe_results(file_path):
print(f"Overall accuracy: {sum(all_accuracy.values())/len(all_accuracy):.4f}")

if args.save:
if args.dtype == "nv_fp4":
if args.dtype == "NVFP4":
# using llm_compressor format to save nv_fp4 model
autoround.save_quantized(args.save_path, format=args.save_format)
else:
Expand Down
Loading