Skip to content

Commit

Permalink
[LLM example] add calib_shuffle args for text-generation example (#1087)
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Chang1 <chang1.wang@intel.com>
  • Loading branch information
changwangss committed Dec 28, 2023
1 parent e8170aa commit a4aba8d
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
AutoModel,
)
from transformers.utils import check_min_version
from intel_extension_for_transformers.transformers.utils import str2bool
from optimum.intel.generation.modeling import TSModelForCausalLM
from intel_extension_for_transformers.transformers import (
MixedPrecisionConfig,
Expand Down Expand Up @@ -67,6 +68,12 @@
parser.add_argument(
"--calib_padding", action="store_true", help="Calibration dataset do padding."
)
parser.add_argument(
"--calib_shuffle",
default=True,
type=str2bool,
help="Calibration dataset do shuffle.",
)
parser.add_argument(
"--calib_pad_val", default=1, type=int, help="Calibration dataset padding value."
)
Expand Down Expand Up @@ -126,16 +133,14 @@
parser.add_argument("--load_in_4bit", type=bool, default=False)
parser.add_argument("--load_in_8bit", type=bool, default=False)
parser.add_argument("--_commit_hash", default="main", type=str)
parser.add_argument("--trust_remote_code", default=False)
parser.add_argument("--trust_remote_code", type=bool, default=False)
parser.add_argument("--use_llm_runtime", action="store_true")
# =======================================
args = parser.parse_args()

# transformers version >= 4.32.0 contained the mpt modeling definition.
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mpt/modeling_mpt.py
# 4.31.0 for ipex.optimize_transformers
check_min_version("4.31.0")

# get model config
if args.peft_model_id:
from peft import PeftConfig
Expand Down Expand Up @@ -228,6 +233,7 @@
op_type_dict=op_type_dict, # default is {}
excluded_precisions=excluded_precisions, # default is []
num_beams=generate_kwargs["num_beams"],
calib_shuffle=args.calib_shuffle,
calib_iters=args.calib_iters,
calib_padding=args.calib_padding,
calib_len=args.calib_len,
Expand Down Expand Up @@ -257,7 +263,6 @@
trust_remote_code=args.trust_remote_code,
_commit_hash=args._commit_hash,
use_llm_runtime=args.use_llm_runtime,

)
elif args.load_in_4bit or args.load_in_8bit:
# CPU device usage is provided by intel-extension-for-transformers.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
from torch.utils.data import DataLoader

calib_dataset = quantization_config.calib_dataset
calib_shuffle = quantization_config.calib_shuffle
calib_iters = quantization_config.calib_iters
calib_padding = quantization_config.calib_padding
calib_len = quantization_config.calib_len
Expand All @@ -392,7 +393,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if calib_dataset in ["mbpp", "openai_humaneval"]
else "train",
)
calib_dataset = calib_dataset.shuffle(seed=42)
if calib_shuffle:
calib_dataset = calib_dataset.shuffle(seed=42)

def tokenize_function(examples):
if "prompt" in examples:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@
SparsityConfig,
WeightOnlyQuantConfig,
)
from .utility import LazyImport, logger
from .utility import LazyImport, logger, str2bool
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ class SmoothQuantConfig:
tokenizer: Any = None
calib_func: Any = None
calib_dataset: str = "NeelNanda/pile-10k"
calib_shuffle: bool = True
calib_iters: int = 100
calib_padding: bool = False
calib_len: int = 512
Expand Down
10 changes: 10 additions & 0 deletions intel_extension_for_transformers/transformers/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

"""Utils for pytorch framework."""

import argparse
import os
from typing import Optional, Tuple
from neural_compressor.utils import logger
Expand All @@ -36,6 +37,15 @@

torch = LazyImport("torch")

def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')

def distributed_init(
backend="gloo",
Expand Down

0 comments on commit a4aba8d

Please sign in to comment.