Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Quantization Aware Training (QAT)

Quantization-Aware Training (QAT) is a technique designed to bridge the accuracy gap often observed with Post-Training Quantization (PTQ). Unlike PTQ, which applies quantization after model training, QAT simulates the effects of low-precision arithmetic during the training process itself. This allows the model to adapt its weights and activations to quantization constraints, significantly reducing accuracy degradation. As a result, QAT is particularly effective in preserving model performance even at extremely low precisions, such as MXFP8 or MXFP4, making it a critical approach for deploying efficient, high-performance models on resource-constrained hardware.

## Pre-Requisites

Install the requirements for the example:

```bash
pip install -r requirements.txt
```

## Getting Started

In QAT, a model quantized using `prepare_qat()` can be directly fine-tuned with the original training pipeline. During QAT, the scaling factors inside quantizers are frozen and the model weights are fine-tuned.

### Hugging Face QAT

#### QAT

##### Step 1:

Start by training or fine-tuning your model in its original precision (e.g., BF16). This establishes a strong baseline before introducing quantization.

```
accelerate launch --config-file accelerate_config/fsdp1.yaml \
--fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer \
main.py \
--model_name_or_path meta-llama/Llama-3.1-8B \
--model_max_length 4096 \
--dataloader_drop_last True \
--do_train True \
--do_eval True \
--output_dir ./llama3.1-finetuned \
--dataset Daring-Anteater \
--num_train_epochs 2.0 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--eval_accumulation_steps 1 \
--save_strategy steps \
--save_steps 3000 \
--eval_strategy steps \
--eval_steps 3000 \
--load_best_model_at_end True \
--save_total_limit 2 \
--learning_rate 1e-5 \
--weight_decay 0.0 \
--warmup_ratio 0.1 \
--lr_scheduler_type linear \
--logging_steps 1 \
--report_to tensorboard
```

##### Step 2:

Quantize the trained model using `prepare_qat()` by setting the following flags `--quant_scheme MXFP8 --do_train False`. This inserts fake quantization modules into the model without starting training yet. Then save the model directly to a get post training quantization model.


```
accelerate launch --config-file accelerate_config/fsdp1.yaml \
--fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer \
main.py \
--model_name_or_path ./llama3.1-finetuned \
--model_max_length 4096 \
--dataloader_drop_last True \
--do_train False \
--do_eval False \
--quant_scheme MXFP8 \
--output_dir ./llama3.1-finetuned-ptq \
--dataset Daring-Anteater \
--num_train_epochs 2.0 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--eval_accumulation_steps 1 \
--save_strategy steps \
--save_steps 3000 \
--eval_strategy steps \
--eval_steps 3000 \
--load_best_model_at_end True \
--save_total_limit 2 \
--learning_rate 1e-5 \
--weight_decay 0.0 \
--warmup_ratio 0.1 \
--lr_scheduler_type linear \
--logging_steps 1 \
--report_to tensorboard
```

##### Step 3:

Train/fine-tune the quantized model with a small learning rate, e.g. 1e-5 for Adam optimizer by setting `--quant_scheme MXFP8 --do_train True`

```
accelerate launch --config-file accelerate_config/fsdp1.yaml \
--fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer \
main.py \
--model_name_or_path ./llama3.1-finetuned \
--model_max_length 4096 \
--dataloader_drop_last True \
--do_train True \
--do_eval True \
--quant_scheme MXFP8 \
--output_dir ./llama3.1-finetuned-qat \
--dataset Daring-Anteater \
--max_steps 1000 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--eval_accumulation_steps 1 \
--save_strategy steps \
--save_steps 3000 \
--eval_strategy steps \
--eval_steps 3000 \
--load_best_model_at_end True \
--save_total_limit 2 \
--learning_rate 1e-5 \
--weight_decay 0.0 \
--warmup_ratio 0.03 \
--lr_scheduler_type linear \
--logging_steps 1 \
--report_to tensorboard
```

#### Evaluation

Once QAT is complete, the saved quantized model can be deployed using vLLM for efficient inference. For example, to evaluate on GSM8K:

```
lm_eval \
--model vllm \
--model_args pretrained=./llama3.1-finetuned-qat,tensor_parallel_size=1,data_parallel_size=1,gpu_memory_utilization=0.3,max_model_len=32768,enforce_eager=True \
--tasks gsm8k \
--batch_size 8
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
enable_cpu_affinity: false
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: gpu
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_activation_checkpointing: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false
fsdp_offload_params: false
fsdp_reshard_after_forward: FULL_SHARD
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sync_module_states: true
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_use_orig_params: true
fsdp_version: 1
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: gpu
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import logging
import os
import sys
from dataclasses import dataclass, field
from warnings import warn

import torch
import transformers
from transformers.trainer_utils import get_last_checkpoint
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
HfArgumentParser,
Trainer,
default_data_collator,
set_seed,
TrainerCallback,
)

from utils import (
get_metrics_with_perplexity,
make_supervised_data_module,
)

logger = logging.getLogger(__name__)

@dataclass
class ModelArguments:
model_name_or_path: str = field(default="meta-llama/Llama-3.1-8B")

@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: str | None = field(default=None)
model_max_length: int = field(
default=2048,
metadata={
"help": (
"Maximum sequence length. Sequences will be right padded (and possibly truncated)."
)
},
)
dataloader_drop_last: bool = field(default=True)
bf16: bool = field(default=True)

@dataclass
class DataArguments:
dataset: str = field(
default="Daring-Anteater",
metadata={"help": "Specify the dataset.", "choices": ["Daring-Anteater"]},
)
train_size: int = field(
default=0,
metadata={"help": "Number of training samples to use. If `0`, use default training size."},
)
eval_size: int = field(
default=0,
metadata={
"help": "Number of evaluation samples to use. If `0`, use default evaluation size."
},
)

@dataclass
class QuantizationArguments:
quant_scheme: str | None = field(
default=None,
metadata={
"help": (
"Specify the quantization format for PTQ/QAT. if specified, PTQ/QAT will be enabled"
" with the specified quantization format"
),
"choices": ["MXFP8"],
},
)


def train():
parser = HfArgumentParser(
(ModelArguments, TrainingArguments, DataArguments, QuantizationArguments)
)

model_args, training_args, data_args, quant_args = parser.parse_args_into_dataclasses()

# Setup logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)

# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_process_index}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
)
# Set seed before initializing model.
set_seed(training_args.seed)

logger.info(f"arguments: {model_args}, {training_args}, {data_args}, {quant_args}")

# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
logger.info(f"Last checkpoint detected: {last_checkpoint}")


model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
torch_dtype=torch.bfloat16,
)
model.generation_config.do_sample = True
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, model_max_length=training_args.model_max_length
)
tokenizer.pad_token_id = tokenizer.eos_token_id

# We set model.config.use_cache to False for training when gradient_checkpointing=False.
# Currently useful for FSDP2 to allow for setting activation_checkpointing=True in the config file.
model.config.use_cache = False

# prepare model for quantization
if quant_args.quant_scheme is not None:
from neural_compressor.torch.quantization.quantize import prepare_qat
# inplace
# default mxfp8
prepare_qat(model)

logger.info("Finish model preparation for QAT.")

logger.info("Loading dataset......")

# reuse the dataset function, TODO: preprocess a new dataset
data_module = make_supervised_data_module(
dataset=data_args.dataset,
tokenizer=tokenizer,
train_size=data_args.train_size,
eval_size=data_args.eval_size,
)

# Ensure calibration size doesn't exceed evaluation dataset size
eval_dataset_size = len(data_module["eval_dataset"])

# Training
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint

# Torch >= 2.4 throws an error if `use_reentrant` is not set explicitly
if training_args.gradient_checkpointing and training_args.gradient_checkpointing_kwargs is None:
training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}

trainer = Trainer(
model=model,
processing_class=tokenizer,
args=training_args,
**data_module,
)

if training_args.do_train:
logger.info("Starting Train...")
trainer.train(resume_from_checkpoint=checkpoint)
logger.info("Training completed.")

if training_args.do_eval:
logger.info("Starting Evaluation...")
metrics = trainer.evaluate()
metrics = get_metrics_with_perplexity(metrics)
logger.info(f"Evaluation results: \n{metrics}")

if training_args.do_train and quant_args.quant_scheme is None:
logger.info("Saving the model...")
trainer.save_model(training_args.output_dir)
elif quant_args.quant_scheme is not None:
from neural_compressor.torch.export.export_hf import export_hf2compressored_model
# export quantized model for vllm inference using llm-compressor and compressed_tensor
export_hf2compressored_model(model, training_args.output_dir, quant_args.quant_scheme)
if tokenizer is not None:
tokenizer.save_pretrained(training_args.output_dir)


if __name__ == "__main__":
train()
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
auto-round==0.8.0
neural-compressor-pt==3.6
transformers==4.52.4
datasets
sentencepiece>=0.2.0
tensorboardX
peft
accelerate >= 0.12.0
lm-eval==0.4.9.1
Loading