In [43]:
#| default_exp helper.latex.__init__

# helper.latex
> Helper functions for latex functionalities


We remark that many of the functions in this module are AI generated or assisted.

In [44]:
#| export
import re
from typing import Callable, Union

from Levenshtein import distance
from pylatexenc.latexwalker import LatexCharsNode, LatexEnvironmentNode, LatexNode, LatexWalker, LatexMacroNode

from trouver.helper import sublist_generator
from trouver.helper.regex import latex_indices, replace_string_by_indices
from trouver.helper.latex.macros_and_commands import regex_pattern_detecting_command

In [45]:
from fastcore.test import *

## Validity of latex syntax

#### Test latex syntax

We require some functions to evaluate whether a latex math mode string is syntactically valid.

In [46]:
#| export
def _does_not_end_with_script(s):
    """
    This is a helper function to `math_mode_string_is_syntactically_valid`.
    `s` is not supposed to have math mode delimits
    """
    s = s.strip(' $')
    return not s.endswith('_') and not s.endswith('^')

In [47]:
#| hide
assert not _does_not_end_with_script(r'n=p_1^{e_1} p_2^{e_2} \cdots p_k^')
assert not _does_not_end_with_script(r'$n=p_1^{e_1} p_2^{e_2} \cdots p_k^$')
assert not _does_not_end_with_script(r'$$n=p_1^{e_1} p_2^{e_2} \cdots p_k^$$')
assert _does_not_end_with_script(r'n=p_1^{e_1} p_2^{e_2} \cdots p_k')
assert _does_not_end_with_script(r'$n=p_1^{e_1} p_2^{e_2} \cdots p_k$')
assert _does_not_end_with_script(r'$$n=p_1^{e_1} p_2^{e_2} \cdots p_k$$')
assert not _does_not_end_with_script(r'n=p_1^{e_1} p_2^{e_2} \cdots p_k_')


In [48]:
#| export
def _is_balanced_braces(s):
    """
    This is a helper function to `math_mode_string_is_syntactically_valid`.

    Note that curly braces (`{`, `}`) that have 
    """
    stack = []
    escaped = False
    
    for _, char in enumerate(s):
        if char == '\\':
            escaped = True
        elif char == '{' and not escaped:
            stack.append(char)
        elif char == '}' and not escaped:
            if not stack:
                return False
            stack.pop()
        else:
            escaped = False
    
    return len(stack) == 0


In [49]:
#| hide
assert _is_balanced_braces('{{}}')
assert _is_balanced_braces('{asdf_{}}')
assert not _is_balanced_braces('{hi')
assert not _is_balanced_braces('{hi asdf}}')
assert _is_balanced_braces(r'\{hi')
assert _is_balanced_braces(r'\{{hi}')
assert not _is_balanced_braces(r'\{{hi\}')

In [50]:
#| export
def _detect_backslash_space_curly(
        text: str
        ) -> bool:
    """
    Return `True` if there is some backslash `\` followed
    by spaces and then followed by curly brackets `{`

    Note that the presence of such a combination of text
    will induce a syntax error in LaTeX math mode string.

    This is a helper function of `math_mode_string_is_syntactically_valid`
    """
    pattern = r'\\\s+[{}]'
    match = re.search(pattern, text)
    return bool(match)

In [51]:
#| hide
assert _detect_backslash_space_curly(r'\ {')
assert not _detect_backslash_space_curly(r'\{')
assert not _detect_backslash_space_curly(r'{')
assert _detect_backslash_space_curly(r'\ }')

In [52]:
#| export
def _is_left_right_balanced(
        latex_string: str
        ) -> bool:
    """
    Return `True` if occurrences of `\left` and `\right` are balanced. 

    This is a helper function of `math_mode_string_is_syntactically_valid`

    This function does not test whether occurrences of the
    appropriately corresponding braces are balanced. For instance,
    the function would return `True` on the input `\left . \right)`.
    Compare against `_is_semantically_left_right_balanced`, which
    is a similar function that tests whether left-right braces
    are "semantically" balanced.
    """
    # Remove all whitespace from the string
    latex_string = ''.join(latex_string.split())
    
    # Find all \left and \right commands
    left_commands = re.findall(r'\\left', latex_string)
    right_commands = re.findall(r'\\right', latex_string)
    
    # Check if the number of \left and \right commands are equal
    if len(left_commands) != len(right_commands):
        return False
    
    # Check if \left always comes before \right
    left_indices = [m.start() for m in re.finditer(r'\\left', latex_string)]
    right_indices = [m.start() for m in re.finditer(r'\\right', latex_string)]
    
    for left, right in zip(left_indices, right_indices):
        if left > right:
            return False
    
    return True



    # # Dictionary to store counts of left and right commands for each brace type
    # brace_counts = {
    #     '(': 0, ')': 0,
    #     '[': 0, ']': 0,
    #     '{': 0, '}': 0,
    #     '|': 0,
    #     '\\|': 0,
    #     '\\{': 0, '\\}': 0,
    #     '.': 0  # For \left. and \right.
    # }
    
    # # Regular expression to match \left and \right commands
    # pattern = r'\\(left|right)(\(|\)|\[|\]|{|}|\||\\\||\\{|\\}|\.)'
    
    # # Find all matches in the latex_string
    # matches = re.finditer(pattern, latex_string)
    
    # for match in matches:
    #     command, brace = match.groups()
    #     if command == 'left':
    #         brace_counts[brace] += 1
    #     elif command == 'right':
    #         brace_counts[brace] -= 1
    
    # # Check if all counts are zero (balanced)
    # return all(count == 0 for count in brace_counts.values()) 

In [53]:
#| hide
assert _is_left_right_balanced(r"\left( x \right)")
assert _is_left_right_balanced(r"\left(x\right)")
assert _is_left_right_balanced(r"\left( x \left[ y \right] \right)")
assert _is_left_right_balanced(r"\left( x \right) \left[ y \right]")
assert not _is_left_right_balanced(r"\left( x \left[ y \right)")
assert _is_left_right_balanced(r"\left( x \right] \left[ y \right)")
assert _is_left_right_balanced(r"x + y")
assert not _is_left_right_balanced(r"\right)")
assert _is_left_right_balanced(r"\left. \right)")


In [54]:
#| export
def _is_semantically_left_right_balanced(
        latex_string: str
        ) -> bool:
    """
    Return `True` if occurrences of `\left` and `\right` macros
    preceding various braces are balanced.

    This is a helper function of `math_mode_string_is_syntactically_clean`

    Compare against `_is_left_right_balanced`, which
    is a similar function that only tests whether left-right
    macros are "syntactically" balanced, without regard to
    the types of braces actually used.
    """
    # Remove all whitespace from the string
    latex_string = ''.join(latex_string.split())
    
    # Define a stack to keep track of opening delimiters
    stack = []
    
    # Define a dictionary to match opening and closing delimiters
    delimiters = {
        '(': ')', '[': ']', '{': '}', '<': '>',
        r'\left(': r'\right)', r'\left[': r'\right]', 
        r'\left{': r'\right}', r'\left<': r'\right>',
        r'\left\{': r'\right\}', r'\left|': r'\right|',
        r'\left\|': r'\right\|', r'\left.': r'\right.',
        r'\left\langle': r'\right\rangle'
    }
    
    # Regular expression to match \left and \right commands with their delimiters
    pattern = r'(\\left[\(\[\{\<\|\.\|]|\\right[\)\]\}\>\|\.\|]|\(|\)|\[|\]|\{|\}|\<|\>)'
    
    # Find all delimiters and \left/\right commands
    tokens = re.findall(pattern, latex_string)
    
    for token in tokens:
        if token.startswith(r'\left') or token in '([{<':
            stack.append(token)
        elif token.startswith(r'\right') or token in ')]}>':
            if not stack:
                return False
            last_open = stack.pop()
            if token != delimiters.get(last_open):
                return False
    
    # If the stack is empty, all delimiters are balanced
    return len(stack) == 0


In [55]:
#| hide
# Test cases
    
assert _is_semantically_left_right_balanced(r"\left( x \right)")
assert _is_semantically_left_right_balanced(r"\left( x \left[ y \right] \right)")
assert _is_semantically_left_right_balanced(r"\left( x \right) \left[ y \right]")
assert not _is_semantically_left_right_balanced(r"\left( x \left[ y \right)")
assert not _is_semantically_left_right_balanced(r"\left( x \right] \left[ y \right)")
assert _is_semantically_left_right_balanced(r"x + y")
assert _is_semantically_left_right_balanced(r"\left\{ x \left( y \right) \right\}")
assert not _is_semantically_left_right_balanced(r"\left. x \right|_{a}^{b}")
assert _is_semantically_left_right_balanced(r"\left\| x \right\|")
assert _is_semantically_left_right_balanced(r"\left< x , y \right>")
assert _is_semantically_left_right_balanced(r"\left( \left[ \left\{ x \right\} \right] \right)")
assert not _is_semantically_left_right_balanced(r"\left( \left[ \left\{ x \right\} \right] \right]")

# TODO: the example of r"\left. x \right|_{a}^{b}" could be osmething that occurs often. do something about this.

In [56]:
#| export
def _has_invalid_left_right_bracket(
        latex_string: str
        ) -> bool:
    """
    Return `True` is there is at least one invalid use of
    a `\left` or `\right` command.

    This is a helper function of `math_mode_string_is_syntactically_valid`
    """
    # Remove all whitespace from the string
    latex_string = ''.join(latex_string.split())
    
    # Define valid brackets for \left and \right
    valid_brackets = [
        r'(', r')',
        r'[', r']',
        r'\(', r'\)', r'\[', r'\]',  # Parentheses and square brackets
        r'\{', r'\}',                # Curly braces (escaped)
        r'<', r'>',                  # Angle brackets
        r'\|',                       # Vertical bar
        r'|',                       # Vertical bar
        r'\\\|',                     # Double vertical bar (escaped)
        r'\.',                       # Dot
        r'.',                       # Dot
        r'\\lfloor', r'\\rfloor',    # Floor brackets
        r'\\lceil', r'\\rceil',      # Ceiling brackets
        r'\\langle', r'\\rangle'     # Angle brackets (commands)
    ]
    
    # Escape special regex characters and join with |
    valid_brackets_pattern = '|'.join(re.escape(b) for b in valid_brackets)
    
    # Pattern to match \left or \right followed by a valid bracket
    valid_pattern = rf'\\(left|right)({valid_brackets_pattern})'
    
    # Find all \left and \right commands
    commands = list(re.finditer(r'\\(left|right)', latex_string))
    
    for command in commands:
        # Check if the command is followed by a valid bracket
        if not re.match(valid_pattern, latex_string[command.start():]):
            # If not, return True and the invalid command
            invalid_part = latex_string[command.start():command.start()+6]  # Adjust slice as needed
            # return True, invalid_part
            return True
    
    # return False, None
    return False

In [57]:
assert not _has_invalid_left_right_bracket(r"\left( x \right)")
assert not _has_invalid_left_right_bracket(r"\left[ x \right]")
assert not _has_invalid_left_right_bracket(r"\left\{ x \right\}")
assert not _has_invalid_left_right_bracket(r"\left< x \right>")
assert not _has_invalid_left_right_bracket(r"\left| x \right|")
assert not _has_invalid_left_right_bracket(r"\left\| x \right\|")
assert not _has_invalid_left_right_bracket(r"\left\| x \right\|")
assert _has_invalid_left_right_bracket(r"\lefta x \right)")
assert _has_invalid_left_right_bracket(r"\left( x \righta")
assert _has_invalid_left_right_bracket(r"\left x \right)")
assert _has_invalid_left_right_bracket(r"\left( x \right x")
assert _has_invalid_left_right_bracket(r"\left\backslash x \right/")
assert not _has_invalid_left_right_bracket(r"x + y")
assert _has_invalid_left_right_bracket(r"\left\\")
assert _has_invalid_left_right_bracket(r"\right\\")


In [58]:
#| export
import re

def _has_double_script(
        latex_string: str
        ) -> bool:
    """
    Return `True` is there is at least one double superscript
    or double subscript in `latex_string` 

    This function fails to give correct outputs for more
    nuanced texts, such as `r"x^{2}_{3}^{4}"`; while in
    principle, the function should return `True` on this
    input, the actual return value is `False`.

    This is a helper function of `math_mode_string_is_syntactically_valid`
    """
    # Remove all whitespace from the string
    latex_string = ''.join(latex_string.split())
    
    # Function to match balanced braces
    def match_braces(s, start):
        count = 0
        for i, char in enumerate(s[start:], start):
            if char == '{':
                count += 1
            elif char == '}':
                count -= 1
                if count == 0:
                    return i
        return len(s) - 1

    # Find all subscripts and superscripts
    i = 0
    last_script = None
    while i < len(latex_string):
        if latex_string[i] in '^_' and (i == 0 or latex_string[i-1] != '\\'):
            current_script = latex_string[i]
            i += 1
            if i < len(latex_string):
                if latex_string[i] == '{':
                    end = match_braces(latex_string, i)
                    script_content = latex_string[i:end+1]
                    i = end + 1
                else:
                    script_content = latex_string[i]
                    i += 1
                
                if last_script and last_script[0] == current_script:
                    return True
                last_script = (current_script, script_content)
        else:
            if latex_string[i] not in '^_':
                last_script = None
            i += 1

    return False

In [59]:
#| hide
assert not _has_double_script(r"x^2")
assert not _has_double_script(r"x_2")
assert not _has_double_script(r"x^{2}")
assert not _has_double_script(r"x_{2}")
assert _has_double_script(r"x^2^3")
assert _has_double_script(r"x_2_3")
assert _has_double_script(r"x^{2}^{3}")
assert _has_double_script(r"x_{2}_{3}")
assert not _has_double_script(r"x^{2^3}")
assert not _has_double_script(r"x_{2_3}")
assert not _has_double_script(r"x_2^3")
assert not _has_double_script(r"x^2_3")
assert not _has_double_script(r"x\^2")
assert not _has_double_script(r"x\_2")
assert not _has_double_script(r"x^{2^{3^4}}")
assert not _has_double_script(r"x\_2")
assert not _has_double_script(r"x_{2_{3_4}}")
assert not _has_double_script(r'$$x^2 + y^2$$')
assert not _has_double_script(r'\operatorname{Res}^{G}_{H} M')

# The function does not work correctly on the below input
# assert _has_double_script(r"x^{2}_{3}^{4}")

In [60]:
#| export
def _has_unescaped_dollar(s):
    # Pattern explanation:
    # (?<!\\) is a negative lookbehind assertion that ensures the dollar sign is not preceded by a backslash
    # \$ matches a literal dollar sign
    pattern = r'(?<!\\)(\\\\)*\$'
    match = re.search(pattern, s)
    return bool(match)

In [61]:
#| hide
assert not _has_unescaped_dollar("This is a normal string")
assert not _has_unescaped_dollar("This has an escaped dollar sign \\$")
assert _has_unescaped_dollar("This has an unescaped dollar sign $")
assert _has_unescaped_dollar("Mixed case: \\$ and $")
assert _has_unescaped_dollar("Multiple unescaped: $ $")
assert not _has_unescaped_dollar("Escaped at the end \\$")
assert _has_unescaped_dollar("Unescaped at the end $")
assert _has_unescaped_dollar("$")
assert not _has_unescaped_dollar("\\$")
assert _has_unescaped_dollar("\\\\$")  # Double backslash followed by dollar
assert not _has_unescaped_dollar("\\\\\\$")
assert _has_unescaped_dollar("\\ $")

In [62]:
pattern = regex_pattern_detecting_command(('Sur', 0, None, r'\mathrm{Sur}'))
text = r'The number of element of $\Sur(\operatorname{Cl} \mathcal{O}_L, A)$ is ...'
match = pattern.search(text)
start, end = match.span()
test_eq(text[start:end], r'\Sur')


In [63]:
#| export
def extract_latex_commands(latex_string):
    # Create a LatexWalker instance
    walker = LatexWalker(latex_string)
    
    # Get the nodes from the LaTeX string
    try:
        nodelist, _, _ = walker.get_latex_nodes()
    except Exception as e:
        print(f"Error parsing LaTeX: {e}")
        return []  # Return an empty list if there's a parsing error
    # Extract commands
    commands = []
    extract_commands_from_nodes(commands, nodelist)
    return commands


def extract_commands_from_nodes(
        commands: list[str],
        nodes: list[LatexNode]
        ):
    """
    This is a helper function to `extract_latex_commands`.
    """
    for node in nodes:
        # If the node is a character node, we skip it
        if isinstance(node, LatexCharsNode):
            continue
        elif isinstance(node, LatexMacroNode):
            commands.append(node.macroname)
            # Check for arguments of the macro node
            for arg in node.nodeargs:
                if arg and not isinstance(arg, LatexCharsNode):
                    if hasattr(arg, 'nodelist'):  # Ensure the argument is not None
                        extract_commands_from_nodes(commands, arg.nodelist)  # Extract from argument nodes
                    elif isinstance(arg, LatexMacroNode):
                        commands.append(arg.macroname)
        # elif isinstance(node, LatexEnvironmentNode):

        elif isinstance(node, LatexEnvironmentNode):
            commands.extend(_detect_begin_and_end_environments(node.latex_verbatim()))
        # If the node has a nodelist, extract commands from it
        if hasattr(node, 'nodelist'):
            extract_commands_from_nodes(commands, node.nodelist)

def _detect_begin_and_end_environments(
        latex_string: str
        ) -> list[str]:
    """
    Return a list of at most two items containing 'begin' if there is a \begin and containing 'end' if there is an \end

    This is a helper function to `extract_latex_commands`.
    """
    # Regular expressions to match \begin and \end with optional spaces
    begin_pattern = r'\\\s*begin'
    end_pattern = r'\\\s*end'
    
    # Initialize an empty result list
    result = []
    
    # Check for \begin
    if re.search(begin_pattern, latex_string):
        result.append('begin')
    
    # Check for \end
    if re.search(end_pattern, latex_string):
        result.append('end')
    
    return result

In [64]:


# Example usage
assert extract_latex_commands(r"\frac{a}{b}") == ['frac']
assert extract_latex_commands(r"$\frac{a}{b}$") == ['frac']
assert extract_latex_commands(r"\sqrt[n]{x}") == ['sqrt']
assert extract_latex_commands(r"\binom{n}{k}") == ['binom']
assert extract_latex_commands(r"x^2 + y^2") == []  # No commands, just variables
assert extract_latex_commands(r"\overset{a}{b}") == ['overset']

# Additional tests
assert extract_latex_commands(r"\sum_{i=1}^{n} i") == ['sum']
assert extract_latex_commands(r"\int_{0}^{\infty} e^{-x} dx") == ['int', 'infty']
assert extract_latex_commands(r"\lim_{x \to 0} f(x)") == ['lim', 'to']
assert extract_latex_commands(r"\prod_{i=1}^{n} i") == ['prod']
assert extract_latex_commands(r"\text{Hello} + \frac{1}{2}") == ['text', 'frac']

# Multiple commands in one string
assert extract_latex_commands(r"\frac{a}{b} + \sqrt{c} + \binom{n}{k}") == ['frac', 'sqrt', 'binom']
assert extract_latex_commands(r"\sum_{i=1}^{n} i + \int_{0}^{\infty} e^{-x} dx") == ['sum', 'int', 'infty']
assert extract_latex_commands(r"\lim_{x \to 0} f(x) = \frac{1}{x}") == ['lim', 'to', 'frac']
assert extract_latex_commands(r"\overset{a}{b} + \underset{c}{d}") == ['overset', 'underset']
assert extract_latex_commands(r"\text{This is } \textbf{bold} + \textit{italic} + \frac{1}{2}") == ['text', 'textbf', 'textit', 'frac']

# Complex expressions
test_eq(extract_latex_commands(r"\frac{\sum_{i=1}^{n} i}{n} = \frac{n(n+1)}{2}"), ['frac', 'sum', 'frac'])
test_eq(extract_latex_commands(r"\int_{0}^{1} x^2 \, dx = \frac{1}{3}"), ['int', ',', 'frac'])
assert extract_latex_commands(r"\sqrt{\frac{a}{b}} + \binom{n}{k}") == ['sqrt', 'frac', 'binom']

# Incorrect synntax
assert extract_latex_commands(r"\frac{}}") == ['frac']
assert extract_latex_commands(r"\frac{a}{b}{c}") == ['frac']  # Extra argument
assert extract_latex_commands(r"\frac{a}{b + \frac{c}{d}}") == ['frac', 'frac']  # Nested command
test_eq(extract_latex_commands(r"\sum_{i=1}^{n} i + \int_{0}^{\infty} e^{-x} dx = \frac{1}{2}"), ['sum', 'int', 'infty', 'frac'])
# Comment
assert extract_latex_commands(r"%hi") == []
# Environment Node
test_eq(extract_latex_commands(r"\begin{align} \end{align}"), ['begin', 'end'])
test_eq(extract_latex_commands(r"\begin{align}"), ['begin'])
# test_eq(extract_latex_commands(r"\ begin{align} \end{align}"), [' '])

test_eq(extract_latex_commands(r'\text\in'), ['text', 'in'])

In [65]:
#| export
# Some arguments that can be used towards `regex_pattern_detecting_command`
# for some basic latex arguments.
# Note that the last argument doesn't actually matter, because
# we just want to be able to detect uses of comands, see
# `regex_pattern_detecting_commands``
REGEX_PATTERN_DETECTIONS = [
    ('frac', 2, None, None),
    ('binom', 2, None, None),
    ('sqrt', 1, '2', None),
    ('overset', 2, None, None),
    ('underset', 2, None, None),
    ('stackrel', 2, None, None),
    ('dfrac', 2, None, None),
    ('cfrac', 2, None, None),
    ('sideset', 3, None, None),
    ('xrightarrow', 1, None, None),
    ('xleftarrow', 1, None, None),
    ('overline', 1, None, None),
    ('bar', 1, None, None),
    ('arccos', 1, None, None),
    ('arcsin', 1, None, None),
    ('arctan', 1, None, None),
    ('arg', 1, None, None),
    ('atop', 2, None, None),
    ('begin', 1, None, None),
    ('boldsymbol', 1, None, None),
    ('breve', 1, None, None),
    ('check', 1, None, None),
    ('cline', 1, None, None),
    ('cos', 1, None, None),
    ('cosh', 1, None, None),
    ('cot', 1, None, None),
    ('csc', 1, None, None),
    ('dddot', 1, None, None),
    ('ddot', 1, None, None),
    ('dot', 1, None, None),
    ('end', 1, None, None),
    ('exp', 1, None, None),
    ('gcd', 2, None, None),
    ('grave', 1, None, None),
    ('hat', 1, None, None),
    # ('int', '1', None, None),
    ('lcm', 2, None, None),
    # ('left', 1, None, None),
    ('lg', 1, None, None),
    ('lim', 1, None, None),
    ('liminf', 1, None, None),
    ('limsup', 1, None, None),
    ('ln', 1, None, None),
    ('log', 1, None, None),
    ('longdiv', 2, None, None),
    ('lvert', 1, None, None),
    ('mapsto', 1, None, None),
    ('mathbb', 1, None, None),
    ('mathbf', 1, None, None),
    ('mathcal', 1, None, None),
    ('mathfrak', 1, None, None),
    ('mathop', 1, None, None),
    ('mathrm', 1, None, None),
    ('mathscr', 1, None, None),
    ('max', 1, None, None),
    ('min', 1, None, None),
    ('multicolumn', 3, 'center', None),
    ('multirow', 3, None, None),
    ('not', 1, None, None),
    ('oint', 1, None, None),
    ('overbrace', 1, None, None),
    ('overleftarrow', 1, None, None),
    ('overleftrightarrow', 1, None, None),
    ('overrightarrow', 1, None, None),
    # ('prod', 1, None, None),
    # ('right', 1, None, None),
    ('rvert', 1, None, None),
    ('sec', 1, None, None),
    ('section', 1, None, None),
    ('sin', 1, None, None),
    ('sinh', 1, None, None),
    ('stackrel', 2, None, None),
    ('subsection', 2, None, None),
    ('substack', 2, None, None),
    ('subsubsection', 2, None, None),
    # ('sum', 1, None, None),
    ('sup', 1, None, None),
    ('tag', 1, None, None),
    ('tan', 1, None, None),
    ('tanh', 1, None, None),
    ('text', 1, None, None),
    ('textbf', 1, None, None),
    ('textrm', 1, None, None),
    ('tilde', 1, None, None),
    ('underbrace', 1, None, None),
    ('underline', 1, None, None),
    ('underset', 2, None, None),
    ('varliminf', 1, None, None),
    ('varlimsup', 1, None, None),
    ('vec', 1, None, None),
    ('widehat', 1, None, None),
    ('widetilde', 1, None, None),
    ('xrightarrow', 1, None, None),
]
temp_dict = {}
for entry in REGEX_PATTERN_DETECTIONS:
    temp_dict[entry[0]] = entry
REGEX_PATTERN_DETECTIONS = temp_dict



In [66]:
#| export
def detect_incorrect_latex_commands(
        latex_string: str,
        ) -> bool:
    """
    Return `True` if there is at least one syntactically
    incorrect use of a latex command detected in `latex_string`.

    This is a helper function to `math_mode_string_is_syntactically_valid`.
    """
    commands_in_string = set(extract_latex_commands(latex_string))
    for command in commands_in_string:
        if command not in temp_dict:
            continue
        tuppy = temp_dict[command]
        pattern = regex_pattern_detecting_command(tuppy)
        # Look at each invocation of the command to see if 
        # each invocation is properly used.
        simp_pattern = rf"\\\s*{command}"
        simp_matches = re.finditer(simp_pattern, latex_string)
        # simp_matches = re.findall(simp_pattern, latex_string)
        for match in simp_matches:
            trailing_substring = latex_string[match.start():]
            alt_match = pattern.search(trailing_substring)
            if not alt_match or alt_match.span()[0] != 0:
                return True

        # if not matches and not simp_matches:
        #     continue
        # if len(matches) != len(simp_matches):
        #     return True
    return False

In [67]:
#| hide
    
# Tests
# Correct usage
assert not detect_incorrect_latex_commands(r'\frac{a}{b}')
assert not detect_incorrect_latex_commands(r'\binom{n}{k}')
assert not detect_incorrect_latex_commands(r'\sqrt[n]{x}')
assert not detect_incorrect_latex_commands(r'\overset{a}{b}')
assert not detect_incorrect_latex_commands(r'\underset{a}{b}')
assert not detect_incorrect_latex_commands(r'\stackrel{a}{b}')
assert not detect_incorrect_latex_commands(r'\dfrac{a}{b}')
assert not detect_incorrect_latex_commands(r'\cfrac{1}{1+\cfrac{1}{x}}')
assert not detect_incorrect_latex_commands(r'\xleftarrow{text}')
assert not detect_incorrect_latex_commands(r'\xrightarrow{text}')
assert not detect_incorrect_latex_commands(r'\left( \right.')
assert not detect_incorrect_latex_commands(r'\overbrace{x+y+z}^{\text{sum}}')
assert not detect_incorrect_latex_commands(r'\underbrace{x+y+z}_{\text{sum}}')
assert not detect_incorrect_latex_commands(r'\overbrace{x+y+z}')
assert not detect_incorrect_latex_commands(r'\underbrace{x+y+z}')

# Incorrect usage (missing arguments)
assert detect_incorrect_latex_commands(r'\frac{a}')
assert detect_incorrect_latex_commands(r'\binom{n}')
assert detect_incorrect_latex_commands(r'\overset{a}')
assert detect_incorrect_latex_commands(r'\underset{a}')
assert detect_incorrect_latex_commands(r'\stackrel{a}')
assert detect_incorrect_latex_commands(r'\dfrac{a}')
assert detect_incorrect_latex_commands(r'\cfrac{1}')
assert detect_incorrect_latex_commands(r'\sideset{_1^2}')

#Extra arguments are technically okay 
assert not detect_incorrect_latex_commands(r'\frac{a}{b}{c}')
assert not detect_incorrect_latex_commands(r'\binom{n}{k}{m}')
assert not detect_incorrect_latex_commands(r'\overset{a}{b}{c}')

# Mixed correct and incorrect usage
assert detect_incorrect_latex_commands(r'\frac{a}{b} + \frac{c}')
assert detect_incorrect_latex_commands(r'\binom{n}{k} \cdot \binom{m}')

assert detect_incorrect_latex_commands(r'\text\in')


In [68]:
#| export
def detect_unbalanced_environments(
        latex_string: str) -> list[str]:
    # Define a regex pattern to match \begin{...} and \end{...}
    pattern = r'\\(begin|end)\{([^}]+)\}'
    
    # Stack to keep track of opened environments
    stack = []
    # List to store errors
    errors = []

    # Find all matches in the LaTeX string
    for match in re.finditer(pattern, latex_string):
        command, env_name = match.groups()
        
        if command == 'begin':
            # Push the environment name onto the stack
            stack.append(env_name)
        elif command == 'end':
            # Check if there is a matching begin for this end
            if stack and stack[-1] == env_name:
                stack.pop()  # Match found, pop from stack
            else:
                # Mismatch found, record the error
                errors.append(f"Mismatched \\end{{{env_name}}} at position {match.start()}")

    # If there are any unmatched begin commands left in the stack, report them
    while stack:
        unmatched_env = stack.pop()
        errors.append(f"Unmatched \\begin{{{unmatched_env}}}")

    return errors

In [69]:

# Example usage
latex_code = r"""
\begin{document}
This is a sample document.
\begin{itemize}
    \item First item
    \begin{enumerate}
        \item First sub-item
    \end{enumerate}
    \item Second item
\end{itemize}
\end{document}
\begin{wrongenv}  % This environment is unmatched
"""

# Detect unbalanced environments
unbalanced = detect_unbalanced_environments(latex_code)

# Print the results
# if unbalanced:
#     print("Unbalanced environments detected:")
#     for error in unbalanced:
#         print(error)
# else:
#     print("All environments are balanced.")

assert unbalanced


latex_code = r"""
\begin{document}
This is a sample document.
\begin{itemize}
    \item First item
    \begin{enumerate}
        \item First sub-item
    \end{enumerate}
    \item Second item
\end{itemize}
\end{document}
"""
# Detect unbalanced environments
unbalanced = detect_unbalanced_environments(latex_code)
assert not unbalanced

In [70]:
#| export
def math_mode_string_is_syntactically_valid(
        text: str,
        ) -> bool:
    """
    Return `True` if `text` is determined to be syntactically valid
    as a latex str.

    There may be TeX syntax rules beyond the scope of this function.

    Some caveats:

    `text` is allowed to have dollar signs `$` and is also allowed to not have
    dollar signs. Even if `text` does not have dollar signs, this function
    may return `True`. Even if `text` has dollar signs, this function may return
    `False` if the entire string is not a singular math mode string or if the
    dollar signs are not used in a math-mode-valid way.
    """
    # 
    text = text.strip()
    math_mode_indices = latex_indices(text)
    if _has_unescaped_dollar(text):
        if len(math_mode_indices) != 1:
            return False
        if (math_mode_indices[0][0] != 0 or math_mode_indices[0][1] != len(text)):
            return False
    if not _does_not_end_with_script(text):
        return False
    if _detect_backslash_space_curly(text):
        return False
    if not _is_balanced_braces(text):
        return False
    if _has_invalid_left_right_bracket(text):
        return False
    if not _is_left_right_balanced(text):
        return False
    if _has_double_script(text):
        return False
    if detect_incorrect_latex_commands(text):
        return False
    if bool(detect_unbalanced_environments(text)):
        return False
    return True



In [71]:
assert not math_mode_string_is_syntactically_valid(r'$$n=p_1^{e_1} p_2^{e_2} \cdots p_k^$$')
assert not math_mode_string_is_syntactically_valid(r'$x^2 + y^2')
assert not math_mode_string_is_syntactically_valid(r'$$x^2 + y^2$')
assert not math_mode_string_is_syntactically_valid(r'$$x^2 + y^2$ $')
assert math_mode_string_is_syntactically_valid(r'hi')
assert math_mode_string_is_syntactically_valid(r'$hi$')
assert not math_mode_string_is_syntactically_valid(r'$hi$$')
assert math_mode_string_is_syntactically_valid(r'$\\dim ^ a$')
assert not math_mode_string_is_syntactically_valid(r'{ hi')
assert math_mode_string_is_syntactically_valid(r'\{ hi')
assert math_mode_string_is_syntactically_valid(r'\ [')
assert math_mode_string_is_syntactically_valid(r'\left( \right.')
assert not math_mode_string_is_syntactically_valid(r'\left \right.')
assert math_mode_string_is_syntactically_valid(r'$$\left|\sum_{i=0} \right|$$')
assert math_mode_string_is_syntactically_valid(r'$\\\$$')
assert not math_mode_string_is_syntactically_valid(r'\begin{enumerate}')
assert math_mode_string_is_syntactically_valid(r'\begin{enumerate} asdf \end{enumerate}')
# TODO there is something to be considered here; the below
# example would be a syntax error, and yet the functions  implemented
# above don't really detect as such.
# assert not detect_incorrect_latex_commands(r'\sideset{_1^2}{_3^4}')

math_mode_string_is_syntactically_valid(r'\text\in')

False

The `math_mode_string_is_syntactically_valid` experimentally assesses whether a given math mode LaTeX string is syntactically valid. In principal, this should mean that a LaTeX syntax error caused by the string should be detected by the function.

TODO: consider the following to :


Unescaped % sign (starts a comment):
`$x = 50% of y$`

Using ! (negative space) at the beginning of math mode:
`$\!x + y$`

The following lists some example outputs of the `math_mode_string_is_syntactically_valid` function along with explanations.

Unmatched curly braces are a common syntactical error:

In [72]:
assert not math_mode_string_is_syntactically_valid(r'\sqrt{x}}')

However, using `\{` or `\}` does not count towards curly bracket matching:

In [73]:
assert math_mode_string_is_syntactically_valid(r'\{hi')

On the other hand, a backslash `\` followed by spaces ` ` and then followed by a curly bracket is in itself an invalid syntax.

In [74]:
assert not math_mode_string_is_syntactically_valid(r'\ {hi')

`math_mode_string_is_syntactically_valid` will consider the validity of a string whether or not the string has math mode delimiters. 

In [75]:
assert math_mode_string_is_syntactically_valid('\operatorname{Gal}')
assert math_mode_string_is_syntactically_valid('$\operatorname{Gal}$')

However, `math_mode_string_is_syntactically_valid` returns `False` if the string has dollar sign delimiters and more than one math mode string is detected in the string (use `latex_indices` to separate out math mode strings.),  

In [76]:
# More than one math mode string is present
assert not math_mode_string_is_syntactically_valid('$hi$ $bye$')
# the math mode delimiter `$` is unbalanced.
assert not math_mode_string_is_syntactically_valid(r'$x^2 + y^2')
# the math mode delimiters `$$` and `$` are unbalanced.
assert not math_mode_string_is_syntactically_valid(r'$$x^2 + y^2$')

In [77]:
#| export
# def math_mode_string_is_syntactically_clean(
#         text: str,
#         ) -> bool:
#     """
#     Return `True` if `text` is syntactically "clean" as a LaTeX math mode str.
    
#     While the precise meaning of this may be subjective, here we will
#     consider `text` to be clean, assuming that it is syntactically valid, if

#     - It does not have double blackslashes
#     """
#     if r'\\' in text:
#         return False

## Tweak a latex string

Sometimes, when autogenerating a latex string through an ML model, some minor formatting eyesores occur, such as a curly bracket `{` or an underscore `_` followed by an unncessary space. We provide some functions to fix such formatting.

In [78]:
#| export
def reduce_unnecessary_spaces(
        text: str,

        ) -> str:
    """
    Return a string modifying `text` by removing spaces which are
    unnecessary for the purposes of considering the string as a 
    LaTeX string.
    """
    pattern = r'([{_^\\()])\s+'
    text = re.sub(pattern, r'\1', text)
    pattern = r'\s+([}_^()])'
    text = re.sub(pattern, r'\1', text)
    return text
    # for char in ['{', '_', '^', '}', '\\']:
    #     text = re.sub(fr'\s*{chr}\s*', chr, text)

In [79]:

# It might not be necessary or desirable to eliminate the space before the backslash `\``
test_eq(reduce_unnecessary_spaces(r'something something \  operatorname'), r'something something \operatorname')
test_eq(reduce_unnecessary_spaces(r'\operatorname{Res}  ^ G_ H (R)'), r'\operatorname{Res}^G_H(R)')
test_eq(reduce_unnecessary_spaces(r'\operatorname{Res}^{ G}_{ H } (R)'), r'\operatorname{Res}^{G}_{H}(R)')
test_eq(reduce_unnecessary_spaces(r'M_{ f}'), r'M_{f}')
test_eq(reduce_unnecessary_spaces(r'h_{ p}'), r'h_{p}')
test_eq(reduce_unnecessary_spaces(r'\zeta (s)'), r'\zeta(s)')
test_eq(reduce_unnecessary_spaces(r'\mathcal{ H} _{ v}'), r'\mathcal{H}_{v}')

#### Make fixes to summary

In [80]:
#| export
def fix_autogen_formatting(
        text: str
        ) -> str:
    """Fix some latex formatting issues in an autogenerated text
    """
    text = text.replace(r'\ ', '\\')
    text = text.replace(r'{ ', r'{')
    text = text.replace(r' }', r'}')
    text, _ = re.subn(r'\$\s*([^\$]+?)\s*\$', r'$\1$', text)
    # TODO: do $ <latex_string> $ into $<latex_stinrg>$
    # TODO: if the replacement of r'\ ' by '\\' happesn to
    # make `\` stick to the previous chunk of things
    # (e.g. r'd\in\mathbb{Z}_{\geq 0}`, then give it some
    # space, e.g. r'd \in \mathbb{Z}_{\geq 0}'.
    text = reduce_unnecessary_spaces(text)
    text = _insert_newline_or_spaces_around_latex(text)
    return text


def _insert_newline_or_spaces_around_latex(
        text:  str
        ) -> str:
    """
    Insert spaces or newlines around latex math mode strings inside `text`
    if necessary.
    """
    math_mode_indices = latex_indices(text)
    replacements = []
    for start, end in math_mode_indices:
        math_mode = text[start:end]
        spaces_potentially_added = math_mode
        if not math_mode.startswith('$$'): #starts with exactly one $
            if start != 0 and text[start-1] != ' ':
                spaces_potentially_added = f' {spaces_potentially_added}'
            if end != len(text) and text[end] != ' ':
                spaces_potentially_added = f'{spaces_potentially_added} '
            replacements.append(spaces_potentially_added)
            continue
        if start != 0 and text[start-1] != '\n':
            front_newline_count = 2
        elif start > 1 and text[start-2] != '\n':
            front_newline_count = 1
        else:
            front_newline_count = 0
        spaces_potentially_added = front_newline_count * '\n' + spaces_potentially_added

        if end != len(text) and text[end] != '\n':
            back_newline_count = 2
        elif end < len(text) - 1 and text[end-1] != '\n':
            back_newline_count = 1
        else:
            back_newline_count = 0
        spaces_potentially_added = spaces_potentially_added + '\n'*back_newline_count
        replacements.append(spaces_potentially_added)
    text = replace_string_by_indices(text, math_mode_indices, replacements)
    text = text.replace('$  $', '$ $')
    text = text.replace('$$\n\n\n\n$$', '$$\n\n$$')
    text = text.replace('$$\n\n\n$$', '$$\n\n$$')
    return text

In [81]:
#| hide
sample_text = '$hi$'
test_eq(_insert_newline_or_spaces_around_latex(sample_text), sample_text)
sample_text = '$hi$asdf'
test_eq(_insert_newline_or_spaces_around_latex(sample_text), '$hi$ asdf')
sample_text = 'asdf$hi$'
test_eq(_insert_newline_or_spaces_around_latex(sample_text), 'asdf $hi$')
sample_text = 'asdf$hi$asdf'
test_eq(_insert_newline_or_spaces_around_latex(sample_text), 'asdf $hi$ asdf')


sample_text = '$$hi$$'
test_eq(_insert_newline_or_spaces_around_latex(sample_text), sample_text)
sample_text = 'asdf$$hi$$'
test_eq(_insert_newline_or_spaces_around_latex(sample_text), 'asdf\n\n$$hi$$')
sample_text = '$$hi$$asdf'
test_eq(_insert_newline_or_spaces_around_latex(sample_text), '$$hi$$\n\nasdf')
sample_text = 'asdf$$hi$$asdf'
test_eq(_insert_newline_or_spaces_around_latex(sample_text), 'asdf\n\n$$hi$$\n\nasdf')

sample_text = '$hi$ $hi$'
test_eq(_insert_newline_or_spaces_around_latex(sample_text), '$hi$ $hi$')
sample_text = '$hi$$hi$'
test_eq(_insert_newline_or_spaces_around_latex(sample_text), '$hi$ $hi$')
# sample_text = '$$hi$$ $$hi$$'
# test_eq(_insert_newline_or_spaces_around_latex(sample_text), '$$hi$$\n\n$$hi$$')
sample_text = '$$hi$$$$hi$$'
test_eq(_insert_newline_or_spaces_around_latex(sample_text), '$$hi$$\n\n$$hi$$')

Currently, the model is inclined to decode and format its summarizations in such a way that creates formatting issues either for LaTeX or `Obsidian.md`. For example, the model would output a str containing

- `\ <command_name>` instead of `\<command_name>`
- `{ ` when `{` is preferable
- `$ <latex_string> $` when `$<latex_string>$` is needed for `Obsidian.md`.

The `fix_summary_formatting` function attempts to get around some of these issues.

In [82]:
text = r'\ to'
sample_output = fix_autogen_formatting(text)
assert r'\to' in sample_output

text = r'$d\ in\ mathbb{ Z}_{\ geq 0} $'
sample_output = fix_autogen_formatting(text)
assert r'\in' in sample_output
assert r'\mathbb{Z}' in sample_output
assert r'\geq 0' in sample_output


In [83]:
text = r'There are some extra spaces in this math mode string: $  5 + 7 = 12 $.'
sample_output = fix_autogen_formatting(text)
print(sample_output)
assert r'$5' in sample_output
assert r'12$' in sample_output

There are some extra spaces in this math mode string: $5 + 7 = 12$ .


In [84]:
text=  r'the group of $G$-coinvariants of $A$. It is defined as $$A_{G} :=A / I_\G} A$$'
sample_output = fix_autogen_formatting(text)
print(sample_output)

the group of $G$ -coinvariants of $A$ . It is defined as 

$$A_{G} :=A / I_\G} A$$


## Correct syntax errors in autogenerated math mode strings

In [85]:
#| export
def _tokenize_latex_math(
        latex_string: str
        ) -> list[str]:
    """
    Tokenize `latex_string` by the following principles:

    1. A latex command/macro invoked (but not the inputs) is a token.
    2. the special characters ^ { } _ are tokens.
    3. groups of consecutive whitespaces are tokens.
    4. afterwards, all "words" (one or more consecutive non-whitespace non-special characters) are tokens.
    """
    # Define the regex pattern for tokenization
    pattern = r"""
        (\\[a-zA-Z]+)        # Match LaTeX commands (e.g., \alpha, \sum)
        | ([^\\\s^{}_]+)     # Match words (consecutive non-whitespace, non-special characters)
        | ([^\\\s])          # Match special characters (including ^, {, }, _, etc.)
        | (\s+)              # Match groups of consecutive whitespace
    """
    # Use re.findall to find all matches based on the pattern
    tokens = re.findall(pattern, latex_string, re.VERBOSE)
    # Extract the matched groups, filtering out empty strings
    token_list = [token for group in tokens for token in group if token]
    return token_list


In [86]:
#| hide
# Example usage
latex_string = r"\alpha + \beta^{2} - \gamma_{1} + 3 \times \text{some text}"
tokens = _tokenize_latex_math(latex_string)
# print(tokens)
test_eq(
    ['\\alpha', ' ', '+', ' ', '\\beta', '^', '{', '2', '}', ' ', '-', ' ', '\\gamma', '_', '{', '1', '}', ' ', '+', ' ', '3', ' ', '\\times', ' ', '\\text', '{', 'some', ' ', 'text', '}'],
    tokens
    )
test_eq(''.join(tokens), latex_string)

In [87]:
#| export
def _list_of_candidates_from_math_mode_strings(
        main_content: str, # A text of LaTeX code. In practice, this should be the `main content` of an information note, cf. `summarize_notation`.`
        syntax_validation: Callable[str, bool] = math_mode_string_is_syntactically_valid # A test to tell whether a math mode string is syntactically  valid.
        ) -> set[str]:
    """
    Return a substrings from latex math mode strings in `main_content`
    that are syntactically valid .

    None of the elements in the output have delimiters (`$`, `$$`)
    """
    syntactically_valid_substrings = [] 
    math_mode_indices = latex_indices(main_content)
    for start, end in math_mode_indices:
        latex_str = main_content[start:end]
        latex_str = latex_str.strip('$')
        tokenization = _tokenize_latex_math(latex_str)
        for sublist in sublist_generator(tokenization):
            substring = ''.join(sublist)
            if syntax_validation(substring):
                syntactically_valid_substrings.append(substring.strip())
    return set(syntactically_valid_substrings)

In [88]:
#| hide
output = _list_of_candidates_from_math_mode_strings('$\operatorname{Gal}(L/K)$', math_mode_string_is_syntactically_valid)
assert r'\operatorname{Gal}' in output
assert r'Gal' in output

output = _list_of_candidates_from_math_mode_strings('$\operatorname{Gal}(L/K) \to G_\ell^\infty$', math_mode_string_is_syntactically_valid)

In [89]:
#| hide
_list_of_candidates_from_math_mode_strings(r'the signum of the complete factorization $\\text\\in S_n$ into disjoint cycles. It is defined by$$\\operatorname sgn(\\left )=(-1)n-t .$$')

{'',
 ')=(-1)n-t',
 ')=(-1)n-t .',
 '.',
 'S',
 'S_n',
 '\\in',
 '\\in S',
 '\\in S_n',
 '\\operatorname',
 '\\operatorname sgn(',
 '_n',
 'n',
 'sgn('}

In [90]:
#| export
def _find_closest_match(
        math_mode_text: str,
        replacement_candidates: list[str]
        ) -> Union[str, None]:
    """This is a helper function to `correct_latex_syntax_error`."""
    if not replacement_candidates:
        return None
    # Calculate Levenshtein distance for each candidate
    distances = [(candidate, distance(math_mode_text, candidate)) for candidate in replacement_candidates]
    # Find the candidate with the minimum distance
    closest_match = min(distances, key=lambda x: x[1])
    return closest_match[0]

In [91]:
#| hide
test_eq(_find_closest_match('hi', ['hib', 'basdy']), 'hib')

In [92]:
#| export
def correct_latex_syntax_error(
        summary: str, # The autogenerated summary
        replacement_candidates: list[str], # A list of candidates to replace. This is expected to be an output of `_list_of_candidates_from_math_mode_strings`
        # min_length_to_replace_math_mode_string: int = 5, # The minimum length that a math mode string needs to be (exclusing delimiting dollar signs `$`, `$$`) in summary in order to be considered for replacement.
        syntax_validation: Callable[str, bool] = math_mode_string_is_syntactically_valid # A test to tell whether a math mode string is syntactically  valid.
        ) -> str:
    """
    Attempt to replace within `summary` a modified version in which
    the syntactically incorrect latex math mode strings are replaced
    with the most closely resembling element of `replacement_candidates`. 
     
    with a modified version in which the
    latex math mode strings within `summary` that are syntactically
    incorrect 

    TODO: consider the possibility that not all math mode str delimiters
    are formatted correctly.
    """
    math_mode_indices = latex_indices(summary)
    replacements = []
    for start, end in math_mode_indices:
        math_mode_text = summary[start:end]
        if syntax_validation(math_mode_text) or not replacement_candidates:
            replacements.append(math_mode_text)
            continue
        delimiter = '$$' if math_mode_text.startswith('$$') else '$'
        replacement = _find_closest_match(math_mode_text, replacement_candidates)
        replacement = f'{delimiter}{replacement}{delimiter}'
        replacements.append(replacement)
    return replace_string_by_indices(summary, math_mode_indices, replacements)



In [93]:
sample_summary = r'the group of $G$-coinvariants of $A$. It is defined as $$A_{G} :=A / I_\G} A$$'
replacement_candidates = [
    'A',
    'A_',
    'A_{G}',
    'A_{G}:=A',
    'A_{G}:=A',
    'A_{G}:=A /',
    'A_{G}:=A / I_{G}',
    'A_{G}:=A / I_{G} A',
    'H_{0}(G, A)',
    'H_{0}(G, A) \\simeq',
    'H_{0}(G, A) \\simeq A',
    'H_{0}(G, A) \\simeq A_',
    'H_{0}(G, A) \\simeq A_{G}',
]
test_eq(correct_latex_syntax_error(sample_summary, replacement_candidates), r'the group of $G$-coinvariants of $A$. It is defined as $$A_{G}:=A / I_{G} A$$')
# replacement_candidates