Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add structure run tool #764

Merged
merged 7 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/docs-integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ jobs:
GOOGLE_AUTH_URI: ${{ secrets.INTEG_GOOGLE_AUTH_URI }}
GOOGLE_TOKEN_URI: ${{ secrets.INTEG_GOOGLE_TOKEN_URI }}
GOOGLE_AUTH_PROVIDER_X509_CERT_URL: ${{ secrets.INTEG_GOOGLE_AUTH_PROVIDER_X509_CERT_URL }}
GRIPTAPE_CLOUD_API_KEY: ${{ secrets.INTEG_GRIPTAPE_CLOUD_API_KEY }}
GRIPTAPE_CLOUD_STRUCTURE_ID: ${{ secrets.INTEG_GRIPTAPE_CLOUD_STRUCTURE_ID }}
OPENWEATHER_API_KEY: ${{ secrets.INTEG_OPENWEATHER_API_KEY }}
ANTHROPIC_API_KEY: ${{ secrets.INTEG_ANTHROPIC_API_KEY }}
SAGEMAKER_LLAMA_ENDPOINT_NAME: ${{ secrets.INTEG_LLAMA_ENDPOINT_NAME }}
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `AmazonS3FileManagerDriver` for managing files on Amazon S3.
- `MediaArtifact` as a base class for `ImageArtifact` and future media Artifacts.
- Optional `exception` field to `ErrorArtifact`.
- `GriptapeCloudStructureRunClient` tool for invoking Griptape Cloud Structure Run APIs.

### Changed
- **BREAKING**: Secret fields (ex: api_key) removed from serialized Drivers.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# GriptapeCloudStructureRunClient

The GriptapeCloudStructureRunClient tool provides a way to interact with the Griptape Cloud Structure Run API. It can be used to execute a Structure Run and retrieve the results.

```python
from griptape.tools import GriptapeCloudStructureRunClient
from griptape.structures import Agent
import os

api_key = os.environ["GRIPTAPE_CLOUD_API_KEY"]
structure_id = os.environ["GRIPTAPE_CLOUD_STRUCTURE_ID"]

# Create the GriptapeCloudStructureRunClient tool
structure_run_tool = GriptapeCloudStructureRunClient(
description="Danish Baker Agent - Structure to invoke with natural language queries about Danish pastries",
api_key=api_key,
structure_id=structure_id,
off_prompt=False,
)

# Set up an agent using the GriptapeCloudStructureRunClient tool
agent = Agent(
tools=[structure_run_tool]
)

# Task: Ask the Griptape Cloud Hosted Structure about new Danish pastries
agent.run(
"What are the new pastries?"
)
```
```
[04/29/24 20:46:14] INFO ToolkitTask 3b3f31a123584f05be9bcb02a58dddb6
Input: what are the new pastries?
[04/29/24 20:46:23] INFO Subtask 2740dcd92bdf4b159dc7a7fb132c98f3
Thought: To find out about new pastries, I need to use the Danish Baker Agent Structure. I will execute a run of
this Structure with the query "what are the new pastries".

Actions: [
{
"name": "GriptapeCloudStructureRunClient",
"path": "execute_structure_run",
"input": {
"values": {
"args": ["what are the new pastries"]
}
},
"tag": "query_new_pastries"
}
]
[04/29/24 20:47:01] INFO Subtask 2740dcd92bdf4b159dc7a7fb132c98f3
Response: {'id': '4a329cbd09ad42e0bd265e9ba4690400', 'name': '4a329cbd09ad42e0bd265e9ba4690400', 'type':
'TextArtifact', 'value': 'Ah, my friend, I am glad you asked! We have been busy in the bakery, kneading dough
and sprinkling sugar. Our new pastries include the "Copenhagen Cream Puff", a delightful puff pastry filled with
sweet cream and dusted with powdered sugar. We also have the "Danish Delight", a buttery croissant filled with
raspberry jam and topped with a drizzle of white chocolate. And let\'s not forget the "Nordic Nutella Twist", a
flaky pastry twisted with Nutella and sprinkled with chopped hazelnuts. I promise, each bite will transport you
to a cozy Danish bakery!'}
[04/29/24 20:47:07] INFO ToolkitTask 3b3f31a123584f05be9bcb02a58dddb6
Output: The new pastries include the "Copenhagen Cream Puff," which is a puff pastry filled with sweet cream and
dusted with powdered sugar; the "Danish Delight," a buttery croissant filled with raspberry jam and topped with
white chocolate; and the "Nordic Nutella Twist," a flaky pastry twisted with Nutella and sprinkled with chopped
hazelnuts.
Assistant: The new pastries include the "Copenhagen Cream Puff," which is a puff pastry filled with sweet cream and dusted with powdered sugar; the "Danish Delight," a buttery croissant filled with raspberry jam and topped with white chocolate; and the "Nordic Nutella Twist," a flaky pastry twisted with Nutella and sprinkled with chopped hazelnuts.
```
1 change: 1 addition & 0 deletions griptape/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .inpainting_image_generation_client.tool import InpaintingImageGenerationClient
from .outpainting_image_generation_client.tool import OutpaintingImageGenerationClient
from .griptape_cloud_knowledge_base_client.tool import GriptapeCloudKnowledgeBaseClient
from .griptape_cloud_structure_run_client.tool import GriptapeCloudStructureRunClient
from .image_query_client.tool import ImageQueryClient

__all__ = [
Expand Down
20 changes: 20 additions & 0 deletions griptape/tools/base_griptape_cloud_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from __future__ import annotations
from abc import ABC
from attr import Factory, define, field
from griptape.tools import BaseTool


@define
class BaseGriptapeCloudClient(BaseTool, ABC):
"""
Attributes:
base_url: Base URL for the Griptape Cloud Knowledge Base API.
api_key: API key for Griptape Cloud.
headers: Headers for the Griptape Cloud Knowledge Base API.
"""

base_url: str = field(default="https://cloud.griptape.ai", kw_only=True)
api_key: str = field(kw_only=True)
headers: dict = field(
default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), kw_only=True
)
14 changes: 3 additions & 11 deletions griptape/tools/griptape_cloud_knowledge_base_client/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,21 @@
from typing import Optional
from urllib.parse import urljoin
from schema import Schema, Literal
from attr import define, field, Factory
from griptape.tools import BaseTool
from attr import define, field
from griptape.tools.base_griptape_cloud_client import BaseGriptapeCloudClient
from griptape.utils.decorators import activity
from griptape.artifacts import TextArtifact, ErrorArtifact


@define
class GriptapeCloudKnowledgeBaseClient(BaseTool):
class GriptapeCloudKnowledgeBaseClient(BaseGriptapeCloudClient):
"""
Attributes:
description: LLM-friendly knowledge base description.
base_url: Base URL for the Griptape Cloud Knowledge Base API.
api_key: API key for Griptape Cloud.
headers: Headers for the Griptape Cloud Knowledge Base API.
knowledge_base_id: ID of the Griptape Cloud Knowledge Base.
"""

description: Optional[str] = field(default=None, kw_only=True)
base_url: str = field(default="https://cloud.griptape.ai", kw_only=True)
api_key: str = field(kw_only=True)
headers: dict = field(
default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), kw_only=True
)
knowledge_base_id: str = field(kw_only=True)

@activity(
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
version: "v1"
name: Griptape Cloud Structure Run Client
description: Tool for using the Griptape Cloud Structure Run API.
contact_email: hello@griptape.ai
legal_info_url: https://www.griptape.ai/legal
99 changes: 99 additions & 0 deletions griptape/tools/griptape_cloud_structure_run_client/tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from __future__ import annotations
import time
from typing import Any, Optional
from urllib.parse import urljoin
from schema import Schema, Literal
from attr import define, field
from griptape.tools.base_griptape_cloud_client import BaseGriptapeCloudClient
from griptape.utils.decorators import activity
from griptape.artifacts import InfoArtifact, TextArtifact, ErrorArtifact


@define
class GriptapeCloudStructureRunClient(BaseGriptapeCloudClient):
"""
Attributes:
description: LLM-friendly structure description.
structure_id: ID of the Griptape Cloud Structure.
"""

_description: Optional[str] = field(default=None, kw_only=True)
structure_id: str = field(kw_only=True)
structure_run_wait_time_interval: int = field(default=2, kw_only=True)
structure_run_max_wait_time_attempts: int = field(default=20, kw_only=True)

@property
def description(self) -> str:
if self._description is None:
from requests import get

url = urljoin(self.base_url.strip("/"), f"/api/structures/{self.structure_id}/")

response = get(url, headers=self.headers).json()
if "description" in response:
self._description = response["description"]
else:
raise ValueError(f'Error getting Structure description: {response["message"]}')

return self._description

@description.setter
def description(self, value: str) -> None:
self._description = value

@activity(
config={
"description": "Can be used to execute a Run of a Structure with the following description: {{ _self.description }}",
"schema": Schema(
{Literal("args", description="A list of string arguments to submit to the Structure Run"): list}
),
}
)
def execute_structure_run(self, params: dict) -> InfoArtifact | TextArtifact | ErrorArtifact:
from requests import post, exceptions, HTTPError, Response

args: list[str] = params["values"]["args"]
url = urljoin(self.base_url.strip("/"), f"/api/structures/{self.structure_id}/runs")

try:
response: Response = post(url, json={"args": args}, headers=self.headers)
response.raise_for_status()
response_json = response.json()
return self._get_structure_run_result(response_json["structure_run_id"])

except (exceptions.RequestException, HTTPError) as err:
return ErrorArtifact(str(err))

def _get_structure_run_result(self, structure_run_id: str) -> InfoArtifact | TextArtifact | ErrorArtifact:
url = urljoin(self.base_url.strip("/"), f"/api/structure-runs/{structure_run_id}")

result = self._get_structure_run_result_attempt(url)
status = result["status"]

wait_attempts = 0
while status in ("QUEUED", "RUNNING") and wait_attempts < self.structure_run_max_wait_time_attempts:
# wait
time.sleep(self.structure_run_wait_time_interval)
wait_attempts += 1
result = self._get_structure_run_result_attempt(url)
status = result["status"]

if wait_attempts >= self.structure_run_max_wait_time_attempts:
return ErrorArtifact(
f"Failed to get Run result after {self.structure_run_max_wait_time_attempts} attempts."
)

if status != "SUCCEEDED":
return ErrorArtifact(result)

if "output" in result:
return TextArtifact(result["output"])
else:
return InfoArtifact("No output found in response")

def _get_structure_run_result_attempt(self, structure_run_url: str) -> Any:
from requests import get, Response

response: Response = get(structure_run_url, headers=self.headers)
response.raise_for_status()
return response.json()
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ nav:
- GoogleGmailClient: "griptape-tools/official-tools/google-gmail-client.md"
- GoogleDriveClient: "griptape-tools/official-tools/google-drive-client.md"
- GoogleDocsClient: "griptape-tools/official-tools/google-docs-client.md"
- GriptapeCloudStructureRunClient: "griptape-tools/official-tools/griptape-cloud-structure-run-client.md"
- OpenWeatherClient: "griptape-tools/official-tools/openweather-client.md"
- RestApiClient: "griptape-tools/official-tools/rest-api-client.md"
- SqlClient: "griptape-tools/official-tools/sql-client.md"
Expand Down
27 changes: 27 additions & 0 deletions tests/unit/tools/test_griptape_cloud_structure_run_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest
from griptape.artifacts import TextArtifact


class TestGriptapeCloudStructureRunClient:
@pytest.fixture
def client(self, mocker):
from griptape.tools import GriptapeCloudStructureRunClient

mock_response = mocker.Mock()
mock_response.json.return_value = {"structure_run_id": 1}
mocker.patch("requests.post", return_value=mock_response)

mock_response = mocker.Mock()
mock_response.json.return_value = {"description": "fizz buzz", "output": "fooey booey", "status": "SUCCEEDED"}
mocker.patch("requests.get", return_value=mock_response)

return GriptapeCloudStructureRunClient(base_url="https://api.griptape.ai", api_key="foo bar", structure_id="1")

def test_execute_structure_run(self, client):
assert isinstance(client.execute_structure_run({"values": {"args": ["foo bar"]}}), TextArtifact)

def test_get_structure_description(self, client):
assert client.description == "fizz buzz"

client.description = "foo bar"
assert client.description == "foo bar"
Loading