Skip to content

Commit

Permalink
Merge branch 'main' into justinchu/flag-trace
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed May 24, 2024
2 parents 6129749 + c41ded5 commit ab73a54
Show file tree
Hide file tree
Showing 31 changed files with 1,006 additions and 481 deletions.
38 changes: 13 additions & 25 deletions docs/intermediate_representation/tensors.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,26 +141,17 @@ In the following scenario, we show how to go from a `TensorProto` to an `ir.Tens

## Working with non-native NumPy dtypes: bfloat16, float8, int4

`ir.Tensor.numpy()` produces a NumPy array representation of the tensor's value. When the tensor has dtype `BFLOAT16`, `FLOAT8[...]` or `[U]INT4` which are not supported by NumPy, the value is the bit representation for the dtype:
`ir.Tensor.numpy()` produces a NumPy array representation of the tensor's value. When the tensor has dtype `BFLOAT16`, `FLOAT8[...]` or `[U]INT4` which are not supported by NumPy, we use dtypes from the `ml_dtypes` package.

`uint4`/`int4` is always unpacked; **`tobyte()` produces a packed representation** as expected.

Initialization of `ir.Tensor` requires the NumPy array to follow the following typing constraints, or have a `ml_dtypes` dtype.

- `int8` for (unpacked) int4, with the sign bit extended to 8 bits.
- `uint8` for (unpacked) uint4.
- `uint8` for 8-bit data types like float8.
- `uint16` for bfloat16.

uint4/int4 is always unpacked; `tobyte()` produces a packed representation as expected.

Initialization of `ir.Tensor` requires the NumPy array to follow these typing constraints as well.

:::{tip}
You can use the [ml_dtypes package](https://github.com/jax-ml/ml_dtypes) to extend NumPy and work with these values.

```bash
pip install --upgrade ml_dtypes
```

:::

The following example shows how to create a `FLOAT8E4M3FN` tensor, transform its values, and create a new tensor to store the transformed values.

```{eval-rst}
Expand All @@ -170,24 +161,21 @@ The following example shows how to create a `FLOAT8E4M3FN` tensor, transform its
import numpy as np
array = np.array([0b1, 0b11], dtype=np.uint8)
# The array is reinterpreted using the ml_dtypes package
tensor = ir.Tensor(array, dtype=ir.DataType.FLOAT8E4M3FN)
print(tensor) # Tensor<FLOAT8E4M3FN,[2]>(array([1, 3], dtype=uint8), name='')
print("tensor.numpy():", tensor.numpy()) # array([1, 3], dtype=uint8)
# You can use the ml_dtypes package to work with these values in NumPy
import ml_dtypes
float8_array = tensor.numpy().view(ml_dtypes.float8_e4m3fn)
print("float8_array:", float8_array) # array([0.00195312, 0.00585938], dtype='float8_e4m3fn')
print(tensor) # Tensor<FLOAT8E4M3FN,[2]>(array([0.00195312, 0.00585938], dtype='float8_e4m3fn'), name=None)
print("tensor.numpy():", tensor.numpy()) # [0.00195312 0.00585938]
# Compute
times_100 = float8_array * 100
times_100 = tensor.numpy() * 100
print("times_100:", times_100)
# Create a new tensor out of the new value; dtype must be specified
new_tensor = ir.Tensor(times_100.view(np.uint8), dtype=ir.DataType.FLOAT8E4M3FN)
print("new_tensor:", new_tensor) # Tensor<FLOAT8E4M3FN,[2]>(array([36, 49], dtype=uint8), name='')
print("new_tensor == times_100", new_tensor.numpy().view(ml_dtypes.float8_e4m3fn) == times_100) # array([ True, True])
# You can also directly create the tensor from the float8 array without specifying dtype
# new_tensor = ir.Tensor(times_100)
print("new_tensor:", new_tensor) # Tensor<FLOAT8E4M3FN,[2]>(array([0.1875, 0.5625], dtype='float8_e4m3fn'), name=None)
print("new_tensor == times_100", new_tensor.numpy() == times_100) # array([ True, True])
```

## Advanced Usage
Expand Down
71 changes: 42 additions & 29 deletions docs/tutorial/rewriter/examples/broadcast_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import logging

import numpy as np
import onnx

import onnxscript
Expand Down Expand Up @@ -65,71 +64,81 @@ def matmul_pattern(op, input_a: ir.Value, input_b: ir.Value, **_):
def check_if_not_need_reshape(
context, input_a: ir.Value, input_b: ir.Value, shape_c: ir.Value, **_
) -> bool:
"""If matmul broadcasting is enough, then we don't need the reshapes.
"""Condition to check if we need to replace the pattern.
If matmul broadcasting is enough, then we don't need the reshapes.
To validate this, we need to check the following:
1. Input shapes check: input_a and input_b should be broadcastable
2. Output shape check: shape_c should be the same as the output shape from the matmul(input_a, input_b)
If the above are true, then we don't need the reshapes.
Returns:
True if we need to replace the pattern, False otherwise.
"""
del context # Reserved for future extensions

input_a_shape = input_a.shape
input_b_shape = input_b.shape
# TODO: Get a helper func to get const_value
shape_c_value = _ir_utils.propagate_const_value(shape_c)
shape_c = shape_c_value.const_value.numpy() # type: ignore[union-attr]
if shape_c is None:
return False
if not isinstance(shape_c, np.ndarray):
logger.info("Unexpected shape_c value. Expected np.ndarray, got %s", type(shape_c))
_ir_utils.propagate_const_value(shape_c)
shape_c_tensor = shape_c.const_value
if shape_c_tensor is None:
logger.info("The value 'shape_c' is not statically known.")
return False
if len(shape_c.shape) != 1:

if len(shape_c_tensor.shape) != 1:
logger.info(
"Unexpected final shape. The shape of 'shape' value is %s",
shape_c.shape,
shape_c_tensor.shape,
)
return False
shape_c_list = shape_c.tolist()

# NOTE: When there is a subset match with a pattern. The MatchResult won't have the shape
# information. So, we need to check if the shape is None and return False.
if input_a_shape is None or input_b_shape is None or shape_c is None:
if input_a_shape is None or input_b_shape is None:
logger.info("Shape information is not available for the inputs and outputs.")
return False
input_a_shape = list(input_a_shape)
input_b_shape = list(input_b_shape)
input_a_shape = input_a_shape.numpy()
input_b_shape = input_b_shape.numpy()
shape_c = shape_c_tensor.numpy().tolist()

a_rank = len(input_a_shape)
b_rank = len(input_b_shape)

dim_a = len(input_a_shape)
dim_b = len(input_b_shape)
# TODO(justinchuby): Check shape size

# 1. Check if input shapes are broadcastable
# 1.a. If the first input is 1-D, check whether
# the dim matches the last second dim of the second input.
mimic_matmul_broadcast_behavior = False
if dim_a < 2:
if a_rank < 2:
if b_rank < 2:
logger.info("Optimization of dot product is not supported yet.")
return False
if input_a_shape[-1] != input_b_shape[-2]:
logger.info("Original shape is not MatMul compatible.")
return False
else:
input_a_shape = [1, *input_a_shape]
dim_a = len(input_a_shape)
a_rank = len(input_a_shape)
mimic_matmul_broadcast_behavior = True
# 1.b. If the second input is 1-D, check whether
# the dim matches the last dim of the first input.
if dim_b < 2:
if b_rank < 2:
if input_b_shape[-1] != input_a_shape[-1]:
logger.info("Original shape is not MatMul compatible.")
return False
else:
input_b_shape = [*input_b_shape, 1]
dim_b = len(input_b_shape)
b_rank = len(input_b_shape)
mimic_matmul_broadcast_behavior = True
# 1.c. If both inputs are at least 2-D, check whether
# the last dimension of the first input matches the second
# last dimension of the second input, and shape[:-2] are
# broadcastable.
input_a_shape_except_second_last_dim = input_a_shape[:-2] + [input_a_shape[-1]]
input_a_shape_except_second_last_dim = [*input_a_shape[:-2], *[input_a_shape[-1]]]
input_b_shape_except_last_dim = input_b_shape[:-1]
broadcast_matmul_output_shape = [input_a_shape[-2], input_b_shape[-1]]
for idx, (dim_from_a, dim_from_b) in enumerate(
Expand All @@ -149,23 +158,27 @@ def check_if_not_need_reshape(

# 2. Check if output shape is the same as the output shape from the matmul(input_a, input_b)
# Prepend the broadcast_matmul_output_shape with the longer shape of input
if dim_a > dim_b:
if a_rank > b_rank:
longer_shape = input_a_shape
shorter_shape = input_b_shape
else:
longer_shape = input_b_shape
shorter_shape = input_a_shape
broadcast_matmul_output_shape = (
longer_shape[: -len(shorter_shape)] + broadcast_matmul_output_shape
)
if mimic_matmul_broadcast_behavior and dim_b == 2:
broadcast_matmul_output_shape = [
*longer_shape[: -len(shorter_shape)],
*broadcast_matmul_output_shape,
]
if mimic_matmul_broadcast_behavior and b_rank == 2 and input_b_shape[-1] == 1:
# If input_b is expanded to 2-D, then we need to remove the last dimension
broadcast_matmul_output_shape = broadcast_matmul_output_shape[:-1]
if mimic_matmul_broadcast_behavior and dim_a == 2:
if mimic_matmul_broadcast_behavior and a_rank == 2 and input_a_shape[0] == 1:
# If input_a is expanded to 2-D, then we need to remove the first dimension
# of input_a, which would be the -2nd dimension of the output shape.
broadcast_matmul_output_shape.pop(-2)
if shape_c_list != broadcast_matmul_output_shape:
if shape_c != broadcast_matmul_output_shape:
logger.info(
"Final output shape is not the same. Expected %s vs actual %s",
shape_c_list,
shape_c,
broadcast_matmul_output_shape,
)
return False
Expand Down
10 changes: 3 additions & 7 deletions examples/pattern_rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import onnx.numpy_helper as onh

from onnxscript import ir
from onnxscript.rewriter import generic_pattern
from onnxscript.rewriter import pattern


def get_rotary_model(bad_model=False):
Expand Down Expand Up @@ -99,9 +99,7 @@ def rotary_apply_pattern(op, x, pos_ids, axis):
#
# The rule is easy to create.

rule = generic_pattern.make_pattern_rule(
rotary_match_pattern, rotary_apply_pattern, verbose=10
)
rule = pattern.RewriteRule(rotary_match_pattern, rotary_apply_pattern, verbose=10)

##########################
# Let's apply it.
Expand Down Expand Up @@ -136,9 +134,7 @@ def rotary_apply_pattern(op, x, pos_ids, axis):
# The match did not happen.
# Let's increase the verbosity.

rule = generic_pattern.make_pattern_rule(
rotary_match_pattern, rotary_apply_pattern, verbose=10
)
rule = pattern.RewriteRule(rotary_match_pattern, rotary_apply_pattern, verbose=10)

rule.apply_to_model(ir_model)

Expand Down
18 changes: 18 additions & 0 deletions onnxscript/converter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,24 @@ def sum(n: INT64) -> INT64:
self.check_run(sum, [np.array(5, dtype=np.int64)], np.array(10, dtype=np.int64))
self.check_run(sum, [np.array(-5, dtype=np.int64)], np.array(0, dtype=np.int64))

def test_function_opset_import(self):
"""Test that model inherits opset version from the function."""
from onnxscript import opset19

@script()
def double(x):
return opset19.Add(x, x)

@script()
def model(x):
return double(x)

model_proto = model.to_model_proto()
onnx_opset_import = [opset for opset in model_proto.opset_import if opset.domain == ""]

self.assertEqual(len(onnx_opset_import), 1)
self.assertEqual(onnx_opset_import[0].version, 19)


if __name__ == "__main__":
unittest.main(verbosity=2)
6 changes: 5 additions & 1 deletion onnxscript/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,15 @@
# Conversion functions
"from_proto",
"to_proto",
# IR Tensor initializer
"tensor",
# Pass infrastructure
"passes",
"traversal",
]

from onnxscript.ir import passes, serde
from onnxscript.ir import passes, serde, traversal
from onnxscript.ir._convenience import tensor
from onnxscript.ir._core import (
Attr,
AttrFloat32,
Expand Down
85 changes: 85 additions & 0 deletions onnxscript/ir/_convenience.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,17 @@
"replace_all_uses_with",
]

import typing
from typing import Mapping, Sequence, Union

import numpy as np
import onnx

from onnxscript.ir import _core, _enums, _protocols, serde

if typing.TYPE_CHECKING:
import numpy.typing as npt

SupportedAttrTypes = Union[
str,
int,
Expand Down Expand Up @@ -285,3 +290,83 @@ def replace_all_uses_with(
for value, replacement in zip(values, replacements):
for user_node, index in tuple(value.uses()):
user_node.replace_input_with(index, replacement)


def tensor(
value: npt.ArrayLike
| onnx.TensorProto
| _protocols.DLPackCompatible
| _protocols.ArrayCompatible,
dtype: _enums.DataType | None = None,
name: str | None = None,
doc_string: str | None = None,
) -> _protocols.TensorProtocol:
"""Create a tensor value from an ArrayLike object or a TensorProto.
The dtype must match the value. Reinterpretation of the value is
not supported, unless if the value is a plain Python object, in which case
it is converted to a numpy array with the given dtype.
:param:`value` can be a numpy array, a plain Python object, or a TensorProto.
Example::
>>> from onnxscript import ir
>>> import numpy as np
>>> import ml_dtypes
>>> import onnx
>>> ir.tensor(np.array([1, 2, 3], dtype=np.int16))
Tensor<INT16,[3]>(array([1, 2, 3], dtype=int16), name=None)
>>> ir.tensor([1, 2, 3], dtype=ir.DataType.BFLOAT16)
Tensor<BFLOAT16,[3]>(array([1, 2, 3], dtype=bfloat16), name=None)
>>> tp_tensor = ir.tensor(onnx.helper.make_tensor("tensor", onnx.TensorProto.FLOAT, dims=[], vals=[0.5]))
>>> tp_tensor.numpy()
array(0.5, dtype=float32)
Args:
value: The numpy array to create the tensor from.
dtype: The data type of the tensor.
name: The name of the tensor.
doc_string: The documentation string of the tensor.
Returns:
A tensor value.
Raises:
ValueError: If the dtype does not match the value when value is not a plain Python
object like ``list[int]``.
"""
if isinstance(value, _protocols.TensorProtocol):
if dtype is not None and dtype != value.dtype:
raise ValueError(
f"The dtype must match the value when value is a Tensor. dtype={dtype}, value.dtype={value.dtype}. "
"You do not have to specify the dtype when value is a Tensor."
)
return value
if isinstance(value, onnx.TensorProto):
tensor_ = serde.deserialize_tensor(value)
if name is not None:
tensor_.name = name
if doc_string is not None:
tensor_.doc_string = doc_string
if dtype is not None and dtype != tensor_.dtype:
raise ValueError(
f"The dtype must match the value when value is a TensorProto. dtype={dtype}, value.data_type={tensor_.dtype}"
"You do not have to specify the dtype when value is a TensorProto."
)
elif isinstance(value, (_protocols.DLPackCompatible, _protocols.ArrayCompatible)):
tensor_ = _core.Tensor(value, dtype=dtype, name=name, doc_string=name)
else:
if dtype is not None:
numpy_dtype = dtype.numpy()
else:
numpy_dtype = None
array = np.array(value, dtype=numpy_dtype)
tensor_ = _core.Tensor(
array,
dtype=dtype,
shape=_core.Shape(array.shape),
name=name,
doc_string=name,
)
return tensor_
Loading

0 comments on commit ab73a54

Please sign in to comment.