In [None]:
#| default_exp helper.latex.core

# helper.latex.core
> Core functions for latex functionalities

In [None]:
#| export
import re


In [None]:
from fastcore.test import *

## Validity of latex syntax

#### Test latex syntax

We require some functions to evaluate whether a latex math mode string is syntactically valid.

In [None]:
#| export
def _does_not_end_with_script(s):
    """
    This is a helper function to `math_mode_string_is_syntactically_valid`.
    `s` is not supposed to have math mode delimits
    """
    s = s.strip(' $')
    return not s.endswith('_') and not s.endswith('^')

In [None]:
#| hide
assert not _does_not_end_with_script(r'n=p_1^{e_1} p_2^{e_2} \cdots p_k^')
assert not _does_not_end_with_script(r'$n=p_1^{e_1} p_2^{e_2} \cdots p_k^$')
assert not _does_not_end_with_script(r'$$n=p_1^{e_1} p_2^{e_2} \cdots p_k^$$')
assert _does_not_end_with_script(r'n=p_1^{e_1} p_2^{e_2} \cdots p_k')
assert _does_not_end_with_script(r'$n=p_1^{e_1} p_2^{e_2} \cdots p_k$')
assert _does_not_end_with_script(r'$$n=p_1^{e_1} p_2^{e_2} \cdots p_k$$')
assert not _does_not_end_with_script(r'n=p_1^{e_1} p_2^{e_2} \cdots p_k_')


In [None]:
#| export
def _is_balanced_braces(s):
    """
    This is a helper function to `math_mode_string_is_syntactically_valid`.

    Note that curly braces (`{`, `}`) that are not preceded by backslashes
    '\\' are counted towards "balancing". 
    """
    stack = []
    escaped = False
    
    for _, char in enumerate(s):
        if char == '\\':
            escaped = True
        elif char == '{' and not escaped:
            stack.append(char)
        elif char == '}' and not escaped:
            if not stack:
                return False
            stack.pop()
        else:
            escaped = False
    
    return len(stack) == 0


In [None]:
#| hide
assert _is_balanced_braces('{{}}')
assert _is_balanced_braces('{asdf_{}}')
assert not _is_balanced_braces('{hi')
assert not _is_balanced_braces('{hi asdf}}')
assert _is_balanced_braces(r'\{hi')
assert _is_balanced_braces(r'\{{hi}')
assert not _is_balanced_braces(r'\{{hi\}')
assert not _is_balanced_braces('}')

In [None]:
#| export
def _first_curly_bracket(s) -> str|None:
    r"""
    Return whether a left curly bracket `{` or a right curly bracket `}`
    appears first in the string, ignoring escaped curly brackets `\{` or `\}`.
    """
    i = 0
    while i < len(s):
        if s[i] == '\\' and i + 1 < len(s):
            if s[i+1] in '{}':
                # Skip this escaped bracket
                i += 2
                continue
        elif s[i] == '{':
            return '{'
        elif s[i] == '}':
            return '}'
        i += 1
    return None


In [None]:
#| hide
assert _first_curly_bracket("{abc}") == '{'
assert _first_curly_bracket("}abc{") == '}'
assert _first_curly_bracket("abc{def}") == '{'
assert _first_curly_bracket("abc}def{") == '}'
assert _first_curly_bracket("abc") == None
assert _first_curly_bracket("\\{abc}") == '}'
assert _first_curly_bracket("a\\}b{c}") == '{'
assert _first_curly_bracket("\\{\\}abc") == None
assert _first_curly_bracket("\\\\{abc}") == '}'
assert _first_curly_bracket("\\a{bc}") == '{'
assert _first_curly_bracket("") == None
assert _first_curly_bracket("\\") == None
assert _first_curly_bracket("\\\\{}") == '}'

In [None]:
#| export
# def _last_curly_bracket(s) -> str|None:
def _last_curly_bracket(s):
    i = len(s) - 1
    while i >= 0:
        if s[i] in '{}':
            if i > 0 and s[i-1] == '\\':
                # This bracket is escaped
                i -= 2
                continue
            return s[i]
        i -= 1
    return None

In [None]:
#| hide
assert _last_curly_bracket("{abc}") == '}'
assert _last_curly_bracket("}abc{") == '{'
assert _last_curly_bracket("abc{def}") == '}'
assert _last_curly_bracket("abc}def{") == '{'
assert _last_curly_bracket("abc") == None
assert _last_curly_bracket("abc\\}") == None
assert _last_curly_bracket("a{b\\}c}") == '}'
assert _last_curly_bracket("\\{\\}abc") == None
assert _last_curly_bracket("abc{\\\\}") == '{'
assert _last_curly_bracket("{abc\\{") == '{'
assert _last_curly_bracket("") == None
assert _last_curly_bracket("\\") == None
assert _last_curly_bracket("{}\\\\") == '}'
assert _last_curly_bracket("\\\\{}") == '}'

In [None]:
#| export
def _detect_backslash_space_curly(
        text: str
        ) -> bool:
    r"""
    Return `True` if there is some backslash `\` followed
    by spaces and then followed by curly brackets `{`

    Note that the presence of such a combination of text
    will induce a syntax error in LaTeX math mode string.

    This is a helper function of `math_mode_string_is_syntactically_valid`
    """
    pattern = r'\\\s+[{}]'
    match = re.search(pattern, text)
    return bool(match)

In [None]:
#| hide
assert _detect_backslash_space_curly(r'\ {')
assert not _detect_backslash_space_curly(r'\{')
assert not _detect_backslash_space_curly(r'{')
assert _detect_backslash_space_curly(r'\ }')

In [None]:
#| export
def _is_left_right_balanced(
        latex_string: str
        ) -> bool:
    r"""
    Return `True` if occurrences of `\left` and `\right` are balanced. 

    This is a helper function of `math_mode_string_is_syntactically_valid`

    This function does not test whether occurrences of the
    appropriately corresponding braces are balanced. For instance,
    the function would return `True` on the input `\left . \right)`.
    Compare against `_is_semantically_left_right_balanced`, which
    is a similar function that tests whether left-right braces
    are "semantically" balanced.
    """
    # Remove all whitespace from the string
    latex_string = ''.join(latex_string.split())
    
    # Find all \left and \right commands
    left_commands = re.findall(r'\\left', latex_string)
    right_commands = re.findall(r'\\right', latex_string)
    
    # Check if the number of \left and \right commands are equal
    if len(left_commands) != len(right_commands):
        return False
    
    # Check if \left always comes before \right
    left_indices = [m.start() for m in re.finditer(r'\\left', latex_string)]
    right_indices = [m.start() for m in re.finditer(r'\\right', latex_string)]
    
    for left, right in zip(left_indices, right_indices):
        if left > right:
            return False
    
    return True

In [None]:
#| hide
assert _is_left_right_balanced(r"\left( x \right)")
assert _is_left_right_balanced(r"\left(x\right)")
assert _is_left_right_balanced(r"\left( x \left[ y \right] \right)")
assert _is_left_right_balanced(r"\left( x \right) \left[ y \right]")
assert not _is_left_right_balanced(r"\left( x \left[ y \right)")
assert _is_left_right_balanced(r"\left( x \right] \left[ y \right)")
assert _is_left_right_balanced(r"x + y")
assert not _is_left_right_balanced(r"\right)")
assert _is_left_right_balanced(r"\left. \right)")


In [None]:
#| export
def _is_semantically_left_right_balanced(
        latex_string: str
        ) -> bool:
    r"""
    Return `True` if occurrences of `\left` and `\right` macros
    preceding various braces are balanced.

    This is a helper function of `math_mode_string_is_syntactically_clean`

    Compare against `_is_left_right_balanced`, which
    is a similar function that only tests whether left-right
    macros are "syntactically" balanced, without regard to
    the types of braces actually used.
    """
    # Remove all whitespace from the string
    latex_string = ''.join(latex_string.split())
    
    # Define a stack to keep track of opening delimiters
    stack = []
    
    # Define a dictionary to match opening and closing delimiters
    delimiters = {
        '(': ')', '[': ']', '{': '}', '<': '>',
        r'\left(': r'\right)', r'\left[': r'\right]', 
        r'\left{': r'\right}', r'\left<': r'\right>',
        r'\left\{': r'\right\}', r'\left|': r'\right|',
        r'\left\|': r'\right\|', r'\left.': r'\right.',
        r'\left\langle': r'\right\rangle'
    }
    
    # Regular expression to match \left and \right commands with their delimiters
    pattern = r'(\\left[\(\[\{\<\|\.\|]|\\right[\)\]\}\>\|\.\|]|\(|\)|\[|\]|\{|\}|\<|\>)'
    
    # Find all delimiters and \left/\right commands
    tokens = re.findall(pattern, latex_string)
    
    for token in tokens:
        if token.startswith(r'\left') or token in '([{<':
            stack.append(token)
        elif token.startswith(r'\right') or token in ')]}>':
            if not stack:
                return False
            last_open = stack.pop()
            if token != delimiters.get(last_open):
                return False
    
    # If the stack is empty, all delimiters are balanced
    return len(stack) == 0


In [None]:
#| hide
# Test cases
    
assert _is_semantically_left_right_balanced(r"\left( x \right)")
assert _is_semantically_left_right_balanced(r"\left( x \left[ y \right] \right)")
assert _is_semantically_left_right_balanced(r"\left( x \right) \left[ y \right]")
assert not _is_semantically_left_right_balanced(r"\left( x \left[ y \right)")
assert not _is_semantically_left_right_balanced(r"\left( x \right] \left[ y \right)")
assert _is_semantically_left_right_balanced(r"x + y")
assert _is_semantically_left_right_balanced(r"\left\{ x \left( y \right) \right\}")
assert not _is_semantically_left_right_balanced(r"\left. x \right|_{a}^{b}")
assert _is_semantically_left_right_balanced(r"\left\| x \right\|")
assert _is_semantically_left_right_balanced(r"\left< x , y \right>")
assert _is_semantically_left_right_balanced(r"\left( \left[ \left\{ x \right\} \right] \right)")
assert not _is_semantically_left_right_balanced(r"\left( \left[ \left\{ x \right\} \right] \right]")

# TODO: the example of r"\left. x \right|_{a}^{b}" could be osmething that occurs often. do something about this.

In [None]:
#| export
def _has_invalid_left_right_bracket(
        latex_string: str
        ) -> bool:
    r"""
    Return `True` is there is at least one invalid use of
    a `\left` or `\right` command.

    This is a helper function of `math_mode_string_is_syntactically_valid`
    """
    # Remove all whitespace from the string
    latex_string = ''.join(latex_string.split())
    
    # Define valid brackets for \left and \right
    valid_brackets = [
        r'(', r')',
        r'[', r']',
        r'\(', r'\)', r'\[', r'\]',  # Parentheses and square brackets
        r'\{', r'\}',                # Curly braces (escaped)
        r'<', r'>',                  # Angle brackets
        r'\|',                       # Vertical bar
        r'|',                       # Vertical bar
        r'\\\|',                     # Double vertical bar (escaped)
        r'\.',                       # Dot
        r'.',                       # Dot
        r'\\lfloor', r'\\rfloor',    # Floor brackets
        r'\\lceil', r'\\rceil',      # Ceiling brackets
        r'\\langle', r'\\rangle'     # Angle brackets (commands)
    ]
    
    # Escape special regex characters and join with |
    valid_brackets_pattern = '|'.join(re.escape(b) for b in valid_brackets)
    
    # Pattern to match \left or \right followed by a valid bracket
    valid_pattern = rf'\\(left|right)({valid_brackets_pattern})'
    
    # Find all \left and \right commands
    commands = list(re.finditer(r'\\(left|right)', latex_string))
    
    for command in commands:
        # Check if the command is followed by a valid bracket
        if not re.match(valid_pattern, latex_string[command.start():]):
            # If not, return True and the invalid command
            invalid_part = latex_string[command.start():command.start()+6]  # Adjust slice as needed
            # return True, invalid_part
            return True
    
    # return False, None
    return False

In [None]:
assert not _has_invalid_left_right_bracket(r"\left( x \right)")
assert not _has_invalid_left_right_bracket(r"\left[ x \right]")
assert not _has_invalid_left_right_bracket(r"\left\{ x \right\}")
assert not _has_invalid_left_right_bracket(r"\left< x \right>")
assert not _has_invalid_left_right_bracket(r"\left| x \right|")
assert not _has_invalid_left_right_bracket(r"\left\| x \right\|")
assert not _has_invalid_left_right_bracket(r"\left\| x \right\|")
assert _has_invalid_left_right_bracket(r"\lefta x \right)")
assert _has_invalid_left_right_bracket(r"\left( x \righta")
assert _has_invalid_left_right_bracket(r"\left x \right)")
assert _has_invalid_left_right_bracket(r"\left( x \right x")
assert _has_invalid_left_right_bracket(r"\left\backslash x \right/")
assert not _has_invalid_left_right_bracket(r"x + y")
assert _has_invalid_left_right_bracket(r"\left\\")
assert _has_invalid_left_right_bracket(r"\right\\")


In [None]:
#| export

def _has_double_script(
        latex_string: str
        ) -> bool:
    """
    Return `True` if there is at least one double superscript
    or double subscript in `latex_string` 

    This function fails to give correct outputs for more
    nuanced texts, such as `r"x^{2}_{3}^{4}"`; while in
    principle, the function should return `True` on this
    input, the actual return value is `False`.

    This is a helper function of `math_mode_string_is_syntactically_valid`
    """
    # Remove all whitespace from the string
    latex_string = ''.join(latex_string.split())
    
    # Function to match balanced braces
    def match_braces(s, start):
        count = 0
        for i, char in enumerate(s[start:], start):
            if char == '{':
                count += 1
            elif char == '}':
                count -= 1
                if count == 0:
                    return i
        return len(s) - 1

    # Find all subscripts and superscripts
    i = 0
    last_script = None
    while i < len(latex_string):
        if latex_string[i] in '^_' and (i == 0 or latex_string[i-1] != '\\'):
            current_script = latex_string[i]
            i += 1
            if i < len(latex_string):
                if latex_string[i] == '{':
                    end = match_braces(latex_string, i)
                    script_content = latex_string[i:end+1]
                    i = end + 1
                else:
                    script_content = latex_string[i]
                    i += 1
                
                if last_script and last_script[0] == current_script:
                    return True
                last_script = (current_script, script_content)
        else:
            if latex_string[i] not in '^_':
                last_script = None
            i += 1

    return False

In [None]:
#| hide
assert not _has_double_script(r"x^2")
assert not _has_double_script(r"x_2")
assert not _has_double_script(r"x^{2}")
assert not _has_double_script(r"x_{2}")
assert _has_double_script(r"x^2^3")
assert _has_double_script(r"x_2_3")
assert _has_double_script(r"x^{2}^{3}")
assert _has_double_script(r"x_{2}_{3}")
assert not _has_double_script(r"x^{2^3}")
assert not _has_double_script(r"x_{2_3}")
assert not _has_double_script(r"x_2^3")
assert not _has_double_script(r"x^2_3")
assert not _has_double_script(r"x\^2")
assert not _has_double_script(r"x\_2")
assert not _has_double_script(r"x^{2^{3^4}}")
assert not _has_double_script(r"x\_2")
assert not _has_double_script(r"x_{2_{3_4}}")
assert not _has_double_script(r'$$x^2 + y^2$$')
assert not _has_double_script(r'\operatorname{Res}^{G}_{H} M')

# The function does not work correctly on the below input
# assert _has_double_script(r"x^{2}_{3}^{4}")

In [None]:
#| export
def _has_double_script_literal(
        latex_string: str
        ) -> bool:
    """
    Return `True` if there is at least one double superscript
    or double subscript in `latex_string` by virtue of having
    `__`, `_^`, `^_, `^^`

    This is a helper function of `math_mode_string_is_syntactically_valid`
    """
    for bad_text in ['__', '_^', '^_', '^^']:
        if bad_text in latex_string:
            return True
    return False

In [None]:
#| hide
assert _has_double_script_literal(r"$$R=\sum_P\in X\operatorname length\left(\Omega__X / Y\right)_p\cdot P$$")

In [None]:
#| export
def _has_unescaped_dollar(s):
    # Pattern explanation:
    # (?<!\\) is a negative lookbehind assertion that ensures the dollar sign is not preceded by a backslash
    # \$ matches a literal dollar sign
    pattern = r'(?<!\\)(\\\\)*\$'
    match = re.search(pattern, s)
    return bool(match)

In [None]:
#| hide
assert not _has_unescaped_dollar("This is a normal string")
assert not _has_unescaped_dollar("This has an escaped dollar sign \\$")
assert _has_unescaped_dollar("This has an unescaped dollar sign $")
assert _has_unescaped_dollar("Mixed case: \\$ and $")
assert _has_unescaped_dollar("Multiple unescaped: $ $")
assert not _has_unescaped_dollar("Escaped at the end \\$")
assert _has_unescaped_dollar("Unescaped at the end $")
assert _has_unescaped_dollar("$")
assert not _has_unescaped_dollar("\\$")
assert _has_unescaped_dollar("\\\\$")  # Double backslash followed by dollar
assert not _has_unescaped_dollar("\\\\\\$")
assert _has_unescaped_dollar("\\ $")

In [None]:
#| export
def detect_unbalanced_environments(
        latex_string: str) -> list[str]:
    # Define a regex pattern to match \begin{...} and \end{...}
    pattern = r'\\(begin|end)\{([^}]+)\}'
    
    # Stack to keep track of opened environments
    stack = []
    # List to store errors
    errors = []

    # Find all matches in the LaTeX string
    for match in re.finditer(pattern, latex_string):
        command, env_name = match.groups()
        
        if command == 'begin':
            # Push the environment name onto the stack
            stack.append(env_name)
        elif command == 'end':
            # Check if there is a matching begin for this end
            if stack and stack[-1] == env_name:
                stack.pop()  # Match found, pop from stack
            else:
                # Mismatch found, record the error
                errors.append(f"Mismatched \\end{{{env_name}}} at position {match.start()}")

    # If there are any unmatched begin commands left in the stack, report them
    while stack:
        unmatched_env = stack.pop()
        errors.append(f"Unmatched \\begin{{{unmatched_env}}}")

    return errors

In [None]:

# Example usage
latex_code = r"""
\begin{document}
This is a sample document.
\begin{itemize}
    \item First item
    \begin{enumerate}
        \item First sub-item
    \end{enumerate}
    \item Second item
\end{itemize}
\end{document}
\begin{wrongenv}  % This environment is unmatched
"""

# Detect unbalanced environments
unbalanced = detect_unbalanced_environments(latex_code)

# Print the results
# if unbalanced:
#     print("Unbalanced environments detected:")
#     for error in unbalanced:
#         print(error)
# else:
#     print("All environments are balanced.")

assert unbalanced


latex_code = r"""
\begin{document}
This is a sample document.
\begin{itemize}
    \item First item
    \begin{enumerate}
        \item First sub-item
    \end{enumerate}
    \item Second item
\end{itemize}
\end{document}
"""
# Detect unbalanced environments
unbalanced = detect_unbalanced_environments(latex_code)
assert not unbalanced