Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
laughingman7743 committed May 12, 2024
1 parent 28e4bf1 commit ae45082
Showing 1 changed file with 82 additions and 53 deletions.
135 changes: 82 additions & 53 deletions pyathena/filesystem/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def _head_bucket(self, bucket, refresh: bool = False) -> Optional[S3Object]:
bucket=bucket,
key=None,
version_id=None,
delimiter=None,
)
self.dircache[bucket] = file
else:
Expand Down Expand Up @@ -207,6 +208,7 @@ def _head_object(
bucket=bucket,
key=key,
version_id=version_id,
delimiter=None,
)
self.dircache[path] = file
else:
Expand Down Expand Up @@ -234,6 +236,7 @@ def _ls_buckets(self, refresh: bool = False) -> List[S3Object]:
bucket=b["Name"],
key=None,
version_id=None,
delimiter=None,
)
for b in response["Buckets"]
]
Expand All @@ -254,58 +257,63 @@ def _ls_dirs(
bucket, key, version_id = self.parse_path(path)
if key:
prefix = f"{key}/{prefix if prefix else ''}"
if path not in self.dircache or refresh:
files: List[S3Object] = []
while True:
request: Dict[Any, Any] = {
"Bucket": bucket,
"Prefix": prefix,
"Delimiter": delimiter,
}
if next_token:
request.update({"ContinuationToken": next_token})
if max_keys:
request.update({"MaxKeys": max_keys})
response = self._call(
self._client.list_objects_v2,
**request,
)
files.extend(
S3Object(
init={
"ContentLength": 0,
"ContentType": None,
"StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY,
"ETag": None,
"LastModified": None,
},
type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY,
bucket=bucket,
key=c["Prefix"][:-1].rstrip("/"),
version_id=version_id,
)
for c in response.get("CommonPrefixes", [])
)
files.extend(
S3Object(
init=c,
type=S3ObjectType.S3_OBJECT_TYPE_FILE,
bucket=bucket,
key=c["Key"],
)
for c in response.get("Contents", [])
)
next_token = response.get("NextContinuationToken")
if not next_token:
break
if files:
self.dircache[path] = files
else:

if path in self.dircache and not refresh:
cache = self.dircache[path]
if not isinstance(cache, list):
files = [cache]
caches = [cache]
else:
files = cache
caches = cache
if all([f.delimiter == delimiter for f in caches]):
return caches

files: List[S3Object] = []
while True:
request: Dict[Any, Any] = {
"Bucket": bucket,
"Prefix": prefix,
"Delimiter": delimiter,
}
if next_token:
request.update({"ContinuationToken": next_token})
if max_keys:
request.update({"MaxKeys": max_keys})
response = self._call(
self._client.list_objects_v2,
**request,
)
files.extend(
S3Object(
init={
"ContentLength": 0,
"ContentType": None,
"StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY,
"ETag": None,
"LastModified": None,
},
type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY,
bucket=bucket,
key=c["Prefix"][:-1].rstrip("/"),
version_id=version_id,
delimiter=delimiter,
)
for c in response.get("CommonPrefixes", [])
)
files.extend(
S3Object(
init=c,
type=S3ObjectType.S3_OBJECT_TYPE_FILE,
bucket=bucket,
key=c["Key"],
delimiter=delimiter,
)
for c in response.get("Contents", [])
)
next_token = response.get("NextContinuationToken")
if not next_token:
break
if files:
self.dircache[path] = files
return files

def ls(
Expand Down Expand Up @@ -340,6 +348,7 @@ def info(self, path: str, **kwargs) -> S3Object:
bucket=bucket,
key=None,
version_id=None,
delimiter=None,
)
if not refresh:
caches: Union[List[S3Object], S3Object] = self._ls_from_cache(path)
Expand All @@ -366,6 +375,7 @@ def info(self, path: str, **kwargs) -> S3Object:
bucket=bucket,
key=key.rstrip("/") if key else None,
version_id=version_id,
delimiter=None,
)
if key:
object_info = self._head_object(path, refresh=refresh, version_id=version_id)
Expand Down Expand Up @@ -402,31 +412,50 @@ def info(self, path: str, **kwargs) -> S3Object:
bucket=bucket,
key=key.rstrip("/") if key else None,
version_id=version_id,
delimiter=None,
)
else:
raise FileNotFoundError(path)

def find(
def _find(
self,
path: str,
maxdepth: Optional[int] = None,
withdirs: Optional[bool] = None,
detail: bool = False,
**kwargs,
) -> Union[Dict[str, S3Object], List[str]]:
# TODO: Support maxdepth and withdirs
) -> List[S3Object]:
path = self._strip_protocol(path)
if path in ["", "/"]:
raise ValueError("Cannot traverse all files in S3.")
bucket, key, _ = self.parse_path(path)
prefix = kwargs.pop("prefix", "")
if maxdepth:
return super().find(
path=path,
maxdepth=maxdepth,
withdirs=withdirs,
detail=True,
**kwargs
).values()

files = self._ls_dirs(path, prefix=prefix, delimiter="")
if not files and key:
try:
files = [self.info(path)]
except FileNotFoundError:
files = []
return files

def find(
self,
path: str,
maxdepth: Optional[int] = None,
withdirs: Optional[bool] = None,
detail: bool = False,
**kwargs,
) -> Union[Dict[str, S3Object], List[str]]:
# TODO: Support withdirs
files = self._find(path=path, maxdepth=maxdepth, withdirs=withdirs, **kwargs)
if detail:
return {f.name: f for f in files}
else:
Expand Down

0 comments on commit ae45082

Please sign in to comment.