-
Notifications
You must be signed in to change notification settings - Fork 146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add config classes #591
Add config classes #591
Changes from 84 commits
2960468
e0fbd0e
5f90b95
88d12a3
5ff1db0
32ce57c
d9234c3
eecd9fc
0ab1e21
a072af9
41d998e
c0579aa
921bdc8
7e75e06
9ebad74
a9a91b3
c1504e5
7b1ec6d
942ca51
71046ce
c58b340
5504e93
2ac97b0
140be19
11baf50
2c944c3
35004b3
36da3b7
9d4832e
5b78e8b
c042586
0bd6e07
f0d57d8
1d27584
df45392
e5bff8e
7253194
d4cbce2
da492b5
fb65e65
fc57891
d063c00
d6a0c5a
287f77f
d4536d1
a340f0c
67dcd4b
c7e955e
f54722d
949be05
8389d97
4b901fa
e3a5045
30b0af6
25059f7
669f591
73c291d
73dc252
9d87bfa
7e3769f
81afccb
cd729ea
a6ab72b
4bdcdbe
eccb0c6
069b394
48f46d0
d8fa2ee
a84db08
eba57c3
3f3a9b1
03cbf90
1798a26
87577ac
c71ddb1
815e3ea
54ebb24
4e84d01
137a981
9a05f8e
bc92356
4ae312c
81172e1
d2b66fa
503bc47
1b668b0
6342b1a
5c01d68
da8b50d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from .base_config import BaseConfig | ||
|
||
from .structure_global_drivers_config import StructureGlobalDriversConfig | ||
from .structure_task_memory_extraction_engine_csv_config import StructureTaskMemoryExtractionEngineCsvConfig | ||
from .structure_task_memory_extraction_engine_json_config import StructureTaskMemoryExtractionEngineJsonConfig | ||
from .structure_task_memory_extraction_engine_config import StructureTaskMemoryExtractionEngineConfig | ||
from .structure_task_memory_query_engine_config import StructureTaskMemoryQueryEngineConfig | ||
from .structure_task_memory_summary_engine_config import StructureTaskMemorySummaryEngineConfig | ||
from .structure_task_memory_config import StructureTaskMemoryConfig | ||
from .base_structure_config import BaseStructureConfig | ||
|
||
from .structure_config import StructureConfig | ||
from .openai_structure_config import OpenAiStructureConfig | ||
from .amazon_bedrock_structure_config import AmazonBedrockStructureConfig | ||
|
||
|
||
__all__ = [ | ||
"BaseConfig", | ||
"BaseStructureConfig", | ||
"StructureTaskMemoryConfig", | ||
"StructureGlobalDriversConfig", | ||
"StructureTaskMemoryQueryEngineConfig", | ||
"StructureTaskMemorySummaryEngineConfig", | ||
"StructureTaskMemoryExtractionEngineConfig", | ||
"StructureTaskMemoryExtractionEngineCsvConfig", | ||
"StructureTaskMemoryExtractionEngineJsonConfig", | ||
"StructureConfig", | ||
"OpenAiStructureConfig", | ||
"AmazonBedrockStructureConfig", | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
from attrs import Factory, define, field | ||
|
||
from griptape.config import ( | ||
BaseStructureConfig, | ||
StructureGlobalDriversConfig, | ||
StructureTaskMemoryConfig, | ||
StructureTaskMemoryExtractionEngineConfig, | ||
StructureTaskMemoryExtractionEngineCsvConfig, | ||
StructureTaskMemoryExtractionEngineJsonConfig, | ||
StructureTaskMemoryQueryEngineConfig, | ||
StructureTaskMemorySummaryEngineConfig, | ||
) | ||
from griptape.drivers import ( | ||
AmazonBedrockImageGenerationDriver, | ||
AmazonBedrockPromptDriver, | ||
AmazonBedrockTitanEmbeddingDriver, | ||
BedrockClaudePromptModelDriver, | ||
BedrockTitanImageGenerationModelDriver, | ||
BedrockTitanPromptModelDriver, | ||
LocalVectorStoreDriver, | ||
) | ||
|
||
|
||
@define(kw_only=True) | ||
class AmazonBedrockStructureConfig(BaseStructureConfig): | ||
global_drivers: StructureGlobalDriversConfig = field( | ||
default=Factory( | ||
lambda: StructureGlobalDriversConfig( | ||
prompt_driver=AmazonBedrockPromptDriver( | ||
model="anthropic.claude-v2", stream=False, prompt_model_driver=BedrockClaudePromptModelDriver() | ||
), | ||
image_generation_driver=AmazonBedrockImageGenerationDriver( | ||
model="amazon.titan-image-generator-v1", | ||
image_generation_model_driver=BedrockTitanImageGenerationModelDriver(), | ||
), | ||
embedding_driver=AmazonBedrockTitanEmbeddingDriver(model="amazon.titan-embed-text-v1"), | ||
vector_store_driver=LocalVectorStoreDriver( | ||
embedding_driver=AmazonBedrockTitanEmbeddingDriver(model="amazon.titan-embed-text-v1") | ||
), | ||
) | ||
), | ||
kw_only=True, | ||
metadata={"serializable": True}, | ||
) | ||
task_memory: StructureTaskMemoryConfig = field( | ||
default=Factory( | ||
lambda: StructureTaskMemoryConfig( | ||
query_engine=StructureTaskMemoryQueryEngineConfig( | ||
prompt_driver=AmazonBedrockPromptDriver( | ||
model="amazon.titan-text-express-v1", prompt_model_driver=BedrockTitanPromptModelDriver() | ||
), | ||
vector_store_driver=LocalVectorStoreDriver( | ||
embedding_driver=AmazonBedrockTitanEmbeddingDriver(model="amazon.titan-embed-text-v1") | ||
), | ||
), | ||
extraction_engine=StructureTaskMemoryExtractionEngineConfig( | ||
csv=StructureTaskMemoryExtractionEngineCsvConfig( | ||
prompt_driver=AmazonBedrockPromptDriver( | ||
model="amazon.titan-text-express-v1", prompt_model_driver=BedrockTitanPromptModelDriver() | ||
) | ||
), | ||
json=StructureTaskMemoryExtractionEngineJsonConfig( | ||
prompt_driver=AmazonBedrockPromptDriver( | ||
model="amazon.titan-text-express-v1", prompt_model_driver=BedrockTitanPromptModelDriver() | ||
) | ||
), | ||
), | ||
summary_engine=StructureTaskMemorySummaryEngineConfig( | ||
prompt_driver=AmazonBedrockPromptDriver( | ||
model="amazon.titan-text-express-v1", prompt_model_driver=BedrockTitanPromptModelDriver() | ||
) | ||
), | ||
) | ||
), | ||
kw_only=True, | ||
metadata={"serializable": True}, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from abc import ABC | ||
|
||
from attrs import define | ||
|
||
from griptape.mixins.serializable_mixin import SerializableMixin | ||
|
||
|
||
@define | ||
class BaseConfig(SerializableMixin, ABC): | ||
... |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from __future__ import annotations | ||
|
||
from abc import ABC | ||
|
||
from attr import define, field | ||
|
||
from griptape.config import BaseConfig, StructureGlobalDriversConfig, StructureTaskMemoryConfig | ||
from griptape.utils import dict_merge | ||
|
||
|
||
@define | ||
class BaseStructureConfig(BaseConfig, ABC): | ||
global_drivers: StructureGlobalDriversConfig = field(kw_only=True, metadata={"serializable": True}) | ||
task_memory: StructureTaskMemoryConfig = field(kw_only=True, metadata={"serializable": True}) | ||
|
||
def merge_config(self, config: dict) -> BaseStructureConfig: | ||
base_config = self.to_dict() | ||
merged_config = dict_merge(base_config, config) | ||
|
||
return BaseStructureConfig.from_dict(merged_config) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from attrs import Factory, define, field | ||
|
||
from griptape.config import ( | ||
BaseStructureConfig, | ||
StructureGlobalDriversConfig, | ||
StructureTaskMemoryConfig, | ||
StructureTaskMemoryExtractionEngineConfig, | ||
StructureTaskMemoryExtractionEngineCsvConfig, | ||
StructureTaskMemoryExtractionEngineJsonConfig, | ||
StructureTaskMemoryQueryEngineConfig, | ||
StructureTaskMemorySummaryEngineConfig, | ||
) | ||
from griptape.drivers import ( | ||
LocalVectorStoreDriver, | ||
OpenAiChatPromptDriver, | ||
OpenAiEmbeddingDriver, | ||
OpenAiImageGenerationDriver, | ||
) | ||
|
||
|
||
@define(kw_only=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we apply this to all new classes in this PR and, in a separate PR, to most (or all?) classes in Griptape? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I'd actually vote to remove it from this PR, and do it in a single sweep in another. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, let's remove it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
class OpenAiStructureConfig(BaseStructureConfig): | ||
global_drivers: StructureGlobalDriversConfig = field( | ||
default=Factory( | ||
lambda: StructureGlobalDriversConfig( | ||
prompt_driver=OpenAiChatPromptDriver(model="gpt-4"), | ||
image_generation_driver=OpenAiImageGenerationDriver(model="dall-e-2", image_size="512x512"), | ||
embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-ada-002"), | ||
vector_store_driver=LocalVectorStoreDriver( | ||
embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-ada-002") | ||
), | ||
) | ||
), | ||
kw_only=True, | ||
metadata={"serializable": True}, | ||
) | ||
task_memory: StructureTaskMemoryConfig = field( | ||
default=Factory( | ||
lambda: StructureTaskMemoryConfig( | ||
query_engine=StructureTaskMemoryQueryEngineConfig( | ||
prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo"), | ||
vector_store_driver=LocalVectorStoreDriver( | ||
embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-ada-002") | ||
), | ||
), | ||
extraction_engine=StructureTaskMemoryExtractionEngineConfig( | ||
csv=StructureTaskMemoryExtractionEngineCsvConfig( | ||
prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo") | ||
), | ||
json=StructureTaskMemoryExtractionEngineJsonConfig( | ||
prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo") | ||
), | ||
), | ||
summary_engine=StructureTaskMemorySummaryEngineConfig( | ||
prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo") | ||
), | ||
) | ||
), | ||
kw_only=True, | ||
metadata={"serializable": True}, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from attrs import Factory, define, field | ||
|
||
from griptape.config import BaseStructureConfig, StructureGlobalDriversConfig, StructureTaskMemoryConfig | ||
|
||
|
||
@define(kw_only=True) | ||
class StructureConfig(BaseStructureConfig): | ||
global_drivers: StructureGlobalDriversConfig = field( | ||
default=Factory(lambda: StructureGlobalDriversConfig()), kw_only=True, metadata={"serializable": True} | ||
) | ||
task_memory: StructureTaskMemoryConfig = field( | ||
default=Factory(lambda: StructureTaskMemoryConfig()), kw_only=True, metadata={"serializable": True} | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from typing import Optional | ||
|
||
from attrs import Factory, define, field | ||
|
||
from griptape.drivers import ( | ||
BaseConversationMemoryDriver, | ||
BaseEmbeddingDriver, | ||
BaseImageGenerationDriver, | ||
BasePromptDriver, | ||
BaseVectorStoreDriver, | ||
NopVectorStoreDriver, | ||
NopEmbeddingDriver, | ||
NopImageGenerationDriver, | ||
NopPromptDriver, | ||
) | ||
from griptape.mixins.serializable_mixin import SerializableMixin | ||
|
||
|
||
@define(kw_only=True) | ||
class StructureGlobalDriversConfig(SerializableMixin): | ||
prompt_driver: BasePromptDriver = field( | ||
kw_only=True, default=Factory(lambda: NopPromptDriver()), metadata={"serializable": True} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just wanted to make sure this is the best convention...do we like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you think @andrewfrench? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Traditionally I use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good with me too. Before I do a big refactor, @vasinov are you also good with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That works! It's more fun than |
||
) | ||
image_generation_driver: BaseImageGenerationDriver = field( | ||
kw_only=True, default=Factory(lambda: NopImageGenerationDriver()), metadata={"serializable": True} | ||
) | ||
embedding_driver: BaseEmbeddingDriver = field( | ||
kw_only=True, default=Factory(lambda: NopEmbeddingDriver()), metadata={"serializable": True} | ||
) | ||
vector_store_driver: BaseVectorStoreDriver = field( | ||
default=Factory(lambda: NopVectorStoreDriver()), kw_only=True, metadata={"serializable": True} | ||
) | ||
conversation_memory_driver: Optional[BaseConversationMemoryDriver] = field( | ||
default=None, kw_only=True, metadata={"serializable": True} | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from attrs import Factory, define, field | ||
|
||
from griptape.config import ( | ||
StructureTaskMemoryExtractionEngineConfig, | ||
StructureTaskMemoryQueryEngineConfig, | ||
StructureTaskMemorySummaryEngineConfig, | ||
) | ||
from griptape.mixins.serializable_mixin import SerializableMixin | ||
|
||
|
||
@define(kw_only=True) | ||
class StructureTaskMemoryConfig(SerializableMixin): | ||
query_engine: StructureTaskMemoryQueryEngineConfig = field( | ||
default=Factory(lambda: StructureTaskMemoryQueryEngineConfig()), metadata={"serializable": True} | ||
) | ||
extraction_engine: StructureTaskMemoryExtractionEngineConfig = field( | ||
default=Factory(lambda: StructureTaskMemoryExtractionEngineConfig()), metadata={"serializable": True} | ||
) | ||
summary_engine: StructureTaskMemorySummaryEngineConfig = field( | ||
default=Factory(lambda: StructureTaskMemorySummaryEngineConfig()), metadata={"serializable": True} | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from attrs import Factory, define, field | ||
|
||
from griptape.config import StructureTaskMemoryExtractionEngineCsvConfig, StructureTaskMemoryExtractionEngineJsonConfig | ||
from griptape.mixins.serializable_mixin import SerializableMixin | ||
|
||
|
||
@define(kw_only=True) | ||
class StructureTaskMemoryExtractionEngineConfig(SerializableMixin): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Random thought: will we ever want to use configs as mixins? In other words, will it ever make sense to inherit from two configs at once? If that's the case, should we add the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm having a difficult time imaging a configuration situation where we'd inherit from multiple. Did you have something specific in mind? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nothing specific, just wondering. Let's keep it the way it is then :) |
||
csv: StructureTaskMemoryExtractionEngineCsvConfig = field( | ||
default=Factory(lambda: StructureTaskMemoryExtractionEngineCsvConfig()), metadata={"serializable": True} | ||
) | ||
json: StructureTaskMemoryExtractionEngineJsonConfig = field( | ||
default=Factory(lambda: StructureTaskMemoryExtractionEngineJsonConfig()), metadata={"serializable": True} | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from attrs import define, field, Factory | ||
|
||
from griptape.drivers import BasePromptDriver, NopPromptDriver | ||
from griptape.mixins.serializable_mixin import SerializableMixin | ||
|
||
|
||
@define(kw_only=True) | ||
class StructureTaskMemoryExtractionEngineCsvConfig(SerializableMixin): | ||
prompt_driver: BasePromptDriver = field(default=Factory(lambda: NopPromptDriver()), metadata={"serializable": True}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from attrs import define, field, Factory | ||
|
||
from griptape.drivers import BasePromptDriver, NopPromptDriver | ||
from griptape.mixins.serializable_mixin import SerializableMixin | ||
|
||
|
||
@define(kw_only=True) | ||
class StructureTaskMemoryExtractionEngineJsonConfig(SerializableMixin): | ||
prompt_driver: BasePromptDriver = field(default=Factory(lambda: NopPromptDriver()), metadata={"serializable": True}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from attrs import Factory, define, field | ||
|
||
from griptape.drivers import ( | ||
BasePromptDriver, | ||
BaseVectorStoreDriver, | ||
NopVectorStoreDriver, | ||
NopEmbeddingDriver, | ||
NopPromptDriver, | ||
) | ||
from griptape.mixins.serializable_mixin import SerializableMixin | ||
|
||
|
||
@define(kw_only=True) | ||
class StructureTaskMemoryQueryEngineConfig(SerializableMixin): | ||
prompt_driver: BasePromptDriver = field(default=Factory(lambda: NopPromptDriver()), metadata={"serializable": True}) | ||
vector_store_driver: BaseVectorStoreDriver = field( | ||
default=Factory(lambda: NopVectorStoreDriver(embedding_driver=NopEmbeddingDriver())), | ||
metadata={"serializable": True}, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from attrs import define, field, Factory | ||
|
||
from griptape.drivers import BasePromptDriver, NopPromptDriver | ||
from griptape.mixins.serializable_mixin import SerializableMixin | ||
|
||
|
||
@define(kw_only=True) | ||
class StructureTaskMemorySummaryEngineConfig(SerializableMixin): | ||
prompt_driver: BasePromptDriver = field(default=Factory(lambda: NopPromptDriver()), metadata={"serializable": True}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we rename it to
global
andStructureGlobalConfig
in case we add more things, other than drivers, to it in the future?Or should we keep it namespaced to
drivers
for clarity and just introduce another config "bundle" in the future?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer to keep it namespaced; I could see a singular
global
becoming quite unruly.