Skip to content

Commit

Permalink
Issue 6598: load_dataset broken for data_files on s3
Browse files Browse the repository at this point in the history
  • Loading branch information
matstrand committed May 3, 2024
1 parent 7ae4314 commit ded2cac
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 0 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@
"jax>=0.3.14; sys_platform != 'win32'",
"jaxlib>=0.3.14; sys_platform != 'win32'",
"lz4",
"moto[server]",
"pyspark>=3.4", # https://issues.apache.org/jira/browse/SPARK-40991 fixed in 3.4.0
"py7zr",
"rarfile>=4.0",
Expand Down
4 changes: 4 additions & 0 deletions src/datasets/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,10 @@ def get_from_cache(
if scheme == "ftp":
connected = ftp_head(url)
elif scheme not in ("http", "https"):
if scheme in ("s3", "s3a") and storage_options is not None and "hf" in storage_options:
# Issue 6598: **storage_options is passed to botocore.session.Session()
# and must not contain keys that become invalid kwargs.
del storage_options["hf"]
response = fsspec_head(url, storage_options=storage_options)
# s3fs uses "ETag", gcsfs uses "etag"
etag = (response.get("ETag", None) or response.get("etag", None)) if use_etag else None
Expand Down
38 changes: 38 additions & 0 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@
import shutil
import tempfile
import time
from contextlib import contextmanager
from hashlib import sha256
from multiprocessing import Pool
from pathlib import Path
from unittest import TestCase
from unittest.mock import patch

import boto3
import dill
import pyarrow as pa
import pytest
import requests
from moto.server import ThreadedMotoServer

import datasets
from datasets import config, load_dataset, load_from_disk
Expand Down Expand Up @@ -1648,6 +1651,41 @@ def test_load_from_disk_with_default_in_memory(
_ = load_from_disk(dataset_path)


@contextmanager
def moto_server():
with patch.dict(os.environ, {"AWS_ENDPOINT_URL": "http://localhost:5000"}):
server = ThreadedMotoServer()
server.start()
try:
yield
finally:
server.stop()


def test_load_file_from_s3():
# we need server mode here because of an aiobotocore incompatibility with moto.mock_aws
# (https://github.com/getmoto/moto/issues/6836)
with moto_server():
# Create a mock S3 bucket
bucket_name = "test-bucket"
s3 = boto3.client("s3", region_name="us-east-1")
s3.create_bucket(Bucket=bucket_name)

# Upload a file to the mock bucket
key = "test-file.csv"
csv_data = "Name\nPatrick\nMat"

s3.put_object(Bucket=bucket_name, Key=key, Body=csv_data)

# Load the file from the mock bucket
ds = datasets.load_dataset(
"csv", data_files={"train": "s3://test-bucket/test-file.csv"}
)

# Check if the loaded content matches the original content
assert list(ds["train"]) == [{"Name": "Patrick"}, {"Name": "Mat"}]


@pytest.mark.integration
def test_remote_data_files():
repo_id = "hf-internal-testing/raw_jsonl"
Expand Down

0 comments on commit ded2cac

Please sign in to comment.