In [80]:
import os
import re

def clean_latex_file(input_str):
    r"""
    Cleans a LaTeX file by removing comments, extracting content within \begin{document} and \end{document},
    handling \maketitle by replacing it with the title and authors in Markdown format,
    and converting sections, figures, itemize/enumerate, equations, tables, etc. into a Markdown-friendly format.    
    """
    content = None
    
    if isinstance(input_str, str) and '\n' in input_str:
        content = input_str

    # Try to load from file if input is a filename
    if not content:
        if not input_str.endswith('.tex'):
            raise ValueError("Input file must have a .tex extension.")

        base, _ = os.path.splitext(input_str)
        output_path = f"{base}.txt"
    
        try:
            with open(input_str, 'r', encoding='utf-8') as infile:
                content = infile.read()
        except FileNotFoundError:
            print(f"Error: The file {input_str} does not exist.")
            return
        except Exception as e:
            print(f"An error occurred: {e}")
            return
    else:
        output_path = "output.txt"

    def preprocess_abstract(text):
        text = re.sub(r'\\begin\{abstract\}', r'\\section*{Abstract}', text)
        text = re.sub(r'\\end\{abstract\}', '', text)
        return text

    content = preprocess_abstract(content)

    # Extract \title and \author from the entire content before extracting the document
    def extract_title_authors(text):
        title = ""
        authors = ""
        
        # Pattern to match \title{...} with nested braces
        title_pattern = re.compile(r'\\title\{')
        title_match = title_pattern.search(text)
        if title_match:
            start = title_match.end()
            brace_level = 1
            i = start
            while i < len(text) and brace_level > 0:
                if text[i] == '{':
                    brace_level += 1
                elif text[i] == '}':
                    brace_level -= 1
                i += 1
            if brace_level == 0:
                title = text[start:i-1].strip()
                # Remove the \title{...} from text
                text = text[:title_match.start()] + text[i:]
            else:
                print("Warning: Unbalanced braces in \\title command.")
                title = text[start:].strip()
                text = text[:title_match.start()]
        
        # Pattern to match \author{...} with nested braces
        author_pattern = re.compile(r'\\author\{')
        author_match = author_pattern.search(text)
        if author_match:
            start = author_match.end()
            brace_level = 1
            i = start
            while i < len(text) and brace_level > 0:
                if text[i] == '{':
                    brace_level += 1
                elif text[i] == '}':
                    brace_level -= 1
                i += 1
            if brace_level == 0:
                authors = text[start:i-1].strip()
                # Remove the \author{...} from text
                text = text[:author_match.start()] + text[i:]
            else:
                print("Warning: Unbalanced braces in \\author command.")
                authors = text[start:].strip()
                text = text[:author_match.start()]
        
        return text, title, authors

    content, title, authors = extract_title_authors(content)

    # Extract content between \begin{document} and \end{document}, if present
    doc_match = re.search(r'\\begin\{document\}(.*?)\\end\{document\}', content, flags=re.DOTALL)
    if doc_match:
        document_content = doc_match.group(1)
    else:
        # If \begin{document} is not found, assume entire content is the document
        document_content = content

    # Remove comments (unescaped %)
    document_content = re.sub(r'(?<!\\)%.*', '', document_content)

    # Normalize whitespace
    document_content = re.sub(r'\r\n', '\n', document_content)
    document_content = re.sub(r'\n{3,}', '\n\n', document_content)
    document_content = document_content.replace('\n\n', '<<<PARA_BREAK>>>')
    document_content = re.sub(r'\n', ' ', document_content)
    document_content = re.sub(r'\s+', ' ', document_content).strip()
    document_content = document_content.replace('<<<PARA_BREAK>>>', '\n\n')

    def apply_inline_formats(text):
        # \emph{...}, \textit{...} -> *...*
        emph_pattern = re.compile(r'\\(?:emph|textit)\{(.*?)\}')
        text = emph_pattern.sub(lambda m: "*" + m.group(1).strip() + "*", text)

        # \textbf{...} -> **...**
        bold_pattern = re.compile(r'\\textbf\{(.*?)\}')
        text = bold_pattern.sub(lambda m: "**" + m.group(1).strip() + "**", text)

        # \textsc{...} -> `...`
        textsc_pattern = re.compile(r'\\textsc\{(.*?)\}')
        text = textsc_pattern.sub(lambda m: "`" + m.group(1).strip() + "`", text)

        return text

    def replace_maketitle(text, title, authors):
        r"""
        Replaces \maketitle with Markdown-formatted title and authors.
        """
        # Prepare Markdown replacement
        replacement = ""
        if title:
            replacement += f"# {apply_inline_formats(title)}\n\n"
        if authors:
            # Split authors by \and or commas
            authors_clean = re.split(r'\\and|,', authors)
            authors_clean = [a.strip() for a in authors_clean if a.strip()]
            authors_md = ', '.join(authors_clean)
            replacement += f"**Authors:** {authors_md}\n\n"

        # Replace \maketitle with the prepared Markdown using a lambda to avoid escape sequence issues
        maketitle_pattern = re.compile(r'\\maketitle')
        text = maketitle_pattern.sub(lambda m: replacement, text)

        return text

    # Replace \maketitle in the document content
    document_content = replace_maketitle(document_content, title, authors)

    def replace_figures(text):
        figure_env = re.compile(r'\\begin\{figure.*?\}(.*?)\\end\{figure.*?\}', flags=re.DOTALL)

        def figure_repl(m):
            inner = m.group(1)
            captions = []
            labels = []

            # Extract captions with nested brace handling
            cap_start_pattern = re.compile(r'\\caption\{')
            pos = 0
            while True:
                cmatch = cap_start_pattern.search(inner, pos)
                if not cmatch:
                    break
                start = cmatch.end()
                brace_level = 1
                i = start
                while i < len(inner) and brace_level > 0:
                    if inner[i] == '{':
                        brace_level += 1
                    elif inner[i] == '}':
                        brace_level -= 1
                    i += 1
                if brace_level == 0:
                    caption_text = inner[start:i-1].strip()
                    # Remove any \label{...} from caption_text
                    caption_text = re.sub(r'\\label\{[^}]+\}', '', caption_text).strip()
                    captions.append(caption_text)
                    pos = i
                else:
                    # Unbalanced braces
                    print("Warning: Unbalanced braces in figure caption.")
                    caption_text = inner[start:].strip()
                    # Remove any \label{...} from caption_text
                    caption_text = re.sub(r'\\label\{[^}]+\}', '', caption_text).strip()
                    captions.append(caption_text)
                    break

            full_caption = ' '.join(captions).strip()
            full_caption = apply_inline_formats(full_caption)

            # Extract labels
            label_pattern = re.compile(r'\\label\{([^}]+)\}')
            labels = label_pattern.findall(inner)

            figure_markdown = "\n\n**Figure:** " + full_caption
            for label in labels:
                figure_markdown += f" \\label{{{label}}}"
            figure_markdown += "\n\n"

            return figure_markdown

        return figure_env.sub(figure_repl, text)

    document_content = replace_figures(document_content)

    def has_nested_tabulars(text):
        """
        Checks if there are nested tabular environments within the given text.
        
        Parameters:
            text (str): The text to check for nested tabulars.
        
        Returns:
            bool: True if nested tabulars are detected, False otherwise.
        """
        open_tabulars = 0
        pos = 0
        while pos < len(text):
            begin_match = re.search(r'\\begin\{tabular\}', text[pos:])
            end_match = re.search(r'\\end\{tabular\}', text[pos:])
            if begin_match:
                begin_pos = pos + begin_match.start()
            else:
                begin_pos = None
            if end_match:
                end_pos = pos + end_match.start()
            else:
                end_pos = None
            if begin_pos is not None and (end_pos is None or begin_pos < end_pos):
                open_tabulars += 1
                if open_tabulars > 1:
                    return True
                pos = begin_pos + len('\\begin{tabular}')
            elif end_pos is not None:
                open_tabulars -= 1
                pos = end_pos + len('\\end{tabular}')
            else:
                break
        return False

    def convert_tabular_to_markdown(inner):
        try:
            # Remove scalebox if present
            while True:
                scalebox_match = re.search(r'\\scalebox\{[^\}]*\}\{', inner)
                if not scalebox_match:
                    break
                inner = re.sub(r'\\scalebox\{[^\}]*\}\{(.*?)\}', r'\1', inner, flags=re.DOTALL)

            # Check for nested tabulars
            if has_nested_tabulars(inner):
                # Nested tabulars detected
                print("Warning: Nested tabular environments detected within a table.")
                return None  # Signal to handle verbatim copy

            # Find all tabular environments
            tabular_env = re.compile(r'\\begin\{tabular\}\{.*?\}(.*?)\\end\{tabular\}', flags=re.DOTALL)
            tabulars = list(tabular_env.finditer(inner))

            if not tabulars:
                return ""

            md_tables = []
            for tabular in tabulars:
                tabular_content = tabular.group(1)

                # Remove booktabs lines
                tabular_content = re.sub(r'\\toprule', '', tabular_content)
                tabular_content = re.sub(r'\\midrule', '', tabular_content)
                tabular_content = re.sub(r'\\bottomrule', '', tabular_content)
                tabular_content = re.sub(r'\\cmidrule\{[^\}]*\}', '', tabular_content)

                rows = re.split(r'\\\\', tabular_content)
                rows = [r.strip() for r in rows if r.strip()]

                if not rows:
                    continue  # Skip empty tabular

                table_rows = []
                for r in rows:
                    r = re.sub(r'\\textcolor\{[^\}]*\}\{(.*?)\}', r'\1', r)
                    r = re.sub(r'\\textbf\{(.*?)\}', r'**\1**', r)
                    r = re.sub(r'\\emph\{(.*?)\}', r'*\1*', r)

                    # Replace escaped chars
                    r = r.replace(r'\&', '&').replace(r'\\', '\\')

                    cells = [c.strip() for c in r.split('&')]
                    table_rows.append(cells)

                num_cols = len(table_rows[0])

                md_table = "\n\n| " + " | ".join(table_rows[0]) + " |\n"
                md_table += "| " + " | ".join(["---"] * num_cols) + " |\n"
                for row in table_rows[1:]:
                    if len(row) < num_cols:
                        row += [""] * (num_cols - len(row))
                    elif len(row) > num_cols:
                        row = row[:num_cols]
                    md_table += "| " + " | ".join(row) + " |\n"
                md_table += "\n\n"

                md_tables.append(md_table)

            # Combine all markdown tables
            return ''.join(md_tables)
        except Exception as e:
            print(f"Error during tabular conversion: {e}")
            return ""

    def remove_captions_and_labels(tex):
        r"""
        Remove all \caption{...} commands (with possible nested braces)
        and all \label\{...\} commands from the given LaTeX code.
        """
        # Remove captions with nested braces
        out = ""
        start_idx = 0
        caption_pattern = re.compile(r'\\caption\{')
        while True:
            cmatch = caption_pattern.search(tex, start_idx)
            if not cmatch:
                # no more captions
                break
            out += tex[start_idx:cmatch.start()]
            # find matching braces
            pos = cmatch.end()
            brace_level = 1
            while pos < len(tex) and brace_level > 0:
                if tex[pos] == '{':
                    brace_level += 1
                elif tex[pos] == '}':
                    brace_level -= 1
                pos += 1
            # skip this entire caption block
            start_idx = pos
        out += tex[start_idx:]

        # Now remove labels
        out = re.sub(r'\\label\{[^}]+\}', '', out)
        return out

    def replace_tables(text):
        table_env = re.compile(r'(\\begin\{table.*?\}.*?\\end\{table.*?\})', flags=re.DOTALL)

        def table_repl(m):
            entire_table = m.group(1)

            # Extract captions with nested brace handling
            captions = []
            pos = 0
            cap_start_pattern = re.compile(r'\\caption\{')
            while True:
                cmatch = cap_start_pattern.search(entire_table, pos)
                if not cmatch:
                    break
                start = cmatch.end()
                brace_level = 1
                i = start
                while i < len(entire_table) and brace_level > 0:
                    if entire_table[i] == '{':
                        brace_level += 1
                    elif entire_table[i] == '}':
                        brace_level -= 1
                    i += 1
                if brace_level == 0:
                    caption_text = entire_table[start:i-1].strip()
                    # Remove any \label{...} from caption_text
                    caption_text = re.sub(r'\\label\{[^}]+\}', '', caption_text).strip()
                    captions.append(caption_text)
                    pos = i
                else:
                    print("Warning: Unbalanced braces in table caption.")
                    caption_text = entire_table[start:].strip()
                    # Remove any \label{...} from caption_text
                    caption_text = re.sub(r'\\label\{[^}]+\}', '', caption_text).strip()
                    captions.append(caption_text)
                    break

            full_caption = ' '.join(captions).strip()
            full_caption = apply_inline_formats(full_caption)

            # Extract labels
            label_pattern = re.compile(r'\\label\{([^}]+)\}')
            labels = label_pattern.findall(entire_table)

            # Convert tabular to markdown
            markdown_table = convert_tabular_to_markdown(entire_table)

            if markdown_table is None:
                # Nested tabulars detected
                print("Warning: Nested tabular environments detected. Retaining tabular content as LaTeX code.")
                
                # Extract all tabular environments
                tabulars = re.findall(r'\\begin\{tabular\}.*?\\end\{tabular\}', entire_table, flags=re.DOTALL)
                tabular_content = '\n'.join(tabulars)

                # Prepare Markdown with caption and label, and include tabular as code block
                labels_str = ' '.join([f"\\label{{{label}}}" for label in labels])
                table_markdown = f"\n\n**Table:** {full_caption} {labels_str}\n\n```latex\n{tabular_content}\n```\n\n"

                return table_markdown

            if not markdown_table:
                # Conversion failed or no tabulars found
                print("Warning: Table conversion failed or no tabular environments found. Retaining original LaTeX table without captions and labels.")
                cleaned_table = remove_captions_and_labels(entire_table)
                # Append labels to the caption
                labels_str = ' '.join([f"\\label{{{label}}}" for label in labels])
                return f"\n\n**Table:** {full_caption} {labels_str}\n\n```latex\n{cleaned_table}\n```\n\n"

            # Conversion succeeded
            # Append labels to the caption
            labels_str = ' '.join([f"\\label{{{label}}}" for label in labels])
            table_markdown = f"\n\n**Table:** {full_caption} {labels_str}\n\n{markdown_table}\n"

            return table_markdown

        return table_env.sub(table_repl, text)

    document_content = replace_tables(document_content)

    def replace_equations(text):
        eq_env = re.compile(r'\\begin\{equation\}(.*?)\\end\{equation\}', flags=re.DOTALL)

        def eq_repl(m):
            eq_text = m.group(1).strip()
            eq_text = re.sub(r'(\\label\{[^}]+\})\s*', r'\1\n', eq_text)
            return "\n\n$$\n" + eq_text + "\n$$\n\n"

        return eq_env.sub(eq_repl, text)

    document_content = replace_equations(document_content)

    def replace_headings(text):
        sec_pattern = re.compile(
            r'\\(section|subsection|subsubsection|paragraph|runningtitle)\*?\{(.*?)\}'
            r'(?:\s*\\label\{([^}]+)\})?', flags=re.DOTALL
        )

        def sec_repl(m):
            level_map = {
                "section": 1,
                "subsection": 2,
                "subsubsection": 3,
                "paragraph": 4,
                "runningtitle": 1
            }
            level = level_map.get(m.group(1), 2)
            title = m.group(2).strip()
            label = m.group(3)

            markdown = "\n\n" + "#"*level + " " + title
            if label:
                markdown += "\n\\label{" + label + "}"
            markdown += "\n\n"

            return markdown

        return sec_pattern.sub(sec_repl, text)

    document_content = replace_headings(document_content)

    def replace_lists(text):
        enum_env = re.compile(r'\\begin\{enumerate\}(\[[^\]]*\])?(.*?)\\end\{enumerate\}', flags=re.DOTALL)
        def enum_repl(m):
            inner = m.group(2)
            items = re.split(r'\\item', inner)
            items = [i.strip() for i in items if i.strip()]
            result = "\n\n"
            for idx, it in enumerate(items, start=1):
                it = apply_inline_formats(it)
                result += f"{idx}. {it}\n"
            result += "\n"
            return result

        text = enum_env.sub(enum_repl, text)

        item_env = re.compile(r'\\begin\{itemize\}(\[[^\]]*\])?(.*?)\\end\{itemize\}', flags=re.DOTALL)
        def item_repl(m):
            inner = m.group(2)
            items = re.split(r'\\item', inner)
            items = [i.strip() for i in items if i.strip()]
            result = "\n\n"
            for it in items:
                it = apply_inline_formats(it)
                result += f"- {it}\n"
            result += "\n"
            return result

        text = item_env.sub(item_repl, text)
        text = re.sub(r'\\item\s+', '\n- ', text)

        return text

    document_content = replace_lists(document_content)

    def replace_inline_formats_func(text):
        return apply_inline_formats(text)

    document_content = replace_inline_formats_func(document_content)

    def remove_latex_command(text, command):
        """
        Removes all instances of a LaTeX command with its argument, handling nested braces.
        
        Parameters:
            text (str): The input text.
            command (str): The LaTeX command to remove (e.g., 'ignore').
        
        Returns:
            str: The text with the specified command removed.
        """
        pattern = re.compile(r'\\' + re.escape(command) + r'\{')
        result = []
        pos = 0
        while True:
            cmatch = pattern.search(text, pos)
            if not cmatch:
                result.append(text[pos:])
                break
            start = cmatch.start()
            result.append(text[pos:start])
            brace_level = 1
            i = cmatch.end()
            while i < len(text) and brace_level > 0:
                if text[i] == '{':
                    brace_level += 1
                elif text[i] == '}':
                    brace_level -= 1
                i += 1
            pos = i  # Move past the closing brace
        return ''.join(result)

    def remove_leftover_commands(text):
        r"""
        Remove all specified LaTeX commands from the text, handling nested braces.
        """
        commands_to_remove = ['vspace', 'hspace', 'bigskip', 'smallskip', 'medskip', 'ignore', 'bibliographystyle']
        commands_to_replace_newline = ['newpage', 'pagebreak', 'linebreak', 'clearpage', 'cleardoublepage']

        for cmd in commands_to_remove:
            text = remove_latex_command(text, cmd)

        # Remove commands without arguments
        remove_pattern = re.compile(
            r'\\(?:' + '|'.join(commands_to_remove) + r')\s*'
        )
        text = remove_pattern.sub('', text)

        # Replace newline commands with two newlines
        replace_newline_pattern = re.compile(
            r'\\(?:' + '|'.join(commands_to_replace_newline) + r')\s*'
        )
        text = replace_newline_pattern.sub('\n\n', text)

        # Clean up extra newlines
        text = re.sub(r'\n{3,}', '\n\n', text)
        return text

    document_content = remove_leftover_commands(document_content)

    def remove_formatting_cmds(text):
        formatting_cmds = re.compile(
            r'\\(vspace|hspace|bigskip|newpage|smallskip|medskip|pagebreak|linebreak|clearpage|cleardoublepage)'
            r'(\[[^\]]*\])?(\{[^}]*\})?'
        )
        return formatting_cmds.sub(' ', text)

    document_content = remove_formatting_cmds(document_content)

    document_content = re.sub(r'\n\n\s+', '\n\n', document_content)

    def final_cleanup(text):
        text = re.sub(r' {2,}', ' ', text)
        text = re.sub(r'\n{3,}', '\n\n', text)
        return text.strip() + "\n"

    document_content = final_cleanup(document_content)

    try:
        with open(output_path, 'w', encoding='utf-8') as outfile:
            outfile.write(document_content)
        print(f"Cleaned file has been written to: {output_path}")
    except Exception as e:
        print(f"An error occurred while writing the output file: {e}")


In [81]:
import requests
import tarfile
import os
import re
from pathlib import Path
import tempfile

def get_arxiv_tex(identifier):
    """
    Download and extract arXiv source files, returning the main tex content.
    
    Args:
        identifier (str): Either full arXiv URL or just the paper number (e.g. '2412.06264')
        
    Returns:
        str: Content of the main tex file
        
    Raises:
        ValueError: If the identifier is invalid or source files cannot be accessed
        RuntimeError: If no main tex file is found
    """
    # Extract paper number from URL if needed
    if identifier.startswith('http'):
        match = re.search(r'arxiv.org/(?:abs|pdf)/(\d+\.\d+)', identifier)
        if not match:
            raise ValueError("Invalid arXiv URL format")
        paper_number = match.group(1)
    else:
        # Verify the paper number format
        if not re.match(r'^\d+\.\d+$', identifier):
            raise ValueError("Invalid arXiv identifier format")
        paper_number = identifier
    
    # Create source URL
    source_url = f'https://arxiv.org/src/{paper_number}'
    
    # Download the source files
    response = requests.get(source_url)
    if response.status_code != 200:
        raise ValueError(f"Failed to download source files (Status code: {response.status_code})")
    
    # Create a temporary directory for extraction
    with tempfile.TemporaryDirectory() as temp_dir:
        # Save the downloaded tar file
        tar_path = Path(temp_dir) / 'source.tar.gz'
        with open(tar_path, 'wb') as f:
            f.write(response.content)
        
        # Extract the tar file using the new filter parameter
        with tarfile.open(tar_path, 'r:gz') as tar:
            tar.extractall(path=temp_dir, filter='data')
        
        # Look for main tex file
        tex_files = list(Path(temp_dir).rglob('*.tex'))
        if not tex_files:
            raise RuntimeError("No tex files found in the source")
        
        # Common main file names (all lowercase for comparison)
        main_candidates = [
            'main.tex',
            'paper.tex',
            'article.tex',
            'manuscript.tex',
            'submission.tex',
            'arxiv.tex',
            'document.tex',
            'draft.tex',
            'preprint.tex',
            'source.tex',
            'neurips.tex',
            'icml.tex',
            'iclr.tex',
            'aaai.tex',
            'ijcai.tex',
            f'{paper_number}.tex'
        ]
        
        main_file = None
        
        # First try common file names (case-insensitive)
        for candidate in main_candidates:
            for tex_file in tex_files:
                if tex_file.name.lower() == candidate:
                    main_file = tex_file
                    break
            if main_file:
                break
                
        # If no common names found, try the directory name (case-insensitive)
        if not main_file:
            for tex_file in tex_files:
                if tex_file.parent.name.lower() + '.tex' == tex_file.name.lower():
                    main_file = tex_file
                    break
        
        # If still no match, look for file with \documentclass
        if not main_file:
            for tex_file in tex_files:
                with open(tex_file, 'r', encoding='utf-8') as f:
                    content = f.read()
                    if r'\documentclass' in content:
                        main_file = tex_file
                        break
        
        # If still no main file found, use the first tex file
        if not main_file and tex_files:
            main_file = tex_files[0]
        
        if not main_file:
            raise RuntimeError("Could not identify main tex file")
        
        # Read and return the content
        with open(main_file, 'r', encoding='utf-8') as f:
            return f.read()
            
# Example usage:
# tex_content = get_arxiv_tex('2412.06264')
# tex_content = get_arxiv_tex('https://arxiv.org/abs/2412.06264')

In [82]:
#input_file = './maintext.tex'  # Replace with your .tex file path
input_file = get_arxiv_tex('2412.06264')
clean_latex_file(input_file)

Cleaned file has been written to: output.txt
