Skip to content

Commit

Permalink
update sparseGPT example (#1408)
Browse files Browse the repository at this point in the history
Signed-off-by: Zhang, Weiwei1 <weiwei1.zhang@intel.com>
Co-authored-by: kevinintel <hanwen.chang@intel.com>
  • Loading branch information
WeiweiZhang1 and kevinintel committed Mar 25, 2024
1 parent ef0882f commit 3ae0cd0
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def skip(*args, **kwargs):
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip

import re
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import logging
Expand Down Expand Up @@ -69,7 +69,7 @@ def parse_args():
parser.add_argument(
"--calibration_dataset_name",
type=str,
default="wikitext-2-raw-v1",
default="NeelNanda/pile-10k", # e.g. wikitext-2-raw-v1
help="The name of the pruning dataset to use (via the datasets library).",
)
parser.add_argument(
Expand Down Expand Up @@ -128,6 +128,12 @@ def parse_args():
default=16,
help="Batch size (per device) for the evaluation dataloader.",
)
parser.add_argument(
"--calib_size",
type=int,
default=128,
help="sample size for the calibration dataset.",
)
parser.add_argument(
"--learning_rate",
type=float,
Expand Down Expand Up @@ -268,10 +274,12 @@ def parse_args():
parser.add_argument("--tasks", default=["lambada_openai"],
help="Usually chosen with ['lambada_openai','hellaswag','winogrande','piqa']",
)
parser.add_argument("--eval_fp16", action='store_true',
help=" fp16")
parser.add_argument("--use_accelerate", action='store_true',
help="Usually use to accelerate evaluation for large models")
help="Usually use to accelerate evaluation for large models"
)
parser.add_argument("--eval_dtype", default='fp32',
help="choose in bf16, fp16 and fp32"
)

args = parser.parse_args()

Expand Down Expand Up @@ -376,34 +384,33 @@ def main():
logger.warning("You are instantiating a new config instance from scratch.")

is_llama = bool("llama" in args.model_name_or_path)
is_t5 = bool("t5" in args.model_name_or_path)
if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer)
elif args.model_name_or_path:
if is_llama:
tokenizer = transformers.LlamaTokenizer.from_pretrained(args.model_name_or_path)
else :
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path,
use_fast=not args.use_slow_tokenizer, trust_remote_code=True)
else:
raise ValueError(
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
)

if args.model_name_or_path:
if is_t5:
model = T5ForConditionalGeneration.from_pretrained(
args.model_name_or_path,
config=config,
)
if re.search("chatglm", args.model_name_or_path.lower()):
model = AutoModel.from_pretrained(args.model_name_or_path,
trust_remote_code=args.trust_remote_code) # .half()
else:
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
trust_remote_code=args.trust_remote_code,
low_cpu_mem_usage=args.low_cpu_mem_usage,
low_cpu_mem_usage=args.low_cpu_mem_usage
)


else:
logger.info("Training new model from scratch")
Expand Down Expand Up @@ -492,7 +499,7 @@ def group_texts(examples):
train_dataset = lm_datasets["train"]

# DataLoaders creation:
train_dataset = train_dataset.shuffle(seed=42).select(range(128))
train_dataset = train_dataset.shuffle(seed=42).select(range(args.calib_size))
total_batch_size = args.per_device_train_batch_size
if local_rank != -1:
total_batch_size *= WORLD_SIZE
Expand Down Expand Up @@ -543,8 +550,10 @@ def group_texts(examples):
torch.backends.cudnn.allow_tf32 = False
use_cache = model.config.use_cache
model.config.use_cache = False

import time
s = time.time()
pruning = prepare_pruning(model, configs, dataloader=train_dataloader, device=device)
logger.info(f"cost time: {time.time() - s}")
model.config.use_cache = use_cache

if args.output_dir is not None:
Expand All @@ -555,20 +564,28 @@ def group_texts(examples):
logger.info(f"The model has been exported to {output_dir}")

if device != 'cpu':
model = model.to(device)
if not args.use_accelerate:
model = model.to(device)
else:
model = model.cpu()
logger.info(f"***** Evaluation in GPU mode. *****")
else:
logger.info(f"***** Evaluation in CPU mode. *****")
model.eval()

model_name = args.model_name_or_path
dtype = 'float32'
if args.eval_fp16:
if (hasattr(model, 'config') and model.config.torch_dtype is torch.bfloat16):
dtype = 'bfloat16'
else:
dtype='float16'
model_args = f'pretrained={model_name},tokenizer={model_name},dtype={dtype},use_accelerate={args.use_accelerate}'
dtype = None
if args.eval_dtype == 'bf16':
model = model.to(dtype=torch.bfloat16)
dtype = 'bfloat16'
elif args.eval_dtype == 'fp16':
dtype = 'float16'
model = model.to(dtype=torch.float16)
else:
dtype = 'float32'
model = model.to(dtype=torch.float32)

model_args = f'pretrained={model_name},tokenizer={model_name},dtype={dtype},use_accelerate={args.use_accelerate},trust_remote_code={args.trust_remote_code}'
eval_batch = args.per_device_eval_batch_size
user_model = None if args.use_accelerate else model
results = evaluate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ export CUBLAS_WORKSPACE_CONFIG=':4096:8'
#cd intel-extension-for-transformers
python examples/huggingface/pytorch/language-modeling/pruning/run_clm_sparsegpt.py \
--model_name_or_path /PATH/TO/LLM/ \
--calibration_dataset_name wikitext-2-raw-v1 \
--do_prune \
--device=0 \
--output_dir=/PATH/TO/SAVE/ \
--eval_dtype 'bf16' \
--per_device_eval_batch_size 16 \
--target_sparsity 0.5 \
--pruning_pattern 1x1

Expand Down

0 comments on commit 3ae0cd0

Please sign in to comment.