In [1]:
import os
import logging

from automata.singletons.dependency_factory import dependency_factory, DependencyFactory
from automata.singletons.py_module_loader import py_module_loader
from automata.symbol.graph import SymbolGraph
from automata.symbol_embedding.vector_databases import (
    ChromaSymbolEmbeddingVectorDatabase,
)
from automata.symbol_embedding.base import SymbolCodeEmbedding
from automata.llm.providers.openai import OpenAIEmbeddingProvider



In [2]:
py_module_loader.reset()

root_path = "/Users/ocolegrove/repo_store/llama_index"
project_name = "langchain"

py_module_loader.initialize(root_path, project_name)

In [3]:
code_embedding_db = ChromaSymbolEmbeddingVectorDatabase(
    project_name,
    persist_directory=DependencyFactory.DEFAULT_CODE_EMBEDDING_FPATH,
    factory=SymbolCodeEmbedding.from_args,
)
len(code_embedding_db.get_ordered_entries())

5321

In [4]:

# symbol_graph = SymbolGraph(
#     os.path.join(DependencyFactory.DEFAULT_SCIP_FPATH, f"{project_name}.scip")
# )
embedding_provider = OpenAIEmbeddingProvider()

dependency_factory.set_overrides(
    **{
#         "symbol_graph": symbol_graph,
        "code_embedding_db": code_embedding_db,
        "embedding_provider": embedding_provider,
    }
    
)

symbol_code_embedding_handler = dependency_factory.get("symbol_code_embedding_handler")
embedding_similarity_calculator = dependency_factory.get("embedding_similarity_calculator")


code_embeddings = symbol_code_embedding_handler.get_ordered_entries()

query = "What is langchain?"

print(f"Demonstrating code embedding search results for query = '{query}'")
code_similarity_results = embedding_similarity_calculator.calculate_query_similarity_dict(code_embeddings, query)
for i,rank in enumerate(list(code_similarity_results.items())[0:10]):
    print(f"rank {i} = {rank[0].dotpath} with distance {rank[1]:.3f}")
    

Demonstrating code embedding search results for query = 'What is langchain?'
rank 0 = langchain.callbacks.mlflow_callback.MlflowLogger.langchain_artifact with distance 0.794
rank 1 = langchain.callbacks.manager._get_debug with distance 0.786
rank 2 = langchain.chains.llm.LLMChain._chain_type with distance 0.780
rank 3 = langchain.chains.llm.LLMChain with distance 0.778
rank 4 = langchain.chains.llm_bash.base.LLMBashChain with distance 0.777
rank 5 = langchain.schema.output_parser.NoOpOutputParser.lc_serializable with distance 0.777
rank 6 = langchain.chains.natbot.base.NatBotChain with distance 0.776
rank 7 = langchain.chains.llm_math.base.LLMMathChain with distance 0.775
rank 8 = langchain.chains.sql_database.base.SQLDatabaseChain with distance 0.774
rank 9 = langchain.chains.llm_math.base.LLMMathChain._chain_type with distance 0.773
