-
Notifications
You must be signed in to change notification settings - Fork 135
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f73c21c
commit c87c85f
Showing
114 changed files
with
1,931 additions
and
245 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from .base_config import BaseConfig | ||
|
||
from .structure_global_drivers_config import StructureGlobalDriversConfig | ||
from .structure_task_memory_extraction_engine_csv_config import StructureTaskMemoryExtractionEngineCsvConfig | ||
from .structure_task_memory_extraction_engine_json_config import StructureTaskMemoryExtractionEngineJsonConfig | ||
from .structure_task_memory_extraction_engine_config import StructureTaskMemoryExtractionEngineConfig | ||
from .structure_task_memory_query_engine_config import StructureTaskMemoryQueryEngineConfig | ||
from .structure_task_memory_summary_engine_config import StructureTaskMemorySummaryEngineConfig | ||
from .structure_task_memory_config import StructureTaskMemoryConfig | ||
from .base_structure_config import BaseStructureConfig | ||
|
||
from .structure_config import StructureConfig | ||
from .openai_structure_config import OpenAiStructureConfig | ||
from .amazon_bedrock_structure_config import AmazonBedrockStructureConfig | ||
|
||
|
||
__all__ = [ | ||
"BaseConfig", | ||
"BaseStructureConfig", | ||
"StructureTaskMemoryConfig", | ||
"StructureGlobalDriversConfig", | ||
"StructureTaskMemoryQueryEngineConfig", | ||
"StructureTaskMemorySummaryEngineConfig", | ||
"StructureTaskMemoryExtractionEngineConfig", | ||
"StructureTaskMemoryExtractionEngineCsvConfig", | ||
"StructureTaskMemoryExtractionEngineJsonConfig", | ||
"StructureConfig", | ||
"OpenAiStructureConfig", | ||
"AmazonBedrockStructureConfig", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
from attrs import Factory, define, field | ||
|
||
from griptape.config import ( | ||
BaseStructureConfig, | ||
StructureGlobalDriversConfig, | ||
StructureTaskMemoryConfig, | ||
StructureTaskMemoryExtractionEngineConfig, | ||
StructureTaskMemoryExtractionEngineCsvConfig, | ||
StructureTaskMemoryExtractionEngineJsonConfig, | ||
StructureTaskMemoryQueryEngineConfig, | ||
StructureTaskMemorySummaryEngineConfig, | ||
) | ||
from griptape.drivers import ( | ||
AmazonBedrockImageGenerationDriver, | ||
AmazonBedrockPromptDriver, | ||
AmazonBedrockTitanEmbeddingDriver, | ||
BedrockClaudePromptModelDriver, | ||
BedrockTitanImageGenerationModelDriver, | ||
BedrockTitanPromptModelDriver, | ||
LocalVectorStoreDriver, | ||
) | ||
|
||
|
||
@define() | ||
class AmazonBedrockStructureConfig(BaseStructureConfig): | ||
global_drivers: StructureGlobalDriversConfig = field( | ||
default=Factory( | ||
lambda: StructureGlobalDriversConfig( | ||
prompt_driver=AmazonBedrockPromptDriver( | ||
model="anthropic.claude-v2", stream=False, prompt_model_driver=BedrockClaudePromptModelDriver() | ||
), | ||
image_generation_driver=AmazonBedrockImageGenerationDriver( | ||
model="amazon.titan-image-generator-v1", | ||
image_generation_model_driver=BedrockTitanImageGenerationModelDriver(), | ||
), | ||
embedding_driver=AmazonBedrockTitanEmbeddingDriver(model="amazon.titan-embed-text-v1"), | ||
vector_store_driver=LocalVectorStoreDriver( | ||
embedding_driver=AmazonBedrockTitanEmbeddingDriver(model="amazon.titan-embed-text-v1") | ||
), | ||
) | ||
), | ||
kw_only=True, | ||
metadata={"serializable": True}, | ||
) | ||
task_memory: StructureTaskMemoryConfig = field( | ||
default=Factory( | ||
lambda: StructureTaskMemoryConfig( | ||
query_engine=StructureTaskMemoryQueryEngineConfig( | ||
prompt_driver=AmazonBedrockPromptDriver( | ||
model="amazon.titan-text-express-v1", prompt_model_driver=BedrockTitanPromptModelDriver() | ||
), | ||
vector_store_driver=LocalVectorStoreDriver( | ||
embedding_driver=AmazonBedrockTitanEmbeddingDriver(model="amazon.titan-embed-text-v1") | ||
), | ||
), | ||
extraction_engine=StructureTaskMemoryExtractionEngineConfig( | ||
csv=StructureTaskMemoryExtractionEngineCsvConfig( | ||
prompt_driver=AmazonBedrockPromptDriver( | ||
model="amazon.titan-text-express-v1", prompt_model_driver=BedrockTitanPromptModelDriver() | ||
) | ||
), | ||
json=StructureTaskMemoryExtractionEngineJsonConfig( | ||
prompt_driver=AmazonBedrockPromptDriver( | ||
model="amazon.titan-text-express-v1", prompt_model_driver=BedrockTitanPromptModelDriver() | ||
) | ||
), | ||
), | ||
summary_engine=StructureTaskMemorySummaryEngineConfig( | ||
prompt_driver=AmazonBedrockPromptDriver( | ||
model="amazon.titan-text-express-v1", prompt_model_driver=BedrockTitanPromptModelDriver() | ||
) | ||
), | ||
) | ||
), | ||
kw_only=True, | ||
metadata={"serializable": True}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from abc import ABC | ||
|
||
from attrs import define | ||
|
||
from griptape.mixins.serializable_mixin import SerializableMixin | ||
|
||
|
||
@define | ||
class BaseConfig(SerializableMixin, ABC): | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from __future__ import annotations | ||
|
||
from abc import ABC | ||
|
||
from attr import define, field | ||
|
||
from griptape.config import BaseConfig, StructureGlobalDriversConfig, StructureTaskMemoryConfig | ||
from griptape.utils import dict_merge | ||
|
||
|
||
@define | ||
class BaseStructureConfig(BaseConfig, ABC): | ||
global_drivers: StructureGlobalDriversConfig = field(kw_only=True, metadata={"serializable": True}) | ||
task_memory: StructureTaskMemoryConfig = field(kw_only=True, metadata={"serializable": True}) | ||
|
||
def merge_config(self, config: dict) -> BaseStructureConfig: | ||
base_config = self.to_dict() | ||
merged_config = dict_merge(base_config, config) | ||
|
||
return BaseStructureConfig.from_dict(merged_config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from attrs import Factory, define, field | ||
|
||
from griptape.config import ( | ||
BaseStructureConfig, | ||
StructureGlobalDriversConfig, | ||
StructureTaskMemoryConfig, | ||
StructureTaskMemoryExtractionEngineConfig, | ||
StructureTaskMemoryExtractionEngineCsvConfig, | ||
StructureTaskMemoryExtractionEngineJsonConfig, | ||
StructureTaskMemoryQueryEngineConfig, | ||
StructureTaskMemorySummaryEngineConfig, | ||
) | ||
from griptape.drivers import ( | ||
LocalVectorStoreDriver, | ||
OpenAiChatPromptDriver, | ||
OpenAiEmbeddingDriver, | ||
OpenAiImageGenerationDriver, | ||
) | ||
|
||
|
||
@define | ||
class OpenAiStructureConfig(BaseStructureConfig): | ||
global_drivers: StructureGlobalDriversConfig = field( | ||
default=Factory( | ||
lambda: StructureGlobalDriversConfig( | ||
prompt_driver=OpenAiChatPromptDriver(model="gpt-4"), | ||
image_generation_driver=OpenAiImageGenerationDriver(model="dall-e-2", image_size="512x512"), | ||
embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-ada-002"), | ||
vector_store_driver=LocalVectorStoreDriver( | ||
embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-ada-002") | ||
), | ||
) | ||
), | ||
kw_only=True, | ||
metadata={"serializable": True}, | ||
) | ||
task_memory: StructureTaskMemoryConfig = field( | ||
default=Factory( | ||
lambda: StructureTaskMemoryConfig( | ||
query_engine=StructureTaskMemoryQueryEngineConfig( | ||
prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo"), | ||
vector_store_driver=LocalVectorStoreDriver( | ||
embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-ada-002") | ||
), | ||
), | ||
extraction_engine=StructureTaskMemoryExtractionEngineConfig( | ||
csv=StructureTaskMemoryExtractionEngineCsvConfig( | ||
prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo") | ||
), | ||
json=StructureTaskMemoryExtractionEngineJsonConfig( | ||
prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo") | ||
), | ||
), | ||
summary_engine=StructureTaskMemorySummaryEngineConfig( | ||
prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo") | ||
), | ||
) | ||
), | ||
kw_only=True, | ||
metadata={"serializable": True}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from attrs import Factory, define, field | ||
|
||
from griptape.config import BaseStructureConfig, StructureGlobalDriversConfig, StructureTaskMemoryConfig | ||
|
||
|
||
@define | ||
class StructureConfig(BaseStructureConfig): | ||
global_drivers: StructureGlobalDriversConfig = field( | ||
default=Factory(lambda: StructureGlobalDriversConfig()), kw_only=True, metadata={"serializable": True} | ||
) | ||
task_memory: StructureTaskMemoryConfig = field( | ||
default=Factory(lambda: StructureTaskMemoryConfig()), kw_only=True, metadata={"serializable": True} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from typing import Optional | ||
|
||
from attrs import Factory, define, field | ||
|
||
from griptape.drivers import ( | ||
BaseConversationMemoryDriver, | ||
BaseEmbeddingDriver, | ||
BaseImageGenerationDriver, | ||
BasePromptDriver, | ||
BaseVectorStoreDriver, | ||
DummyVectorStoreDriver, | ||
DummyEmbeddingDriver, | ||
DummyImageGenerationDriver, | ||
DummyPromptDriver, | ||
) | ||
from griptape.mixins.serializable_mixin import SerializableMixin | ||
|
||
|
||
@define | ||
class StructureGlobalDriversConfig(SerializableMixin): | ||
prompt_driver: BasePromptDriver = field( | ||
kw_only=True, default=Factory(lambda: DummyPromptDriver()), metadata={"serializable": True} | ||
) | ||
image_generation_driver: BaseImageGenerationDriver = field( | ||
kw_only=True, default=Factory(lambda: DummyImageGenerationDriver()), metadata={"serializable": True} | ||
) | ||
embedding_driver: BaseEmbeddingDriver = field( | ||
kw_only=True, default=Factory(lambda: DummyEmbeddingDriver()), metadata={"serializable": True} | ||
) | ||
vector_store_driver: BaseVectorStoreDriver = field( | ||
default=Factory(lambda: DummyVectorStoreDriver()), kw_only=True, metadata={"serializable": True} | ||
) | ||
conversation_memory_driver: Optional[BaseConversationMemoryDriver] = field( | ||
default=None, kw_only=True, metadata={"serializable": True} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from attrs import Factory, define, field | ||
|
||
from griptape.config import ( | ||
StructureTaskMemoryExtractionEngineConfig, | ||
StructureTaskMemoryQueryEngineConfig, | ||
StructureTaskMemorySummaryEngineConfig, | ||
) | ||
from griptape.mixins.serializable_mixin import SerializableMixin | ||
|
||
|
||
@define | ||
class StructureTaskMemoryConfig(SerializableMixin): | ||
query_engine: StructureTaskMemoryQueryEngineConfig = field( | ||
kw_only=True, default=Factory(lambda: StructureTaskMemoryQueryEngineConfig()), metadata={"serializable": True} | ||
) | ||
extraction_engine: StructureTaskMemoryExtractionEngineConfig = field( | ||
kw_only=True, | ||
default=Factory(lambda: StructureTaskMemoryExtractionEngineConfig()), | ||
metadata={"serializable": True}, | ||
) | ||
summary_engine: StructureTaskMemorySummaryEngineConfig = field( | ||
kw_only=True, default=Factory(lambda: StructureTaskMemorySummaryEngineConfig()), metadata={"serializable": True} | ||
) |
18 changes: 18 additions & 0 deletions
18
griptape/config/structure_task_memory_extraction_engine_config.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from attrs import Factory, define, field | ||
|
||
from griptape.config import StructureTaskMemoryExtractionEngineCsvConfig, StructureTaskMemoryExtractionEngineJsonConfig | ||
from griptape.mixins.serializable_mixin import SerializableMixin | ||
|
||
|
||
@define | ||
class StructureTaskMemoryExtractionEngineConfig(SerializableMixin): | ||
csv: StructureTaskMemoryExtractionEngineCsvConfig = field( | ||
kw_only=True, | ||
default=Factory(lambda: StructureTaskMemoryExtractionEngineCsvConfig()), | ||
metadata={"serializable": True}, | ||
) | ||
json: StructureTaskMemoryExtractionEngineJsonConfig = field( | ||
kw_only=True, | ||
default=Factory(lambda: StructureTaskMemoryExtractionEngineJsonConfig()), | ||
metadata={"serializable": True}, | ||
) |
11 changes: 11 additions & 0 deletions
11
griptape/config/structure_task_memory_extraction_engine_csv_config.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from attrs import define, field, Factory | ||
|
||
from griptape.drivers import BasePromptDriver, DummyPromptDriver | ||
from griptape.mixins.serializable_mixin import SerializableMixin | ||
|
||
|
||
@define | ||
class StructureTaskMemoryExtractionEngineCsvConfig(SerializableMixin): | ||
prompt_driver: BasePromptDriver = field( | ||
kw_only=True, default=Factory(lambda: DummyPromptDriver()), metadata={"serializable": True} | ||
) |
11 changes: 11 additions & 0 deletions
11
griptape/config/structure_task_memory_extraction_engine_json_config.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from attrs import define, field, Factory | ||
|
||
from griptape.drivers import BasePromptDriver, DummyPromptDriver | ||
from griptape.mixins.serializable_mixin import SerializableMixin | ||
|
||
|
||
@define | ||
class StructureTaskMemoryExtractionEngineJsonConfig(SerializableMixin): | ||
prompt_driver: BasePromptDriver = field( | ||
kw_only=True, default=Factory(lambda: DummyPromptDriver()), metadata={"serializable": True} | ||
) |
22 changes: 22 additions & 0 deletions
22
griptape/config/structure_task_memory_query_engine_config.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from attrs import Factory, define, field | ||
|
||
from griptape.drivers import ( | ||
BasePromptDriver, | ||
BaseVectorStoreDriver, | ||
DummyVectorStoreDriver, | ||
DummyEmbeddingDriver, | ||
DummyPromptDriver, | ||
) | ||
from griptape.mixins.serializable_mixin import SerializableMixin | ||
|
||
|
||
@define | ||
class StructureTaskMemoryQueryEngineConfig(SerializableMixin): | ||
prompt_driver: BasePromptDriver = field( | ||
kw_only=True, default=Factory(lambda: DummyPromptDriver()), metadata={"serializable": True} | ||
) | ||
vector_store_driver: BaseVectorStoreDriver = field( | ||
kw_only=True, | ||
default=Factory(lambda: DummyVectorStoreDriver(embedding_driver=DummyEmbeddingDriver())), | ||
metadata={"serializable": True}, | ||
) |
11 changes: 11 additions & 0 deletions
11
griptape/config/structure_task_memory_summary_engine_config.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from attrs import Factory, define, field | ||
|
||
from griptape.drivers import BasePromptDriver, DummyPromptDriver | ||
from griptape.mixins.serializable_mixin import SerializableMixin | ||
|
||
|
||
@define | ||
class StructureTaskMemorySummaryEngineConfig(SerializableMixin): | ||
prompt_driver: BasePromptDriver = field( | ||
kw_only=True, default=Factory(lambda: DummyPromptDriver()), metadata={"serializable": True} | ||
) |
Oops, something went wrong.