Skip to content

Commit

Permalink
Add config classes (#591)
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Feb 12, 2024
1 parent f73c21c commit c87c85f
Show file tree
Hide file tree
Showing 114 changed files with 1,931 additions and 245 deletions.
14 changes: 14 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `AzureMongoDbVectorStoreDriver` for using CosmosDB with MongoDB vCore API.
- `vector_path` field on `MongoDbAtlasVectorStoreDriver`.
- `LeonardoImageGenerationDriver` supports image to image generation.
- `OpenAiStructureConfig` for providing Structures with all OpenAi Driver configuration.
- `AmazonBedrockStructureConfig` for providing Structures with all Amazon Bedrock Driver configuration.
- `StructureConfig` for building your own Structure configuration.
- `JsonExtractionTask` for convenience over using `ExtractionTask` with a `JsonExtractionEngine`.
- `CsvExtractionTask` for convenience over using `ExtractionTask` with a `CsvExtractionEngine`.

### Fixed
- `BedrockStableDiffusionImageGenerationModelDriver` request parameters for SDXLv1.
Expand All @@ -21,6 +26,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **BREAKING**: Make `index_name` on `MongoDbAtlasVectorStoreDriver` a required field.
- **BREAKING**: Remove `create_index()` from `MarqoVectorStoreDriver`, `OpenSearchVectorStoreDriver`, `PineconeVectorStoreDriver`, `RedisVectorStoreDriver`.
- **BREAKING**: `ImageLoader().load()` now accepts image bytes instead of a file path.
- Deprecated `Structure.prompt_driver` in favor of `Structure.global_drivers.prompt_driver`.
- Deprecated `Structure.embedding_driver` in favor of `Structure.global_drivers.embedding_driver`.
- Deprecated `Structure.stream` in favor of `Structure.global_drivers.prompt_driver.stream`.
- `TextSummaryTask.summary_engine` now defaults to a `PromptSummaryEngine` with a Prompt Driver default of `Structure.global_drivers.prompt_driver`.
- `TextQueryTask.query_engine` now defaults to a `VectorQueryEngine` with a Prompt Driver default of `Structure.global_drivers.prompt_driver` and Vector Store Driver default of `Structure.global_drivers.vector_store_driver`.
- `PromptImageGenerationTask.image_generation_engine` now defaults to a `PromptImageGenerationEngine` with an Image Generation Driver default of `Structure.global_drivers.image_generation_driver`.
- `VariationImageGenerationTask.image_generation_engine` now defaults to a `VariationImageGenerationEngine` with an Image Generation Driver default of `Structure.global_drivers.image_generation_driver`.
- `InpaintingImageGenerationTask.image_generation_engine` now defaults to an `InpaintingImageGenerationEngine` with an Image Generation Driver default of `Structure.global_drivers.image_generation_driver`.
- `OutpaintingImageGenerationTask.image_generation_engine` now defaults to an `OutpaintingImageGenerationEngine` with an Image Generation Driver default of `Structure.global_drivers.image_generation_driver`.

## [0.22.3] - 2024-01-22

Expand Down
30 changes: 30 additions & 0 deletions griptape/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from .base_config import BaseConfig

from .structure_global_drivers_config import StructureGlobalDriversConfig
from .structure_task_memory_extraction_engine_csv_config import StructureTaskMemoryExtractionEngineCsvConfig
from .structure_task_memory_extraction_engine_json_config import StructureTaskMemoryExtractionEngineJsonConfig
from .structure_task_memory_extraction_engine_config import StructureTaskMemoryExtractionEngineConfig
from .structure_task_memory_query_engine_config import StructureTaskMemoryQueryEngineConfig
from .structure_task_memory_summary_engine_config import StructureTaskMemorySummaryEngineConfig
from .structure_task_memory_config import StructureTaskMemoryConfig
from .base_structure_config import BaseStructureConfig

from .structure_config import StructureConfig
from .openai_structure_config import OpenAiStructureConfig
from .amazon_bedrock_structure_config import AmazonBedrockStructureConfig


__all__ = [
"BaseConfig",
"BaseStructureConfig",
"StructureTaskMemoryConfig",
"StructureGlobalDriversConfig",
"StructureTaskMemoryQueryEngineConfig",
"StructureTaskMemorySummaryEngineConfig",
"StructureTaskMemoryExtractionEngineConfig",
"StructureTaskMemoryExtractionEngineCsvConfig",
"StructureTaskMemoryExtractionEngineJsonConfig",
"StructureConfig",
"OpenAiStructureConfig",
"AmazonBedrockStructureConfig",
]
77 changes: 77 additions & 0 deletions griptape/config/amazon_bedrock_structure_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from attrs import Factory, define, field

from griptape.config import (
BaseStructureConfig,
StructureGlobalDriversConfig,
StructureTaskMemoryConfig,
StructureTaskMemoryExtractionEngineConfig,
StructureTaskMemoryExtractionEngineCsvConfig,
StructureTaskMemoryExtractionEngineJsonConfig,
StructureTaskMemoryQueryEngineConfig,
StructureTaskMemorySummaryEngineConfig,
)
from griptape.drivers import (
AmazonBedrockImageGenerationDriver,
AmazonBedrockPromptDriver,
AmazonBedrockTitanEmbeddingDriver,
BedrockClaudePromptModelDriver,
BedrockTitanImageGenerationModelDriver,
BedrockTitanPromptModelDriver,
LocalVectorStoreDriver,
)


@define()
class AmazonBedrockStructureConfig(BaseStructureConfig):
global_drivers: StructureGlobalDriversConfig = field(
default=Factory(
lambda: StructureGlobalDriversConfig(
prompt_driver=AmazonBedrockPromptDriver(
model="anthropic.claude-v2", stream=False, prompt_model_driver=BedrockClaudePromptModelDriver()
),
image_generation_driver=AmazonBedrockImageGenerationDriver(
model="amazon.titan-image-generator-v1",
image_generation_model_driver=BedrockTitanImageGenerationModelDriver(),
),
embedding_driver=AmazonBedrockTitanEmbeddingDriver(model="amazon.titan-embed-text-v1"),
vector_store_driver=LocalVectorStoreDriver(
embedding_driver=AmazonBedrockTitanEmbeddingDriver(model="amazon.titan-embed-text-v1")
),
)
),
kw_only=True,
metadata={"serializable": True},
)
task_memory: StructureTaskMemoryConfig = field(
default=Factory(
lambda: StructureTaskMemoryConfig(
query_engine=StructureTaskMemoryQueryEngineConfig(
prompt_driver=AmazonBedrockPromptDriver(
model="amazon.titan-text-express-v1", prompt_model_driver=BedrockTitanPromptModelDriver()
),
vector_store_driver=LocalVectorStoreDriver(
embedding_driver=AmazonBedrockTitanEmbeddingDriver(model="amazon.titan-embed-text-v1")
),
),
extraction_engine=StructureTaskMemoryExtractionEngineConfig(
csv=StructureTaskMemoryExtractionEngineCsvConfig(
prompt_driver=AmazonBedrockPromptDriver(
model="amazon.titan-text-express-v1", prompt_model_driver=BedrockTitanPromptModelDriver()
)
),
json=StructureTaskMemoryExtractionEngineJsonConfig(
prompt_driver=AmazonBedrockPromptDriver(
model="amazon.titan-text-express-v1", prompt_model_driver=BedrockTitanPromptModelDriver()
)
),
),
summary_engine=StructureTaskMemorySummaryEngineConfig(
prompt_driver=AmazonBedrockPromptDriver(
model="amazon.titan-text-express-v1", prompt_model_driver=BedrockTitanPromptModelDriver()
)
),
)
),
kw_only=True,
metadata={"serializable": True},
)
10 changes: 10 additions & 0 deletions griptape/config/base_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from abc import ABC

from attrs import define

from griptape.mixins.serializable_mixin import SerializableMixin


@define
class BaseConfig(SerializableMixin, ABC):
...
20 changes: 20 additions & 0 deletions griptape/config/base_structure_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from __future__ import annotations

from abc import ABC

from attr import define, field

from griptape.config import BaseConfig, StructureGlobalDriversConfig, StructureTaskMemoryConfig
from griptape.utils import dict_merge


@define
class BaseStructureConfig(BaseConfig, ABC):
global_drivers: StructureGlobalDriversConfig = field(kw_only=True, metadata={"serializable": True})
task_memory: StructureTaskMemoryConfig = field(kw_only=True, metadata={"serializable": True})

def merge_config(self, config: dict) -> BaseStructureConfig:
base_config = self.to_dict()
merged_config = dict_merge(base_config, config)

return BaseStructureConfig.from_dict(merged_config)
61 changes: 61 additions & 0 deletions griptape/config/openai_structure_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from attrs import Factory, define, field

from griptape.config import (
BaseStructureConfig,
StructureGlobalDriversConfig,
StructureTaskMemoryConfig,
StructureTaskMemoryExtractionEngineConfig,
StructureTaskMemoryExtractionEngineCsvConfig,
StructureTaskMemoryExtractionEngineJsonConfig,
StructureTaskMemoryQueryEngineConfig,
StructureTaskMemorySummaryEngineConfig,
)
from griptape.drivers import (
LocalVectorStoreDriver,
OpenAiChatPromptDriver,
OpenAiEmbeddingDriver,
OpenAiImageGenerationDriver,
)


@define
class OpenAiStructureConfig(BaseStructureConfig):
global_drivers: StructureGlobalDriversConfig = field(
default=Factory(
lambda: StructureGlobalDriversConfig(
prompt_driver=OpenAiChatPromptDriver(model="gpt-4"),
image_generation_driver=OpenAiImageGenerationDriver(model="dall-e-2", image_size="512x512"),
embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-ada-002"),
vector_store_driver=LocalVectorStoreDriver(
embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-ada-002")
),
)
),
kw_only=True,
metadata={"serializable": True},
)
task_memory: StructureTaskMemoryConfig = field(
default=Factory(
lambda: StructureTaskMemoryConfig(
query_engine=StructureTaskMemoryQueryEngineConfig(
prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo"),
vector_store_driver=LocalVectorStoreDriver(
embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-ada-002")
),
),
extraction_engine=StructureTaskMemoryExtractionEngineConfig(
csv=StructureTaskMemoryExtractionEngineCsvConfig(
prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo")
),
json=StructureTaskMemoryExtractionEngineJsonConfig(
prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo")
),
),
summary_engine=StructureTaskMemorySummaryEngineConfig(
prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo")
),
)
),
kw_only=True,
metadata={"serializable": True},
)
13 changes: 13 additions & 0 deletions griptape/config/structure_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from attrs import Factory, define, field

from griptape.config import BaseStructureConfig, StructureGlobalDriversConfig, StructureTaskMemoryConfig


@define
class StructureConfig(BaseStructureConfig):
global_drivers: StructureGlobalDriversConfig = field(
default=Factory(lambda: StructureGlobalDriversConfig()), kw_only=True, metadata={"serializable": True}
)
task_memory: StructureTaskMemoryConfig = field(
default=Factory(lambda: StructureTaskMemoryConfig()), kw_only=True, metadata={"serializable": True}
)
35 changes: 35 additions & 0 deletions griptape/config/structure_global_drivers_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Optional

from attrs import Factory, define, field

from griptape.drivers import (
BaseConversationMemoryDriver,
BaseEmbeddingDriver,
BaseImageGenerationDriver,
BasePromptDriver,
BaseVectorStoreDriver,
DummyVectorStoreDriver,
DummyEmbeddingDriver,
DummyImageGenerationDriver,
DummyPromptDriver,
)
from griptape.mixins.serializable_mixin import SerializableMixin


@define
class StructureGlobalDriversConfig(SerializableMixin):
prompt_driver: BasePromptDriver = field(
kw_only=True, default=Factory(lambda: DummyPromptDriver()), metadata={"serializable": True}
)
image_generation_driver: BaseImageGenerationDriver = field(
kw_only=True, default=Factory(lambda: DummyImageGenerationDriver()), metadata={"serializable": True}
)
embedding_driver: BaseEmbeddingDriver = field(
kw_only=True, default=Factory(lambda: DummyEmbeddingDriver()), metadata={"serializable": True}
)
vector_store_driver: BaseVectorStoreDriver = field(
default=Factory(lambda: DummyVectorStoreDriver()), kw_only=True, metadata={"serializable": True}
)
conversation_memory_driver: Optional[BaseConversationMemoryDriver] = field(
default=None, kw_only=True, metadata={"serializable": True}
)
23 changes: 23 additions & 0 deletions griptape/config/structure_task_memory_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from attrs import Factory, define, field

from griptape.config import (
StructureTaskMemoryExtractionEngineConfig,
StructureTaskMemoryQueryEngineConfig,
StructureTaskMemorySummaryEngineConfig,
)
from griptape.mixins.serializable_mixin import SerializableMixin


@define
class StructureTaskMemoryConfig(SerializableMixin):
query_engine: StructureTaskMemoryQueryEngineConfig = field(
kw_only=True, default=Factory(lambda: StructureTaskMemoryQueryEngineConfig()), metadata={"serializable": True}
)
extraction_engine: StructureTaskMemoryExtractionEngineConfig = field(
kw_only=True,
default=Factory(lambda: StructureTaskMemoryExtractionEngineConfig()),
metadata={"serializable": True},
)
summary_engine: StructureTaskMemorySummaryEngineConfig = field(
kw_only=True, default=Factory(lambda: StructureTaskMemorySummaryEngineConfig()), metadata={"serializable": True}
)
18 changes: 18 additions & 0 deletions griptape/config/structure_task_memory_extraction_engine_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from attrs import Factory, define, field

from griptape.config import StructureTaskMemoryExtractionEngineCsvConfig, StructureTaskMemoryExtractionEngineJsonConfig
from griptape.mixins.serializable_mixin import SerializableMixin


@define
class StructureTaskMemoryExtractionEngineConfig(SerializableMixin):
csv: StructureTaskMemoryExtractionEngineCsvConfig = field(
kw_only=True,
default=Factory(lambda: StructureTaskMemoryExtractionEngineCsvConfig()),
metadata={"serializable": True},
)
json: StructureTaskMemoryExtractionEngineJsonConfig = field(
kw_only=True,
default=Factory(lambda: StructureTaskMemoryExtractionEngineJsonConfig()),
metadata={"serializable": True},
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from attrs import define, field, Factory

from griptape.drivers import BasePromptDriver, DummyPromptDriver
from griptape.mixins.serializable_mixin import SerializableMixin


@define
class StructureTaskMemoryExtractionEngineCsvConfig(SerializableMixin):
prompt_driver: BasePromptDriver = field(
kw_only=True, default=Factory(lambda: DummyPromptDriver()), metadata={"serializable": True}
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from attrs import define, field, Factory

from griptape.drivers import BasePromptDriver, DummyPromptDriver
from griptape.mixins.serializable_mixin import SerializableMixin


@define
class StructureTaskMemoryExtractionEngineJsonConfig(SerializableMixin):
prompt_driver: BasePromptDriver = field(
kw_only=True, default=Factory(lambda: DummyPromptDriver()), metadata={"serializable": True}
)
22 changes: 22 additions & 0 deletions griptape/config/structure_task_memory_query_engine_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from attrs import Factory, define, field

from griptape.drivers import (
BasePromptDriver,
BaseVectorStoreDriver,
DummyVectorStoreDriver,
DummyEmbeddingDriver,
DummyPromptDriver,
)
from griptape.mixins.serializable_mixin import SerializableMixin


@define
class StructureTaskMemoryQueryEngineConfig(SerializableMixin):
prompt_driver: BasePromptDriver = field(
kw_only=True, default=Factory(lambda: DummyPromptDriver()), metadata={"serializable": True}
)
vector_store_driver: BaseVectorStoreDriver = field(
kw_only=True,
default=Factory(lambda: DummyVectorStoreDriver(embedding_driver=DummyEmbeddingDriver())),
metadata={"serializable": True},
)
11 changes: 11 additions & 0 deletions griptape/config/structure_task_memory_summary_engine_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from attrs import Factory, define, field

from griptape.drivers import BasePromptDriver, DummyPromptDriver
from griptape.mixins.serializable_mixin import SerializableMixin


@define
class StructureTaskMemorySummaryEngineConfig(SerializableMixin):
prompt_driver: BasePromptDriver = field(
kw_only=True, default=Factory(lambda: DummyPromptDriver()), metadata={"serializable": True}
)

0 comments on commit c87c85f

Please sign in to comment.