Skip to content

Commit

Permalink
Add AnthropicStructureConfig using VoyageAi
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Mar 20, 2024
1 parent 194c275 commit e650a4c
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 11 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `TrafilaturaWebScraperDriver` for scraping text from web pages using trafilatura.
- `MarkdownifyWebScraperDriver` for scraping text from web pages using playwright and converting to markdown using markdownify.
- `VoyageAiEmbeddingDriver` for use with VoyageAi's embedding models.
- `AnthropicStructureConfig` for providing Structures with Anthropic Prompt and VoyageAi Embedding Driver configuration.

### Fixed
- Improved system prompt in `ToolTask` to support more use cases.
Expand Down
2 changes: 2 additions & 0 deletions griptape/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .structure_config import StructureConfig
from .openai_structure_config import OpenAiStructureConfig
from .amazon_bedrock_structure_config import AmazonBedrockStructureConfig
from .anthropic_structure_config import AnthropicStructureConfig


__all__ = [
Expand All @@ -27,4 +28,5 @@
"StructureConfig",
"OpenAiStructureConfig",
"AmazonBedrockStructureConfig",
"AnthropicStructureConfig",
]
48 changes: 48 additions & 0 deletions griptape/config/anthropic_structure_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from attrs import Factory, define, field

from griptape.config import (
BaseStructureConfig,
StructureGlobalDriversConfig,
StructureTaskMemoryConfig,
StructureTaskMemoryExtractionEngineConfig,
StructureTaskMemoryExtractionEngineCsvConfig,
StructureTaskMemoryExtractionEngineJsonConfig,
StructureTaskMemoryQueryEngineConfig,
StructureTaskMemorySummaryEngineConfig,
)
from griptape.drivers import LocalVectorStoreDriver, AnthropicPromptDriver, VoyageAiEmbeddingDriver


@define
class AnthropicStructureConfig(BaseStructureConfig):
global_drivers: StructureGlobalDriversConfig = field(
default=Factory(
lambda: StructureGlobalDriversConfig(
prompt_driver=AnthropicPromptDriver(model="claude-3-opus-20240229"),
embedding_driver=VoyageAiEmbeddingDriver(model="voyage-large-2"),
vector_store_driver=LocalVectorStoreDriver(
embedding_driver=VoyageAiEmbeddingDriver(model="voyage-large-2")
),
)
),
kw_only=True,
metadata={"serializable": True},
)
task_memory: StructureTaskMemoryConfig = field(
default=Factory(
lambda self: StructureTaskMemoryConfig(
query_engine=StructureTaskMemoryQueryEngineConfig(
prompt_driver=self.global_drivers.prompt_driver,
vector_store_driver=LocalVectorStoreDriver(embedding_driver=self.global_drivers.embedding_driver),
),
extraction_engine=StructureTaskMemoryExtractionEngineConfig(
csv=StructureTaskMemoryExtractionEngineCsvConfig(prompt_driver=self.global_drivers.prompt_driver),
json=StructureTaskMemoryExtractionEngineJsonConfig(prompt_driver=self.global_drivers.prompt_driver),
),
summary_engine=StructureTaskMemorySummaryEngineConfig(prompt_driver=self.global_drivers.prompt_driver),
),
takes_self=True,
),
kw_only=True,
metadata={"serializable": True},
)
7 changes: 2 additions & 5 deletions griptape/drivers/embedding/voyageai_embedding_driver.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from __future__ import annotations
from typing import Optional, TYPE_CHECKING
from typing import Optional, Any
from attr import define, field, Factory
from griptape.utils import import_optional_dependency
from griptape.drivers import BaseEmbeddingDriver
from griptape.tokenizers import VoyageAiTokenizer

if TYPE_CHECKING:
from voyageai import Client


@define
class VoyageAiEmbeddingDriver(BaseEmbeddingDriver):
Expand All @@ -24,7 +21,7 @@ class VoyageAiEmbeddingDriver(BaseEmbeddingDriver):

model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={"serializable": True})
api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
client: Client = field(
client: Any = field(
default=Factory(
lambda self: import_optional_dependency("voyageai").Client(api_key=self.api_key), takes_self=True
)
Expand Down
9 changes: 3 additions & 6 deletions griptape/drivers/prompt/anthropic_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import Optional, Any
from collections.abc import Iterator
from attr import define, field, Factory
from griptape.artifacts import TextArtifact
from griptape.utils import PromptStack, import_optional_dependency
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import AnthropicTokenizer

if TYPE_CHECKING:
from anthropic import Anthropic


@define
class AnthropicPromptDriver(BasePromptDriver):
Expand All @@ -21,9 +18,9 @@ class AnthropicPromptDriver(BasePromptDriver):
tokenizer: Custom `AnthropicTokenizer`.
"""

api_key: str = field(kw_only=True, metadata={"serializable": True})
api_key: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
model: str = field(kw_only=True, metadata={"serializable": True})
client: Anthropic = field(
client: Any = field(
default=Factory(
lambda self: import_optional_dependency("anthropic").Anthropic(api_key=self.api_key), takes_self=True
),
Expand Down
119 changes: 119 additions & 0 deletions tests/unit/config/test_anthropic_structure_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from pytest import fixture
from griptape.config import AnthropicStructureConfig


class TestAnthropicStructureConfig:
@fixture(autouse=True)
def mock_anthropic(self, mocker):
mocker.patch("anthropic.Anthropic")
mocker.patch("voyageai.Client")

@fixture
def config(self):
return AnthropicStructureConfig()

def test_to_dict(self, config):
assert config.to_dict() == {
"type": "AnthropicStructureConfig",
"global_drivers": {
"type": "StructureGlobalDriversConfig",
"prompt_driver": {
"type": "AnthropicPromptDriver",
"temperature": 0.1,
"max_tokens": None,
"stream": False,
"api_key": None,
"model": "claude-3-opus-20240229",
"top_p": 0.999,
"top_k": 250,
},
"image_generation_driver": {"type": "DummyImageGenerationDriver"},
"image_query_driver": {"type": "DummyImageQueryDriver"},
"embedding_driver": {
"type": "VoyageAiEmbeddingDriver",
"model": "voyage-large-2",
"api_key": None,
"input_type": "document",
},
"vector_store_driver": {
"type": "LocalVectorStoreDriver",
"embedding_driver": {
"type": "VoyageAiEmbeddingDriver",
"model": "voyage-large-2",
"api_key": None,
"input_type": "document",
},
},
"conversation_memory_driver": None,
},
"task_memory": {
"type": "StructureTaskMemoryConfig",
"query_engine": {
"type": "StructureTaskMemoryQueryEngineConfig",
"prompt_driver": {
"type": "AnthropicPromptDriver",
"temperature": 0.1,
"max_tokens": None,
"stream": False,
"api_key": None,
"model": "claude-3-opus-20240229",
"top_p": 0.999,
"top_k": 250,
},
"vector_store_driver": {
"type": "LocalVectorStoreDriver",
"embedding_driver": {
"type": "VoyageAiEmbeddingDriver",
"model": "voyage-large-2",
"api_key": None,
"input_type": "document",
},
},
},
"extraction_engine": {
"type": "StructureTaskMemoryExtractionEngineConfig",
"csv": {
"type": "StructureTaskMemoryExtractionEngineCsvConfig",
"prompt_driver": {
"type": "AnthropicPromptDriver",
"temperature": 0.1,
"max_tokens": None,
"stream": False,
"api_key": None,
"model": "claude-3-opus-20240229",
"top_p": 0.999,
"top_k": 250,
},
},
"json": {
"type": "StructureTaskMemoryExtractionEngineJsonConfig",
"prompt_driver": {
"type": "AnthropicPromptDriver",
"temperature": 0.1,
"max_tokens": None,
"stream": False,
"api_key": None,
"model": "claude-3-opus-20240229",
"top_p": 0.999,
"top_k": 250,
},
},
},
"summary_engine": {
"type": "StructureTaskMemorySummaryEngineConfig",
"prompt_driver": {
"type": "AnthropicPromptDriver",
"temperature": 0.1,
"max_tokens": None,
"stream": False,
"api_key": None,
"model": "claude-3-opus-20240229",
"top_p": 0.999,
"top_k": 250,
},
},
},
}

def test_from_dict(self, config):
assert AnthropicStructureConfig.from_dict(config.to_dict()).to_dict() == config.to_dict()

0 comments on commit e650a4c

Please sign in to comment.