In [1]:
# =============================================================================
# Imports
# =============================================================================

import json
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Union

# xDSL IR and dialect primitives
from xdsl.ir import Block, Region
from xdsl.context import Context
from xdsl.dialects.builtin import ModuleOp, IntegerAttr, i32
from xdsl.dialects.func import FuncOp
from xdsl.dialects.arith import ConstantOp, SubiOp, AddiOp, MuliOp, DivSIOp
from xdsl.printer import Printer


In [2]:

# =============================================================================
# AST Dataclasses - Minimal Subset for Integer Operations
# =============================================================================

# --------------------
# Literal Integer Node
# --------------------
@dataclass
class IntegerLiteral:
    value: int  # Constant integer value

# --------------------
# Binary Operation Node
# --------------------
@dataclass
class BinaryOperator:
    opcode: str                  # e.g., 'add', 'sub', 'mul', 'div'
    lhs: 'Expression'            # Left operand (can be another BinaryOperator or literal)
    rhs: 'Expression'            # Right operand

# Define general expression type: either literal or binary operator
Expression = Union[IntegerLiteral, BinaryOperator]

# --------------------
# Variable Declaration
# --------------------
@dataclass
class VarDecl:
    name: str                                # Variable name
    init: Optional[Expression] = None        # Optional initializer (expression)

# --------------------
# Return Statement
# --------------------
@dataclass
class ReturnStmt:
    value: Optional[Expression] = None       # Optional return value

# --------------------
# Compound Statement Block
# --------------------
@dataclass
class CompoundStmt:
    stmts: List[Union[VarDecl, ReturnStmt]] = field(default_factory=list)
    # A function body consisting of a list of statements

# --------------------
# Function Declaration
# --------------------
@dataclass
class FunctionDecl:
    name: str             # Function name
    body: CompoundStmt    # Function body (compound block)

# --------------------
# Root Node - Translation Unit
# --------------------
@dataclass
class TranslationUnit:
    decls: List[FunctionDecl] = field(default_factory=list)
    # Top-level container for the entire program


In [3]:
# =============================================================================
# MLIR Code Generator (from Dataclass AST to xDSL IR)
# =============================================================================

class MLIRGenerator:
    def __init__(self):
        # Symbol table to track SSA values for declared variables
        self.symbol_table: Dict[str, Union[ConstantOp, SubiOp, AddiOp, MuliOp, DivSIOp]] = {}

        # Current block being populated (used to emit ops)
        self.current_block: Optional[Block] = None

        # SSA name counter (not used here directly, placeholder for later)
        self.ssa_counter = 0

    # --------------------
    # Emit constant value
    # --------------------
    def emit_constant(self, value: int) -> ConstantOp:
        # Create a constant operation with 32-bit width
        const_op = ConstantOp.from_int_and_width(value, 32)

        # Add it to the current block
        self.current_block.add_op(const_op)

        # Return the constant operation (for use in expressions)
        return const_op

    # --------------------
    # Recursive Expression Lowering
    # --------------------
    def process_expression(self, expr: Expression):
        # Case: Literal integer → emit constant
        if isinstance(expr, IntegerLiteral):
            return self.emit_constant(expr.value)

        # Case: Binary operator → recursively lower children and apply op
        elif isinstance(expr, BinaryOperator):
            lhs_op = self.process_expression(expr.lhs)
            rhs_op = self.process_expression(expr.rhs)

            # Match on opcode
            if expr.opcode == '+':
                op = AddiOp(lhs_op.results[0], rhs_op.results[0])
            elif expr.opcode == '-':
                op = SubiOp(lhs_op.results[0], rhs_op.results[0])
            elif expr.opcode == '*':
                op = MuliOp(lhs_op.results[0], rhs_op.results[0])
            elif expr.opcode == '/':
                op = DivSIOp(lhs_op.results[0], rhs_op.results[0])
            else:
                raise ValueError(f"Unsupported operator: {expr.opcode}")

            # Add to block and return result
            self.current_block.add_op(op)
            return op

        # Catch-all for unknown expressions
        raise TypeError(f"Unsupported expression type: {type(expr)}")

    # --------------------
    # Generate FuncOp from FunctionDecl
    # --------------------
    def generate_function(self, func: FunctionDecl) -> FuncOp:
        # Start new block for function body
        block = Block()
        self.current_block = block

        # Reset symbol table and SSA counter
        self.symbol_table.clear()
        self.ssa_counter = 0

        # Lower each statement in the compound body
        for stmt in func.body.stmts:
            # Variable declaration with initializer → compute and store SSA value
            if isinstance(stmt, VarDecl) and stmt.init:
                op = self.process_expression(stmt.init)
                self.symbol_table[stmt.name] = op

            # Return statement → compute value but no explicit return yet
            elif isinstance(stmt, ReturnStmt) and stmt.value:
                op = self.process_expression(stmt.value)
                # NOTE: no return op emitted — this would be handled downstream

        # Return a FuncOp with one i32 return (simplified assumption)
        func_type = [i32]
        return FuncOp(func.name, ([], func_type), Region([block]))


In [4]:
# =============================================================================
# JSON → AST Dataclass Parser (Minimal C-style main-only subset)
# =============================================================================

def parse_ast(ast_json: dict) -> TranslationUnit:
    # Create top-level container for all function declarations
    tu = TranslationUnit()

    # Loop through top-level declarations in the JSON
    for decl in ast_json.get("inner", []):
        # Only process the 'main' function (skip others for now)
        if decl["kind"] == "FunctionDecl" and decl["name"] == "main":
            compound_stmt = None

            # Search for the CompoundStmt inside 'main'
            for inner in decl.get("inner", []):
                if inner["kind"] == "CompoundStmt":
                    compound_stmt = CompoundStmt()

                    # Iterate over each statement in the body of 'main'
                    for stmt in inner.get("inner", []):
                        
                        # Case: Variable declaration statement (DeclStmt)
                        if stmt["kind"] == "DeclStmt":
                            for var_decl in stmt.get("inner", []):
                                if var_decl["kind"] == "VarDecl":
                                    init_expr = None

                                    # Look for initialization expression
                                    if "inner" in var_decl and var_decl["inner"]:
                                        expr_node = var_decl["inner"][0]

                                        # Case: Binary expression (e.g. 3 + 4)
                                        if expr_node["kind"] == "BinaryOperator":
                                            lhs = IntegerLiteral(int(expr_node["inner"][0]["value"]))
                                            rhs = IntegerLiteral(int(expr_node["inner"][1]["value"]))
                                            init_expr = BinaryOperator(
                                                expr_node["opcode"],
                                                lhs,
                                                rhs
                                            )

                                        # Case: Single integer initializer (e.g. int x = 3)
                                        elif expr_node["kind"] == "IntegerLiteral":
                                            init_expr = IntegerLiteral(int(expr_node["value"]))

                                    # Append parsed VarDecl to compound body
                                    compound_stmt.stmts.append(
                                        VarDecl(var_decl["name"], init_expr)
                                    )

                        # Case: Return statement (e.g. return 0;)
                        elif stmt["kind"] == "ReturnStmt":
                            return_expr = None
                            if "inner" in stmt and stmt["inner"]:
                                expr_node = stmt["inner"][0]
                                if expr_node["kind"] == "IntegerLiteral":
                                    return_expr = IntegerLiteral(int(expr_node["value"]))
                            
                            # Append ReturnStmt to compound body
                            compound_stmt.stmts.append(ReturnStmt(return_expr))

            # Once entire compound body parsed, register the function
            if compound_stmt:
                tu.decls.append(FunctionDecl(decl["name"], compound_stmt))

    # Return fully constructed TranslationUnit
    return tu


In [5]:

# === Main Execution (Notebook-compatible) ===
json_path = "json_out/try.json"

with open(json_path) as f:
    ast_json = json.load(f)

translation_unit = parse_ast(ast_json)

if not translation_unit.decls:
    print("No functions found in AST")
else:
    mlir_generator = MLIRGenerator()
    ctx = Context()
    module = ModuleOp([])

    for func in translation_unit.decls:
        func_op = mlir_generator.generate_function(func)
        module.regions[0].blocks[0].add_op(func_op)

    printer = Printer()
    printer.print_op(module)


KeyError: 'value'