In [1]:
from __future__ import annotations

In [2]:
from typing import Tuple, List, Dict

## Following [this](https://build-your-own.org/b2a/r1_parse) tutorial

# R1

In [3]:
RE_REPEAT_LIMIT = 1000

class Node:
    """A node in the regex AST"""
    def __init__(self, kind: str, *args):
        self.kind = kind
        self.args = args
    def __repr__(self):
        return f'Node({self.kind}, {", ".join(repr(a) for a in self.args)})'

def parse_split(r: str, idx: int) -> Tuple[int, Node]:
    """
    Takes subexpressions from parse_concat and parses them into a split node
    
    This function parses the "split" construct in regular expressions, represented by the | character.
    It builds a binary tree where each node represents a choice between two
    """
    idx, prev = parse_concat(r, idx) # parse a concat node
    while idx < len(r): # parse the rest of the concat nodes
        if r[idx] == ')': # end of the split
            break # return to the outer parse node 
        assert r[idx] == '|', f'Expected "|" got {r[idx]}'
        idx, node = parse_concat(r, idx + 1)
        prev = ('split', prev, node)
    return idx, prev

def parse_concat(r: str, idx: int) -> Tuple[int, Node]:
    """
    Takes subexpressions from parse_node and parses them into a concat node
    
    This function parses concatenation in regular expressions, implicitly represented by the juxtaposition of characters or subexpressions.
    It constructs a binary tree where each node represents two subexpressions that must occur in sequence.
    """
    prev = None
    while idx < len(r):
        if r[idx] in '|)':
            # return to the outer parse_split or parse_node
            break
        idx, node = parse_node(r, idx) # parse a node
        if prev is None:
            prev = node
        else:
            prev = ('cat', prev, node)
    return idx, prev

def parse_node(r: str, idx: int) -> Tuple[int, Node]:
    """
    Parses a single node
    
    This function parses individual elements or characters in the regular expression.
    It handles characters, dot characters representing any character (.), and parentheses which start subexpressions.
    """
    ch = r[idx]
    idx += 1
    assert ch not in '|)', f'Expected a character, got {ch}'
    if ch == '(':
        idx, node = parse_split(r, idx)
        if idx < len(r) and r[idx] == ')':
            idx += 1
        else:
            raise SyntaxError('Expected ")": unbalanced parentheses')
    elif ch == '.':
        node = 'dot'
    elif ch in '*+{':
        raise Exception('nothing to repeat')
    else:
        node = ch

    idx, node = parse_postfix(r, idx, node)
    return idx, node

def parse_postfix(r: str, idx: int, node: Node) -> Tuple[int, Node]:
    """
    This function handles postfix operators that modify the number of times a node can be matched,
    such as * for 0 or more times, + for 1 or more times, and {n,m} for a specific range of times.
    """
    if idx == len(r) or r[idx] not in '*+{':
        return idx, node
    
    ch = r[idx]
    idx += 1
    if ch == '*':
        rmin, rmax = 0, float('inf')
    elif ch == '+':
        rmin, rmax = 1, float('inf')
    else:
        # the first number inside the {}
        idx, i = parse_int(r, idx)
        if i is None:
            raise Exception('expect int')
        rmin = rmax = i
        # the optional second number
        if idx < len(r) and r[idx] == ',':
            idx, j = parse_int(r, idx + 1)
            rmax = j if (j is not None) else float('inf')
        # close the brace
        if idx < len(r) and r[idx] == '}':
            idx += 1
        else:
            raise Exception('unbalanced brace')

    # sanity checks
    if rmax < rmin:
        raise Exception('min repeat greater than max repeat')
    if rmin > RE_REPEAT_LIMIT:
        raise Exception('the repetition number is too large')

    node = ('repeat', node, rmin, rmax)
    return idx, node

def parse_int(r, idx):
    """This utility function parses a sequence of digits from the input string and returns an integer if possible."""
    save = idx
    while idx < len(r) and r[idx].isdigit():
        idx += 1
    return idx, int(r[save:idx]) if save != idx else None

def re_parse(r):
    """
    This is the entry point function that begins the parsing process.
    It checks if the entire input string has been parsed and raises an exception if any unprocessed input remains.
    """
    idx, node = parse_split(r, 0) 
    if idx != len(r):
        # parsing stopped at a bad ")"
        raise Exception('unexpected ")"')
    return node

In [4]:
assert re_parse('') is None
assert re_parse('.') == 'dot'
assert re_parse('a') == 'a'
assert re_parse('ab') == ('cat', 'a', 'b')
assert re_parse('a|b') == ('split', 'a', 'b')
assert re_parse('a+') == ('repeat', 'a', 1, float('inf'))
assert re_parse('a{3,6}') == ('repeat', 'a', 3, 6)
assert re_parse('a|bc') == ('split', 'a', ('cat', 'b', 'c'))

# R2

In [5]:
def match_backtrack(node, text, idx):
    if node is None:
        yield idx   # empty string
    elif node == 'dot':
        if idx < len(text):
            yield idx + 1
    elif isinstance(node, str):
        assert len(node) == 1   # single char
        if idx < len(text) and text[idx] == node:
            yield idx + 1
    elif node[0] == 'cat':
        # the `yield from` is equivalent to:
        # for idx1 in match_backtrack_concat(node, text, idx):
        #     yield idx1
        yield from match_backtrack_concat(node, text, idx)
    elif node[0] == 'split':
        yield from match_backtrack(node[1], text, idx)
        yield from match_backtrack(node[2], text, idx)
    elif node[0] == 'repeat':
        yield from match_backtrack_repeat(node, text, idx)
    else:
        assert not 'reachable'

def match_backtrack_concat(node, text, idx):
    met = set()
    for idx1 in match_backtrack(node[1], text, idx):
        if idx1 in met:
            continue    # duplication
        met.add(idx1)
        yield from match_backtrack(node[2], text, idx1)


def match_backtrack_repeat(node, text, idx):
    _, node, rmin, rmax = node
    rmax = min(rmax, RE_REPEAT_LIMIT)
    # the output is buffered and reversed later
    output = []
    if rmin == 0:
        # don't have to match anything
        output.append(idx)
    # positions from the previous step
    start = {idx}
    # try every possible repetition number
    for i in range(1, rmax + 1):
        found = set()
        for idx1 in start:
            for idx2 in match_backtrack(node, text, idx1):
                found.add(idx2)
                if i >= rmin:
                    output.append(idx2)
        # TODO: bail out if the only match is of zero-length
        if not found:
            break
        start = found
    # repetition is greedy, output the most repetitive match first.
    yield from reversed(output)

def re_full_match_bt(node, text):
    for idx in match_backtrack(node, text, 0):
        # idx is the size of the matched prefix
        if idx == len(text):
            # NOTE: the greedy aspect of regexes seems to be irrelevant
            #       if we are only accepting the fully matched text.
            return True
    return False