In [None]:
#| default_exp helper.latex.__init__

# helper.latex
> Helper functions for latex functionalities


We remark that many of the functions in this module are AI generated or assisted.

In [None]:
#| export
import random
import re
import string
from typing import Callable, List, Optional, Tuple, Union

from Levenshtein import distance
from pylatexenc.latexwalker import LatexCharsNode, LatexEnvironmentNode, LatexNode, LatexWalker, LatexMacroNode

from trouver.helper import sublist_generator
from trouver.helper.regex import latex_indices, replace_string_by_indices
from trouver.helper.latex.macros_and_commands import regex_pattern_detecting_command

In [None]:
from fastcore.test import *

## Validity of latex syntax

#### Test latex syntax

We require some functions to evaluate whether a latex math mode string is syntactically valid.

In [None]:
#| export
def _does_not_end_with_script(s):
    """
    This is a helper function to `math_mode_string_is_syntactically_valid`.
    `s` is not supposed to have math mode delimits
    """
    s = s.strip(' $')
    return not s.endswith('_') and not s.endswith('^')

In [None]:
#| hide
assert not _does_not_end_with_script(r'n=p_1^{e_1} p_2^{e_2} \cdots p_k^')
assert not _does_not_end_with_script(r'$n=p_1^{e_1} p_2^{e_2} \cdots p_k^$')
assert not _does_not_end_with_script(r'$$n=p_1^{e_1} p_2^{e_2} \cdots p_k^$$')
assert _does_not_end_with_script(r'n=p_1^{e_1} p_2^{e_2} \cdots p_k')
assert _does_not_end_with_script(r'$n=p_1^{e_1} p_2^{e_2} \cdots p_k$')
assert _does_not_end_with_script(r'$$n=p_1^{e_1} p_2^{e_2} \cdots p_k$$')
assert not _does_not_end_with_script(r'n=p_1^{e_1} p_2^{e_2} \cdots p_k_')


In [None]:
#| export
def _is_balanced_braces(s):
    """
    This is a helper function to `math_mode_string_is_syntactically_valid`.

    Note that curly braces (`{`, `}`) that are not preceded by backslashes
    '\\' are counted towards "balancing". 
    """
    stack = []
    escaped = False
    
    for _, char in enumerate(s):
        if char == '\\':
            escaped = True
        elif char == '{' and not escaped:
            stack.append(char)
        elif char == '}' and not escaped:
            if not stack:
                return False
            stack.pop()
        else:
            escaped = False
    
    return len(stack) == 0


In [None]:
#| hide
assert _is_balanced_braces('{{}}')
assert _is_balanced_braces('{asdf_{}}')
assert not _is_balanced_braces('{hi')
assert not _is_balanced_braces('{hi asdf}}')
assert _is_balanced_braces(r'\{hi')
assert _is_balanced_braces(r'\{{hi}')
assert not _is_balanced_braces(r'\{{hi\}')
assert not _is_balanced_braces('}')

In [None]:
#| export
def _first_curly_bracket(s) -> str|None:
    r"""
    Return whether a left curly bracket `{` or a right curly bracket `}`
    appears first in the string, ignoring escaped curly brackets `\{` or `\}`.
    """
    i = 0
    while i < len(s):
        if s[i] == '\\' and i + 1 < len(s):
            if s[i+1] in '{}':
                # Skip this escaped bracket
                i += 2
                continue
        elif s[i] == '{':
            return '{'
        elif s[i] == '}':
            return '}'
        i += 1
    return None


In [None]:
#| hide
assert _first_curly_bracket("{abc}") == '{'
assert _first_curly_bracket("}abc{") == '}'
assert _first_curly_bracket("abc{def}") == '{'
assert _first_curly_bracket("abc}def{") == '}'
assert _first_curly_bracket("abc") == None
assert _first_curly_bracket("\\{abc}") == '}'
assert _first_curly_bracket("a\\}b{c}") == '{'
assert _first_curly_bracket("\\{\\}abc") == None
assert _first_curly_bracket("\\\\{abc}") == '}'
assert _first_curly_bracket("\\a{bc}") == '{'
assert _first_curly_bracket("") == None
assert _first_curly_bracket("\\") == None
assert _first_curly_bracket("\\\\{}") == '}'

In [None]:
#| export
# def _last_curly_bracket(s) -> str|None:
def _last_curly_bracket(s):
    i = len(s) - 1
    while i >= 0:
        if s[i] in '{}':
            if i > 0 and s[i-1] == '\\':
                # This bracket is escaped
                i -= 2
                continue
            return s[i]
        i -= 1
    return None

In [None]:
#| hide
assert _last_curly_bracket("{abc}") == '}'
assert _last_curly_bracket("}abc{") == '{'
assert _last_curly_bracket("abc{def}") == '}'
assert _last_curly_bracket("abc}def{") == '{'
assert _last_curly_bracket("abc") == None
assert _last_curly_bracket("abc\\}") == None
assert _last_curly_bracket("a{b\\}c}") == '}'
assert _last_curly_bracket("\\{\\}abc") == None
assert _last_curly_bracket("abc{\\\\}") == '{'
assert _last_curly_bracket("{abc\\{") == '{'
assert _last_curly_bracket("") == None
assert _last_curly_bracket("\\") == None
assert _last_curly_bracket("{}\\\\") == '}'
assert _last_curly_bracket("\\\\{}") == '}'

In [None]:
#| export
def _detect_backslash_space_curly(
        text: str
        ) -> bool:
    r"""
    Return `True` if there is some backslash `\` followed
    by spaces and then followed by curly brackets `{`

    Note that the presence of such a combination of text
    will induce a syntax error in LaTeX math mode string.

    This is a helper function of `math_mode_string_is_syntactically_valid`
    """
    pattern = r'\\\s+[{}]'
    match = re.search(pattern, text)
    return bool(match)

In [None]:
#| hide
assert _detect_backslash_space_curly(r'\ {')
assert not _detect_backslash_space_curly(r'\{')
assert not _detect_backslash_space_curly(r'{')
assert _detect_backslash_space_curly(r'\ }')

In [None]:
#| export
def _is_left_right_balanced(
        latex_string: str
        ) -> bool:
    r"""
    Return `True` if occurrences of `\left` and `\right` are balanced. 

    This is a helper function of `math_mode_string_is_syntactically_valid`

    This function does not test whether occurrences of the
    appropriately corresponding braces are balanced. For instance,
    the function would return `True` on the input `\left . \right)`.
    Compare against `_is_semantically_left_right_balanced`, which
    is a similar function that tests whether left-right braces
    are "semantically" balanced.
    """
    # Remove all whitespace from the string
    latex_string = ''.join(latex_string.split())
    
    # Find all \left and \right commands
    left_commands = re.findall(r'\\left', latex_string)
    right_commands = re.findall(r'\\right', latex_string)
    
    # Check if the number of \left and \right commands are equal
    if len(left_commands) != len(right_commands):
        return False
    
    # Check if \left always comes before \right
    left_indices = [m.start() for m in re.finditer(r'\\left', latex_string)]
    right_indices = [m.start() for m in re.finditer(r'\\right', latex_string)]
    
    for left, right in zip(left_indices, right_indices):
        if left > right:
            return False
    
    return True

In [None]:
#| hide
assert _is_left_right_balanced(r"\left( x \right)")
assert _is_left_right_balanced(r"\left(x\right)")
assert _is_left_right_balanced(r"\left( x \left[ y \right] \right)")
assert _is_left_right_balanced(r"\left( x \right) \left[ y \right]")
assert not _is_left_right_balanced(r"\left( x \left[ y \right)")
assert _is_left_right_balanced(r"\left( x \right] \left[ y \right)")
assert _is_left_right_balanced(r"x + y")
assert not _is_left_right_balanced(r"\right)")
assert _is_left_right_balanced(r"\left. \right)")


In [None]:
#| export
def _is_semantically_left_right_balanced(
        latex_string: str
        ) -> bool:
    r"""
    Return `True` if occurrences of `\left` and `\right` macros
    preceding various braces are balanced.

    This is a helper function of `math_mode_string_is_syntactically_clean`

    Compare against `_is_left_right_balanced`, which
    is a similar function that only tests whether left-right
    macros are "syntactically" balanced, without regard to
    the types of braces actually used.
    """
    # Remove all whitespace from the string
    latex_string = ''.join(latex_string.split())
    
    # Define a stack to keep track of opening delimiters
    stack = []
    
    # Define a dictionary to match opening and closing delimiters
    delimiters = {
        '(': ')', '[': ']', '{': '}', '<': '>',
        r'\left(': r'\right)', r'\left[': r'\right]', 
        r'\left{': r'\right}', r'\left<': r'\right>',
        r'\left\{': r'\right\}', r'\left|': r'\right|',
        r'\left\|': r'\right\|', r'\left.': r'\right.',
        r'\left\langle': r'\right\rangle'
    }
    
    # Regular expression to match \left and \right commands with their delimiters
    pattern = r'(\\left[\(\[\{\<\|\.\|]|\\right[\)\]\}\>\|\.\|]|\(|\)|\[|\]|\{|\}|\<|\>)'
    
    # Find all delimiters and \left/\right commands
    tokens = re.findall(pattern, latex_string)
    
    for token in tokens:
        if token.startswith(r'\left') or token in '([{<':
            stack.append(token)
        elif token.startswith(r'\right') or token in ')]}>':
            if not stack:
                return False
            last_open = stack.pop()
            if token != delimiters.get(last_open):
                return False
    
    # If the stack is empty, all delimiters are balanced
    return len(stack) == 0


In [None]:
#| hide
# Test cases
    
assert _is_semantically_left_right_balanced(r"\left( x \right)")
assert _is_semantically_left_right_balanced(r"\left( x \left[ y \right] \right)")
assert _is_semantically_left_right_balanced(r"\left( x \right) \left[ y \right]")
assert not _is_semantically_left_right_balanced(r"\left( x \left[ y \right)")
assert not _is_semantically_left_right_balanced(r"\left( x \right] \left[ y \right)")
assert _is_semantically_left_right_balanced(r"x + y")
assert _is_semantically_left_right_balanced(r"\left\{ x \left( y \right) \right\}")
assert not _is_semantically_left_right_balanced(r"\left. x \right|_{a}^{b}")
assert _is_semantically_left_right_balanced(r"\left\| x \right\|")
assert _is_semantically_left_right_balanced(r"\left< x , y \right>")
assert _is_semantically_left_right_balanced(r"\left( \left[ \left\{ x \right\} \right] \right)")
assert not _is_semantically_left_right_balanced(r"\left( \left[ \left\{ x \right\} \right] \right]")

# TODO: the example of r"\left. x \right|_{a}^{b}" could be osmething that occurs often. do something about this.

In [None]:
#| export
def _has_invalid_left_right_bracket(
        latex_string: str
        ) -> bool:
    r"""
    Return `True` is there is at least one invalid use of
    a `\left` or `\right` command.

    This is a helper function of `math_mode_string_is_syntactically_valid`
    """
    # Remove all whitespace from the string
    latex_string = ''.join(latex_string.split())
    
    # Define valid brackets for \left and \right
    valid_brackets = [
        r'(', r')',
        r'[', r']',
        r'\(', r'\)', r'\[', r'\]',  # Parentheses and square brackets
        r'\{', r'\}',                # Curly braces (escaped)
        r'<', r'>',                  # Angle brackets
        r'\|',                       # Vertical bar
        r'|',                       # Vertical bar
        r'\\\|',                     # Double vertical bar (escaped)
        r'\.',                       # Dot
        r'.',                       # Dot
        r'\\lfloor', r'\\rfloor',    # Floor brackets
        r'\\lceil', r'\\rceil',      # Ceiling brackets
        r'\\langle', r'\\rangle'     # Angle brackets (commands)
    ]
    
    # Escape special regex characters and join with |
    valid_brackets_pattern = '|'.join(re.escape(b) for b in valid_brackets)
    
    # Pattern to match \left or \right followed by a valid bracket
    valid_pattern = rf'\\(left|right)({valid_brackets_pattern})'
    
    # Find all \left and \right commands
    commands = list(re.finditer(r'\\(left|right)', latex_string))
    
    for command in commands:
        # Check if the command is followed by a valid bracket
        if not re.match(valid_pattern, latex_string[command.start():]):
            # If not, return True and the invalid command
            invalid_part = latex_string[command.start():command.start()+6]  # Adjust slice as needed
            # return True, invalid_part
            return True
    
    # return False, None
    return False

In [None]:
assert not _has_invalid_left_right_bracket(r"\left( x \right)")
assert not _has_invalid_left_right_bracket(r"\left[ x \right]")
assert not _has_invalid_left_right_bracket(r"\left\{ x \right\}")
assert not _has_invalid_left_right_bracket(r"\left< x \right>")
assert not _has_invalid_left_right_bracket(r"\left| x \right|")
assert not _has_invalid_left_right_bracket(r"\left\| x \right\|")
assert not _has_invalid_left_right_bracket(r"\left\| x \right\|")
assert _has_invalid_left_right_bracket(r"\lefta x \right)")
assert _has_invalid_left_right_bracket(r"\left( x \righta")
assert _has_invalid_left_right_bracket(r"\left x \right)")
assert _has_invalid_left_right_bracket(r"\left( x \right x")
assert _has_invalid_left_right_bracket(r"\left\backslash x \right/")
assert not _has_invalid_left_right_bracket(r"x + y")
assert _has_invalid_left_right_bracket(r"\left\\")
assert _has_invalid_left_right_bracket(r"\right\\")


In [None]:
#| export

def _has_double_script(
        latex_string: str
        ) -> bool:
    """
    Return `True` if there is at least one double superscript
    or double subscript in `latex_string` 

    This function fails to give correct outputs for more
    nuanced texts, such as `r"x^{2}_{3}^{4}"`; while in
    principle, the function should return `True` on this
    input, the actual return value is `False`.

    This is a helper function of `math_mode_string_is_syntactically_valid`
    """
    # Remove all whitespace from the string
    latex_string = ''.join(latex_string.split())
    
    # Function to match balanced braces
    def match_braces(s, start):
        count = 0
        for i, char in enumerate(s[start:], start):
            if char == '{':
                count += 1
            elif char == '}':
                count -= 1
                if count == 0:
                    return i
        return len(s) - 1

    # Find all subscripts and superscripts
    i = 0
    last_script = None
    while i < len(latex_string):
        if latex_string[i] in '^_' and (i == 0 or latex_string[i-1] != '\\'):
            current_script = latex_string[i]
            i += 1
            if i < len(latex_string):
                if latex_string[i] == '{':
                    end = match_braces(latex_string, i)
                    script_content = latex_string[i:end+1]
                    i = end + 1
                else:
                    script_content = latex_string[i]
                    i += 1
                
                if last_script and last_script[0] == current_script:
                    return True
                last_script = (current_script, script_content)
        else:
            if latex_string[i] not in '^_':
                last_script = None
            i += 1

    return False

In [None]:
#| hide
assert not _has_double_script(r"x^2")
assert not _has_double_script(r"x_2")
assert not _has_double_script(r"x^{2}")
assert not _has_double_script(r"x_{2}")
assert _has_double_script(r"x^2^3")
assert _has_double_script(r"x_2_3")
assert _has_double_script(r"x^{2}^{3}")
assert _has_double_script(r"x_{2}_{3}")
assert not _has_double_script(r"x^{2^3}")
assert not _has_double_script(r"x_{2_3}")
assert not _has_double_script(r"x_2^3")
assert not _has_double_script(r"x^2_3")
assert not _has_double_script(r"x\^2")
assert not _has_double_script(r"x\_2")
assert not _has_double_script(r"x^{2^{3^4}}")
assert not _has_double_script(r"x\_2")
assert not _has_double_script(r"x_{2_{3_4}}")
assert not _has_double_script(r'$$x^2 + y^2$$')
assert not _has_double_script(r'\operatorname{Res}^{G}_{H} M')

# The function does not work correctly on the below input
# assert _has_double_script(r"x^{2}_{3}^{4}")

In [None]:
#| export
def _has_double_script_literal(
        latex_string: str
        ) -> bool:
    """
    Return `True` if there is at least one double superscript
    or double subscript in `latex_string` by virtue of having
    `__`, `_^`, `^_, `^^`

    This is a helper function of `math_mode_string_is_syntactically_valid`
    """
    for bad_text in ['__', '_^', '^_', '^^']:
        if bad_text in latex_string:
            return True
    return False

In [None]:
#| hide
assert _has_double_script_literal(r"$$R=\sum_P\in X\operatorname length\left(\Omega__X / Y\right)_p\cdot P$$")

In [None]:
#| export
def _has_unescaped_dollar(s):
    # Pattern explanation:
    # (?<!\\) is a negative lookbehind assertion that ensures the dollar sign is not preceded by a backslash
    # \$ matches a literal dollar sign
    pattern = r'(?<!\\)(\\\\)*\$'
    match = re.search(pattern, s)
    return bool(match)

In [None]:
#| hide
assert not _has_unescaped_dollar("This is a normal string")
assert not _has_unescaped_dollar("This has an escaped dollar sign \\$")
assert _has_unescaped_dollar("This has an unescaped dollar sign $")
assert _has_unescaped_dollar("Mixed case: \\$ and $")
assert _has_unescaped_dollar("Multiple unescaped: $ $")
assert not _has_unescaped_dollar("Escaped at the end \\$")
assert _has_unescaped_dollar("Unescaped at the end $")
assert _has_unescaped_dollar("$")
assert not _has_unescaped_dollar("\\$")
assert _has_unescaped_dollar("\\\\$")  # Double backslash followed by dollar
assert not _has_unescaped_dollar("\\\\\\$")
assert _has_unescaped_dollar("\\ $")

In [None]:
pattern = regex_pattern_detecting_command(('Sur', 0, None, r'\mathrm{Sur}'))
text = r'The number of element of $\Sur(\operatorname{Cl} \mathcal{O}_L, A)$ is ...'
match = pattern.search(text)
start, end = match.span()
test_eq(text[start:end], r'\Sur')


In [None]:
#| export
def extract_latex_commands(latex_string):
    # Create a LatexWalker instance
    walker = LatexWalker(latex_string)
    
    # Get the nodes from the LaTeX string
    try:
        nodelist, _, _ = walker.get_latex_nodes()
    except Exception as e:
        print(f"Error parsing LaTeX: {e}")
        return []  # Return an empty list if there's a parsing error
    # Extract commands
    commands = []
    extract_commands_from_nodes(commands, nodelist)
    return commands


def extract_commands_from_nodes(
        commands: list[str],
        nodes: list[LatexNode]
        ):
    """
    This is a helper function to `extract_latex_commands`.
    """
    for node in nodes:
        # If the node is a character node, we skip it
        if isinstance(node, LatexCharsNode):
            continue
        elif isinstance(node, LatexMacroNode):
            commands.append(node.macroname)
            # Check for arguments of the macro node
            for arg in node.nodeargs:
                if arg and not isinstance(arg, LatexCharsNode):
                    if hasattr(arg, 'nodelist'):  # Ensure the argument is not None
                        extract_commands_from_nodes(commands, arg.nodelist)  # Extract from argument nodes
                    elif isinstance(arg, LatexMacroNode):
                        commands.append(arg.macroname)
        # elif isinstance(node, LatexEnvironmentNode):

        elif isinstance(node, LatexEnvironmentNode):
            commands.extend(_detect_begin_and_end_environments(node.latex_verbatim()))
        # If the node has a nodelist, extract commands from it
        if hasattr(node, 'nodelist'):
            extract_commands_from_nodes(commands, node.nodelist)

def _detect_begin_and_end_environments(
        latex_string: str
        ) -> list[str]:
    r"""
    Return a list of at most two items containing 'begin' if there is a \begin and containing 'end' if there is an \end

    This is a helper function to `extract_latex_commands`.
    """
    # Regular expressions to match \begin and \end with optional spaces
    begin_pattern = r'\\\s*begin'
    end_pattern = r'\\\s*end'
    
    # Initialize an empty result list
    result = []
    
    # Check for \begin
    if re.search(begin_pattern, latex_string):
        result.append('begin')
    
    # Check for \end
    if re.search(end_pattern, latex_string):
        result.append('end')
    
    return result

In [None]:


# Example usage
assert extract_latex_commands(r"\frac{a}{b}") == ['frac']
assert extract_latex_commands(r"$\frac{a}{b}$") == ['frac']
assert extract_latex_commands(r"\sqrt[n]{x}") == ['sqrt']
assert extract_latex_commands(r"\binom{n}{k}") == ['binom']
assert extract_latex_commands(r"x^2 + y^2") == []  # No commands, just variables
assert extract_latex_commands(r"\overset{a}{b}") == ['overset']

# Additional tests
assert extract_latex_commands(r"\sum_{i=1}^{n} i") == ['sum']
assert extract_latex_commands(r"\int_{0}^{\infty} e^{-x} dx") == ['int', 'infty']
assert extract_latex_commands(r"\lim_{x \to 0} f(x)") == ['lim', 'to']
assert extract_latex_commands(r"\prod_{i=1}^{n} i") == ['prod']
assert extract_latex_commands(r"\text{Hello} + \frac{1}{2}") == ['text', 'frac']

# Multiple commands in one string
assert extract_latex_commands(r"\frac{a}{b} + \sqrt{c} + \binom{n}{k}") == ['frac', 'sqrt', 'binom']
assert extract_latex_commands(r"\sum_{i=1}^{n} i + \int_{0}^{\infty} e^{-x} dx") == ['sum', 'int', 'infty']
assert extract_latex_commands(r"\lim_{x \to 0} f(x) = \frac{1}{x}") == ['lim', 'to', 'frac']
assert extract_latex_commands(r"\overset{a}{b} + \underset{c}{d}") == ['overset', 'underset']
assert extract_latex_commands(r"\text{This is } \textbf{bold} + \textit{italic} + \frac{1}{2}") == ['text', 'textbf', 'textit', 'frac']

# Complex expressions
test_eq(extract_latex_commands(r"\frac{\sum_{i=1}^{n} i}{n} = \frac{n(n+1)}{2}"), ['frac', 'sum', 'frac'])
test_eq(extract_latex_commands(r"\int_{0}^{1} x^2 \, dx = \frac{1}{3}"), ['int', ',', 'frac'])
assert extract_latex_commands(r"\sqrt{\frac{a}{b}} + \binom{n}{k}") == ['sqrt', 'frac', 'binom']

# Incorrect synntax
assert extract_latex_commands(r"\frac{}}") == ['frac']
assert extract_latex_commands(r"\frac{a}{b}{c}") == ['frac']  # Extra argument
assert extract_latex_commands(r"\frac{a}{b + \frac{c}{d}}") == ['frac', 'frac']  # Nested command
test_eq(extract_latex_commands(r"\sum_{i=1}^{n} i + \int_{0}^{\infty} e^{-x} dx = \frac{1}{2}"), ['sum', 'int', 'infty', 'frac'])
# Comment
assert extract_latex_commands(r"%hi") == []
# Environment Node
test_eq(extract_latex_commands(r"\begin{align} \end{align}"), ['begin', 'end'])
test_eq(extract_latex_commands(r"\begin{align}"), ['begin'])
# test_eq(extract_latex_commands(r"\ begin{align} \end{align}"), [' '])

test_eq(extract_latex_commands(r'\text\in'), ['text', 'in'])

In [None]:
#| export
# Some arguments that can be used towards `regex_pattern_detecting_command`
# for some basic latex arguments.
# Note that the last argument doesn't actually matter, because
# we just want to be able to detect uses of comands, see
# `regex_pattern_detecting_commands``
REGEX_PATTERN_DETECTIONS = [
    ('frac', 2, None, None),
    ('binom', 2, None, None),
    ('sqrt', 1, '2', None),
    ('overset', 2, None, None),
    ('underset', 2, None, None),
    ('stackrel', 2, None, None),
    ('dfrac', 2, None, None),
    ('cfrac', 2, None, None),
    ('sideset', 3, None, None),
    ('xrightarrow', 1, None, None),
    ('xleftarrow', 1, None, None),
    ('overline', 1, None, None),
    ('bar', 1, None, None),
    ('arccos', 1, None, None),
    ('arcsin', 1, None, None),
    ('arctan', 1, None, None),
    ('arg', 1, None, None),
    ('atop', 2, None, None),
    ('begin', 1, None, None),
    ('boldsymbol', 1, None, None),
    ('breve', 1, None, None),
    ('check', 1, None, None),
    ('cline', 1, None, None),
    ('cos', 1, None, None),
    ('cosh', 1, None, None),
    ('cot', 1, None, None),
    ('csc', 1, None, None),
    ('dddot', 1, None, None),
    ('ddot', 1, None, None),
    ('dot', 1, None, None),
    ('end', 1, None, None),
    ('exp', 1, None, None),
    ('gcd', 2, None, None),
    ('grave', 1, None, None),
    ('hat', 1, None, None),
    # ('int', '1', None, None),
    ('lcm', 2, None, None),
    # ('left', 1, None, None),
    ('lg', 1, None, None),
    ('lim', 1, None, None),
    ('liminf', 1, None, None),
    ('limsup', 1, None, None),
    ('ln', 1, None, None),
    ('log', 1, None, None),
    ('longdiv', 2, None, None),
    ('lvert', 1, None, None),
    ('mapsto', 1, None, None),
    ('mathbb', 1, None, None),
    ('mathbf', 1, None, None),
    ('mathcal', 1, None, None),
    ('mathfrak', 1, None, None),
    ('mathop', 1, None, None),
    ('mathrm', 1, None, None),
    ('mathscr', 1, None, None),
    ('max', 1, None, None),
    ('min', 1, None, None),
    ('multicolumn', 3, 'center', None),
    ('multirow', 3, None, None),
    ('not', 1, None, None),
    ('oint', 1, None, None),
    ('overbrace', 1, None, None),
    ('overleftarrow', 1, None, None),
    ('overleftrightarrow', 1, None, None),
    ('overrightarrow', 1, None, None),
    # ('prod', 1, None, None),
    # ('right', 1, None, None),
    ('rvert', 1, None, None),
    ('sec', 1, None, None),
    ('section', 1, None, None),
    ('sin', 1, None, None),
    ('sinh', 1, None, None),
    ('stackrel', 2, None, None),
    ('subsection', 2, None, None),
    ('substack', 2, None, None),
    ('subsubsection', 2, None, None),
    # ('sum', 1, None, None),
    ('sup', 1, None, None),
    ('tag', 1, None, None),
    ('tan', 1, None, None),
    ('tanh', 1, None, None),
    ('text', 1, None, None),
    ('textbf', 1, None, None),
    ('textrm', 1, None, None),
    ('tilde', 1, None, None),
    ('underbrace', 1, None, None),
    ('underline', 1, None, None),
    ('underset', 2, None, None),
    ('varliminf', 1, None, None),
    ('varlimsup', 1, None, None),
    ('vec', 1, None, None),
    ('widehat', 1, None, None),
    ('widetilde', 1, None, None),
    ('xrightarrow', 1, None, None),
]
temp_dict = {}
for entry in REGEX_PATTERN_DETECTIONS:
    temp_dict[entry[0]] = entry
REGEX_PATTERN_DETECTIONS = temp_dict



In [None]:
#| export
def detect_incorrect_latex_commands(
        latex_string: str,
        ) -> bool:
    """
    Return `True` if there is at least one syntactically
    incorrect use of a latex command detected in `latex_string`.

    This is a helper function to `math_mode_string_is_syntactically_valid`.
    """
    commands_in_string = set(extract_latex_commands(latex_string))
    for command in commands_in_string:
        if command not in temp_dict:
            continue
        tuppy = temp_dict[command]
        pattern = regex_pattern_detecting_command(tuppy)
        # Look at each invocation of the command to see if 
        # each invocation is properly used.
        simp_pattern = rf"\\\s*{command}"
        simp_matches = re.finditer(simp_pattern, latex_string)
        # simp_matches = re.findall(simp_pattern, latex_string)
        for match in simp_matches:
            trailing_substring = latex_string[match.start():]
            alt_match = pattern.search(trailing_substring)
            if not alt_match or alt_match.span()[0] != 0:
                return True

        # if not matches and not simp_matches:
        #     continue
        # if len(matches) != len(simp_matches):
        #     return True
    return False

In [None]:
#| hide
    
# Tests
# Correct usage
assert not detect_incorrect_latex_commands(r'\frac{a}{b}')
assert not detect_incorrect_latex_commands(r'\binom{n}{k}')
assert not detect_incorrect_latex_commands(r'\sqrt[n]{x}')
assert not detect_incorrect_latex_commands(r'\overset{a}{b}')
assert not detect_incorrect_latex_commands(r'\underset{a}{b}')
assert not detect_incorrect_latex_commands(r'\stackrel{a}{b}')
assert not detect_incorrect_latex_commands(r'\dfrac{a}{b}')
assert not detect_incorrect_latex_commands(r'\cfrac{1}{1+\cfrac{1}{x}}')
assert not detect_incorrect_latex_commands(r'\xleftarrow{text}')
assert not detect_incorrect_latex_commands(r'\xrightarrow{text}')
assert not detect_incorrect_latex_commands(r'\left( \right.')
assert not detect_incorrect_latex_commands(r'\overbrace{x+y+z}^{\text{sum}}')
assert not detect_incorrect_latex_commands(r'\underbrace{x+y+z}_{\text{sum}}')
assert not detect_incorrect_latex_commands(r'\overbrace{x+y+z}')
assert not detect_incorrect_latex_commands(r'\underbrace{x+y+z}')

# Incorrect usage (missing arguments)
assert detect_incorrect_latex_commands(r'\frac{a}')
assert detect_incorrect_latex_commands(r'\binom{n}')
assert detect_incorrect_latex_commands(r'\overset{a}')
assert detect_incorrect_latex_commands(r'\underset{a}')
assert detect_incorrect_latex_commands(r'\stackrel{a}')
assert detect_incorrect_latex_commands(r'\dfrac{a}')
assert detect_incorrect_latex_commands(r'\cfrac{1}')
assert detect_incorrect_latex_commands(r'\sideset{_1^2}')

#Extra arguments are technically okay 
assert not detect_incorrect_latex_commands(r'\frac{a}{b}{c}')
assert not detect_incorrect_latex_commands(r'\binom{n}{k}{m}')
assert not detect_incorrect_latex_commands(r'\overset{a}{b}{c}')

# Mixed correct and incorrect usage
assert detect_incorrect_latex_commands(r'\frac{a}{b} + \frac{c}')
assert detect_incorrect_latex_commands(r'\binom{n}{k} \cdot \binom{m}')

assert detect_incorrect_latex_commands(r'\text\in')


In [None]:
#| export
def detect_unbalanced_environments(
        latex_string: str) -> list[str]:
    # Define a regex pattern to match \begin{...} and \end{...}
    pattern = r'\\(begin|end)\{([^}]+)\}'
    
    # Stack to keep track of opened environments
    stack = []
    # List to store errors
    errors = []

    # Find all matches in the LaTeX string
    for match in re.finditer(pattern, latex_string):
        command, env_name = match.groups()
        
        if command == 'begin':
            # Push the environment name onto the stack
            stack.append(env_name)
        elif command == 'end':
            # Check if there is a matching begin for this end
            if stack and stack[-1] == env_name:
                stack.pop()  # Match found, pop from stack
            else:
                # Mismatch found, record the error
                errors.append(f"Mismatched \\end{{{env_name}}} at position {match.start()}")

    # If there are any unmatched begin commands left in the stack, report them
    while stack:
        unmatched_env = stack.pop()
        errors.append(f"Unmatched \\begin{{{unmatched_env}}}")

    return errors

In [None]:

# Example usage
latex_code = r"""
\begin{document}
This is a sample document.
\begin{itemize}
    \item First item
    \begin{enumerate}
        \item First sub-item
    \end{enumerate}
    \item Second item
\end{itemize}
\end{document}
\begin{wrongenv}  % This environment is unmatched
"""

# Detect unbalanced environments
unbalanced = detect_unbalanced_environments(latex_code)

# Print the results
# if unbalanced:
#     print("Unbalanced environments detected:")
#     for error in unbalanced:
#         print(error)
# else:
#     print("All environments are balanced.")

assert unbalanced


latex_code = r"""
\begin{document}
This is a sample document.
\begin{itemize}
    \item First item
    \begin{enumerate}
        \item First sub-item
    \end{enumerate}
    \item Second item
\end{itemize}
\end{document}
"""
# Detect unbalanced environments
unbalanced = detect_unbalanced_environments(latex_code)
assert not unbalanced

In [None]:
#| export
def math_mode_string_is_syntactically_valid(
        text: str,
        ) -> bool:
    """
    Return `True` if `text` is determined to be syntactically valid
    as a latex str.

    There may be TeX syntax rules beyond the scope of this function.

    Some caveats:

    `text` is allowed to have dollar signs `$` and is also allowed to not have
    dollar signs. Even if `text` does not have dollar signs, this function
    may return `True`. Even if `text` has dollar signs, this function may return
    `False` if the entire string is not a singular math mode string or if the
    dollar signs are not used in a math-mode-valid way.
    """
    # 
    text = text.strip()
    math_mode_indices = latex_indices(text)
    if _has_unescaped_dollar(text):
        if len(math_mode_indices) != 1:
            return False
        if (math_mode_indices[0][0] != 0 or math_mode_indices[0][1] != len(text)):
            return False
    if not _does_not_end_with_script(text):
        return False
    if _detect_backslash_space_curly(text):
        return False
    if not _is_balanced_braces(text):
        return False
    if _has_invalid_left_right_bracket(text):
        return False
    if not _is_left_right_balanced(text):
        return False
    if _has_double_script(text):
        return False
    if _has_double_script_literal(text):
        return False
    if detect_incorrect_latex_commands(text):
        return False
    if bool(detect_unbalanced_environments(text)):
        return False
    return True



In [None]:
assert not math_mode_string_is_syntactically_valid(r'$$n=p_1^{e_1} p_2^{e_2} \cdots p_k^$$')
assert not math_mode_string_is_syntactically_valid(r'$x^2 + y^2')
assert not math_mode_string_is_syntactically_valid(r'$$x^2 + y^2$')
assert not math_mode_string_is_syntactically_valid(r'$$x^2 + y^2$ $')
assert math_mode_string_is_syntactically_valid(r'hi')
assert math_mode_string_is_syntactically_valid(r'$hi$')
assert not math_mode_string_is_syntactically_valid(r'$hi$$')
assert math_mode_string_is_syntactically_valid(r'$\\dim ^ a$')
assert not math_mode_string_is_syntactically_valid(r'{ hi')
assert math_mode_string_is_syntactically_valid(r'\{ hi')
assert math_mode_string_is_syntactically_valid(r'\ [')
assert math_mode_string_is_syntactically_valid(r'\left( \right.')
assert not math_mode_string_is_syntactically_valid(r'\left \right.')
assert math_mode_string_is_syntactically_valid(r'$$\left|\sum_{i=0} \right|$$')
assert math_mode_string_is_syntactically_valid(r'$\\\$$')
assert not math_mode_string_is_syntactically_valid(r'\begin{enumerate}')
assert math_mode_string_is_syntactically_valid(r'\begin{enumerate} asdf \end{enumerate}')
assert not math_mode_string_is_syntactically_valid(r'$$R=\sum_P\in X\operatorname length\left(\Omega__X / Y\right)_p\cdot P$$')
# TODO there is something to be considered here; the below
# example would be a syntax error, and yet the functions  implemented
# above don't really detect as such.
# assert not detect_incorrect_latex_commands(r'\sideset{_1^2}{_3^4}')

math_mode_string_is_syntactically_valid(r'\text\in')

False

The `math_mode_string_is_syntactically_valid` experimentally assesses whether a given math mode LaTeX string is syntactically valid. In principal, this should mean that a LaTeX syntax error caused by the string should be detected by the function.

TODO: consider the following to :


Unescaped % sign (starts a comment):
`$x = 50% of y$`

Using ! (negative space) at the beginning of math mode:
`$\!x + y$`

The following lists some example outputs of the `math_mode_string_is_syntactically_valid` function along with explanations.

Unmatched curly braces are a common syntactical error:

In [None]:
assert not math_mode_string_is_syntactically_valid(r'\sqrt{x}}')

However, using `\{` or `\}` does not count towards curly bracket matching:

In [None]:
assert math_mode_string_is_syntactically_valid(r'\{hi')

On the other hand, a backslash `\` followed by spaces ` ` and then followed by a curly bracket is in itself an invalid syntax.

In [None]:
assert not math_mode_string_is_syntactically_valid(r'\ {hi')

`math_mode_string_is_syntactically_valid` will consider the validity of a string whether or not the string has math mode delimiters. 

In [None]:
assert math_mode_string_is_syntactically_valid(r'\operatorname{Gal}')
assert math_mode_string_is_syntactically_valid(r'$\operatorname{Gal}$')

However, `math_mode_string_is_syntactically_valid` returns `False` if the string has dollar sign delimiters and more than one math mode string is detected in the string (use `latex_indices` to separate out math mode strings.),  

In [None]:
# More than one math mode string is present
assert not math_mode_string_is_syntactically_valid('$hi$ $bye$')
# the math mode delimiter `$` is unbalanced.
assert not math_mode_string_is_syntactically_valid(r'$x^2 + y^2')
# the math mode delimiters `$$` and `$` are unbalanced.
assert not math_mode_string_is_syntactically_valid(r'$$x^2 + y^2$')

In [None]:
#| export
# def math_mode_string_is_syntactically_clean(
#         text: str,
#         ) -> bool:
#     """
#     Return `True` if `text` is syntactically "clean" as a LaTeX math mode str.
    
#     While the precise meaning of this may be subjective, here we will
#     consider `text` to be clean, assuming that it is syntactically valid, if

#     - It does not have double blackslashes
#     """
#     if r'\\' in text:
#         return False

## Tweak a latex string

Sometimes, when autogenerating a latex string through an ML model, some minor formatting eyesores occur, such as a curly bracket `{` or an underscore `_` followed by an unncessary space. We provide some functions to fix such formatting.

In [None]:
#| export
def reduce_unnecessary_spaces(
        text: str,
        ) -> str:
    """
    Return a string modifying `text` by removing spaces which are
    unnecessary for the purposes of considering the string as a 
    LaTeX string.
    """
    pattern = r'([{_^\\()])\s+'
    text = re.sub(pattern, r'\1', text)
    pattern = r'\s+([}_^()])'
    text = re.sub(pattern, r'\1', text)
    return text
    # for char in ['{', '_', '^', '}', '\\']:
    #     text = re.sub(fr'\s*{chr}\s*', chr, text)

In [None]:

# It might not be necessary or desirable to eliminate the space before the backslash `\``
test_eq(reduce_unnecessary_spaces(r'something something \  operatorname'), r'something something \operatorname')
test_eq(reduce_unnecessary_spaces(r'\operatorname{Res}  ^ G_ H (R)'), r'\operatorname{Res}^G_H(R)')
test_eq(reduce_unnecessary_spaces(r'\operatorname{Res}^{ G}_{ H } (R)'), r'\operatorname{Res}^{G}_{H}(R)')
test_eq(reduce_unnecessary_spaces(r'M_{ f}'), r'M_{f}')
test_eq(reduce_unnecessary_spaces(r'h_{ p}'), r'h_{p}')
test_eq(reduce_unnecessary_spaces(r'\zeta (s)'), r'\zeta(s)')
test_eq(reduce_unnecessary_spaces(r'\mathcal{ H} _{ v}'), r'\mathcal{H}_{v}')

#### Make fixes to summary

In [None]:
#| export
def fix_autogen_formatting(
        text: str
        ) -> str:
    """Fix some latex formatting issues in an autogenerated text
    """
    text = text.replace(r'\ ', '\\')
    text = text.replace(r'{ ', r'{')
    text = text.replace(r' }', r'}')
    text, _ = re.subn(r'\$\s*([^\$]+?)\s*\$', r'$\1$', text)
    # TODO: do $ <latex_string> $ into $<latex_stinrg>$
    # TODO: if the replacement of r'\ ' by '\\' happesn to
    # make `\` stick to the previous chunk of things
    # (e.g. r'd\in\mathbb{Z}_{\geq 0}`, then give it some
    # space, e.g. r'd \in \mathbb{Z}_{\geq 0}'.
    text = reduce_unnecessary_spaces(text)
    text = _insert_newline_or_spaces_around_latex(text)
    return text


def _insert_newline_or_spaces_around_latex(
        text:  str
        ) -> str:
    """
    Insert spaces or newlines around latex math mode strings inside `text`
    if necessary.
    """
    math_mode_indices = latex_indices(text)
    replacements = []
    for start, end in math_mode_indices:
        math_mode = text[start:end]
        spaces_potentially_added = math_mode
        if not math_mode.startswith('$$'): #starts with exactly one $
            if start != 0 and text[start-1] != ' ':
                spaces_potentially_added = f' {spaces_potentially_added}'
            if end != len(text) and text[end] != ' ':
                spaces_potentially_added = f'{spaces_potentially_added} '
            replacements.append(spaces_potentially_added)
            continue
        if start != 0 and text[start-1] != '\n':
            front_newline_count = 2
        elif start > 1 and text[start-2] != '\n':
            front_newline_count = 1
        else:
            front_newline_count = 0
        spaces_potentially_added = front_newline_count * '\n' + spaces_potentially_added

        if end != len(text) and text[end] != '\n':
            back_newline_count = 2
        elif end < len(text) - 1 and text[end-1] != '\n':
            back_newline_count = 1
        else:
            back_newline_count = 0
        spaces_potentially_added = spaces_potentially_added + '\n'*back_newline_count
        replacements.append(spaces_potentially_added)
    text = replace_string_by_indices(text, math_mode_indices, replacements)
    text = text.replace('$  $', '$ $')
    text = text.replace('$$\n\n\n\n$$', '$$\n\n$$')
    text = text.replace('$$\n\n\n$$', '$$\n\n$$')
    return text

In [None]:
#| hide
sample_text = '$hi$'
test_eq(_insert_newline_or_spaces_around_latex(sample_text), sample_text)
sample_text = '$hi$asdf'
test_eq(_insert_newline_or_spaces_around_latex(sample_text), '$hi$ asdf')
sample_text = 'asdf$hi$'
test_eq(_insert_newline_or_spaces_around_latex(sample_text), 'asdf $hi$')
sample_text = 'asdf$hi$asdf'
test_eq(_insert_newline_or_spaces_around_latex(sample_text), 'asdf $hi$ asdf')


sample_text = '$$hi$$'
test_eq(_insert_newline_or_spaces_around_latex(sample_text), sample_text)
sample_text = 'asdf$$hi$$'
test_eq(_insert_newline_or_spaces_around_latex(sample_text), 'asdf\n\n$$hi$$')
sample_text = '$$hi$$asdf'
test_eq(_insert_newline_or_spaces_around_latex(sample_text), '$$hi$$\n\nasdf')
sample_text = 'asdf$$hi$$asdf'
test_eq(_insert_newline_or_spaces_around_latex(sample_text), 'asdf\n\n$$hi$$\n\nasdf')

sample_text = '$hi$ $hi$'
test_eq(_insert_newline_or_spaces_around_latex(sample_text), '$hi$ $hi$')
sample_text = '$hi$$hi$'
test_eq(_insert_newline_or_spaces_around_latex(sample_text), '$hi$ $hi$')
# sample_text = '$$hi$$ $$hi$$'
# test_eq(_insert_newline_or_spaces_around_latex(sample_text), '$$hi$$\n\n$$hi$$')
sample_text = '$$hi$$$$hi$$'
test_eq(_insert_newline_or_spaces_around_latex(sample_text), '$$hi$$\n\n$$hi$$')

Currently, the model is inclined to decode and format its summarizations in such a way that creates formatting issues either for LaTeX or `Obsidian.md`. For example, the model would output a str containing

- `\ <command_name>` instead of `\<command_name>`
- `{ ` when `{` is preferable
- `$ <latex_string> $` when `$<latex_string>$` is needed for `Obsidian.md`.

The `fix_summary_formatting` function attempts to get around some of these issues.

In [None]:
text = r'\ to'
sample_output = fix_autogen_formatting(text)
assert r'\to' in sample_output

text = r'$d\ in\ mathbb{ Z}_{\ geq 0} $'
sample_output = fix_autogen_formatting(text)
assert r'\in' in sample_output
assert r'\mathbb{Z}' in sample_output
assert r'\geq 0' in sample_output


In [None]:
text = r'There are some extra spaces in this math mode string: $  5 + 7 = 12 $.'
sample_output = fix_autogen_formatting(text)
print(sample_output)
assert r'$5' in sample_output
assert r'12$' in sample_output

There are some extra spaces in this math mode string: $5 + 7 = 12$ .


In [None]:
text=  r'the group of $G$-coinvariants of $A$. It is defined as $$A_{G} :=A / I_\G} A$$'
sample_output = fix_autogen_formatting(text)
print(sample_output)

the group of $G$ -coinvariants of $A$ . It is defined as 

$$A_{G} :=A / I_\G} A$$


## Correct syntax errors in autogenerated math mode strings

In [None]:
#| export
def _tokenize_latex_math(
        latex_string: str
        ) -> list[str]:
    """
    Tokenize `latex_string` by the following principles:

    1. A latex command/macro invoked (but not the inputs) is a token.
    2. the special characters ^ { } _ are tokens.
    3. groups of consecutive whitespaces are tokens.
    4. afterwards, all "words" (one or more consecutive non-whitespace non-special characters) are tokens.
    """
    # Define the regex pattern for tokenization
    pattern = r"""
        (\\[a-zA-Z]+)        # Match LaTeX commands (e.g., \alpha, \sum)
        | ([^\\\s^{}_]+)     # Match words (consecutive non-whitespace, non-special characters)
        | ([^\\\s])          # Match special characters (including ^, {, }, _, etc.)
        | (\s+)              # Match groups of consecutive whitespace
    """
    # Use re.findall to find all matches based on the pattern
    tokens = re.findall(pattern, latex_string, re.VERBOSE)
    # Extract the matched groups, filtering out empty strings
    token_list = [token for group in tokens for token in group if token]
    return token_list


In [None]:
#| hide
# Example usage
latex_string = r"\alpha + \beta^{2} - \gamma_{1} + 3 \times \text{some text}"
tokens = _tokenize_latex_math(latex_string)
# print(tokens)
test_eq(
    ['\\alpha', ' ', '+', ' ', '\\beta', '^', '{', '2', '}', ' ', '-', ' ', '\\gamma', '_', '{', '1', '}', ' ', '+', ' ', '3', ' ', '\\times', ' ', '\\text', '{', 'some', ' ', 'text', '}'],
    tokens
    )
test_eq(''.join(tokens), latex_string)

In [None]:
#| export
def _list_of_candidates_from_math_mode_strings(
        main_content: str, # A text of LaTeX code. In practice, this should be the `main content` of an information note, cf. `summarize_notation`.`
        syntax_validation: Callable[str, bool] = math_mode_string_is_syntactically_valid # A test to tell whether a math mode string is syntactically  valid.
        ) -> set[str]:
    """
    Return a substrings from latex math mode strings in `main_content`
    that are syntactically valid .

    None of the elements in the output have delimiters (`$`, `$$`)
    """
    syntactically_valid_substrings = [] 
    math_mode_indices = latex_indices(main_content)
    for start, end in math_mode_indices:
        latex_str = main_content[start:end]
        latex_str = latex_str.strip('$')
        tokenization = _tokenize_latex_math(latex_str)
        for sublist in sublist_generator(tokenization):
            substring = ''.join(sublist)
            if syntax_validation(substring):
                syntactically_valid_substrings.append(substring.strip())
    return set(syntactically_valid_substrings)

In [None]:
#| hide
output = _list_of_candidates_from_math_mode_strings(r'$\operatorname{Gal}(L/K)$', math_mode_string_is_syntactically_valid)
assert r'\operatorname{Gal}' in output
assert r'Gal' in output

output = _list_of_candidates_from_math_mode_strings(r'$\operatorname{Gal}(L/K) \to G_\ell^\infty$', math_mode_string_is_syntactically_valid)

In [None]:
#| hide
_list_of_candidates_from_math_mode_strings(r'the signum of the complete factorization $\\text\\in S_n$ into disjoint cycles. It is defined by$$\\operatorname sgn(\\left )=(-1)n-t .$$')

{'',
 ')=(-1)n-t',
 ')=(-1)n-t .',
 '.',
 'S',
 'S_n',
 '\\in',
 '\\in S',
 '\\in S_n',
 '\\operatorname',
 '\\operatorname sgn(',
 '_n',
 'n',
 'sgn('}

In [None]:
#| export
def _find_closest_match(
        math_mode_text: str,
        replacement_candidates: list[str]
        ) -> Union[str, None]:
    """This is a helper function to `correct_latex_syntax_error`."""
    if not replacement_candidates:
        return None
    # Calculate Levenshtein distance for each candidate
    distances = [(candidate, distance(math_mode_text, candidate)) for candidate in replacement_candidates]
    # Find the candidate with the minimum distance
    closest_match = min(distances, key=lambda x: x[1])
    return closest_match[0]

In [None]:
#| hide
test_eq(_find_closest_match('hi', ['hib', 'basdy']), 'hib')

In [None]:
#| export
def correct_latex_syntax_error(
        summary: str, # The autogenerated summary
        replacement_candidates: list[str], # A list of candidates to replace. This is expected to be an output of `_list_of_candidates_from_math_mode_strings`
        # min_length_to_replace_math_mode_string: int = 5, # The minimum length that a math mode string needs to be (exclusing delimiting dollar signs `$`, `$$`) in summary in order to be considered for replacement.
        syntax_validation: Callable[str, bool] = math_mode_string_is_syntactically_valid # A test to tell whether a math mode string is syntactically  valid.
        ) -> str:
    """
    Attempt to replace within `summary` a modified version in which
    the syntactically incorrect latex math mode strings are replaced
    with the most closely resembling element of `replacement_candidates`. 
     
    with a modified version in which the
    latex math mode strings within `summary` that are syntactically
    incorrect 

    TODO: consider the possibility that not all math mode str delimiters
    are formatted correctly.
    """
    math_mode_indices = latex_indices(summary)
    replacements = []
    for start, end in math_mode_indices:
        math_mode_text = summary[start:end]
        if syntax_validation(math_mode_text) or not replacement_candidates:
            replacements.append(math_mode_text)
            continue
        delimiter = '$$' if math_mode_text.startswith('$$') else '$'
        replacement = _find_closest_match(math_mode_text, replacement_candidates)
        replacement = f'{delimiter}{replacement}{delimiter}'
        replacements.append(replacement)
    return replace_string_by_indices(summary, math_mode_indices, replacements)



In [None]:
sample_summary = r'the group of $G$-coinvariants of $A$. It is defined as $$A_{G} :=A / I_\G} A$$'
replacement_candidates = [
    'A',
    'A_',
    'A_{G}',
    'A_{G}:=A',
    'A_{G}:=A',
    'A_{G}:=A /',
    'A_{G}:=A / I_{G}',
    'A_{G}:=A / I_{G} A',
    'H_{0}(G, A)',
    'H_{0}(G, A) \\simeq',
    'H_{0}(G, A) \\simeq A',
    'H_{0}(G, A) \\simeq A_',
    'H_{0}(G, A) \\simeq A_{G}',
]
test_eq(correct_latex_syntax_error(sample_summary, replacement_candidates), r'the group of $G$-coinvariants of $A$. It is defined as $$A_{G}:=A / I_{G} A$$')
# replacement_candidates

## Augment latex text

For data augmentation, it can be useful to introduce latex typos intentionally. The following functions do so.

### Modify just latex str

In [None]:
#| export

FONT_STYLE_COMMANDS = [
    "mathscr",
    "mathcal",
    "mathfrak",
    "mathbb",
    "mathbf",
    "mathrm",
    "operatorname",
    "text",
    ]
UNCOMMON_FONT_STYLE_COMMANDS = [
    "mathit",
    "mathsf",
    "mathtt",
]
# COMMON_FONT_STYLE_TYPOS = {
#     "mathscr": {"mathcal", "mathfrak"},
#     "mathcal": {"mathscr"},
#     "mathrm": {"operatorname"},
#     "mathrmfrak": {"mathcal", "mathbf"}
# }





def modify_at_random(
        latex_string: str, # A latex str, surrounded by dollar signs (either single or double) as necessary.
        pattern: Union[str,re.Pattern],
        chance: float, # The chance that each change is performed
        replace_func: Callable[[re.Match, float], str],
        seed: Optional[int] = None
    ) -> str:
    # Set the random seed if provided
    if seed is not None:
        random.seed(seed)
    result = re.sub(pattern, lambda x: replace_func(x, chance), latex_string)
    return result



In [None]:
#| export
def remove_font_styles_at_random(
        latex_string: str, # A latex str, surrounded by dollar signs (either single or double) as necessary.
        p: float = 0.05, # The chance that each font styling comand is removed
        seed: Optional[int] = None
        ) -> str: 
    """Randomly remove font style commands at random from `latex_string`.
    """
    # Combine all font style commands
    all_commands = FONT_STYLE_COMMANDS # + UNCOMMON_FONT_STYLE_COMMANDS
    # Create a regex pattern to match all font style commands
    pattern = r'\\(' + '|'.join(all_commands) + r')\s*\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}'
    def replace_func(match, p: float):
        # Randomly decide whether to remove the command
        if random.random() < p:
            # If removed, return only the content inside the braces
            return match.group(2)
        else:
            # If not removed, return the original match
            return match.group(0)
    return modify_at_random(latex_string, pattern, p, replace_func, seed)

In [None]:
# Test 1: Basic removal
latex = r"$\mathbf{Bold} and \mathcal{Calligraphic}$"
result = remove_font_styles_at_random(latex, p=1.0, seed=42)
assert result == "$Bold and Calligraphic$"

# Test 2: No removal
latex = r"$\mathbf{Bold} and \mathcal{Calligraphic}$"
result = remove_font_styles_at_random(latex, p=0.0, seed=42)
assert result == latex

# Test 3: Partial removal
latex = r"$\mathbf{Bold} and \mathcal{Calligraphic} and \mathfrak{Fraktur}$"
result = remove_font_styles_at_random(latex, p=0.5, seed=42)
assert result != latex
assert result != "$Bold and Calligraphic and Fraktur$"

# Test 4: Nested commands
# latex = r"$\mathbf{\mathcal{Nested}}$"
# result = remove_font_styles_at_random(latex, chance=1.0, seed=42)
# assert result == "$Nested$"

# Test 5: Uncommon commands
# latex = r"$\mathtt{Typewriter} and \mathsf{Sans Serif}$"
# result = remove_font_styles_at_random(latex, chance=1.0, seed=42)
# assert result == "$Typewriter and Sans Serif$"

# Test 6: Text and operatorname
latex = r"$\text{Plain text} and \operatorname{sin}(x)$"
result = remove_font_styles_at_random(latex, p=1.0, seed=42)
assert result == "$Plain text and sin(x)$"

# Test 7: No commands present
latex = "$x + y = z$"
result = remove_font_styles_at_random(latex, p=1.0, seed=42)
assert result == latex

# Test 8: Multiple dollar signs
latex = r"$$\mathbf{Equation}: E = mc^2$$"
result = remove_font_styles_at_random(latex, p=1.0, seed=42)
assert result == "$$Equation: E = mc^2$$"

# Test 9: Seed consistency
latex = r"$\mathbf{Bold} and \mathcal{Calligraphic}$"
result1 = remove_font_styles_at_random(latex, p=0.5, seed=42)
result2 = remove_font_styles_at_random(latex, p=0.5, seed=42)
assert result1 == result2

# Test 10: Different seeds
latex = r"$\mathbf{Bold} and \mathcal{Calligraphic}$"
result1 = remove_font_styles_at_random(latex, p=0.5, seed=42)
result2 = remove_font_styles_at_random(latex, p=0.5, seed=43)
assert result1 != result2


In [None]:
#| export

def change_font_styles_at_random(
        latex_string: str,
        p: float = 0.1,
        seed: Optional[int] = None
        ) -> str:
    """Randomly change font style commands in `latex_string`."""
    all_commands = FONT_STYLE_COMMANDS # + UNCOMMON_FONT_STYLE_COMMANDS
    pattern = r'\\(' + '|'.join(all_commands) + r')\s*\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}'
    def replace_func(match, p: float):
            if random.random() < p:
                current_command = match.group(1)
                new_command = random.choice([cmd for cmd in all_commands if cmd != current_command])
                return f"\\{new_command}{{{match.group(2)}}}"
            return match.group(0)
    return modify_at_random(latex_string, pattern, p, replace_func, seed)

In [None]:

# Test 1: Basic functionality
latex = r"$\mathbf{Bold}$"
result = change_font_styles_at_random(latex, p=1.0, seed=42)
# print(result)
test_ne(result, latex)
test_eq(re.match(r"\$\\\w+{Bold}\$", result) is not None, True)


# Test 2: No change with chance 0
latex = r"$\mathcal{Calligraphic}$"
result = change_font_styles_at_random(latex, p=0.0, seed=42)
test_eq(result, latex)

# Test 3: Multiple commands
latex = r"$\mathbf{Bold} and \mathcal{Calligraphic}$"
result = change_font_styles_at_random(latex, p=1.0, seed=42)
test_ne(result, latex)
test_eq(re.match(r"\$\\\w+{Bold} and \\\w+{Calligraphic}\$", result) is not None, True)

# Test 4: Nested commands
latex = r"$\mathbf{\mathcal{Nested}}$"
result = change_font_styles_at_random(latex, p=1.0, seed=42)
test_ne(result, latex)
test_eq(re.match(r"\$\\\w+{\\\w+{Nested}}\$", result) is not None, True)

# Test 5: Uncommon commands
# latex = r"$\mathtt{Typewriter}$"
# result = change_font_styles_at_random(latex, p=1.0, seed=42)
# test_ne(result, latex)
# test_eq(re.match(r"\$\\\w+{Typewriter}\$", result) is not None, True)

# Test 6: Consistency with same seed
latex = r"$\mathbf{Bold} and \mathcal{Calligraphic}$"
result1 = change_font_styles_at_random(latex, p=0.5, seed=42)
result2 = change_font_styles_at_random(latex, p=0.5, seed=42)
test_eq(result1, result2)

# Test 7: Different results with different seeds
latex = r"$\mathbf{Bold} and \mathcal{Calligraphic}$"
result1 = change_font_styles_at_random(latex, p=0.5, seed=42)
result2 = change_font_styles_at_random(latex, p=0.5, seed=43)
test_ne(result1, result2)


In [None]:
#| export
# List of Greek letters in LaTeX
GREEK_LETTERS = [
    'alpha', 'beta', 'gamma', 'delta', 'epsilon', 'zeta', 'eta', 'theta', 'iota', 'kappa', 'lambda', 'mu', 
    'nu', 'xi', 'omicron', 'pi', 'rho', 'sigma', 'tau', 'upsilon', 'phi', 'chi', 'psi', 'omega',
    'Gamma', 'Delta', 'Theta', 'Lambda', 'Xi', 'Pi', 'Sigma', 'Upsilon', 'Phi', 'Psi', 'Omega'
]

def change_greek_letters_at_random(
        latex_string: str,
        p: float = 0.05
) -> str:
    """Randomly change Greek letters in `latex_string`."""
    def replace_func(match, p: float):
        if random.random() < p:
            current_letter = match.group(1)
            new_letter = random.choice([l for l in GREEK_LETTERS if l != current_letter])
            return f"\\{new_letter}"
        return match.group(0)
    # Pattern to match Greek letters
    pattern = r'\\(' + '|'.join(GREEK_LETTERS) + r')\b'
    # Apply the replacement
    result = re.sub(pattern, lambda x: replace_func(x, p), latex_string)
    return result


In [None]:
# Test 1: Basic functionality
latex = r"$\alpha + \beta = \gamma$"
result = change_greek_letters_at_random(latex, p=1.0)
test_ne(result, latex)
# test(all(letter in result for letter in [r'\alpha', r'\beta', r'\gamma']))

# Test 2: No change with chance 0
latex = r"$\delta \times \epsilon$"
result = change_greek_letters_at_random(latex, p=0.0)
test_eq(result, latex)

# Test 3: Mixed content
latex = r"$f(x) = \theta x + \phi$"
result = change_greek_letters_at_random(latex, p=1.0)
test_ne(result, latex)
# test(all(letter in result for letter in [r'\theta', r'\phi']))
# test('f(x) =' in result)

# Test 4: Uppercase Greek letters
latex = r"$\Gamma(x) + \Delta y = \Omega$"
result = change_greek_letters_at_random(latex, p=1.0)
test_ne(result, latex)
# test(all(letter in result for letter in [r'\Gamma', r'\Delta', r'\Omega']))

# Test 5: Multiple occurrences
latex = r"$\alpha + \alpha = 2\alpha$"
result = change_greek_letters_at_random(latex, p=1.0)
test_ne(result, latex)
# test(result.count('\\') == 3)  # Ensure all Greek letters were changed


In [None]:
#| export
# def push_dollar_signs_surrounding_latex(
#         latex_string: str, # A latex str, surrounded by dollar signs (either single or double) as necessary.
#         remove_pushed_out_font_style_command: bool = True
#         ) -> str:
#     """
#     Modify `latex_string` so that in effect, dollar signs are 
#     """
#     return ""

### Modify latex str 

In [None]:
#| export
def random_char_modification(text, p=0.05):
    """
    Randomly change characters in `text`.
    """
    chars = list(text)
    all_chars = string.ascii_letters + string.digits + string.punctuation + ' '
    for i in range(len(chars)):
        if random.random() < p:
            action = random.choice(['delete', 'add', 'modify'])
            if action == 'delete':
                chars[i] = ''
            elif action == 'add':
                chars.insert(i, random.choice(all_chars))
            else:
                chars[i] = random.choice(all_chars)
    return ''.join(chars)


def dollar_sign_manipulation(text, p=0.05):
    """
    Either delete or move dollar signs (which are usually there for latex math mode) from `text`,
    while preserving all whitespace characters.
    """
    # Split the text into tokens, preserving whitespace
    tokens = re.split(r'(\s+)', text)
    
    # Find indices of non-whitespace tokens containing '$'
    dollar_indices = [i for i, token in enumerate(tokens) if '$' in token and not token.isspace()]
    
    for i in dollar_indices:
        if random.random() < p:
            action = random.choice(['delete', 'move'])
            if action == 'delete':
                tokens[i] = tokens[i].replace('$', '')
            else:
                if len(dollar_indices) > 1:
                    # Find a new position for the dollar sign
                    possible_positions = [pos for pos in dollar_indices if pos != i]
                    new_pos = random.choice(possible_positions)
                    
                    # Move the dollar sign
                    tokens[new_pos] += '$'
                    tokens[i] = tokens[i].replace('$', '')
                else:
                    # If there's only one dollar sign, we can't move it, so we'll delete it instead
                    tokens[i] = tokens[i].replace('$', '')

    return ''.join(tokens)

# def dollar_sign_manipulation(text, p=0.05):
#     """
#     Either delete or move dollar signs (which are usually there for latex math mode) from `text`.
#     """
#     words = text.split()
#     for i in range(len(words)):
#         if '$' in words[i] and random.random() < p:
#             action = random.choice(['delete', 'move'])
#             if action == 'delete':
#                 words[i] = words[i].replace('$', '')
#             else:
#                 new_pos = random.randint(0, len(words) - 1)
#                 words[new_pos] += '$'
#                 words[i] = words[i].replace('$', '')
#     return ' '.join(words)

def remove_math_keywords(text, p=0.05):
    """
    Remove all mentions of Definition/Theorem/Remark, etc.
    """
    keywords = r"(Definition|Remark|Proposition|Exercise|Example|Theorem|Lemma|Corollary)\s+\w+(\.\w+){1,3}"
    def random_remove(match):
        if random.random() < p:
            return ''
        else:
            return match.group(0)
    return re.sub(keywords, random_remove, text)
    # keywords = r"(Definition|Remark|Proposition|Exercise|Example|Theorem|Lemma|Corollary)\s+\w+(\.\w+){1,3}"
    # return re.sub(keywords, '', text)

def random_word_removal(text, p=0.05):
    """
    Randomly remove words while preserving all whitespace characters.
    """
    # Split the text into tokens, preserving whitespace
    tokens = re.split(r'(\s+)', text)
    
    # Process non-whitespace tokens
    result = []
    for token in tokens:
        if token.strip():  # If the token is not just whitespace
            if random.random() > p:
                result.append(token)
        else:
            result.append(token)  # Always keep whitespace tokens
    
    return ''.join(result)

# def random_word_removal(text, p=0.05):
#     """
#     Randomly remove words
#     """
#     words = text.split()
#     return ' '.join(word for word in words if random.random() > p)

def random_latex_command_removal(text, p=0.1):
    """
    Randomly remove latex commands
    """
    return re.sub(r'\\[a-zA-Z]+(\{[^}]*\})?', lambda m: m.group(0) if random.random() > p else '', text)


In [None]:
#| export
# TODO: this function, as implemented, is very buggy
def push_dollar_signs(
        latex: str,
        p: float = 0.1, # Push probability
        seed: int = None,
        return_indices_of_math_mode_content: bool = False # If `True`, additionally return a `list[tuple[int, int]]` of indices within the outputted `str` signifiying the location of what was essentially the content of the original math mode.
        ) -> tuple[str, list[tuple[int, int]]]:
    """
    Push dollar signs delimiting math mode into each other at random within a text.
    """
    
    if seed is not None:
        random.seed(seed)

    font_commands = FONT_STYLE_COMMANDS + UNCOMMON_FONT_STYLE_COMMANDS 
    font_commands = [rf'\{font_command}' for font_command in font_commands]
    
    # Find all single and double dollar sign positions
    dollar_positions = [(m.start(), m.end()) for m in re.finditer(r'\${1,2}', latex)]
    
    # Ensure we have an even number of dollar sign groups
    if len(dollar_positions) % 2 != 0:
        dollar_positions = dollar_positions[:-1]

    new_indices_of_math_mode_content: list[tuple[int, int]] = []
    for i in range(0, len(dollar_positions), 2):
        start, end = dollar_positions[i], dollar_positions[i+1]
        if random.random() >= p:
            continue
        # if random.random() < push_probability:
            # Decide which dollar sign group to push
        push_start = random.choice([True, False])
        
        if push_start:
            new_start = push_dollar_sign(latex, start[1], direction='right')
            dollar_length = start[1] - start[0]
            new_starting_part = latex[:start[0]]
            pushed_out_math_mode_part = remove_split_commands(latex[start[1]:new_start], font_commands).rstrip() + ' ' + '$'*dollar_length
            new_ending_part = latex[new_start:].lstrip()
            latex = new_starting_part + pushed_out_math_mode_part + new_ending_part 
            # latex = latex[:start[0]] + pushed_out_math_mode_part + latex[new_start:].lstrip()
            new_indices_of_math_mode_content.append(
                (start[0], start[0] + len(pushed_out_math_mode_part) + new_ending_part.index('$'*dollar_length) + dollar_length))
        else:
            new_end = push_dollar_sign(latex, end[0], direction='left')
            dollar_length = end[1] - end[0]
            new_starting_part = latex[:new_end].rstrip()
            pushed_out_math_mode_part = '$'*dollar_length + ' ' + remove_split_commands(latex[new_end:end[0]], font_commands).lstrip()
            new_ending_part = latex[end[1]:]
            latex = new_starting_part + pushed_out_math_mode_part + new_ending_part
            last_dollar_in_new_starting_part = new_starting_part.rindex('$'*dollar_length)
            new_indices_of_math_mode_content.append(
                # (start[0], start[0] + len(pushed_out_math_mode_part))
                (last_dollar_in_new_starting_part, len(new_starting_part) + len(pushed_out_math_mode_part))
                )
    
        # Update dollar positions for the next pair
        dollar_positions = [(m.start(), m.end()) for m in re.finditer(r'\${1,2}', latex)]

    if return_indices_of_math_mode_content:
        return latex, new_indices_of_math_mode_content
    else:
        return latex

def push_dollar_sign(latex: str, pos: int, direction: str) -> int:
    """Push the dollar sign in the specified direction to the next word boundary."""
    if direction == 'right':
        next_space = latex.find(' ', pos)
        if next_space == -1:
            return len(latex)
        next_non_space = next_space
        while next_non_space < len(latex) and latex[next_non_space].isspace():
            next_non_space += 1
        return next_non_space
    else:  # left
        prev_space = latex.rfind(' ', 0, pos)
        if prev_space == -1:
            return 0
        prev_non_space = prev_space
        while prev_non_space > 0 and latex[prev_non_space-1].isspace():
            prev_non_space -= 1
        return prev_non_space

def remove_split_commands(latex: str, commands: List[str]) -> str:
    """Remove font style commands from the latex string."""
    for cmd in commands:
        pattern = re.escape(cmd) + r'\s*\{([^}]*)\}'
        latex = re.sub(pattern, r'\1', latex)
    return latex

In [None]:
# Test the function
latex1 = r"This is $\mathbf{bold} and \mathrm{roman}$ text"
latex2 = r"This is $$\mathbf{bold} and \mathrm{roman}$$ text"
result1 = push_dollar_signs(latex1, p=0.7, seed=17)
result2 = push_dollar_signs(latex2, p=0.7, seed=17)
print(result1)
print(result2)

This is $\mathbf{bold} and$ roman text
This is $$\mathbf{bold} and$$ roman text


In [None]:

latex = r"This is $\mathbf{bold} and \mathrm{roman}$ text"
result = push_dollar_signs(latex, p=1.0, seed=42)
print(result)
assert r'\mathbf' not in result
assert r'\mathrm' in result

latex = r"This is $\mathbf{bold} and \mathrm{roman}$ text"
result = push_dollar_signs(latex, p=1.0, seed=17)
print(result)
assert r'\mathrm' not in result
assert r'\mathbf' in result

This is bold $and \mathrm{roman}$ text
This is $\mathbf{bold} and$ roman text


In [None]:
latex = r"This is $\mathbf{bold}_a^b and \mathrm{roman}$ text"
result = push_dollar_signs(latex, p=1.0, seed=42)
print(result)

This is bold_a^b $and \mathrm{roman}$ text


In [None]:
latex = r"This is $\mathbf{bold}and \mathrm{roman}$ text"
result = push_dollar_signs(latex, p=1.0, seed=42)
print(result)

This is boldand $\mathrm{roman}$ text


In [None]:
latex = r"This is $\mathbf{bold}_a^b and \mathrm{roman}$ text; This is $$\mathbf{bold}_a^b and \mathrm{roman}$$ text"
result, original_math_mode_content_indices = push_dollar_signs(latex, p=1.0, seed=17, return_indices_of_math_mode_content=True)
print(result)
print(original_math_mode_content_indices)
print(result[original_math_mode_content_indices[0][0]:original_math_mode_content_indices[0][1]])
print(result[original_math_mode_content_indices[1][0]:original_math_mode_content_indices[1][1]])

This is $\mathbf{bold}_a^b and$ roman text; This is $$\mathbf{bold}_a^b and$$ roman text
[(8, 37), (52, 83)]
$\mathbf{bold}_a^b and$ roman
$$\mathbf{bold}_a^b and$$ roman


In [None]:
print(result[52:126])

$$\mathbf{bold}_a^b and$$ roman text


In [None]:
latex = r"This is $A + B=C$"
result = push_dollar_signs(latex, p=1.0, seed=43)
print(result)

This is A $+ B=C$


In [None]:
latex = r"This is $\mathbf{bold} and \mathrm{roman}$ text"
result, original_math_mode_content_indices = push_dollar_signs(latex, p=1.0, seed=42, return_indices_of_math_mode_content=True)
print(result)
print(original_math_mode_content_indices)
print(result[original_math_mode_content_indices[0][0]:original_math_mode_content_indices[0][1]])

This is bold $and \mathrm{roman}$ text
[(8, 33)]
bold $and \mathrm{roman}$


In [None]:
latex = r"This is $\mathbf{bold} and \mathrm{roman}$ text"
result, original_math_mode_content_indices = push_dollar_signs(latex, p=1.0, seed=17, return_indices_of_math_mode_content=True)
print(result)
print(original_math_mode_content_indices)
print(result[original_math_mode_content_indices[0][0]:original_math_mode_content_indices[0][1]])

This is $\mathbf{bold} and$ roman text
[(8, 33)]
$\mathbf{bold} and$ roman


In [None]:
#| export

### Using the modification functions to augment text

In [None]:
#| export
def augment_text(
        text: str,
        methods: list[Callable[[str], str]],
        ) -> str:
    """
    Augment `text` by applying modification methods.
    """
    for method in methods:
        text = method(text)
    return text

In [None]:
#| export
# def add_typos(
#         latex_string: str, # A latex str, surrounded by dollar signs (either single or double) as necessary.
#         seed: Optional[int] = None
#         ) -> str: # A new str that is a modification of `latex_string` with "typos".

#     return ""



In [None]:
#| export
def _create_method(method, p, scale):
    """
    Helper function to `choose_modification_methods_at_random`
    """
    return lambda x: method(x, p=p*scale)

In [None]:
#| export
def choose_modification_methods_at_random(
        methods: list[tuple[Callable, float]],
        method_inclusion_chance = float, # The chance to include each method
        scale = float, # The amount by which to "scale" the method's tendency to modify the text.
        ) -> list[Callable[[str], str]]:
    random_methods: list[Callable[[str], str]] = []
    for method, p in methods:
        if random.random() < method_inclusion_chance:
            random_methods.append(_create_method(method, p, scale))
    return random_methods