Skip to content

Commit

Permalink
Remove uses of deprecated fields (griptape-ai#645)
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Feb 13, 2024
1 parent a5bd70e commit a28eb73
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 15 deletions.
20 changes: 10 additions & 10 deletions griptape/structures/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,16 +157,16 @@ def finished_tasks(self) -> list[BaseTask]:
def default_config(self) -> BaseStructureConfig:
config = OpenAiStructureConfig()

if self.prompt_driver is not None:
config.global_drivers.prompt_driver = self.prompt_driver
config.task_memory.query_engine.prompt_driver = self.prompt_driver
config.task_memory.summary_engine.prompt_driver = self.prompt_driver
config.task_memory.extraction_engine.csv.prompt_driver = self.prompt_driver
config.task_memory.extraction_engine.json.prompt_driver = self.prompt_driver
if self.embedding_driver is not None:
config.task_memory.query_engine.vector_store_driver.embedding_driver = self.embedding_driver
if self.stream is not None:
config.global_drivers.prompt_driver.stream = self.stream
if self._prompt_driver is not None:
config.global_drivers.prompt_driver = self._prompt_driver
config.task_memory.query_engine.prompt_driver = self._prompt_driver
config.task_memory.summary_engine.prompt_driver = self._prompt_driver
config.task_memory.extraction_engine.csv.prompt_driver = self._prompt_driver
config.task_memory.extraction_engine.json.prompt_driver = self._prompt_driver
if self._embedding_driver is not None:
config.task_memory.query_engine.vector_store_driver.embedding_driver = self._embedding_driver
if self._stream is not None:
config.global_drivers.prompt_driver.stream = self._stream

return config

Expand Down
4 changes: 2 additions & 2 deletions griptape/utils/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class Chat:
)

def default_output_fn(self, text: str) -> None:
if self.structure.prompt_driver.stream:
if self.structure.config.global_drivers.prompt_driver.stream:
print(text, end="", flush=True)
else:
print(text)
Expand All @@ -36,7 +36,7 @@ def start(self) -> None:
self.output_fn(self.exiting_text)
break

if self.structure.prompt_driver.stream:
if self.structure.config.global_drivers.prompt_driver.stream:
self.output_fn(self.processing_text + "\n")
stream = Stream(self.structure).run(question)
first_chunk = next(stream)
Expand Down
2 changes: 1 addition & 1 deletion griptape/utils/prompt_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def add_conversation_memory(self, memory: BaseConversationMemory, index: Optiona

if memory.autoprune and hasattr(memory, "structure"):
should_prune = True
prompt_driver = memory.structure.prompt_driver
prompt_driver = memory.structure.config.global_drivers.prompt_driver
temp_stack = PromptStack()

# Try to determine how many Conversation Memory runs we can
Expand Down
5 changes: 3 additions & 2 deletions griptape/utils/stream.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Iterator
from typing import TYPE_CHECKING
from collections.abc import Iterator
from threading import Thread
from queue import Queue
from griptape.artifacts.text_artifact import TextArtifact
Expand Down Expand Up @@ -32,7 +33,7 @@ class Stream:

@structure.validator # pyright: ignore
def validate_structure(self, _, structure: Structure):
if structure and not structure.prompt_driver.stream:
if structure and not structure.config.global_drivers.prompt_driver.stream:
raise ValueError("prompt driver does not have streaming enabled, enable with stream=True")

_event_queue: Queue[BaseEvent] = field(default=Factory(lambda: Queue()))
Expand Down

0 comments on commit a28eb73

Please sign in to comment.