Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operators] preliminary symmetric weight quantization #298

Merged
merged 10 commits into from
Jul 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion gallery/how-to-guides/add-new-operator-rule-based.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(self, a: Tensor, b: Tensor):

def batch_matmul(a: Tensor, b: Tensor) -> Tensor:
# get_output(0) returns the first output tensor of the operator
return BatchMatmulOp(a, b).get_output(0)
return BatchMatmulOp(a, b).outputs[0]


# %%
Expand Down
2 changes: 1 addition & 1 deletion gallery/how-to-guides/add-new-operator-template-based.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def __init__(self, a: Tensor, b: Tensor):


def batch_matmul_fp16(a: Tensor, b: Tensor) -> Tensor:
return BatchMatmulFp16Op(a, b).get_output(0)
return BatchMatmulFp16Op(a, b).outputs[0]


def demo_usage():
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/graph/flow_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def forward(self, inputs: List[Tensor]) -> List[Tensor]:
GraphForwardContext._before_operator(node, node_inputs)
logger.debug('[%4d/%d] run operator %s, %s', idx, len(self.nodes), node.name, node.task)
logger.debug(' inputs: %s', [x.signature() for x in node_inputs])
node_outputs = node.imperative_run(node_inputs)
node_outputs = node.compiled_task.run_async(node_inputs)
logger.debug(' outputs: %s', [x.signature() for x in node_outputs])
GraphForwardContext._after_operator(node, node_inputs, node_outputs)

Expand Down
13 changes: 6 additions & 7 deletions python/hidet/graph/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,17 @@ def __init__(self, in_features, out_features, bias: bool = True):
else:
self.bias = None

self._transposed = False

def extra_str(self) -> str:
return 'in_features={}, out_features={}'.format(self.in_features, self.out_features)

def forward(self, x: Tensor) -> Tensor:
# x = ops.matmul(x, ops.transpose(self.weight)) # will duplicate weight memory consumption
# workaround: use ops.matmul with some transformations
# todo: use matmul(..., trans_a=False, trans_b=True) when we have support for transposed matmul

out_shape = list(x.shape[:-1]) + [self.out_features]
x = ops.reshape(x, [-1, self.in_features]).transpose(0, 1)
x = ops.matmul(self.weight, x).transpose(0, 1).reshape(out_shape)
if not self._transposed:
self._transposed = True
self.weight = ops.transpose(self.weight, [1, 0])

x = ops.matmul(x, self.weight)
if self.bias is not None:
x = ops.add(x, self.bias)
return x
Expand Down
117 changes: 47 additions & 70 deletions python/hidet/graph/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Dict, Any
from hidet.ir.type import TensorType, DataType
from hidet.ir.expr import Var, Constant
from hidet.ir.type import TensorType
from hidet.ir.task import Task
from hidet.runtime.compiled_task import CompiledTask
from hidet.graph.tensor import empty, Tensor, SymbolVar
from hidet.ffi.ffi import get_last_error, BackendException
from hidet.graph.tensor import Tensor, SymbolVar
from hidet.runtime.device import Device, instantiate_device


Expand All @@ -42,7 +40,7 @@ def __init__(self, inputs: List[Tensor], attributes: Dict[str, Any], task: Optio
# cache
self._compiled_task: Optional[CompiledTask] = None

self.outputs = self._run()
self.outputs = self.run()

def __str__(self):
arguments = ['{}: {}{}'.format(i, t.dtype.name, t.shape) for i, t in enumerate(self.inputs)]
Expand All @@ -58,6 +56,17 @@ def device(self) -> Device:
ret: Device
The device of the output tensor of this operator.
"""
# Some notes about self.device:
#
# Each hidet operator has a device property, which is the device of the output tensor of this operator.
# For common operators, the device property is inferred from the device of the input tensors. For these
# operators, the device for all input tensors and output tensor must be the same.
# There are two exceptions:
# 1. for the operators that create a tensor (e.g., hidet.full), they do not have input tensors.
# 2. for the transfer operators (e.g., hidet.ops.transfer), the output device is different from the input's
# For these operators, they must explicitly set the 'device' attribute, which is used determine the device
# of the output tensor.
#
if 'device' in self.attrs:
# this is an operator that create a tensor like hidet.full, or transfer operator
# get the device from the operator attributes
Expand All @@ -82,6 +91,16 @@ def build_target(self) -> str:
"""
from hidet.graph.ops.transfer import TransferOp

# Some notes about self.build_target:
#
# Each hidet operator has a build_target property, which is the build target of this operator and determines
# the scheduling of task and compilation target of the scheduled tensor program.
# For common operators, the build_target property is inferred from the device of the output tensor.
# There is one exception:
# 1. for the transfer operators (e.g., hidet.ops.transfer), the build target is always 'cuda' because the
# current transfer operators are always between cpu and cuda. Even the output tensor is on cpu (in this case,
# the transfer operator copy a tensor from cuda to cpu), the build target is still 'cuda'.

if isinstance(self, TransferOp):
return 'cuda'
else:
Expand All @@ -93,73 +112,31 @@ def compiled_task(self) -> CompiledTask:
self._compiled_task = self.task.build(target=self.build_target)
return self._compiled_task

def _run(self) -> List[Tensor]:
from hidet.ir.tools import rewrite, simplify, collect

if all(t.storage is not None for t in self.inputs) and len(collect(self.task, SymbolVar)) == 0:
return self.imperative_run(self.inputs)
else:
output_types: List[TensorType] = [output_node.type for output_node in self.task.outputs]
outputs: List[Tensor] = []
remap: Dict[Var, Constant] = {}
for i, (a, b) in enumerate(zip(self.task.inputs, self.inputs)):
for d1, d2 in zip(a.type.shape, b.shape):
if isinstance(d1, Var) and not (d1 in remap and isinstance(remap[d1], Constant)):
remap[d1] = d2
for i, output_type in enumerate(output_types):
shape = [simplify(rewrite(d, remap)) for d in output_type.shape]
outputs.append(
Tensor(shape=shape, dtype=output_type.dtype.name, device=self.device, storage=None, trace=(self, i))
)
return outputs

def get_output(self, idx: int) -> Tensor:
if self.outputs is None:
outputs = self._run()
else:
outputs = self.outputs
return outputs[idx]

def _imperative_run_prepare_outputs(self) -> List[Tensor]:
from hidet.ir.tools import simplify, collect, rewrite
from hidet.ffi import runtime_api

# get the mapping from size var to the actual size
symbolic_shapes = tuple(tuple(d for d in output.type.shape) for output in self.task.outputs)
used_symbols: List[SymbolVar] = collect(symbolic_shapes, SymbolVar)
remap: Dict[SymbolVar, Constant] = {}
for used_symbol in used_symbols:
try:
dtype: DataType = used_symbol.type.as_data_type()
remap[used_symbol] = dtype(runtime_api.get_symbol_value(used_symbol.name))
except BackendException as e:
raise RuntimeError('Failed to get the symbol value of "{}"'.format(used_symbol)) from e
output_shapes = simplify(rewrite(symbolic_shapes, remap))

# check if all the output shapes are constant
for shape in output_shapes:
for d in shape:
if not isinstance(d, Constant):
raise RuntimeError(
'The output shape "{}" of "{}" can not be reduced to a constant'.format(d, self.name)
)

# create the output tensors
output_dtypes: List[DataType] = [output.type.dtype for output in self.task.outputs]
outputs: List[Tensor] = [
empty(shape=shape, dtype=dtype, device=self.device) for shape, dtype in zip(output_shapes, output_dtypes)
]

return outputs

def imperative_run(self, inputs: List[Tensor]) -> List[Tensor]:
outputs = self.compiled_task.run_async(inputs)
def run(self) -> List[Tensor]:
from hidet.ir.tools import collect

status = get_last_error()
if status is not None:
msg = 'Kernel for operator {} failed. Error:\n{}'.format(self.name, status)
raise BackendException(msg)
# we imperatively run the operator if
# 1. all inputs are concrete tensors (i.e., t.storage is not None)
# 2. there is no symbol variable in the task
could_imperative_run = (
all(t.storage is not None for t in self.inputs) and len(collect(self.task, SymbolVar)) == 0
)

if could_imperative_run:
return self.compiled_task.run_async(self.inputs)
else:
return self.symbolic_run()

def symbolic_run(self) -> List[Tensor]:
from hidet.ir.tools import simplify

output_types: List[TensorType] = [output_node.type for output_node in self.task.outputs]
outputs: List[Tensor] = []
for i, output_type in enumerate(output_types):
shape = [simplify(d) for d in output_type.shape]
outputs.append(
Tensor(shape=shape, dtype=output_type.dtype.name, device=self.device, storage=None, trace=(self, i))
)
return outputs

def reforward(self, inputs: List[Tensor], update_attributes: Optional[Dict[str, Any]] = None) -> List[Tensor]:
Expand Down
1 change: 1 addition & 0 deletions python/hidet/graph/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from .complex import real, imag, conj, make_complex
from .compare import equal, not_equal, less, greater, less_equal, greater_equal
from .compare import logical_not, logical_and, logical_or, logical_xor
from .quant import symmetric_quantize, symmetric_dequantize
from .reduce import mean, sum, var, min, max, std, prod, argmin, argmax, all, any
from .cumulative import cumsum
from .transform import squeeze, unsqueeze, flatten, concat, cast, take, rearrange, strided_slice, reshape
Expand Down
42 changes: 21 additions & 21 deletions python/hidet/graph/ops/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,23 +193,23 @@ def __init__(self, x: Tensor, axis: int = 1):


def relu(x) -> Tensor:
return ReluOp(x).get_output(0)
return ReluOp(x).outputs[0]


def leaky_relu(x: Tensor, alpha: float) -> Tensor:
return LeakyReluOp(x, alpha).get_output(0)
return LeakyReluOp(x, alpha).outputs[0]


def sigmoid(x: Tensor) -> Tensor:
return SigmoidOp(x).get_output(0)
return SigmoidOp(x).outputs[0]


def hardsigmoid(x: Tensor) -> Tensor:
return HardSigmoidOp(x).get_output(0)
return HardSigmoidOp(x).outputs[0]


def clip(x: Tensor, min_val: Optional[float], max_val: Optional[float]) -> Tensor:
return ClipOp(x, min_val, max_val).get_output(0)
return ClipOp(x, min_val, max_val).outputs[0]


def relu6(x: Tensor) -> Tensor:
Expand All @@ -221,64 +221,64 @@ def gelu(x: Tensor, approximate: bool = False) -> Tensor:


def silu(x: Tensor) -> Tensor:
return SiluOp(x).get_output(0)
return SiluOp(x).outputs[0]


def prelu(x: Tensor, slope: Tensor) -> Tensor:
return PReluOp(x, slope).get_output(0)
return PReluOp(x, slope).outputs[0]


def hardswish(x: Tensor) -> Tensor:
return HardSwishOp(x).get_output(0)
return HardSwishOp(x).outputs[0]


def threshold(x: Tensor, threshold_val: float, value: float) -> Tensor:
return ThresholdOp(x, threshold_val, value).get_output(0)
return ThresholdOp(x, threshold_val, value).outputs[0]


def hardtanh(x: Tensor, min_val: float, max_val: float) -> Tensor:
return HardTanhOp(x, min_val, max_val).get_output(0)
return HardTanhOp(x, min_val, max_val).outputs[0]


def elu(x: Tensor, alpha: float) -> Tensor:
return EluOp(x, alpha).get_output(0)
return EluOp(x, alpha).outputs[0]


def selu(x: Tensor, alpha: float, scale: float) -> Tensor:
return SeluOp(x, alpha, scale).get_output(0)
return SeluOp(x, alpha, scale).outputs[0]


def celu(x: Tensor, alpha: float) -> Tensor:
return CeluOp(x, alpha).get_output(0)
return CeluOp(x, alpha).outputs[0]


def logsigmoid(x: Tensor) -> Tensor:
return LogSigmoidOp(x).get_output(0)
return LogSigmoidOp(x).outputs[0]


def hardshrink(x: Tensor, lambda_val: float) -> Tensor:
return HardShrinkOp(x, lambda_val).get_output(0)
return HardShrinkOp(x, lambda_val).outputs[0]


def tanhshrink(x: Tensor) -> Tensor:
return TanhShrinkOp(x).get_output(0)
return TanhShrinkOp(x).outputs[0]


def softsign(x: Tensor) -> Tensor:
return SoftSignOp(x).get_output(0)
return SoftSignOp(x).outputs[0]


def softplus(x: Tensor, beta: int, threshold_val: int) -> Tensor:
return SoftPlusOp(x, beta, threshold_val).get_output(0)
return SoftPlusOp(x, beta, threshold_val).outputs[0]


def softshrink(x: Tensor, lambda_val: float) -> Tensor:
return SoftShrinkOp(x, lambda_val).get_output(0)
return SoftShrinkOp(x, lambda_val).outputs[0]


def softmax(x: Tensor, axis=1) -> Tensor:
return SoftmaxOp(x, axis).get_output(0)
return SoftmaxOp(x, axis).outputs[0]


def softmin(x: Tensor, axis: int) -> Tensor:
return SoftmaxOp(-x, axis).get_output(0)
return SoftmaxOp(-x, axis).outputs[0]