Skip to content
This repository has been archived by the owner on Mar 16, 2024. It is now read-only.

Feature/refactor for added clarity #104

Merged
merged 2 commits into from
Jun 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ Examples of these classes are:
Code example for creating an instance of 'SymbolCodeEmbedding':
```python
import numpy as np
from automata.core.symbol.base import SymbolCodeEmbedding
from automata.core.symbol_embedding.base import SymbolCodeEmbedding
from automata.core.symbol.parser import parse_symbol

symbol_str = 'scip-python python automata 75482692a6fe30c72db516201a6f47d9fb4af065 `automata.core.agent.agent_enums`/ActionIndicator#'
Expand Down
6 changes: 2 additions & 4 deletions automata/cli/scripts/run_code_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@
from automata.config.base import ConfigCategory
from automata.core.base.database.vector import JSONEmbeddingVectorDatabase
from automata.core.coding.py.module_loader import py_module_loader
from automata.core.embedding.code_embedding import (
SymbolCodeEmbeddingBuilder,
SymbolCodeEmbeddingHandler,
)
from automata.core.llm.providers.openai import OpenAIEmbeddingProvider
from automata.core.memory_store.symbol_code_embedding import SymbolCodeEmbeddingHandler
from automata.core.symbol.graph import SymbolGraph
from automata.core.symbol.symbol_utils import get_rankable_symbols
from automata.core.symbol_embedding.embedding_builders import SymbolCodeEmbeddingBuilder
from automata.core.utils import get_config_fpath

logger = logging.getLogger(__name__)
Expand Down
16 changes: 7 additions & 9 deletions automata/cli/scripts/run_doc_embedding_l2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,22 @@
PyContextRetriever,
PyContextRetrieverConfig,
)
from automata.core.embedding.code_embedding import (
SymbolCodeEmbeddingBuilder,
SymbolCodeEmbeddingHandler,
)
from automata.core.embedding.doc_embedding import (
SymbolDocEmbeddingBuilder,
SymbolDocEmbeddingHandler,
)
from automata.core.embedding.symbol_similarity import SymbolSimilarityCalculator
from automata.core.llm.providers.openai import (
OpenAIChatCompletionProvider,
OpenAIEmbeddingProvider,
)
from automata.core.memory_store.symbol_code_embedding import SymbolCodeEmbeddingHandler
from automata.core.memory_store.symbol_doc_embedding import SymbolDocEmbeddingHandler
from automata.core.symbol.base import SymbolDescriptor
from automata.core.symbol.graph import SymbolGraph
from automata.core.symbol.search.rank import SymbolRankConfig
from automata.core.symbol.search.symbol_search import SymbolSearch
from automata.core.symbol.symbol_utils import get_rankable_symbols
from automata.core.symbol_embedding.embedding_builders import (
SymbolCodeEmbeddingBuilder,
SymbolDocEmbeddingBuilder,
)
from automata.core.symbol_embedding.similarity import SymbolSimilarityCalculator
from automata.core.utils import get_config_fpath

logger = logging.getLogger(__name__)
Expand Down
16 changes: 7 additions & 9 deletions automata/cli/scripts/run_doc_embedding_l3.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,22 @@
PyContextRetriever,
PyContextRetrieverConfig,
)
from automata.core.embedding.code_embedding import (
SymbolCodeEmbeddingBuilder,
SymbolCodeEmbeddingHandler,
)
from automata.core.embedding.doc_embedding import (
SymbolDocEmbeddingBuilder,
SymbolDocEmbeddingHandler,
)
from automata.core.embedding.symbol_similarity import SymbolSimilarityCalculator
from automata.core.llm.providers.openai import (
OpenAIChatCompletionProvider,
OpenAIEmbeddingProvider,
)
from automata.core.memory_store.symbol_code_embedding import SymbolCodeEmbeddingHandler
from automata.core.memory_store.symbol_doc_embedding import SymbolDocEmbeddingHandler
from automata.core.symbol.base import SymbolDescriptor
from automata.core.symbol.graph import SymbolGraph
from automata.core.symbol.search.rank import SymbolRankConfig
from automata.core.symbol.search.symbol_search import SymbolSearch
from automata.core.symbol.symbol_utils import get_rankable_symbols
from automata.core.symbol_embedding.embedding_builders import (
SymbolCodeEmbeddingBuilder,
SymbolDocEmbeddingBuilder,
)
from automata.core.symbol_embedding.similarity import SymbolSimilarityCalculator
from automata.core.utils import get_config_fpath

logger = logging.getLogger(__name__)
Expand Down
4 changes: 2 additions & 2 deletions automata/config/prompt/doc_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@
│   └── vector.py
├── embedding
│   ├── __init__.py
│   ├── code_embedding.py
│   ├── doc_embedding.py
│   ├── symbol_code_embedding.py
│   ├── symbol_doc_embedding.py
│   ├── embedding_types.py
│   └── symbol_similarity.py
├── symbol
Expand Down
4 changes: 2 additions & 2 deletions automata/config/symbol/symbol_code_embedding.json
Git LFS file not shown
4 changes: 2 additions & 2 deletions automata/config/symbol/symbol_doc_embedding_l2.json
Git LFS file not shown
4 changes: 2 additions & 2 deletions automata/config/symbol/symbol_doc_embedding_l3.json
Git LFS file not shown
12 changes: 6 additions & 6 deletions automata/core/agent/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, instructions: str, config: AutomataOpenAIAgentConfig) -> None
super().__init__(instructions)
self.config = config
self.iteration_count = 0
self.conversation = OpenAIConversation()
self.agent_conversations = OpenAIConversation()
self.completed = False
self._setup()

Expand Down Expand Up @@ -133,7 +133,7 @@ def run(self) -> str:
except AgentStopIteration:
break

last_message = self.conversation.get_latest_message()
last_message = self.agent_conversations.get_latest_message()
if self.iteration_count >= self.config.max_iterations:
raise AgentMaxIterError("The agent exceeded the maximum number of iterations.")
if not self.completed or not isinstance(last_message, OpenAIChatMessage):
Expand All @@ -148,7 +148,7 @@ def set_database_provider(self, provider: LLMConversationDatabaseProvider) -> No
if self.database_provider:
raise AgentDatabaseError("The database provider has already been set.")
self.database_provider = provider
self.conversation.register_observer(provider)
self.agent_conversations.register_observer(provider)

def _build_initial_messages(
self, instruction_formatter: Dict[str, str]
Expand Down Expand Up @@ -217,21 +217,21 @@ def _setup(self) -> None:
AgentError: If the agent fails to initialize.
"""
logger.debug(f"Setting up agent with tools = {self.config.tools}")
self.conversation.add_message(
self.agent_conversations.add_message(
OpenAIChatMessage(role="system", content=self.config.system_instruction)
)
for message in list(
self._build_initial_messages({"user_input_instructions": self.instructions})
):
logger.debug(f"Adding the following initial mesasge to the conversation {message}")
self.conversation.add_message(message)
self.agent_conversations.add_message(message)
logging.debug(f"\n{('-' * 120)}")

self.chat_provider = OpenAIChatCompletionProvider(
model=self.config.model,
temperature=self.config.temperature,
stream=self.config.stream,
conversation=self.conversation,
conversation=self.agent_conversations,
functions=self.functions,
)
self._initialized = True
Expand Down
2 changes: 1 addition & 1 deletion automata/core/agent/tool/builder/context_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from automata.core.agent.tool.registry import AutomataOpenAIAgentToolBuilderRegistry
from automata.core.base.agent import AgentToolBuilder, AgentToolProviders
from automata.core.base.tool import Tool
from automata.core.embedding.symbol_similarity import SymbolSimilarityCalculator
from automata.core.llm.providers.openai import OpenAITool
from automata.core.symbol_embedding.similarity import SymbolSimilarityCalculator

logger = logging.getLogger(__name__)

Expand Down
10 changes: 4 additions & 6 deletions automata/core/agent/tool/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,17 @@
PyContextRetriever,
PyContextRetrieverConfig,
)
from automata.core.embedding.code_embedding import SymbolCodeEmbeddingHandler
from automata.core.embedding.doc_embedding import (
SymbolDocEmbeddingBuilder,
SymbolDocEmbeddingHandler,
)
from automata.core.embedding.symbol_similarity import SymbolSimilarityCalculator
from automata.core.llm.providers.openai import (
OpenAIChatCompletionProvider,
OpenAIEmbeddingProvider,
)
from automata.core.memory_store.symbol_code_embedding import SymbolCodeEmbeddingHandler
from automata.core.memory_store.symbol_doc_embedding import SymbolDocEmbeddingHandler
from automata.core.symbol.graph import SymbolGraph
from automata.core.symbol.search.rank import SymbolRank, SymbolRankConfig
from automata.core.symbol.search.symbol_search import SymbolSearch
from automata.core.symbol_embedding.embedding_builders import SymbolDocEmbeddingBuilder
from automata.core.symbol_embedding.similarity import SymbolSimilarityCalculator
from automata.core.utils import get_config_fpath

logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion automata/core/base/database/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import jsonpickle

from automata.core.symbol.base import SymbolEmbedding
from automata.core.symbol_embedding.base import SymbolEmbedding

logger = logging.getLogger(__name__)

Expand Down
3 changes: 2 additions & 1 deletion automata/core/coding/py/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
find_syntax_tree_node,
)
from automata.core.coding.py.reader import PyReader
from automata.core.symbol.base import Symbol, SymbolDocEmbedding
from automata.core.symbol.base import Symbol
from automata.core.symbol_embedding.base import SymbolDocEmbedding

logger = logging.getLogger(__name__)

Expand Down
16 changes: 8 additions & 8 deletions automata/core/llm/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def __init__(
self.temperature = temperature
self.stream = stream
self.functions = functions
self.conversation = conversation
self.agent_conversations = conversation
set_openai_api_key()

def get_next_assistant_completion(self) -> OpenAIChatMessage:
Expand All @@ -229,15 +229,15 @@ def get_next_assistant_completion(self) -> OpenAIChatMessage:
if functions:
response = openai.ChatCompletion.create(
model=self.model,
messages=self.conversation.get_messages_for_next_completion(),
messages=self.agent_conversations.get_messages_for_next_completion(),
functions=functions,
function_call="auto", # auto is default, but we'll be explicit
stream=self.stream,
)
else:
response = openai.ChatCompletion.create(
model=self.model,
messages=self.conversation.get_messages_for_next_completion(),
messages=self.agent_conversations.get_messages_for_next_completion(),
stream=self.stream,
)
if self.stream:
Expand All @@ -249,11 +249,11 @@ def get_next_assistant_completion(self) -> OpenAIChatMessage:
)

def reset(self) -> None:
self.conversation.reset_conversation()
self.agent_conversations.reset_conversation()

def standalone_call(self, prompt: str) -> str:
"""Return the completion message based on the provided prompt."""
if self.conversation.messages:
if self.agent_conversations.messages:
raise ValueError(
"The conversation is not empty. Please call reset() before calling standalone_call()."
)
Expand All @@ -266,11 +266,11 @@ def standalone_call(self, prompt: str) -> str:

def add_message(self, message: LLMChatMessage) -> None:
if not isinstance(message, OpenAIChatMessage):
self.conversation.add_message(
self.agent_conversations.add_message(
OpenAIChatMessage(role=message.role, content=message.content)
)
else:
self.conversation.add_message(message)
self.agent_conversations.add_message(message)
logger.debug(
f"Approximately {self.get_approximate_tokens_consumed()} tokens were after adding the latest message."
)
Expand Down Expand Up @@ -361,7 +361,7 @@ def get_approximate_tokens_consumed(self) -> int:
"\n".join(
[
json.dumps(ele)
for ele in self.conversation.get_messages_for_next_completion()
for ele in self.agent_conversations.get_messages_for_next_completion()
]
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from automata.core.llm.completion import LLMChatMessage, LLMConversationDatabaseProvider


class AutomataAgentConversationDatabase(LLMConversationDatabaseProvider):
class AgentConversationDatabase(LLMConversationDatabaseProvider):
"""A conversation database for an Automata agent."""

def __init__(self, session_id: str, db_path: str = CONVERSATION_DB_PATH) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,30 +1,14 @@
import logging

from automata.core.base.database.vector import VectorDatabaseProvider
from automata.core.llm.embedding import (
EmbeddingProvider,
SymbolEmbeddingBuilder,
SymbolEmbeddingHandler,
)
from automata.core.symbol.base import Symbol, SymbolCodeEmbedding
from automata.core.llm.embedding import SymbolEmbeddingHandler
from automata.core.symbol.base import Symbol
from automata.core.symbol_embedding.base import SymbolCodeEmbedding
from automata.core.symbol_embedding.embedding_builders import SymbolCodeEmbeddingBuilder

logger = logging.getLogger(__name__)


class SymbolCodeEmbeddingBuilder(SymbolEmbeddingBuilder):
"""Builds `Symbol` source code embeddings."""

def __init__(
self,
embedding_provider: EmbeddingProvider,
) -> None:
self.embedding_provider = embedding_provider

def build(self, source_code: str, symbol: Symbol) -> SymbolCodeEmbedding:
embedding_vector = self.embedding_provider.build_embedding_array(source_code)
return SymbolCodeEmbedding(symbol, source_code, embedding_vector)


class SymbolCodeEmbeddingHandler(SymbolEmbeddingHandler):
"""Handles a database for `Symbol` source code embeddings."""

Expand Down
49 changes: 49 additions & 0 deletions automata/core/memory_store/symbol_doc_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import logging

from automata.core.base.database.vector import VectorDatabaseProvider
from automata.core.llm.embedding import SymbolEmbeddingHandler
from automata.core.symbol.base import Symbol
from automata.core.symbol_embedding.base import SymbolDocEmbedding
from automata.core.symbol_embedding.embedding_builders import SymbolDocEmbeddingBuilder

logger = logging.getLogger(__name__)


class SymbolDocEmbeddingHandler(SymbolEmbeddingHandler):
"""A class to handle the embedding of symbols"""

def __init__(
self,
embedding_db: VectorDatabaseProvider,
embedding_builder: SymbolDocEmbeddingBuilder,
) -> None:
self.embedding_db = embedding_db
self.embedding_builder = embedding_builder

def get_embedding(self, symbol: Symbol) -> SymbolDocEmbedding:
return self.embedding_db.get(symbol.dotpath)

def process_embedding(self, symbol: Symbol) -> None:
source_code = self.embedding_builder.fetch_embedding_context(symbol)

if not source_code:
raise ValueError(f"Symbol {symbol} has no source code")

if self.embedding_db.contains(symbol.dotpath):
self.update_existing_embedding(source_code, symbol)
else:
symbol_embedding = self.embedding_builder.build(source_code, symbol)
self.embedding_db.add(symbol_embedding)

def update_existing_embedding(self, source_code: str, symbol: Symbol) -> None:
existing_embedding = self.embedding_db.get(symbol.dotpath)
if existing_embedding.symbol != symbol or existing_embedding.source_code != source_code:
logger.debug(
f"Rolling forward the embedding for {existing_embedding.symbol} to {symbol}"
)
self.embedding_db.discard(symbol.dotpath)
existing_embedding.symbol = symbol
existing_embedding.source_code = source_code
self.embedding_db.add(existing_embedding)
else:
logger.debug("Passing for %s", symbol)
Loading