In [2]:
# To import local version of xdsl
import sys
sys.path.insert(0, '../src')

In [3]:
fib = '''
# Compute the x'th fibonacci number.
def fib(x)
  if x < 3 then
    1
  else
    fib(x-1)+fib(x-2)

# This expression will compute the 40th number.
fib(40)
'''

## Lexer

In [50]:
from enum import Enum
from dataclasses import dataclass

@dataclass 
class Token:
    pass

class EOFToken(Token):
    pass

class DefinitionToken(Token):
    pass

class ExternToken(Token):
    pass

@dataclass
class IdentifierToken(Token):
    identifier: str

@dataclass
class NumberToken(Token):
    value: float

@dataclass
class CharToken(Token):
    char: str


def lexer(program) -> Token:
    'Generates a stream of tokens given an input program in Keleidoscope'

    it = iter(program)

    try:
        char = next(it)
        
        while True:
            if char.isspace():
                # Ignore whitespace
                char = next(it)
                continue

            if char == '#':
                # Ignore characters until end of line
                while char not in ['\n', '\r']:
                    char = next(it)
                continue

            if char.isalpha():
                # Read identifier, definition, or extern
                identifierString = char
                
                while char := next(it):
                    if char.isalnum():
                        identifierString += char
                    elif identifierString == 'def':
                        yield DefinitionToken()
                        break
                    elif identifierString == 'extern':
                        yield ExternToken()
                        break
                    else:
                        yield IdentifierToken(identifier=identifierString)
                        break

                continue
            
            if char.isdigit() or char == '.':
                numberString = char

                while char := next(it):
                    if char.isdigit() or char == '.':
                        numberString += char
                    else:
                        yield NumberToken(value=float(numberString))
                        break
                
                continue

            yield CharToken(char)
            char = next(it)

    except StopIteration:
        yield EOFToken()

list(lexer(fib))

[DefinitionToken(),
 IdentifierToken(identifier='fib'),
 CharToken(char='('),
 IdentifierToken(identifier='x'),
 CharToken(char=')'),
 IdentifierToken(identifier='if'),
 IdentifierToken(identifier='x'),
 CharToken(char='<'),
 NumberToken(value=3.0),
 IdentifierToken(identifier='then'),
 NumberToken(value=1.0),
 IdentifierToken(identifier='else'),
 IdentifierToken(identifier='fib'),
 CharToken(char='('),
 IdentifierToken(identifier='x'),
 CharToken(char='-'),
 NumberToken(value=1.0),
 CharToken(char=')'),
 CharToken(char='+'),
 IdentifierToken(identifier='fib'),
 CharToken(char='('),
 IdentifierToken(identifier='x'),
 CharToken(char='-'),
 NumberToken(value=2.0),
 CharToken(char=')'),
 IdentifierToken(identifier='fib'),
 CharToken(char='('),
 NumberToken(value=40.0),
 CharToken(char=')'),
 EOFToken()]

## Abstract Syntax Tree

AST Nodes

In [116]:
from typing import List

@dataclass 
class Expr:
    pass

@dataclass
class NumberExpr(Expr):
    'Expression class for numeric literals like "1.0".'
    value: float

@dataclass
class VariableExpr(Expr):
    'Expression class for referencing a variable, like "a".'
    name: str

@dataclass
class BinaryExpr(Expr):
    'Expression class for a binary operator.'
    op: str
    lhs: Expr
    rhs: Expr

@dataclass
class CallExpr(Expr):
    'Expression class for function calls.'
    callee: str
    args: List[Expr]

@dataclass
class Prototype(Expr):
    '''
    PrototypeAST - This class represents the "prototype" for a function,
    which captures its name, and its argument names (thus implicitly the number
    of arguments the function takes).
    '''
    name: str
    args: List[str]
    
@dataclass
class Function(Expr):
    'This class represents a function definition itself.'
    proto: Prototype
    body: Expr

In [139]:
from typing import Tuple

class ParseError(Exception):
    pass

class UnexpectedTokenError(ParseError):
    def __init__(self, token, expected):
        self.token = token
        self.expected = expected
        self.message = f'Unexpected token: {token}, expected: {expected}'
        super().__init__(self.message)

class UnknownBinopError(ParseError):
    def __init__(self, char):
        self.char = char
        self.message = f'Unknown binop: {char}'
        super().__init__(self.message)

class UnexpectedEOFError(ParseError):
    def __init__(self):
        self.message = 'Unexpected end of file.'
        super().__init__(self.message)


class Parser:
    def __init__(self, binopPrecedence):
        self.binopPrecedence = binopPrecedence

    def parse(self, tokens):
        'Parses a stream of tokens into an abstract syntax tree.'
        tail = iter(tokens)
        head = next(tail)
        while not isinstance(head, EOFToken):
            if isinstance(head, CharToken) and head.char == ';':
                # Ignore top-level semicolons
                head = next(tail)
            elif isinstance(head, DefinitionToken):
                expr, head = self.parseDefinition(head, tail)
                yield expr
            elif isinstance(head, ExternToken):
                expr, head = self.parseExtern(head, tail)
                yield expr
            else:
                expr, head = self.parseExpression(head, tail)
                yield expr
        self.check(head, EOFToken)

    def check(self, token, expected, **kwargs):
        if not isinstance(token, expected):
            raise UnexpectedTokenError(token=token, expected=expected)
        
        if kwargs:
            for key, value in kwargs.items():
                if token.__dict__[key] != value:
                    raise UnexpectedTokenError(token=token, expected=f'{key}={value}')

    def getTokenPrecedence(self, token) -> int:
        if isinstance(token, CharToken) and token.char in self.binopPrecedence:
            return self.binopPrecedence[token.char]
        else:
            return -1

    def parsePrimary(self, head, tail) -> Tuple[NumberExpr, Token]:
        if isinstance(head, EOFToken):
            raise UnexpectedEOFError()
        elif isinstance(head, NumberToken):
            return self.parseNumberExpr(head, tail)
        elif isinstance(head, CharToken) and head.char == '(':
            return self.parseParenExpr(head, tail)
        else:
            assert False

    def parseNumberExpr(self, head, tail) -> Tuple[NumberExpr, Token]:
        '''
        numberexpr ::= number
        '''
        self.check(head, NumberToken)
        return NumberExpr(value=head.value), next(tail)
    
    def eatChar(self, char, head, tail) -> Token:
        self.check(head, CharToken, char=char)
        return next(tail)

    def parseParenExpr(self, head, tail) -> Tuple[Expr, Token]:
        '''
        parenexpr ::= '(' expression ')'
        '''
        head = self.eatChar('(', head, tail)
        expr, head = self.parseExpression(head, tail)
        head = self.eatChar(')', head, tail)
        return expr, head

    def parseIdentifierExpr(self, head, tail) -> Tuple[Expr, Token]:
        '''
        identifierexpr
            ::= identifier
            ::= identifier '(' expression* ')'
        '''
        self.check(head, IdentifierToken)
        identifier = head.identifier
        head = next(tail)

        if head != CharToken('('):
            return VariableExpr(name=identifier), head

        # Call
        head = next(tail)
        args = []
        while head != CharToken(')'):
            arg, head = self.parseExpression(head, tail)
            args.append(arg)

            if head == CharToken(','):
                head = next(tail)
            elif head != CharToken(')'):
                raise UnexpectedTokenError(head, '"," or ")"')
        head = self.eatChar(')', head, tail)
        return CallExpr(callee=identifier, args=args), head

    def parsePrimary(self, head, tail) -> Tuple[Expr, Token]:
        '''
        primary
           ::= identifierexpr
           ::= numberexpr
           ::= parenexpr
        '''
        if isinstance(head, IdentifierToken):
            return self.parseIdentifierExpr(head, tail)
        elif isinstance(head, NumberToken):
            return self.parseNumberExpr(head, tail)
        elif isinstance(head, CharToken) and head.char == '(':
            return self.parseParenExpr(head, tail)
        else:
            raise ParseError(f'Cannot parse primary expression starting with {head}')
    
    def parseBinOpRHS(self, lhsPrec: int, lhs: Expr, head: Token, tail: iter) -> Tuple[Expr, Token]:
        '''
        binoprhs
            ::= ('+' primary)*
        '''
        while True:
            rhsPrec = self.getTokenPrecedence(head)

            # If this is a binop that binds at least as tightly as the current binop, consume it, otherwise we are done.
            if rhsPrec < lhsPrec:
                return lhs, head
            
            binOp = head.char
            head = next(tail)

            # We know this is a binop
            rhs, head = self.parsePrimary(head, tail)

            # If BinOp binds less tightly with RHS than the operator after RHS, let the pending operator take RHS as its LHS.
            nextPrec = self.getTokenPrecedence(head)
            if rhsPrec < nextPrec:
                rhs, head = self.parseBinOpRHS(rhsPrec + 1, rhs, head, tail)
            
            lhs = BinaryExpr(op=binOp, lhs=lhs, rhs=rhs)

    def parseExpression(self, head: Token, tail: iter) -> Tuple[Expr, Token]:
        '''
        expression
            ::= primary binoprhs
        '''
        lhs, head = self.parsePrimary(head, tail)
        return self.parseBinOpRHS(0, lhs, head, tail)
    
    def parsePrototype(self, head, tail) -> Tuple[Expr, Token]:
        '''
        prototype
            ::= id '(' id* ')'
        '''
        self.check(head, IdentifierToken)

        functionName = head.identifier
        head = next(tail)

        self.check(head, CharToken, char='(')

        argNames = []
        while isinstance(head := next(tail), IdentifierToken):
            argNames.append(head.identifier)
        
        self.check(head, CharToken, char=')')

        head = self.eatChar(')', head, tail)

        return Prototype(functionName, argNames), head

    def parseDefinition(self, head, tail) -> Tuple[Expr, Token]:
        '''
        definition
            ::= 'def' prototype expression
        '''
        self.check(head, DefinitionToken)
        head = next(tail)
        proto, head = self.parsePrototype(head, tail)
        body, head = self.parseExpression(head, tail)
        return Function(proto, body), head

    def parseTopLevelExpr(self, head, tail):
        '''
        toplevelexpr ::= expression
        '''
        e, head = self.parseExpression(head, tail)
        # Make an anonymous proto
        proto = Prototype('__anon_expr', [])
        return Function(proto, e), head
    
    def parseExtern(self, head, tail):
        '''
        external ::= 'extern' prototype
        '''
        self.check(head, ExternToken)
        head = next(tail)
        return self.parsePrototype(head, tail)


binopPrecedence = {
    '<': 10,
    '+': 20,
    '-': 20,
    '*': 40,
}

parser = Parser(binopPrecedence=binopPrecedence)


print(parser.getTokenPrecedence(CharToken('<')))

def test_parse(fun, tokens, output):
    result = fun(tokens[0], iter(tokens[1:] + [EOFToken()]))
    assert result[0] == output, f'{result[0]} != {output}'
    assert result[1] == EOFToken()

test_parse(parser.parseNumberExpr, [NumberToken(value=1.0)], NumberExpr(1.0))

test_parse(parser.parseParenExpr, [CharToken('('),
    NumberToken(value=1.0), CharToken(')')], NumberExpr(value=1.0))

test_parse(parser.parseIdentifierExpr, [IdentifierToken(identifier='foo')], VariableExpr(name='foo'))

test_parse(parser.parseIdentifierExpr, [IdentifierToken(identifier='foo'), 
    CharToken('('), NumberToken(value=1.0), CharToken(')')], CallExpr(callee='foo', args=[NumberExpr(value=1.0)]))


test_parse(parser.parsePrimary, [IdentifierToken(identifier='foo')], VariableExpr(name='foo'))
test_parse(parser.parsePrimary, [NumberToken(value=1.0)], NumberExpr(1.0))
test_parse(parser.parsePrimary, [IdentifierToken(identifier='foo'), 
    CharToken('('), NumberToken(value=1.0), CharToken(')')], CallExpr(callee='foo', args=[NumberExpr(value=1.0)]))

test_parse(parser.parseExpression, [IdentifierToken(identifier='a'), 
    CharToken('+'), IdentifierToken(identifier='b')], BinaryExpr('+', VariableExpr('a'), VariableExpr('b')))
test_parse(parser.parseExpression, [
    IdentifierToken(identifier='a'), CharToken('+'), IdentifierToken(identifier='b'),
    CharToken(char='+'), IdentifierToken(identifier='c')
    ], BinaryExpr('+', BinaryExpr('+', VariableExpr('a'), VariableExpr('b')), VariableExpr('c')))
test_parse(parser.parseExpression, [
    IdentifierToken(identifier='a'), CharToken('+'), IdentifierToken(identifier='b'),
    CharToken(char='*'), IdentifierToken(identifier='c')
    ], BinaryExpr('+', VariableExpr('a'), BinaryExpr('*', VariableExpr('b'), VariableExpr('c'))))

test_parse(parser.parsePrototype, [
    IdentifierToken('foo'), CharToken('('), IdentifierToken('arg0'), 
    IdentifierToken('arg1'), CharToken(')')
    ], Prototype('foo', ['arg0', 'arg1']))

test_parse(parser.parseDefinition, [
    DefinitionToken(), IdentifierToken('id'), CharToken('('),
    IdentifierToken('x'), CharToken(')'), IdentifierToken('x')
    ], Function(Prototype('id', ['x']), VariableExpr('x')))

10


## Top-Level Parsing

In [140]:
list(parser.parse(lexer(fib)))

[Function(proto=Prototype(name='fib', args=['x']), body=VariableExpr(name='if')),
 BinaryExpr(op='<', lhs=VariableExpr(name='x'), rhs=NumberExpr(value=3.0)),
 VariableExpr(name='then'),
 NumberExpr(value=1.0),
 VariableExpr(name='else'),
 BinaryExpr(op='+', lhs=CallExpr(callee='fib', args=[BinaryExpr(op='-', lhs=VariableExpr(name='x'), rhs=NumberExpr(value=1.0))]), rhs=CallExpr(callee='fib', args=[BinaryExpr(op='-', lhs=VariableExpr(name='x'), rhs=NumberExpr(value=2.0))])),
 CallExpr(callee='fib', args=[NumberExpr(value=40.0)])]

## xDSL

In [3]:
from xdsl import *
from xdsl.ir import *
from xdsl.irdl import *
from xdsl.dialects.func import *
from xdsl.dialects.arith import *
from xdsl.dialects.builtin import *
from xdsl.parser import *
from xdsl.printer import *
from xdsl.util import *

# MLContext, containing information about the registered dialects
context = MLContext()

# Some useful dialects
arith = Arith(context)
func = Func(context)
builtin = Builtin(context)

# Printer used to pretty-print MLIR data structures
printer = Printer()

In [5]:
import arpeggio