diff --git a/rastervision_gdal_vsi/rastervision/gdal_vsi/vsi_file_system.py b/rastervision_gdal_vsi/rastervision/gdal_vsi/vsi_file_system.py index 75f016486..7d590dc78 100644 --- a/rastervision_gdal_vsi/rastervision/gdal_vsi/vsi_file_system.py +++ b/rastervision_gdal_vsi/rastervision/gdal_vsi/vsi_file_system.py @@ -1,13 +1,30 @@ -from datetime import datetime +from typing import List, Optional import os +from os.path import join from pathlib import Path -from typing import (List, Optional) +import re +from datetime import datetime from urllib.parse import urlparse from rastervision.pipeline.file_system import FileSystem from osgeo import gdal +ARCHIVE_URI_FORMAT = ( + r'^(?P[^+]+)\+(?P[^!]+)!(?P.+)$') +URI_SCHEME_TO_VSI = { + 'http': 'vsicurl', + 'https': 'vsicurl', + 'ftp': 'vsicurl', + 's3': 'vsis3', + 'gs': 'vsigs', +} +ARCHIVE_SCHEME_TO_VSI = { + 'zip': 'vsizip', + 'gzip': 'vsigzip', + 'tar': 'vsitar', +} + class VsiFileSystem(FileSystem): """A FileSystem to access files over any protocol supported by GDAL's VSI""" @@ -18,77 +35,63 @@ def uri_to_vsi_path(uri: str) -> str: Args: uri: URI of the file, possibly nested within archives as follows - +!path/to/contained/file.ext - Acceptable URI schemes are file, s3, gs, http, https, and ftp - Allowable archive schema are tar, zip, and gzip - """ - parsed = urlparse(uri) - scheme = parsed.scheme.split('+')[0] + +!path/to/contained/file.ext. + Acceptable URI schemes are file, s3, gs, http, https, and ftp. + Allowable archive schema are tar, zip, and gzip. - archive_content = uri.rfind('!') - if archive_content == -1: + Raises: + ValueError: If URI format or schema is invalid. + """ + parsed = VsiFileSystem.parse_archive_format(uri) + if parsed is None: # regular URI - if scheme == 'http' or scheme == 'https' or scheme == 'ftp': - return '/vsicurl/{}'.format(uri) - elif scheme == 's3' or scheme == 'gs': - return '/vsi{}/{}{}'.format(scheme, parsed.netloc, parsed.path) - else: - # assume file schema - return os.path.abspath( - os.path.join(parsed.netloc, parsed.path)) - else: - archive_target = uri.find('+') - assert archive_target != -1 - - if scheme in ['zip', 'tar', 'gzip']: - return '/vsi{}/{}/{}'.format( - scheme, - VsiFileSystem.uri_to_vsi_path( - uri[archive_target + 1:archive_content]), - uri[archive_content + 1:]) - else: - raise ValueError( - 'Attempted access into archive with unsupported scheme "{}"'. - format(scheme)) + parsed = urlparse(uri) + scheme, netloc, path = parsed.scheme, parsed.netloc, parsed.path + if scheme in URI_SCHEME_TO_VSI: + return join('/', URI_SCHEME_TO_VSI[scheme], f'{netloc}{path}') + # assume file schema + return os.path.abspath(join(netloc, path)) + + archive_scheme = parsed['archive_scheme'] + archive_uri = parsed['archive_uri'] + file_path = parsed['file_path'] + try: + vsi_archive_scheme = ARCHIVE_SCHEME_TO_VSI[archive_scheme] + except KeyError: + raise ValueError('Expected archive scheme to be one of "zip", ' + f'"tar", or "gzip". Found "{archive_scheme}".') + vsi_archive_uri = VsiFileSystem.uri_to_vsi_path(archive_uri) + vsipath = join(f'/{vsi_archive_scheme}{vsi_archive_uri}', file_path) + return vsipath @staticmethod - def matches_uri(vsipath: str, mode: str) -> bool: - """Returns True if this FS can be used for the given URI/mode pair. + def parse_archive_format(uri: str) -> re.Match: + match = re.match(ARCHIVE_URI_FORMAT, uri) + if match is None: + return None + return match.groupdict() - Args: - uri: URI of file - mode: mode to open file in, 'r' or 'w' - """ - if mode == 'r' and vsipath.startswith('/vsi'): - return True - elif mode == 'w' and vsipath.startswith( - '/vsi') and '/vsicurl/' not in vsipath: - return True - else: + @staticmethod + def matches_uri(uri: str, mode: str) -> bool: + if not uri.startswith('/vsi'): return False + if mode == 'w' and '/vsicurl/' in uri: + return False + return True @staticmethod def file_exists(vsipath: str, include_dir: bool = True) -> bool: - """Check if a file exists. - - Args: - uri: The URI to check - include_dir: Include directories in check, if this file_system - supports directory reads. Otherwise only return true if a single - file exists at the URI. - """ file_stats = gdal.VSIStatL(vsipath) if include_dir: - return True if file_stats else False + return bool(file_stats) else: - return True if file_stats and not file_stats.IsDirectory( - ) else False + return file_stats and not file_stats.IsDirectory() @staticmethod def read_bytes(vsipath: str) -> bytes: stats = gdal.VSIStatL(vsipath) if not stats or stats.IsDirectory(): - raise FileNotFoundError('{} does not exist'.format(vsipath)) + raise FileNotFoundError(f'{vsipath} does not exist') try: retval = bytes() @@ -117,27 +120,15 @@ def write_bytes(vsipath: str, data: bytes): @staticmethod def write_str(uri: str, data: str): - """Write string in data to URI.""" VsiFileSystem.write_bytes(uri, data.encode()) @staticmethod def sync_to_dir(src_dir: str, dst_dir_uri: str, delete: bool = False): - """Syncs a local source directory to a destination directory. - - If the FileSystem is remote, this involves uploading. - - Args: - src_dir: local source directory to sync from - dst_dir_uri: A destination directory that can be synced to by this - FileSystem - delete: True if the destination should be deleted first. - """ - - def work(src, vsi_dest): + def work(src: Path, vsi_dest: str): gdal.Mkdir(vsi_dest, 0o777) for item in src.iterdir(): - item_vsi_dest = os.path.join(vsi_dest, item.name) + item_vsi_dest = join(vsi_dest, item.name) if item.is_dir(): work(item, item_vsi_dest) else: @@ -145,119 +136,90 @@ def work(src, vsi_dest): stats = gdal.VSIStatL(dst_dir_uri) if stats: - assert delete, 'Cannot overwrite existing files if delete=False' + if not delete: + raise FileExistsError( + 'Target location must not exist if delete=False') if stats.IsDirectory(): gdal.RmdirRecursive(dst_dir_uri) else: gdal.Unlink(dst_dir_uri) src = Path(src_dir) - assert src.exists() and src.is_dir(), \ - 'Local source ({}) must be a directory'.format(src_dir) + if not (src.exists() and src.is_dir()): + raise ValueError('Source must be a directory') work(src, dst_dir_uri) @staticmethod def sync_from_dir(src_dir_uri: str, dst_dir: str, delete: bool = False): - """Syncs a source directory to a local destination directory. - - If the FileSystem is remote, this involves downloading. - - Args: - src_dir_uri: source directory that can be synced from by this FileSystem - dst_dir: A local destination directory - delete: True if the destination should be deleted first. - """ - - def work(vsi_src, dest): + def work(vsi_src: str, dest: Path): if dest.exists(): - assert dest.is_dir( - ), 'Local target ({}) must be a directory'.format(dest) + if not dest.is_dir(): + raise ValueError( + f'Local target ({dest}) must be a directory') else: dest.mkdir() - for item in gdal.ReadDir(vsi_src): - item_vsi_src = os.path.join(vsi_src, item) + for item in VsiFileSystem.list_children(vsi_src): + item_vsi_src = join(vsi_src, item) target = dest.joinpath(item) if gdal.VSIStatL(item_vsi_src).IsDirectory(): work(item_vsi_src, target) else: - assert not target.exists() or delete, \ - 'Target location must not exist if delete=False' + if target.exists() and not delete: + raise FileExistsError( + 'Target location must not exist if delete=False') VsiFileSystem.copy_from(item_vsi_src, str(target)) stats = gdal.VSIStatL(src_dir_uri) - assert stats and stats.IsDirectory(), 'Source must be a directory' + if not (stats and stats.IsDirectory()): + raise ValueError('Source must be a directory') work(src_dir_uri, Path(dst_dir)) @staticmethod def copy_to(src_path: str, dst_uri: str): - """Copy a local source file to a destination. - - If the FileSystem is remote, this involves uploading. - - Args: - src_path: local path to source file - dst_uri: uri of destination that can be copied to by this FileSystem - """ with open(src_path, 'rb') as f: buf = f.read() VsiFileSystem.write_bytes(dst_uri, buf) @staticmethod def copy_from(src_uri: str, dst_path: str): - """Copy a source file to a local destination. - - If the FileSystem is remote, this involves downloading. - - Args: - src_uri: uri of source that can be copied from by this FileSystem - dst_path: local path to destination file - """ buf = VsiFileSystem.read_bytes(src_uri) with open(dst_path, 'wb') as f: f.write(buf) @staticmethod def local_path(vsipath: str, download_dir: str) -> str: - """Return the path where a local copy should be stored. - - Args: - uri: the URI of the file to be copied - download_dir: path of the local directory in which files should - be copied - """ filename = Path(vsipath).name - return os.path.join(download_dir, filename) + return join(download_dir, filename) @staticmethod def last_modified(vsipath: str) -> Optional[datetime]: - """Get the last modified date of a file. - - Args: - uri: the URI of the file - - Returns: - the last modified date in UTC of a file or None if this FileSystem - does not support this operation. - """ stats = gdal.VSIStatL(vsipath) return datetime.fromtimestamp(stats.mtime) if stats else None @staticmethod def list_paths(vsipath: str, ext: Optional[str] = None) -> List[str]: - """List paths rooted at URI. + items = VsiFileSystem.list_children(vsipath, ext=ext) + paths = [join(vsipath, item) for item in items] + return paths - Optionally only includes paths with a certain file extension. + @staticmethod + def list_children(vsipath: str, ext: Optional[str] = None) -> List[str]: + """List filenames of children rooted at URI. + + Optionally only includes filenames with a certain file extension. Args: - uri: the URI of a directory - ext: the optional file extension to filter by + uri: The URI of a directory. + ext: The optional file extension to filter by. + + Returns: + List of filenames excluding "." or "..". """ - items = gdal.ReadDir(vsipath) ext = ext if ext else '' - return [ - os.path.join(vsipath, item) # This may not work for windows paths - for item in filter(lambda x: x.endswith(ext), items) - ] + items = gdal.ReadDir(vsipath) + items = [item for item in items if item not in ['.', '..']] + items = [item for item in items if item.endswith(ext)] + return items diff --git a/tests/gdal_vsi/__init__.py b/tests/gdal_vsi/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/gdal_vsi/test_vsi_file_system.py b/tests/gdal_vsi/test_vsi_file_system.py new file mode 100644 index 000000000..6223e8627 --- /dev/null +++ b/tests/gdal_vsi/test_vsi_file_system.py @@ -0,0 +1,146 @@ +from os.path import join +import unittest + +from rastervision.pipeline.file_system import (get_tmp_dir, str_to_file, + LocalFileSystem) +from rastervision.gdal_vsi.vsi_file_system import VsiFileSystem + +fs = VsiFileSystem + + +class TestVsiFileSystem(unittest.TestCase): + def test_uri_to_vsi_path(self): + self.assertEqual(fs.uri_to_vsi_path('/a/b/c'), '/a/b/c') + self.assertEqual(fs.uri_to_vsi_path('http://a/b/c'), '/vsicurl/a/b/c') + self.assertEqual(fs.uri_to_vsi_path('https://a/b/c'), '/vsicurl/a/b/c') + self.assertEqual(fs.uri_to_vsi_path('ftp://a/b/c'), '/vsicurl/a/b/c') + self.assertEqual(fs.uri_to_vsi_path('s3://a/b/c'), '/vsis3/a/b/c') + self.assertEqual(fs.uri_to_vsi_path('gs://a/b/c'), '/vsigs/a/b/c') + + def test_uri_to_vsi_path_archive(self): + with self.assertRaises(ValueError): + _ = fs.uri_to_vsi_path('wrongscheme+s3://a/b!c') + + self.assertEqual( + fs.uri_to_vsi_path('zip+s3://a/b!c'), '/vsizip/vsis3/a/b/c') + self.assertEqual( + fs.uri_to_vsi_path('gzip+s3://a/b!c'), '/vsigzip/vsis3/a/b/c') + self.assertEqual( + fs.uri_to_vsi_path('tar+s3://a/b!c'), '/vsitar/vsis3/a/b/c') + + def test_matches_uri(self): + self.assertFalse(fs.matches_uri('/a/b/c', 'r')) + self.assertTrue(fs.matches_uri('/vsis3/a/b/c', 'r')) + self.assertTrue(fs.matches_uri('/vsis3/a/b/c', 'w')) + self.assertTrue(fs.matches_uri('/vsicurl/a/b/c', 'r')) + self.assertFalse(fs.matches_uri('/vsicurl/a/b/c', 'w')) + + def test_local_path(self): + vsipath = '/vsicurl/a/b/c' + self.assertEqual(fs.local_path(vsipath, '/'), '/c') + + def test_read_write_bytes(self): + with get_tmp_dir() as tmp_dir: + path = join(tmp_dir, 'test.bin') + path_vsi = fs.uri_to_vsi_path(path) + bytes_in = bytes([0x00, 0x01, 0x02]) + fs.write_bytes(path_vsi, bytes_in) + bytes_out = fs.read_bytes(path_vsi) + self.assertEqual(bytes_in, bytes_out) + + with self.assertRaises(FileNotFoundError): + fs.read_bytes(path_vsi) + + def test_read_write_str(self): + with get_tmp_dir() as tmp_dir: + path = join(tmp_dir, 'test.txt') + path_vsi = fs.uri_to_vsi_path(path) + str_in = 'abc' + fs.write_str(path_vsi, str_in) + str_out = fs.read_str(path_vsi) + self.assertEqual(str_in, str_out) + + def test_list_paths(self): + with get_tmp_dir() as tmp_dir: + dir_vsi = fs.uri_to_vsi_path(tmp_dir) + str_to_file('abc', join(tmp_dir, '1.txt')) + str_to_file('def', join(tmp_dir, '2.txt')) + str_to_file('ghi', join(tmp_dir, '3.tiff')) + paths = fs.list_paths(dir_vsi, ext='txt') + self.assertListEqual( + paths, [join(tmp_dir, '1.txt'), + join(tmp_dir, '2.txt')]) + + def test_sync_to_from(self): + with get_tmp_dir() as src, get_tmp_dir() as dst: + src_vsi = fs.uri_to_vsi_path(src) + dst_vsi = fs.uri_to_vsi_path(dst) + str_to_file('abc', join(src, '1.txt')) + str_to_file('def', join(src, '2.txt')) + str_to_file('ghi', join(src, 'subdir', '3.txt')) + fs.sync_to_dir(src_vsi, dst_vsi, delete=True) + paths = fs.list_paths(dst_vsi) + self.assertListEqual(paths, [ + join(dst, 'subdir'), + join(dst, '1.txt'), + join(dst, '2.txt'), + ]) + paths = fs.list_paths(dst_vsi, ext='txt') + self.assertListEqual(paths, [ + join(dst, '1.txt'), + join(dst, '2.txt'), + ]) + paths = fs.list_paths(join(dst_vsi, 'subdir')) + self.assertListEqual(paths, [join(dst, 'subdir', '3.txt')]) + + with self.assertRaises(FileExistsError): + fs.sync_to_dir(src_vsi, dst_vsi, delete=False) + + with self.assertRaises(ValueError): + fs.sync_to_dir(join(src, '1.txt'), dst_vsi, delete=True) + + fs.sync_from_dir(src_vsi, dst_vsi, delete=True) + paths = fs.list_paths(src_vsi) + self.assertListEqual(paths, [ + join(src, 'subdir'), + join(src, '1.txt'), + join(src, '2.txt'), + ]) + paths = fs.list_paths(src_vsi, ext='txt') + self.assertListEqual(paths, [ + join(src, '1.txt'), + join(src, '2.txt'), + ]) + paths = fs.list_paths(join(src, 'subdir')) + self.assertListEqual(paths, [join(src, 'subdir', '3.txt')]) + + with self.assertRaises(FileExistsError): + fs.sync_from_dir(src_vsi, dst_vsi, delete=False) + + with self.assertRaises(ValueError): + fs.sync_from_dir(src_vsi, join(dst, '1.txt'), delete=True) + + with self.assertRaises(ValueError): + fs.sync_from_dir(join(src, '1.txt'), dst_vsi, delete=True) + + def test_last_modified(self): + with get_tmp_dir() as tmp_dir: + path = join(tmp_dir, '1.txt') + str_to_file('abc', path) + path_vsi = fs.uri_to_vsi_path(path) + self.assertEqual( + fs.last_modified(path_vsi).timestamp(), + int(LocalFileSystem.last_modified(path).timestamp())) + + def test_file_exists(self): + with get_tmp_dir() as tmp_dir: + path = join(tmp_dir, '1.txt') + str_to_file('abc', path) + path_vsi = fs.uri_to_vsi_path(path) + self.assertTrue(fs.file_exists(path_vsi, include_dir=False)) + dir_vsi = fs.uri_to_vsi_path(tmp_dir) + self.assertTrue(fs.file_exists(dir_vsi)) + + +if __name__ == '__main__': + unittest.main()