In [None]:
import re
from typing import Dict, List, Set
from typing import OrderedDict as OrderedDictType
from collections import OrderedDict
from dataclasses import dataclass

In [None]:
@dataclass
class Production:
    """
    产生式规则
    """
    left: str
    right: List[str]

    def __str__(self):
        return f"{self.left} -> {' '.join(self.right)}"

    def __repr__(self):
        return str(self)

    def __hash__(self):
        return hash(str(self))

In [None]:
class Grammar:
    """
    文法
    """

    def __init__(self, productions: Set[Production], start_symbol: str):
        self.productions: Set[Production] = productions  # P
        self.terminals: Set[str] = set()  # V_T
        self.non_terminals: Set[str] = set()  # V_N
        self.start_symbol: str = start_symbol  # S

        self._compute_non_terminals()
        self._compute_terminals()

    def _compute_non_terminals(self):
        """
        计算非终结符集合
        """
        for production in self.productions:
            self.non_terminals.add(production.left)

    def _compute_terminals(self):
        """
        计算终结符集合
        """
        if len(self.non_terminals) == 0:
            self._compute_non_terminals()
        for production in self.productions:
            for symbol in production.right:
                if symbol not in self.non_terminals:
                    self.terminals.add(symbol)

    def __str__(self):
        return f"""
Start Symbol: {self.start_symbol}
Terminals: {self.terminals}
Non-terminals: {self.non_terminals}
Productions:
""" + "\n".join([str(p) for p in self.productions])

    def __repr__(self):
        return str(self)

    def __hash__(self):
        return hash(str(self))

In [None]:
FIRST: Dict[str, Set[str]] = {}  # first sets
FOLLOW: Dict[str, Set[str]] = {}  # follow sets
SELECT: Dict[Production, Set[str]] = {}  # select sets

In [None]:
production_regex = re.compile(r'(?P<left>\w+) -> (?P<right>.+)')

In [None]:
with open("grammar.txt", "r") as f:
    productions: Set[Production] = set()
    lines = f.readlines()
    for line in lines:
        line = line.strip()  # remove trailing whitespace
        if line == "":
            continue
        match = production_regex.match(line)
        if match is None:
            raise Exception(f"Invalid production: {line}")
        left = match.group("left")
        right = match.group("right").split(" ")
        productions.add(Production(left, right))

In [None]:
print("\n".join(str(p) for p in productions))

In [None]:
grammar = Grammar(productions=productions,
                  start_symbol="program")

In [None]:
print(grammar)

In [None]:
def eliminate_left_recursion(grammar: Grammar) -> Grammar:
    """
    消除左递归：
        1. 带入生成式，产生 mid_productions_p_i
        2. 消除 mid_productions_p_i 的直接左递归，产生无左递归的 new_productions_p_i
    Args:
        grammar:

    Returns:

    """
    productions = list(grammar.productions)
    terminals = list(grammar.terminals)
    non_terminals = list(grammar.non_terminals)
    new_productions: List[Production] = []

    def get_right_symbols(productions: List[Production]) -> Set[str]:
        """
        Get symbols that appeared on the right hand side of a production
        Args:
            productions:

        Returns:

        """
        right_symbol_set = set()
        for p in productions:
            for sym in p.right:
                right_symbol_set.add(sym)
        return right_symbol_set

    # create a dictionary for easy access, i.e. [A -> B,..., A -> C, A -> D] => A: A -> B | C | D, B: B->...
    productions_by_key: OrderedDictType[str, List[Production]] = OrderedDict()  # A -> alpha_1 | ... | alpha_n
    for production in productions:
        if production.left not in productions_by_key:
            productions_by_key[production.left] = []
        productions_by_key[production.left].append(production)

    # indirect left recursion
    for i in range(len(non_terminals)):
        P_i = non_terminals[i]

        prev_productions_p_i = productions_by_key[P_i]
        # intermediate productions, (not recursion free). Those start with terminals is not changed
        mid_productions_p_i: List[Production] = [p for p in prev_productions_p_i if p.right[0] in terminals]
        # those with terminals after P_i is not changed
        mid_productions_p_i.extend([p for p in prev_productions_p_i if p.right[0] in non_terminals[i:]])

        for j in range(i):
            P_j = non_terminals[j]

            # find all right hand side symbols of P_i
            P_i_right_symbol_set = get_right_symbols(productions_by_key[P_i])

            if P_j not in P_i_right_symbol_set:  # P_i -> P_j gamma not exists
                continue

            # P_i -> P_j gamma exists,
            # change P_i -> P_j gamma to P_i -> delta_1 gamma | delta_2 gamma | ... | delta_n gamma
            # where P_j -> delta_1 | delta_2 | ... | delta_n
            productions_p_j = productions_by_key[P_j]
            for production_p_i in prev_productions_p_i:  # for each in P_i -> alpha_1 | alpha_2 | ... | alpha_n
                if production_p_i.right[0] == P_j:  # if production_p_i: P_i -> P_j gamma
                    gamma = production_p_i.right[1:]
                    for production_p_j in productions_p_j:  # for each in P_j -> delta_1 | ... | delta_n
                        delta = production_p_j.right
                        mid_production_p_i = Production(
                            left=P_i,
                            right=delta + gamma
                        )  # P_i -> delta gamma
                        mid_productions_p_i.append(mid_production_p_i)

                else:  # production_p_i: P_i -> beta
                    mid_productions_p_i.append(production_p_i)

            productions_by_key[P_i] = mid_productions_p_i  # TODO check
            prev_productions_p_i = productions_by_key[P_i]

            # store intermediate productions. Those start with terminals is not changed
            mid_productions_p_i: List[Production] = [p for p in prev_productions_p_i if p.right[0] in terminals]
            # those with terminals after P_i is not changed
            mid_productions_p_i.extend([p for p in prev_productions_p_i if p.right[0] in non_terminals[i:]])

        # end of for j in range(i)

        # attempt to eliminate direct left recursion for P_i
        new_productions_p_i: List[Production] = []  # stores left-recursion-free productions for P_i
        new_productions_p_i_: List[Production] = []  # P_i' for left recursion elimination

        # test if P_i has left recursion
        has_left_recursion: bool = False
        for production_p_i in mid_productions_p_i:
            if production_p_i.right[0] == P_i:  # production_p_i: P_i -> P_i gamma
                has_left_recursion = True
                break

        if not has_left_recursion:
            new_productions_p_i = mid_productions_p_i
        else:
            # eliminate direct left recursion for P_i
            for production_p_i in mid_productions_p_i:
                if production_p_i.right[0] == P_i:  # production_p_i: P_i -> P_i gamma
                    gamma = production_p_i.right[1:]

                    # P_i' -> epsilon | gamma P_i'
                    new_production_p_i_ = Production(
                        left=P_i + "'",
                        right=gamma + [P_i + "'"]
                    )  # P_i' -> gamma P_i'
                    new_productions_p_i_.append(new_production_p_i_)
                else:  # production_p_i: P_i -> beta, no left recursion for current production
                    new_production_p_i = Production(
                        left=P_i,
                        right=production_p_i.right + [P_i + "'"]
                    )  # P_i -> beta P_i'
                    new_productions_p_i.append(new_production_p_i)

            # add P_i' -> epsilon
            new_productions_p_i_.append(Production(
                left=P_i + "'",
                right=["$"]
            ))  # P_i' -> epsilon

            productions_by_key[P_i + "'"] = new_productions_p_i_

        productions_by_key[P_i] = new_productions_p_i
    # end of for i in range(len(non_terminals))

    for k, v in productions_by_key.items():
        new_productions.extend(v)

    return Grammar(
        productions=set(new_productions),
        start_symbol=grammar.start_symbol
    )

In [None]:
with open("grammar_recursion.txt", "r") as f:
    productions: Set[Production] = set()
    lines = f.readlines()
    for line in lines:
        line = line.strip()  # remove trailing whitespace
        if line == "":
            continue
        match = production_regex.match(line)
        if match is None:
            raise Exception(f"Invalid production: {line}")
        left = match.group("left")
        right = match.group("right").split(" ")
        productions.add(Production(left, right))

old_grammar = Grammar(
    productions=productions,
    start_symbol="R"
)

In [None]:
old_grammar

In [None]:
new_grammar = eliminate_left_recursion(old_grammar)

In [None]:
sorted(list(new_grammar.productions), key=lambda x: x.left)