-
Notifications
You must be signed in to change notification settings - Fork 3.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Clean up and fix numpy_helper and subbyte #6124
base: main
Are you sure you want to change the base?
Conversation
Test Results 3 files ±0 3 suites ±0 2m 15s ⏱️ +8s For more details on these failures, see this check. Results for commit 64be57c. ± Comparison against base commit 013eb5e. ♻️ This comment has been updated with latest results. |
Maybe it is worth adding a unit test. |
9ba8ff1
to
b688548
Compare
""" | ||
single_func = lambda x: subbyte.unpack_single_4bitx2(x, signed) # noqa: E731 | ||
func = np.frompyfunc(single_func, 1, 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
frompyfunc is not performant
if tensor_dtype in (TensorProto.COMPLEX64, TensorProto.COMPLEX128): | ||
data = combine_pairs_to_complex(data) # type: ignore[assignment,arg-type] | ||
if tensor_dtype in (onnx.TensorProto.COMPLEX64, onnx.TensorProto.COMPLEX128): | ||
return np.asarray(combine_pairs_to_complex(data)).astype(np_dtype).reshape(dims) |
Check failure
Code scanning / lintrunner
MYPY/arg-type Error
if tensor_dtype in (TensorProto.COMPLEX64, TensorProto.COMPLEX128): | ||
data = combine_pairs_to_complex(data) # type: ignore[assignment,arg-type] | ||
if tensor_dtype in (onnx.TensorProto.COMPLEX64, onnx.TensorProto.COMPLEX128): | ||
return np.asarray(combine_pairs_to_complex(data)).astype(np_dtype).reshape(dims) |
Check failure
Code scanning / lintrunner
MYPY/arg-type Error
clip_high = INT4_MAX if signed else UINT4_MAX | ||
if not isinstance(x, np.ndarray): | ||
x = np.asarray(x) | ||
return np.rint(np.clip(x, INT4_MIN, INT4_MAX)).astype(np.int8) |
Check failure
Code scanning / lintrunner
MYPY/no-any-return Error
Returns: | ||
An ndarray with a single int4 element. | ||
""" | ||
return np.rint(np.clip(x, UINT4_MIN, UINT4_MAX)).astype(np.uint8) |
Check failure
Code scanning / lintrunner
MYPY/no-any-return Error
else: | ||
i8_low = cast_uint4(val_low) | ||
i8_high = cast_uint4(val_high) | ||
i8_high <<= 4 |
Check failure
Code scanning / lintrunner
MYPY/assignment Error
x_low = x & np.uint8(0x0F) | ||
x_high = (x >> 4).astype(np.uint8) | ||
if signed: | ||
x_low = _int4_to_int8(x_low) |
Check failure
Code scanning / lintrunner
MYPY/assignment Error
x_high = (x >> 4).astype(np.uint8) | ||
if signed: | ||
x_low = _int4_to_int8(x_low) | ||
x_high = _int4_to_int8(x_high) |
Check failure
Code scanning / lintrunner
MYPY/assignment Error
# if mantissa > 0: | ||
# exponent = 127 - exponent_bias | ||
# if mantissa & 0b100 == 0: | ||
# mantissa &= 0b011 | ||
# mantissa <<= 1 | ||
# exponent -= 1 | ||
# if mantissa & 0b100 == 0: | ||
# mantissa &= 0b011 | ||
# mantissa <<= 1 | ||
# exponent -= 1 | ||
# result |= (mantissa & 0b011) << 21 | ||
# result |= exponent << 23 |
Check notice
Code scanning / CodeQL
Commented-out code Note
return f | ||
|
||
|
||
_float8e4m3_to_float32 = np.vectorize( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed use of vectorize because it is a for loop and is not performant
result[normal_mask] |= exponents[normal_mask] << 23 | ||
result = result.view(np.float32) | ||
if is_scalar: | ||
return result[0] |
Check failure
Code scanning / lintrunner
MYPY/no-any-return Error
result = result.view(np.float32) | ||
if is_scalar: | ||
return result[0] | ||
return result |
Check failure
Code scanning / lintrunner
MYPY/return-value Error
result[normal_mask] |= exponents[normal_mask] << 23 | ||
result = result.view(np.float32) | ||
if is_scalar: | ||
return result[0] |
Check failure
Code scanning / lintrunner
MYPY/no-any-return Error
result = result.view(np.float32) | ||
if is_scalar: | ||
return result[0] | ||
return result |
Check failure
Code scanning / lintrunner
MYPY/return-value Error
# if exponent == 0: | ||
# # Subnormal number | ||
# if mantissa > 0: | ||
# exponent = 127 - exponent_bias | ||
# if mantissa & 0b10 == 0: | ||
# mantissa &= 0b01 | ||
# mantissa <<= 1 | ||
# exponent -= 1 | ||
# result |= (mantissa & 0b01) << 22 | ||
# result |= exponent << 23 |
Check notice
Code scanning / CodeQL
Commented-out code Note
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Signed-off-by: Justin Chu <justinchu@microsoft.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Signed-off-by: Justin Chu <justinchu@microsoft.com>
Signed-off-by: Justin Chu <justinchu@microsoft.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Signed-off-by: Justin Chu <justinchu@microsoft.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Signed-off-by: Justin Chu <justinchu@microsoft.com>
Signed-off-by: Justin Chu <justinchu@microsoft.com>
Signed-off-by: Justin Chu <justinchu@microsoft.com>
Signed-off-by: Justin Chu <justinchu@microsoft.com>
Signed-off-by: Justin Chu <justinchu@microsoft.com>
Signed-off-by: Justin Chu <justinchu@microsoft.com>
15aec39
to
9fd4909
Compare
y[i] = d | ||
return y.reshape(shape) | ||
data = np.array(tensor.int32_data, dtype=np.uint8) | ||
data = data.view(dtype_mapping[elem_type]) |
Check failure
Code scanning / lintrunner
MYPY/arg-type Error
y[i] = d | ||
return y.reshape(shape) | ||
data = np.array(tensor.int32_data, dtype=np.uint8) | ||
data = data.view(dtype_mapping[elem_type]) |
Check failure
Code scanning / lintrunner
MYPY/index Error
for i, d in enumerate(data): | ||
y[i] = d | ||
dtype_mapping = {TensorProto.INT4: int4, TensorProto.UINT4: uint4} | ||
dtype = dtype_mapping[elem_type] |
Check failure
Code scanning / lintrunner
MYPY/index Error
y[i] = d | ||
dtype_mapping = {TensorProto.INT4: int4, TensorProto.UINT4: uint4} | ||
dtype = dtype_mapping[elem_type] | ||
return subbyte.unpack_int4(data, dims=tensor.dims, signed=signed).view(dtype) |
Check failure
Code scanning / lintrunner
MYPY/arg-type Error
cd51e75
to
4c18a15
Compare
Signed-off-by: Justin Chu <justinchu@microsoft.com>
Signed-off-by: Justin Chu <justinchu@microsoft.com>
return result | ||
|
||
|
||
def _small_endian_dtype(dtype) -> np.dtype: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def _small_endian_dtype(dtype) -> np.dtype: | |
def _little_endian_dtype(dtype) -> np.dtype: |
For float8 usage, we may be better of using https://github.com/jax-ml/ml_dtypes? |
return shift(data.astype(np.int32)).reshape(dims).view(np.float32) # type: ignore[no-any-return] | ||
|
||
|
||
def _float8e4m3_to_float32_scalar(ival: int, fn: bool, uz: bool) -> np.float32: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would keep the code of the old function in the documentation. The logic is easier to read so that the new code can be more easily understood.
Is this active? It is still marked "draft". Maybe we should get this into 1.17 release? |
I am personally fine with missing the release. The IR has more efficient handling of numpy arrays and does not rely on the helper right now so we are not blocked. |
complex number handling in numpy_helper
Otherwise the line
raises
TypeError: float() argument must be a string or a real number, not 'complex'
, becausestorage_np_dtype
is float butdata
is complex already.Vectorize float8 conversion functions and improve readability: Speed up
float8e4m3_to_float32
by 10.3x (1000x1000 input, 10 iterations, 34.829s -> 3.11s)Clean up int4 numpy helpers to make them more useful and performant with np native vectorization. Move all int4 related functions to the subbyte module.
Improve handling of big-endian systems
Remove the
dims
parameter in numpy helper functions to simplify the implementation.Improve reference evaluator to_array_extended
@galagam for int4 updates, @AlexandreEichenberger for big-endian handling @xadupre for float8 functions and reference evaluator. Thanks!
Float 8 util speed test
TODO: Unit tests
Fixes #6126