In [1]:
import sys
from collections import defaultdict

In [2]:
class Peekable():
    def __init__(self, iterable):
        self.index = 0
        self.iterable = iterable
        self.length = len(iterable)

    def eof(self):
        return self.index >= self.length
    
    def peek(self):
        return self.iterable[self.index] if not self.eof() else None
        
    def advance(self):
        thing = self.peek()
        self.index += 1
        return thing
    
    def get_iterable(self):
        return self.iterable[self.index:] if not self.eof() else None

    def to_string(self):
        return str(self.get_iterable())

    def __repr__(self):
        return self.to_string()

In [3]:
class Token():
    def __init__(self, *args, **kwargs):
        self.type = kwargs.pop('type', None)
        self.val  = kwargs.pop('val',  None)
    
    def to_string(self):
        infoList = []
        if self.type:
            infoList.append(f'type = {self.type}')
        if self.val:
            infoList.append(f'val = {self.val}')
        infoString = ', '.join(infoList)
        return f'Token({infoString})'
    
    def __repr__(self):
        return self.to_string()

In [4]:
class Scanner():
    digits = ['0','1','2','3','4','5','6','7','8','9']
    letters = ['a','b','c','d','e','g','h','j','k','l','m','o','q','r','s','t','u','v','w','x','y','z']
    keywords = ['f','i','p']
    whitespace = [' ','\n','\t']

    def __init__(self, program):
        self.program = Peekable(program)

    def scan(self):
        tokens = []
        while not self.program.eof():
            tokens.append(self.scan_next_character())
        return tokens

    def scan_next_character(self):
        while self.program.peek() in self.whitespace:
            self.program.advance()
        
        if self.program.eof():
            return Token(type='$')
        
        if self.program.peek() in self.digits:
            return self.scan_digits()
        
        ch = self.program.advance()
        if ch in self.letters:
            return Token(type='id', val=ch)
        
        match ch:
            case 'f':
                return Token(type='floatdcl')
            case 'i':
                return Token(type='intdcl')
            case 'p':
                return Token(type='print')
            case '=':
                return Token(type='assign')
            case '+':
                return Token(type='plus')
            case '-':
                return Token(type='minus')
            case _:
                self.lexical_error()

    def scan_digits(self):
        tok = Token(val='')
        while self.program.peek() in self.digits:
            nextThing = self.program.advance()
            tok.val += nextThing
        
        if self.program.peek() != '.':
            tok.type = 'inum'
        else:
            tok.type = 'fnum'
            tok.val += self.program.advance()
            while self.program.peek() in self.digits:
                tok.val += self.program.advance()
        
        return tok
    
    def lexical_error(self):
        sys.exit('Lexical error!')

In [5]:
class Node():
    def __init__(self, type, val=None):
        self.type = type
        self.val = val
        self.children = []
        self.id = None
    
    def add_child(self, node):
        self.children.append(node)

    def add_parent(self, node):
        self.parent = node
        self.parent.add_child(self)
     
    def get_id(self):
        return self.id

    def get_type(self):
        return self.type
    
    def get_val(self):
        return self.val

    def get_child(self, index):
        children = self.get_children()
        return children[index] if index < self.num_children() else None

    def get_children(self):
        return self.children

    def is_leaf(self):
        return self.num_children() == 0
    
    def num_children(self):
        return len(self.children)

    def set_id(self, id):
        self.id = id

    def to_string(self):
        return f'Node({self.get_type()})'
    
    def __repr__(self):
        return self.to_string()

class Tree():
    def __init__(self):
        self.nodes = {}
        self.root = None

    def add_node(self, node, parent=None):
        id = self.num_nodes()

        if parent:
            parentId = parent.get_id()
            if parentId not in self.nodes:
                print('You provided a parent node that is not in the graph, not adding this node.')
                return
        
            node.add_parent(parent)
        
        node.set_id(id)
        self.nodes[id] = node

        # By convention, the first node added will be the root
        if id == 0:
            self.root = node

    def get_root(self):
        return self.root

    def num_nodes(self):
        return len(self.nodes)

    def render_tree(self, node=None, preamble=''):
        if not node:
            node = self.root

        string = preamble + node.get_type() 
        if node.get_val():
            string += f' ({node.get_val()})'
        string += '\n'

        children = node.get_children()
        for i, child in enumerate(children):
            nextPreamble = preamble
            nextPreamble = nextPreamble.replace('├','│').replace('─',' ').replace('└',' ')
            nextPreamble += '└── ' if i == len(children) - 1 else '├── '
            string += self.render_tree(child, nextPreamble)
        return string
    
    def __repr__(self):
        return self.render_tree()

In [6]:
class Parser():
    def __init__(self, tokens):
        # We don't need to token 'val' field to check syntax
        self.tokens = Peekable(tokens)
        self.verbose = True
        self.depth = 0
        self.callStack = []
        self.parseTree = Tree()
        self.ast = Tree()

    def parse(self):
        self.prog()
        self.convert_parse_tree_to_ast()
        print('Your program is syntactically valid!')

    def prog(self):
        self.append_to_call_stack('prog')
        
        node = Node('Prog')
        self.parseTree.add_node(node)

        if self.tokens.peek().type in ['floatdcl', 'intdcl', 'id', 'print', '$']:
            self.dcls(node)
            self.stmts(node)

            eosNode = Node('$')
            self.parseTree.add_node(eosNode, parent=node)
        else:
            self.syntax_error('prog')
    
    def dcls(self, parentNode):
        self.depth += 1
        self.append_to_call_stack('dcls')

        node = Node('Dcls')
        self.parseTree.add_node(node, parent=parentNode)

        if self.tokens.peek().type in ['floatdcl', 'intdcl']:
            self.dcl(node)
            self.dcls(node)
        elif self.tokens.peek().type in ['id', 'print', '$']:
            # Do nothing for lambda production, $ is next
            lambdaNode = Node('λ')
            self.parseTree.add_node(lambdaNode, parent=node)
        else:
            self.syntax_error('dcls')
        self.depth -= 1

    def dcl(self, parentNode):
        self.depth += 1
        self.append_to_call_stack('dcl')

        node = Node('Dcl')
        self.parseTree.add_node(node, parent=parentNode)

        if self.tokens.peek().type == 'floatdcl':
            self.match('floatdcl', node)
            self.match('id', node)
        elif self.tokens.peek().type == 'intdcl':
            self.match('intdcl', node)
            self.match('id', node)
        else:
            self.syntax_error('dcl')
        self.depth -= 1

    def stmts(self, parentNode):
        self.depth += 1
        self.append_to_call_stack('stmts')

        node = Node('Stmts')
        self.parseTree.add_node(node, parent=parentNode)

        if self.tokens.peek().type in ['id', 'print']:
            self.stmt(node)
            self.stmts(node)
        elif self.tokens.peek().type == '$':
            # Do nothing for lambda production, $ is next
            lambdaNode = Node('λ')
            self.parseTree.add_node(lambdaNode, parent=node)
        else:
            self.syntax_error('stmts')
        self.depth -= 1

    def stmt(self, parentNode):
        self.depth += 1
        self.append_to_call_stack('stmt')

        node = Node('Stmt')
        self.parseTree.add_node(node, parent=parentNode)

        if self.tokens.peek().type == 'id':
            self.match('id', node)
            self.match('assign', node)
            self.val(node)
            self.expr(node)
        elif self.tokens.peek().type == 'print':
            self.match('print', node)
            self.match('id', node)
        else:
            self.syntax_error('stmt')
        self.depth -= 1

    def expr(self, parentNode):
        self.depth += 1
        self.append_to_call_stack('expr')

        node = Node('Expr')
        self.parseTree.add_node(node, parent=parentNode)

        if self.tokens.peek().type == 'plus':
            self.match('plus', node)
            self.val(node)
            self.expr(node)
        elif self.tokens.peek().type == 'minus':
            self.match('minus', node)
            self.val(node)
            self.expr(node)
        elif self.tokens.peek().type in ['id', 'print', '$']:
            # Do nothing for lambda production, Stmt or $ is next
            lambdaNode = Node('λ')
            self.parseTree.add_node(lambdaNode, parent=node)
        else:
            self.syntax_error('expr')
        self.depth -= 1

    def val(self, parentNode):
        self.depth += 1
        self.append_to_call_stack('val')

        node = Node('Val')
        self.parseTree.add_node(node, parent=parentNode)

        if self.tokens.peek().type == 'id':
            self.match('id', node)
        elif self.tokens.peek().type == 'inum':
            self.match('inum', node)
        elif self.tokens.peek().type == 'fnum':
            self.match('fnum', node)
        else:
            self.syntax_error('val')
        self.depth -= 1

    def match(self, compare, parentNode):
        self.depth += 1
        self.append_to_call_stack('match')

        nextToken = self.tokens.peek()
        if nextToken.type == compare:
            node = Node(compare, val=nextToken.val)
            self.parseTree.add_node(node, parent=parentNode)

            self.depth -= 1
            return self.tokens.advance()
        else:
            self.syntax_error('match')

    def get_ast(self):
        return self.ast

    def get_parse_tree(self):
        return self.parseTree
    
    def convert_parse_tree_to_ast(self):       
        progNode = Node('Program')
        self.ast.add_node(progNode)    

        def recurse_on_expr(exprNode, prevTermNode, astParentNode):
            if exprNode.num_children() == 1 and exprNode.get_child(0).get_type() == 'λ':
                self.ast.add_node(prevTermNode, parent=astParentNode)
            else:
                pluOrMinusNode = exprNode.get_child(0)
                self.ast.add_node(pluOrMinusNode, parent=astParentNode)
                self.ast.add_node(prevTermNode, parent=pluOrMinusNode)

                valNode = exprNode.get_child(1).get_child(0)

                nextExprNode = exprNode.get_child(2)
                recurse_on_expr(nextExprNode, valNode, pluOrMinusNode)

        def recurse(parseNode, astParent):
            if parseNode.get_type() == 'Dcl':
                children = parseNode.get_children()
                astType = children[0].get_type() # floatdcl or intdcl
                astVal  = children[1].get_val()  # the id
                node = Node(astType, astVal)
                self.ast.add_node(node, parent=astParent)
                return
            
            if parseNode.get_type() == 'Stmt':
                children = parseNode.get_children()
                if children[0].get_type() == 'print':
                    astVal  = children[1].get_val() # the id
                    node = Node('print', astVal)
                    self.ast.add_node(node, parent=astParent)
                    return
                
                if children[1].get_type() == 'assign':
                    assignNode = Node('assign')
                    self.ast.add_node(assignNode, parent=astParent)

                    idNode = children[0]
                    self.ast.add_node(idNode, parent=assignNode)

                    valNode = children[2].get_child(0)
                    
                    exprNode = children[3]
                    recurse_on_expr(exprNode, valNode, assignNode)
                    
            for child in parseNode.get_children():
                recurse(child, astParent)
        
        recurse(self.parseTree.get_root(), progNode)

    def append_to_call_stack(self, caller):
        remainingTokens = ' '.join([tok.type for tok in self.tokens.get_iterable()])
        formatted = 2*self.depth*' ' + f'{caller}() --> {remainingTokens}'
        self.callStack.append(formatted)

    def syntax_error(self, caller):
        errorMsg = f'Syntax error in call to {caller}(), here is the stacktrace:\n\n'
        errorMsg += '\n'.join(self.callStack)
        sys.exit(errorMsg)

In [7]:
program = '''

f b
i a
a = 5
b = a + 3.2
p b

'''

In [8]:
tokens = Scanner(program).scan()

In [9]:
parser = Parser(tokens)
parser.parse()

Your program is syntactically valid!


In [10]:
parser.get_parse_tree()

Prog
├── Dcls
│   ├── Dcl
│   │   ├── floatdcl
│   │   └── id (b)
│   └── Dcls
│       ├── Dcl
│       │   ├── intdcl
│       │   └── id (a)
│       └── Dcls
│           └── λ
├── Stmts
│   ├── Stmt
│   │   ├── id (a)
│   │   ├── assign
│   │   ├── Val
│   │   │   └── inum (5)
│   │   └── Expr
│   │       └── λ
│   └── Stmts
│       ├── Stmt
│       │   ├── id (b)
│       │   ├── assign
│       │   ├── Val
│       │   │   └── id (a)
│       │   └── Expr
│       │       ├── plus
│       │       │   ├── id (a)
│       │       │   └── fnum (3.2)
│       │       ├── Val
│       │       │   └── fnum (3.2)
│       │       └── Expr
│       │           └── λ
│       └── Stmts
│           ├── Stmt
│           │   ├── print
│           │   └── id (b)
│           └── Stmts
│               └── λ
└── $

In [11]:
parser.get_ast()

Program
├── floatdcl (b)
├── intdcl (a)
├── assign
│   ├── id (a)
│   └── inum (5)
├── assign
│   ├── id (b)
│   └── plus
│       ├── id (a)
│       └── fnum (3.2)
└── print (b)

In [52]:
class ContextFreeGrammar:
    def __init__(self, startSymbol):
        self.startSymbol = startSymbol
        self.nonterminals = set()
        self.terminals = set()
        self.productions = set()

    ####################################
    # Low-level building block methods #
    ####################################

    def add_production(self, A, rhs):
        production = (A, rhs)
        self.productions.add(production)

        # The rest of this method automatically updates the terminals and
        # nonterminals of the grammar; still not sure if we should do this
        # here or require the user to call the methods below to register
        # these symbols... but for now this is way more convenient
        self.add_nonterminal(A)
        if A in self.get_terminals():
            self.terminals.remove(A)
            
        for y in rhs:
            if y not in self.get_nonterminals():
                self.add_terminal(y)
        
    def add_nonterminal(self, A):
        self.nonterminals.add(A)

    def add_terminal(self, x):
        self.terminals.add(x)

    def get_start_symbol(self):
        return self.startSymbol

    def get_productions(self):
        return self.productions

    def get_nonterminals(self):
        return self.nonterminals
    
    def get_terminals(self):
        return self.terminals
    
    def is_terminal(self, x):
        return x in self.get_terminals()
    
    def get_lhs(self, p):
        return p[0]
    def get_rhs(self, p):
        return p[1]

    def get_productions_for(self, A):
        return set(p for p in self.get_productions() if self.get_lhs(p) == A)

    def get_occurrences(self, X):
        occurrences = set()
        for p in self.get_productions():
            indicies = [i for i, e in enumerate(self.get_rhs(p)) if e == X]
            occurrences |= set((p, i) for i in indicies)
        return occurrences

    def get_production(self, occurrence):
        return occurrence[0]

    def get_tail(self, occurrence):
        production, index = occurrence
        rhs = self.get_rhs(production)
        return rhs[index+1:]

    ############################
    # Methods for CFG analysis #
    ############################

    def derives_empty_string(self):
        symbol_derives_empty = {A:False for A in self.get_nonterminals()}
        rule_derives_empty = {}
        count = {}
        work_list = set()

        def check_for_empty(p):
            if count[p] == 0:
                rule_derives_empty[p] = True
                A = self.get_lhs(p)
                if not symbol_derives_empty[A]:
                    symbol_derives_empty[A] = True
                    work_list.add(A)

        for p in self.get_productions():
            rule_derives_empty[p] = False
            count[p] = len(self.get_rhs(p))
            check_for_empty(p)

        while len(work_list) > 0:
            X = work_list.pop()
            for x in self.get_occurrences(X):
                p = self.get_production(x)
                count[p] -= 1
                check_for_empty(p)

        return symbol_derives_empty, rule_derives_empty

    def first(self, nonterminal):
        visited_first = {A:False for A in self.get_nonterminals()}

        def internal_first(XB):
            if len(XB) == 0:
                return set()
            
            first = XB[0]
            rest = XB[1:]

            if first in self.get_terminals():
                return set(first)
            
            # X is nonterminal
            result = set()
            if not visited_first[first]:
                visited_first[first] = True
                for p in self.get_productions_for(first):
                    rhs = self.get_rhs(p)
                    result |= internal_first(rhs)

            symbol_derives_empty, _ = self.derives_empty_string()
            if symbol_derives_empty[first]:
                result |= internal_first(rest)

            return result

        return internal_first((nonterminal,))

    def follow(self, nonterminal):
        visited_follow = {A:False for A in self.get_nonterminals()}

        def internal_follow(A):
            print(f'internal_follow for A = {A}')
            result = set()
            if not visited_follow[A]:
                for a in self.get_occurrences(A):
                    print(f'occurrence a = {a}, tail = {self.get_tail(a)}')
                    result |= self.first(self.get_tail(a))
                    if all_derives_empty(self.get_tail(a)):
                        targ = self.get_lhs(self.get_production(a))
                        ans |= internal_follow(targ)
            return result
        
        def all_derives_empty(γ):
            for X in γ:
                symbol_derives_empty, _ = self.derives_empty_string()
                if not symbol_derives_empty[X] or X in self.get_terminals():
                    return False
            return True

        return internal_follow(nonterminal)

    ################################################
    # Methods for printing / string representation #
    ################################################

    def longest_nonterminal(self):
        return max([len(self.get_lhs(p)) for p in self.get_productions()])

    def nonterminal_to_string(self, lhs, lhsMaxLen):
        string = ''       
        thisLhsLen = len(lhs)
        numSpaces = lhsMaxLen - thisLhsLen

        for i, (_, rhs) in enumerate(self.get_productions_for(lhs)):
            if len(rhs) == 0:
                rhs = ('λ',)

            if i == 0:
                string += lhs + numSpaces*' ' +  ' → ' + ' '.join(rhs) + '\n'
            else:
                string += (len(lhs) + numSpaces)*' ' +  ' | ' + ' '.join(rhs) + '\n'

        return string

    def to_string(self):
        length = self.longest_nonterminal()
        string = self.nonterminal_to_string(self.startSymbol, length)

        for lhs in self.get_nonterminals():
            if lhs != self.startSymbol:
                string += self.nonterminal_to_string(lhs, length)

        return string
    
    def __repr__(self):
        return self.to_string()

In [53]:
cfg = ContextFreeGrammar('E')

cfg.add_production('E',      ('Prefix','(','E',')'))
cfg.add_production('E',      ('v','Tail'))
cfg.add_production('Prefix', ('f',))
cfg.add_production('Prefix', ())
cfg.add_production('Tail',   ('+','E'))
cfg.add_production('Tail',   ())

In [54]:
cfg

E      → v Tail
       | Prefix ( E )
Prefix → f
       | λ
Tail   → + E
       | λ

In [55]:
cfg.follow('Prefix')

internal_follow for A = Prefix
occurrence a = (('E', ('Prefix', '(', 'E', ')')), 0), tail = ('Prefix', '(', 'E', ')')


KeyError: ('Prefix', '(', 'E', ')')

In [56]:
cfg.first(('Prefix', '(', 'E', ')'))

KeyError: ('Prefix', '(', 'E', ')')

In [25]:
cfg.first('Tail')

{'+'}

In [294]:
cfg = ContextFreeGrammar('S')

cfg.add_production('S', ('A','B','c'))
cfg.add_production('A', ('a',))
cfg.add_production('A', ())
cfg.add_production('B', ('b',))
cfg.add_production('B', ())

In [295]:
cfg.first('B')

{'b'}

In [283]:
cfg = ContextFreeGrammar('Program')
cfg.add_production('Program', ('begin','Stmts','end','$'))
cfg.add_production('Stmts',   ('Stmt',';','Stmts'))
cfg.add_production('Stmts',   ())
cfg.add_production('Stmt',    ('simplestmt',))
cfg.add_production('Stmt',    ('begin','Stmts','end'))

In [284]:
cfg

Program → begin Stmts end $
Stmts   → Stmt ; Stmts
        | λ
Stmt    → begin Stmts end
        | simplestmt

In [285]:
cfg = ContextFreeGrammar('A')
cfg.add_production('A', ('B','C','D'))
cfg.add_production('B', ())
cfg.add_production('C', ())
cfg.add_production('D', ())
cfg.add_production('D', ('Z'))

In [286]:
derives_empty_string(cfg)

({'C': True, 'B': True, 'D': True, 'A': True},
 {('D', ()): True,
  ('A', ('B', 'C', 'D')): True,
  ('B', ()): True,
  ('D', 'Z'): False,
  ('C', ()): True})