Skip to content

Commit

Permalink
[mlir][sparse][taco] Support more data types.
Browse files Browse the repository at this point in the history
Support int8, int16, int32 and int32. Also fix source code format in mlir_pytaco_utils.py.

Add tests.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D124925
  • Loading branch information
bixia1 committed May 4, 2022
1 parent b6c67c3 commit 1cd13e6
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 127 deletions.
@@ -0,0 +1,33 @@
# RUN: SUPPORTLIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext %PYTHON %s | FileCheck %s

import numpy as np
import os
import sys

_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
sys.path.append(_SCRIPT_PATH)
from tools import mlir_pytaco_api as pt

compressed = pt.compressed
dense = pt.dense

passed = 0
all_types = [pt.int8, pt.int16, pt.int32, pt.int64, pt.float32, pt.float64]
for t in all_types:
i, j = pt.get_index_vars(2)
A = pt.tensor([2, 3], dtype=t)
B = pt.tensor([2, 3], dtype=t)
C = pt.tensor([2, 3], compressed, dtype=t)
A.insert([0, 1], 10)
A.insert([1, 2], 40)
B.insert([0, 0], 20)
B.insert([1, 2], 30)
C[i, j] = A[i, j] + B[i, j]

indices, values = C.get_coordinates_and_values()
passed += isinstance(values[0], t.value)
passed += np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
passed += np.allclose(values, [20.0, 10.0, 70.0])

# CHECK: Number of passed: 18
print("Number of passed:", passed)
Expand Up @@ -67,6 +67,7 @@ class Type(enum.Enum):
We use numpy data types to implement the enum data types.
"""
INT8 = np.int8
INT16 = np.int16
INT32 = np.int32
INT64 = np.int64
Expand All @@ -78,10 +79,11 @@ class Type(enum.Enum):
# All floating point type enums.
_FLOAT_TYPES = (Type.FLOAT32, Type.FLOAT64)
# All integral type enums.
_INT_TYPES = (Type.INT16, Type.INT32, Type.INT64)
_INT_TYPES = (Type.INT8, Type.INT16, Type.INT32, Type.INT64)
# Type alias for any numpy type used to implement the runtime support for the
# enum data types.
_AnyRuntimeType = Union[np.int16, np.int32, np.int64, np.float32, np.float64]
_AnyRuntimeType = Union[np.int8, np.int16, np.int32, np.int64, np.float32,
np.float64]


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -117,6 +119,7 @@ def value(self) -> _AnyRuntimeType:
def _dtype_to_mlir_str(dtype: DType) -> str:
"""Returns the MLIR string for the given dtype."""
dtype_to_str = {
Type.INT16: "i8",
Type.INT16: "i16",
Type.INT32: "i32",
Type.INT64: "i64",
Expand All @@ -129,6 +132,7 @@ def _dtype_to_mlir_str(dtype: DType) -> str:
def _nptype_to_taco_type(ty: np.dtype) -> DType:
"""Returns the TACO type for the given numpy type."""
nptype_to_dtype = {
np.int8: Type.INT8,
np.int16: Type.INT16,
np.int32: Type.INT32,
np.int64: Type.INT64,
Expand All @@ -141,6 +145,7 @@ def _nptype_to_taco_type(ty: np.dtype) -> DType:
def _mlir_type_from_taco_type(dtype: DType) -> ir.Type:
"""Returns the MLIR type corresponding to the given TACO type."""
dtype_to_irtype = {
Type.INT8: ir.IntegerType.get_signless(8),
Type.INT16: ir.IntegerType.get_signless(16),
Type.INT32: ir.IntegerType.get_signless(32),
Type.INT64: ir.IntegerType.get_signless(64),
Expand Down
Expand Up @@ -35,6 +35,7 @@
access = mlir_pytaco.Access

# Data type constants defined by PyTACO API.
int8 = mlir_pytaco.DType(mlir_pytaco.Type.INT8)
int16 = mlir_pytaco.DType(mlir_pytaco.Type.INT16)
int32 = mlir_pytaco.DType(mlir_pytaco.Type.INT32)
int64 = mlir_pytaco.DType(mlir_pytaco.Type.INT64)
Expand Down

0 comments on commit 1cd13e6

Please sign in to comment.