Skip to content

Commit

Permalink
Add WOQ GPTQ frontend and example (#1107)
Browse files Browse the repository at this point in the history
  • Loading branch information
changwangss committed Jan 19, 2024
1 parent 63d9543 commit f4c58d0
Show file tree
Hide file tree
Showing 11 changed files with 590 additions and 166 deletions.
23 changes: 23 additions & 0 deletions examples/.config/pytorch_optimize.json
Original file line number Diff line number Diff line change
Expand Up @@ -1384,6 +1384,29 @@
}
}
},
"llama2_7b_gen_woq_gptq": {
"working_dir": "huggingface/pytorch/text-generation/quantization",
"tune":{
"cmd": "bash run_tuning.sh",
"params": {
"topology": "llama2_7b_int4_gptq",
"task": "generation",
"output_model": "saved_results"
}
},
"benchmark": {
"cmd": "bash run_benchmark.sh",
"params": {
"topology": "llama2_7b_int4_gptq",
"task": "generation",
"mode": "benchmark",
"batch_size": "112",
"iters": "100",
"int8": "false",
"config": "saved_results"
}
}
},
"gpt_j_6b_gen_woq_bab": {
"working_dir": "huggingface/pytorch/text-generation/quantization",
"tune":{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ function run_benchmark {
model_name_or_path="bigscience/bloomz-3b"
elif [ "${topology}" = "llama_7b" ]; then
model_name_or_path="meta-llama/Llama-2-7b-chat-hf"
elif [ "${topology}" = "llama2_7b_int4_gptq" ]; then
model_name_or_path="meta-llama/Llama-2-7b-hf"
elif [ "${topology}" = "llama_13b" ]; then
model_name_or_path="meta-llama/Llama-2-13b-chat-hf"
elif [ "${topology}" = "dolly_v2_3b" ]; then
Expand Down Expand Up @@ -165,6 +167,11 @@ function run_benchmark {
extra_cmd=$extra_cmd" --load_in_8bit True"
elif [ "${topology}" = "gpt_j_mp" ]; then
extra_cmd=$extra_cmd" --mixed_precision"
elif [ "${topology}" = "llama2_7b_int4_gptq" ]; then
model_name_or_path="meta-llama/Llama-2-7b-hf"
extra_cmd=$extra_cmd" --woq --woq_weight_dtype int4_clip --woq_compute_dtype fp32"
extra_cmd=$extra_cmd" --woq_algo "GPTQ" --gptq_actorder --gptq_block_size 128 --gptq_pad_max_length 2048 --gptq_use_max_length"
pip install tranformers==4.35.2
else
extra_cmd=$extra_cmd" --int8"
fi
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
parser.add_argument(
"--woq_algo",
default="RTN",
choices=["RTN", "AWQ", "TEQ"],
choices=["RTN", "AWQ", "TEQ", "GPTQ"],
help="Weight-only parameter.",
)
parser.add_argument(
Expand Down Expand Up @@ -127,6 +127,37 @@
)
parser.add_argument("--woq_group_size", type=int, default=32)
parser.add_argument("--woq_scheme", default="sym")
parser.add_argument(
"--gptq_actorder",
action="store_true",
help="Whether to apply the activation order GPTQ heuristic.",
)
parser.add_argument(
"--gptq_percdamp",
type=float,
default=0.01,
help="Percent of the average Hessian diagonal to use for dampening.",
)
parser.add_argument(
"--gptq_block_size",
type=int,
default=128,
help="Block size. sub weight matrix size to run GPTQ.",
)
parser.add_argument(
"--gptq_nsamples", type=int, default=128, help="Number of calibration data samples."
)
parser.add_argument(
"--gptq_use_max_length",
action="store_true",
help="Set all sequence length to be same length of args.gptq_pad_max_length",
)
parser.add_argument(
"--gptq_pad_max_length",
type=int,
default=2048,
help="Calibration dataset sequence max length, this should align with your model config",
)
# ============BitsAndBytes configs==============
parser.add_argument("--bitsandbytes", action="store_true")
# ============AutoModel parameters==============
Expand Down Expand Up @@ -240,13 +271,33 @@
calib_pad_val=args.calib_pad_val,
)
elif args.woq:
quantization_config = WeightOnlyQuantConfig(
compute_dtype=args.woq_compute_dtype,
scale_dtype=args.woq_scale_dtype,
weight_dtype=args.woq_weight_dtype,
scheme=args.woq_scheme,
group_size=args.woq_group_size,
) # default is A32W4G32
if args.woq_algo == "GPTQ":
gptq_recipes = {
"act_order": args.gptq_actorder,
"percdamp": args.gptq_percdamp,
"block_size": args.gptq_block_size,
"nsamples": args.gptq_nsamples,
"use_max_length": args.gptq_use_max_length,
"pad_max_length": args.gptq_pad_max_length,
}
quantization_config = WeightOnlyQuantConfig(
compute_dtype=args.woq_compute_dtype,
scale_dtype=args.woq_scale_dtype,
weight_dtype=args.woq_weight_dtype,
scheme=args.woq_scheme,
group_size=args.gptq_block_size,
algorithm=args.woq_algo,
tokenizer=tokenizer,
gptq_recipes=gptq_recipes,
)
else:
quantization_config = WeightOnlyQuantConfig(
compute_dtype=args.woq_compute_dtype,
scale_dtype=args.woq_scale_dtype,
weight_dtype=args.woq_weight_dtype,
scheme=args.woq_scheme,
group_size=args.woq_group_size,
) # default is A32W4G32
# bitsandbytes
elif args.bitsandbytes:
# GPU device is need for `load_in_4bit` and `load_in_8bit`.
Expand Down Expand Up @@ -327,7 +378,9 @@
trust_remote_code=args.trust_remote_code,
)


if args.benchmark:
user_model.eval()
prompt = "Once upon a time, there existed a little girl, who liked to have adventures. She wanted to go to places and meet new people, and have fun."

input_size = tokenizer(prompt, return_tensors="pt").input_ids.size(dim=1)
Expand All @@ -339,7 +392,6 @@
num_warmup = args.num_warmup
total_token_num = 0
eos_token_id = tokenizer.eos_token_id

with torch.inference_mode(), torch.no_grad():
for i in range(num_iter):
tic = time.time()
Expand Down Expand Up @@ -383,6 +435,7 @@
print("Throughput: {} samples/sec".format(throughput))

if args.accuracy:

args.model = (
peft_config.base_model_name_or_path if args.peft_model_id else args.model
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ function run_tuning {
extra_cmd=$extra_cmd" --output_dir ${tuned_checkpoint}"
extra_cmd=$extra_cmd" --trust_remote_code True"
extra_cmd=$extra_cmd" --_commit_hash f7bc352f27bb1c02ee371a4576942a7d96c8bb97"
pip install transformers==4.35.2
pip install transformers==4.35.2
elif [ "${topology}" = "mistral_7b" ]; then
alpha=0.8
model_name_or_path="Intel/neural-chat-7b-v3"
Expand All @@ -195,14 +195,21 @@ function run_tuning {
extra_cmd=$extra_cmd" --sq --alpha ${alpha}"
extra_cmd=$extra_cmd" --output_dir ${tuned_checkpoint}"
extra_cmd=$extra_cmd" --trust_remote_code True"
pip install transformers==4.36.1
pip install transformers==4.36.1
elif [ "${topology}" = "phi_1_5b" ]; then
alpha=0.5
model_name_or_path="susnato/phi-1_5_dev"
extra_cmd=$extra_cmd" --sq --alpha ${alpha}"
extra_cmd=$extra_cmd" --output_dir ${tuned_checkpoint}"
extra_cmd=$extra_cmd" --trust_remote_code True"
pip install tranformers==4.36.1
pip install tranformers==4.36.1
elif [ "${topology}" = "llama2_7b_int4_gptq" ]; then
model_name_or_path="meta-llama/Llama-2-7b-hf"
extra_cmd=$extra_cmd" --woq --woq_weight_dtype int4_clip --woq_compute_dtype fp32"
extra_cmd=$extra_cmd" --woq_algo "GPTQ" --gptq_actorder --gptq_block_size 128 --gptq_pad_max_length 2048 --gptq_use_max_length"
extra_cmd=$extra_cmd" --output_dir ${tuned_checkpoint}"
extra_cmd=$extra_cmd" --trust_remote_code True"
pip install tranformers==4.35.2
fi

if [ ${script} = "run_generation.py" ];then
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

def convert_idx(g_idx, k, blocksize):
ret_idx = torch.zeros(k, dtype=int)
g_counter = torch.zeros(blocksize, dtype=int)
g_counter = torch.zeros((k+blocksize-1) // blocksize, dtype=int)
for i in range(k):
ret_idx[g_idx[i]*blocksize+g_counter[g_idx[i]]] = i
g_counter[g_idx[i]] += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,23 @@
from torch import Tensor
from typing import Tuple, Optional, List


def prod(iterable):
return reduce(operator.mul, iterable, 1)


class MatMulKBit(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B, out=None, bias=None, compute_dtype=None, weight_dtype=None, scale_dtype=None):
def forward(
ctx,
A,
B,
out=None,
bias=None,
compute_dtype=None,
weight_dtype=None,
scale_dtype=None,
):
# # 1. Dequantize
# B_dequant = torch.zeros(out.shape[-1], A.shape[-1], dtype=torch.float)
# torch.ops.bestlaop.woq_dequantize(
Expand All @@ -39,28 +50,49 @@ def forward(ctx, A, B, out=None, bias=None, compute_dtype=None, weight_dtype=Non
if prod(A.shape) == 0:
ctx.is_empty = True
ctx.A = A
ctx.B = B # B_dequant
ctx.B = B # B_dequant
ctx.bias = bias
B_shape = (out.shape[-1], A.shape[-1]) # B_dequant.shape
B_shape = (out.shape[-1], A.shape[-1]) # B_dequant.shape
if A.shape[-1] == B_shape[0]:
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
return torch.empty(
A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device
)
else:
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
return torch.empty(
A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device
)

# 2. Matmul
# output = torch.nn.functional.linear(A, B_dequant, bias)
torch.ops.bestlaop.woq_linear(
A, B.data, bias, out, out.shape[-1], bias is not None, compute_dtype, weight_dtype, scale_dtype,
False)
A,
B.data,
bias,
out,
out.shape[-1],
bias is not None,
compute_dtype,
weight_dtype,
scale_dtype,
False,
)
output = out

# 3. Save state
ctx.compute_dtype, ctx.weight_dtype, ctx.scale_dtype = compute_dtype, weight_dtype, scale_dtype
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
ctx.compute_dtype, ctx.weight_dtype, ctx.scale_dtype = (
compute_dtype,
weight_dtype,
scale_dtype,
)
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = (
A.dtype,
B.dtype,
None if bias is None else bias.dtype,
)
# B_dequant.dtype

if any(ctx.needs_input_grad[:2]):
ctx.tensors = (A, B) # B_dequant
ctx.tensors = (A, B) # B_dequant
else:
ctx.tensors = (None, None)

Expand All @@ -70,41 +102,64 @@ def forward(ctx, A, B, out=None, bias=None, compute_dtype=None, weight_dtype=Non
def backward(ctx, grad_output):
if ctx.is_empty:
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
return (
torch.zeros_like(ctx.A),
torch.zeros_like(ctx.B),
None,
bias_grad,
None,
)

req_gradA, _, _, req_gradBias, _, _, _ = ctx.needs_input_grad
A, B = ctx.tensors
grad_A, grad_B, grad_bias = None, None, None

B_dequant = torch.zeros(grad_output.shape[-1], A.shape[-1], dtype=torch.float)

torch.ops.bestlaop.woq_dequantize(
B, B_dequant, True, ctx.compute_dtype, ctx.weight_dtype, ctx.scale_dtype)
B, B_dequant, True, ctx.compute_dtype, ctx.weight_dtype, ctx.scale_dtype
)

B = B_dequant

if req_gradBias:
# compute grad_bias first before changing grad_output dtype
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)

# not supported by PyTorch. TODO: create work-around
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
if req_gradA: grad_A = torch.matmul(grad_output, B.to(grad_output.dtype))
# if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
if req_gradA:
grad_A = torch.matmul(grad_output, B.to(grad_output.dtype))

return grad_A, grad_B, None, grad_bias, None, None, None


def matmul_kbit(A: Tensor,
B: Tensor,
bias,
out,
compute_dtype,
weight_dtype,
scale_dtype,
do_dequant=False):
def matmul_kbit(
A: Tensor,
B: Tensor,
bias,
out,
compute_dtype,
weight_dtype,
scale_dtype,
do_dequant=False,
):
if do_dequant:
return MatMulKBit.apply(A, B, out, bias, compute_dtype, weight_dtype,
scale_dtype)
return MatMulKBit.apply(
A, B, out, bias, compute_dtype, weight_dtype, scale_dtype
)
else:
torch.ops.bestlaop.woq_linear(A, B.data, bias, out, out.shape[-1], bias
is not None, compute_dtype, weight_dtype,
scale_dtype, False)
torch.ops.bestlaop.woq_linear(
A,
B.data,
bias,
out,
out.shape[-1],
bias is not None,
compute_dtype,
weight_dtype,
scale_dtype,
False,
)

return out

0 comments on commit f4c58d0

Please sign in to comment.