Skip to content

Commit

Permalink
Add LocalFileManagerDriver
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanholmes committed Apr 9, 2024
1 parent 44b0653 commit 77fef0e
Show file tree
Hide file tree
Showing 10 changed files with 358 additions and 200 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- `list_files_from_disk` activity to `FileManager` Tool.
- `BaseFileManagerDriver` to abstract file management operations.
- `LocalFileManagerDriver` for managing files on the local file system.
- Added optional `BaseLoader.encoding` field.

### Changed
- Improved RAG performance in `VectorQueryEngine`.
- **BREAKING**: Secret fields (ex: api_key) removed from serialized Drivers.
- **BREAKING**: Removed `workdir`, `loaders`, `default_loader`, and `save_file_encoding` fields from `FileManager` and added `file_manager_driver`.

## [0.24.2] - 2024-04-04

Expand Down
5 changes: 5 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@
from .web_scraper.trafilatura_web_scraper_driver import TrafilaturaWebScraperDriver
from .web_scraper.markdownify_web_scraper_driver import MarkdownifyWebScraperDriver

from .file_manager.base_file_manager_driver import BaseFileManagerDriver
from .file_manager.local_file_manager_driver import LocalFileManagerDriver

__all__ = [
"BasePromptDriver",
"OpenAiChatPromptDriver",
Expand Down Expand Up @@ -161,4 +164,6 @@
"BaseWebScraperDriver",
"TrafilaturaWebScraperDriver",
"MarkdownifyWebScraperDriver",
"BaseFileManagerDriver",
"LocalFileManagerDriver",
]
Empty file.
52 changes: 52 additions & 0 deletions griptape/drivers/file_manager/base_file_manager_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional
from attr import Factory, define, field
from griptape.artifacts import BaseArtifact, ErrorArtifact, TextArtifact, InfoArtifact
from griptape.loaders import BaseLoader, CsvLoader, ImageLoader, PdfLoader, TextLoader


@define
class BaseFileManagerDriver(ABC):
"""
BaseFileManagerDriver can be used to list, load, and save files.
Attributes:
default_loader: The default loader to use for loading file contents into artifacts.
loaders: Dictionary of file extension specifc loaders to use for loading file contents into artifacts.
"""

default_loader: Optional[BaseLoader] = field(default=None, kw_only=True)
loaders: dict[str, BaseLoader] = field(
default=Factory(
lambda: {
"pdf": PdfLoader(),
"csv": CsvLoader(),
"txt": TextLoader(),
"html": TextLoader(),
"json": TextLoader(),
"yaml": TextLoader(),
"xml": TextLoader(),
"png": ImageLoader(),
"jpg": ImageLoader(),
"jpeg": ImageLoader(),
"webp": ImageLoader(),
"gif": ImageLoader(),
"bmp": ImageLoader(),
"tiff": ImageLoader(),
}
),
kw_only=True,
)

@abstractmethod
def list_files(self, path: str) -> TextArtifact | ErrorArtifact:
...

@abstractmethod
def load_file(self, path: str) -> BaseArtifact:
...

@abstractmethod
def save_file(self, path: str, value: bytes | str) -> InfoArtifact | ErrorArtifact:
...
74 changes: 74 additions & 0 deletions griptape/drivers/file_manager/local_file_manager_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from __future__ import annotations
import os
from pathlib import Path
from attr import define, field, Factory
from griptape.artifacts import ErrorArtifact, InfoArtifact, ListArtifact, BaseArtifact, TextArtifact
from .base_file_manager_driver import BaseFileManagerDriver


@define
class LocalFileManagerDriver(BaseFileManagerDriver):
"""
LocalFileManagerDriver can be used to list, load, and save files on the local file system.
Attributes:
workdir: The absolute working directory. List, load, and save operations will be performed relative to this directory.
"""

workdir: str = field(default=Factory(lambda: os.getcwd()), kw_only=True)

@workdir.validator # pyright: ignore
def validate_workdir(self, _, workdir: str) -> None:
if not Path(workdir).is_absolute():
raise ValueError("Workdir must be an absolute path")

def list_files(self, path: str) -> TextArtifact | ErrorArtifact:
path = path.lstrip("/")
full_path = Path(os.path.join(self.workdir, path))

if os.path.exists(full_path):
entries = os.listdir(full_path)

return TextArtifact("\n".join([e for e in entries]))
else:
return ErrorArtifact("Path not found")

def load_file(self, path: str) -> BaseArtifact:
path = path.lstrip("/")
full_path = Path(os.path.join(self.workdir, path))
extension = path.split(".")[-1]
loader = self.loaders.get(extension) or self.default_loader
try:
result = loader.load(full_path)
except Exception as e:
return ErrorArtifact(f"Failed to load file: {str(e)}")

if isinstance(result, BaseArtifact):
return result
else:
return ListArtifact(result)

def save_file(self, path: str, value: bytes | str) -> InfoArtifact | ErrorArtifact:
path = path.lstrip("/")
full_path = Path(os.path.join(self.workdir, path))
extension = path.split(".")[-1]
loader = self.loaders.get(extension) or self.default_loader
encoding = None if loader is None else loader.encoding

os.makedirs(os.path.dirname(full_path), exist_ok=True)

try:
if isinstance(value, str):
if encoding is None:
value = value.encode()
else:
value = value.encode(encoding=encoding)
elif isinstance(value, bytearray) or isinstance(value, memoryview):
raise ValueError(f"Unsupported type: {type(value)}")

with open(full_path, "wb") as file:
file.write(value)
except Exception as e:
return ErrorArtifact(f"Failed to save file: {str(e)}")

return InfoArtifact("Successfully saved file")
3 changes: 2 additions & 1 deletion griptape/loaders/base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from abc import ABC, abstractmethod
from concurrent import futures
from typing import Any
from typing import Any, Optional
from collections.abc import Mapping, Sequence

from attr import define, field, Factory
Expand All @@ -13,6 +13,7 @@
@define
class BaseLoader(ABC):
futures_executor: futures.Executor = field(default=Factory(lambda: futures.ThreadPoolExecutor()), kw_only=True)
encoding: Optional[str] = field(default=None, kw_only=True)

@abstractmethod
def load(self, source: Any, *args, **kwargs) -> BaseArtifact | Sequence[BaseArtifact]:
Expand Down
136 changes: 32 additions & 104 deletions griptape/tools/file_manager/tool.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,23 @@
from __future__ import annotations
import logging
import os
from pathlib import Path
from attr import define, field, Factory
from griptape.artifacts import ErrorArtifact, InfoArtifact, ListArtifact, BaseArtifact, TextArtifact
from griptape.artifacts import ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact
from griptape.drivers import BaseFileManagerDriver, LocalFileManagerDriver
from griptape.tools import BaseTool
from griptape.utils.decorators import activity
from griptape.loaders import FileLoader, BaseLoader, PdfLoader, CsvLoader, TextLoader, ImageLoader
from schema import Schema, Literal
from typing import Optional, Any


@define
class FileManager(BaseTool):
"""
FileManager is a tool that can be used to load and save files.
FileManager is a tool that can be used to list, load, and save files.
Attributes:
workdir: The absolute directory to load files from and save files to.
loaders: Dictionary of file extensions and matching loaders to use when loading files in load_files_from_disk.
default_loader: The loader to use when loading files in load_files_from_disk without any matching loader in `loaders`.
save_file_encoding: The encoding to use when saving files to disk.
file_manager_driver: File Manager Driver to use to list, load, and save files.
"""

workdir: str = field(default=Factory(lambda: os.getcwd()), kw_only=True)
default_loader: BaseLoader = field(default=Factory(lambda: FileLoader()))
loaders: dict[str, BaseLoader] = field(
default=Factory(
lambda: {
"pdf": PdfLoader(),
"csv": CsvLoader(),
"txt": TextLoader(),
"html": TextLoader(),
"json": TextLoader(),
"yaml": TextLoader(),
"xml": TextLoader(),
"png": ImageLoader(),
"jpg": ImageLoader(),
"jpeg": ImageLoader(),
"webp": ImageLoader(),
"gif": ImageLoader(),
"bmp": ImageLoader(),
"tiff": ImageLoader(),
}
),
kw_only=True,
)
save_file_encoding: Optional[str] = field(default=None, kw_only=True)

@workdir.validator # pyright: ignore
def validate_workdir(self, _, workdir: str) -> None:
if not Path(workdir).is_absolute():
raise ValueError("workdir has to be absolute absolute")
file_manager_driver: BaseFileManagerDriver = field(default=Factory(lambda: LocalFileManagerDriver()), kw_only=True)

@activity(
config={
Expand All @@ -62,15 +28,8 @@ def validate_workdir(self, _, workdir: str) -> None:
}
)
def list_files_from_disk(self, params: dict) -> TextArtifact | ErrorArtifact:
path = params["values"]["path"].lstrip("/")
full_path = Path(os.path.join(self.workdir, path))

if os.path.exists(full_path):
entries = os.listdir(full_path)

return TextArtifact("\n".join([e for e in entries]))
else:
return ErrorArtifact("Path not found")
path = params["values"]["path"]
return self.file_manager_driver.list_files(path)

@activity(
config={
Expand All @@ -86,22 +45,14 @@ def list_files_from_disk(self, params: dict) -> TextArtifact | ErrorArtifact:
}
)
def load_files_from_disk(self, params: dict) -> ListArtifact | ErrorArtifact:
paths = params["values"]["paths"]
artifacts = []

for path in params["values"]["paths"]:
path = path.lstrip("/")
full_path = Path(os.path.join(self.workdir, path))
extension = path.split(".")[-1]
loader = self.loaders.get(extension) or self.default_loader
result = loader.load(full_path)

if isinstance(result, list):
artifacts.extend(result)
elif isinstance(result, BaseArtifact):
artifacts.append(result)
for path in paths:
artifact = self.file_manager_driver.load_file(path)
if isinstance(artifact, ListArtifact):
artifacts.extend(artifact.value)
else:
logging.warning(f"Unknown loader return type for file {path}")

artifacts.append(artifact)
return ListArtifact(artifacts)

@activity(
Expand All @@ -121,33 +72,29 @@ def load_files_from_disk(self, params: dict) -> ListArtifact | ErrorArtifact:
}
)
def save_memory_artifacts_to_disk(self, params: dict) -> ErrorArtifact | InfoArtifact:
memory = self.find_input_memory(params["values"]["memory_name"])
artifact_namespace = params["values"]["artifact_namespace"]
dir_name = params["values"]["dir_name"]
file_name = params["values"]["file_name"]
memory_name = params["values"]["memory_name"]
artifact_namespace = params["values"]["artifact_namespace"]

if memory:
list_artifact = memory.load_artifacts(artifact_namespace)
memory = self.find_input_memory(params["values"]["memory_name"])
if not memory:
return ErrorArtifact(f"Failed to save memory artifacts to disk - memory named '{memory_name}' not found")

if len(list_artifact) == 0:
return ErrorArtifact("no artifacts found")
elif len(list_artifact) == 1:
try:
self._save_to_disk(os.path.join(self.workdir, dir_name, file_name), list_artifact.value[0].value)
list_artifact = memory.load_artifacts(artifact_namespace)

return InfoArtifact("saved successfully")
except Exception as e:
return ErrorArtifact(f"error writing file to disk: {e}")
else:
try:
for a in list_artifact.value:
self._save_to_disk(os.path.join(self.workdir, dir_name, f"{a.name}-{file_name}"), a.to_text())
if len(list_artifact) == 0:
return ErrorArtifact(
f"Failed to save memory artifacts to disk - memory named '{memory_name}' does not contain any artifacts"
)

return InfoArtifact("saved successfully")
except Exception as e:
return ErrorArtifact(f"error writing file to disk: {e}")
else:
return ErrorArtifact("memory not found")
for artifact in list_artifact.value:
formatted_file_name = f"{artifact.name}-{file_name}" if len(list_artifact) > 1 else file_name
result = self.file_manager_driver.save_file(os.path.join(dir_name, formatted_file_name), artifact.value)
if isinstance(result, ErrorArtifact):
return result

return InfoArtifact("Successfully saved memory artifacts to disk")

@activity(
config={
Expand All @@ -164,25 +111,6 @@ def save_memory_artifacts_to_disk(self, params: dict) -> ErrorArtifact | InfoArt
}
)
def save_content_to_file(self, params: dict) -> ErrorArtifact | InfoArtifact:
path = params["values"]["path"]
content = params["values"]["content"]
new_path = params["values"]["path"].lstrip("/")
full_path = os.path.join(self.workdir, new_path)

try:
self._save_to_disk(full_path, content)

return InfoArtifact("saved successfully")
except Exception as e:
return ErrorArtifact(f"error writing file to disk: {e}")

def _save_to_disk(self, path: str, value: Any) -> None:
os.makedirs(os.path.dirname(path), exist_ok=True)

with open(path, "wb") as file:
if isinstance(value, str):
if self.save_file_encoding:
file.write(value.encode(self.save_file_encoding))
else:
file.write(value.encode())
else:
file.write(value)
return self.file_manager_driver.save_file(path, content)
Empty file.
Loading

0 comments on commit 77fef0e

Please sign in to comment.