Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support autoTP with weight only quantization in DS inference path #4750

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions deepspeed/inference/quantization/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,67 @@ def __init__(self, config: Dict, pre_quant_layer: nn.Linear) -> None:
device=pre_quant_layer.weight.device,
dtype=pre_quant_layer.weight.dtype)
self.config = config
self.quantizer = Quantizer(config=config)
self.bias = pre_quant_layer.bias
self.weight = get_quantized_weight_wrapper(self, pre_quant_layer.weight,
get_quantize_weight_fn(self.quantizer, pre_quant_layer.weight))

self.weight.dequantizer = DeQuantizer(config, pre_quant_layer.weight.dtype)

def forward(self, input: Tensor) -> Tensor:
quantized_weight, quant_scale, quant_min = self.weight.deconcat(self.weight)
temp_dequantized_weight = self.weight.dequantizer.dequantize(quantized_weight.view(torch.uint8), quant_scale,
quant_min)

# !!! Do not use torch.functional.linear(input, temp_dequantized_weight, self.bias) here as in zero3 torch.functional.linear is
# replaced by LinearFunctionForZeroStage3. Which assume weight is non-temporary.
# If weight is temp buffer there will be memory leak.
return torch._C._nn.linear(input, temp_dequantized_weight, self.bias)


class QuantizedLinearAllreduce(nn.Linear):

def __init__(self, config: Dict, pre_quant_layer: nn.Linear) -> None:
super(QuantizedLinearAllreduce, self).__init__(in_features=pre_quant_layer.weight.shape[1],
out_features=pre_quant_layer.weight.shape[0],
bias=pre_quant_layer.bias is not None,
device=pre_quant_layer.weight.device,
dtype=pre_quant_layer.weight.dtype)
self.config = config
self.mp_group = pre_quant_layer.mp_group if hasattr(pre_quant_layer, 'mp_group') else None
self.quantizer = Quantizer(config=config, mp_group=self.mp_group)
self.bias = pre_quant_layer.bias
self.weight = get_quantized_weight_wrapper(self, pre_quant_layer.weight,
get_quantize_weight_fn(self.quantizer, pre_quant_layer.weight))

self.weight.dequantizer = DeQuantizer(config, pre_quant_layer.weight.dtype)

def forward(self, input: Tensor) -> Tensor:
quantized_weight, quant_scale, quant_min = self.weight.deconcat(self.weight)
temp_dequantized_weight = self.weight.dequantizer.dequantize(quantized_weight.view(torch.uint8), quant_scale,
quant_min)

# !!! Do not use torch.functional.linear(input, temp_dequantized_weight, self.bias) here as in zero3 torch.functional.linear is
# replaced by LinearFunctionForZeroStage3. Which assume weight is non-temporary.
# If weight is temp buffer there will be memory leak.
output = torch._C._nn.linear(input, temp_dequantized_weight)
if self.mp_group is not None:
from deepspeed import comm as dist
dist.inference_all_reduce(output, group=self.mp_group)
if self.bias is not None:
output += self.bias
return output


class QuantizedLinearLayer(nn.Linear):

def __init__(self, config: Dict, pre_quant_layer: nn.Linear) -> None:
super(QuantizedLinearLayer, self).__init__(in_features=pre_quant_layer.weight.shape[1],
out_features=pre_quant_layer.weight.shape[0],
bias=pre_quant_layer.bias is not None,
device=pre_quant_layer.weight.device,
dtype=pre_quant_layer.weight.dtype)
self.config = config
self.quantizer = Quantizer(config=config)
self.bias = pre_quant_layer.bias
self.weight = get_quantized_weight_wrapper(self, pre_quant_layer.weight,
Expand All @@ -72,6 +132,46 @@ def forward(self, input: Tensor) -> Tensor:
return torch._C._nn.linear(input, temp_dequantized_weight, self.bias)


class QuantizedLmHeadLinearAllreduce(nn.Linear):

def __init__(self, config: Dict, pre_quant_layer: nn.Linear) -> None:
super(QuantizedLmHeadLinearAllreduce, self).__init__(in_features=pre_quant_layer.weight.shape[1],
out_features=pre_quant_layer.weight.shape[0],
bias=pre_quant_layer.bias is not None,
device=pre_quant_layer.weight.device,
dtype=pre_quant_layer.weight.dtype)
self.config = config
self.quantizer = Quantizer(config=config)
self.bias = pre_quant_layer.bias
self.rank = pre_quant_layer.rank
self.world_size = pre_quant_layer.world_size
self.weight = get_quantized_weight_wrapper(self, pre_quant_layer.weight,
get_quantize_weight_fn(self.quantizer, pre_quant_layer.weight))

self.weight.dequantizer = DeQuantizer(config, pre_quant_layer.weight.dtype)

def forward(self, input: Tensor) -> Tensor:
quantized_weight, quant_scale, quant_min = self.weight.deconcat(self.weight)
temp_dequantized_weight = self.weight.dequantizer.dequantize(quantized_weight.view(torch.uint8), quant_scale,
quant_min)
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
input_shard_size = get_shard_size(input.shape[-1], self.world_size)
input_shard_offset = sum(get_shard_size_list(input.shape[-1], self.world_size)[0:self.rank])

# !!! Do not use torch.functional.linear(input, temp_dequantized_weight, self.bias) here as in zero3 torch.functional.linear is
# replaced by LinearFunctionForZeroStage3. Which assume weight is non-temporary.
# If weight is temp buffer there will be memory leak.
output = torch._C._nn.linear(input[:, :, input_shard_offset:input_shard_offset + input_shard_size],
temp_dequantized_weight.transpose(-1, -2))

if self.mp_group is not None:
from deepspeed import comm as dist
dist.inference_all_reduce(output, group=self.mp_group)
if self.bias is not None:
output += self.bias
return output


class QuantizedEmbedding(nn.Embedding):

def __init__(self, config: Dict, pre_quant_layer: nn.Embedding) -> None:
Expand Down Expand Up @@ -108,7 +208,12 @@ def forward(self, input: Tensor) -> Tensor:
self.scale_grad_by_freq, self.sparse)


from ...module_inject import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce

QUANTIZATION_LAYER_MAPPINGS = {
nn.Linear: QuantizedLinear,
nn.Embedding: QuantizedEmbedding,
LinearAllreduce: QuantizedLinearAllreduce,
LinearLayer: QuantizedLinearLayer,
LmHeadLinearAllreduce: QuantizedLmHeadLinearAllreduce
}
9 changes: 7 additions & 2 deletions deepspeed/inference/quantization/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,20 @@ def _init_group_wise_weight_quantization(model: nn.Module, ds_config: Dict) -> n
assert matched_key is None, f'{module_name} matched multiple quantization key word {matched_key} and {key}'
matched_key = key
matched_quantization_config = config
elif key == '*':
# The preserved key '*' is used to present all quantizable layers for simplifying the ds_config
assert matched_key is None, f'{module_name} matched multiple quantization key word {matched_key} and the wildcard "*"'
matched_key = module_name
matched_quantization_config = config
else:
continue

if matched_key is None:
continue

if is_zero3_enabled:
module.weight.all_gather()

assert module.weight.dtype == torch.float16, 'Model weight is expected in half.'

new_module = QUANTIZATION_LAYER_MAPPINGS[type(module)](matched_quantization_config, module)

if is_zero3_enabled:
Expand Down
3 changes: 2 additions & 1 deletion deepspeed/inference/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ def tensor_round(tensor: Tensor) -> Tensor:

class Quantizer:

def __init__(self, config: Dict) -> None:
def __init__(self, config: Dict, mp_group=None) -> None:
self.config = config
self.mp_group = mp_group
assert self.config['num_bits'] == 4 or self.config[
'num_bits'] == 8, 'Only INT4 and INT8 quantization is supported.'
assert self.config['symmetric'] == False, 'Only asymmetric quantization is supported at this moment.'
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/module_inject/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
from .replace_module import replace_transformer_layer, revert_transformer_layer, ReplaceWithTensorSlicing, GroupQuantizer, generic_injection
from .module_quantize import quantize_transformer_layer
from .replace_policy import HFBertLayerPolicy
from .layers import LinearAllreduce, LinearLayer, EmbeddingLayer, Normalize
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce, EmbeddingLayer, Normalize
from .policy import DSPolicy
31 changes: 31 additions & 0 deletions tests/unit/inference/quantization/test_intX_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,37 @@ def test_model_quantization(self, quantization_bits):
assert type(model.self_attn.v_proj) is QuantizedLinear
assert type(model.self_attn.out_proj) is QuantizedLinear

def test_wildcard_model_quantization(self, quantization_bits):
reset_random()

config = AutoConfig.from_pretrained('facebook/opt-125m')

with torch.no_grad():
model = OPTDecoderLayer(config).half().to(device)
bits = quantization_bits

ds_config = {
'weight_quantization': {
'post_init_quant': {
'*': {
'num_bits': bits,
'group_size': 64,
'group_dim': 0,
'symmetric': False
},
}
}
}

model = _init_group_wise_weight_quantization(model, ds_config)

assert type(model.fc1) is QuantizedLinear
assert type(model.fc2) is QuantizedLinear
assert type(model.self_attn.q_proj) is QuantizedLinear
assert type(model.self_attn.k_proj) is QuantizedLinear
assert type(model.self_attn.v_proj) is QuantizedLinear
assert type(model.self_attn.out_proj) is QuantizedLinear

@pytest.mark.skipif(device == 'cpu', reason='CPU does support FP16 GEMM')
def test_quantized_linear(self, quantization_bits, group_dim):
reset_random()
Expand Down
52 changes: 52 additions & 0 deletions tests/unit/inference/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,58 @@ def test(
#assert assert_fn(bs_output, ds_output)


@pytest.mark.seq_inference
@pytest.mark.parametrize("model_w_task", [("tiiuae/falcon-7b", "text-generation")], ids=["falcon"])
class TestAutoTPwithWeightQuant(DistributedTest):
world_size = 2

def test(
self,
model_w_task,
query,
inf_kwargs,
assert_fn,
):
# TODO: enable this test for H100 tests
pytest.skip("Not enough GPU memory for this on V100 runners")
model, task = model_w_task
dtype = torch.bfloat16
local_rank = int(os.getenv("LOCAL_RANK", "0"))

# We have to load these large models on CPU with pipeline because not
# enough GPU memory
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True)
pipe = pipeline(task,
model=model,
tokenizer=tokenizer,
torch_dtype=dtype,
trust_remote_code=True,
device=torch.device("cpu"),
framework="pt")

pipe.model = deepspeed.init_inference(pipe.model, mp_size=self.world_size, replace_with_kernel_inject=False)
ds_config = {
"weight_quantization": {
"post_init_quant": {
'*': {
'num_bits': 4,
'group_size': 32,
'group_dim': 1,
'symmetric': False
},
}
}
}
from deepspeed.inference.quantization.quantization import _init_group_wise_weight_quantization
pipe.model = _init_group_wise_weight_quantization(pipe.model, ds_config)
pipe.device = torch.device(get_accelerator().device_name(local_rank))
ds_output = pipe(query, **inf_kwargs)

#print(local_rank, "baseline", bs_output)
print(local_rank, "deepspeed", ds_output)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @ftian1 I have run this test. But the result I got is 'deepspeed [{'generated_text': 'DeepSpeed is the greatest,,,,,,,,,,,,,,,'}]'. This result is not right. Can you figure out what's wrong with this test? BTW, I can pass all tests in test_intX_quantization.py.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@baodii may I know which device you are running on? cuda or cpu?

#assert assert_fn(bs_output, ds_output)


@pytest.mark.seq_inference
@pytest.mark.parametrize(
"model_w_task, injection_policy",
Expand Down