In [1]:
import os
import logging
from typing import Any, Set

from automata.config.base import AgentConfigName
from automata.singletons.dependency_factory import dependency_factory, DependencyFactory
from automata.singletons.py_module_loader import py_module_loader
from automata.context_providers.symbol_synchronization import (
    SymbolProviderSynchronizationContext,
)
from automata.symbol.graph import SymbolGraph
from automata.symbol_embedding.vector_databases import (
    ChromaSymbolEmbeddingVectorDatabase,
)
from automata.symbol_embedding.base import SymbolCodeEmbedding, SymbolDocEmbedding
from automata.llm.providers.openai import OpenAIEmbeddingProvider
from automata.experimental.search.rank import SymbolRank, SymbolRankConfig

logger = logging.getLogger(__name__)


In [2]:
root_path = "/Users/ocolegrove/repo_store/llama_index"
project_name = "llama_index"


root_path = "/Users/ocolegrove/AutomataNoCruft"
project_name = "automata"
py_module_loader.initialize(root_path, project_name)

In [4]:
code_embedding_db = ChromaSymbolEmbeddingVectorDatabase(
    project_name,
    persist_directory=DependencyFactory.DEFAULT_CODE_EMBEDDING_FPATH,
    factory=SymbolCodeEmbedding.from_args,
)

symbol_graph = SymbolGraph(
    os.path.join(DependencyFactory.DEFAULT_SCIP_FPATH, f"{project_name}.scip")
)
embedding_provider = OpenAIEmbeddingProvider()

dependency_factory.set_overrides(
    **{
        "symbol_graph": symbol_graph,
        "code_embedding_db": code_embedding_db,
        "embedding_provider": embedding_provider,
    }
)



self.overrides.get("disable_synchronization", False) =  False
provider =  <automata.symbol.graph.SymbolGraph object at 0x159f05b50>
self.overrides =  {'symbol_graph': <automata.symbol.graph.SymbolGraph object at 0x159f05b50>, 'code_embedding_db': <automata.symbol_embedding.vector_databases.ChromaSymbolEmbeddingVectorDatabase object at 0x159f05ac0>, 'embedding_provider': <automata.llm.providers.openai.OpenAIEmbeddingProvider object at 0x10808cfd0>}


In [5]:
from automata.symbol.symbol_utils import get_rankable_symbols
filtered_symbols = get_rankable_symbols(symbol_graph.get_sorted_supported_symbols())
print(f'len(filtered_symbols) = {len(filtered_symbols)}')



len(filtered_symbols) = 645


In [9]:
from automata.symbol.symbol_utils import convert_to_ast_object, get_rankable_symbols
filtered_symbols = get_rankable_symbols(symbol_graph.get_sorted_supported_symbols())
print(f'len(filtered_symbols) = {len(filtered_symbols)}')


len(filtered_symbols) = 0


In [7]:
for i,filtered_symbol in enumerate(filtered_symbols):
    print(f"filtered_symbol = {filtered_symbol}, i={i}")
    ast_object = convert_to_ast_object(filtered_symbol)
    print('ast_object = ', ast_object)

filtered_symbol = Symbol(scip-python python automata 6e304eca4cda41fbac2bcc8445ed191a85e8b260 `automata.agent.agent`/Agent#, scip-python, Package(python automata 6e304eca4cda41fbac2bcc8445ed191a85e8b260), (Descriptor(automata.agent.agent, 1), Descriptor(Agent, 2))), i=0
ast_object =  <ast.ClassDef object at 0x17995c6d0>
filtered_symbol = Symbol(scip-python python automata 6e304eca4cda41fbac2bcc8445ed191a85e8b260 `automata.agent.agent`/Agent#__init__()., scip-python, Package(python automata 6e304eca4cda41fbac2bcc8445ed191a85e8b260), (Descriptor(automata.agent.agent, 1), Descriptor(Agent, 2), Descriptor(__init__, 4))), i=1
ast_object =  <ast.FunctionDef object at 0x17995c7f0>
filtered_symbol = Symbol(scip-python python automata 6e304eca4cda41fbac2bcc8445ed191a85e8b260 `automata.agent.agent`/Agent#__iter__()., scip-python, Package(python automata 6e304eca4cda41fbac2bcc8445ed191a85e8b260), (Descriptor(automata.agent.agent, 1), Descriptor(Agent, 2), Descriptor(__iter__, 4))), i=2
ast_object

In [8]:
search = dependency_factory.get("symbol_search")

self.overrides.get("disable_synchronization", False) =  False
provider =  <automata.memory_store.symbol_code_embedding.SymbolCodeEmbeddingHandler object at 0x179a58550>


RuntimeError: Must synchronize symbol providers in synchronization context

In [None]:
symbol_code_embedding_handler = dependency_factory.get('symbol_code_embedding_handler')


In [None]:
len(symbol_code_embedding_handler.get_sorted_supported_symbols())

In [None]:
search.symbol_rank_config = SymbolRankConfig(alpha=0.85)
search.z_score_power = 10
query = "How do I build a storage context?"
symbol_rank_search_results = search.symbol_rank_search(query)
print("Demonstrating SymbolRank results for query = `how do I index data?`\n")
for i,rank in enumerate(symbol_rank_search_results[0:10]):
    print(f"rank {i} = {rank[0].dotpath} with rank {rank[1]:.3f}")

In [None]:
embedding_similarity_calculator = dependency_factory.get('embedding_similarity_calculator')
code_embeddings = symbol_code_embedding_handler.get_ordered_entries()

print(f"Demonstrating code embedding search results for query = `{query}`\n")
code_similarity_results = embedding_similarity_calculator.calculate_query_similarity_dict(code_embeddings, query)
for i,rank in enumerate(list(code_similarity_results.items())[0:10]):
    print(f"rank {i} = {rank[0].dotpath} with distance {rank[1]:.3f}")
    
    