diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp index a544e52c261318..6b4e5434d1d7e7 100644 --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -2477,6 +2477,9 @@ class PyConcreteType : public BaseTy { static void bind(py::module &m) { auto cls = ClassTy(m, DerivedTy::pyClassName); cls.def(py::init(), py::keep_alive<0, 1>()); + cls.def_static("isinstance", [](PyType &otherType) -> bool { + return DerivedTy::isaFunction(otherType); + }); DerivedTy::bindDerived(cls); } diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py index 115ea40619b8f5..fdc6cfd9bab078 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py @@ -21,6 +21,7 @@ __all__ = [ "LinalgStructuredOpConfig", "LinalgOpConfig", + "TensorDefConfig", ] @@ -51,17 +52,17 @@ def __init__(self, tensor_def: TensorDef, shape_map: _ir.AffineMap): self.shape_map = shape_map self.indexing_map = None # type: Optional[_ir.AffineMap] - def to_yaml_custom_dict(self): - - def get_usage(): - if self.tensor_def.output: - return "output" - else: - return "input" + @property + def usage(self) -> str: + if self.tensor_def.output: + return "output" + else: + return "input" + def to_yaml_custom_dict(self): return dict( name=self.tensor_def.tensor_name, - usage=get_usage(), + usage=self.usage, shape=_serialize_affine_map(self.shape_map), element_type_var=self.tensor_def.type_var.name, ) diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py index d367c5bdde077b..cbff41db2d889b 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -11,6 +11,8 @@ from mlir import ir from .comprehension import * +from .config import * +from .emitter import * _CONTEXT = threading.local() @@ -42,9 +44,34 @@ def __init__(self, op_name: str, model: LinalgOpDef): self.op_name = op_name self.model = model - def __call__(self, *args, **kwargs): - # TODO: Upstream the emitter and invoke here - raise NotImplementedError("Linalg generic emission not yet implemented") + def __call__(self, *args, emit_generic: bool = True, **kwargs): + """Emits the corresponding op definition as IR. + + Most arguments are passed through to the underlying emitter. The following + are interpreted here: + emit_generic: Emits a generic form as appropriate (default True). If + False, a named form is emitted (which must have been built in to the + compiler). + """ + op_configs = LinalgOpConfig.from_linalg_op_def(self.model, + context=ir.Context.current) + + if len(op_configs) != 1: + # TODO: Support composite ops. + raise NotImplementedError( + f"Emission of composite linalg ops not supported: {op_configs}") + + op_config = op_configs[0] + if op_config.structured_op: + if emit_generic: + return emit_generic_structured_op(op_config.structured_op, *args, + **kwargs) + else: + return emit_named_structured_op(op_config.structured_op, *args, + **kwargs) + + raise NotImplementedError( + f"Emission of linalg op type not supported: {op_config}") def linalg_structured_op(dsl_func=None, diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py new file mode 100644 index 00000000000000..9a18993e9f6272 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -0,0 +1,252 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Dict, Sequence + +from mlir.ir import * +from mlir.dialects import linalg +from mlir.dialects import std + +from .scalar_expr import * +from .config import * + +__all__ = [ + "emit_generic_structured_op", + "emit_named_structured_op", +] + + +def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, + *ins: Value, + outs: Value = ()): + all_arg_defs = op_config.ordered_tensor_args + in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "input"] + out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "output"] + + # Arity validation. + if len(ins) != len(in_arg_defs): + raise ValueError(f"Expected {len(in_arg_defs)} inputs but got " + f"{len(ins)} for {op_config}") + if outs and len(outs) != len(out_arg_defs): + raise ValueError(f"Expected {len(out_arg_defs)} outputs but got " + f"{len(outs)} for {op_config}") + + outs, out_types = _infer_structured_outs(op_config, in_arg_defs, ins, + out_arg_defs, outs) + + # Extract type vars for input/output based types. + type_mapping = dict() # type: Dict[str, Type] + for arg_def, arg_element_type in zip( + in_arg_defs + out_arg_defs, + _get_shaped_element_types_from_values(*ins, *outs)): + tv_name = arg_def.tensor_def.type_var.name + type_mapping[tv_name] = arg_element_type + + # Emit the generic op. + # TODO: Support emission of pure memref form. + indexing_maps_attr = ArrayAttr.get( + [AffineMapAttr.get(am) for am in op_config.indexing_maps]) + iterator_types_attr = ArrayAttr.get( + [StringAttr.get(s) for s in op_config.iterator_types]) + generic_op = linalg.GenericOp( + result_tensors=out_types, + inputs=ins, + outputs=outs, + indexing_maps=indexing_maps_attr, + iterator_types=iterator_types_attr, + doc=None, # TODO: Make optional. + library_call=None, # TODO: Make optional. + sparse=BoolAttr.get(False)) # TODO: Make optional. + + # Construct the body. + block_arg_names = _get_tensor_def_names(*in_arg_defs, *out_arg_defs) + block_arg_types = _get_shaped_element_types_from_values(*ins, *outs) + block = generic_op.regions[0].blocks.append(*block_arg_types) + block_arg_mapping = dict(zip(block_arg_names, block.arguments)) + with InsertionPoint(block): + body_builder = _BodyBuilder(type_mapping, block_arg_mapping) + for assignment in op_config.assignments: + body_builder.assign(assignment) + body_builder.yield_outputs(*_get_tensor_def_names(*out_arg_defs)) + + if len(out_arg_defs) == 1: + return generic_op.result + else: + return generic_op.results + + +def emit_named_structured_op(op_config: LinalgStructuredOpConfig, + *ins: Value, + outs: Value = ()): + raise NotImplementedError( + f"Emission of named structured ops is not supported: {op_config}") + + +class _BodyBuilder: + """Constructs a structured op body by evaluating assignments.""" + + def __init__(self, type_mapping: Dict[str, Type], + block_arg_mapping: Dict[str, Value]): + self.type_mapping = type_mapping + self.block_arg_mapping = block_arg_mapping + self.yield_mapping = dict() # type: Dict[str, Value] + + def assign(self, assignment: ScalarAssign): + if assignment.arg in self.yield_mapping: + raise ValueError( + f"Multiple assignments to the same argument are forbidden: " + f"{assignment}") + self.yield_mapping[assignment.arg] = self.expression(assignment.value) + + def expression(self, expr: ScalarExpression) -> Value: + if expr.scalar_arg: + try: + return self.block_arg_mapping[expr.scalar_arg.arg] + except KeyError: + raise ValueError(f"Argument {expr.scalar_arg.arg} is not bound for " + f"this structured op.") + elif expr.scalar_apply: + try: + fn = getattr(self, f"_eval_{expr.scalar_apply.fn_name}") + except AttributeError: + raise ValueError( + f"Function '{expr.scalar_apply.fn_name}' is not a known " + "scalar body function") + operand_values = [ + self.expression(operand) for operand in expr.scalar_apply.operands + ] + return fn(*operand_values) + elif expr.symbolic_cast: + operand_value = self.expression(expr.symbolic_cast.operand) + return self.cast(expr.symbolic_cast.to_type.name, operand_value) + raise NotImplementedError(f"Unimplemented scalar body expression: {expr}") + + def cast(self, type_var_name: str, operand: Value) -> Value: + try: + to_type = self.type_mapping[type_var_name] + except KeyError: + raise ValueError(f"Unbound type variable '{type_var_name}' (" + f"expected one of {self.type_mappings.keys()}") + if operand.type == to_type: + return operand + if _is_integer_type(to_type): + return self._cast_to_integer(to_type, operand) + elif _is_floating_point_type(to_type): + return self._cast_to_floating_point(to_type, operand) + + raise ValueError(f"Unable to cast body expression from {operand.type} to " + f"{to_type}") + + def _cast_to_integer(self, to_type: Type, operand: Value) -> Value: + to_width = IntegerType(to_type).width + operand_type = operand.type + if _is_floating_point_type(operand_type): + return std.FPToSIOp(to_type, operand).result + # Assume integer. + from_width = IntegerType(operand_type).width + if to_width > from_width: + return std.SignExtendIOp(to_type, operand).result + elif to_width < from_width: + return std.TruncateIOp(to_type, operand).result + raise ValueError(f"Unable to cast body expression from {operand_type} to " + f"{to_type}") + + def _cast_to_floating_point(self, to_type: Type, operand: Value) -> Value: + operand_type = operand.type + if _is_integer_type(operand_type): + return std.SIToFPOp(to_type, operand).result + # Assume FloatType. + to_width = _get_floating_point_width(to_type) + from_width = _get_floating_point_width(operand_type) + if to_width > from_width: + return std.FPExtOp(to_type, operand).result + elif to_width < from_width: + return std.FPTruncOp(to_type, operand).result + raise ValueError(f"Unable to cast body expression from {operand_type} to " + f"{to_type}") + + def yield_outputs(self, *output_names: str): + output_values = [] + for n in output_names: + try: + output_values.append(self.yield_mapping[n]) + except KeyError: + raise ValueError(f"Body assignments do not assign all outputs: " + f"missing '{n}'") + linalg.YieldOp(output_values) + + def _eval_add(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return std.AddFOp(lhs.type, lhs, rhs).result + if _is_integer_type(lhs.type): + return std.AddIOp(lhs.type, lhs, rhs).result + raise NotImplementedError("Unsupported 'add' operand: {lhs}") + + def _eval_mul(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return std.MulFOp(lhs.type, lhs, rhs).result + if _is_integer_type(lhs.type): + return std.MulIOp(lhs.type, lhs, rhs).result + raise NotImplementedError("Unsupported 'mul' operand: {lhs}") + + +def _infer_structured_outs(op_config: LinalgStructuredOpConfig, + in_arg_defs: Sequence[TensorDefConfig], + ins: Sequence[Value], + out_arg_defs: Sequence[TensorDefConfig], + outs: Sequence[Value]): + """Infers implicit outs and output types. + + Respects existing contents of outs if not empty. + + Returns: + normalized outs, output types + """ + # If outs were explicitly provided, we accept them verbatim. + if outs: + return outs, [out.type for out in outs] + + raise NotImplementedError(f"Output tensor inference not yet supported for " + "structured ops") + + +def _get_shaped_element_types_from_values(*values: Value) -> Sequence[Type]: + types = [] + for v in values: + try: + t = ShapedType(v.type) + except Exception as e: + raise ValueError(f"Expected ShapedType but got {v}") from e + types.append(t.element_type) + return types + + +def _get_tensor_def_names( + *tensor_def_configs: TensorDefConfig) -> Sequence[str]: + return [tdc.tensor_def.tensor_name for tdc in tensor_def_configs] + + +def _is_floating_point_type(t: Type) -> bool: + # TODO: Create a FloatType in the Python API and implement the switch + # there. + return (F64Type.isinstance(t) or F32Type.isinstance(t) or + F16Type.isinstance(t) or BF16Type.isinstance(t)) + + +def _is_integer_type(t: Type) -> bool: + return IntegerType.isinstance(t) + + +def _get_floating_point_width(t: Type) -> int: + # TODO: Create a FloatType in the Python API and implement the switch + # there. + if F64Type.isinstance(t): + return 64 + if F32Type.isinstance(t): + return 32 + if F16Type.isinstance(t): + return 16 + if BF16Type.isinstance(t): + return 16 + raise NotImplementedError(f"Unhandled floating point type switch {t}") diff --git a/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py new file mode 100644 index 00000000000000..7f8c1167945736 --- /dev/null +++ b/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py @@ -0,0 +1,194 @@ +# RUN: %PYTHON %s | FileCheck %s + +from typing import Optional, Sequence + +from mlir.ir import * +from mlir.dialects import builtin +from mlir.dialects import linalg +from mlir.dialects import std + +from mlir.dialects.linalg.opdsl.lang import * + + +# TODO: Find a home for this quality of life helper. +def build_function(*inputs: Type, results: Optional[Sequence[Type]] = None): + """Decorator that emits a function in a more pythonic way. + + If result types are not specified, they are inferred from the function + returns. The `ReturnOp` is implicitly added upon the wrapped function return. + """ + + def decorator(f): + return_types = results + symbol_name = f.__name__ + function_type = FunctionType.get(inputs=inputs, results=results or []) + func_op = builtin.FuncOp(name=symbol_name, type=function_type) + with InsertionPoint(func_op.add_entry_block()): + func_args = func_op.entry_block.arguments + return_values = f(*func_args) + if return_values is None: + return_values = [] + elif isinstance(return_values, Value): + return_values = [return_values] + else: + return_values = list(return_values) + std.ReturnOp(return_values) + if return_types is None: + # Recompute the function type. + return_types = [v.type for v in return_values] + function_type = FunctionType.get(inputs=inputs, results=return_types) + # TODO: Have an API or a setter for this. + func_op.attributes["type"] = TypeAttr.get(function_type) + + # TODO: When turning this into a real facility, return a function that emits + # a `call` to the function instead of doing nothing. + wrapped = lambda: None + wrapped.__name__ = symbol_name + wrapped.func_op = func_op + return wrapped + + return decorator + + +@linalg_structured_op +def matmul_mono(A=TensorDef(T, S.M, S.K), + B=TensorDef(T, S.K, S.N), + C=TensorDef(T, S.M, S.N, output=True)): + C[D.m, D.n] += A[D.m, D.k] * B[D.k, D.n] + + +@linalg_structured_op +def matmul_poly(A=TensorDef(TV.T1, S.M, S.K), + B=TensorDef(TV.T2, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True)): + C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) + + +with Context() as ctx, Location.unknown(): + module = Module.create() + f16 = F16Type.get() + f32 = F32Type.get() + f64 = F64Type.get() + i8 = IntegerType.get_signless(8) + i16 = IntegerType.get_signless(16) + i32 = IntegerType.get_signless(32) + with InsertionPoint.at_block_terminator(module.body): + + # Note that these all have the same indexing maps. We verify the first and + # then do more permutation tests on casting and body generation + # behavior. + # CHECK: #[[$MAPA:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)> + # CHECK: #[[$MAPB:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)> + # CHECK: #[[$MAPC:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)> + + # CHECK-LABEL: func @test_matmul_mono + # CHECK-SAME: %[[A:.+]]: tensor<4x16xf32> + # CHECK-SAME: %[[B:.+]]: tensor<16x8xf32> + + # CHECK: %[[INITC:.+]] = linalg.init_tensor [4, 8] : tensor<4x8xf32> + # CHECK: linalg.generic + # CHECK-SAME: indexing_maps = [#[[$MAPA]], #[[$MAPB]], #[[$MAPC]]] + # CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] + # CHECK-SAME: ins(%[[A]], %[[B]] + # CHECK-SAME: outs(%[[INITC]] + + @build_function(RankedTensorType.get((4, 16), f32), + RankedTensorType.get((16, 8), f32)) + def test_matmul_mono(lhs, rhs): + # TODO: Enable outs inference and add sugar for InitTensorOp + # construction. + init_result = linalg.InitTensorOp(result=RankedTensorType.get((4, 8), + f32), + static_sizes=ArrayAttr.get([ + IntegerAttr.get(IndexType.get(), 4), + IntegerAttr.get(IndexType.get(), 8) + ]), + sizes=[]) + return matmul_mono(lhs, rhs, outs=[init_result.result]) + + # CHECK-LABEL: @test_i8i8i32_matmul + # CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: i32) + # CHECK-NEXT: %[[A_CAST:.+]] = sexti %[[A_ARG]] : i8 to i32 + # CHECK-NEXT: %[[B_CAST:.+]] = sexti %[[B_ARG]] : i8 to i32 + # CHECK-NEXT: %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i32 + # CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32 + # CHECK-NEXT: linalg.yield %[[ADD]] : i32 + # CHECK-NEXT: -> tensor<4x8xi32> + @build_function(RankedTensorType.get((4, 16), i8), + RankedTensorType.get((16, 8), i8), + RankedTensorType.get((4, 8), i32)) + def test_i8i8i32_matmul(lhs, rhs, init_result): + return matmul_poly(lhs, rhs, outs=[init_result]) + + # CHECK-LABEL: @test_i8i16i32_matmul + # CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i16, %[[C_ARG:.+]]: i32) + # CHECK-NEXT: %[[A_CAST:.+]] = sexti %[[A_ARG]] : i8 to i32 + # CHECK-NEXT: %[[B_CAST:.+]] = sexti %[[B_ARG]] : i16 to i32 + # CHECK-NEXT: %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i32 + # CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32 + # CHECK-NEXT: linalg.yield %[[ADD]] : i32 + # CHECK-NEXT: -> tensor<4x8xi32> + @build_function(RankedTensorType.get((4, 16), i8), + RankedTensorType.get((16, 8), i16), + RankedTensorType.get((4, 8), i32)) + def test_i8i16i32_matmul(lhs, rhs, init_result): + return matmul_poly(lhs, rhs, outs=[init_result]) + + # CHECK-LABEL: @test_i32i32i16_matmul + # CHECK: ^{{.*}}(%[[A_ARG:.+]]: i32, %[[B_ARG:.+]]: i32, %[[C_ARG:.+]]: i16) + # CHECK-NEXT: %[[A_CAST:.+]] = trunci %[[A_ARG]] : i32 to i16 + # CHECK-NEXT: %[[B_CAST:.+]] = trunci %[[B_ARG]] : i32 to i16 + # CHECK-NEXT: %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i16 + # CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i16 + # CHECK-NEXT: linalg.yield %[[ADD]] : i16 + # CHECK-NEXT: -> tensor<4x8xi16> + @build_function(RankedTensorType.get((4, 16), i32), + RankedTensorType.get((16, 8), i32), + RankedTensorType.get((4, 8), i16)) + def test_i32i32i16_matmul(lhs, rhs, init_result): + return matmul_poly(lhs, rhs, outs=[init_result]) + + # CHECK-LABEL: @test_i8i8f32_matmul + # CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: f32) + # CHECK-NEXT: %[[A_CAST:.+]] = sitofp %[[A_ARG]] : i8 to f32 + # CHECK-NEXT: %[[B_CAST:.+]] = sitofp %[[B_ARG]] : i8 to f32 + # CHECK-NEXT: %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32 + # CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32 + # CHECK-NEXT: linalg.yield %[[ADD]] : f32 + # CHECK-NEXT: -> tensor<4x8xf32> + @build_function(RankedTensorType.get((4, 16), i8), + RankedTensorType.get((16, 8), i8), + RankedTensorType.get((4, 8), f32)) + def test_i8i8f32_matmul(lhs, rhs, init_result): + return matmul_poly(lhs, rhs, outs=[init_result]) + + # CHECK-LABEL: @test_f16f16f32_matmul + # CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32) + # CHECK-NEXT: %[[A_CAST:.+]] = fpext %[[A_ARG]] : f16 to f32 + # CHECK-NEXT: %[[B_CAST:.+]] = fpext %[[B_ARG]] : f16 to f32 + # CHECK-NEXT: %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32 + # CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32 + # CHECK-NEXT: linalg.yield %[[ADD]] : f32 + # CHECK-NEXT: -> tensor<4x8xf32> + @build_function(RankedTensorType.get((4, 16), f16), + RankedTensorType.get((16, 8), f16), + RankedTensorType.get((4, 8), f32)) + def test_f16f16f32_matmul(lhs, rhs, init_result): + return matmul_poly(lhs, rhs, outs=[init_result]) + + # CHECK-LABEL: @test_f64f64f32_matmul + # CHECK: ^{{.*}}(%[[A_ARG:.+]]: f64, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32) + # CHECK-NEXT: %[[A_CAST:.+]] = fptrunc %[[A_ARG]] : f64 to f32 + # CHECK-NEXT: %[[B_CAST:.+]] = fptrunc %[[B_ARG]] : f64 to f32 + # CHECK-NEXT: %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32 + # CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32 + # CHECK-NEXT: linalg.yield %[[ADD]] : f32 + # CHECK-NEXT: -> tensor<4x8xf32> + @build_function(RankedTensorType.get((4, 16), f64), + RankedTensorType.get((16, 8), f64), + RankedTensorType.get((4, 8), f32)) + def test_f64f64f32_matmul(lhs, rhs, init_result): + return matmul_poly(lhs, rhs, outs=[init_result]) + + +print(module) diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py index 59b4b50b533d89..ea05c1561f74bb 100644 --- a/mlir/test/Bindings/Python/ir_types.py +++ b/mlir/test/Bindings/Python/ir_types.py @@ -59,6 +59,21 @@ def testTypeEq(): run(testTypeEq) +# CHECK-LABEL: TEST: testTypeIsInstance +def testTypeIsInstance(): + ctx = Context() + t1 = Type.parse("i32", ctx) + t2 = Type.parse("f32", ctx) + # CHECK: True + print(IntegerType.isinstance(t1)) + # CHECK: False + print(F32Type.isinstance(t1)) + # CHECK: True + print(F32Type.isinstance(t2)) + +run(testTypeIsInstance) + + # CHECK-LABEL: TEST: testTypeEqDoesNotRaise def testTypeEqDoesNotRaise(): ctx = Context()