Skip to content

Commit

Permalink
Refactor deprecated warning
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Feb 12, 2024
1 parent 6342b1a commit 5c01d68
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 19 deletions.
61 changes: 42 additions & 19 deletions griptape/structures/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import logging
import uuid
import warnings
from abc import ABC, abstractmethod
from logging import Logger
from typing import TYPE_CHECKING, Any, Optional
Expand All @@ -12,13 +11,7 @@

from griptape.artifacts import BlobArtifact, TextArtifact
from griptape.config import BaseStructureConfig, OpenAiStructureConfig
from griptape.drivers import (
BaseEmbeddingDriver,
BasePromptDriver,
NopPromptDriver,
NopVectorStoreDriver,
NopEmbeddingDriver,
)
from griptape.drivers import BaseEmbeddingDriver, BasePromptDriver, NopPromptDriver, NopVectorStoreDriver
from griptape.engines import CsvExtractionEngine, JsonExtractionEngine, PromptSummaryEngine, VectorQueryEngine
from griptape.events import BaseEvent, EventListener
from griptape.events.finish_structure_run_event import FinishStructureRunEvent
Expand All @@ -29,6 +22,7 @@
from griptape.memory.task.storage import BlobArtifactStorage, TextArtifactStorage
from griptape.rules import Rule, Ruleset
from griptape.tasks import BaseTask
from griptape.utils.decorators import deprecated

if TYPE_CHECKING:
from griptape.memory.structure import BaseConversationMemory
Expand All @@ -39,9 +33,9 @@ class Structure(ABC):
LOGGER_NAME = "griptape"

id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True)
stream: Optional[bool] = field(default=None, kw_only=True)
prompt_driver: Optional[BasePromptDriver] = field(default=None)
embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True)
_stream: Optional[bool] = field(default=None, kw_only=True, alias="stream")
_prompt_driver: Optional[BasePromptDriver] = field(default=None, alias="prompt_driver")
_embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True, alias="embedding_driver")
config: BaseStructureConfig = field(
default=Factory(lambda self: self.default_config, takes_self=True), kw_only=True
)
Expand Down Expand Up @@ -92,6 +86,43 @@ def __attrs_post_init__(self) -> None:
def __add__(self, other: BaseTask | list[BaseTask]) -> list[BaseTask]:
return self.add_tasks(*other) if isinstance(other, list) else self + [other]

@property
@deprecated("use `config.global_drivers.prompt_driver` instead.")
def prompt_driver(self) -> Optional[BasePromptDriver]:
if self._prompt_driver is not None:
return self._prompt_driver
else:
return None

@prompt_driver.setter
def prompt_driver(self, prompt_driver: Optional[BasePromptDriver]) -> None:
self._prompt_driver = prompt_driver

@property
@deprecated("use `config.global_drivers.embedding_driver` instead.")
def embedding_driver(self) -> Optional[BaseEmbeddingDriver]:
if self._embedding_driver is not None:
return self._embedding_driver
else:
return None

@embedding_driver.setter
def embedding_driver(self, embedding_driver: Optional[BaseEmbeddingDriver]) -> None:
self._embedding_driver = embedding_driver

@property
@deprecated("`stream` is deprecated, use `config.prompt_driver.stream` instead.")
def stream(self) -> Optional[bool]:
if self._stream is not None:
return self._stream
else:
return None

@stream.setter
@deprecated("`stream` is deprecated, use `config.prompt_driver.stream` instead.")
def stream(self, stream: Optional[bool]) -> None:
self._stream = stream

@property
def execution_args(self) -> tuple:
return self._execution_args
Expand Down Expand Up @@ -127,22 +158,14 @@ def default_config(self) -> BaseStructureConfig:
config = OpenAiStructureConfig()

if self.prompt_driver is not None:
warnings.warn(
"`prompt_driver` is deprecated, use `config.global_drivers.prompt_driver` instead.", DeprecationWarning
)
config.global_drivers.prompt_driver = self.prompt_driver
config.task_memory.query_engine.prompt_driver = self.prompt_driver
config.task_memory.summary_engine.prompt_driver = self.prompt_driver
config.task_memory.extraction_engine.csv.prompt_driver = self.prompt_driver
config.task_memory.extraction_engine.json.prompt_driver = self.prompt_driver
if self.embedding_driver is not None:
warnings.warn(
"`embedding_driver` is deprecated, use `config.global_drivers.embedding_driver` instead.",
DeprecationWarning,
)
config.task_memory.query_engine.vector_store_driver.embedding_driver = self.embedding_driver
if self.stream is not None:
warnings.warn("`stream` is deprecated, use `config.prompt_driver.stream` instead.", DeprecationWarning)
config.global_drivers.prompt_driver.stream = self.stream

return config
Expand Down
31 changes: 31 additions & 0 deletions griptape/utils/decorators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import functools
import schema
import inspect
import warnings

from schema import Schema


Expand All @@ -26,3 +29,31 @@ def wrapper(self, *args, **kwargs):
return wrapper

return decorator


def deprecated(reason: str):
"""This is a decorator which can be used to mark functions
as deprecated. It will result in a warning being emitted
when the function is used.
Args:
reason: The reason why the function is deprecated.
"""

def decorator(func):
if inspect.isclass(func):
message = "Call to deprecated class {name} ({reason})."
else:
message = "Call to deprecated function {name} ({reason})."

@functools.wraps(func)
def wrapper(*args, **kwargs):
warnings.simplefilter("always", DeprecationWarning)
warnings.warn(message.format(name=func.__name__, reason=reason), category=DeprecationWarning, stacklevel=2)
warnings.simplefilter("default", DeprecationWarning)

return func(*args, **kwargs)

return wrapper

return decorator
12 changes: 12 additions & 0 deletions tests/unit/utils/test_decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import pytest
from griptape.utils.decorators import deprecated


class TestDecorators:
def test_deprecated(self):
@deprecated("This function is deprecated")
def test_function():
pass

with pytest.deprecated_call():
test_function()

0 comments on commit 5c01d68

Please sign in to comment.