Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operators] preliminary symmetric weight quantization #298

Merged
merged 10 commits into from
Jul 4, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 21 additions & 0 deletions python/hidet/graph/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from hidet.graph import ops
from hidet.graph.nn.module import Module
from hidet.graph.tensor import Tensor, empty
Expand Down Expand Up @@ -61,3 +62,23 @@ def forward(self, x: Tensor) -> Tensor:
if self.bias is not None:
x = ops.add(x, self.bias)
return x


class SymQuantLinearTransposed(Module):
def __init__(self, weight: Tensor, bias: Optional[Tensor] = None, quant_type: str = 'int8'):
super().__init__()
self.in_features = weight.shape[0]
self.out_features = weight.shape[1]
qweight, scale = ops.symmetric_quantize(weight, quant_type=quant_type, dims=[-1])
self.qweight = qweight
self.scale = scale
self.bias = bias
Aalanli marked this conversation as resolved.
Show resolved Hide resolved

def extra_str(self) -> str:
return 'in_features={}, out_features={}'.format(self.in_features, self.out_features)

def forward(self, x: Tensor) -> Tensor:
x = ops.matmul(x, ops.symmetric_dequantize(ops.barrier(self.qweight), self.scale, dims=[-1]))
if self.bias is not None:
x = ops.add(x, self.bias)
return x
54 changes: 30 additions & 24 deletions python/hidet/graph/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, inputs: List[Tensor], attributes: Dict[str, Any], task: Optio
# cache
self._compiled_task: Optional[CompiledTask] = None

self.outputs = self._run()
self.outputs = self.symbolic_run()

def __str__(self):
arguments = ['{}: {}{}'.format(i, t.dtype.name, t.shape) for i, t in enumerate(self.inputs)]
Expand Down Expand Up @@ -93,32 +93,38 @@ def compiled_task(self) -> CompiledTask:
self._compiled_task = self.task.build(target=self.build_target)
return self._compiled_task

def _run(self) -> List[Tensor]:
from hidet.ir.tools import rewrite, simplify, collect

if all(t.storage is not None for t in self.inputs) and len(collect(self.task, SymbolVar)) == 0:
return self.imperative_run(self.inputs)
else:
output_types: List[TensorType] = [output_node.type for output_node in self.task.outputs]
outputs: List[Tensor] = []
remap: Dict[Var, Constant] = {}
for i, (a, b) in enumerate(zip(self.task.inputs, self.inputs)):
for d1, d2 in zip(a.type.shape, b.shape):
if isinstance(d1, Var) and not (d1 in remap and isinstance(remap[d1], Constant)):
remap[d1] = d2
for i, output_type in enumerate(output_types):
shape = [simplify(rewrite(d, remap)) for d in output_type.shape]
outputs.append(
Tensor(shape=shape, dtype=output_type.dtype.name, device=self.device, storage=None, trace=(self, i))
)
return outputs
def symbolic_run(self) -> List[Tensor]:
from hidet.ir.tools import rewrite, simplify

output_types: List[TensorType] = [output_node.type for output_node in self.task.outputs]
outputs: List[Tensor] = []
remap: Dict[Var, Constant] = {}
for i, (a, b) in enumerate(zip(self.task.inputs, self.inputs)):
for d1, d2 in zip(a.type.shape, b.shape):
if isinstance(d1, Var) and not (d1 in remap and isinstance(remap[d1], Constant)):
remap[d1] = d2
for i, output_type in enumerate(output_types):
shape = [simplify(rewrite(d, remap)) for d in output_type.shape]
outputs.append(
Tensor(shape=shape, dtype=output_type.dtype.name, device=self.device, storage=None, trace=(self, i))
)
return outputs

def get_output(self, idx: int) -> Tensor:
if self.outputs is None:
outputs = self._run()
from hidet.ir.tools import collect

could_imperative_run = (
all(t.storage is not None for t in self.inputs) and len(collect(self.task, SymbolVar)) == 0
)

if could_imperative_run:
if self.outputs is None or any(t.storage is None for t in self.outputs):
self.outputs = self.imperative_run(self.inputs)
else:
outputs = self.outputs
return outputs[idx]
if self.outputs is None:
self.outputs = self.symbolic_run()

return self.outputs[idx]

def _imperative_run_prepare_outputs(self) -> List[Tensor]:
from hidet.ir.tools import simplify, collect, rewrite
Expand Down
1 change: 1 addition & 0 deletions python/hidet/graph/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from .complex import real, imag, conj, make_complex
from .compare import equal, not_equal, less, greater, less_equal, greater_equal
from .compare import logical_not, logical_and, logical_or, logical_xor
from .quant import symmetric_quantize, symmetric_dequantize
from .reduce import mean, sum, var, min, max, std, prod, argmin, argmax, all, any
from .cumulative import cumsum
from .transform import squeeze, unsqueeze, flatten, concat, cast, take, rearrange, strided_slice, reshape
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/graph/ops/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def relu6(x: Tensor) -> Tensor:


def gelu(x: Tensor, approximate: bool = False) -> Tensor:
return GeluOp(x, approximate).outputs[0]
return GeluOp(x, approximate).get_output(0)


def silu(x: Tensor) -> Tensor:
Expand Down
8 changes: 4 additions & 4 deletions python/hidet/graph/ops/complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,16 @@ def __init__(self, real: Tensor, imag: Tensor):


def real(x: Tensor) -> Tensor:
return RealOperator(x).outputs[0]
return RealOperator(x).get_output(0)


def imag(x: Tensor) -> Tensor:
return ImagOperator(x).outputs[0]
return ImagOperator(x).get_output(0)


def conj(x: Tensor) -> Tensor:
return ConjOperator(x).outputs[0]
return ConjOperator(x).get_output(0)


def make_complex(real: Tensor, imag: Tensor) -> Tensor:
return MakeComplexOperator(real, imag).outputs[0]
return MakeComplexOperator(real, imag).get_output(0)
2 changes: 1 addition & 1 deletion python/hidet/graph/ops/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,4 +205,4 @@ def tri(
) -> Tensor:
if m is None:
m = n
return TriOp(n, m, k, dtype, device).outputs[0]
return TriOp(n, m, k, dtype, device).get_output(0)
2 changes: 1 addition & 1 deletion python/hidet/graph/ops/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(self, x: Tensor, op: str, comm_id: int):
def all_reduce(x: Tensor, op: str, comm_id: int = 0) -> Tensor:
if x.device.kind != 'cuda':
raise RuntimeError("NCCL only supports CUDA tensors")
return AllReduceOp(x, op, comm_id).outputs[0]
return AllReduceOp(x, op, comm_id).get_output(0)


def broadcast(x: Tensor, root: int, comm_id: int = 0) -> Tensor:
Expand Down
1 change: 1 addition & 0 deletions python/hidet/graph/ops/quant/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .symmetric import symmetric_quantize, symmetric_dequantize
77 changes: 77 additions & 0 deletions python/hidet/graph/ops/quant/symmetric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from typing import Union
from hidet import ir
from hidet.ir.expr import cast, if_then_else
from hidet.ir.compute.primitives import TensorNode, compute
from hidet.ir import primitives as prim
from hidet.ir.compute import cops

from hidet.graph.ops.utils import Task, Operator, Tensor, input_like


# pylint: disable=dangerous-default-value
class SymmetricQuantizationTask(Task):
def __init__(self, w: TensorNode, quant_type: ir.type.DataType, dims=[-1]):
dims = [i if i >= 0 else len(w.shape) + i for i in dims]
self._assert(all(i >= 0 or i < len(w.shape) for i in dims), "dims are out of bounds")

wm = compute(
name='abs', shape=w.shape, fcompute=lambda *indices: if_then_else(w[indices] >= 0, w[indices], -w[indices])
)
scale = cops.reduce_cop(wm, dims, keep_dim=False, reduce_type='max')
scale = compute(
name='scaling', shape=scale.shape, fcompute=lambda *indices: quant_type.max_value / scale[indices]
)

def scale_weight(*indices):
scale_indices = [indices[i] for i in range(len(indices)) if not i in dims]
return cast(prim.round(w[indices] * scale[scale_indices]), quant_type)

wq = compute(name='quantize', shape=w.shape, fcompute=scale_weight)
super().__init__(
name='symmetric_quantization_task',
inputs=[w],
outputs=[wq, scale],
attributes={'dims': dims, 'quant_type': quant_type},
)


class SymmetricDeQuantizationTask(Task):
def __init__(self, wq: TensorNode, scale: TensorNode, dims=[-1]):
dims = [i if i >= 0 else len(wq.shape) + i for i in dims]
Aalanli marked this conversation as resolved.
Show resolved Hide resolved
self._assert(all(i >= 0 or i < len(wq.shape) for i in dims), "dims are out of bounds")

def unscale_weight(*indices):
scale_indices = [indices[i] for i in range(len(indices)) if not i in dims]
return cast(wq[indices], scale.type.dtype) / scale[scale_indices]

w = compute(name='dequantize', shape=wq.shape, fcompute=unscale_weight)
super().__init__(
name='symmetric_dequantization_task', inputs=[wq, scale], outputs=[w], attributes={'dims': dims}
Aalanli marked this conversation as resolved.
Show resolved Hide resolved
)


class SymmetricQuantizationOp(Operator):
def __init__(self, w: Tensor, quant_type: ir.type.DataType, dims=[-1]):
super().__init__(
inputs=[w],
attributes={'dims': dims, 'quant_type': quant_type},
task=SymmetricQuantizationTask(input_like(w, 'w'), quant_type=quant_type, dims=dims),
)


class SymmetricDeQuantizationOp(Operator):
def __init__(self, wq: Tensor, scale: Tensor, dims=[-1]):
super().__init__(
inputs=[wq, scale],
attributes={'dims': dims},
task=SymmetricDeQuantizationTask(input_like(wq, 'wq'), input_like(scale, 'scale'), dims=dims),
)


def symmetric_quantize(w: Tensor, quant_type: Union[str, ir.type.DataType] = 'int8', dims=[-1]):
op = SymmetricQuantizationOp(w, ir.type.data_type(quant_type), dims=dims)
return op.get_output(0), op.get_output(1)


def symmetric_dequantize(wq: Tensor, scale: Tensor, dims=[-1]):
return SymmetricDeQuantizationOp(wq, scale, dims).get_output(0)
2 changes: 1 addition & 1 deletion python/hidet/graph/ops/reduce/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .reduce import reduce, ReduceBaseOp, ReduceTask
from .reduce import ReduceBaseOp, ReduceTask
from .reduce import mean, sum, var, min, max, std, prod, argmin, argmax, all, any
from .reduce import ReduceSumOp, ReduceMeanOp
from .reduce_f16 import reduce_f16
Expand Down
35 changes: 3 additions & 32 deletions python/hidet/graph/ops/reduce/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,47 +11,18 @@
# limitations under the License.
from typing import List, Union, Optional, Sequence

from hidet.ir.compute import cops
from ..arithmetic import square, sqrt
from ..utils import Task, Operator, Tensor, TensorNode, IRModule, ReduceType
from ..utils import compute, reduce, input_like, normalize_dim, arg_reduce
from ..utils import compute, input_like, normalize_dim, arg_reduce


class ReduceTask(Task):
def __init__(
self, x: TensorNode, dims: List[int], keep_dim: bool, reduce_type: str, accumulate_dtype: str = 'float32'
):
y_shape = []
for i in range(len(x.shape)):
if i in dims:
if keep_dim:
y_shape.append(1)
else:
y_shape.append(x.shape[i])

def fcompute(*indices):
def reduce_fcompute(*reduce_indices):
x_indices = []
p = 0
q = 0
for i in range(len(x.shape)):
if i not in dims:
x_indices.append(indices[p])
p += 1
else:
x_indices.append(reduce_indices[q])
q += 1
if keep_dim:
p += 1
assert p == len(indices) and q == len(reduce_indices)
return x[x_indices]

reduce_shape = [x.shape[i] for i in dims]
return reduce(
shape=reduce_shape, fcompute=reduce_fcompute, reduce_type=reduce_type, accumulate_dtype=accumulate_dtype
)

y = compute(name='y', shape=y_shape, fcompute=fcompute)

y = cops.reduce_cop(x, dims, keep_dim, reduce_type, accumulate_dtype)
self.dims: List[int] = dims
self.keep_dim: bool = keep_dim
self.reduce_type: str = reduce_type
Expand Down
7 changes: 6 additions & 1 deletion python/hidet/graph/ops/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from hidet.ir.task import Task
from hidet.graph.tensor import Tensor
from .utils import Task, Operator, Tensor, TensorNode, compute, input_like


# todo: add GraphInput and GraphOutput special operators here.


Expand All @@ -25,6 +26,10 @@ class BarrierOp(Operator):
def __init__(self, x: Tensor):
super().__init__(inputs=[x], attributes={}, task=BarrierTask(input_like(x, 'x')))

def get_output(self, idx: int) -> Tensor:
self.outputs = super().symbolic_run()
return self.outputs[idx]


def barrier(x: Tensor) -> Tensor:
"""
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/graph/ops/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ def take(data: Tensor, indices: Tensor, axis: int = 0) -> Tensor:


def gather(data: Tensor, indices: Tensor, axis: int = 0) -> Tensor:
return GatherOp(data, indices, axis).outputs[0]
return GatherOp(data, indices, axis).get_output(0)


def strided_slice(
Expand Down
12 changes: 12 additions & 0 deletions python/hidet/ir/compute/cops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,14 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .matmul import matmul
from .pad import pad
from .reduce import reduce_cop
11 changes: 11 additions & 0 deletions python/hidet/ir/compute/cops/matmul.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Sequence
from hidet.ir.expr import Expr, Var, is_constant
from hidet.ir.compute.primitives import TensorNode, compute, reduce
Expand Down
11 changes: 11 additions & 0 deletions python/hidet/ir/compute/cops/pad.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from hidet.ir.expr import if_then_else, logical_and, convert
from hidet.ir.compute.primitives import TensorNode, compute
Expand Down