In [2]:
import logging
from typing import Any, Set

from automata.config.base import AgentConfigName
from automata.core.singletons.dependency_factory import dependency_factory
from automata.core.singletons.py_module_loader import py_module_loader
from automata.core.context_providers.symbol_synchronization import (
    SymbolProviderSynchronizationContext,
)
    
from automata.core.tools.base import Tool
from automata.core.tools.builders.context_oracle import ContextOracleToolkitBuilder

logger = logging.getLogger(__name__)

py_module_loader.initialize()

In [19]:
symbol_code_embedding_handler = dependency_factory.get('symbol_code_embedding_handler')
symbol_doc_embedding_handler = dependency_factory.get('symbol_doc_embedding_handler')
embedding_similarity_calculator = dependency_factory.get('embedding_similarity_calculator')

In [11]:
with SymbolProviderSynchronizationContext() as synchronization_context:
    synchronization_context.register_provider(symbol_search.symbol_graph)
    synchronization_context.register_provider(symbol_search.search_embedding_handler)
    synchronization_context.synchronize()

graph_symbols = symbol_search.symbol_graph.get_sorted_supported_symbols()
embedding_symbols = symbol_search.search_embedding_handler.get_sorted_supported_symbols()
available_symbols = set(graph_symbols).intersection(set(embedding_symbols))

In [23]:

context_oracle = ContextOracleToolkitBuilder(symbol_doc_embedding_handler, symbol_code_embedding_handler, embedding_similarity_calculator)

In [24]:
context = context_oracle._get_context("Provide new code which refactors the SymbolDocEmbeddingHandler class to be more robust.")

In [25]:
print(context)

class SymbolDocEmbeddingHandler(SymbolEmbeddingHandler):
    """A class to handle the embedding of symbols"""

    def __init__(
        self,
        embedding_db: JSONSymbolEmbeddingVectorDatabase,
        embedding_builder: SymbolDocEmbeddingBuilder,
    ) -> None:
        super().__init__(embedding_db, embedding_builder)

    def get_embedding(self, symbol: Symbol) -> SymbolDocEmbedding:
        return self.embedding_db.get(symbol.dotpath)

    def process_embedding(self, symbol: Symbol) -> None:
        """
        Process the embedding for a `Symbol` -
        Currently we do nothing if the symbol is already contained
        """

        source_code = self.embedding_builder.fetch_embedding_source_code(symbol)

        if not source_code:
            raise ValueError(f"Symbol {symbol} has no source code")

        if self.embedding_db.contains(symbol.dotpath):
            self.update_existing_embedding(source_code, symbol)
            return
        # else:
        symbol_embeddi