Skip to content

Commit

Permalink
cast quant linear output to model dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
chu-tianxiang committed Aug 26, 2023
1 parent 3f501ac commit b633191
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions vllm/model_executor/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(
outfeatures // world_size, bias, **kwargs)

def forward(self, input_):
output_parallel = super().forward(input_)
output_parallel = super().forward(input_).to(input_.dtype)
if self.gather_output:
# All-gather across the partitions.
output = gather_from_tensor_model_parallel_region(output_parallel)
Expand Down Expand Up @@ -100,7 +100,7 @@ def forward(self, input_):
input_parallel = input_
else:
input_parallel = scatter_to_tensor_model_parallel_region(input_)
output_parallel = super().forward(input_parallel)
output_parallel = super().forward(input_parallel).to(input_.dtype)
if self.reduce_results and self.world_size > 1:
output = reduce_from_tensor_model_parallel_region(output_parallel)
else:
Expand All @@ -126,7 +126,7 @@ def forward(self, input_):
# All-gather across the partitions.
if self.input_is_parallel:
input_ = gather_from_tensor_model_parallel_region(input_)
output = super().forward(input_)
output = super().forward(input_).to(input_.dtype)
return output, None

if isinstance(module, QuantLinear):
Expand Down

0 comments on commit b633191

Please sign in to comment.