diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_complex.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_complex.py new file mode 100644 index 0000000000000..723d39f9700fa --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_complex.py @@ -0,0 +1,31 @@ +# RUN: SUPPORTLIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext %PYTHON %s | FileCheck %s +import numpy as np +import os +import sys + +_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(_SCRIPT_PATH) +from tools import mlir_pytaco_api as pt + +compressed = pt.compressed + +passed = 0 +all_types = [pt.complex64, pt.complex128] +for t in all_types: + i, j = pt.get_index_vars(2) + A = pt.tensor([2, 3], dtype=t) + B = pt.tensor([2, 3], dtype=t) + C = pt.tensor([2, 3], compressed, dtype=t) + A.insert([0, 1], 10 + 20j) + A.insert([1, 2], 40 + 0.5j) + B.insert([0, 0], 20) + B.insert([1, 2], 30 + 15j) + C[i, j] = A[i, j] + B[i, j] + + indices, values = C.get_coordinates_and_values() + passed += isinstance(values[0], t.value) + passed += np.array_equal(indices, [[0, 0], [0, 1], [1, 2]]) + passed += np.allclose(values, [20, 10 + 20j, 70 + 15.5j]) + +# CHECK: Number of passed: 6 +print("Number of passed:", passed) diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py index 9bab366fcbe82..48b6b552bb110 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py @@ -75,16 +75,20 @@ class Type(enum.Enum): # numpy _ctype_from_dtype_scalar can't handle np.float16 yet. FLOAT32 = np.float32 FLOAT64 = np.float64 + COMPLEX64 = np.complex64 + COMPLEX128 = np.complex128 # All floating point type enums. _FLOAT_TYPES = (Type.FLOAT32, Type.FLOAT64) # All integral type enums. _INT_TYPES = (Type.INT8, Type.INT16, Type.INT32, Type.INT64) +# All complex type enums. +_COMPLEX_TYPES = (Type.COMPLEX64, Type.COMPLEX128) # Type alias for any numpy type used to implement the runtime support for the # enum data types. _AnyRuntimeType = Union[np.int8, np.int16, np.int32, np.int64, np.float32, - np.float64] + np.float64, np.complex64, np.complex128] @dataclasses.dataclass(frozen=True) @@ -111,6 +115,10 @@ def is_int(self) -> bool: """Returns whether the data type represents an integral value.""" return self.kind in _INT_TYPES + def is_complex(self) -> bool: + """Returns whether the data type represents a complex value.""" + return self.kind in _COMPLEX_TYPES + @property def value(self) -> _AnyRuntimeType: """Returns the numpy dtype for the data type.""" @@ -125,7 +133,9 @@ def _dtype_to_mlir_str(dtype: DType) -> str: Type.INT32: "i32", Type.INT64: "i64", Type.FLOAT32: "f32", - Type.FLOAT64: "f64" + Type.FLOAT64: "f64", + Type.COMPLEX64: "complex", + Type.COMPLEX128: "complex" } return dtype_to_str[dtype.kind] @@ -138,7 +148,9 @@ def _nptype_to_taco_type(ty: np.dtype) -> DType: np.int32: Type.INT32, np.int64: Type.INT64, np.float32: Type.FLOAT32, - np.float64: Type.FLOAT64 + np.float64: Type.FLOAT64, + np.complex64: Type.COMPLEX64, + np.complex128: Type.COMPLEX128 } return DType(nptype_to_dtype[ty]) @@ -151,7 +163,9 @@ def _mlir_type_from_taco_type(dtype: DType) -> ir.Type: Type.INT32: ir.IntegerType.get_signless(32), Type.INT64: ir.IntegerType.get_signless(64), Type.FLOAT32: ir.F32Type.get(), - Type.FLOAT64: ir.F64Type.get() + Type.FLOAT64: ir.F64Type.get(), + Type.COMPLEX64: ir.ComplexType.get(ir.F32Type.get()), + Type.COMPLEX128: ir.ComplexType.get(ir.F64Type.get()) } return dtype_to_irtype[dtype.kind] @@ -1004,8 +1018,8 @@ def _init_format(self, fmt: Union[ModeFormat, List[ModeFormat], raise ValueError(f"Invalid format argument: {fmt}.") def __init__(self, - value_or_shape: Optional[Union[List[int], Tuple[int, ...], float, - int]] = None, + value_or_shape: Optional[Union[List[int], Tuple[int, ...], + complex, float, int]] = None, fmt: Optional[Union[ModeFormat, List[ModeFormat], Format]] = None, dtype: Optional[DType] = None, @@ -1059,7 +1073,7 @@ def __init__(self, self._values = [] self._stats = _Stats() if value_or_shape is None or isinstance(value_or_shape, int) or isinstance( - value_or_shape, float): + value_or_shape, float) or isinstance(value_or_shape, complex): # Create a scalar tensor and ignore the fmt parameter. self._shape = [] self._format = _make_format([], []) @@ -1108,7 +1122,7 @@ def __repr__(self) -> str: return (f"Tensor(_name={repr(self._name)} " f"_dtype={repr(self._dtype)} : ") + value_str - def insert(self, coords: List[int], val: Union[float, int]) -> None: + def insert(self, coords: List[int], val: Union[complex, float, int]) -> None: """Inserts a value to the given coordinate. Args: @@ -1134,7 +1148,8 @@ def insert(self, coords: List[int], val: Union[float, int]) -> None: raise ValueError("Invalid coordinate for rank: " f"{self.order}, {coords}.") - if not isinstance(val, int) and not isinstance(val, float): + if not isinstance(val, int) and not isinstance( + val, float) and not isinstance(val, complex): raise ValueError(f"Value is neither int nor float: {val}.") self._coords.append(tuple(coords)) diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_api.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_api.py index 7573d6655360f..8300dfef5bc63 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_api.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_api.py @@ -41,6 +41,8 @@ int64 = mlir_pytaco.DType(mlir_pytaco.Type.INT64) float32 = mlir_pytaco.DType(mlir_pytaco.Type.FLOAT32) float64 = mlir_pytaco.DType(mlir_pytaco.Type.FLOAT64) +complex64 = mlir_pytaco.DType(mlir_pytaco.Type.COMPLEX64) +complex128 = mlir_pytaco.DType(mlir_pytaco.Type.COMPLEX128) # Storage format constants defined by the PyTACO API. In PyTACO, each storage # format constant has two aliasing names. diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py index ce6d3c70bd50c..f5ec14aa80b03 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py @@ -92,7 +92,11 @@ def _get_support_func_locator() -> _SupportFuncLocator: (np.float32, c_lib.convertToMLIRSparseTensorF32, c_lib.convertFromMLIRSparseTensorF32), (np.float64, c_lib.convertToMLIRSparseTensorF64, - c_lib.convertFromMLIRSparseTensorF64)] + c_lib.convertFromMLIRSparseTensorF64), + (np.complex64, c_lib.convertToMLIRSparseTensorC32, + c_lib.convertFromMLIRSparseTensorC32), + (np.complex128, c_lib.convertToMLIRSparseTensorC64, + c_lib.convertFromMLIRSparseTensorC64)] except Exception as e: raise ValueError(f"Missing supporting function: {e}") from e for i, info in enumerate(support_types): @@ -134,14 +138,15 @@ def sparse_tensor_to_coo_tensor( rank = ctypes.c_ulonglong(0) nse = ctypes.c_ulonglong(0) shape = ctypes.POINTER(ctypes.c_ulonglong)() - values = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype))() + + values = ctypes.POINTER(runtime.as_ctype(np.dtype(dtype)))() indices = ctypes.POINTER(ctypes.c_ulonglong)() convert_from(sparse_tensor, ctypes.byref(rank), ctypes.byref(nse), ctypes.byref(shape), ctypes.byref(values), ctypes.byref(indices)) # Convert the returned values to the corresponding numpy types. shape = np.ctypeslib.as_array(shape, shape=[rank.value]) - values = np.ctypeslib.as_array(values, shape=[nse.value]) + values = runtime.to_numpy(np.ctypeslib.as_array(values, shape=[nse.value])) indices = np.ctypeslib.as_array(indices, shape=[nse.value, rank.value]) return rank.value, nse.value, shape, values, indices @@ -175,7 +180,7 @@ def coo_tensor_to_sparse_tensor(np_shape: np.ndarray, np_values: np.ndarray, nse = ctypes.c_ulonglong(len(np_values)) shape = np_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong)) values = np_values.ctypes.data_as( - ctypes.POINTER(np.ctypeslib.as_ctypes_type(np_values.dtype))) + ctypes.POINTER(runtime.as_ctype(np.dtype(np_values.dtype)))) indices = np_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong)) perm = np_perm.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))