# Chapter 2: Emitting Basic MLIR

This is an xDSL version of the Toy compiler, as described in the [MLIR tutorial](https://mlir.llvm.org/docs/Tutorials/Toy/)

This tutorial will be illustrated with a toy language that we’ll call “Toy” (naming is hard…). Toy is a tensor-based language that allows you to define functions, perform some math computation, and print results.

In [2]:
# This adds the locally checked-out version of xdsl to path.
import sys
sys.path.append('../src')

In [3]:
from xdsl import *
from xdsl.ir import *
from xdsl.irdl import *
from xdsl.dialects.func import *
from xdsl.dialects.arith import *
from xdsl.dialects.builtin import *
from xdsl.parser import *
from xdsl.printer import *
from xdsl.util import *

# MLContext, containing information about the registered dialects
context = MLContext()

# Some useful dialects
arith = Arith(context)
func = Func(context)
builtin = Builtin(context)

# Printer used to pretty-print MLIR data structures
printer = Printer(target=Printer.Target.MLIR)

@dataclass
class Toy:
    ctx: MLContext

    def __post_init__(self):
        self.ctx.register_op(ConstantOp)
        # self.ctx.register_op(Call)
        # self.ctx.register_op(Return)

TensorTypeF64 = TensorType.from_type_and_list(f64)

@irdl_op_definition
class ConstantOp(Operation):
    '''
    Constant operation turns a literal into an SSA value. The data is attached
    to the operation as an attribute. For example:

    ```mlir
      %0 = toy.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>
                        : tensor<2x3xf64>
    ```
    '''
    name: str = "toy.constant"
    result = ResultDef(AnyAttr()) #TensorType[F64])
    value = AttributeDef(AnyAttr()) #F64ElementsAttr

    # TODO verify that the result and value type are equal

    @staticmethod
    def from_list(data: List[float], shape: List[int] = []):
        value = DenseIntOrFPElementsAttr.tensor_from_list(data, f64, shape)
        return ConstantOp.create(result_types=[value.type], attributes={"value": value})


AddOpT = TypeVar('AddOpT', bound='AddOp')

@irdl_op_definition
class AddOp(Operation):
    '''
    The "add" operation performs element-wise addition between two tensors.
    The shapes of the tensor operands are expected to match.
    '''
    name: str = 'toy.add'
    arguments = VarOperandDef(AnyAttr())
    result = ResultDef(AnyAttr()) # F64ElementsAttr

    @classmethod
    def from_summands(cls: type[AddOpT], lhs: OpResult, rhs: SSAValue) -> AddOpT:
        return super().create(result_types=[lhs.typ.element_type], operands=[lhs, rhs])


@irdl_op_definition
class FuncOp(Operation):
    '''
    The "toy.func" operation represents a user defined function. These are
    callable SSA-region operations that contain toy computations.

    Example:

    ```mlir
    toy.func @main() {
      %0 = toy.constant dense<5.500000e+00> : tensor<f64>
      %1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
      toy.print %1 : tensor<2x2xf64>
      toy.return
    }
    ```
    '''
    name = 'toy.func'
    body = RegionDef()
    sym_name = AttributeDef(SymbolNameAttr)
    function_type = AttributeDef(FunctionType)

    @staticmethod
    def from_region(name: str, ftype: FunctionType, region: Region):
        return FuncOp.create(result_types=[ftype.outputs], attributes={
            "sym_name": SymbolNameAttr.from_str(name),
            "function_type": ftype
        }, regions=[region])

    @staticmethod
    def from_callable(name: str, input_types: List[Attribute],
                      return_types: List[Attribute],
                      func: Block.BlockCallback):
        type_attr = FunctionType.from_lists(input_types, return_types)
        op = FuncOp.build(attributes={
            "sym_name": name,
            "function_type": type_attr,
        },
                          regions=[
                              Region.from_block_list(
                                  [Block.from_callable(input_types, func)])
                          ])
        return op

COpT = TypeVar('COpT', bound='GenericCallOp')

@irdl_op_definition
class GenericCallOp(Operation):
    name: str = "toy.generic_call"
    arguments = VarOperandDef(AnyAttr())
    callee = AttributeDef(FlatSymbolRefAttr)

    # Note: naming this results triggers an ArgumentError
    res = VarResultDef(AnyAttr())
    # TODO how do we verify that the types are correct?

    @classmethod
    def get(cls: type[COpT], callee: Union[str, FlatSymbolRefAttr],
            operands: List[Union[SSAValue, Operation]],
            return_types: List[Attribute]) -> COpT:
        if isinstance(callee, str):
            callee = FlatSymbolRefAttr.from_str(callee)
        
        return super().create(operands=operands, result_types=return_types, attributes={"callee": callee})


MulOpT = TypeVar('MulOpT', bound='MulOp')

@irdl_op_definition
class MulOp(Operation):
    '''
    The "mul" operation performs element-wise multiplication between two
    tensors. The shapes of the tensor operands are expected to match.
    '''
    name: str = 'toy.mul'
    arguments = VarOperandDef(AnyAttr())
    result = ResultDef(AnyAttr()) # F64ElementsAttr

    @classmethod
    def from_summands(cls: type[MulOpT], lhs: OpResult, rhs: SSAValue) -> MulOpT:
        return super().create(result_types=[lhs.typ.element_type], operands=[lhs, rhs])



PrintOpT = TypeVar('PrintOpT', bound='PrintOp')

@irdl_op_definition
class PrintOp(Operation):
    '''
    The "print" builtin operation prints a given input tensor, and produces
    no results.
    '''
    name: str = 'toy.print'
    arguments = VarOperandDef(AnyAttr())

    @classmethod
    def from_input(cls: type[PrintOpT], input: SSAValue) -> PrintOpT:
        return super().create(operands=[input])


ReturnOpT = TypeVar('ReturnOpT', bound='ReturnOp')

@irdl_op_definition
class ReturnOp(Operation):
    '''
    The "return" operation represents a return operation within a function.
    The operation takes an optional tensor operand and produces no results.
    The operand type must match the signature of the function that contains
    the operation. For example:

    ```mlir
      func @foo() -> tensor<2xf64> {
        ...
        toy.return %0 : tensor<2xf64>
      }
    ```
    '''
    name: str = 'toy.return'
    arguments = OptOperandDef(AnyAttr())

    @classmethod
    def from_input(cls: type[ReturnOpT], input: SSAValue) -> ReturnOpT:
        return super().create(operands=[input])


ROpT = TypeVar('ROpT', bound='ReshapeOp')

@irdl_op_definition
class ReshapeOp(Operation):
    '''
    Reshape operation is transforming its input tensor into a new tensor with
    the same number of elements but different shapes. For example:

    ```mlir
       %0 = toy.reshape (%arg1 : tensor<10xf64>) to tensor<5x2xf64>
    ```
    '''
    name: str = 'toy.reshape'
    arguments = VarOperandDef(AnyAttr())
    # We expect that the reshape operation returns a statically shaped tensor.
    result = ResultDef(AnyAttr()) # F64ElementsAttr 

    @classmethod
    def from_input(cls: type[ROpT], input: SSAValue, shape: List[int]) -> ROpT:
        t = AnyTensorType.from_type_and_list(input.typ.element_type, shape)
        return super().create(result_types=[t], operands=[input])


@irdl_op_definition
class TransposeOp(Operation):
    name: str = 'toy.transpose'
    arguments = OperandDef(AnyAttr())
    result = ResultDef(AnyAttr()) # F64ElementsAttr

    @staticmethod
    def from_input(input: SSAValue):
        input_type = input.typ
        if not isinstance(input_type, TensorType):
            assert False, f'{input_type}: {type(input_type)}'
        
        output_type = TensorType.from_type_and_list(input_type.element_type, list(reversed(input_type.shape.data)))

        return TransposeOp.create(operands=[input], result_types=[output_type])


def func_body(arg0: SSAValue, arg1: SSAValue) -> List[Operation]:
    f0 = TransposeOp.from_input(arg0)
    f1 = TransposeOp.from_input(arg1)
    f2 = MulOp.from_summands(f0.results[0], f1.results[0])
    f3 = ReturnOp.from_input(f2.results[0])
    return [f0, f1, f2, f3]

def main_body() -> List[Operation]:
    m0 = ConstantOp.from_list([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3])
    m1 = ReshapeOp.from_input(m0.results[0], [2,3])
    m2 = ConstantOp.from_list([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3])
    m3 = ReshapeOp.from_input(m2.results[0], [2,3])
    m4 = GenericCallOp.get('multiply_transpose', [m1.results[0], m3.results[0]], [TensorType.from_type_and_list(f64, [2, 3])])
    m5 = GenericCallOp.get('multiply_transpose', [m3.results[0], m1.results[0]], [TensorType.from_type_and_list(f64, [2, 3])])
    m6 = PrintOp.from_input(m5.results[0])
    m7 = ReturnOp.from_input(m5.results[0])
    return [
        m0, m1, m2, m3, m4, m5, m6, m7
    ]

multiply_transpose = FuncOp.from_callable('multiply_transpose', [TensorTypeF64, TensorTypeF64], [TensorTypeF64], func_body)
main = FuncOp.from_callable('main', [], [], main_body)

module = ModuleOp.from_region_or_ops([multiply_transpose, main])

printer.print(module)

"builtin.module"() ({
  "toy.func"() ({
  ^0(%0 : tensor<1xf64>, %1 : tensor<1xf64>):
    %2 = "toy.transpose"(%0) : (tensor<1xf64>) -> tensor<1xf64>
    %3 = "toy.transpose"(%1) : (tensor<1xf64>) -> tensor<1xf64>
    %4 = "toy.mul"(%2, %3) : (tensor<1xf64>, tensor<1xf64>) -> f64
    "toy.return"(%4) : (f64) -> ()
  }) {"sym_name" = #symbol_name<"multiply_transpose">, "function_type" = (tensor<1xf64>, tensor<1xf64>) -> tensor<1xf64>} : () -> ()
  "toy.func"() ({
    %5 = "toy.constant"() {"value" = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
    %6 = "toy.reshape"(%5) : (tensor<2x3xf64>) -> tensor<2x3xf64>
    %7 = "toy.constant"() {"value" = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
    %8 = "toy.reshape"(%7) : (tensor<2x3xf64>) -> tensor<2x3xf64>
    %9 = "toy.generic_call"(%6, %8) {"callee" = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<2x3xf64>
    %10 = "toy.generic_call"(%8,

In [3]:


const_op = ConstantOp.from_list([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3])
printer.print_op(const_op)

%11 = "toy.constant"() {"value" = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>


In [None]:
# %0 : !tensor<[2 : !index, 3 : !index], !f64> = toy.constant() ["value" = !dense<!tensor<[2 : !index, 3 : !index], !f64>, [1.0 : !f64, 2.0 : !f64, 3.0 : !f64, 4.0 : !f64, 5.0 : !f64, 6.0 : !f64]>]
# %0 = "toy.constant"() {"value" = dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]> : tensor<2x3x!f64>} : () -> tensor<2x3x!f64>


In [3]:
my_int = IntAttr.from_int(42)
printer.print_attribute(my_int)

!int<42>

In [4]:
from pathlib import Path

ast_toy = Path() / 'toy' / 'ast.toy'

In [5]:
with open(ast_toy, 'r') as f:
    print(f.read())

# RUN: toyc-ch1 %s -emit=ast 2>&1 | FileCheck %s

# User defined generic function that operates on unknown shaped arguments.
def multiply_transpose(a, b) {
  return transpose(a) * transpose(b);
}

def main() {
  # Define a variable `a` with shape <2, 3>, initialized with the literal value.
  # The shape is inferred from the supplied literal.
  var a = [[1, 2, 3], [4, 5, 6]];
  # b is identical to a, the literal array is implicitly reshaped: defining new
  # variables is the way to reshape arrays (element count in literal must match
  # the size of specified shape).
  var b<2, 3> = [1, 2, 3, 4, 5, 6];

  # This call will specialize `multiply_transpose` with <2, 3> for both
  # arguments and deduce a return type of <2, 2> in initialization of `c`.
  var c = multiply_transpose(a, b);
  # A second call to `multiply_transpose` with <2, 3> for both arguments will
  # reuse the previously specialized and inferred version and return `<2, 2>`
  var d = multiply_transpose(b, a);
  # A new call w

In [6]:
from toy.Parser import Parser

with open(ast_toy, 'r') as f:
    parser = Parser(ast_toy, f.read())

print(parser.parseModule().dump())

Module:
  Function 
    Proto 'multiply_transpose' toy/ast.toy:4:1'
    Params: [a, b]
    Block {
      Return
        BinOp: * toy/ast.toy:5:25
          Call 'transpose' [ toy/ast.toy:5:10
            var: a toy/ast.toy:5:20
          ]
          Call 'transpose' [ toy/ast.toy:5:25
            var: b toy/ast.toy:5:35
          ]
    } // Block
  Function 
    Proto 'main' toy/ast.toy:8:1'
    Params: []
    Block {
      VarDecl a<> toy/ast.toy:11:3
        Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] toy/ast.toy:11:11
      VarDecl b<2, 3> toy/ast.toy:15:3
        Literal: <6>[ 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00] toy/ast.toy:15:17
      VarDecl c<> toy/ast.toy:19:3
        Call 'multiply_transpose' [ toy/ast.toy:19:11
          var: a toy/ast.toy:19:30
          var: b toy/ast.toy:19:33
        ]
      VarDecl d<> toy/ast.toy:22:3
        Call 'multiply_transpose' [ t