Skip to content
This repository has been archived by the owner on Mar 16, 2024. It is now read-only.

Commit

Permalink
Feature/refactor for added clarity (#104)
Browse files Browse the repository at this point in the history
* More clarity in directory layout

* complete refactor for additional codebase clarity
  • Loading branch information
emrgnt-cmplxty committed Jun 30, 2023
1 parent a1fe32f commit 64a820c
Show file tree
Hide file tree
Showing 75 changed files with 289 additions and 317 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ Examples of these classes are:
Code example for creating an instance of 'SymbolCodeEmbedding':
```python
import numpy as np
from automata.core.symbol.base import SymbolCodeEmbedding
from automata.core.symbol_embedding.base import SymbolCodeEmbedding
from automata.core.symbol.parser import parse_symbol

symbol_str = 'scip-python python automata 75482692a6fe30c72db516201a6f47d9fb4af065 `automata.core.agent.agent_enums`/ActionIndicator#'
Expand Down
6 changes: 2 additions & 4 deletions automata/cli/scripts/run_code_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@
from automata.config.base import ConfigCategory
from automata.core.base.database.vector import JSONEmbeddingVectorDatabase
from automata.core.coding.py.module_loader import py_module_loader
from automata.core.embedding.code_embedding import (
SymbolCodeEmbeddingBuilder,
SymbolCodeEmbeddingHandler,
)
from automata.core.llm.providers.openai import OpenAIEmbeddingProvider
from automata.core.memory_store.symbol_code_embedding import SymbolCodeEmbeddingHandler
from automata.core.symbol.graph import SymbolGraph
from automata.core.symbol.symbol_utils import get_rankable_symbols
from automata.core.symbol_embedding.embedding_builders import SymbolCodeEmbeddingBuilder
from automata.core.utils import get_config_fpath

logger = logging.getLogger(__name__)
Expand Down
16 changes: 7 additions & 9 deletions automata/cli/scripts/run_doc_embedding_l2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,22 @@
PyContextRetriever,
PyContextRetrieverConfig,
)
from automata.core.embedding.code_embedding import (
SymbolCodeEmbeddingBuilder,
SymbolCodeEmbeddingHandler,
)
from automata.core.embedding.doc_embedding import (
SymbolDocEmbeddingBuilder,
SymbolDocEmbeddingHandler,
)
from automata.core.embedding.symbol_similarity import SymbolSimilarityCalculator
from automata.core.llm.providers.openai import (
OpenAIChatCompletionProvider,
OpenAIEmbeddingProvider,
)
from automata.core.memory_store.symbol_code_embedding import SymbolCodeEmbeddingHandler
from automata.core.memory_store.symbol_doc_embedding import SymbolDocEmbeddingHandler
from automata.core.symbol.base import SymbolDescriptor
from automata.core.symbol.graph import SymbolGraph
from automata.core.symbol.search.rank import SymbolRankConfig
from automata.core.symbol.search.symbol_search import SymbolSearch
from automata.core.symbol.symbol_utils import get_rankable_symbols
from automata.core.symbol_embedding.embedding_builders import (
SymbolCodeEmbeddingBuilder,
SymbolDocEmbeddingBuilder,
)
from automata.core.symbol_embedding.similarity import SymbolSimilarityCalculator
from automata.core.utils import get_config_fpath

logger = logging.getLogger(__name__)
Expand Down
16 changes: 7 additions & 9 deletions automata/cli/scripts/run_doc_embedding_l3.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,22 @@
PyContextRetriever,
PyContextRetrieverConfig,
)
from automata.core.embedding.code_embedding import (
SymbolCodeEmbeddingBuilder,
SymbolCodeEmbeddingHandler,
)
from automata.core.embedding.doc_embedding import (
SymbolDocEmbeddingBuilder,
SymbolDocEmbeddingHandler,
)
from automata.core.embedding.symbol_similarity import SymbolSimilarityCalculator
from automata.core.llm.providers.openai import (
OpenAIChatCompletionProvider,
OpenAIEmbeddingProvider,
)
from automata.core.memory_store.symbol_code_embedding import SymbolCodeEmbeddingHandler
from automata.core.memory_store.symbol_doc_embedding import SymbolDocEmbeddingHandler
from automata.core.symbol.base import SymbolDescriptor
from automata.core.symbol.graph import SymbolGraph
from automata.core.symbol.search.rank import SymbolRankConfig
from automata.core.symbol.search.symbol_search import SymbolSearch
from automata.core.symbol.symbol_utils import get_rankable_symbols
from automata.core.symbol_embedding.embedding_builders import (
SymbolCodeEmbeddingBuilder,
SymbolDocEmbeddingBuilder,
)
from automata.core.symbol_embedding.similarity import SymbolSimilarityCalculator
from automata.core.utils import get_config_fpath

logger = logging.getLogger(__name__)
Expand Down
4 changes: 2 additions & 2 deletions automata/config/prompt/doc_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@
│   └── vector.py
├── embedding
│   ├── __init__.py
│   ├── code_embedding.py
│   ├── doc_embedding.py
│   ├── symbol_code_embedding.py
│   ├── symbol_doc_embedding.py
│   ├── embedding_types.py
│   └── symbol_similarity.py
├── symbol
Expand Down
4 changes: 2 additions & 2 deletions automata/config/symbol/symbol_code_embedding.json
Git LFS file not shown
4 changes: 2 additions & 2 deletions automata/config/symbol/symbol_doc_embedding_l2.json
Git LFS file not shown
4 changes: 2 additions & 2 deletions automata/config/symbol/symbol_doc_embedding_l3.json
Git LFS file not shown
12 changes: 6 additions & 6 deletions automata/core/agent/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, instructions: str, config: AutomataOpenAIAgentConfig) -> None
super().__init__(instructions)
self.config = config
self.iteration_count = 0
self.conversation = OpenAIConversation()
self.agent_conversations = OpenAIConversation()
self.completed = False
self._setup()

Expand Down Expand Up @@ -133,7 +133,7 @@ def run(self) -> str:
except AgentStopIteration:
break

last_message = self.conversation.get_latest_message()
last_message = self.agent_conversations.get_latest_message()
if self.iteration_count >= self.config.max_iterations:
raise AgentMaxIterError("The agent exceeded the maximum number of iterations.")
if not self.completed or not isinstance(last_message, OpenAIChatMessage):
Expand All @@ -148,7 +148,7 @@ def set_database_provider(self, provider: LLMConversationDatabaseProvider) -> No
if self.database_provider:
raise AgentDatabaseError("The database provider has already been set.")
self.database_provider = provider
self.conversation.register_observer(provider)
self.agent_conversations.register_observer(provider)

def _build_initial_messages(
self, instruction_formatter: Dict[str, str]
Expand Down Expand Up @@ -217,21 +217,21 @@ def _setup(self) -> None:
AgentError: If the agent fails to initialize.
"""
logger.debug(f"Setting up agent with tools = {self.config.tools}")
self.conversation.add_message(
self.agent_conversations.add_message(
OpenAIChatMessage(role="system", content=self.config.system_instruction)
)
for message in list(
self._build_initial_messages({"user_input_instructions": self.instructions})
):
logger.debug(f"Adding the following initial mesasge to the conversation {message}")
self.conversation.add_message(message)
self.agent_conversations.add_message(message)
logging.debug(f"\n{('-' * 120)}")

self.chat_provider = OpenAIChatCompletionProvider(
model=self.config.model,
temperature=self.config.temperature,
stream=self.config.stream,
conversation=self.conversation,
conversation=self.agent_conversations,
functions=self.functions,
)
self._initialized = True
Expand Down
2 changes: 1 addition & 1 deletion automata/core/agent/tool/builder/context_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from automata.core.agent.tool.registry import AutomataOpenAIAgentToolBuilderRegistry
from automata.core.base.agent import AgentToolBuilder, AgentToolProviders
from automata.core.base.tool import Tool
from automata.core.embedding.symbol_similarity import SymbolSimilarityCalculator
from automata.core.llm.providers.openai import OpenAITool
from automata.core.symbol_embedding.similarity import SymbolSimilarityCalculator

logger = logging.getLogger(__name__)

Expand Down
10 changes: 4 additions & 6 deletions automata/core/agent/tool/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,17 @@
PyContextRetriever,
PyContextRetrieverConfig,
)
from automata.core.embedding.code_embedding import SymbolCodeEmbeddingHandler
from automata.core.embedding.doc_embedding import (
SymbolDocEmbeddingBuilder,
SymbolDocEmbeddingHandler,
)
from automata.core.embedding.symbol_similarity import SymbolSimilarityCalculator
from automata.core.llm.providers.openai import (
OpenAIChatCompletionProvider,
OpenAIEmbeddingProvider,
)
from automata.core.memory_store.symbol_code_embedding import SymbolCodeEmbeddingHandler
from automata.core.memory_store.symbol_doc_embedding import SymbolDocEmbeddingHandler
from automata.core.symbol.graph import SymbolGraph
from automata.core.symbol.search.rank import SymbolRank, SymbolRankConfig
from automata.core.symbol.search.symbol_search import SymbolSearch
from automata.core.symbol_embedding.embedding_builders import SymbolDocEmbeddingBuilder
from automata.core.symbol_embedding.similarity import SymbolSimilarityCalculator
from automata.core.utils import get_config_fpath

logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion automata/core/base/database/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import jsonpickle

from automata.core.symbol.base import SymbolEmbedding
from automata.core.symbol_embedding.base import SymbolEmbedding

logger = logging.getLogger(__name__)

Expand Down
3 changes: 2 additions & 1 deletion automata/core/coding/py/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
find_syntax_tree_node,
)
from automata.core.coding.py.reader import PyReader
from automata.core.symbol.base import Symbol, SymbolDocEmbedding
from automata.core.symbol.base import Symbol
from automata.core.symbol_embedding.base import SymbolDocEmbedding

logger = logging.getLogger(__name__)

Expand Down
16 changes: 8 additions & 8 deletions automata/core/llm/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def __init__(
self.temperature = temperature
self.stream = stream
self.functions = functions
self.conversation = conversation
self.agent_conversations = conversation
set_openai_api_key()

def get_next_assistant_completion(self) -> OpenAIChatMessage:
Expand All @@ -229,15 +229,15 @@ def get_next_assistant_completion(self) -> OpenAIChatMessage:
if functions:
response = openai.ChatCompletion.create(
model=self.model,
messages=self.conversation.get_messages_for_next_completion(),
messages=self.agent_conversations.get_messages_for_next_completion(),
functions=functions,
function_call="auto", # auto is default, but we'll be explicit
stream=self.stream,
)
else:
response = openai.ChatCompletion.create(
model=self.model,
messages=self.conversation.get_messages_for_next_completion(),
messages=self.agent_conversations.get_messages_for_next_completion(),
stream=self.stream,
)
if self.stream:
Expand All @@ -249,11 +249,11 @@ def get_next_assistant_completion(self) -> OpenAIChatMessage:
)

def reset(self) -> None:
self.conversation.reset_conversation()
self.agent_conversations.reset_conversation()

def standalone_call(self, prompt: str) -> str:
"""Return the completion message based on the provided prompt."""
if self.conversation.messages:
if self.agent_conversations.messages:
raise ValueError(
"The conversation is not empty. Please call reset() before calling standalone_call()."
)
Expand All @@ -266,11 +266,11 @@ def standalone_call(self, prompt: str) -> str:

def add_message(self, message: LLMChatMessage) -> None:
if not isinstance(message, OpenAIChatMessage):
self.conversation.add_message(
self.agent_conversations.add_message(
OpenAIChatMessage(role=message.role, content=message.content)
)
else:
self.conversation.add_message(message)
self.agent_conversations.add_message(message)
logger.debug(
f"Approximately {self.get_approximate_tokens_consumed()} tokens were after adding the latest message."
)
Expand Down Expand Up @@ -361,7 +361,7 @@ def get_approximate_tokens_consumed(self) -> int:
"\n".join(
[
json.dumps(ele)
for ele in self.conversation.get_messages_for_next_completion()
for ele in self.agent_conversations.get_messages_for_next_completion()
]
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from automata.core.llm.completion import LLMChatMessage, LLMConversationDatabaseProvider


class AutomataAgentConversationDatabase(LLMConversationDatabaseProvider):
class AgentConversationDatabase(LLMConversationDatabaseProvider):
"""A conversation database for an Automata agent."""

def __init__(self, session_id: str, db_path: str = CONVERSATION_DB_PATH) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,30 +1,14 @@
import logging

from automata.core.base.database.vector import VectorDatabaseProvider
from automata.core.llm.embedding import (
EmbeddingProvider,
SymbolEmbeddingBuilder,
SymbolEmbeddingHandler,
)
from automata.core.symbol.base import Symbol, SymbolCodeEmbedding
from automata.core.llm.embedding import SymbolEmbeddingHandler
from automata.core.symbol.base import Symbol
from automata.core.symbol_embedding.base import SymbolCodeEmbedding
from automata.core.symbol_embedding.embedding_builders import SymbolCodeEmbeddingBuilder

logger = logging.getLogger(__name__)


class SymbolCodeEmbeddingBuilder(SymbolEmbeddingBuilder):
"""Builds `Symbol` source code embeddings."""

def __init__(
self,
embedding_provider: EmbeddingProvider,
) -> None:
self.embedding_provider = embedding_provider

def build(self, source_code: str, symbol: Symbol) -> SymbolCodeEmbedding:
embedding_vector = self.embedding_provider.build_embedding_array(source_code)
return SymbolCodeEmbedding(symbol, source_code, embedding_vector)


class SymbolCodeEmbeddingHandler(SymbolEmbeddingHandler):
"""Handles a database for `Symbol` source code embeddings."""

Expand Down
49 changes: 49 additions & 0 deletions automata/core/memory_store/symbol_doc_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import logging

from automata.core.base.database.vector import VectorDatabaseProvider
from automata.core.llm.embedding import SymbolEmbeddingHandler
from automata.core.symbol.base import Symbol
from automata.core.symbol_embedding.base import SymbolDocEmbedding
from automata.core.symbol_embedding.embedding_builders import SymbolDocEmbeddingBuilder

logger = logging.getLogger(__name__)


class SymbolDocEmbeddingHandler(SymbolEmbeddingHandler):
"""A class to handle the embedding of symbols"""

def __init__(
self,
embedding_db: VectorDatabaseProvider,
embedding_builder: SymbolDocEmbeddingBuilder,
) -> None:
self.embedding_db = embedding_db
self.embedding_builder = embedding_builder

def get_embedding(self, symbol: Symbol) -> SymbolDocEmbedding:
return self.embedding_db.get(symbol.dotpath)

def process_embedding(self, symbol: Symbol) -> None:
source_code = self.embedding_builder.fetch_embedding_context(symbol)

if not source_code:
raise ValueError(f"Symbol {symbol} has no source code")

if self.embedding_db.contains(symbol.dotpath):
self.update_existing_embedding(source_code, symbol)
else:
symbol_embedding = self.embedding_builder.build(source_code, symbol)
self.embedding_db.add(symbol_embedding)

def update_existing_embedding(self, source_code: str, symbol: Symbol) -> None:
existing_embedding = self.embedding_db.get(symbol.dotpath)
if existing_embedding.symbol != symbol or existing_embedding.source_code != source_code:
logger.debug(
f"Rolling forward the embedding for {existing_embedding.symbol} to {symbol}"
)
self.embedding_db.discard(symbol.dotpath)
existing_embedding.symbol = symbol
existing_embedding.source_code = source_code
self.embedding_db.add(existing_embedding)
else:
logger.debug("Passing for %s", symbol)
Loading

0 comments on commit 64a820c

Please sign in to comment.