In [None]:
%pip install --upgrade pip
%pip install unstructured
%pip install unstructured-ingest
%pip install llama-index
%pip install llama-index-readers-json
%pip install llama-index-readers-file
%pip install llama-index-graph-stores-neo4j
%pip install llama-index-embeddings-azure-openai
%pip install llama-index-llms-azure-openai
%pip install llama-index-llms-huggingface
%pip install llama-index-vector-stores-neo4jvector
%pip install llama-index-extractors-entity
%pip install neo4j
%pip install "transformers[torch]" "huggingface_hub[inference]"
%pip install accelerate
%pip install python-dotenv
%pip install llama-index-embeddings-huggingface
%pip install llama-index-embeddings-instructor

In [None]:
import warnings

warnings.filterwarnings('ignore')

import logging

neo4j_log = logging.getLogger("neo4j")
neo4j_log.setLevel(logging.CRITICAL)

In [None]:
import json
from src.classes.utils.DebugLogger import DebugLogger
from src.classes.utils.EnvLoader import EnvLoader
import os
import nest_asyncio
import nltk
from llama_index.core import PropertyGraphIndex, Settings
from llama_index.core import SimpleDirectoryReader
from typing import Dict

from llama_index.core import VectorStoreIndex, StorageContext
from llama_index.core.vector_stores.types import VectorStore
import re
from llama_index.core.base.response.schema import Response
from llama_index.core.indices.property_graph import SchemaLLMPathExtractor
from llama_index.core.ingestion import IngestionPipeline
from llama_index.core.node_parser import SimpleFileNodeParser
from llama_index.core.postprocessor import SimilarityPostprocessor
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.schema import TransformComponent, NodeWithScore, Document
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.graph_stores.neo4j import Neo4jPropertyGraphStore
from llama_index.llms.azure_openai import AzureOpenAI
from llama_index.llms.openai import OpenAI
from llama_index.vector_stores.neo4jvector import Neo4jVectorStore
from neo4j import GraphDatabase
from llama_index.embeddings.openai import OpenAIEmbedding
from abc import ABC, abstractmethod
from typing import Union, Optional
from typing import List, Tuple, Type, Literal
from llama_index.core.graph_stores import PropertyGraphStore
from llama_index.core.indices.base import BaseIndex
from llama_index.legacy.chat_engine.types import BaseChatEngine, AgentChatResponse
from llama_index.llms.huggingface import HuggingFaceLLM

nest_asyncio.apply()
nltk.download('punkt_tab')
nltk.download('averaged_perceptron_tagger_eng')

EnvLoader(env_dir="config").load_env_files()

logger = DebugLogger(use_panel_for_errors=True)

LLM_MODE = "openai"
EMBEDDING_MODE = "local"

In [None]:
class ModelManager:
    """
    Manages the configuration and lazy initialization of LLMs and embedding models
    for OpenAI, Azure OpenAI, and Hugging Face. Reads all configuration parameters
    from environment variables with sensible defaults.
    """

    def __init__(self) -> None:
        # Load configuration from environment variables
        self.openai_model = os.getenv("OPENAI_MODEL_NAME_CHAT")

        self.azure_model = os.getenv("OPENAI_MODEL_NAME_CHAT")
        self.azure_deployment_name = os.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME")
        self.azure_api_key = os.getenv("AZURE_OPENAI_API_KEY")
        self.azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
        self.azure_api_version = os.getenv("AZURE_OPENAI_API_VERSION")

        self.huggingface_model_name = os.getenv("HUGGINGFACE_LLM_MODEL")
        self.huggingface_embed_model_name = os.getenv("HUGGINGFACE_EMBED_MODEL")

        # Lazily initialized models
        self._openai_llm = None
        self._azure_llm = None
        self._local_llm = None
        self._openai_embed_model = None
        self._local_embed_model = None

    @property
    def openai_llm(self) -> OpenAI:
        """Lazy initialization of the OpenAI LLM."""
        if self._openai_llm is None:
            self._openai_llm = OpenAI(model=self.openai_model)
        return self._openai_llm

    @property
    def azure_llm(self) -> AzureOpenAI:
        """Lazy initialization of the Azure OpenAI LLM."""
        if self._azure_llm is None:
            self._azure_llm = AzureOpenAI(
                model=self.azure_model,
                deployment_name=self.azure_deployment_name,
                api_key=self.azure_api_key,
                azure_endpoint=self.azure_endpoint,
                api_version=self.azure_api_version,
            )
        return self._azure_llm

    @property
    def local_llm(self) -> HuggingFaceLLM:
        """Lazy initialization of the Hugging Face LLM."""
        if self._local_llm is None:
            self._local_llm = HuggingFaceLLM(model_name=self.huggingface_model_name, device_map="auto")
        return self._local_llm

    @property
    def openai_embed_model(self) -> OpenAIEmbedding:
        """Lazy initialization of the OpenAI embedding model."""
        if self._openai_embed_model is None:
            self._openai_embed_model = OpenAIEmbedding(model="text-embedding-ada-002")
        return self._openai_embed_model

    @property
    def local_embed_model(self) -> HuggingFaceEmbedding:
        """Lazy initialization of the Hugging Face embedding model."""
        if self._local_embed_model is None:
            self._local_embed_model = HuggingFaceEmbedding(model_name=self.huggingface_embed_model_name)
        return self._local_embed_model

    def get_llm(self, llm_type: str) -> Union[OpenAI, AzureOpenAI, HuggingFaceLLM]:
        """
        Retrieve the desired LLM based on the specified type.

        :param llm_type: The type of LLM to retrieve ("openai", "azure", "local").
        :return: The requested LLM instance.
        :raises ValueError: If an invalid llm_type is provided.
        """
        if llm_type == "openai":
            return self.openai_llm
        elif llm_type == "azure":
            return self.azure_llm
        elif llm_type == "local":
            return self.local_llm
        else:
            raise ValueError(f"Invalid llm_type '{llm_type}'. Expected 'openai', 'azure', or 'local'.")

    def get_embedding_model(self, embed_type: str) -> Union[OpenAIEmbedding, HuggingFaceEmbedding]:
        """
        Retrieve the desired embedding model based on the specified type.

        :param embed_type: The type of embedding model to retrieve ("openai", "local").
        :return: The requested embedding model instance.
        :raises ValueError: If an invalid embed_type is provided.
        """
        if embed_type == "openai":
            return self.openai_embed_model
        elif embed_type == "local":
            return self.local_embed_model
        else:
            raise ValueError(f"Invalid embed_type '{embed_type}'. Expected 'openai' or 'local'.")

In [None]:
Settings.llm = ModelManager().get_llm(LLM_MODE)
Settings.embed_model = ModelManager().get_embedding_model(EMBEDDING_MODE)

In [None]:
class Neo4jDBManager:
    """
    Manages the configuration and creation of Neo4j graph and vector stores.
    """

    def __init__(self, url: str = None, username: str = None, password: str = None, database: str = None):
        """
        Initialize Neo4j connection parameters.

        :param url: The URL for the Neo4j instance, defaults to "bolt://localhost:7687".
        :param username: Username for Neo4j authentication, defaults to "neo4j".
        :param password: Password for Neo4j authentication, retrieved from environment if not provided.
        :param database: Name of the Neo4j database, defaults to "neo4j".
        """
        self.logger = DebugLogger(use_panel_for_errors=True)
        self.url = url or os.getenv("NEO4J_URL", "bolt://localhost:7687")
        self.username = username or os.getenv("NEO4J_USERNAME", "neo4j")
        self.password = password or os.getenv("NEO4J_PASSWORD")
        self.database = database or os.getenv("NEO4J_DATABASE", "neo4j")

        self._validate_password()
        self.logger.success(f"Neo4jDBManager initialized with URL: '{self.url}', Database: '{self.database}'.")

    def _validate_password(self):
        """
        Validate that a Neo4j password is set.

        :raises ValueError: If the password is not provided.
        """
        if not self.password:
            error_message = (
                "Neo4j password is required. Set it in the environment or pass it directly."
            )
            self.logger.error(error_message)
            raise ValueError(error_message)

    def create_graph_store(self) -> Neo4jPropertyGraphStore:
        """
        Create and return a Neo4jPropertyGraphStore instance.

        :return: Configured Neo4jPropertyGraphStore instance.
        """
        return self._create_store(Neo4jPropertyGraphStore, "Neo4jPropertyGraphStore")

    def create_vector_store(self, embedding_dimension: int = 384, hybrid_search: bool = True) -> Neo4jVectorStore:
        """
        Create and return a Neo4jVectorStore instance.

        :param embedding_dimension: Dimension of embeddings, defaults to 1536.
        :param hybrid_search: Enables hybrid search, defaults to True.
        :return: Configured Neo4jVectorStore instance.
        """
        return self._create_store(
            Neo4jVectorStore,
            "Neo4jVectorStore",
            embedding_dimension=embedding_dimension,
            hybrid_search=hybrid_search,
        )

    def _create_store(self, store_class: type, store_name: str, **kwargs):
        """
        Helper method to create a store instance with the provided configuration.

        :param store_class: The class of the store to be created.
        :param store_name: The name of the store, used for logging purposes.
        :param kwargs: Additional configuration parameters for the store.
        :return: Configured store instance.
        :raises RuntimeError: If store creation fails.
        """
        try:
            store_instance = store_class(
                username=self.username,
                password=self.password,
                url=self.url,
                database=self.database,
                **kwargs,
            )
            self.logger.success(f"{store_name} instance created successfully.")
            return store_instance
        except Exception as e:
            error_message = f"Failed to create {store_name}: {e}"
            self.logger.error(error_message)
            raise RuntimeError(error_message) from e

In [None]:
# Define your database configuration
db_config = Neo4jDBManager()

# Connect to the Neo4j database
driver = GraphDatabase.driver(db_config.url, auth=(db_config.username, db_config.password))

In [None]:
# Function to reset the database completely
def reset_database(driver):
    with driver.session() as session:
        # Clear all nodes and relationships
        session.run("MATCH (n) DETACH DELETE n")
        print("Data cleared.")

        # Drop all constraints
        constraints = session.run("SHOW CONSTRAINTS")
        for record in constraints:
            constraint_name = record["name"]
            session.run(f"DROP CONSTRAINT {constraint_name}")
        print("All constraints dropped.")

        # Drop all indexes
        indexes = session.run("SHOW INDEXES")
        for record in indexes:
            index_name = record["name"]
            session.run(f"DROP INDEX {index_name}")
        print("All indexes dropped.")


# Execute the reset function
reset_database(driver)

# Close the driver connection
driver.close()

In [None]:
class UnstructuredTransform(TransformComponent):
    def __call__(self, docs, **kwargs):
        transformations = [
            #EntityExtractor()
        ]

        pipeline = IngestionPipeline(transformations=[SimpleFileNodeParser()] + transformations)
        base_nodes = pipeline.run(documents=docs, show_progress=True)

        return base_nodes

In [None]:
class SchemaHandler:
    """
    Handles schema definitions for knowledge graph validation, including entities, relations,
    and validation schemas specifically for smart contract reentrancy detection.
    """

    @staticmethod
    def get_validation_schema() -> List[Tuple[str, str, str]]:
        """
        Retrieve the validation schema defining valid triples in the knowledge graph for reentrancy detection.

        :return: A list of tuples representing valid (entity, relation, entity) triples.
        """
        return [
            # Smart Contract-specific triples
            ("SMART_CONTRACT", "CONTAINS", "FUNCTION"),
            ("SMART_CONTRACT", "DEPLOYS", "SMART_CONTRACT"),
            ("SMART_CONTRACT", "INTERACTS_WITH", "EXTERNAL_CONTRACT"),
            ("SMART_CONTRACT", "USES", "VARIABLE"),
            ("SMART_CONTRACT", "CALLS", "FUNCTION"),
            ("SMART_CONTRACT", "SUFFERED_FROM", "VULNERABILITY"),
            ("SMART_CONTRACT", "ASSOCIATED_WITH", "REENTRANCY_PATTERN"),

            # Function-related triples
            ("FUNCTION", "CALLS", "FUNCTION"),
            ("FUNCTION", "CALLS", "EXTERNAL_FUNCTION"),
            ("FUNCTION", "CONTAINS", "STATE_CHANGE"),
            ("FUNCTION", "READS", "VARIABLE"),
            ("FUNCTION", "WRITES", "VARIABLE"),
            ("FUNCTION", "USES", "REENTRANCY_PATTERN"),
            ("FUNCTION", "TRIGGERED_BY", "TRANSACTION"),

            # Vulnerability-related triples
            ("VULNERABILITY", "AFFECTS", "FUNCTION"),
            ("VULNERABILITY", "RELATED_TO", "REENTRANCY_PATTERN"),
            ("VULNERABILITY", "EXPLOITS", "STATE_CHANGE"),
            ("VULNERABILITY", "FOUND_IN", "SMART_CONTRACT"),

            # State Change-related triples
            ("STATE_CHANGE", "MODIFIES", "VARIABLE"),
            ("STATE_CHANGE", "LEADS_TO", "VULNERABILITY"),
            ("STATE_CHANGE", "TRIGGERED_BY", "CALL"),

            # Variable-related triples
            ("VARIABLE", "MODIFIED_BY", "FUNCTION"),
            ("VARIABLE", "READ_BY", "FUNCTION"),
            ("VARIABLE", "AFFECTED_BY", "STATE_CHANGE"),

            # Call-related triples
            ("CALL", "MAKES", "EXTERNAL_CALL"),
            ("CALL", "RETURNS", "VALUE"),
            ("CALL", "RESULTS_IN", "STATE_CHANGE"),

            # Reentrancy-related triples
            ("REENTRANCY_PATTERN", "IDENTIFIED_IN", "FUNCTION"),
            ("REENTRANCY_PATTERN", "LEADS_TO", "VULNERABILITY"),
            ("REENTRANCY_PATTERN", "EXPLOITED_BY", "CALL"),

            # Transaction-related triples
            ("TRANSACTION", "TRIGGERS", "FUNCTION"),
            ("TRANSACTION", "LEADS_TO", "STATE_CHANGE"),
            ("TRANSACTION", "RESULTS_IN", "VULNERABILITY"),
            ("TRANSACTION", "SENT_TO", "SMART_CONTRACT"),
        ]

    @staticmethod
    def get_entities() -> Type[str]:
        """
        Retrieve the list of possible entity types for the knowledge graph.

        :return: A Literal type representing the valid entity types.
        """
        return Literal[
            "SMART_CONTRACT", "FUNCTION", "EXTERNAL_FUNCTION", "VARIABLE",
            "STATE_CHANGE", "CALL", "EXTERNAL_CALL", "REENTRANCY_PATTERN",
            "VULNERABILITY", "TRANSACTION", "EXTERNAL_CONTRACT", "VALUE"
        ]

    @staticmethod
    def get_relations() -> Type[str]:
        """
        Retrieve the list of possible relation types for the knowledge graph.

        :return: A Literal type representing the valid relation types.
        """
        return Literal[
            "CONTAINS", "CALLS", "READS", "WRITES", "MODIFIES", "TRIGGERED_BY",
            "LEADS_TO", "RESULTS_IN", "AFFECTS", "RELATED_TO", "FOUND_IN",
            "IDENTIFIED_IN", "EXPLOITED_BY", "SUFFERED_FROM", "DEPLOYS",
            "INTERACTS_WITH", "USES", "MAKES", "RETURNS", "SENT_TO"
        ]

In [None]:
class KnowledgeManager(ABC):
    """
    Abstract base class for managing indexing, retrieval, and querying of knowledge data.
    Provides methods for indexing documents, creating a chat engine, and executing queries.
    """

    def __init__(self, store: PropertyGraphStore, storage_context: StorageContext) -> None:
        """
        Initialize the KnowledgeManager with a store and storage context.

        :param store: The storage backend for managing knowledge data.
        :param storage_context: Configuration context for the storage backend.
        """
        self.store = store
        self.storage_context = storage_context
        self.persist_dir: str = ""
        self.index: Optional[BaseIndex] = None
        self.chat_engine: Optional[BaseChatEngine] = None
        self.logger = DebugLogger(use_panel_for_errors=True)

    def get_index(self) -> Optional[BaseIndex]:
        """
        Get the current index, if available.

        :return: The current index or None if not initialized.
        """
        return self.index

    def get_query_engine(self) -> Optional[BaseChatEngine]:
        """
        Get the current chat engine, if available.

        :return: The current chat engine or None if not initialized.
        """
        return self.chat_engine

    def index_documents(self, documents: List[Document], reload_index: bool = False) -> None:
        """
        Index a list of documents into the knowledge store.

        :param documents: List of Document objects to be indexed.
        :param reload_index: Whether to reload an existing index if available.
        :raises Exception: If an error occurs during indexing.
        """
        try:
            if reload_index and self.load_index():
                self.logger.success("Index loaded successfully. Skipping re-indexing.")
                return

            if self.index:
                self.refresh_index(documents)

            self.logger.info("Starting document indexing... This may take a while.")
            self.create_index(documents)
            self.logger.success("Document indexing completed successfully.")
        except Exception as e:
            self.logger.error("Error during document indexing:", exc_info=True)
            raise

    def refresh_index(self, documents: List[Document]):
        try:
            if not self.index and self.load_index():
                self.logger.success("Index loaded successfully.")

            self.logger.info("Re-indexing with new documents.")
            self.index.refresh(documents, transformations=[UnstructuredTransform()])

        except Exception as e:
            self.logger.error("Error during new document indexing:", exc_info=True)
            raise

    def create_chat_engine(self) -> None:
        """
        Initialize the chat engine using the current retriever.

        :raises ValueError: If the chat engine setup fails.
        """
        if not self.index:
            error_message = "Cannot create chat engine: Index is not initialized."
            self.logger.error(error_message)
            raise ValueError(error_message)

        try:
            self.chat_engine = self.index.as_chat_engine(verbose=True)
            self.logger.success("Query engine initialized successfully.")
        except Exception as e:
            error_message = f"Failed to initialize chat engine: {e}"
            self.logger.error(error_message, exc_info=True)
            raise ValueError(error_message) from e

    def execute_query(self, query: str) -> Optional[AgentChatResponse]:
        """
        Execute a query on the knowledge store and return the result.

        :param query: The query string to execute.
        :return: The response object if the query is successful, or None otherwise.
        """
        if not self.chat_engine:
            self.logger.info("Query engine not initialized. Creating a new chat engine...")
            self.create_chat_engine()

        try:
            return self.chat_engine.chat(query)
        except Exception as e:
            self.logger.error("Query execution failed:", exc_info=True)
            return None

    @abstractmethod
    def create_index(self, documents: List[Document]) -> None:
        """
        Abstract method to create an index from a list of documents.

        :param documents: List of Document objects to be indexed.
        """
        pass

    @abstractmethod
    def load_index(self) -> bool:
        """
        Abstract method to load an existing index from storage.

        :return: True if the index was loaded successfully, False otherwise.
        """
        pass

In [None]:
class GraphManager(KnowledgeManager):
    """
    Manages graph store operations, including cleaning, indexing, and query execution.
    Extends KnowledgeManager to support configurations specific to graph-based indexing.
    """

    def __init__(self, store: PropertyGraphStore, storage_context: StorageContext) -> None:
        """
        Initialize the GraphManager with a graph-based store and storage context.

        :param store: The graph data storage backend.
        :param storage_context: Context or configuration settings for storage management.
        """
        super().__init__(store, storage_context)
        self.persist_dir = "graph_index"
        self.logger.info("GraphManager initialized with a graph store.")

    def create_index(self, documents: List[Document]) -> None:
        """
        Index a list of documents into the knowledge graph with specified configurations.

        :param documents: List of Document objects to be indexed.
        """
        self.logger.info("Starting document indexing into the graph store. This may take some time.")

        try:
            self.index = PropertyGraphIndex.from_documents(
                documents=documents,
                kg_extractors=[
                    #SimpleLLMPathExtractor(),
                    SchemaLLMPathExtractor(
                        llm=ModelManager().get_llm(LLM_MODE),
                        possible_entities=SchemaHandler.get_entities(),
                        possible_relations=SchemaHandler.get_relations(),
                        kg_validation_schema=SchemaHandler.get_validation_schema(),
                        strict=False,
                        max_triplets_per_chunk=3
                    ),
                ],
                property_graph_store=self.store,
                storage_context=self.storage_context,
                embed_kg_nodes=True,
                show_progress=True,
                transformations=[UnstructuredTransform()],
            )
            self.logger.success("Document indexing completed successfully.")
        except Exception as e:
            self.logger.error("Error during document indexing:", exc_info=True)
            raise RuntimeError("Graph indexing failed.") from e

    def load_index(self) -> bool:
        """
        Load the index from the graph store if available.

        :return: True if the index was successfully loaded, False otherwise.
        """
        if not self.store:
            self.logger.warning("No graph store is available. Unable to load index.")
            return False

        try:
            self.logger.info("Attempting to load the index from the graph store.")
            self.index = PropertyGraphIndex.from_existing(
                property_graph_store=self.store, embed_kg_nodes=True
            )
            self.logger.success("Index loaded successfully from the graph store.")
            return True
        except Exception as e:
            self.logger.error("Error while loading the index:", exc_info=True)
            return False

In [None]:
class VectorManager(KnowledgeManager):
    """
    Manages vector store operations, including document indexing, retrieval,
    and querying within a vector-based storage system.
    """

    def __init__(self, store: VectorStore, storage_context: StorageContext) -> None:
        """
        Initialize the VectorManager with a vector storage backend and context configuration.

        :param store: The vector store used for managing indexed data.
        :param storage_context: Context or configuration for managing vector storage.
        """
        super().__init__(store, storage_context)
        self.persist_dir = "vector_index"
        self.logger.info("VectorManager initialized with vector storage backend.")

    def create_index(self, documents: List[Document]) -> None:
        """
        Index a list of documents into the vector store.

        :param documents: List of Document objects to be indexed.
        """
        self.logger.info("Starting vector document indexing... This may take a while.")

        try:
            # Create the vector index with the given documents and context
            self.index = VectorStoreIndex.from_documents(
                documents=documents,
                storage_context=self.storage_context,
                show_progress=True,
                store_nodes_override=True,
                transformations=[UnstructuredTransform()]
            )

            # Debugging output for indexed documents
            self.logger.debug(f"Indexed documents: {self.index.storage_context.docstore.docs}")
            self.logger.success("Vector document indexing completed successfully.")
        except Exception as e:
            self.logger.error("Error occurred during vector document indexing:", exc_info=True)
            raise RuntimeError("Vector document indexing failed.") from e

    def load_index(self) -> bool:
        """
        Load the index from the vector store if available.

        :return: True if the index was successfully loaded, False otherwise.
        """
        #if not self.index:
        #    self.logger.info("No vector index is available. Unable to load index.")
        #    return False

        try:
            self.logger.info("Attempting to load the index from the vector store.")
            self.index = VectorStoreIndex.from_vector_store(vector_store=self.store)
            self.logger.success("Index loaded successfully from the vector store.")
            return True
        except Exception as e:
            self.logger.error("Error occurred while loading the index:", exc_info=True)
            return False

In [None]:
class SmartContractProcessor:
    def __init__(self, input_dir: str, output_dir: str, batch_size: int = 10):
        """
        Initialize the batch processor.
        :param input_dir: Directory containing Solidity files.
        :param output_dir: Directory to save processed JSON batches.
        :param batch_size: Number of contracts to process per batch.
        """
        self.input_dir = input_dir
        self.output_dir = output_dir
        self.batch_size = batch_size

    @staticmethod
    def load_contract(file_path: str) -> str:
        """Load the Solidity smart contract from a file."""
        with open(file_path, "r", encoding="latin-1") as f:
            return f.read()

    @staticmethod
    def remove_comments(code: str) -> str:
        """Remove single-line and multi-line comments."""
        pattern = r"(//.*?$|/\*.*?\*/)"
        return re.sub(pattern, "", code, flags=re.DOTALL | re.MULTILINE)

    @staticmethod
    def extract_metadata(code: str) -> Dict:
        """Extract metadata such as contract name, Solidity version, and functions."""
        metadata = {}
        # Extract Solidity version
        pragma_match = re.search(r"pragma\s+solidity\s+([^;]+);", code)
        metadata["solidity_version"] = pragma_match.group(1) if pragma_match else "unknown"

        # Extract contract name
        contract_match = re.search(r"contract\s+([a-zA-Z0-9_]+)", code)
        metadata["contract_name"] = contract_match.group(1) if contract_match else "UnnamedContract"

        # Extract function names
        function_names = re.findall(r"function\s+([a-zA-Z0-9_]+)", code)
        metadata["functions"] = function_names

        return metadata

    @staticmethod
    def split_into_chunks(code: str) -> List[Dict]:
        """Split the Solidity code into logical chunks based on functions."""
        chunks = []
        # Split code at function boundaries
        function_pattern = r"(function\s+[a-zA-Z0-9_]+\s*\([^)]*\)\s*(public|private|external|internal)?\s*\{)"
        split_code = re.split(function_pattern, code)

        # Combine splits into chunks
        for i in range(1, len(split_code), 2):
            header = split_code[i]
            body = split_code[i + 1] if i + 1 < len(split_code) else ""
            chunks.append({"content": header + body})
        return chunks

    def process_contract(self, file_path: str) -> List[Dict]:
        """Process a single Solidity contract."""
        # Load and clean contract
        code = self.load_contract(file_path)
        code = self.remove_comments(code)

        # Extract metadata
        metadata = self.extract_metadata(code)

        # Split into chunks
        chunks = self.split_into_chunks(code)

        # Combine chunks with metadata
        processed_chunks = []
        for chunk in chunks:
            processed_chunks.append({
                "content": chunk["content"].strip(),
                "metadata": metadata
            })

        return processed_chunks

    def process_batch(self):
        """Process contracts in batches and save them as JSON."""
        all_files = [os.path.join(self.input_dir, f) for f in os.listdir(self.input_dir) if f.endswith(".sol")]
        os.makedirs(self.output_dir, exist_ok=True)

        for i in range(0, len(all_files), self.batch_size):
            batch_files = all_files[i:i + self.batch_size]
            batch_results = []

            for file in batch_files:
                print(f"Processing: {file}")
                processed_data = self.process_contract(file)
                batch_results.extend(processed_data)

            # Save the batch results
            batch_output_path = os.path.join(self.output_dir, f"batch_{i // self.batch_size + 1}.json")
            with open(batch_output_path, "w") as f:
                json.dump(batch_results, f, indent=4)

            print(f"Saved batch {i // self.batch_size + 1} to {batch_output_path}")


In [None]:
class RAG:
    """
    A hybrid retrieval pipeline utilizing Neo4j for both vector-based document retrieval
    and knowledge graph storage. Supports structured and unstructured query handling.
    """

    def __init__(self, logger: Optional[DebugLogger] = None, db_manager: Optional[Neo4jDBManager] = None) -> None:
        """
        Initializes the RAG pipeline with components for graph and vector-based retrieval.

        :param db_manager: Instance of Neo4jDBManager for database interaction (optional).
        """
        self.logger = logger or DebugLogger(use_panel_for_errors=True)
        self.db_manager = db_manager or Neo4jDBManager()
        self.chat_engine = None
        self.knowledge_manager = None
        self._initialize_managers()

    def _initialize_managers(self) -> None:
        """
        Initializes necessary managers (e.g., RAG Manager).
        Placeholder for future initialization logic.
        """
        # Example initialization (Replace with actual manager setup)
        self.logger.info("Initializing RAG managers...")
        # self.knowledge_manager = RAGManager(self.db_manager)

    def load_and_index_documents(self, folder_path: str, reload_index: bool = False) -> None:
        """
        Loads, chunks, and indexes documents into the Neo4j vector store and knowledge graph.

        :param folder_path: Path to the folder containing document files.
        :param reload_index: If True, reloads the index regardless of existing data.
        """
        self.logger.info(f"{'Reloading' if reload_index else 'Loading'} documents from: {folder_path}")

        docs = [] if reload_index else self._load_docs(folder_path)

        if not docs and not reload_index:
            self.logger.warning("No documents available for indexing.")
            return

        self._index(docs, reload_index=reload_index)

    @staticmethod
    @DebugLogger.profile
    def _load_docs(folder_path: str) -> List[Document]:
        """
        Loads documents from the specified folder.

        :param folder_path: Path to the folder containing document files.
        :return: List of loaded Document objects.
        """
        reader = SimpleDirectoryReader(input_dir=folder_path, errors="strict", encoding="latin-1")
        docs = reader.load_data(show_progress=True)
        return docs

    @DebugLogger.profile
    def _index(self, docs: List[Document], reload_index: bool) -> None:
        """
        Indexes the given documents in the vector store.

        :param docs: List of Document objects to index.
        :param reload_index: If True, reloads the index.
        """
        try:
            self.logger.info("Indexing documents into Neo4j...")
            self.knowledge_manager.index_documents(docs, reload_index)
            self.logger.success("Document indexing completed successfully.")
        except Exception as e:
            self.logger.error(f"Error during indexing: {e}", exc_info=True)

    def _initialize_chat_engine(self) -> None:
        """
        Sets up the RAG chat engine for document retrieval.
        """
        try:
            self.logger.info("Initializing chat engine...")
            self.knowledge_manager.create_chat_engine()
            self.chat_engine = self.knowledge_manager.get_query_engine()
            self.logger.success("Query engine initialized successfully.")
        except Exception as e:
            self.logger.error(f"Failed to initialize chat engine: {e}", exc_info=True)

    def as_chat_engine(self) -> RetrieverQueryEngine:
        if not self.chat_engine:
            self._initialize_chat_engine()
        return self.chat_engine

    @DebugLogger.profile
    def query(self, question: str) -> Optional[Response]:
        """
        Executes a query using the vector store.

        :param question: The input query as a string.
        :return: Query response or None in case of an error.
        """
        if not self.chat_engine:
            self._initialize_chat_engine()

        try:
            self.logger.info(f"Executing query: {question}")
            response = self.knowledge_manager.execute_query(question)
            self.logger.success("Query executed successfully.")
            return response
        except Exception as e:
            self.logger.error(f"Error during query execution: {e}", exc_info=True)
            return None

    @staticmethod
    def fetch_sources(nodes: List[NodeWithScore]) -> List[str]:
        """
        Filters nodes by similarity score and extracts unique source filenames.

        :param nodes: List of nodes from the query response.
        :return: List of unique filenames from filtered nodes.
        """
        processor = SimilarityPostprocessor(similarity_cutoff=0.75)
        filtered_nodes = processor.postprocess_nodes(nodes)
        return list({node.node.metadata["file_name"] for node in filtered_nodes if "file_name" in node.node.metadata})

In [None]:
class GraphRAG(RAG):
    """
    A specialized implementation of the RAG (Retrieval-Augmented Generation) pipeline
    that leverages Neo4j for both vector-based document retrieval and knowledge graph storage.

    This class supports structured and unstructured query handling by integrating graph-based
    storage and retrieval mechanisms with the RAG pipeline.
    """

    def __init__(self) -> None:
        """
        Initializes the GraphRAG pipeline with components for graph-based and vector-based retrieval.
        """
        super().__init__()

    def _initialize_managers(self) -> None:
        """
        Configures the graph manager and its associated storage contexts.

        This method sets up the graph store and initializes the graph manager, enabling
        efficient graph-based storage and retrieval operations. It ensures that the pipeline
        can handle both graph structures and their integration with vector-based retrieval.

        :raises Exception:
            If the initialization of the graph manager fails, an error is logged with details.
        """
        try:
            self.logger.info("Initializing graph manager...")

            # Create the graph store and its associated storage context
            graph_store = self.db_manager.create_graph_store()
            graph_storage_context = StorageContext.from_defaults(graph_store=graph_store)

            # Initialize the GraphManager
            self.knowledge_manager = GraphManager(graph_store, graph_storage_context)

            self.logger.success("Graph manager initialized successfully.")
        except Exception as e:
            self.logger.error(f"Failed to initialize graph manager: {e}", exc_info=True)

In [None]:
class VectorRAG(RAG):
    """
    A specialized implementation of the Retrieval-Augmented Generation (RAG) pipeline
    that focuses on vector-based document retrieval.

    This class integrates vector-based storage and retrieval functionality, supporting
    efficient query execution and management of vector embeddings.
    """

    def __init__(self) -> None:
        """
        Initializes the VectorRAG pipeline by setting up components for vector-based retrieval.
        """
        super().__init__()

    def _initialize_managers(self) -> None:
        """
        Configures the vector manager and its associated storage context.

        This method sets up the vector store and initializes the vector manager,
        enabling efficient storage, retrieval, and management of vector embeddings.

        :raises Exception:
            Logs any errors encountered during the initialization process.
        """
        try:
            self.logger.info("Initializing vector manager...")

            # Create the vector store and its associated storage context
            vector_store = self.db_manager.create_vector_store()
            vector_storage_context = StorageContext.from_defaults(vector_store=vector_store)

            # Initialize the VectorManager with the configured store and context
            self.knowledge_manager = VectorManager(vector_store, vector_storage_context)

            self.logger.success("Vector manager initialized successfully.")
        except Exception as e:
            self.logger.error(f"Failed to initialize vector manager: {e}", exc_info=True)

In [None]:
rag = VectorRAG()

In [None]:
path_to_reentrant = os.path.join("..", "dataset", "manually-verified-tiny", "source", "reentrant")
rag.load_and_index_documents(path_to_reentrant)

In [None]:
path_to_safe = os.path.join("..", "dataset", "manually-verified-tiny", "source", "safe")
rag.load_and_index_documents(path_to_safe)

In [None]:
prompt = """ You must follow these steps:

1. **Retrieve Examples**:
   Search your knowledge base for relevant examples of Solidity smart contracts labeled as **reentrant** or **non-reentrant**. Focus on:
   - Contracts with **reentrancy vulnerabilities**, such as making external calls (`call`, `delegatecall`, `transfer`) before updating state variables.
   - Contracts that use **mitigations** like the *checks-effects-interactions* pattern, `ReentrancyGuard` modifiers, or mutex locks.

   Provide **contract snippets** and **explanations** of why these examples were labeled as reentrant or non-reentrant.

2. **Analyze the Target Contract**:
   Carefully analyze the **input Solidity contract** to identify:
   - Use of external calls (`msg.sender.call`, `delegatecall`, `send`, etc.).
   - Whether state variables are updated **before** or **after** the external call.
   - Reentrancy mitigations like `ReentrancyGuard` modifiers or the *checks-effects-interactions* pattern.

3. **Classify**:
   Based on the retrieved examples and your analysis, classify the target contract as:
   - **Reentrant**: If it contains vulnerabilities that allow external calls before updating state variables.
   - **Non-Reentrant**: If it uses proper safeguards or patterns to prevent reentrancy.

4. **Justify the Classification**:
   Explain your reasoning in detail. Compare the patterns you observed in the target contract with the retrieved examples. Highlight specific lines or functions that led to your conclusion.

5. **Output**:
   Return the result in the following structured JSON format:

---

### Output Format

```json
{
  "classification": "Reentrant / Non-Reentrant",
  "justification": "Provide a detailed explanation of your reasoning.",
  "retrieved_examples": [
    {
      "contract_snippet": "Relevant Solidity contract code or fragment.",
      "label": "Reentrant / Non-Reentrant",
      "explanation": "Why this contract is classified as such."
    },
    {
      "contract_snippet": "Relevant Solidity contract code or fragment.",
      "label": "Reentrant / Non-Reentrant",
      "explanation": "Why this contract is classified as such."
    }
  ],
  "analysis": "Key observations about the target contract, including function behaviors, external calls, and state updates."
}
```

---

### Input

"""

In [96]:
test_contract = """
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.0;

contract ReentrantBank {
    mapping(address => uint256) public balances;

    // Deposit function
    function deposit() public payable {
        balances[msg.sender] += msg.value;
    }

    // Withdraw function (Vulnerable to Reentrancy)
    function withdraw() public {
        uint256 amount = balances[msg.sender];

        require(amount > 0, "Insufficient balance");

        // External call to the sender before updating state
        (bool success, ) = msg.sender.call{value: amount}("");
        require(success, "Transfer failed");

        // Update state AFTER external call (vulnerability)
        balances[msg.sender] = 0;
    }

    // View balance
    function getBalance() public view returns (uint256) {
        return balances[msg.sender];
    }
}
"""

answer = rag.query(prompt + test_contract)
sources = rag.fetch_sources(answer.source_nodes)
print(f"{answer} - \n\n --> SOURCES: {sources}")

Added user message to memory:  You must follow these steps:

1. **Retrieve Examples**:
   Search your knowledge base for relevant examples of Solidity smart contracts labeled as **reentrant** or **non-reentrant**. Focus on:
   - Contracts with **reentrancy vulnerabilities**, such as making external calls (`call`, `delegatecall`, `transfer`) before updating state variables.
   - Contracts that use **mitigations** like the *checks-effects-interactions* pattern, `ReentrancyGuard` modifiers, or mutex locks.

   Provide **contract snippets** and **explanations** of why these examples were labeled as reentrant or non-reentrant.

2. **Analyze the Target Contract**:
   Carefully analyze the **input Solidity contract** to identify:
   - Use of external calls (`msg.sender.call`, `delegatecall`, `send`, etc.).
   - Whether state variables are updated **before** or **after** the external call.
   - Reentrancy mitigations like `ReentrancyGuard` modifiers or the *checks-effects-interactions* pattern

```json
{
  "classification": "Reentrant",
  "justification": "The target contract, `ReentrantBank`, is classified as reentrant due to the vulnerability present in its `withdraw` function. The function makes an external call to `msg.sender.call{value: amount}('')` before updating the state variable `balances[msg.sender]`. This sequence allows an attacker to exploit the contract by re-entering the `withdraw` function before the balance is updated, enabling multiple withdrawals.",
  "retrieved_examples": [
    {
      "contract_snippet": "function withdrawFunds_re_ent17 (uint256 _weiToWithdraw) public {\n    require(balances_re_ent17[msg.sender] >= _weiToWithdraw);\n    (bool success,) = msg.sender.call.value(_weiToWithdraw)(\"\");\n    require(success);  //bug\n    balances_re_ent17[msg.sender] -= _weiToWithdraw;\n}",
      "label": "Reentrant",
      "explanation": "This contract is labeled as reentrant because it makes an external call to transfer Ether before updating the user's bala