Skip to content

Commit

Permalink
Merge pull request #1918 from AdeelH/vsi_unit_tests
Browse files Browse the repository at this point in the history
Add unit tests for `VsiFileSystem`
  • Loading branch information
AdeelH committed Sep 11, 2023
2 parents 60fc990 + 06979b3 commit b334d68
Show file tree
Hide file tree
Showing 3 changed files with 241 additions and 133 deletions.
228 changes: 95 additions & 133 deletions rastervision_gdal_vsi/rastervision/gdal_vsi/vsi_file_system.py
@@ -1,13 +1,30 @@
from datetime import datetime
from typing import List, Optional
import os
from os.path import join
from pathlib import Path
from typing import (List, Optional)
import re
from datetime import datetime
from urllib.parse import urlparse

from rastervision.pipeline.file_system import FileSystem

from osgeo import gdal

ARCHIVE_URI_FORMAT = (
r'^(?P<archive_scheme>[^+]+)\+(?P<archive_uri>[^!]+)!(?P<file_path>.+)$')
URI_SCHEME_TO_VSI = {
'http': 'vsicurl',
'https': 'vsicurl',
'ftp': 'vsicurl',
's3': 'vsis3',
'gs': 'vsigs',
}
ARCHIVE_SCHEME_TO_VSI = {
'zip': 'vsizip',
'gzip': 'vsigzip',
'tar': 'vsitar',
}


class VsiFileSystem(FileSystem):
"""A FileSystem to access files over any protocol supported by GDAL's VSI"""
Expand All @@ -18,77 +35,63 @@ def uri_to_vsi_path(uri: str) -> str:
Args:
uri: URI of the file, possibly nested within archives as follows
<archive_scheme>+<archive_URI>!path/to/contained/file.ext
Acceptable URI schemes are file, s3, gs, http, https, and ftp
Allowable archive schema are tar, zip, and gzip
"""
parsed = urlparse(uri)
scheme = parsed.scheme.split('+')[0]
<archive_scheme>+<archive_URI>!path/to/contained/file.ext.
Acceptable URI schemes are file, s3, gs, http, https, and ftp.
Allowable archive schema are tar, zip, and gzip.
archive_content = uri.rfind('!')
if archive_content == -1:
Raises:
ValueError: If URI format or schema is invalid.
"""
parsed = VsiFileSystem.parse_archive_format(uri)
if parsed is None:
# regular URI
if scheme == 'http' or scheme == 'https' or scheme == 'ftp':
return '/vsicurl/{}'.format(uri)
elif scheme == 's3' or scheme == 'gs':
return '/vsi{}/{}{}'.format(scheme, parsed.netloc, parsed.path)
else:
# assume file schema
return os.path.abspath(
os.path.join(parsed.netloc, parsed.path))
else:
archive_target = uri.find('+')
assert archive_target != -1

if scheme in ['zip', 'tar', 'gzip']:
return '/vsi{}/{}/{}'.format(
scheme,
VsiFileSystem.uri_to_vsi_path(
uri[archive_target + 1:archive_content]),
uri[archive_content + 1:])
else:
raise ValueError(
'Attempted access into archive with unsupported scheme "{}"'.
format(scheme))
parsed = urlparse(uri)
scheme, netloc, path = parsed.scheme, parsed.netloc, parsed.path
if scheme in URI_SCHEME_TO_VSI:
return join('/', URI_SCHEME_TO_VSI[scheme], f'{netloc}{path}')
# assume file schema
return os.path.abspath(join(netloc, path))

archive_scheme = parsed['archive_scheme']
archive_uri = parsed['archive_uri']
file_path = parsed['file_path']
try:
vsi_archive_scheme = ARCHIVE_SCHEME_TO_VSI[archive_scheme]
except KeyError:
raise ValueError('Expected archive scheme to be one of "zip", '
f'"tar", or "gzip". Found "{archive_scheme}".')
vsi_archive_uri = VsiFileSystem.uri_to_vsi_path(archive_uri)
vsipath = join(f'/{vsi_archive_scheme}{vsi_archive_uri}', file_path)
return vsipath

@staticmethod
def matches_uri(vsipath: str, mode: str) -> bool:
"""Returns True if this FS can be used for the given URI/mode pair.
def parse_archive_format(uri: str) -> re.Match:
match = re.match(ARCHIVE_URI_FORMAT, uri)
if match is None:
return None
return match.groupdict()

Args:
uri: URI of file
mode: mode to open file in, 'r' or 'w'
"""
if mode == 'r' and vsipath.startswith('/vsi'):
return True
elif mode == 'w' and vsipath.startswith(
'/vsi') and '/vsicurl/' not in vsipath:
return True
else:
@staticmethod
def matches_uri(uri: str, mode: str) -> bool:
if not uri.startswith('/vsi'):
return False
if mode == 'w' and '/vsicurl/' in uri:
return False
return True

@staticmethod
def file_exists(vsipath: str, include_dir: bool = True) -> bool:
"""Check if a file exists.
Args:
uri: The URI to check
include_dir: Include directories in check, if this file_system
supports directory reads. Otherwise only return true if a single
file exists at the URI.
"""
file_stats = gdal.VSIStatL(vsipath)
if include_dir:
return True if file_stats else False
return bool(file_stats)
else:
return True if file_stats and not file_stats.IsDirectory(
) else False
return file_stats and not file_stats.IsDirectory()

@staticmethod
def read_bytes(vsipath: str) -> bytes:
stats = gdal.VSIStatL(vsipath)
if not stats or stats.IsDirectory():
raise FileNotFoundError('{} does not exist'.format(vsipath))
raise FileNotFoundError(f'{vsipath} does not exist')

try:
retval = bytes()
Expand Down Expand Up @@ -117,147 +120,106 @@ def write_bytes(vsipath: str, data: bytes):

@staticmethod
def write_str(uri: str, data: str):
"""Write string in data to URI."""
VsiFileSystem.write_bytes(uri, data.encode())

@staticmethod
def sync_to_dir(src_dir: str, dst_dir_uri: str, delete: bool = False):
"""Syncs a local source directory to a destination directory.
If the FileSystem is remote, this involves uploading.
Args:
src_dir: local source directory to sync from
dst_dir_uri: A destination directory that can be synced to by this
FileSystem
delete: True if the destination should be deleted first.
"""

def work(src, vsi_dest):
def work(src: Path, vsi_dest: str):
gdal.Mkdir(vsi_dest, 0o777)

for item in src.iterdir():
item_vsi_dest = os.path.join(vsi_dest, item.name)
item_vsi_dest = join(vsi_dest, item.name)
if item.is_dir():
work(item, item_vsi_dest)
else:
VsiFileSystem.copy_to(str(item), item_vsi_dest)

stats = gdal.VSIStatL(dst_dir_uri)
if stats:
assert delete, 'Cannot overwrite existing files if delete=False'
if not delete:
raise FileExistsError(
'Target location must not exist if delete=False')
if stats.IsDirectory():
gdal.RmdirRecursive(dst_dir_uri)
else:
gdal.Unlink(dst_dir_uri)

src = Path(src_dir)
assert src.exists() and src.is_dir(), \
'Local source ({}) must be a directory'.format(src_dir)
if not (src.exists() and src.is_dir()):
raise ValueError('Source must be a directory')

work(src, dst_dir_uri)

@staticmethod
def sync_from_dir(src_dir_uri: str, dst_dir: str, delete: bool = False):
"""Syncs a source directory to a local destination directory.
If the FileSystem is remote, this involves downloading.
Args:
src_dir_uri: source directory that can be synced from by this FileSystem
dst_dir: A local destination directory
delete: True if the destination should be deleted first.
"""

def work(vsi_src, dest):
def work(vsi_src: str, dest: Path):
if dest.exists():
assert dest.is_dir(
), 'Local target ({}) must be a directory'.format(dest)
if not dest.is_dir():
raise ValueError(
f'Local target ({dest}) must be a directory')
else:
dest.mkdir()

for item in gdal.ReadDir(vsi_src):
item_vsi_src = os.path.join(vsi_src, item)
for item in VsiFileSystem.list_children(vsi_src):
item_vsi_src = join(vsi_src, item)
target = dest.joinpath(item)
if gdal.VSIStatL(item_vsi_src).IsDirectory():
work(item_vsi_src, target)
else:
assert not target.exists() or delete, \
'Target location must not exist if delete=False'
if target.exists() and not delete:
raise FileExistsError(
'Target location must not exist if delete=False')
VsiFileSystem.copy_from(item_vsi_src, str(target))

stats = gdal.VSIStatL(src_dir_uri)
assert stats and stats.IsDirectory(), 'Source must be a directory'
if not (stats and stats.IsDirectory()):
raise ValueError('Source must be a directory')

work(src_dir_uri, Path(dst_dir))

@staticmethod
def copy_to(src_path: str, dst_uri: str):
"""Copy a local source file to a destination.
If the FileSystem is remote, this involves uploading.
Args:
src_path: local path to source file
dst_uri: uri of destination that can be copied to by this FileSystem
"""
with open(src_path, 'rb') as f:
buf = f.read()
VsiFileSystem.write_bytes(dst_uri, buf)

@staticmethod
def copy_from(src_uri: str, dst_path: str):
"""Copy a source file to a local destination.
If the FileSystem is remote, this involves downloading.
Args:
src_uri: uri of source that can be copied from by this FileSystem
dst_path: local path to destination file
"""
buf = VsiFileSystem.read_bytes(src_uri)
with open(dst_path, 'wb') as f:
f.write(buf)

@staticmethod
def local_path(vsipath: str, download_dir: str) -> str:
"""Return the path where a local copy should be stored.
Args:
uri: the URI of the file to be copied
download_dir: path of the local directory in which files should
be copied
"""
filename = Path(vsipath).name
return os.path.join(download_dir, filename)
return join(download_dir, filename)

@staticmethod
def last_modified(vsipath: str) -> Optional[datetime]:
"""Get the last modified date of a file.
Args:
uri: the URI of the file
Returns:
the last modified date in UTC of a file or None if this FileSystem
does not support this operation.
"""
stats = gdal.VSIStatL(vsipath)
return datetime.fromtimestamp(stats.mtime) if stats else None

@staticmethod
def list_paths(vsipath: str, ext: Optional[str] = None) -> List[str]:
"""List paths rooted at URI.
items = VsiFileSystem.list_children(vsipath, ext=ext)
paths = [join(vsipath, item) for item in items]
return paths

Optionally only includes paths with a certain file extension.
@staticmethod
def list_children(vsipath: str, ext: Optional[str] = None) -> List[str]:
"""List filenames of children rooted at URI.
Optionally only includes filenames with a certain file extension.
Args:
uri: the URI of a directory
ext: the optional file extension to filter by
uri: The URI of a directory.
ext: The optional file extension to filter by.
Returns:
List of filenames excluding "." or "..".
"""
items = gdal.ReadDir(vsipath)
ext = ext if ext else ''
return [
os.path.join(vsipath, item) # This may not work for windows paths
for item in filter(lambda x: x.endswith(ext), items)
]
items = gdal.ReadDir(vsipath)
items = [item for item in items if item not in ['.', '..']]
items = [item for item in items if item.endswith(ext)]
return items
Empty file added tests/gdal_vsi/__init__.py
Empty file.

0 comments on commit b334d68

Please sign in to comment.