Skip to content

Commit

Permalink
Remove file reading responsibility from loaders (#737)
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanholmes committed Apr 9, 2024
1 parent 44b0653 commit c9edaff
Show file tree
Hide file tree
Showing 26 changed files with 317 additions and 431 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
- Improved RAG performance in `VectorQueryEngine`.
- **BREAKING**: Secret fields (ex: api_key) removed from serialized Drivers.
- **BREAKING**: Remove `FileLoader`.
- **BREAKING**: `CsvLoader` no longer accepts `str` file paths as a source. It will now accept the content of the CSV file as a `str` or `bytes` object.
- **BREAKING**: `PdfLoader` no longer accepts `str` file content, `Path` file paths or `IO` objects as sources. Instead, it will only accept the content of the PDF file as a `bytes` object.
- **BREAKING**: `TextLoader` no longer accepts `Path` file paths as a source. It will now accept the content of the text file as a `str` or `bytes` object.
- **BREAKING**: `FileManager.default_loader` is now `None` by default.

## [0.24.2] - 2024-04-04

Expand Down
2 changes: 0 additions & 2 deletions griptape/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from .sql_loader import SqlLoader
from .csv_loader import CsvLoader
from .dataframe_loader import DataFrameLoader
from .file_loader import FileLoader
from .email_loader import EmailLoader
from .image_loader import ImageLoader

Expand All @@ -20,7 +19,6 @@
"SqlLoader",
"CsvLoader",
"DataFrameLoader",
"FileLoader",
"EmailLoader",
"ImageLoader",
]
21 changes: 19 additions & 2 deletions griptape/loaders/base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from attr import define, field, Factory

from griptape.artifacts import BaseArtifact
from griptape.utils.futures import execute_futures_dict
from griptape.utils.hash import bytes_to_hash, str_to_hash


@define
Expand All @@ -18,8 +20,23 @@ class BaseLoader(ABC):
def load(self, source: Any, *args, **kwargs) -> BaseArtifact | Sequence[BaseArtifact]:
...

@abstractmethod
def load_collection(
self, sources: list[Any], *args, **kwargs
) -> Mapping[str, BaseArtifact | Sequence[BaseArtifact | Sequence[BaseArtifact]]]:
...
# Create a dictionary before actually submitting the jobs to the executor
# to avoid duplicate work.
sources_by_key = {self.to_key(source): source for source in sources}
return execute_futures_dict(
{
key: self.futures_executor.submit(self.load, source, *args, **kwargs)
for key, source in sources_by_key.items()
}
)

def to_key(self, source: Any, *args, **kwargs) -> str:
if isinstance(source, bytes):
return bytes_to_hash(source)
elif isinstance(source, str):
return str_to_hash(source)
else:
return str_to_hash(str(source))
21 changes: 10 additions & 11 deletions griptape/loaders/base_text_loader.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from __future__ import annotations

from abc import ABC
from typing import Optional
from typing import Any, Optional, Union, cast

from attrs import define, field, Factory
from pathlib import Path

from griptape.artifacts import TextArtifact
from griptape.artifacts.error_artifact import ErrorArtifact
from griptape.chunkers import TextChunker, BaseChunker
from griptape.drivers import BaseEmbeddingDriver
from griptape.loaders import BaseLoader
Expand All @@ -33,19 +33,18 @@ class BaseTextLoader(BaseLoader, ABC):
embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True)
encoding: str = field(default="utf-8", kw_only=True)

def _text_to_artifacts(self, text: str | Path) -> list[TextArtifact]:
artifacts = []
def load_collection(self, sources: list[Any], *args, **kwargs) -> dict[str, ErrorArtifact | list[TextArtifact]]:
return cast(
dict[str, Union[ErrorArtifact, list[TextArtifact]]], super().load_collection(sources, *args, **kwargs)
)

if isinstance(text, Path):
with open(text, encoding=self.encoding) as file:
body = file.read()
else:
body = text
def _text_to_artifacts(self, text: str) -> list[TextArtifact]:
artifacts = []

if self.chunker:
chunks = self.chunker.chunk(body)
chunks = self.chunker.chunk(text)
else:
chunks = [TextArtifact(body)]
chunks = [TextArtifact(text)]

if self.embedding_driver:
for chunk in chunks:
Expand Down
46 changes: 27 additions & 19 deletions griptape/loaders/csv_loader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations
import csv
from typing import Optional
from io import StringIO
from typing import Optional, Union, cast

from attr import define, field

from griptape import utils
from griptape.artifacts import CsvRowArtifact
from griptape.artifacts import CsvRowArtifact, ErrorArtifact
from griptape.drivers import BaseEmbeddingDriver
from griptape.loaders import BaseLoader

Expand All @@ -13,27 +14,34 @@
class CsvLoader(BaseLoader):
embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True)
delimiter: str = field(default=",", kw_only=True)
encoding: str = field(default="utf-8", kw_only=True)

def load(self, source: str, *args, **kwargs) -> list[CsvRowArtifact]:
return self._load_file(source)

def load_collection(self, sources: list[str], *args, **kwargs) -> dict[str, list[CsvRowArtifact]]:
return utils.execute_futures_dict(
{utils.str_to_hash(source): self.futures_executor.submit(self._load_file, source) for source in sources}
)

def _load_file(self, filename: str) -> list[CsvRowArtifact]:
def load(self, source: bytes | str, *args, **kwargs) -> ErrorArtifact | list[CsvRowArtifact]:
artifacts = []

with open(filename, encoding="utf-8") as csv_file:
reader = csv.DictReader(csv_file, delimiter=self.delimiter)
chunks = [CsvRowArtifact(row) for row in reader]
if isinstance(source, bytes):
try:
source = source.decode(encoding=self.encoding)
except UnicodeDecodeError:
return ErrorArtifact(f"Failed to decode bytes to string using encoding: {self.encoding}")
elif isinstance(source, (bytearray, memoryview)):
return ErrorArtifact(f"Unsupported source type: {type(source)}")

if self.embedding_driver:
for chunk in chunks:
chunk.generate_embedding(self.embedding_driver)
reader = csv.DictReader(StringIO(source), delimiter=self.delimiter)
chunks = [CsvRowArtifact(row) for row in reader]

if self.embedding_driver:
for chunk in chunks:
artifacts.append(chunk)
chunk.generate_embedding(self.embedding_driver)

for chunk in chunks:
artifacts.append(chunk)

return artifacts

def load_collection(
self, sources: list[bytes | str], *args, **kwargs
) -> dict[str, ErrorArtifact | list[CsvRowArtifact]]:
return cast(
dict[str, Union[ErrorArtifact, list[CsvRowArtifact]]], super().load_collection(sources, *args, **kwargs)
)
24 changes: 8 additions & 16 deletions griptape/loaders/dataframe_loader.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from __future__ import annotations

from typing import Optional, TYPE_CHECKING
from typing import Optional, TYPE_CHECKING, cast

from attr import define, field

from griptape import utils
from griptape.artifacts import CsvRowArtifact
from griptape.drivers import BaseEmbeddingDriver
from griptape.loaders import BaseLoader
from griptape.utils import import_optional_dependency
from griptape.utils.hash import str_to_hash

if TYPE_CHECKING:
from pandas import DataFrame
Expand All @@ -19,20 +19,9 @@ class DataFrameLoader(BaseLoader):
embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True)

def load(self, source: DataFrame, *args, **kwargs) -> list[CsvRowArtifact]:
return self._load_file(source)

def load_collection(self, sources: list[DataFrame], *args, **kwargs) -> dict[str, list[CsvRowArtifact]]:
return utils.execute_futures_dict(
{
self._dataframe_to_hash(source): self.futures_executor.submit(self._load_file, source)
for source in sources
}
)

def _load_file(self, dataframe: DataFrame) -> list[CsvRowArtifact]:
artifacts = []

chunks = [CsvRowArtifact(row) for row in dataframe.to_dict(orient="records")]
chunks = [CsvRowArtifact(row) for row in source.to_dict(orient="records")]

if self.embedding_driver:
for chunk in chunks:
Expand All @@ -43,7 +32,10 @@ def _load_file(self, dataframe: DataFrame) -> list[CsvRowArtifact]:

return artifacts

def _dataframe_to_hash(self, dataframe: DataFrame) -> str:
def load_collection(self, sources: list[DataFrame], *args, **kwargs) -> dict[str, list[CsvRowArtifact]]:
return cast(dict[str, list[CsvRowArtifact]], super().load_collection(sources, *args, **kwargs))

def to_key(self, source: DataFrame, *args, **kwargs) -> str:
hash_pandas_object = import_optional_dependency("pandas.core.util.hashing").hash_pandas_object

return utils.str_to_hash(str(hash_pandas_object(dataframe, index=True).values))
return str_to_hash(str(hash_pandas_object(source, index=True).values))
20 changes: 6 additions & 14 deletions griptape/loaders/email_loader.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from __future__ import annotations

from typing import Optional
from typing import Optional, Union, cast
import logging
import imaplib

from attr import astuple, define, field

from griptape.utils import execute_futures_dict, import_optional_dependency, str_to_hash
from griptape.utils import import_optional_dependency
from griptape.artifacts import ErrorArtifact, ListArtifact, TextArtifact
from griptape.loaders import BaseLoader

Expand Down Expand Up @@ -34,19 +34,8 @@ class EmailQuery:
password: str = field(kw_only=True)

def load(self, source: EmailQuery, *args, **kwargs) -> ListArtifact | ErrorArtifact:
return self._retrieve_email(source)

def load_collection(self, sources: list[EmailQuery], *args, **kwargs) -> dict[str, ListArtifact | ErrorArtifact]:
return execute_futures_dict(
{
str_to_hash(str(source)): self.futures_executor.submit(self._retrieve_email, source)
for source in set(sources)
}
)

def _retrieve_email(self, query: EmailQuery) -> ListArtifact | ErrorArtifact:
mailparser = import_optional_dependency("mailparser")
label, key, search_criteria, max_count = astuple(query)
label, key, search_criteria, max_count = astuple(source)

artifacts = []
try:
Expand Down Expand Up @@ -88,3 +77,6 @@ def _retrieve_email(self, query: EmailQuery) -> ListArtifact | ErrorArtifact:

def _count_messages(self, message_numbers: bytes):
return len(list(filter(None, message_numbers.decode().split(" "))))

def load_collection(self, sources: list[EmailQuery], *args, **kwargs) -> dict[str, ListArtifact | ErrorArtifact]:
return cast(dict[str, Union[ListArtifact, ErrorArtifact]], super().load_collection(sources, *args, **kwargs))
43 changes: 0 additions & 43 deletions griptape/loaders/file_loader.py

This file was deleted.

15 changes: 5 additions & 10 deletions griptape/loaders/image_loader.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

from io import BytesIO
from typing import Optional, TYPE_CHECKING
from typing import Optional, TYPE_CHECKING, cast

from attr import define, field

from griptape.utils import execute_futures_dict, str_to_hash, import_optional_dependency
from griptape.utils import import_optional_dependency
from griptape.artifacts import ImageArtifact
from griptape.loaders import BaseLoader

Expand Down Expand Up @@ -35,14 +35,6 @@ class ImageLoader(BaseLoader):
}

def load(self, source: bytes, *args, **kwargs) -> ImageArtifact:
return self._load(source)

def load_collection(self, sources: list[bytes], *args, **kwargs) -> dict[str, ImageArtifact]:
return execute_futures_dict(
{str_to_hash(str(source)): self.futures_executor.submit(self._load, source) for source in sources}
)

def _load(self, source: bytes) -> ImageArtifact:
Image = import_optional_dependency("PIL.Image")
image = Image.open(BytesIO(source))

Expand All @@ -67,3 +59,6 @@ def _get_mime_type(self, image_format: str | None) -> str:
raise ValueError(f"Unsupported image format {image_format}")

return self.FORMAT_TO_MIME_TYPE[image_format.lower()]

def load_collection(self, sources: list[bytes], *args, **kwargs) -> dict[str, ImageArtifact]:
return cast(dict[str, ImageArtifact], super().load_collection(sources, *args, **kwargs))
Loading

0 comments on commit c9edaff

Please sign in to comment.