In [4]:
from tree_sitter import Language, Parser, Node
from tree_sitter_cpp import language
from tqdm import tqdm
import os
import pandas as pd
from typing import Callable, Optional
import re
from pathlib import Path

tqdm.pandas()

class FunctionReplacer:
    def __init__(self, output_postfix: str = '_updated', parser: Parser = None):
        self.output_postfix = output_postfix
        self.parser = parser
        
        if self.parser is None:
            self.parser, self.c_language = self.get_tree_sitter_parser()

    def get_tree_sitter_parser(self):
        """
        Get a tree-sitter parser for C language.

        Returns:
            Parser or None if failed
        """
        try:
            # Get the C language from the tree-sitter-c package
            c_language = Language(language())

            # Initialize the parser with the language
            parser = Parser(c_language)

            return parser, c_language
        except Exception as e:
            print(f"Tree-sitter initialization error: {e}")
            return None

    def get_identifiers(self, code: str):
        """
        Extract all identifiers from the given code using tree-sitter.

        Args:
            code: The source code as a string

        Returns:
            A dictionary mapping identifier names to lists of tree-sitter nodes
        """
        code_bytes = bytes(code, 'utf8')
        tree = self.parser.parse(code_bytes)
        root = tree.root_node

        identifiers = {}
        stack = [root]
        while stack:
            node = stack.pop()
            if node.type == "identifier":
                identifier_name = code_bytes[node.start_byte:node.end_byte].decode('utf8')
                
                if identifier_name in identifiers:
                    identifiers[identifier_name].append(node)
                else:
                    identifiers[identifier_name] = [node]

            stack.extend(node.children)

        return identifiers

    # def get_first_function_name(self, code: str):
    #     """
    #     Get the name of the first function defined in the code.
        
    #     Args:
    #         code: The source code to parse
            
    #     Returns:
    #         The name of the first function or None if no function is found
    #     """
    #     code_bytes = bytes(code, 'utf8')
    #     tree = self.parser.parse(code_bytes)
    #     root = tree.root_node

    #     stack = [root]
    #     while stack:
    #         node = stack.pop()
    #         if node.type == "function_definition":
    #             declarator = node.child_by_field_name("declarator")
    #             for child_decorator in declarator.children:
    #                 if child_decorator.type == "identifier":
    #                     identifier_name = code_bytes[child_decorator.start_byte:child_decorator.end_byte].decode('utf8')
    #                     return identifier_name

    #         stack.extend(node.children)

    #     return None
    def get_first_function_name(self, code: str) -> str | None:
        """
        Return the name of the first function defined in *code* or None
        if no function definition is present.
        """
        code_bytes = code.encode("utf-8")
        tree       = self.parser.parse(code_bytes)
        root       = tree.root_node

        stack = [root]
        while stack:
            node = stack.pop()

            if node.type == "function_definition":
                decl = node.child_by_field_name("declarator")
                if decl:
                    # depth-first, but left-to-right
                    id_stack = [decl]
                    while id_stack:
                        n = id_stack.pop()
                        if n.type == "identifier":
                            return code_bytes[n.start_byte:n.end_byte].decode("utf-8")

                        # push children in reverse so the *first* child is popped next
                        id_stack.extend(reversed(n.children))

            stack.extend(node.children)

        return None
    
    def get_function_names(self, code: str) -> list[str]:
        """
        Return a list containing the names of *all* functions defined in *code*
        (in lexical order).  The list is empty if the translation unit contains
        no function definitions.
        """
        code_bytes = code.encode("utf-8")
        tree       = self.parser.parse(code_bytes)
        root       = tree.root_node

        names: list[str] = []
        stack = [root]

        while stack:
            node = stack.pop()

            if node.type == "function_definition":
                decl = node.child_by_field_name("declarator")
                if decl:
                    id_stack = [decl]
                    while id_stack:
                        n = id_stack.pop()
                        if n.type == "identifier":
                            names.append(code_bytes[n.start_byte:n.end_byte].decode("utf-8"))
                            break                    
                        id_stack.extend(reversed(n.children))

            stack.extend(reversed(node.children))

        return names

    def constant_transform(self, new_name: str) -> Callable[[Node, int, bytes], str]:
        """
        Factory that ignores the node/id and always returns `new_name`.
        
        Args:
            new_name: The name to return
            
        Returns:
            A function that always returns new_name
        """
        return lambda _node, _id, _bytes: new_name

    def rename_identifiers(self, code: str, transform: Callable[[Node, int, bytes], str], rename_pattern: str = r'.*(good|bad|cwe).*', cache: dict[str, str] = None):
        """
        Renames identifiers in `code` that match `rename_pattern` using `transform` function.
        
        Args:
            code: The source code or file path
            transform: Function to transform identifiers
            rename_pattern: Regex pattern to match identifiers
            cache: Optional cache of identifier mappings
            
        Returns:
            The transformed code
        """
        identifiers = self.get_identifiers(code)
        code_edits = []
        id = 0

        if os.path.exists(code):
            with open(code, "r", encoding="utf-8") as file:
                code = file.read()
                code_bytes = bytearray(code, 'utf8')
        else:
            code_bytes = bytearray(code, 'utf8')

        if cache is None:
            cache = {}

        for identifier, nodes in identifiers.items():
            # Skip preprocessor directives
            if "preproc" in nodes[0].parent.type:
                continue

            # Skip identifiers that don't match the pattern
            if not re.search(rename_pattern, identifier, flags=re.IGNORECASE):
                continue

            if cache and identifier in cache:
                renamed_identifier = cache[identifier]
            else:
                renamed_identifier = transform(nodes[0], id, code_bytes)
                cache[identifier] = renamed_identifier
            
            renamed_bytes = renamed_identifier.encode('utf8')
            for node in nodes:
                code_edits.append((node.start_byte, node.end_byte, renamed_bytes))
            id += 1

        # Sort by start byte descending so earlier edits don't shift offsets
        code_edits.sort(key=lambda x: x[0], reverse=True)

        for start_byte, end_byte, renamed_bytes in code_edits:
            code_bytes[start_byte: end_byte] = renamed_bytes

        return code_bytes.decode('utf8')

    def replace_function_definition(self, code: str, function_name: str, new_definition: str) -> str:
        """
        Replace a function definition in code with a new definition.
        
        Args:
            code: The source code or file path
            function_name: The name of the function to replace
            new_definition: The new function definition
            
        Returns:
            The updated code or file path if code was a file path
        """
        if os.path.exists(code):
            with open(code, "r", encoding="utf-8") as file:
                code_bytes = bytearray(file.read(), 'utf8')
        else:
            code_bytes = bytearray(code, 'utf8')

        tree = self.parser.parse(code_bytes)

        # Locate definition of function_name
        function_node = None
        root = tree.root_node
        stack = [root]
        while stack:
            node = stack.pop()
            
            if node.type == "function_definition":
                declarator = node.child_by_field_name("declarator")
                for child_decorator in declarator.children:
                    if child_decorator.type == "identifier":
                        identifier_name = code_bytes[child_decorator.start_byte:child_decorator.end_byte].decode('utf8')
                
                        if identifier_name == function_name:
                            print(f"Found {function_name} in {code[:min(len(code), 100)]}...")
                            function_node = node
                            break
                    
            stack.extend(node.children)
        
        if function_node is None:
            raise ValueError(f"No function definition for '{function_name}' found in {code}")

        # Replace previous definition with new definition
        code_bytes[function_node.start_byte:function_node.end_byte] = new_definition.encode("utf-8")
        updated_code = code_bytes.decode("utf-8")

        # Get the name of the new definition and update function calls to previous name
        # Note: This code is assuming the new definition has the same input and output parameters
        new_name = self.get_first_function_name(new_definition)

        if new_name != function_name:
            updated_code = self.rename_identifiers(
                updated_code,
                transform=self.constant_transform(new_name),
                rename_pattern=rf"^{re.escape(function_name)}$",   # match only exact old_name
                cache={}
            )

        if os.path.exists(code):
            output_path = code.replace(".c", f"{self.output_postfix}.c")
            with open(output_path, "w", encoding="utf-8") as file:
                file.write(updated_code)
            return str(output_path)
        else:
            return updated_code
    
    def remove_comments_from_source(self, code: str) -> str:
        """
        Remove all comments from the given source code using tree-sitter.

        Args:
            source_code: The source code to remove comments from

        Returns:
            The source code with all comments removed
        """
        

        if os.path.exists(code):
            with open(code, "r", encoding="utf-8") as file:
                code_bytes = bytearray(file.read(), 'utf8')
        else:
            code_bytes = bytearray(code, 'utf8')

        tree = self.parser.parse(code_bytes)

        try:
            # Find all comment nodes
            results = self.c_language.query('(comment) @comment').captures(tree.root_node)
            if not "comment" in results:
                # Silently handle case where no comments are found
                return code
            
            comment_nodes = []
            for node in results['comment']:
                comment_nodes.append(node)

            # If no comments found, return the original source code
            if not comment_nodes:
                return code

            # Sort comment nodes by start position in reverse order
            # (to avoid invalidating positions when removing comments)
            comment_nodes.sort(key=lambda node: node.start_byte, reverse=True)

            # Remove each comment
            for node in comment_nodes:
                code_bytes = code_bytes[:node.start_byte] + code_bytes[node.end_byte:]

            # Convert back to string
            return code_bytes.decode('utf8')
        
        except Exception as e:
            print(f"Error removing comments:\n{e}")
            return code

    def process_row(self, row, function_column='function', new_function_column='generated_function', code_column='file', output_column='updated_file'):
        """
        Process a row from a DataFrame, replacing a function in the file with a new function.
        
        Args:
            row: The DataFrame row
            function_column: Column containing the function name to replace
            new_function_column: Column containing the new function definition
            file_column: Column containing the file path
            output_column: Column to store the output file path
            
        Returns:
            The updated row
        """
        row[output_column] = None
        code = row[code_column]
        function_name = row[function_column]
        new_function = row[new_function_column]
        
        if pd.isna(code) or pd.isna(function_name) or pd.isna(new_function):
            return row
        
        if not code or not code.strip():
            print(f"Empty code in row: {row}")
            return row
        
        # if not os.path.exists(file_path):
        #     print(f"File not found: {file_path}")
        #     return row
        
        try:
            output_path = self.replace_function_definition(code, function_name, new_function)
            row[output_column] = output_path
            return row
        except Exception as e:
            print(f"Error processing file {code}: {e}")
            return row
        
    def print_tree_structure_recursive(self, node: Node, code_bytes: bytes, depth: int = 0, print_id = False):
        # Print the current node with proper indentation
        node_text = code_bytes[node.start_byte:node.end_byte].decode('utf8')
        # Truncate long text for readability
        if len(node_text) > 50:
            node_text = node_text[:70] + "..."
        
        # Replace newlines with spaces for compact display
        node_text = node_text.replace('\n', ' ')
        id_text = f"id: {node.id}, " if print_id else ""
        print(f"{'  ' * depth}{id_text}{node.type}: {node_text}")
        
        # Recursively process all children with increased depth
        for child in node.children:
            self.print_tree_structure_recursive(child, code_bytes, depth + 1)

    def print_tree_structure(self, code: str, print_id = False):
        """
        Print the tree structure of the given code.

        Args:
            code: The source code as a string
        """
        if os.path.exists(code):
            with open(code, "r", encoding="utf-8") as file:
                code = file.read()
                code_bytes = bytearray(code, 'utf8')
        else:
            code_bytes = bytearray(code, 'utf8')
            
        tree = self.parser.parse(code_bytes)
        self.print_tree_structure_recursive(tree.root_node, code_bytes, print_id = False)

print("Loading data...")
replacer = FunctionReplacer()


Loading data...


In [7]:
file_path = "../temp_test_repos/libssh/src/chachapoly.c"  # replace with your actual file path
with open(file_path, "r", encoding="utf-8") as f:
    code = f.read()

function_names = replacer.get_function_names(code)
print(function_names)

['chacha20_set_encrypt_key', 'chacha20_poly1305_aead_encrypt', 'chacha20_poly1305_aead_decrypt_length', 'chacha20_poly1305_aead_decrypt', 'chacha20_cleanup', 'ssh_get_chacha20poly1305_cipher']
