From c92081e09e98c5c905ff36d97aab250f070653f1 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Wed, 14 Sep 2022 23:32:52 +0800 Subject: [PATCH 01/18] [Refactor] Refactor fileio but without breaking bc --- mmengine/fileio/__init__.py | 20 +- mmengine/fileio/backends/__init__.py | 14 + mmengine/fileio/backends/base.py | 25 + mmengine/fileio/backends/http_backend.py | 74 ++ mmengine/fileio/backends/lmdb_backend.py | 80 ++ mmengine/fileio/backends/local_backend.py | 522 +++++++++++ mmengine/fileio/backends/memcached_backend.py | 43 + mmengine/fileio/backends/petrel_backend.py | 754 +++++++++++++++ mmengine/fileio/backends/registry_utils.py | 115 +++ mmengine/fileio/io.py | 762 +++++++++++++++- requirements/tests.txt | 1 + .../test_backends/test_backend_utils.py | 114 +++ .../test_base_storage_backend.py | 27 + .../test_backends/test_http_backend.py | 51 ++ .../test_backends/test_lmdb_backend.py | 35 + .../test_backends/test_local_backend.py | 486 ++++++++++ .../test_backends/test_memcached_backend.py | 59 ++ .../test_backends/test_petrel_backend.py | 861 ++++++++++++++++++ tests/test_fileio/test_backends/utils.py | 9 + tests/test_fileio/test_io.py | 536 +++++++++++ 20 files changed, 4583 insertions(+), 5 deletions(-) create mode 100644 mmengine/fileio/backends/__init__.py create mode 100644 mmengine/fileio/backends/base.py create mode 100644 mmengine/fileio/backends/http_backend.py create mode 100644 mmengine/fileio/backends/lmdb_backend.py create mode 100644 mmengine/fileio/backends/local_backend.py create mode 100644 mmengine/fileio/backends/memcached_backend.py create mode 100644 mmengine/fileio/backends/petrel_backend.py create mode 100644 mmengine/fileio/backends/registry_utils.py create mode 100644 tests/test_fileio/test_backends/test_backend_utils.py create mode 100644 tests/test_fileio/test_backends/test_base_storage_backend.py create mode 100644 tests/test_fileio/test_backends/test_http_backend.py create mode 100644 tests/test_fileio/test_backends/test_lmdb_backend.py create mode 100644 tests/test_fileio/test_backends/test_local_backend.py create mode 100644 tests/test_fileio/test_backends/test_memcached_backend.py create mode 100644 tests/test_fileio/test_backends/test_petrel_backend.py create mode 100644 tests/test_fileio/test_backends/utils.py create mode 100644 tests/test_fileio/test_io.py diff --git a/mmengine/fileio/__init__.py b/mmengine/fileio/__init__.py index bea658327e..45e68b3e5e 100644 --- a/mmengine/fileio/__init__.py +++ b/mmengine/fileio/__init__.py @@ -1,14 +1,26 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .backends import register_backend from .file_client import (BaseStorageBackend, FileClient, HardDiskBackend, HTTPBackend, LmdbBackend, MemcachedBackend, PetrelBackend) from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler -from .io import dump, load, register_handler +from .io import (copy_if_symlink_fails, copyfile, copyfile_from_local, + copyfile_to_local, copytree, copytree_from_local, + copytree_to_local, dump, exists, generate_presigned_url, + get_bytes, get_file_backend, get_local_path, get_text, isdir, + isfile, join_path, list_dir_or_file, load, put_bytes, + put_text, register_handler, rmfile, rmtree) from .parse import dict_from_file, list_from_file __all__ = [ 'BaseStorageBackend', 'FileClient', 'PetrelBackend', 'MemcachedBackend', - 'LmdbBackend', 'HardDiskBackend', 'HTTPBackend', 'load', 'dump', - 'register_handler', 'BaseFileHandler', 'JsonHandler', 'PickleHandler', - 'YamlHandler', 'list_from_file', 'dict_from_file' + 'LmdbBackend', 'HardDiskBackend', 'HTTPBackend', 'copy_if_symlink_fails', + 'copyfile', 'copyfile_from_local', 'copyfile_to_local', 'copytree', + 'copytree_from_local', 'copytree_to_local', 'exists', + 'generate_presigned_url', 'get_bytes', 'get_file_backend', + 'get_local_path', 'get_text', 'isdir', 'isfile', 'join_path', + 'list_dir_or_file', 'put_bytes', 'put_text', 'rmfile', 'rmtree', 'load', + 'dump', 'register_handler', 'BaseFileHandler', 'JsonHandler', + 'PickleHandler', 'YamlHandler', 'list_from_file', 'dict_from_file', + 'register_backend' ] diff --git a/mmengine/fileio/backends/__init__.py b/mmengine/fileio/backends/__init__.py new file mode 100644 index 0000000000..fa0008977f --- /dev/null +++ b/mmengine/fileio/backends/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import BaseStorageBackend +from .http_backend import HTTPBackend +from .lmdb_backend import LmdbBackend +from .local_backend import LocalBackend +from .memcached_backend import MemcachedBackend +from .petrel_backend import PetrelBackend +from .registry_utils import backends, prefix_to_backends, register_backend + +__all__ = [ + 'BaseStorageBackend', 'LocalBackend', 'HTTPBackend', 'LmdbBackend', + 'MemcachedBackend', 'PetrelBackend', 'register_backend', 'backends', + 'prefix_to_backends' +] diff --git a/mmengine/fileio/backends/base.py b/mmengine/fileio/backends/base.py new file mode 100644 index 0000000000..6846f0f2d1 --- /dev/null +++ b/mmengine/fileio/backends/base.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod + + +class BaseStorageBackend(metaclass=ABCMeta): + """Abstract class of storage backends. + + All backends need to implement two apis: :meth:`get_bytes()` and + :meth:`get_text()`. + + - :meth:`get_bytes()` reads the file as a byte stream. + - :meth:`get_text()` reads the file as texts. + """ + + @property + def name(self): + return self.__class__.__name__ + + @abstractmethod + def get_bytes(self, filepath): + pass + + @abstractmethod + def get_text(self, filepath): + pass diff --git a/mmengine/fileio/backends/http_backend.py b/mmengine/fileio/backends/http_backend.py new file mode 100644 index 0000000000..2f4fa3dc1f --- /dev/null +++ b/mmengine/fileio/backends/http_backend.py @@ -0,0 +1,74 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import tempfile +from contextlib import contextmanager +from pathlib import Path +from typing import Generator, Union +from urllib.request import urlopen + +from .base import BaseStorageBackend + + +class HTTPBackend(BaseStorageBackend): + """HTTP and HTTPS storage bachend.""" + + def get_bytes(self, filepath: str) -> bytes: + """ead bytes from a given ``filepath``. + + Args: + filepath (str): Path to read data. + + Returns: + bytes: Expected bytes object. + + Examples: + >>> backend = HTTPBackend() + >>> backend.get_bytes('http://path/of/file') + b'hello world' + """ + return urlopen(filepath).read() + + def get_text(self, filepath, encoding='utf-8') -> str: + """Read text from a given ``filepath``. + + Args: + filepath (str): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Defaults to 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + + Examples: + >>> backend = HTTPBackend() + >>> backend.get_text('http://path/of/file') + 'hello world' + """ + return urlopen(filepath).read().decode(encoding) + + @contextmanager + def get_local_path( + self, filepath: str) -> Generator[Union[str, Path], None, None]: + """Download a file from ``filepath``. + + ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Args: + filepath (str): Download a file from ``filepath``. + + Examples: + >>> backend = HTTPBackend() + >>> # After existing from the ``with`` clause, + >>> # the path will be removed + >>> with backend.get_local_path('http://path/of/file') as path: + ... # do something here + """ + try: + f = tempfile.NamedTemporaryFile(delete=False) + f.write(self.get_bytes(filepath)) + f.close() + yield f.name + finally: + os.remove(f.name) diff --git a/mmengine/fileio/backends/lmdb_backend.py b/mmengine/fileio/backends/lmdb_backend.py new file mode 100644 index 0000000000..49368920ff --- /dev/null +++ b/mmengine/fileio/backends/lmdb_backend.py @@ -0,0 +1,80 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pathlib import Path +from typing import Union + +from .base import BaseStorageBackend + + +class LmdbBackend(BaseStorageBackend): + """Lmdb storage backend. + + Args: + db_path (str): Lmdb database path. + readonly (bool): Lmdb environment parameter. If True, disallow any + write operations. Defaults to True. + lock (bool): Lmdb environment parameter. If False, when concurrent + access occurs, do not lock the database. Defaults to False. + readahead (bool): Lmdb environment parameter. If False, disable the OS + filesystem readahead mechanism, which may improve random read + performance when a database is larger than RAM. Defaults to False. + + Attributes: + db_path (str): Lmdb database path. + """ + + def __init__(self, + db_path, + readonly=True, + lock=False, + readahead=False, + **kwargs): + try: + import lmdb # noqa: F401 + except ImportError: + raise ImportError( + 'Please run "pip install lmdb" to enable LmdbBackend.') + + self.db_path = str(db_path) + self.readonly = readonly + self.lock = lock + self.readahead = readahead + self.kwargs = kwargs + self._client = None + + def get_bytes(self, filepath: Union[str, Path]) -> bytes: + """Get values according to the filepath. + + Args: + filepath (str or Path): Here, filepath is the lmdb key. + + Returns: + bytes: Expected bytes object. + + Examples: + >>> backend = LmdbBackend('path/to/lmdb') + >>> backend.get_bytes('key') + b'hello world' + """ + if self._client is None: + self._client = self._get_client() + + filepath = str(filepath) + with self._client.begin(write=False) as txn: + value_buf = txn.get(filepath.encode('ascii')) + return value_buf + + def get_text(self, filepath, encoding=None): + raise NotImplementedError + + def _get_client(self): + import lmdb + + return lmdb.open( + self.db_path, + readonly=self.readonly, + lock=self.lock, + readahead=self.readahead, + **self.kwargs) + + def __del__(self): + self._client.close() diff --git a/mmengine/fileio/backends/local_backend.py b/mmengine/fileio/backends/local_backend.py new file mode 100644 index 0000000000..16716275d0 --- /dev/null +++ b/mmengine/fileio/backends/local_backend.py @@ -0,0 +1,522 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +import shutil +from contextlib import contextmanager +from pathlib import Path +from typing import Generator, Iterator, Optional, Tuple, Union + +import mmengine +from .base import BaseStorageBackend + + +class LocalBackend(BaseStorageBackend): + """Raw hard disks storage backend.""" + + def get_bytes(self, filepath: Union[str, Path]) -> bytes: + """Read bytes from a given ``filepath`` with 'rb' mode. + + Args: + filepath (str or Path): Path to read data. + + Returns: + bytes: Expected bytes object. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.get_bytes(filepath) + b'hello world' + """ + with open(filepath, 'rb') as f: + value = f.read() + return value + + def get_text(self, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> str: + """Read text from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Defaults to 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.get_text(filepath) + 'hello world' + """ + with open(filepath, encoding=encoding) as f: + text = f.read() + return text + + def put_bytes(self, obj: bytes, filepath: Union[str, Path]) -> None: + """Write bytes to a given ``filepath`` with 'wb' mode. + + Note: + ``put_bytes`` will create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (bytes): Data to be written. + filepath (str or Path): Path to write data. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.put_bytes(b'hello world', filepath) + """ + mmengine.mkdir_or_exist(osp.dirname(filepath)) + with open(filepath, 'wb') as f: + f.write(obj) + + def put_text(self, + obj: str, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> None: + """Write text to a given ``filepath`` with 'w' mode. + + Note: + ``put_text`` will create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str): The encoding format used to open the ``filepath``. + Defaults to 'utf-8'. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.put_text('hello world', filepath) + """ + mmengine.mkdir_or_exist(osp.dirname(filepath)) + with open(filepath, 'w', encoding=encoding) as f: + f.write(obj) + + def exists(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path exists. + + Args: + filepath (str or Path): Path to be checked whether exists. + + Returns: + bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.exists(filepath) + True + """ + return osp.exists(filepath) + + def isdir(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a directory. + + Args: + filepath (str or Path): Path to be checked whether it is a + directory. + + Returns: + bool: Return ``True`` if ``filepath`` points to a directory, + ``False`` otherwise. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/dir' + >>> backend.isdir(filepath) + True + """ + return osp.isdir(filepath) + + def isfile(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a file. + + Args: + filepath (str or Path): Path to be checked whether it is a file. + + Returns: + bool: Return ``True`` if ``filepath`` points to a file, ``False`` + otherwise. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.isfile(filepath) + True + """ + return osp.isfile(filepath) + + def join_path(self, filepath: Union[str, Path], + *filepaths: Union[str, Path]) -> str: + """Concatenate all file paths. + + Join one or more filepath components intelligently. The return value + is the concatenation of filepath and any members of *filepaths. + + Args: + filepath (str or Path): Path to be concatenated. + + Returns: + str: The result of concatenation. + + Examples: + >>> backend = LocalBackend() + >>> filepath1 = '/path/of/dir1' + >>> filepath2 = 'dir2' + >>> filepath3 = 'path/of/file' + >>> backend.join_path(filepath1, filepath2, filepath3) + '/path/of/dir/dir2/path/of/file' + """ + # TODO, if filepath or filepaths are Path, should return Path + return osp.join(filepath, *filepaths) + + @contextmanager + def get_local_path( + self, + filepath: Union[str, Path], + ) -> Generator[Union[str, Path], None, None]: + """Only for unified API and do nothing. + + Args: + filepath (str or Path): Path to be read data. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Examples: + >>> backend = LocalBackend() + >>> with backend.get_local_path('s3://bucket/abc.jpg') as path: + ... # do something here + """ + yield filepath + + def copyfile( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Copy a file src to dst and return the destination file. + + src and dst should have the same prefix. If dst specifies a directory, + the file will be copied into dst using the base filename from src. If + dst specifies a file that already exists, it will be replaced. + + Args: + src (str or Path): A file to be copied. + dst (str or Path): Copy file to dst. + + Returns: + str: The destination file. + + Raises: + SameFileError: If src and dst are the same file, a SameFileError + will be raised. + + Examples: + >>> backend = LocalBackend() + >>> # dst is a file + >>> src = '/path/of/file' + >>> dst = '/path1/of/file1' + >>> # src will be copied to '/path1/of/file1' + >>> backend.copyfile(src, dst) + '/path1/of/file1' + + >>> # dst is a directory + >>> dst = '/path1/of/dir' + >>> # src will be copied to '/path1/of/dir/file' + >>> backend.copyfile(src, dst) + '/path1/of/dir/file' + """ + return shutil.copy(src, dst) + + def copytree( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Recursively copy an entire directory tree rooted at src to a + directory named dst and return the destination directory. + + src and dst should have the same prefix and dst must not already exist. + + TODO: Whether to support dirs_exist_ok parameter. + + Args: + src (str or Path): A directory to be copied. + dst (str or Path): Copy directory to dst. + + Returns: + str: The destination directory. + + Raises: + FileExistsError: If dst had already existed, a FileExistsError will + be raised. + + Examples: + >>> backend = LocalBackend() + >>> src = '/path/of/dir1' + >>> dst = '/path/of/dir2' + >>> backend.copytree(src, dst) + '/path/of/dir2' + """ + return shutil.copytree(src, dst) + + def copyfile_from_local( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Copy a local file src to dst and return the destination file. Same + as :meth:`copyfile`. + + Args: + src (str or Path): A local file to be copied. + dst (str or Path): Copy file to dst. + + Returns: + str: If dst specifies a directory, the file will be copied into dst + using the base filename from src. + + Raises: + SameFileError: If src and dst are the same file, a SameFileError + will be raised. + + Examples: + >>> backend = LocalBackend() + >>> # dst is a file + >>> src = '/path/of/file' + >>> dst = '/path1/of/file1' + >>> # src will be copied to '/path1/of/file1' + >>> backend.copyfile_from_local(src, dst) + '/path1/of/file1' + + >>> # dst is a directory + >>> dst = '/path1/of/dir' + >>> # src will be copied to + >>> backend.copyfile_from_local(src, dst) + '/path1/of/dir/file' + """ + return self.copyfile(src, dst) + + def copytree_from_local( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Recursively copy an entire directory tree rooted at src to a + directory named dst and return the destination directory. Same as + :meth:`copytree`. + + Args: + src (str or Path): A local directory to be copied. + dst (str or Path): Copy directory to dst. + + Returns: + str: The destination directory. + + Examples: + >>> backend = LocalBackend() + >>> src = '/path/of/dir1' + >>> dst = '/path/of/dir2' + >>> backend.copytree_from_local(src, dst) + '/path/of/dir2' + """ + return self.copytree(src, dst) + + def copyfile_to_local( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Copy the file src to local dst and return the destination file. Same + as :meth:`copyfile`. + + If dst specifies a directory, the file will be copied into dst using + the base filename from src. If dst specifies a file that already + exists, it will be replaced. + + Args: + src (str or Path): A file to be copied. + dst (str or Path): Copy file to to local dst. + + Returns: + str: If dst specifies a directory, the file will be copied into dst + using the base filename from src. + + Examples: + >>> backend = LocalBackend() + >>> # dst is a file + >>> src = '/path/of/file' + >>> dst = '/path1/of/file1' + >>> # src will be copied to '/path1/of/file1' + >>> backend.copyfile_to_local(src, dst) + '/path1/of/file1' + + >>> # dst is a directory + >>> dst = '/path1/of/dir' + >>> # src will be copied to + >>> backend.copyfile_to_local(src, dst) + '/path1/of/dir/file' + """ + return self.copyfile(src, dst) + + def copytree_to_local( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Recursively copy an entire directory tree rooted at src to a local + directory named dst and return the destination directory. + + Args: + src (str or Path): A directory to be copied. + dst (str or Path): Copy directory to local dst. + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + + Returns: + str: The destination directory. + + Examples: + >>> backend = LocalBackend() + >>> src = '/path/of/dir1' + >>> dst = '/path/of/dir2' + >>> backend.copytree_from_local(src, dst) + '/path/of/dir2' + """ + return self.copytree(src, dst) + + def rmfile(self, filepath: Union[str, Path]) -> None: + """Remove a file. + + Args: + filepath (str or Path): Path to be removed. + + Raises: + IsADirectoryError: If filepath is a directory, an IsADirectoryError + will be raised. + FileNotFoundError: If filepath does not exist, an FileNotFoundError + will be raised. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.rmfile(filepath) + """ + os.remove(filepath) + + def rmtree(self, dir_path: Union[str, Path]) -> None: + """Recursively delete a directory tree. + + Args: + dir_path (str or Path): A directory to be removed. + + Examples: + >>> dir_path = '/path/of/dir' + >>> backend.rmtree(dir_path) + """ + shutil.rmtree(dir_path) + + def copy_if_symlink_fails( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> bool: + """Create a symbolic link pointing to src named dst. + + If failed to create a symbolic link pointing to src, directly copy src + to dst instead. + + Args: + src (str or Path): Create a symbolic link pointing to src. + dst (str or Path): Create a symbolic link named dst. + + Returns: + bool: Return True if successfully create a symbolic link pointing + to src. Otherwise, return False. + + Examples: + >>> backend = LocalBackend() + >>> src = '/path/of/file' + >>> dst = '/path1/of/file1' + >>> backend.copy_if_symlink_fails(src, dst) + True + >>> src = '/path/of/dir' + >>> dst = '/path1/of/dir1' + >>> backend.copy_if_symlink_fails(src, dst) + True + """ + try: + os.symlink(src, dst) + return True + except Exception: + if self.isfile(src): + self.copyfile(src, dst) + else: + self.copytree(src, dst) + return False + + def list_dir_or_file(self, + dir_path: Union[str, Path], + list_dir: bool = True, + list_file: bool = True, + suffix: Optional[Union[str, Tuple[str]]] = None, + recursive: bool = False) -> Iterator[str]: + """Scan a directory to find the interested directories or files in + arbitrary order. + + Note: + :meth:`list_dir_or_file` returns the path relative to ``dir_path``. + + Args: + dir_path (str or Path): Path of the directory. + list_dir (bool): List the directories. Defaults to True. + list_file (bool): List the path of files. Defaults to True. + suffix (str or tuple[str], optional): File suffix that we are + interested in. Defaults to None. + recursive (bool): If set to True, recursively scan the directory. + Defaults to False. + + Yields: + Iterable[str]: A relative path to ``dir_path``. + + Examples: + >>> backend = LocalBackend() + >>> dir_path = '/path/of/dir' + >>> for file_path in backend.list_dir_or_file(dir_path): + ... print(file_path) + """ + if list_dir and suffix is not None: + raise TypeError('`suffix` should be None when `list_dir` is True') + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('`suffix` must be a string or tuple of strings') + + root = dir_path + + def _list_dir_or_file(dir_path, list_dir, list_file, suffix, + recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + rel_path = osp.relpath(entry.path, root) + if (suffix is None + or rel_path.endswith(suffix)) and list_file: + yield rel_path + elif osp.isdir(entry.path): + if list_dir: + rel_dir = osp.relpath(entry.path, root) + yield rel_dir + if recursive: + yield from _list_dir_or_file(entry.path, list_dir, + list_file, suffix, + recursive) + + return _list_dir_or_file(dir_path, list_dir, list_file, suffix, + recursive) diff --git a/mmengine/fileio/backends/memcached_backend.py b/mmengine/fileio/backends/memcached_backend.py new file mode 100644 index 0000000000..6e672468f5 --- /dev/null +++ b/mmengine/fileio/backends/memcached_backend.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pathlib import Path +from typing import Union + +from .base import BaseStorageBackend + + +class MemcachedBackend(BaseStorageBackend): + """Memcached storage backend. + + Attributes: + server_list_cfg (str): Config file for memcached server list. + client_cfg (str): Config file for memcached client. + sys_path (str, optional): Additional path to be appended to `sys.path`. + Defaults to None. + """ + + def __init__(self, server_list_cfg, client_cfg, sys_path=None): + if sys_path is not None: + import sys + sys.path.append(sys_path) + try: + import mc + except ImportError: + raise ImportError( + 'Please install memcached to enable MemcachedBackend.') + + self.server_list_cfg = server_list_cfg + self.client_cfg = client_cfg + self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, + self.client_cfg) + # mc.pyvector servers as a point which points to a memory cache + self._mc_buffer = mc.pyvector() + + def get_bytes(self, filepath: Union[str, Path]): + filepath = str(filepath) + import mc + self._client.Get(filepath, self._mc_buffer) + value_buf = mc.ConvertBuffer(self._mc_buffer) + return value_buf + + def get_text(self, filepath, encoding=None): + raise NotImplementedError diff --git a/mmengine/fileio/backends/petrel_backend.py b/mmengine/fileio/backends/petrel_backend.py new file mode 100644 index 0000000000..c708771bf0 --- /dev/null +++ b/mmengine/fileio/backends/petrel_backend.py @@ -0,0 +1,754 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +import re +import tempfile +from contextlib import contextmanager +from pathlib import Path +from shutil import SameFileError +from typing import Generator, Iterator, Optional, Tuple, Union + +import mmengine +from mmengine.utils import has_method +from .base import BaseStorageBackend + + +class PetrelBackend(BaseStorageBackend): + """Petrel storage backend (for internal usage). + + PetrelBackend supports reading and writing data to multiple clusters. + If the file path contains the cluster name, PetrelBackend will read data + from specified cluster or write data to it. Otherwise, PetrelBackend will + access the default cluster. + + Args: + path_mapping (dict, optional): Path mapping dict from local path to + Petrel path. When ``path_mapping={'src': 'dst'}``, ``src`` in + ``filepath`` will be replaced by ``dst``. Defaults to None. + enable_mc (bool, optional): Whether to enable memcached support. + Defaults to True. + + Examples: + >>> backend = PetrelBackend() + >>> filepath1 = 'petrel://path/of/file' + >>> filepath2 = 'cluster-name:petrel://path/of/file' + >>> backend.get_bytes(filepath1) # get data from default cluster + >>> client.get_bytes(filepath2) # get data from 'cluster-name' cluster + """ + + def __init__(self, + path_mapping: Optional[dict] = None, + enable_mc: bool = True): + try: + from petrel_client import client + except ImportError: + raise ImportError('Please install petrel_client to enable ' + 'PetrelBackend.') + + self._client = client.Client(enable_mc=enable_mc) + assert isinstance(path_mapping, dict) or path_mapping is None + self.path_mapping = path_mapping + + def _map_path(self, filepath: Union[str, Path]) -> str: + """Map ``filepath`` to a string path whose prefix will be replaced by + :attr:`self.path_mapping`. + + Args: + filepath (str or Path): Path to be mapped. + """ + filepath = str(filepath) + if self.path_mapping is not None: + for k, v in self.path_mapping.items(): + filepath = filepath.replace(k, v, 1) + return filepath + + def _format_path(self, filepath: str) -> str: + """Convert a ``filepath`` to standard format of petrel oss. + + If the ``filepath`` is concatenated by ``os.path.join``, in a Windows + environment, the ``filepath`` will be the format of + 's3://bucket_name\\image.jpg'. By invoking :meth:`_format_path`, the + above ``filepath`` will be converted to 's3://bucket_name/image.jpg'. + + Args: + filepath (str): Path to be formatted. + """ + return re.sub(r'\\+', '/', filepath) + + def _replace_prefix(self, filepath: Union[str, Path]) -> str: + filepath = str(filepath) + return filepath.replace('petrel://', 's3://') + + def get_bytes(self, filepath: Union[str, Path]) -> bytes: + """Read bytes from a given ``filepath`` with 'rb' mode. + + Args: + filepath (str or Path): Path to read data. + + Returns: + bytes: Return bytes read from filepath. + + Examples: + >>> backend = PetrelBackend() + >>> filepath = 'petrel://path/of/file' + >>> backend.get_bytes(filepath) + b'hello world' + """ + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + filepath = self._replace_prefix(filepath) + value = self._client.Get(filepath) + return value + + def get_text( + self, + filepath: Union[str, Path], + encoding: str = 'utf-8', + ) -> str: + """Read text from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Defaults to 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + + Examples: + >>> backend = PetrelBackend() + >>> filepath = 'petrel://path/of/file' + >>> backend.get_text(filepath) + 'hello world' + """ + return str(self.get_bytes(filepath), encoding=encoding) + + def put_bytes(self, obj: bytes, filepath: Union[str, Path]) -> None: + """Write bytes to a given ``filepath``. + + Args: + obj (bytes): Data to be saved. + filepath (str or Path): Path to write data. + + Examples: + >>> backend = PetrelBackend() + >>> filepath = 'petrel://path/of/file' + >>> backend.put_bytes(b'hello world', filepath) + """ + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + filepath = self._replace_prefix(filepath) + self._client.put(filepath, obj) + + def put_text( + self, + obj: str, + filepath: Union[str, Path], + encoding: str = 'utf-8', + ) -> None: + """Write text to a given ``filepath``. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str): The encoding format used to encode the ``obj``. + Defaults to 'utf-8'. + + Examples: + >>> backend = PetrelBackend() + >>> filepath = 'petrel://path/of/file' + >>> backend.put_text('hello world', filepath) + """ + self.put_bytes(bytes(obj, encoding=encoding), filepath) + + def exists(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path exists. + + Args: + filepath (str or Path): Path to be checked whether exists. + + Returns: + bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. + + Examples: + >>> backend = PetrelBackend() + >>> filepath = 'petrel://path/of/file' + >>> backend.exists(filepath) + True + """ + if not (has_method(self._client, 'contains') + and has_method(self._client, 'isdir')): + raise NotImplementedError( + 'Current version of Petrel Python SDK has not supported ' + 'the `contains` and `isdir` methods, please use a higher' + 'version or dev branch instead.') + + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + filepath = self._replace_prefix(filepath) + return self._client.contains(filepath) or self._client.isdir(filepath) + + def isdir(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a directory. + + Args: + filepath (str or Path): Path to be checked whether it is a + directory. + + Returns: + bool: Return ``True`` if ``filepath`` points to a directory, + ``False`` otherwise. + + Examples: + >>> backend = PetrelBackend() + >>> filepath = 'petrel://path/of/dir' + >>> backend.isdir(filepath) + True + """ + if not has_method(self._client, 'isdir'): + raise NotImplementedError( + 'Current version of Petrel Python SDK has not supported ' + 'the `isdir` method, please use a higher version or dev' + ' branch instead.') + + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + filepath = self._replace_prefix(filepath) + return self._client.isdir(filepath) + + def isfile(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a file. + + Args: + filepath (str or Path): Path to be checked whether it is a file. + + Returns: + bool: Return ``True`` if ``filepath`` points to a file, ``False`` + otherwise. + + Examples: + >>> backend = PetrelBackend() + >>> filepath = 'petrel://path/of/file' + >>> backend.isfile(filepath) + True + """ + if not has_method(self._client, 'contains'): + raise NotImplementedError( + 'Current version of Petrel Python SDK has not supported ' + 'the `contains` method, please use a higher version or ' + 'dev branch instead.') + + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + filepath = self._replace_prefix(filepath) + return self._client.contains(filepath) + + def join_path( + self, + filepath: Union[str, Path], + *filepaths: Union[str, Path], + ) -> str: + """Concatenate all file paths. + + Join one or more filepath components intelligently. The return value + is the concatenation of filepath and any members of *filepaths. + + Args: + filepath (str or Path): Path to be concatenated. + + Returns: + str: The result after concatenation. + + Examples: + >>> backend = PetrelBackend() + >>> filepath = 'petrel://path/of/file' + >>> backend.join_path(filepath, 'another/path') + 'petrel://path/of/file/another/path' + >>> backend.join_path(filepath, '/another/path') + 'petrel://path/of/file/another/path' + """ + filepath = self._format_path(self._map_path(filepath)) + if filepath.endswith('/'): + filepath = filepath[:-1] + formatted_paths = [filepath] + for path in filepaths: + formatted_path = self._format_path(self._map_path(path)) + formatted_paths.append(formatted_path.lstrip('/')) + + return '/'.join(formatted_paths) + + @contextmanager + def get_local_path( + self, + filepath: Union[str, Path], + ) -> Generator[Union[str, Path], None, None]: + """Download a file from ``filepath`` and return a temporary path. + + ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Args: + filepath (str or Path): Download a file from ``filepath``. + + Yields: + Iterable[str]: Only yield one temporary path. + + Examples: + >>> backend = PetrelBackend() + >>> # After existing from the ``with`` clause, + >>> # the path will be removed + >>> filepath = 'petrel://path/of/file' + >>> with backend.get_local_path(filepath) as path: + ... # do something here + """ + assert self.isfile(filepath) + try: + f = tempfile.NamedTemporaryFile(delete=False) + f.write(self.get_bytes(filepath)) + f.close() + yield f.name + finally: + os.remove(f.name) + + def copyfile( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Copy a file src to dst and return the destination file. + + src and dst should have the same prefix. If dst specifies a directory, + the file will be copied into dst using the base filename from src. If + dst specifies a file that already exists, it will be replaced. + + Args: + src (str or Path): A file to be copied. + dst (str or Path): Copy file to dst. + + Returns: + str: The destination file. + + Raises: + SameFileError: If src and dst are the same file, a SameFileError + will be raised. + + Examples: + >>> backend = PetrelBackend() + >>> # dst is a file + >>> src = 'petrel://path/of/file' + >>> dst = 'petrel://path/of/file1' + >>> backend.copyfile(src, dst) + 'petrel://path/of/file1' + + >>> # dst is a directory + >>> dst = 'petrel://path/of/dir' + >>> backend.copyfile(src, dst) + 'petrel://path/of/dir/file' + """ + src = self._format_path(self._map_path(src)) + dst = self._format_path(self._map_path(dst)) + if self.isdir(dst): + dst = self.join_path(dst, src.split('/')[-1]) + + if src == dst: + raise SameFileError('src and dst should not be same') + + self.put_bytes(self.get_bytes(src), dst) + return dst + + def copytree( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Recursively copy an entire directory tree rooted at src to a + directory named dst and return the destination directory. + + src and dst should have the same prefix. + + Args: + src (str or Path): A directory to be copied. + dst (str or Path): Copy directory to dst. + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + + Returns: + str: The destination directory. + + Raises: + FileExistsError: If dst had already existed, a FileExistsError will + be raised. + + Examples: + >>> backend = PetrelBackend() + >>> src = 'petrel://path/of/dir' + >>> dst = 'petrel://path/of/dir1' + >>> backend.copytree(src, dst) + 'petrel://path/of/dir1' + """ + src = self._format_path(self._map_path(src)) + dst = self._format_path(self._map_path(dst)) + + if self.exists(dst): + raise FileExistsError('dst should not exist') + + for path in self.list_dir_or_file(src, list_dir=False, recursive=True): + src_path = self.join_path(src, path) + dst_path = self.join_path(dst, path) + self.put_bytes(self.get_bytes(src_path), dst_path) + + return dst + + def copyfile_from_local( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Upload a local file src to dst and return the destination file. + + Args: + src (str or Path): A local file to be copied. + dst (str or Path): Copy file to dst. + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + + Returns: + str: If dst specifies a directory, the file will be copied into dst + using the base filename from src. + + Examples: + >>> backend = PetrelBackend() + >>> # dst is a file + >>> src = 'path/of/your/file' + >>> dst = 'petrel://path/of/file1' + >>> backend.copyfile_from_local(src, dst) + 'petrel://path/of/file1' + + >>> # dst is a directory + >>> dst = 'petrel://path/of/dir' + >>> backend.copyfile_from_local(src, dst) + 'petrel://path/of/dir/file' + """ + dst = self._format_path(self._map_path(dst)) + if self.isdir(dst): + dst = self.join_path(dst, osp.basename(src)) + + with open(src, 'rb') as f: + self.put_bytes(f.read(), dst) + + return dst + + def copytree_from_local( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Recursively copy an entire directory tree rooted at src to a + directory named dst and return the destination directory. + + Args: + src (str or Path): A local directory to be copied. + dst (str or Path): Copy directory to dst. + + Returns: + str: The destination directory. + + Raises: + FileExistsError: If dst had already existed, a FileExistsError will + be raised. + + Examples: + >>> backend = PetrelBackend() + >>> src = 'path/of/your/dir' + >>> dst = 'petrel://path/of/dir1' + >>> backend.copytree_from_local(src, dst) + 'petrel://path/of/dir1' + """ + dst = self._format_path(self._map_path(dst)) + if self.exists(dst): + raise FileExistsError('dst should not exist') + + src = str(src) + + for cur_dir, _, files in os.walk(src): + for f in files: + src_path = osp.join(cur_dir, f) + dst_path = self.join_path(dst, src_path.replace(src, '')) + self.copyfile_from_local(src_path, dst_path) + + return dst + + def copyfile_to_local( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> Union[str, Path]: + """Copy the file src to local dst and return the destination file. + + If dst specifies a directory, the file will be copied into dst using + the base filename from src. If dst specifies a file that already + exists, it will be replaced. + + Args: + src (str or Path): A file to be copied. + dst (str or Path): Copy file to to local dst. + + Returns: + str: If dst specifies a directory, the file will be copied into dst + using the base filename from src. + + Examples: + >>> backend = PetrelBackend() + >>> # dst is a file + >>> src = 'petrel://path/of/file' + >>> dst = 'path/of/your/file' + >>> backend.copyfile_to_local(src, dst) + 'path/of/your/file' + + >>> # dst is a directory + >>> dst = 'path/of/your/dir' + >>> backend.copyfile_to_local(src, dst) + 'path/of/your/dir/file' + """ + if osp.isdir(dst): + basename = osp.basename(src) + if isinstance(dst, str): + dst = osp.join(dst, basename) + else: + assert isinstance(dst, Path) + dst = dst / basename + + with open(dst, 'wb') as f: + f.write(self.get_bytes(src)) + + return dst + + def copytree_to_local( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> Union[str, Path]: + """Recursively copy an entire directory tree rooted at src to a local + directory named dst and return the destination directory. + + Args: + src (str or Path): A directory to be copied. + dst (str or Path): Copy directory to local dst. + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + + Returns: + str: The destination directory. + + Examples: + >>> backend = PetrelBackend() + >>> src = 'petrel://path/of/dir' + >>> dst = 'path/of/your/dir' + >>> backend.copytree_to_local(src, dst) + 'path/of/your/dir' + """ + for path in self.list_dir_or_file(src, list_dir=False, recursive=True): + dst_path = osp.join(dst, path) + mmengine.mkdir_or_exist(osp.dirname(dst_path)) + with open(dst_path, 'wb') as f: + f.write(self.get_bytes(self.join_path(src, path))) + + return dst + + def rmfile(self, filepath: Union[str, Path]) -> None: + """Remove a file. + + Args: + filepath (str or Path): Path to be removed. + + Raises: + FileNotFoundError: If filepath does not exist, an FileNotFoundError + will be raised. + IsADirectoryError: If filepath is a directory, an IsADirectoryError + will be raised. + + Examples: + >>> backend = PetrelBackend() + >>> filepath = 'petrel://path/of/file' + >>> backend.rmfile(filepath) + """ + if not has_method(self._client, 'delete'): + raise NotImplementedError( + 'Current version of Petrel Python SDK has not supported ' + 'the `delete` method, please use a higher version or dev ' + 'branch instead.') + + if not self.exists(filepath): + raise FileNotFoundError(f'filepath {filepath} does not exist') + + if self.isdir(filepath): + raise IsADirectoryError('filepath should be a file') + + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + filepath = self._replace_prefix(filepath) + self._client.delete(filepath) + + def rmtree(self, dir_path: Union[str, Path]) -> None: + """Recursively delete a directory tree. + + Args: + dir_path (str or Path): A directory to be removed. + + Examples: + >>> backend = PetrelBackend() + >>> dir_path = 'petrel://path/of/dir' + >>> backend.rmtree(dir_path) + """ + for path in self.list_dir_or_file( + dir_path, list_dir=False, recursive=True): + filepath = self.join_path(dir_path, path) + self.rmfile(filepath) + + def copy_if_symlink_fails( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> bool: + """Create a symbolic link pointing to src named dst. + + Directly copy src to dst because PetrelBacekend does not support create + a symbolic link. + + Args: + src (str or Path): A file or directory to be copied. + dst (str or Path): Copy a file or directory to dst. + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + + Returns: + bool: Return False because PetrelBackend does not support create + a symbolic link. + + Examples: + >>> backend = PetrelBackend() + >>> src = 'petrel://path/of/file' + >>> dst = 'petrel://path/of/your/file' + >>> backend.copy_if_symlink_fails(src, dst) + False + >>> src = 'petrel://path/of/dir' + >>> dst = 'petrel://path/of/your/dir' + >>> backend.copy_if_symlink_fails(src, dst) + False + """ + if self.isfile(src): + self.copyfile(src, dst) + else: + self.copytree(src, dst) + return False + + def list_dir_or_file(self, + dir_path: Union[str, Path], + list_dir: bool = True, + list_file: bool = True, + suffix: Optional[Union[str, Tuple[str]]] = None, + recursive: bool = False) -> Iterator[str]: + """Scan a directory to find the interested directories or files in + arbitrary order. + + Note: + Petrel has no concept of directories but it simulates the directory + hierarchy in the filesystem through public prefixes. In addition, + if the returned path ends with '/', it means the path is a public + prefix which is a logical directory. + + Note: + :meth:`list_dir_or_file` returns the path relative to ``dir_path``. + In addition, the returned path of directory will not contains the + suffix '/' which is consistent with other backends. + + Args: + dir_path (str | Path): Path of the directory. + list_dir (bool): List the directories. Defaults to True. + list_file (bool): List the path of files. Defaults to True. + suffix (str or tuple[str], optional): File suffix + that we are interested in. Defaults to None. + recursive (bool): If set to True, recursively scan the + directory. Defaults to False. + + Yields: + Iterable[str]: A relative path to ``dir_path``. + + Examples: + >>> backend = PetrelBackend() + >>> dir_path = 'petrel://path/of/dir' + >>> for path in backend.list_dir_or_file(dir_path): + ... print(path) + """ + if not has_method(self._client, 'list'): + raise NotImplementedError( + 'Current version of Petrel Python SDK has not supported ' + 'the `list` method, please use a higher version or dev' + ' branch instead.') + + dir_path = self._map_path(dir_path) + dir_path = self._format_path(dir_path) + dir_path = self._replace_prefix(dir_path) + if list_dir and suffix is not None: + raise TypeError( + '`list_dir` should be False when `suffix` is not None') + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('`suffix` must be a string or tuple of strings') + + # Petrel's simulated directory hierarchy assumes that directory paths + # should end with `/` + if not dir_path.endswith('/'): + dir_path += '/' + + root = dir_path + + def _list_dir_or_file(dir_path, list_dir, list_file, suffix, + recursive): + for path in self._client.list(dir_path): + # the `self.isdir` is not used here to determine whether path + # is a directory, because `self.isdir` relies on + # `self._client.list` + if path.endswith('/'): # a directory path + next_dir_path = self.join_path(dir_path, path) + if list_dir: + # get the relative path and exclude the last + # character '/' + rel_dir = next_dir_path[len(root):-1] + yield rel_dir + if recursive: + yield from _list_dir_or_file(next_dir_path, list_dir, + list_file, suffix, + recursive) + else: # a file path + absolute_path = self.join_path(dir_path, path) + rel_path = absolute_path[len(root):] + if (suffix is None + or rel_path.endswith(suffix)) and list_file: + yield rel_path + + return _list_dir_or_file(dir_path, list_dir, list_file, suffix, + recursive) + + def generate_presigned_url(self, + url: str, + client_method: str = 'get_object', + expires_in: int = 3600) -> str: + """Generate the presigned url of video stream which can be passed to + mmcv.VideoReader. Now only work on Petrel backend. + + Note: + Now only work on Petrel backend. + + Args: + url (str): Url of video stream. + client_method (str): Method of client, 'get_object' or + 'put_object'. Default: 'get_object'. + expires_in (int): expires, in seconds. Default: 3600. + + Returns: + str: Generated presigned url. + """ + return self._client.generate_presigned_url(url, client_method, + expires_in) diff --git a/mmengine/fileio/backends/registry_utils.py b/mmengine/fileio/backends/registry_utils.py new file mode 100644 index 0000000000..221107a12f --- /dev/null +++ b/mmengine/fileio/backends/registry_utils.py @@ -0,0 +1,115 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import inspect +from typing import Optional, Type, Union + +from .base import BaseStorageBackend +from .http_backend import HTTPBackend +from .lmdb_backend import LmdbBackend +from .local_backend import LocalBackend +from .memcached_backend import MemcachedBackend +from .petrel_backend import PetrelBackend + +backends: dict = {} +prefix_to_backends: dict = {} + + +def _register_backend(name: str, + backend: Type[BaseStorageBackend], + force: bool = False, + prefixes: Union[str, list, tuple, None] = None): + """Register a backend. + + Args: + name (str): The name of the registered backend. + backend (BaseStorageBackend): The backend class to be registered, + which must be a subclass of :class:`BaseStorageBackend`. + force (bool): Whether to override the backend if the name has already + been registered. Defaults to False. + prefixes (str or list[str] or tuple[str], optional): The prefix + of the registered storage backend. Defaults to None. + """ + global backends, prefix_to_backends + + if not isinstance(name, str): + raise TypeError('the backend name should be a string, ' + f'but got {type(name)}') + + if not inspect.isclass(backend): + raise TypeError(f'backend should be a class, but got {type(backend)}') + if not issubclass(backend, BaseStorageBackend): + raise TypeError( + f'backend {backend} is not a subclass of BaseStorageBackend') + + if name in backends and not force: + raise ValueError(f'{name} is already registered as a storage backend, ' + 'add "force=True" if you want to override it') + backends[name] = backend + + if prefixes is not None: + if isinstance(prefixes, str): + prefixes = [prefixes] + else: + assert isinstance(prefixes, (list, tuple)) + + for prefix in prefixes: + if prefix in prefix_to_backends and not force: + raise ValueError( + f'{prefix} is already registered as a storage backend,' + ' add "force=True" if you want to override it') + + prefix_to_backends[prefix] = backend + + +def register_backend(name: str, + backend: Optional[Type[BaseStorageBackend]] = None, + force: bool = False, + prefixes: Union[str, list, tuple, None] = None): + """Register a backend. + + Args: + name (str): The name of the registered backend. + backend (class, optional): The backend class to be registered, + which must be a subclass of :class:`BaseStorageBackend`. + When this method is used as a decorator, backend is None. + Defaults to None. + force (bool): Whether to override the backend if the name has already + been registered. Defaults to False. + prefixes (str or list[str] or tuple[str], optional): The prefix + of the registered storage backend. Defaults to None. + + This method can be used as a normal method or a decorator. + + Examples: + + >>> class NewBackend(BaseStorageBackend): + ... def get(self, filepath): + ... return filepath + ... + ... def get_text(self, filepath): + ... return filepath + >>> register_backend('new', NewBackend) + + >>> @register_backend('new') + ... class NewBackend(BaseStorageBackend): + ... def get(self, filepath): + ... return filepath + ... + ... def get_text(self, filepath): + ... return filepath + """ + if backend is not None: + _register_backend(name, backend, force=force, prefixes=prefixes) + return + + def _register(backend_cls): + _register_backend(name, backend_cls, force=force, prefixes=prefixes) + return backend_cls + + return _register + + +register_backend('local', LocalBackend, prefixes='') +register_backend('memcached', MemcachedBackend) +register_backend('lmdb', LmdbBackend) +register_backend('petrel', PetrelBackend, prefixes='petrel') +register_backend('http', HTTPBackend, prefixes=['http', 'https']) diff --git a/mmengine/fileio/io.py b/mmengine/fileio/io.py index 7374aa158d..ee71cc5df1 100644 --- a/mmengine/fileio/io.py +++ b/mmengine/fileio/io.py @@ -1,11 +1,47 @@ # Copyright (c) OpenMMLab. All rights reserved. +"""This module provides unified file I/O related functions, which support +operating I/O with different file backends based on the specified filepath or +backend_args. + +MMEngine currently supports five file backends: + +- HardDiskBackend +- PetrelBackend +- HTTPBackend +- LmdbBackend +- MemcacheBackend + +Note that this module provide a union of all of the above file backends so +NotImplementedError will be raised if the interface in the file backend is not +implemented. + +There are two ways to call a method of a file backend: + +- Initialize a file backend with ``get_file_backend`` and call its methods. +- Directory call unified I/O functions, which will call ``get_file_backend`` + first and then call the corresponding backend method. + +Examples: + >>> # Initialize a file backend and call its methods + >>> import mmengine.fileio as fileio + >>> backend = fileio.get_file_backend(backend_args={'backend': 'petrel'}) + >>> backend.get_bytes('s3://path/of/your/file') + + >>> # Directory call unified I/O functions + >>> fileio.get_bytes('s3://path/of/your/file') +""" +import json +from contextlib import contextmanager from io import BytesIO, StringIO from pathlib import Path +from typing import Generator, Iterator, Optional, Tuple, Union -from mmengine.utils import is_list_of, is_str +from mmengine.utils import is_filepath, is_list_of, is_str +from .backends import backends, prefix_to_backends from .file_client import FileClient from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler +backend_instances: dict = {} file_handlers = { 'json': JsonHandler(), 'yaml': YamlHandler(), @@ -15,6 +51,730 @@ } +def _parse_uri_prefix(uri: Union[str, Path]) -> str: + """Parse the prefix of uri. + + Args: + uri (str or Path): Uri to be parsed that contains the file prefix. + + Examples: + >>> _parse_uri_prefix('/home/path/of/your/file') + '' + >>> _parse_uri_prefix('s3://path/of/your/file') + 's3' + >>> _parse_uri_prefix('clusterName:s3://path/of/your/file') + 's3' + + Returns: + str: Return the prefix of uri if the uri contains '://'. Otherwise, + return ''. + """ + assert is_filepath(uri) + uri = str(uri) + # if uri does not contains '://', the uri will be handled by + # HardDiskBackend by default + if '://' not in uri: + return '' + else: + prefix, _ = uri.split('://') + # In the case of PetrelBackend, the prefix may contain the cluster + # name like clusterName:s3://path/of/your/file + if ':' in prefix: + _, prefix = prefix.split(':') + return prefix + + +def _get_file_backend(prefix: str, backend_args: dict): + """Return a file backend based on the prefix or backend_args. + + Args: + prefix (str): Prefix of uri. + backend_args (dict): Arguments to instantiate the corresponding + backend. + """ + # backend name has a higher priority + if 'backend' in backend_args: + backend_name = backend_args.pop('backend') + backend = backends[backend_name](**backend_args) + else: + backend = prefix_to_backends[prefix](**backend_args) + return backend + + +def get_file_backend( + uri: Union[str, Path, None] = None, + *, + backend_args: Optional[dict] = None, + enable_singleton: bool = False, +): + """Return a file backend based on the prefix of uri or backend_args. + + Args: + uri (str or Path): Uri to be parsed that contains the file prefix. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + enable_singleton (bool): Whether to enable the singleton pattern. + If it is True, the backend created will be reused if the + signature is same with the previous one. Defaults to False. + + Returns: + BaseStorageBackend: Instantiated Backend object. + + Examples: + >>> # get file backend based on the prefix of uri + >>> uri = 's3://path/of/your/file' + >>> backend = get_file_backend(uri) + >>> # get file backend based on the backend_args + >>> backend = get_file_backend(backend_args={'backend': 'petrel'}) + >>> # backend name has a higher priority if 'backend' in backend_args + >>> backend = get_file_backend(uri, backend_args={'backend': 'petrel'}) + """ + global backend_instances + + if backend_args is None: + backend_args = {} + + if uri is None and 'backend' not in backend_args: + raise ValueError( + 'uri should not be None when "backend" does not exist in ' + 'backend_args') + + if uri is not None: + prefix = _parse_uri_prefix(uri) + else: + prefix = '' + + if enable_singleton: + # TODO: whether to pass sort_key to json.dumps + unique_key = f'{prefix}:{json.dumps(backend_args)}' + if unique_key in backend_instances: + return backend_instances[unique_key] + + backend = _get_file_backend(prefix, backend_args) + backend_instances[unique_key] = backend + return backend + else: + backend = _get_file_backend(prefix, backend_args) + return backend + + +def get_bytes( + filepath: Union[str, Path], + backend_args: Optional[dict] = None, +) -> bytes: + """Read bytes from a given ``filepath`` with 'rb' mode. + + Args: + filepath (str or Path): Path to read data. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + bytes: Expected bytes object. + + Examples: + >>> filepath = '/path/of/file' + >>> get_bytes(filepath) + b'hello world' + """ + backend = get_file_backend( + filepath, backend_args=backend_args, enable_singleton=True) + return backend.get_bytes(filepath) + + +def get_text( + filepath: Union[str, Path], + encoding='utf-8', + backend_args: Optional[dict] = None, +) -> str: + """Read text from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Defaults to 'utf-8'. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: Expected text reading from ``filepath``. + + Examples: + >>> filepath = '/path/of/file' + >>> get_text(filepath) + 'hello world' + """ + backend = get_file_backend( + filepath, backend_args=backend_args, enable_singleton=True) + return backend.get_text(filepath, encoding) + + +def put_bytes( + obj: bytes, + filepath: Union[str, Path], + backend_args: Optional[dict] = None, +) -> None: + """Write bytes to a given ``filepath`` with 'wb' mode. + + Note: + ``put_bytes`` should create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (bytes): Data to be written. + filepath (str or Path): Path to write data. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Examples: + >>> filepath = '/path/of/file' + >>> put_bytes(b'hello world', filepath) + """ + backend = get_file_backend( + filepath, backend_args=backend_args, enable_singleton=True) + backend.put_bytes(obj, filepath) + + +def put_text( + obj: str, + filepath: Union[str, Path], + backend_args: Optional[dict] = None, +) -> None: + """Write text to a given ``filepath`` with 'w' mode. + + Note: + ``put_text`` should create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str, optional): The encoding format used to open the + ``filepath``. Defaults to 'utf-8'. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Examples: + >>> filepath = '/path/of/file' + >>> put_text('hello world', filepath) + """ + backend = get_file_backend( + filepath, backend_args=backend_args, enable_singleton=True) + backend.put_text(obj, filepath) + + +def exists( + filepath: Union[str, Path], + backend_args: Optional[dict] = None, +) -> bool: + """Check whether a file path exists. + + Args: + filepath (str or Path): Path to be checked whether exists. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. + + Examples: + >>> filepath = '/path/of/file' + >>> exists(filepath) + True + """ + backend = get_file_backend( + filepath, backend_args=backend_args, enable_singleton=True) + return backend.exists(filepath) + + +def isdir( + filepath: Union[str, Path], + backend_args: Optional[dict] = None, +) -> bool: + """Check whether a file path is a directory. + + Args: + filepath (str or Path): Path to be checked whether it is a + directory. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + bool: Return ``True`` if ``filepath`` points to a directory, + ``False`` otherwise. + + Examples: + >>> filepath = '/path/of/dir' + >>> isdir(filepath) + True + """ + backend = get_file_backend( + filepath, backend_args=backend_args, enable_singleton=True) + return backend.isdir(filepath) + + +def isfile( + filepath: Union[str, Path], + backend_args: Optional[dict] = None, +) -> bool: + """Check whether a file path is a file. + + Args: + filepath (str or Path): Path to be checked whether it is a file. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + bool: Return ``True`` if ``filepath`` points to a file, ``False`` + otherwise. + + Examples: + >>> filepath = '/path/of/file' + >>> isfile(filepath) + True + """ + backend = get_file_backend( + filepath, backend_args=backend_args, enable_singleton=True) + return backend.isfile(filepath) + + +def join_path( + filepath: Union[str, Path], + *filepaths: Union[str, Path], + backend_args: Optional[dict] = None, +) -> Union[str, Path]: + """Concatenate all file paths. + + Join one or more filepath components intelligently. The return value + is the concatenation of filepath and any members of *filepaths. + + Args: + filepath (str or Path): Path to be concatenated. + *filepaths (str or Path): Other paths to be concatenated. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: The result of concatenation. + + Examples: + >>> filepath1 = '/path/of/dir1' + >>> filepath2 = 'dir2' + >>> filepath3 = 'path/of/file' + >>> join_path(filepath1, filepath2, filepath3) + '/path/of/dir/dir2/path/of/file' + """ + backend = get_file_backend( + filepath, backend_args=backend_args, enable_singleton=True) + return backend.join_path(filepath, *filepaths) + + +@contextmanager +def get_local_path( + filepath: Union[str, Path], + backend_args: Optional[dict] = None, +) -> Generator[Union[str, Path], None, None]: + """Download data from ``filepath`` and write the data to local path. + + ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Note: + If the ``filepath`` is a local path, just return itself and it will + not be released (removed). + + Args: + filepath (str or Path): Path to be read data. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Yields: + Iterable[str]: Only yield one path. + + Examples: + >>> with get_local_path('s3://bucket/abc.jpg') as path: + ... # do something here + """ + backend = get_file_backend( + filepath, backend_args=backend_args, enable_singleton=True) + with backend.get_local_path(str(filepath)) as local_path: + yield local_path + + +def copyfile( + src: Union[str, Path], + dst: Union[str, Path], + backend_args: Optional[dict] = None, +) -> Union[str, Path]: + """Copy a file src to dst and return the destination file. + + src and dst should have the same prefix. If dst specifies a directory, + the file will be copied into dst using the base filename from src. If + dst specifies a file that already exists, it will be replaced. + + Args: + src (str or Path): A file to be copied. + dst (str or Path): Copy file to dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: The destination file. + + Raises: + SameFileError: If src and dst are the same file, a SameFileError will + be raised. + + Examples: + >>> # dst is a file + >>> src = '/path/of/file' + >>> dst = '/path1/of/file1' + >>> # src will be copied to '/path1/of/file1' + >>> copyfile(src, dst) + '/path1/of/file1' + + >>> # dst is a directory + >>> dst = '/path1/of/dir' + >>> # src will be copied to '/path1/of/dir/file' + >>> copyfile(src, dst) + '/path1/of/dir/file' + """ + backend = get_file_backend( + src, backend_args=backend_args, enable_singleton=True) + return backend.copyfile(src, dst) + + +def copytree( + src: Union[str, Path], + dst: Union[str, Path], + backend_args: Optional[dict] = None, +) -> Union[str, Path]: + """Recursively copy an entire directory tree rooted at src to a directory + named dst and return the destination directory. + + src and dst should have the same prefix and dst must not already exist. + + Args: + src (str or Path): A directory to be copied. + dst (str or Path): Copy directory to dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: The destination directory. + + Raises: + FileExistsError: If dst had already existed, a FileExistsError will be + raised. + + Examples: + >>> src = '/path/of/dir1' + >>> dst = '/path/of/dir2' + >>> copytree(src, dst) + '/path/of/dir2' + """ + backend = get_file_backend( + src, backend_args=backend_args, enable_singleton=True) + return backend.copytree(src, dst) + + +def copyfile_from_local( + src: Union[str, Path], + dst: Union[str, Path], + backend_args: Optional[dict] = None, +) -> Union[str, Path]: + """Copy a local file src to dst and return the destination file. + + Note: + If the backend is the instance of HardDiskBackend, it does the same + thing with :func:`copyfile`. + + Args: + src (str or Path): A local file to be copied. + dst (str or Path): Copy file to dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: If dst specifies a directory, the file will be copied into dst + using the base filename from src. + + Examples: + >>> # dst is a file + >>> src = '/path/of/file' + >>> dst = 's3://openmmlab/mmengine/file1' + >>> # src will be copied to 's3://openmmlab/mmengine/file1' + >>> copyfile_from_local(src, dst) + s3://openmmlab/mmengine/file1 + + >>> # dst is a directory + >>> dst = 's3://openmmlab/mmengine' + >>> # src will be copied to 's3://openmmlab/mmengine/file'' + >>> copyfile_from_local(src, dst) + 's3://openmmlab/mmengine/file' + """ + backend = get_file_backend( + dst, backend_args=backend_args, enable_singleton=True) + return backend.copyfile_from_local(src, dst) + + +def copytree_from_local( + src: Union[str, Path], + dst: Union[str, Path], + backend_args: Optional[dict] = None, +) -> Union[str, Path]: + """Recursively copy an entire directory tree rooted at src to a directory + named dst and return the destination directory. + + Note: + If the backend is the instance of HardDiskBackend, it does the same + thing with :func:`copytree`. + + Args: + src (str or Path): A local directory to be copied. + dst (str or Path): Copy directory to dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: The destination directory. + + Examples: + >>> src = '/path/of/dir' + >>> dst = 's3://openmmlab/mmengine/dir' + >>> copyfile_from_local(src, dst) + 's3://openmmlab/mmengine/dir' + """ + backend = get_file_backend( + dst, backend_args=backend_args, enable_singleton=True) + return backend.copytree_from_local(src, dst) + + +def copyfile_to_local( + src: Union[str, Path], + dst: Union[str, Path], + backend_args: Optional[dict] = None, +) -> Union[str, Path]: + """Copy the file src to local dst and return the destination file. + + If dst specifies a directory, the file will be copied into dst using + the base filename from src. If dst specifies a file that already + exists, it will be replaced. + + Note: + If the backend is the instance of HardDiskBackend, it does the same + thing with :func:`copyfile`. + + Args: + src (str or Path): A file to be copied. + dst (str or Path): Copy file to to local dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: If dst specifies a directory, the file will be copied into dst + using the base filename from src. + + Examples: + >>> # dst is a file + >>> src = 's3://openmmlab/mmengine/file' + >>> dst = '/path/of/file' + >>> # src will be copied to '/path/of/file' + >>> copyfile_to_local(src, dst) + '/path/of/file' + + >>> # dst is a directory + >>> dst = '/path/of/dir' + >>> # src will be copied to '/path/of/dir/file' + >>> copyfile_to_local(src, dst) + '/path/of/dir/file' + """ + backend = get_file_backend( + dst, backend_args=backend_args, enable_singleton=True) + return backend.copyfile_to_local(src, dst) + + +def copytree_to_local( + src: Union[str, Path], + dst: Union[str, Path], + backend_args: Optional[dict] = None, +) -> Union[str, Path]: + """Recursively copy an entire directory tree rooted at src to a local + directory named dst and return the destination directory. + + Note: + If the backend is the instance of HardDiskBackend, it does the same + thing with :func:`copytree`. + + Args: + src (str or Path): A directory to be copied. + dst (str or Path): Copy directory to local dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: The destination directory. + + Examples: + >>> src = 's3://openmmlab/mmengine/dir' + >>> dst = '/path/of/dir' + >>> copytree_to_local(src, dst) + '/path/of/dir' + """ + backend = get_file_backend( + dst, backend_args=backend_args, enable_singleton=True) + return backend.copytree_to_local(src, dst) + + +def rmfile( + filepath: Union[str, Path], + backend_args: Optional[dict] = None, +) -> None: + """Remove a file. + + Args: + filepath (str, Path): Path to be removed. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Raises: + FileNotFoundError: If filepath does not exist, an FileNotFoundError + will be raised. + IsADirectoryError: If filepath is a directory, an IsADirectoryError + will be raised. + + Examples: + >>> filepath = '/path/of/file' + >>> rmfile(filepath) + """ + backend = get_file_backend( + filepath, backend_args=backend_args, enable_singleton=True) + backend.rmfile(filepath) + + +def rmtree( + dir_path: Union[str, Path], + backend_args: Optional[dict] = None, +) -> None: + """Recursively delete a directory tree. + + Args: + dir_path (str or Path): A directory to be removed. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Examples: + >>> dir_path = '/path/of/dir' + >>> rmtree(dir_path) + """ + backend = get_file_backend( + dir_path, backend_args=backend_args, enable_singleton=True) + backend.rmtree(dir_path) + + +def copy_if_symlink_fails( + src: Union[str, Path], + dst: Union[str, Path], + backend_args: Optional[dict] = None, +) -> bool: + """Create a symbolic link pointing to src named dst. + + If failed to create a symbolic link pointing to src, directory copy src to + dst instead. + + Args: + src (str or Path): Create a symbolic link pointing to src. + dst (str or Path): Create a symbolic link named dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + bool: Return True if successfully create a symbolic link pointing to + src. Otherwise, return False. + + Examples: + >>> src = '/path/of/file' + >>> dst = '/path1/of/file1' + >>> copy_if_symlink_fails(src, dst) + True + >>> src = '/path/of/dir' + >>> dst = '/path1/of/dir1' + >>> copy_if_symlink_fails(src, dst) + True + """ + backend = get_file_backend( + src, backend_args=backend_args, enable_singleton=True) + return backend.copy_if_symlink_fails(src, dst) + + +def list_dir_or_file( + dir_path: Union[str, Path], + list_dir: bool = True, + list_file: bool = True, + suffix: Optional[Union[str, Tuple[str]]] = None, + recursive: bool = False, + backend_args: Optional[dict] = None, +) -> Iterator[str]: + """Scan a directory to find the interested directories or files in + arbitrary order. + + Note: + :meth:`list_dir_or_file` returns the path relative to ``dir_path``. + + Args: + dir_path (str or Path): Path of the directory. + list_dir (bool): List the directories. Defaults to True. + list_file (bool): List the path of files. Defaults to True. + suffix (str or tuple[str], optional): File suffix that we are + interested in. Defaults to None. + recursive (bool): If set to True, recursively scan the directory. + Defaults to False. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Yields: + Iterable[str]: A relative path to ``dir_path``. + + Examples: + >>> dir_path = '/path/of/dir' + >>> for file_path in list_dir_or_file(dir_path): + ... print(file_path) + """ + backend = get_file_backend( + dir_path, backend_args=backend_args, enable_singleton=True) + yield from backend.list_dir_or_file(dir_path, list_dir, list_file, suffix, + recursive) + + +def generate_presigned_url( + url: str, + client_method: str = 'get_object', + expires_in: int = 3600, + backend_args: Optional[dict] = None, +) -> str: + """Generate the presigned url of video stream which can be passed to + mmcv.VideoReader. Now only work on Petrel backend. + + Note: + Now only work on Petrel backend. + + Args: + url (str): Url of video stream. + client_method (str): Method of client, 'get_object' or + 'put_object'. Default: 'get_object'. + expires_in (int): expires, in seconds. Default: 3600. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: Generated presigned url. + """ + backend = get_file_backend( + url, backend_args=backend_args, enable_singleton=True) + return backend.generate_presigned_url(url, client_method, expires_in) + + def load(file, file_format=None, file_client_args=None, **kwargs): """Load data from json/yaml/pickle files. diff --git a/requirements/tests.txt b/requirements/tests.txt index c5043abf60..debf7eb171 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -1,3 +1,4 @@ coverage lmdb +parameterized pytest diff --git a/tests/test_fileio/test_backends/test_backend_utils.py b/tests/test_fileio/test_backends/test_backend_utils.py new file mode 100644 index 0000000000..310ecf7d47 --- /dev/null +++ b/tests/test_fileio/test_backends/test_backend_utils.py @@ -0,0 +1,114 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest + +from mmengine.fileio.backends import (BaseStorageBackend, backends, + prefix_to_backends, register_backend) + + +def test_register_backend(): + # 1. two ways to register backend + # 1.1 use it as a decorator + @register_backend('example') + class ExampleBackend(BaseStorageBackend): + + def get_bytes(self, filepath): + return filepath + + def get_text(self, filepath): + return filepath + + assert 'example' in backends + + # 1.2 use it as a normal function + class ExampleBackend1(BaseStorageBackend): + + def get_bytes(self, filepath): + return filepath + + def get_text(self, filepath): + return filepath + + register_backend('example1', ExampleBackend1) + assert 'example1' in backends + + # 2. test `name` parameter + # 2. name should a string + with pytest.raises(TypeError, match='name should be a string'): + register_backend(1, ExampleBackend) + + register_backend('example2', ExampleBackend) + assert 'example2' in backends + + # 3. test `backend` parameter + # If backend is not None, it should be a class and a subclass of + # BaseStorageBackend. + with pytest.raises(TypeError, match='backend should be a class'): + + def test_backend(): + pass + + register_backend('example3', test_backend) + + class ExampleBackend2: + + def get_bytes(self, filepath): + return filepath + + def get_text(self, filepath): + return filepath + + with pytest.raises( + TypeError, match='not a subclass of BaseStorageBackend'): + register_backend('example3', ExampleBackend2) + + # 4. test `force` parameter + # 4.1 force=False + with pytest.raises(ValueError, match='example is already registered'): + register_backend('example', ExampleBackend) + + # 4.2 force=True + register_backend('example', ExampleBackend, force=True) + assert 'example' in backends + + # 5. test `prefixes` parameter + class ExampleBackend3(BaseStorageBackend): + + def get_bytes(self, filepath): + return filepath + + def get_text(self, filepath): + return filepath + + # 5.1 prefixes is a string + register_backend('example3', ExampleBackend3, prefixes='prefix1') + assert 'example3' in backends + assert 'prefix1' in prefix_to_backends + + # 5.2 prefixes is a list (tuple) of strings + register_backend( + 'example4', ExampleBackend3, prefixes=['prefix2', 'prefix3']) + assert 'example4' in backends + assert 'prefix2' in prefix_to_backends + assert 'prefix3' in prefix_to_backends + assert prefix_to_backends['prefix2'] == prefix_to_backends['prefix3'] + + # 5.3 prefixes is an invalid type + with pytest.raises(AssertionError): + register_backend('example5', ExampleBackend3, prefixes=1) + + # 5.4 prefixes is already registered + with pytest.raises(ValueError, match='prefix2 is already registered'): + register_backend('example6', ExampleBackend3, prefixes='prefix2') + + class ExampleBackend4(BaseStorageBackend): + + def get_bytes(self, filepath): + return filepath + + def get_text(self, filepath): + return filepath + + register_backend( + 'example6', ExampleBackend4, prefixes='prefix2', force=True) + assert 'example6' in backends + assert 'prefix2' in prefix_to_backends diff --git a/tests/test_fileio/test_backends/test_base_storage_backend.py b/tests/test_fileio/test_backends/test_base_storage_backend.py new file mode 100644 index 0000000000..518c1da5d5 --- /dev/null +++ b/tests/test_fileio/test_backends/test_base_storage_backend.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest + +from mmengine.fileio.backends import BaseStorageBackend + + +def test_base_storage_backend(): + # test inheritance + class ExampleBackend(BaseStorageBackend): + pass + + with pytest.raises( + TypeError, + match="Can't instantiate abstract class ExampleBackend"): + ExampleBackend() + + class ExampleBackend(BaseStorageBackend): + + def get_bytes(self, filepath): + return filepath + + def get_text(self, filepath): + return filepath + + backend = ExampleBackend() + assert backend.get_bytes('test') == 'test' + assert backend.get_text('test') == 'test' diff --git a/tests/test_fileio/test_backends/test_http_backend.py b/tests/test_fileio/test_backends/test_http_backend.py new file mode 100644 index 0000000000..a92cc26245 --- /dev/null +++ b/tests/test_fileio/test_backends/test_http_backend.py @@ -0,0 +1,51 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pathlib import Path +from unittest import TestCase + +import cv2 +import numpy as np + +from mmengine.fileio.backends import HTTPBackend + + +def imfrombytes(content): + img_np = np.frombuffer(content, np.uint8) + img = cv2.imdecode(img_np, cv2.IMREAD_COLOR) + return img + + +def imread(path): + with open(path, 'rb') as f: + content = f.read() + img = imfrombytes(content) + return img + + +class TestHTTPBackend(TestCase): + + @classmethod + def setUpClass(cls): + cls.img_url = ( + 'https://download.openmmlab.com/mmengine/test-data/color.jpg') + cls.img_shape = (300, 400, 3) + cls.text_url = ( + 'https://download.openmmlab.com/mmengine/test-data/filelist.txt') + cls.test_data_dir = Path(__file__).parent.parent.parent / 'data' + cls.text_path = cls.test_data_dir / 'filelist.txt' + + def test_get_bytes(self): + backend = HTTPBackend() + img_bytes = backend.get_bytes(self.img_url) + img = imfrombytes(img_bytes) + self.assertEqual(img.shape, self.img_shape) + + def test_get_text(self): + backend = HTTPBackend() + text = backend.get_text(self.text_url) + self.assertEqual(self.text_path.open('r').read(), text) + + def test_get_local_path(self): + backend = HTTPBackend() + with backend.get_local_path(self.img_url) as filepath: + img = imread(filepath) + self.assertEqual(img.shape, self.img_shape) diff --git a/tests/test_fileio/test_backends/test_lmdb_backend.py b/tests/test_fileio/test_backends/test_lmdb_backend.py new file mode 100644 index 0000000000..6847945874 --- /dev/null +++ b/tests/test_fileio/test_backends/test_lmdb_backend.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pathlib import Path +from unittest import TestCase + +import cv2 +import numpy as np +from parameterized import parameterized + +from mmengine.fileio.backends import LmdbBackend + + +def imfrombytes(content): + img_np = np.frombuffer(content, np.uint8) + img = cv2.imdecode(img_np, cv2.IMREAD_COLOR) + return img + + +class TestLmdbBackend(TestCase): + + @classmethod + def setUpClass(cls): + cls.test_data_dir = Path(__file__).parent.parent.parent / 'data' + cls.lmdb_path = cls.test_data_dir / 'demo.lmdb' + + @parameterized.expand([[Path], [str]]) + def test_get_bytes(self, path_type): + backend = LmdbBackend(path_type(self.lmdb_path)) + img_bytes = backend.get_bytes('baboon') + img = imfrombytes(img_bytes) + self.assertEqual(img.shape, (120, 125, 3)) + + def test_get_text(self): + backend = LmdbBackend(self.lmdb_path) + with self.assertRaises(NotImplementedError): + backend.get_text('filepath') diff --git a/tests/test_fileio/test_backends/test_local_backend.py b/tests/test_fileio/test_backends/test_local_backend.py new file mode 100644 index 0000000000..41e37fe014 --- /dev/null +++ b/tests/test_fileio/test_backends/test_local_backend.py @@ -0,0 +1,486 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +import platform +import tempfile +from contextlib import contextmanager +from pathlib import Path +from shutil import SameFileError +from unittest import TestCase +from unittest.mock import patch + +import cv2 +import numpy as np +from parameterized import parameterized + +from mmengine.fileio.backends import LocalBackend + + +def imfrombytes(content): + img_np = np.frombuffer(content, np.uint8) + img = cv2.imdecode(img_np, cv2.IMREAD_COLOR) + return img + + +@contextmanager +def build_temporary_directory(): + """Build a temporary directory containing many files to test + ``FileClient.list_dir_or_file``. + + . \n + | -- dir1 \n + | -- | -- text3.txt \n + | -- dir2 \n + | -- | -- dir3 \n + | -- | -- | -- text4.txt \n + | -- | -- img.jpg \n + | -- text1.txt \n + | -- text2.txt \n + """ + with tempfile.TemporaryDirectory() as tmp_dir: + text1 = Path(tmp_dir) / 'text1.txt' + text1.open('w').write('text1') + text2 = Path(tmp_dir) / 'text2.txt' + text2.open('w').write('text2') + dir1 = Path(tmp_dir) / 'dir1' + dir1.mkdir() + text3 = dir1 / 'text3.txt' + text3.open('w').write('text3') + dir2 = Path(tmp_dir) / 'dir2' + dir2.mkdir() + jpg1 = dir2 / 'img.jpg' + jpg1.open('wb').write(b'img') + dir3 = dir2 / 'dir3' + dir3.mkdir() + text4 = dir3 / 'text4.txt' + text4.open('w').write('text4') + yield tmp_dir + + +class TestLocalBackend(TestCase): + + @classmethod + def setUpClass(cls): + cls.test_data_dir = Path(__file__).parent.parent.parent / 'data' + cls.img_path = cls.test_data_dir / 'color.jpg' + cls.img_shape = (300, 400, 3) + cls.text_path = cls.test_data_dir / 'filelist.txt' + + def test_name(self): + backend = LocalBackend() + self.assertEqual(backend.name, 'LocalBackend') + + @parameterized.expand([[Path], [str]]) + def test_get_bytes(self, path_type): + backend = LocalBackend() + img_bytes = backend.get_bytes(path_type(self.img_path)) + self.assertEqual(self.img_path.open('rb').read(), img_bytes) + img = imfrombytes(img_bytes) + self.assertEqual(img.shape, self.img_shape) + + @parameterized.expand([[Path], [str]]) + def test_get_text(self, path_type): + backend = LocalBackend() + text = backend.get_text(path_type(self.text_path)) + self.assertEqual(self.text_path.open('r').read(), text) + + @parameterized.expand([[Path], [str]]) + def test_put_bytes(self, path_type): + backend = LocalBackend() + + with tempfile.TemporaryDirectory() as tmp_dir: + filepath = Path(tmp_dir) / 'test.jpg' + backend.put_bytes(b'disk', path_type(filepath)) + self.assertEqual(backend.get_bytes(filepath), b'disk') + + # If the directory does not exist, put_bytes will create a + # directory first + filepath = Path(tmp_dir) / 'not_existed_dir' / 'test.jpg' + backend.put_bytes(b'disk', path_type(filepath)) + self.assertEqual(backend.get_bytes(filepath), b'disk') + + @parameterized.expand([[Path], [str]]) + def test_put_text(self, path_type): + backend = LocalBackend() + + with tempfile.TemporaryDirectory() as tmp_dir: + filepath = Path(tmp_dir) / 'test.txt' + backend.put_text('disk', path_type(filepath)) + self.assertEqual(backend.get_text(filepath), 'disk') + + # If the directory does not exist, put_text will create a + # directory first + filepath = Path(tmp_dir) / 'not_existed_dir' / 'test.txt' + backend.put_text('disk', path_type(filepath)) + self.assertEqual(backend.get_text(filepath), 'disk') + + @parameterized.expand([[Path], [str]]) + def test_exists(self, path_type): + backend = LocalBackend() + with tempfile.TemporaryDirectory() as tmp_dir: + self.assertTrue(backend.exists(path_type(tmp_dir))) + filepath = Path(tmp_dir) / 'test.txt' + self.assertFalse(backend.exists(path_type(filepath))) + backend.put_text('disk', filepath) + self.assertTrue(backend.exists(path_type(filepath))) + backend.rmfile(filepath) + + @parameterized.expand([[Path], [str]]) + def test_isdir(self, path_type): + backend = LocalBackend() + with tempfile.TemporaryDirectory() as tmp_dir: + self.assertTrue(backend.isdir(path_type(tmp_dir))) + filepath = Path(tmp_dir) / 'test.txt' + backend.put_text('disk', filepath) + self.assertFalse(backend.isdir(path_type(filepath))) + + @parameterized.expand([[Path], [str]]) + def test_isfile(self, path_type): + backend = LocalBackend() + with tempfile.TemporaryDirectory() as tmp_dir: + self.assertFalse(backend.isfile(path_type(tmp_dir))) + filepath = Path(tmp_dir) / 'test.txt' + backend.put_text('disk', filepath) + self.assertTrue(backend.isfile(path_type(filepath))) + + @parameterized.expand([[Path], [str]]) + def test_join_path(self, path_type): + backend = LocalBackend() + filepath = backend.join_path( + path_type(self.test_data_dir), path_type('file')) + expected = osp.join(path_type(self.test_data_dir), path_type('file')) + self.assertEqual(filepath, expected) + + filepath = backend.join_path( + path_type(self.test_data_dir), path_type('dir'), path_type('file')) + expected = osp.join( + path_type(self.test_data_dir), path_type('dir'), path_type('file')) + self.assertEqual(filepath, expected) + + @parameterized.expand([[Path], [str]]) + def test_get_local_path(self, path_type): + backend = LocalBackend() + with backend.get_local_path(path_type(self.text_path)) as filepath: + self.assertEqual(path_type(self.text_path), path_type(filepath)) + + @parameterized.expand([[Path], [str]]) + def test_copyfile(self, path_type): + backend = LocalBackend() + with tempfile.TemporaryDirectory() as tmp_dir: + src = Path(tmp_dir) / 'test.txt' + backend.put_text('disk', src) + dst = Path(tmp_dir) / 'test.txt.bak' + self.assertEqual( + backend.copyfile(path_type(src), path_type(dst)), + path_type(dst)) + self.assertEqual(backend.get_text(dst), 'disk') + + # dst is a directory + dst = Path(tmp_dir) / 'dir' + dst.mkdir() + self.assertEqual( + backend.copyfile(path_type(src), path_type(dst)), + backend.join_path(path_type(dst), 'test.txt')) + self.assertEqual( + backend.get_text(backend.join_path(dst, 'test.txt')), 'disk') + + # src and src should not be same file + with self.assertRaises(SameFileError): + backend.copyfile(path_type(src), path_type(src)) + + @parameterized.expand([[Path], [str]]) + def test_copytree(self, path_type): + backend = LocalBackend() + with build_temporary_directory() as tmp_dir: + # src and dst are Path objects + src = Path(tmp_dir) / 'dir1' + dst = Path(tmp_dir) / 'dir100' + self.assertEqual( + backend.copytree(path_type(src), path_type(dst)), + path_type(dst)) + self.assertTrue(backend.isdir(dst)) + self.assertTrue(backend.isfile(dst / 'text3.txt')) + self.assertEqual(backend.get_text(dst / 'text3.txt'), 'text3') + + # dst should not exist + with self.assertRaises(FileExistsError): + backend.copytree( + path_type(src), path_type(Path(tmp_dir) / 'dir2')) + + @parameterized.expand([[Path], [str]]) + def test_copyfile_from_local(self, path_type): + backend = LocalBackend() + with tempfile.TemporaryDirectory() as tmp_dir: + src = Path(tmp_dir) / 'test.txt' + backend.put_text('disk', src) + dst = Path(tmp_dir) / 'test.txt.bak' + self.assertEqual( + backend.copyfile(path_type(src), path_type(dst)), + path_type(dst)) + self.assertEqual(backend.get_text(dst), 'disk') + + dst = Path(tmp_dir) / 'dir' + dst.mkdir() + self.assertEqual( + backend.copyfile(path_type(src), path_type(dst)), + backend.join_path(path_type(dst), 'test.txt')) + self.assertEqual( + backend.get_text(backend.join_path(dst, 'test.txt')), 'disk') + + # src and src should not be same file + with self.assertRaises(SameFileError): + backend.copyfile(path_type(src), path_type(src)) + + @parameterized.expand([[Path], [str]]) + def test_copytree_from_local(self, path_type): + backend = LocalBackend() + with build_temporary_directory() as tmp_dir: + # src and dst are Path objects + src = Path(tmp_dir) / 'dir1' + dst = Path(tmp_dir) / 'dir100' + self.assertEqual( + backend.copytree(path_type(src), path_type(dst)), + path_type(dst)) + self.assertTrue(backend.isdir(dst)) + self.assertTrue(backend.isfile(dst / 'text3.txt')) + self.assertEqual(backend.get_text(dst / 'text3.txt'), 'text3') + + # dst should not exist + with self.assertRaises(FileExistsError): + backend.copytree( + path_type(src), path_type(Path(tmp_dir) / 'dir2')) + + @parameterized.expand([[Path], [str]]) + def test_copyfile_to_local(self, path_type): + backend = LocalBackend() + with tempfile.TemporaryDirectory() as tmp_dir: + src = Path(tmp_dir) / 'test.txt' + backend.put_text('disk', src) + dst = Path(tmp_dir) / 'test.txt.bak' + self.assertEqual( + backend.copyfile(path_type(src), path_type(dst)), + path_type(dst)) + self.assertEqual(backend.get_text(dst), 'disk') + + dst = Path(tmp_dir) / 'dir' + dst.mkdir() + self.assertEqual( + backend.copyfile(path_type(src), path_type(dst)), + backend.join_path(path_type(dst), 'test.txt')) + self.assertEqual( + backend.get_text(backend.join_path(dst, 'test.txt')), 'disk') + + # src and src should not be same file + with self.assertRaises(SameFileError): + backend.copyfile(path_type(src), path_type(src)) + + @parameterized.expand([[Path], [str]]) + def test_copytree_to_local(self, path_type): + backend = LocalBackend() + with build_temporary_directory() as tmp_dir: + # src and dst are Path objects + src = Path(tmp_dir) / 'dir1' + dst = Path(tmp_dir) / 'dir100' + self.assertEqual( + backend.copytree(path_type(src), path_type(dst)), + path_type(dst)) + self.assertTrue(backend.isdir(dst)) + self.assertTrue(backend.isfile(dst / 'text3.txt')) + self.assertEqual(backend.get_text(dst / 'text3.txt'), 'text3') + + # dst should not exist + with self.assertRaises(FileExistsError): + backend.copytree( + path_type(src), path_type(Path(tmp_dir) / 'dir2')) + + @parameterized.expand([[Path], [str]]) + def test_rmfile(self, path_type): + backend = LocalBackend() + with tempfile.TemporaryDirectory() as tmp_dir: + # filepath is a Path object + filepath = Path(tmp_dir) / 'test.txt' + backend.put_text('disk', filepath) + self.assertTrue(backend.exists(filepath)) + backend.rmfile(path_type(filepath)) + self.assertFalse(backend.exists(filepath)) + + # raise error if file does not exist + with self.assertRaises(FileNotFoundError): + filepath = Path(tmp_dir) / 'test1.txt' + backend.rmfile(path_type(filepath)) + + # can not remove directory + filepath = Path(tmp_dir) / 'dir' + filepath.mkdir() + with self.assertRaises(IsADirectoryError): + backend.rmfile(path_type(filepath)) + + @parameterized.expand([[Path], [str]]) + def test_rmtree(self, path_type): + backend = LocalBackend() + with build_temporary_directory() as tmp_dir: + # src and dst are Path objects + dir_path = Path(tmp_dir) / 'dir1' + self.assertTrue(backend.exists(dir_path)) + backend.rmtree(path_type(dir_path)) + self.assertFalse(backend.exists(dir_path)) + + dir_path = Path(tmp_dir) / 'dir2' + self.assertTrue(backend.exists(dir_path)) + backend.rmtree(path_type(dir_path)) + self.assertFalse(backend.exists(dir_path)) + + @parameterized.expand([[Path], [str]]) + def test_copy_if_symlink_fails(self, path_type): + backend = LocalBackend() + with tempfile.TemporaryDirectory() as tmp_dir: + # create a symlink for a file + src = Path(tmp_dir) / 'test.txt' + backend.put_text('disk', src) + dst = Path(tmp_dir) / 'test_link.txt' + res = backend.copy_if_symlink_fails(path_type(src), path_type(dst)) + if platform.system() == 'Linux': + self.assertTrue(res) + self.assertTrue(osp.islink(dst)) + self.assertEqual(backend.get_text(dst), 'disk') + + # create a symlink for a directory + src = Path(tmp_dir) / 'dir' + src.mkdir() + dst = Path(tmp_dir) / 'dir_link' + res = backend.copy_if_symlink_fails(path_type(src), path_type(dst)) + if platform.system() == 'Linux': + self.assertTrue(res) + self.assertTrue(osp.islink(dst)) + self.assertTrue(backend.exists(dst)) + + def symlink(src, dst): + raise Exception + + # copy files if symblink fails + with patch.object(os, 'symlink', side_effect=symlink): + src = Path(tmp_dir) / 'test.txt' + dst = Path(tmp_dir) / 'test_link1.txt' + res = backend.copy_if_symlink_fails( + path_type(src), path_type(dst)) + self.assertFalse(res) + self.assertFalse(osp.islink(dst)) + self.assertTrue(backend.exists(dst)) + + # copy directory if symblink fails + with patch.object(os, 'symlink', side_effect=symlink): + src = Path(tmp_dir) / 'dir' + dst = Path(tmp_dir) / 'dir_link1' + res = backend.copy_if_symlink_fails( + path_type(src), path_type(dst)) + self.assertFalse(res) + self.assertFalse(osp.islink(dst)) + self.assertTrue(backend.exists(dst)) + + @parameterized.expand([[Path], [str]]) + def test_list_dir_or_file(self, path_type): + backend = LocalBackend() + with build_temporary_directory() as tmp_dir: + # list directories and files + self.assertEqual( + set(backend.list_dir_or_file(path_type(tmp_dir))), + {'dir1', 'dir2', 'text1.txt', 'text2.txt'}) + + # list directories and files recursively + self.assertEqual( + set( + backend.list_dir_or_file( + path_type(tmp_dir), recursive=True)), + { + 'dir1', + osp.join('dir1', 'text3.txt'), 'dir2', + osp.join('dir2', 'dir3'), + osp.join('dir2', 'dir3', 'text4.txt'), + osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' + }) + + # only list directories + self.assertEqual( + set( + backend.list_dir_or_file( + path_type(tmp_dir), list_file=False)), + {'dir1', 'dir2'}) + + with self.assertRaisesRegex( + TypeError, + '`suffix` should be None when `list_dir` is True'): + backend.list_dir_or_file( + path_type(tmp_dir), list_file=False, suffix='.txt') + + # only list directories recursively + self.assertEqual( + set( + backend.list_dir_or_file( + path_type(tmp_dir), list_file=False, recursive=True)), + {'dir1', 'dir2', osp.join('dir2', 'dir3')}) + + # only list files + self.assertEqual( + set( + backend.list_dir_or_file( + path_type(tmp_dir), list_dir=False)), + {'text1.txt', 'text2.txt'}) + + # only list files recursively + self.assertEqual( + set( + backend.list_dir_or_file( + path_type(tmp_dir), list_dir=False, recursive=True)), + { + osp.join('dir1', 'text3.txt'), + osp.join('dir2', 'dir3', 'text4.txt'), + osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' + }) + + # only list files ending with suffix + self.assertEqual( + set( + backend.list_dir_or_file( + path_type(tmp_dir), list_dir=False, suffix='.txt')), + {'text1.txt', 'text2.txt'}) + self.assertEqual( + set( + backend.list_dir_or_file( + path_type(tmp_dir), + list_dir=False, + suffix=('.txt', '.jpg'))), {'text1.txt', 'text2.txt'}) + + with self.assertRaisesRegex( + TypeError, + '`suffix` must be a string or tuple of strings'): + backend.list_dir_or_file( + path_type(tmp_dir), + list_dir=False, + suffix=['.txt', '.jpg']) + + # only list files ending with suffix recursively + self.assertEqual( + set( + backend.list_dir_or_file( + path_type(tmp_dir), + list_dir=False, + suffix='.txt', + recursive=True)), { + osp.join('dir1', 'text3.txt'), + osp.join('dir2', 'dir3', 'text4.txt'), 'text1.txt', + 'text2.txt' + }) + + # only list files ending with suffix + self.assertEqual( + set( + backend.list_dir_or_file( + path_type(tmp_dir), + list_dir=False, + suffix=('.txt', '.jpg'), + recursive=True)), + { + osp.join('dir1', 'text3.txt'), + osp.join('dir2', 'dir3', 'text4.txt'), + osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' + }) diff --git a/tests/test_fileio/test_backends/test_memcached_backend.py b/tests/test_fileio/test_backends/test_memcached_backend.py new file mode 100644 index 0000000000..30cf9793ec --- /dev/null +++ b/tests/test_fileio/test_backends/test_memcached_backend.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import sys +from pathlib import Path +from unittest import TestCase +from unittest.mock import MagicMock, patch + +import cv2 +import numpy as np +from parameterized import parameterized + +from mmengine.fileio.backends import MemcachedBackend + + +def imfrombytes(content): + img_np = np.frombuffer(content, np.uint8) + img = cv2.imdecode(img_np, cv2.IMREAD_COLOR) + return img + + +sys.modules['mc'] = MagicMock() + + +class MockMemcachedClient: + + def __init__(self, server_list_cfg, client_cfg): + pass + + def Get(self, filepath, buffer): + with open(filepath, 'rb') as f: + buffer.content = f.read() + + +class TestMemcachedBackend(TestCase): + + @classmethod + def setUpClass(cls): + cls.mc_cfg = dict(server_list_cfg='', client_cfg='', sys_path=None) + cls.test_data_dir = Path(__file__).parent.parent.parent / 'data' + cls.img_path = cls.test_data_dir / 'color.jpg' + cls.img_shape = (300, 400, 3) + + @parameterized.expand([[Path], [str]]) + @patch('mc.MemcachedClient.GetInstance', MockMemcachedClient) + @patch('mc.pyvector', MagicMock) + @patch('mc.ConvertBuffer', lambda x: x.content) + def test_get_bytes(self, path_type): + backend = MemcachedBackend(**self.mc_cfg) + img_bytes = backend.get_bytes(path_type(self.img_path)) + self.assertEqual(self.img_path.open('rb').read(), img_bytes) + img = imfrombytes(img_bytes) + self.assertEqual(img.shape, self.img_shape) + + @patch('mc.MemcachedClient.GetInstance', MockMemcachedClient) + @patch('mc.pyvector', MagicMock) + @patch('mc.ConvertBuffer', lambda x: x.content) + def test_get_text(self): + backend = MemcachedBackend(**self.mc_cfg) + with self.assertRaises(NotImplementedError): + backend.get_text('filepath') diff --git a/tests/test_fileio/test_backends/test_petrel_backend.py b/tests/test_fileio/test_backends/test_petrel_backend.py new file mode 100644 index 0000000000..197dd45086 --- /dev/null +++ b/tests/test_fileio/test_backends/test_petrel_backend.py @@ -0,0 +1,861 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +import sys +import tempfile +from contextlib import contextmanager +from copy import deepcopy +from pathlib import Path +from shutil import SameFileError +from unittest import TestCase +from unittest.mock import MagicMock, patch + +from mmengine.fileio.backends import PetrelBackend +from mmengine.utils import has_method + + +@contextmanager +def build_temporary_directory(): + """Build a temporary directory containing many files to test + ``FileClient.list_dir_or_file``. + + . \n + | -- dir1 \n + | -- | -- text3.txt \n + | -- dir2 \n + | -- | -- dir3 \n + | -- | -- | -- text4.txt \n + | -- | -- img.jpg \n + | -- text1.txt \n + | -- text2.txt \n + """ + with tempfile.TemporaryDirectory() as tmp_dir: + text1 = Path(tmp_dir) / 'text1.txt' + text1.open('w').write('text1') + text2 = Path(tmp_dir) / 'text2.txt' + text2.open('w').write('text2') + dir1 = Path(tmp_dir) / 'dir1' + dir1.mkdir() + text3 = dir1 / 'text3.txt' + text3.open('w').write('text3') + dir2 = Path(tmp_dir) / 'dir2' + dir2.mkdir() + jpg1 = dir2 / 'img.jpg' + jpg1.open('wb').write(b'img') + dir3 = dir2 / 'dir3' + dir3.mkdir() + text4 = dir3 / 'text4.txt' + text4.open('w').write('text4') + yield tmp_dir + + +try: + # Other unit tests may mock these modules so we need to pop them first. + sys.modules.pop('petrel_client', None) + sys.modules.pop('petrel_client.client', None) + + # If petrel_client is imported successfully, we can test PetrelBackend + # without mock. + import petrel_client # noqa: F401 +except ImportError: + sys.modules['petrel_client'] = MagicMock() + sys.modules['petrel_client.client'] = MagicMock() + + class MockPetrelClient: + + def __init__(self, enable_mc=True, enable_multi_cluster=False): + self.enable_mc = enable_mc + self.enable_multi_cluster = enable_multi_cluster + + def Get(self, filepath): + with open(filepath, 'rb') as f: + content = f.read() + return content + + def put(self): + pass + + def delete(self): + pass + + def contains(self): + pass + + def isdir(self): + pass + + def list(self, dir_path): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + yield entry.name + elif osp.isdir(entry.path): + yield entry.name + '/' + + @contextmanager + def delete_and_reset_method(obj, method): + method_obj = deepcopy(getattr(type(obj), method)) + try: + delattr(type(obj), method) + yield + finally: + setattr(type(obj), method, method_obj) + + @patch('petrel_client.client.Client', MockPetrelClient) + class TestPetrelBackend(TestCase): + + @classmethod + def setUpClass(cls): + cls.test_data_dir = Path(__file__).parent.parent.parent / 'data' + cls.img_path = cls.test_data_dir / 'color.jpg' + cls.img_shape = (300, 400, 3) + cls.text_path = cls.test_data_dir / 'filelist.txt' + cls.petrel_dir = 'petrel://user/data' + cls.petrel_path = f'{cls.petrel_dir}/test.jpg' + cls.expected_dir = 's3://user/data' + cls.expected_path = f'{cls.expected_dir}/test.jpg' + + def test_name(self): + backend = PetrelBackend() + self.assertEqual(backend.name, 'PetrelBackend') + + def test_map_path(self): + backend = PetrelBackend(path_mapping=None) + self.assertEqual( + backend._map_path(self.petrel_path), self.petrel_path) + + backend = PetrelBackend( + path_mapping={'data/': 'petrel://user/data/'}) + self.assertEqual( + backend._map_path('data/test.jpg'), self.petrel_path) + + def test_format_path(self): + backend = PetrelBackend() + formatted_filepath = backend._format_path( + 'petrel://user\\data\\test.jpg') + self.assertEqual(formatted_filepath, self.petrel_path) + + def test_replace_prefix(self): + backend = PetrelBackend() + self.assertEqual( + backend._replace_prefix(self.petrel_path), self.expected_path) + + def test_join_path(self): + backend = PetrelBackend() + self.assertEqual( + backend.join_path(self.petrel_dir, 'file'), + f'{self.petrel_dir}/file') + self.assertEqual( + backend.join_path(f'{self.petrel_dir}/', 'file'), + f'{self.petrel_dir}/file') + self.assertEqual( + backend.join_path(f'{self.petrel_dir}/', '/file'), + f'{self.petrel_dir}/file') + self.assertEqual( + backend.join_path(self.petrel_dir, 'dir', 'file'), + f'{self.petrel_dir}/dir/file') + + def test_get_bytes(self): + backend = PetrelBackend() + with patch.object( + backend._client, 'Get', + return_value=b'petrel') as patched_get: + self.assertEqual( + backend.get_bytes(self.petrel_path), b'petrel') + patched_get.assert_called_once_with(self.expected_path) + + def test_get_text(self): + backend = PetrelBackend() + with patch.object( + backend._client, 'Get', + return_value=b'petrel') as patched_get: + self.assertEqual(backend.get_text(self.petrel_path), 'petrel') + patched_get.assert_called_once_with(self.expected_path) + + def test_put_bytes(self): + backend = PetrelBackend() + with patch.object(backend._client, 'put') as patched_put: + backend.put_bytes(b'petrel', self.petrel_path) + patched_put.assert_called_once_with(self.expected_path, + b'petrel') + + def test_put_text(self): + backend = PetrelBackend() + with patch.object(backend._client, 'put') as patched_put: + backend.put_text('petrel', self.petrel_path) + patched_put.assert_called_once_with(self.expected_path, + b'petrel') + + def test_exists(self): + backend = PetrelBackend() + self.assertTrue(has_method(backend._client, 'contains')) + self.assertTrue(has_method(backend._client, 'isdir')) + # raise Exception if `_client.contains` and '_client.isdir' are not + # implemented + with delete_and_reset_method(backend._client, 'contains'), \ + delete_and_reset_method(backend._client, 'isdir'): + self.assertFalse(has_method(backend._client, 'contains')) + self.assertFalse(has_method(backend._client, 'isdir')) + with self.assertRaises(NotImplementedError): + backend.exists(self.petrel_path) + + with patch.object( + backend._client, 'contains', + return_value=True) as patched_contains: + self.assertTrue(backend.exists(self.petrel_path)) + patched_contains.assert_called_once_with(self.expected_path) + + def test_isdir(self): + backend = PetrelBackend() + self.assertTrue(has_method(backend._client, 'isdir')) + # raise Exception if `_client.isdir` is not implemented + with delete_and_reset_method(backend._client, 'isdir'): + self.assertFalse(has_method(backend._client, 'isdir')) + with self.assertRaises(NotImplementedError): + backend.isdir(self.petrel_path) + + with patch.object( + backend._client, 'isdir', + return_value=True) as patched_contains: + self.assertTrue(backend.isdir(self.petrel_path)) + patched_contains.assert_called_once_with(self.expected_path) + + def test_isfile(self): + backend = PetrelBackend() + self.assertTrue(has_method(backend._client, 'contains')) + # raise Exception if `_client.contains` is not implemented + with delete_and_reset_method(backend._client, 'contains'): + self.assertFalse(has_method(backend._client, 'contains')) + with self.assertRaises(NotImplementedError): + backend.isfile(self.petrel_path) + + with patch.object( + backend._client, 'contains', + return_value=True) as patched_contains: + self.assertTrue(backend.isfile(self.petrel_path)) + patched_contains.assert_called_once_with(self.expected_path) + + def test_get_local_path(self): + backend = PetrelBackend() + with patch.object(backend._client, 'Get', + return_value=b'petrel') as patched_get, \ + patch.object(backend._client, 'contains', + return_value=True) as patch_contains: + with backend.get_local_path(self.petrel_path) as path: + self.assertTrue(osp.isfile(path)) + self.assertEqual(Path(path).open('rb').read(), b'petrel') + # exist the with block and path will be released + self.assertFalse(osp.isfile(path)) + patched_get.assert_called_once_with(self.expected_path) + patch_contains.assert_called_once_with(self.expected_path) + + def test_copyfile(self): + backend = PetrelBackend() + with patch.object(backend._client, 'Get', + return_value=b'petrel') as patched_get, \ + patch.object(backend._client, 'put') as patched_put, \ + patch.object(backend._client, 'isdir', return_value=False) as \ + patched_isdir: + src = self.petrel_path + dst = f'{self.petrel_dir}/test.bak.jpg' + expected_dst = f'{self.expected_dir}/test.bak.jpg' + self.assertEqual(backend.copyfile(src, dst), dst) + patched_get.assert_called_once_with(self.expected_path) + patched_put.assert_called_once_with(expected_dst, b'petrel') + patched_isdir.assert_called_once_with(expected_dst) + + with patch.object(backend._client, 'Get', + return_value=b'petrel') as patched_get, \ + patch.object(backend._client, 'put') as patched_put, \ + patch.object(backend._client, 'isdir', return_value=True) as \ + patched_isdir: + # dst is a directory + dst = f'{self.petrel_dir}/dir' + expected_dst = f'{self.expected_dir}/dir/test.jpg' + self.assertEqual(backend.copyfile(src, dst), f'{dst}/test.jpg') + patched_get.assert_called_once_with(self.expected_path) + patched_put.assert_called_once_with(expected_dst, b'petrel') + patched_isdir.assert_called_once_with( + f'{self.expected_dir}/dir') + + with patch.object(backend._client, 'Get', + return_value=b'petrel') as patched_get, \ + patch.object(backend._client, 'isdir', return_value=False) as \ + patched_isdir: + # src and src should not be same file + with self.assertRaises(SameFileError): + backend.copyfile(src, src) + + def test_copytree(self): + backend = PetrelBackend() + put_bytes_inputs = [] + get_bytes_inputs = [] + + def put_bytes(obj, filepath): + put_bytes_inputs.append((obj, filepath)) + + def get_bytes(filepath): + get_bytes_inputs.append(filepath) + + with build_temporary_directory() as tmp_dir, \ + patch.object(backend, 'put_bytes', side_effect=put_bytes),\ + patch.object(backend, 'get_bytes', side_effect=get_bytes),\ + patch.object(backend, 'exists', return_value=False): + dst = f'{tmp_dir}/dir' + self.assertEqual(backend.copytree(tmp_dir, dst), dst) + + self.assertEqual(len(put_bytes_inputs), 5) + self.assertEqual(len(get_bytes_inputs), 5) + + # dst should not exist + with patch.object(backend, 'exists', return_value=True): + with self.assertRaises(FileExistsError): + backend.copytree(dst, tmp_dir) + + def test_copyfile_from_local(self): + backend = PetrelBackend() + with patch.object(backend._client, 'put') as patched_put, \ + patch.object(backend._client, 'isdir', return_value=False) \ + as patched_isdir: + src = self.img_path + dst = f'{self.petrel_dir}/color.bak.jpg' + expected_dst = f'{self.expected_dir}/color.bak.jpg' + self.assertEqual(backend.copyfile_from_local(src, dst), dst) + patched_put.assert_called_once_with(expected_dst, + src.open('rb').read()) + patched_isdir.assert_called_once_with(expected_dst) + + with patch.object(backend._client, 'put') as patched_put, \ + patch.object(backend._client, 'isdir', return_value=True) as \ + patched_isdir: + # dst is a directory + src = self.img_path + dst = f'{self.petrel_dir}/dir' + expected_dst = f'{self.expected_dir}/dir/color.jpg' + self.assertEqual( + backend.copyfile_from_local(src, dst), f'{dst}/color.jpg') + patched_put.assert_called_once_with(expected_dst, + src.open('rb').read()) + patched_isdir.assert_called_once_with( + f'{self.expected_dir}/dir') + + def test_copytree_from_local(self): + backend = PetrelBackend() + inputs = [] + + def copyfile_from_local(src, dst): + inputs.append((src, dst)) + + with build_temporary_directory() as tmp_dir, \ + patch.object(backend, 'copyfile_from_local', + side_effect=copyfile_from_local),\ + patch.object(backend, 'exists', return_value=False): + backend.copytree_from_local(tmp_dir, self.petrel_dir) + + self.assertEqual(len(inputs), 5) + + # dst should not exist + with patch.object(backend, 'exists', return_value=True): + with self.assertRaises(FileExistsError): + backend.copytree_from_local(tmp_dir, self.petrel_dir) + + def test_copyfile_to_local(self): + backend = PetrelBackend() + with patch.object(backend._client, 'Get', + return_value=b'petrel') as patched_get, \ + tempfile.TemporaryDirectory() as tmp_dir: + src = self.petrel_path + dst = Path(tmp_dir) / 'test.bak.jpg' + self.assertEqual(backend.copyfile_to_local(src, dst), dst) + patched_get.assert_called_once_with(self.expected_path) + self.assertEqual(dst.open('rb').read(), b'petrel') + + with patch.object(backend._client, 'Get', + return_value=b'petrel') as patched_get, \ + tempfile.TemporaryDirectory() as tmp_dir: + # dst is a directory + src = self.petrel_path + dst = Path(tmp_dir) / 'dir' + dst.mkdir() + self.assertEqual( + backend.copyfile_to_local(src, dst), dst / 'test.jpg') + patched_get.assert_called_once_with(self.expected_path) + self.assertEqual((dst / 'test.jpg').open('rb').read(), + b'petrel') + + def test_copytree_to_local(self): + backend = PetrelBackend() + inputs = [] + + def get_bytes(filepath): + inputs.append(filepath) + return b'petrel' + + with build_temporary_directory() as tmp_dir, \ + patch.object(backend, 'get_bytes', side_effect=get_bytes): + dst = f'{tmp_dir}/dir' + backend.copytree_to_local(tmp_dir, dst) + + self.assertEqual(len(inputs), 5) + + def test_rmfile(self): + backend = PetrelBackend() + self.assertTrue(has_method(backend._client, 'delete')) + # raise Exception if `delete` is not implemented + with delete_and_reset_method(backend._client, 'delete'): + self.assertFalse(has_method(backend._client, 'delete')) + with self.assertRaises(NotImplementedError): + backend.rmfile(self.petrel_path) + + with patch.object(backend._client, 'delete') as patched_delete, \ + patch.object(backend._client, 'isdir', return_value=False) \ + as patched_isdir, \ + patch.object(backend._client, 'contains', return_value=True) \ + as patched_contains: + backend.rmfile(self.petrel_path) + patched_delete.assert_called_once_with(self.expected_path) + patched_isdir.assert_called_once_with(self.expected_path) + patched_contains.assert_called_once_with(self.expected_path) + + def test_rmtree(self): + backend = PetrelBackend() + inputs = [] + + def rmfile(filepath): + inputs.append(filepath) + + with build_temporary_directory() as tmp_dir,\ + patch.object(backend, 'rmfile', side_effect=rmfile): + backend.rmtree(tmp_dir) + + self.assertEqual(len(inputs), 5) + + def test_copy_if_symlink_fails(self): + backend = PetrelBackend() + copyfile_inputs = [] + copytree_inputs = [] + + def copyfile(src, dst): + copyfile_inputs.append((src, dst)) + + def copytree(src, dst): + copytree_inputs.append((src, dst)) + + with patch.object(backend, 'copyfile', side_effect=copyfile), \ + patch.object(backend, 'isfile', return_value=True): + backend.copy_if_symlink_fails(self.petrel_path, 'path') + + self.assertEqual(len(copyfile_inputs), 1) + + with patch.object(backend, 'copytree', side_effect=copytree), \ + patch.object(backend, 'isfile', return_value=False): + backend.copy_if_symlink_fails(self.petrel_dir, 'path') + + self.assertEqual(len(copytree_inputs), 1) + + def test_list_dir_or_file(self): + backend = PetrelBackend() + + # raise Exception if `_client.list` is not implemented + self.assertTrue(has_method(backend._client, 'list')) + with delete_and_reset_method(backend._client, 'list'): + self.assertFalse(has_method(backend._client, 'list')) + with self.assertRaises(NotImplementedError): + list(backend.list_dir_or_file(self.petrel_dir)) + + with build_temporary_directory() as tmp_dir: + # list directories and files + self.assertEqual( + set(backend.list_dir_or_file(tmp_dir)), + {'dir1', 'dir2', 'text1.txt', 'text2.txt'}) + + # list directories and files recursively + self.assertEqual( + set(backend.list_dir_or_file(tmp_dir, recursive=True)), { + 'dir1', '/'.join(('dir1', 'text3.txt')), 'dir2', + '/'.join(('dir2', 'dir3')), '/'.join( + ('dir2', 'dir3', 'text4.txt')), '/'.join( + ('dir2', 'img.jpg')), 'text1.txt', 'text2.txt' + }) + + # only list directories + self.assertEqual( + set(backend.list_dir_or_file(tmp_dir, list_file=False)), + {'dir1', 'dir2'}) + with self.assertRaisesRegex( + TypeError, + '`list_dir` should be False when `suffix` is not None' + ): + backend.list_dir_or_file( + tmp_dir, list_file=False, suffix='.txt') + + # only list directories recursively + self.assertEqual( + set( + backend.list_dir_or_file( + tmp_dir, list_file=False, recursive=True)), + {'dir1', 'dir2', '/'.join(('dir2', 'dir3'))}) + + # only list files + self.assertEqual( + set(backend.list_dir_or_file(tmp_dir, list_dir=False)), + {'text1.txt', 'text2.txt'}) + + # only list files recursively + self.assertEqual( + set( + backend.list_dir_or_file( + tmp_dir, list_dir=False, recursive=True)), + { + '/'.join(('dir1', 'text3.txt')), '/'.join( + ('dir2', 'dir3', 'text4.txt')), '/'.join( + ('dir2', 'img.jpg')), 'text1.txt', 'text2.txt' + }) + + # only list files ending with suffix + self.assertEqual( + set( + backend.list_dir_or_file( + tmp_dir, list_dir=False, suffix='.txt')), + {'text1.txt', 'text2.txt'}) + self.assertEqual( + set( + backend.list_dir_or_file( + tmp_dir, list_dir=False, suffix=('.txt', '.jpg'))), + {'text1.txt', 'text2.txt'}) + with self.assertRaisesRegex( + TypeError, + '`suffix` must be a string or tuple of strings'): + backend.list_dir_or_file( + tmp_dir, list_dir=False, suffix=['.txt', '.jpg']) + + # only list files ending with suffix recursively + self.assertEqual( + set( + backend.list_dir_or_file( + tmp_dir, + list_dir=False, + suffix='.txt', + recursive=True)), { + '/'.join(('dir1', 'text3.txt')), '/'.join( + ('dir2', 'dir3', 'text4.txt')), + 'text1.txt', 'text2.txt' + }) + + # only list files ending with suffix + self.assertEqual( + set( + backend.list_dir_or_file( + tmp_dir, + list_dir=False, + suffix=('.txt', '.jpg'), + recursive=True)), + { + '/'.join(('dir1', 'text3.txt')), '/'.join( + ('dir2', 'dir3', 'text4.txt')), '/'.join( + ('dir2', 'img.jpg')), 'text1.txt', 'text2.txt' + }) + + def test_generate_presigned_url(self): + pass + +else: + + class TestPetrelBackend(TestCase): # type: ignore + + @classmethod + def setUpClass(cls): + cls.test_data_dir = Path(__file__).parent.parent.parent / 'data' + cls.local_img_path = cls.test_data_dir / 'color.jpg' + cls.local_img_shape = (300, 400, 3) + cls.petrel_dir = 'petrel://mmengine-test/data' + + def setUp(self): + backend = PetrelBackend() + backend.rmtree(self.petrel_dir) + with build_temporary_directory() as tmp_dir: + backend.copytree_from_local(tmp_dir, self.petrel_dir) + + text1_path = f'{self.petrel_dir}/text1.txt' + text2_path = f'{self.petrel_dir}/text2.txt' + text3_path = f'{self.petrel_dir}/dir1/text3.txt' + text4_path = f'{self.petrel_dir}/dir2/dir3/text4.txt' + img_path = f'{self.petrel_dir}/dir2/img.jpg' + self.assertTrue(backend.isfile(text1_path)) + self.assertTrue(backend.isfile(text2_path)) + self.assertTrue(backend.isfile(text3_path)) + self.assertTrue(backend.isfile(text4_path)) + self.assertTrue(backend.isfile(img_path)) + + def test_get_bytes(self): + backend = PetrelBackend() + img_path = f'{self.petrel_dir}/dir2/img.jpg' + self.assertEqual(backend.get_bytes(img_path), b'img') + + def test_get_text(self): + backend = PetrelBackend() + text_path = f'{self.petrel_dir}/text1.txt' + self.assertEqual(backend.get_text(text_path), 'text1') + + def test_put_bytes(self): + backend = PetrelBackend() + img_path = f'{self.petrel_dir}/img.jpg' + backend.put_bytes(b'img', img_path) + + def test_put_text(self): + backend = PetrelBackend() + text_path = f'{self.petrel_dir}/text5.txt' + backend.put_text('text5', text_path) + + def test_exists(self): + backend = PetrelBackend() + + # file and directory exist + dir_path = f'{self.petrel_dir}/dir2' + self.assertTrue(backend.exists(dir_path)) + img_path = f'{self.petrel_dir}/dir2/img.jpg' + self.assertTrue(backend.exists(img_path)) + + # file and directory does not exist + not_existed_dir = f'{self.petrel_dir}/not_existed_dir' + self.assertFalse(backend.exists(not_existed_dir)) + not_existed_path = f'{self.petrel_dir}/img.jpg' + self.assertFalse(backend.exists(not_existed_path)) + + def test_isdir(self): + backend = PetrelBackend() + dir_path = f'{self.petrel_dir}/dir2' + self.assertTrue(backend.isdir(dir_path)) + img_path = f'{self.petrel_dir}/dir2/img.jpg' + self.assertFalse(backend.isdir(img_path)) + + def test_isfile(self): + backend = PetrelBackend() + dir_path = f'{self.petrel_dir}/dir2' + self.assertFalse(backend.isfile(dir_path)) + img_path = f'{self.petrel_dir}/dir2/img.jpg' + self.assertTrue(backend.isfile(img_path)) + + def test_get_local_path(self): + backend = PetrelBackend() + img_path = f'{self.petrel_dir}/dir2/img.jpg' + with backend.get_local_path(img_path) as path: + self.assertTrue(osp.isfile(path)) + self.assertEqual(Path(path).open('rb').read(), b'img') + # exist the with block and path will be released + self.assertFalse(osp.isfile(path)) + + def test_copyfile(self): + backend = PetrelBackend() + + # dst is a file + src = f'{self.petrel_dir}/dir2/img.jpg' + dst = f'{self.petrel_dir}/img.jpg' + self.assertEqual(backend.copyfile(src, dst), dst) + self.assertTrue(backend.isfile(dst)) + + # dst is a directory + dst = f'{self.petrel_dir}/dir1' + expected_dst = f'{self.petrel_dir}/dir1/img.jpg' + self.assertEqual(backend.copyfile(src, dst), expected_dst) + self.assertTrue(backend.isfile(expected_dst)) + + # src and src should not be same file + with self.assertRaises(SameFileError): + backend.copyfile(src, src) + + def test_copytree(self): + backend = PetrelBackend() + src = f'{self.petrel_dir}/dir2' + dst = f'{self.petrel_dir}/dir3' + self.assertFalse(backend.exists(dst)) + self.assertEqual(backend.copytree(src, dst), dst) + self.assertEqual( + list(backend.list_dir_or_file(src)), + list(backend.list_dir_or_file(dst))) + + # dst should not exist + with self.assertRaises(FileExistsError): + backend.copytree(src, dst) + + def test_copyfile_from_local(self): + backend = PetrelBackend() + + # dst is a file + src = self.local_img_path + dst = f'{self.petrel_dir}/color.jpg' + self.assertFalse(backend.exists(dst)) + self.assertEqual(backend.copyfile_from_local(src, dst), dst) + self.assertTrue(backend.isfile(dst)) + + # dst is a directory + src = self.local_img_path + dst = f'{self.petrel_dir}/dir1' + expected_dst = f'{self.petrel_dir}/dir1/color.jpg' + self.assertFalse(backend.exists(expected_dst)) + self.assertEqual( + backend.copyfile_from_local(src, dst), expected_dst) + self.assertTrue(backend.isfile(expected_dst)) + + def test_copytree_from_local(self): + backend = PetrelBackend() + backend.rmtree(self.petrel_dir) + with build_temporary_directory() as tmp_dir: + backend.copytree_from_local(tmp_dir, self.petrel_dir) + files = backend.list_dir_or_file( + self.petrel_dir, recursive=True) + self.assertEqual(len(list(files)), 8) + + def test_copyfile_to_local(self): + backend = PetrelBackend() + with tempfile.TemporaryDirectory() as tmp_dir: + # dst is a file + src = f'{self.petrel_dir}/dir2/img.jpg' + dst = Path(tmp_dir) / 'img.jpg' + self.assertEqual(backend.copyfile_to_local(src, dst), dst) + self.assertEqual(dst.open('rb').read(), b'img') + + # dst is a directory + dst = Path(tmp_dir) / 'dir' + dst.mkdir() + self.assertEqual( + backend.copyfile_to_local(src, dst), dst / 'img.jpg') + self.assertEqual((dst / 'img.jpg').open('rb').read(), b'img') + + def test_copytree_to_local(self): + backend = PetrelBackend() + with tempfile.TemporaryDirectory() as tmp_dir: + backend.copytree_to_local(self.petrel_dir, tmp_dir) + self.assertTrue(osp.exists(Path(tmp_dir) / 'text1.txt')) + self.assertTrue(osp.exists(Path(tmp_dir) / 'dir2' / 'img.jpg')) + + def test_rmfile(self): + backend = PetrelBackend() + img_path = f'{self.petrel_dir}/dir2/img.jpg' + self.assertTrue(backend.isfile(img_path)) + backend.rmfile(img_path) + self.assertFalse(backend.exists(img_path)) + + def test_rmtree(self): + backend = PetrelBackend() + dir_path = f'{self.petrel_dir}/dir2' + self.assertTrue(backend.isdir(dir_path)) + backend.rmtree(dir_path) + self.assertFalse(backend.exists(dir_path)) + + def test_copy_if_symlink_fails(self): + backend = PetrelBackend() + + # dst is a file + src = f'{self.petrel_dir}/dir2/img.jpg' + dst = f'{self.petrel_dir}/img.jpg' + self.assertFalse(backend.exists(dst)) + self.assertFalse(backend.copy_if_symlink_fails(src, dst)) + self.assertTrue(backend.isfile(dst)) + + # dst is a directory + src = f'{self.petrel_dir}/dir2' + dst = f'{self.petrel_dir}/dir' + self.assertFalse(backend.exists(dst)) + self.assertFalse(backend.copy_if_symlink_fails(src, dst)) + self.assertTrue(backend.isdir(dst)) + + def test_list_dir_or_file(self): + backend = PetrelBackend() + + # list directories and files + self.assertEqual( + set(backend.list_dir_or_file(self.petrel_dir)), + {'dir1', 'dir2', 'text1.txt', 'text2.txt'}) + + # list directories and files recursively + self.assertEqual( + set(backend.list_dir_or_file(self.petrel_dir, recursive=True)), + { + 'dir1', '/'.join(('dir1', 'text3.txt')), 'dir2', '/'.join( + ('dir2', 'dir3')), '/'.join( + ('dir2', 'dir3', 'text4.txt')), '/'.join( + ('dir2', 'img.jpg')), 'text1.txt', 'text2.txt' + }) + + # only list directories + self.assertEqual( + set( + backend.list_dir_or_file(self.petrel_dir, + list_file=False)), + {'dir1', 'dir2'}) + with self.assertRaisesRegex( + TypeError, + '`list_dir` should be False when `suffix` is not None'): + backend.list_dir_or_file( + self.petrel_dir, list_file=False, suffix='.txt') + + # only list directories recursively + self.assertEqual( + set( + backend.list_dir_or_file( + self.petrel_dir, list_file=False, recursive=True)), + {'dir1', 'dir2', '/'.join(('dir2', 'dir3'))}) + + # only list files + self.assertEqual( + set(backend.list_dir_or_file(self.petrel_dir, list_dir=False)), + {'text1.txt', 'text2.txt'}) + + # only list files recursively + self.assertEqual( + set( + backend.list_dir_or_file( + self.petrel_dir, list_dir=False, recursive=True)), + { + '/'.join(('dir1', 'text3.txt')), '/'.join( + ('dir2', 'dir3', 'text4.txt')), '/'.join( + ('dir2', 'img.jpg')), 'text1.txt', 'text2.txt' + }) + + # only list files ending with suffix + self.assertEqual( + set( + backend.list_dir_or_file( + self.petrel_dir, list_dir=False, suffix='.txt')), + {'text1.txt', 'text2.txt'}) + self.assertEqual( + set( + backend.list_dir_or_file( + self.petrel_dir, + list_dir=False, + suffix=('.txt', '.jpg'))), {'text1.txt', 'text2.txt'}) + with self.assertRaisesRegex( + TypeError, + '`suffix` must be a string or tuple of strings'): + backend.list_dir_or_file( + self.petrel_dir, list_dir=False, suffix=['.txt', '.jpg']) + + # only list files ending with suffix recursively + self.assertEqual( + set( + backend.list_dir_or_file( + self.petrel_dir, + list_dir=False, + suffix='.txt', + recursive=True)), { + '/'.join(('dir1', 'text3.txt')), '/'.join( + ('dir2', 'dir3', 'text4.txt')), 'text1.txt', + 'text2.txt' + }) + + # only list files ending with suffix + self.assertEqual( + set( + backend.list_dir_or_file( + self.petrel_dir, + list_dir=False, + suffix=('.txt', '.jpg'), + recursive=True)), + { + '/'.join(('dir1', 'text3.txt')), '/'.join( + ('dir2', 'dir3', 'text4.txt')), '/'.join( + ('dir2', 'img.jpg')), 'text1.txt', 'text2.txt' + }) + + def test_generate_presigned_url(self): + pass diff --git a/tests/test_fileio/test_backends/utils.py b/tests/test_fileio/test_backends/utils.py new file mode 100644 index 0000000000..5f4d8f458a --- /dev/null +++ b/tests/test_fileio/test_backends/utils.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import cv2 +import numpy as np + + +def imfrombytes(content): + img_np = np.frombuffer(content, np.uint8) + img = cv2.imdecode(img_np, cv2.IMREAD_COLOR) + return img diff --git a/tests/test_fileio/test_io.py b/tests/test_fileio/test_io.py new file mode 100644 index 0000000000..88f193fd5f --- /dev/null +++ b/tests/test_fileio/test_io.py @@ -0,0 +1,536 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +import platform +import sys +import tempfile +from contextlib import contextmanager +from pathlib import Path +from shutil import SameFileError +from unittest.mock import MagicMock, patch + +import pytest + +import mmengine.fileio as fileio + +sys.modules['petrel_client'] = MagicMock() +sys.modules['petrel_client.client'] = MagicMock() + +test_data_dir = Path(__file__).parent.parent / 'data' +text_path = test_data_dir / 'filelist.txt' +img_path = test_data_dir / 'color.jpg' +img_url = 'https://raw.githubusercontent.com/mmengine/tests/data/img.png' + + +@contextmanager +def build_temporary_directory(): + """Build a temporary directory containing many files to test + ``FileClient.list_dir_or_file``. + + . \n + | -- dir1 \n + | -- | -- text3.txt \n + | -- dir2 \n + | -- | -- dir3 \n + | -- | -- | -- text4.txt \n + | -- | -- img.jpg \n + | -- text1.txt \n + | -- text2.txt \n + """ + with tempfile.TemporaryDirectory() as tmp_dir: + text1 = Path(tmp_dir) / 'text1.txt' + text1.open('w').write('text1') + text2 = Path(tmp_dir) / 'text2.txt' + text2.open('w').write('text2') + dir1 = Path(tmp_dir) / 'dir1' + dir1.mkdir() + text3 = dir1 / 'text3.txt' + text3.open('w').write('text3') + dir2 = Path(tmp_dir) / 'dir2' + dir2.mkdir() + jpg1 = dir2 / 'img.jpg' + jpg1.open('wb').write(b'img') + dir3 = dir2 / 'dir3' + dir3.mkdir() + text4 = dir3 / 'text4.txt' + text4.open('w').write('text4') + yield tmp_dir + + +def test_parse_uri_prefix(): + # input path is None + with pytest.raises(AssertionError): + fileio.io._parse_uri_prefix(None) + + # input path is list + with pytest.raises(AssertionError): + fileio.io._parse_uri_prefix([]) + + # input path is Path object + assert fileio.io._parse_uri_prefix(uri=text_path) == '' + + # input path starts with https + assert fileio.io._parse_uri_prefix(uri=img_url) == 'https' + + # input path starts with s3 + uri = 's3://your_bucket/img.png' + assert fileio.io._parse_uri_prefix(uri) == 's3' + + # input path starts with clusterName:s3 + uri = 'clusterName:s3://your_bucket/img.png' + assert fileio.io._parse_uri_prefix(uri) == 's3' + + +def test_get_file_backend(): + # other unit tests may have added instances so clear them here. + fileio.io.backend_instances = {} + + # uri should not be None when "backend" does not exist in backend_args + with pytest.raises(ValueError, match='uri should not be None'): + fileio.get_file_backend(None, backend_args=None) + + # uri is not None + backend = fileio.get_file_backend(uri=text_path) + assert isinstance(backend, fileio.backends.LocalBackend) + + uri = 'petrel://your_bucket/img.png' + backend = fileio.get_file_backend(uri=uri) + assert isinstance(backend, fileio.backends.PetrelBackend) + + backend = fileio.get_file_backend(uri=img_url) + assert isinstance(backend, fileio.backends.HTTPBackend) + uri = 'http://raw.githubusercontent.com/mmengine/tests/data/img.png' + backend = fileio.get_file_backend(uri=uri) + assert isinstance(backend, fileio.backends.HTTPBackend) + + # backend_args is not None and it contains a backend name + backend_args = {'backend': 'local'} + backend = fileio.get_file_backend(uri=None, backend_args=backend_args) + assert isinstance(backend, fileio.backends.LocalBackend) + + backend_args = {'backend': 'petrel', 'enable_mc': True} + backend = fileio.get_file_backend(uri=None, backend_args=backend_args) + assert isinstance(backend, fileio.backends.PetrelBackend) + + # backend name has a higher priority + backend_args = {'backend': 'http'} + backend = fileio.get_file_backend(uri=text_path, backend_args=backend_args) + assert isinstance(backend, fileio.backends.HTTPBackend) + + # test enable_singleton parameter + assert len(fileio.io.backend_instances) == 0 + backend1 = fileio.get_file_backend(uri=text_path, enable_singleton=True) + assert isinstance(backend1, fileio.backends.LocalBackend) + assert len(fileio.io.backend_instances) == 1 + assert fileio.io.backend_instances[':{}'] is backend1 + + backend2 = fileio.get_file_backend(uri=text_path, enable_singleton=True) + assert isinstance(backend2, fileio.backends.LocalBackend) + assert len(fileio.io.backend_instances) == 1 + assert backend2 is backend1 + + backend3 = fileio.get_file_backend(uri=text_path, enable_singleton=False) + assert isinstance(backend3, fileio.backends.LocalBackend) + assert len(fileio.io.backend_instances) == 1 + assert backend3 is not backend2 + + backend_args = {'path_mapping': {'src': 'dst'}, 'enable_mc': True} + uri = 'petrel://your_bucket/img.png' + backend4 = fileio.get_file_backend( + uri=uri, backend_args=backend_args, enable_singleton=True) + assert isinstance(backend4, fileio.backends.PetrelBackend) + assert len(fileio.io.backend_instances) == 2 + unique_key = 'petrel:{"path_mapping": {"src": "dst"}, "enable_mc": true}' + assert fileio.io.backend_instances[unique_key] is backend4 + assert backend4 is not backend2 + + uri = 'petrel://your_bucket/img1.png' + backend5 = fileio.get_file_backend( + uri=uri, backend_args=backend_args, enable_singleton=True) + assert isinstance(backend5, fileio.backends.PetrelBackend) + assert len(fileio.io.backend_instances) == 2 + assert backend5 is backend4 + assert backend5 is not backend2 + + backend_args = {'path_mapping': {'src1': 'dst1'}, 'enable_mc': True} + backend6 = fileio.get_file_backend( + uri=uri, backend_args=backend_args, enable_singleton=True) + assert isinstance(backend6, fileio.backends.PetrelBackend) + assert len(fileio.io.backend_instances) == 3 + unique_key = 'petrel:{"path_mapping": {"src1": "dst1"}, "enable_mc": true}' + assert fileio.io.backend_instances[unique_key] is backend6 + assert backend6 is not backend4 + assert backend6 is not backend5 + + backend7 = fileio.get_file_backend( + uri=uri, backend_args=backend_args, enable_singleton=False) + assert isinstance(backend7, fileio.backends.PetrelBackend) + assert len(fileio.io.backend_instances) == 3 + assert backend7 is not backend6 + + +def test_get_bytes(): + # test HardDiskBackend + filepath = Path(img_path) + img_bytes = fileio.get_bytes(filepath) + assert filepath.open('rb').read() == img_bytes + + +def test_get_text(): + # test HardDiskBackend + filepath = Path(text_path) + text = fileio.get_text(filepath) + assert filepath.open('r').read() == text + + +def test_put_bytes(): + # test HardDiskBackend + with tempfile.TemporaryDirectory() as tmp_dir: + filepath = Path(tmp_dir) / 'img.png' + fileio.put_bytes(b'disk', filepath) + assert fileio.get_bytes(filepath) == b'disk' + + # If the directory does not exist, put_bytes will create a + # directory first + filepath = Path(tmp_dir) / 'not_existed_dir' / 'test.jpg' + fileio.put_bytes(b'disk', filepath) + assert fileio.get_bytes(filepath) == b'disk' + + +def test_put_text(): + # test HardDiskBackend + with tempfile.TemporaryDirectory() as tmp_dir: + filepath = Path(tmp_dir) / 'text.txt' + fileio.put_text('text', filepath) + assert fileio.get_text(filepath) == 'text' + + # If the directory does not exist, put_text will create a + # directory first + filepath = Path(tmp_dir) / 'not_existed_dir' / 'test.txt' + fileio.put_text('disk', filepath) + assert fileio.get_text(filepath) == 'disk' + + +def test_exists(): + # test HardDiskBackend + with tempfile.TemporaryDirectory() as tmp_dir: + assert fileio.exists(tmp_dir) + filepath = Path(tmp_dir) / 'test.txt' + assert not fileio.exists(filepath) + fileio.put_text('disk', filepath) + assert fileio.exists(filepath) + + +def test_isdir(): + # test HardDiskBackend + with tempfile.TemporaryDirectory() as tmp_dir: + assert fileio.isdir(tmp_dir) + filepath = Path(tmp_dir) / 'test.txt' + fileio.put_text('disk', filepath) + assert not fileio.isdir(filepath) + + +def test_isfile(): + # test HardDiskBackend + with tempfile.TemporaryDirectory() as tmp_dir: + assert not fileio.isfile(tmp_dir) + filepath = Path(tmp_dir) / 'test.txt' + fileio.put_text('disk', filepath) + assert fileio.isfile(filepath) + + +def test_join_path(): + # test HardDiskBackend + filepath = fileio.join_path(test_data_dir, 'file') + expected = osp.join(test_data_dir, 'file') + assert filepath == expected + + filepath = fileio.join_path(test_data_dir, 'dir', 'file') + expected = osp.join(test_data_dir, 'dir', 'file') + assert filepath == expected + + +def test_get_local_path(): + # test HardDiskBackend + with fileio.get_local_path(text_path) as filepath: + assert str(text_path) == filepath + + +def test_copyfile(): + # test HardDiskBackend + with tempfile.TemporaryDirectory() as tmp_dir: + src = Path(tmp_dir) / 'test.txt' + fileio.put_text('disk', src) + dst = Path(tmp_dir) / 'test.txt.bak' + assert fileio.copyfile(src, dst) == dst + assert fileio.get_text(dst) == 'disk' + + # dst is a directory + dst = Path(tmp_dir) / 'dir' + dst.mkdir() + assert fileio.copyfile(src, dst) == fileio.join_path(dst, 'test.txt') + assert fileio.get_text(fileio.join_path(dst, 'test.txt')) == 'disk' + + # src and src should not be same file + with pytest.raises(SameFileError): + fileio.copyfile(src, src) + + +def test_copytree(): + # test HardDiskBackend + with build_temporary_directory() as tmp_dir: + # src and dst are Path objects + src = Path(tmp_dir) / 'dir1' + dst = Path(tmp_dir) / 'dir100' + assert fileio.copytree(src, dst) == dst + assert fileio.isdir(dst) + assert fileio.isfile(dst / 'text3.txt') + assert fileio.get_text(dst / 'text3.txt') == 'text3' + + # dst should not exist + with pytest.raises(FileExistsError): + fileio.copytree(src, Path(tmp_dir) / 'dir2') + + +def test_copyfile_from_local(): + # test HardDiskBackend + with tempfile.TemporaryDirectory() as tmp_dir: + src = Path(tmp_dir) / 'test.txt' + fileio.put_text('disk', src) + dst = Path(tmp_dir) / 'test.txt.bak' + assert fileio.copyfile(src, dst) == dst + assert fileio.get_text(dst) == 'disk' + + dst = Path(tmp_dir) / 'dir' + dst.mkdir() + assert fileio.copyfile(src, dst) == fileio.join_path(dst, 'test.txt') + assert fileio.get_text(fileio.join_path(dst, 'test.txt')) == 'disk' + + # src and src should not be same file + with pytest.raises(SameFileError): + fileio.copyfile(src, src) + + +def test_copytree_from_local(): + # test HardDiskBackend + with build_temporary_directory() as tmp_dir: + # src and dst are Path objects + src = Path(tmp_dir) / 'dir1' + dst = Path(tmp_dir) / 'dir100' + assert fileio.copytree(src, dst) == dst + assert fileio.isdir(dst) + assert fileio.isfile(dst / 'text3.txt') + assert fileio.get_text(dst / 'text3.txt') == 'text3' + + # dst should not exist + with pytest.raises(FileExistsError): + fileio.copytree(src, Path(tmp_dir) / 'dir2') + + +def test_copyfile_to_local(): + # test HardDiskBackend + with tempfile.TemporaryDirectory() as tmp_dir: + src = Path(tmp_dir) / 'test.txt' + fileio.put_text('disk', src) + dst = Path(tmp_dir) / 'test.txt.bak' + assert fileio.copyfile(src, dst) == dst + assert fileio.get_text(dst) == 'disk' + + dst = Path(tmp_dir) / 'dir' + dst.mkdir() + assert fileio.copyfile(src, dst) == fileio.join_path(dst, 'test.txt') + assert fileio.get_text(fileio.join_path(dst, 'test.txt')) == 'disk' + + # src and src should not be same file + with pytest.raises(SameFileError): + fileio.copyfile(src, src) + + +def test_copytree_to_local(): + # test HardDiskBackend + with build_temporary_directory() as tmp_dir: + # src and dst are Path objects + src = Path(tmp_dir) / 'dir1' + dst = Path(tmp_dir) / 'dir100' + assert fileio.copytree(src, dst) == dst + assert fileio.isdir(dst) + assert fileio.isfile(dst / 'text3.txt') + assert fileio.get_text(dst / 'text3.txt') == 'text3' + + # dst should not exist + with pytest.raises(FileExistsError): + fileio.copytree(src, Path(tmp_dir) / 'dir2') + + +def test_rmfile(): + # test HardDiskBackend + with tempfile.TemporaryDirectory() as tmp_dir: + # filepath is a Path object + filepath = Path(tmp_dir) / 'test.txt' + fileio.put_text('disk', filepath) + assert fileio.exists(filepath) + fileio.rmfile(filepath) + assert not fileio.exists(filepath) + + # raise error if file does not exist + with pytest.raises(FileNotFoundError): + filepath = Path(tmp_dir) / 'test1.txt' + fileio.rmfile(filepath) + + # can not remove directory + filepath = Path(tmp_dir) / 'dir' + filepath.mkdir() + with pytest.raises(IsADirectoryError): + fileio.rmfile(filepath) + + +def test_rmtree(): + # test HardDiskBackend + with build_temporary_directory() as tmp_dir: + # src and dst are Path objects + dir_path = Path(tmp_dir) / 'dir1' + assert fileio.exists(dir_path) + fileio.rmtree(dir_path) + assert not fileio.exists(dir_path) + + dir_path = Path(tmp_dir) / 'dir2' + assert fileio.exists(dir_path) + fileio.rmtree(dir_path) + assert not fileio.exists(dir_path) + + +def test_copy_if_symlink_fails(): + # test HardDiskBackend + with tempfile.TemporaryDirectory() as tmp_dir: + # create a symlink for a file + src = Path(tmp_dir) / 'test.txt' + fileio.put_text('disk', src) + dst = Path(tmp_dir) / 'test_link.txt' + res = fileio.copy_if_symlink_fails(src, dst) + if platform.system() == 'Linux': + assert res + assert osp.islink(dst) + assert fileio.get_text(dst) == 'disk' + + # create a symlink for a directory + src = Path(tmp_dir) / 'dir' + src.mkdir() + dst = Path(tmp_dir) / 'dir_link' + res = fileio.copy_if_symlink_fails(src, dst) + if platform.system() == 'Linux': + assert res + assert osp.islink(dst) + assert fileio.exists(dst) + + def symlink(src, dst): + raise Exception + + # copy files if symblink fails + with patch.object(os, 'symlink', side_effect=symlink): + src = Path(tmp_dir) / 'test.txt' + dst = Path(tmp_dir) / 'test_link1.txt' + res = fileio.copy_if_symlink_fails(src, dst) + assert not res + assert not osp.islink(dst) + assert fileio.exists(dst) + + # copy directory if symblink fails + with patch.object(os, 'symlink', side_effect=symlink): + src = Path(tmp_dir) / 'dir' + dst = Path(tmp_dir) / 'dir_link1' + res = fileio.copy_if_symlink_fails(src, dst) + assert not res + assert not osp.islink(dst) + assert fileio.exists(dst) + + +def test_list_dir_or_file(): + # test HardDiskBackend + with build_temporary_directory() as tmp_dir: + # list directories and files + assert set(fileio.list_dir_or_file(tmp_dir)) == { + 'dir1', 'dir2', 'text1.txt', 'text2.txt' + } + + # list directories and files recursively + assert set(fileio.list_dir_or_file(tmp_dir, recursive=True)) == { + 'dir1', + osp.join('dir1', 'text3.txt'), 'dir2', + osp.join('dir2', 'dir3'), + osp.join('dir2', 'dir3', 'text4.txt'), + osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' + } + + # only list directories + assert set(fileio.list_dir_or_file( + tmp_dir, list_file=False)) == {'dir1', 'dir2'} + + with pytest.raises( + TypeError, + match='`suffix` should be None when `list_dir` is True'): + list( + fileio.list_dir_or_file( + tmp_dir, list_file=False, suffix='.txt')) + + # only list directories recursively + assert set( + fileio.list_dir_or_file( + tmp_dir, list_file=False, + recursive=True)) == {'dir1', 'dir2', + osp.join('dir2', 'dir3')} + + # only list files + assert set(fileio.list_dir_or_file( + tmp_dir, list_dir=False)) == {'text1.txt', 'text2.txt'} + + # only list files recursively + assert set( + fileio.list_dir_or_file(tmp_dir, list_dir=False, + recursive=True)) == { + osp.join('dir1', 'text3.txt'), + osp.join('dir2', 'dir3', 'text4.txt'), + osp.join('dir2', 'img.jpg'), + 'text1.txt', 'text2.txt' + } + + # only list files ending with suffix + assert set( + fileio.list_dir_or_file( + tmp_dir, list_dir=False, + suffix='.txt')) == {'text1.txt', 'text2.txt'} + assert set( + fileio.list_dir_or_file( + tmp_dir, list_dir=False, + suffix=('.txt', '.jpg'))) == {'text1.txt', 'text2.txt'} + + with pytest.raises( + TypeError, + match='`suffix` must be a string or tuple of strings'): + list( + fileio.list_dir_or_file( + tmp_dir, list_dir=False, suffix=['.txt', '.jpg'])) + + # only list files ending with suffix recursively + assert set( + fileio.list_dir_or_file( + tmp_dir, list_dir=False, suffix='.txt', recursive=True)) == { + osp.join('dir1', 'text3.txt'), + osp.join('dir2', 'dir3', 'text4.txt'), 'text1.txt', + 'text2.txt' + } + + # only list files ending with suffix + assert set( + fileio.list_dir_or_file( + tmp_dir, + list_dir=False, + suffix=('.txt', '.jpg'), + recursive=True)) == { + osp.join('dir1', 'text3.txt'), + osp.join('dir2', 'dir3', 'text4.txt'), + osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' + } + + +def test_generate_presigned_url(): + pass From c363fa9aea65ebf1ce65395ae8521c66abe76481 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Thu, 15 Sep 2022 17:32:33 +0800 Subject: [PATCH 02/18] handle compatibility --- mmengine/fileio/backends/registry_utils.py | 4 +- mmengine/fileio/file_client.py | 4 ++ mmengine/hooks/checkpoint_hook.py | 53 ++++++++++++----- mmengine/hooks/logger_hook.py | 36 +++++++++--- mmengine/runner/checkpoint.py | 67 +++++++++++----------- mmengine/runner/runner.py | 42 ++++++++++---- 6 files changed, 139 insertions(+), 67 deletions(-) diff --git a/mmengine/fileio/backends/registry_utils.py b/mmengine/fileio/backends/registry_utils.py index 221107a12f..4578a4ca76 100644 --- a/mmengine/fileio/backends/registry_utils.py +++ b/mmengine/fileio/backends/registry_utils.py @@ -111,5 +111,7 @@ def _register(backend_cls): register_backend('local', LocalBackend, prefixes='') register_backend('memcached', MemcachedBackend) register_backend('lmdb', LmdbBackend) -register_backend('petrel', PetrelBackend, prefixes='petrel') +# To avoid breaking backward Compatibility, 's3' is also used as a +# prefix for PetrelBackend +register_backend('petrel', PetrelBackend, prefixes=['petrel', 's3']) register_backend('http', HTTPBackend, prefixes=['http', 'https']) diff --git a/mmengine/fileio/file_client.py b/mmengine/fileio/file_client.py index a371a1864e..96822ba234 100644 --- a/mmengine/fileio/file_client.py +++ b/mmengine/fileio/file_client.py @@ -1017,6 +1017,10 @@ def remove(self, filepath: Union[str, Path]) -> None: """ self.client.remove(filepath) + get_bytes = get + put_bytes = put + rmfile = remove + def exists(self, filepath: Union[str, Path]) -> bool: """Check whether a file path exists. diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index 448d2545cb..d237e4ed06 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -7,7 +7,7 @@ from typing import Callable, Dict, List, Optional, Sequence, Union from mmengine.dist import master_only -from mmengine.fileio import FileClient +from mmengine.fileio import FileClient, get_file_backend from mmengine.registry import HOOKS from mmengine.utils import is_list_of, is_seq_of from .hook import Hook @@ -72,8 +72,12 @@ class CheckpointHook(Hook): inferred by 'less' comparison rule. If ``None``, _default_less_keys will be used. Defaults to None. file_client_args (dict, optional): Arguments to instantiate a - FileClient. See :class:`mmcv.fileio.FileClient` for details. - Defaults to None. + FileClient. See :class:`mmengine.fileio.FileClient` for details. + Defaults to None. It will be deprecated in future. Please use + ``backend_args`` instead. + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + New in v0.2.0. Examples: >>> # Save best based on single metric @@ -116,6 +120,7 @@ def __init__(self, greater_keys: Optional[Sequence[str]] = None, less_keys: Optional[Sequence[str]] = None, file_client_args: Optional[dict] = None, + backend_args: Optional[dict] = None, **kwargs) -> None: self.interval = interval self.by_epoch = by_epoch @@ -125,7 +130,17 @@ def __init__(self, self.max_keep_ckpts = max_keep_ckpts self.save_last = save_last self.args = kwargs + + if file_client_args is not None: + warnings.warn( + '"file_client_args" will be deprecated in future. ', + 'Please use "backend_args" instead', DeprecationWarning) + if backend_args is not None: + raise ValueError( + '"file_client_args and "backend_args" cannot be both set.') + self.file_client_args = file_client_args + self.backend_args = backend_args # save best logic assert (isinstance(save_best, str) or is_list_of(save_best, str) @@ -197,19 +212,28 @@ def before_train(self, runner) -> None: if self.out_dir is None: self.out_dir = runner.work_dir + # If self.file_client_args is None, self.file_client will not + # used in CheckpointHook. To avoid breaking backward compatibility, + # it will not be removed util the release of MMEngine1.0 self.file_client = FileClient.infer_client(self.file_client_args, self.out_dir) + + if self.file_client_args is None: + self.file_backend = get_file_backend( + self.out_dir, backend_args=self.backend_args) + else: + self.file_backend = self.file_client + # if `self.out_dir` is not equal to `runner.work_dir`, it means that # `self.out_dir` is set so the final `self.out_dir` is the # concatenation of `self.out_dir` and the last level directory of # `runner.work_dir` if self.out_dir != runner.work_dir: basename = osp.basename(runner.work_dir.rstrip(osp.sep)) - self.out_dir = self.file_client.join_path( + self.out_dir = self.file_backend.join_path( self.out_dir, basename) # type: ignore # noqa: E501 - runner.logger.info(f'Checkpoints will be saved to {self.out_dir} by ' - f'{self.file_client.name}.') + runner.logger.info(f'Checkpoints will be saved to {self.out_dir}.') if self.save_best is not None: if len(self.key_indicators) == 1: @@ -290,11 +314,12 @@ def _save_checkpoint(self, runner) -> None: save_optimizer=self.save_optimizer, save_param_scheduler=self.save_param_scheduler, by_epoch=self.by_epoch, + backend_args=self.backend_args, **self.args) runner.message_hub.update_info( - 'last_ckpt', self.file_client.join_path(self.out_dir, - ckpt_filename)) + 'last_ckpt', + self.file_backend.join_path(self.out_dir, ckpt_filename)) # remove other checkpoints if self.max_keep_ckpts > 0: @@ -309,16 +334,15 @@ def _save_checkpoint(self, runner) -> None: -self.interval) filename_tmpl = self.args.get('filename_tmpl', name) for _step in redundant_ckpts: - ckpt_path = self.file_client.join_path( + ckpt_path = self.file_backend.join_path( self.out_dir, filename_tmpl.format(_step)) - if self.file_client.isfile(ckpt_path): - self.file_client.remove(ckpt_path) + if self.file_backend.isfile(ckpt_path): + self.file_backend.rmfile(ckpt_path) else: break save_file = osp.join(runner.work_dir, 'last_checkpoint') - file_client = FileClient.infer_client(uri=self.out_dir) - filepath = file_client.join_path(self.out_dir, ckpt_filename) + filepath = self.file_backend.join_path(self.out_dir, ckpt_filename) with open(save_file, 'w') as f: f.write(filepath) @@ -397,7 +421,8 @@ def _save_best_checkpoint(self, runner, metrics) -> None: file_client_args=self.file_client_args, save_optimizer=False, save_param_scheduler=False, - by_epoch=False) + by_epoch=False, + backend_args=self.backend_args) runner.logger.info( f'The best checkpoint with {best_score:0.4f} {key_indicator} ' f'at {cur_time} {cur_type} is saved to {best_ckpt_name}.') diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py index c189c2d1c0..c1cf669770 100644 --- a/mmengine/hooks/logger_hook.py +++ b/mmengine/hooks/logger_hook.py @@ -1,10 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import os import os.path as osp +import warnings from pathlib import Path from typing import Dict, Optional, Sequence, Union from mmengine.fileio import FileClient, dump +from mmengine.fileio.io import get_file_backend from mmengine.hooks import Hook from mmengine.registry import HOOKS from mmengine.utils import is_tuple_of, scandir @@ -50,12 +52,16 @@ class LoggerHook(Hook): removed. Defaults to True. file_client_args (dict, optional): Arguments to instantiate a FileClient. See :class:`mmengine.fileio.FileClient` for details. - Defaults to None. + Defaults to None. It will be deprecated in future. Please use + `backend_args` instead. log_metric_by_epoch (bool): Whether to output metric in validation step by epoch. It can be true when running in epoch based runner. If set to True, `after_val_epoch` will set `step` to self.epoch in `runner.visualizer.add_scalars`. Otherwise `step` will be self.iter. Default to True. + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + New in v0.2.0. Examples: >>> # The simplest LoggerHook config. @@ -71,7 +77,8 @@ def __init__(self, out_suffix: SUFFIX_TYPE = ('.json', '.log', '.py', 'yaml'), keep_local: bool = True, file_client_args: Optional[dict] = None, - log_metric_by_epoch: bool = True): + log_metric_by_epoch: bool = True, + backend_args: Optional[dict] = None): self.interval = interval self.ignore_last = ignore_last self.interval_exp_name = interval_exp_name @@ -82,6 +89,14 @@ def __init__(self, 'specified.') self.out_dir = out_dir + if file_client_args is not None: + warnings.warn( + '"file_client_args" will be deprecated in future. ', + 'Please use "backend_args" instead', DeprecationWarning) + if backend_args is not None: + raise ValueError( + '"file_client_args and "backend_args" cannot be both set.') + if not (out_dir is None or isinstance(out_dir, str) or is_tuple_of(out_dir, str)): raise TypeError('out_dir should be None or string or tuple of ' @@ -91,9 +106,16 @@ def __init__(self, self.keep_local = keep_local self.file_client_args = file_client_args self.json_log_path: Optional[str] = None + if self.out_dir is not None: self.file_client = FileClient.infer_client(file_client_args, self.out_dir) + if file_client_args is None: + self.file_backend = get_file_backend( + self.out_dir, backend_args=backend_args) + else: + self.file_backend = self.file_client + self.log_metric_by_epoch = log_metric_by_epoch def before_run(self, runner) -> None: @@ -107,10 +129,10 @@ def before_run(self, runner) -> None: # The final `self.out_dir` is the concatenation of `self.out_dir` # and the last level directory of `runner.work_dir` basename = osp.basename(runner.work_dir.rstrip(osp.sep)) - self.out_dir = self.file_client.join_path(self.out_dir, basename) + self.out_dir = self.file_backend.join_path(self.out_dir, basename) runner.logger.info( - f'Text logs will be saved to {self.out_dir} by ' - f'{self.file_client.name} after the training process.') + f'Text logs will be saved to {self.out_dir} after the training process.' + ) self.json_log_path = f'{runner.timestamp}.json' @@ -245,9 +267,9 @@ def after_run(self, runner) -> None: return for filename in scandir(runner._log_dir, self.out_suffix, True): local_filepath = osp.join(runner._log_dir, filename) - out_filepath = self.file_client.join_path(self.out_dir, filename) + out_filepath = self.file_backend.join_path(self.out_dir, filename) with open(local_filepath) as f: - self.file_client.put_text(f.read(), out_filepath) + self.file_backend.put_text(f.read(), out_filepath) runner.logger.info( f'The file {local_filepath} has been uploaded to ' diff --git a/mmengine/runner/checkpoint.py b/mmengine/runner/checkpoint.py index 44ae462d03..1acc7c5c04 100644 --- a/mmengine/runner/checkpoint.py +++ b/mmengine/runner/checkpoint.py @@ -15,7 +15,7 @@ import mmengine from mmengine.dist import get_dist_info -from mmengine.fileio import FileClient +from mmengine.fileio import FileClient, get_file_backend from mmengine.fileio import load as load_file from mmengine.logging import print_log from mmengine.model import is_model_wrapper @@ -334,7 +334,8 @@ def load_from_pavi(filename, map_location=None): return checkpoint -@CheckpointLoader.register_scheme(prefixes=r'(\S+\:)?s3://') +@CheckpointLoader.register_scheme( + prefixes=[r'(\S+\:)?s3://', r'(\S+\:)?petrel://']) def load_from_ceph(filename, map_location=None, backend='petrel'): """load checkpoint through the file path prefixed with s3. In distributed setting, this function download ckpt at all ranks to different temporary @@ -343,35 +344,14 @@ def load_from_ceph(filename, map_location=None, backend='petrel'): Args: filename (str): checkpoint file path with s3 prefix map_location (str, optional): Same as :func:`torch.load`. - backend (str, optional): The storage backend type. Options are 'ceph', - 'petrel'. Default: 'petrel'. - - .. warning:: - :class:`mmengine.fileio.file_client.CephBackend` will be deprecated, - please use :class:`mmengine.fileio.file_client.PetrelBackend` instead. + backend (str, optional): The storage backend type. Defaults to 'petrel'. Returns: dict or OrderedDict: The loaded checkpoint. """ - allowed_backends = ['ceph', 'petrel'] - if backend not in allowed_backends: - raise ValueError(f'Load from Backend {backend} is not supported.') - - if backend == 'ceph': - warnings.warn( - 'CephBackend will be deprecated, please use PetrelBackend instead', - DeprecationWarning) - - # CephClient and PetrelBackend have the same prefix 's3://' and the latter - # will be chosen as default. If PetrelBackend can not be instantiated - # successfully, the CephClient will be chosen. - try: - file_client = FileClient(backend=backend) - except ImportError: - allowed_backends.remove(backend) - file_client = FileClient(backend=allowed_backends[0]) - - with io.BytesIO(file_client.get(filename)) as buffer: + file_backend = get_file_backend( + filename, backend_args={'backend': backend}) + with io.BytesIO(file_backend.get_bytes(filename)) as buffer: checkpoint = torch.load(buffer, map_location=map_location) return checkpoint @@ -658,7 +638,10 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False): return destination -def save_checkpoint(checkpoint, filename, file_client_args=None): +def save_checkpoint(checkpoint, + filename, + file_client_args=None, + backend_args=None): """Save checkpoint to file. Args: @@ -666,13 +649,25 @@ def save_checkpoint(checkpoint, filename, file_client_args=None): filename (str): Checkpoint filename. file_client_args (dict, optional): Arguments to instantiate a FileClient. See :class:`mmengine.fileio.FileClient` for details. - Defaults to None. + Defaults to None. It will be deprecated in future. Please use + `backend_args` instead. + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + New in v0.2.0. """ + if file_client_args is not None: + warnings.warn( + '"file_client_args" will be deprecated in future. ' + 'Please use "backend_args" instead', DeprecationWarning) + if backend_args is not None: + raise ValueError( + '"file_client_args and "backend_args" cannot be both set.') + if filename.startswith('pavi://'): - if file_client_args is not None: + if file_client_args is not None or backend_args is not None: raise ValueError( - 'file_client_args should be "None" if filename starts with' - f'"pavi://", but got {file_client_args}') + '"file_client_args" or "backend_args" should be "None" if ' + 'filename starts with "pavi://"') try: from pavi import exception, modelcloud except ImportError: @@ -693,9 +688,15 @@ def save_checkpoint(checkpoint, filename, file_client_args=None): model.create_file(checkpoint_file, name=model_name) else: file_client = FileClient.infer_client(file_client_args, filename) + if file_client_args is None: + file_backend = get_file_backend( + filename, backend_args=backend_args) + else: + file_backend = file_client + with io.BytesIO() as f: torch.save(checkpoint, f) - file_client.put(f.getvalue(), filename) + file_backend.put_bytes(f.getvalue(), filename) def find_latest_checkpoint(path: str) -> Optional[str]: diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index f0d6ad605d..24a2c6eba2 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -23,7 +23,7 @@ from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist, is_distributed, master_only) from mmengine.evaluator import Evaluator -from mmengine.fileio import FileClient +from mmengine.fileio import FileClient, join_path from mmengine.hooks import Hook from mmengine.logging import MessageHub, MMLogger, print_log from mmengine.model import (BaseModel, MMDistributedDataParallel, @@ -1991,14 +1991,17 @@ def load_checkpoint(self, return checkpoint @master_only - def save_checkpoint(self, - out_dir: str, - filename: str, - file_client_args: Optional[dict] = None, - save_optimizer: bool = True, - save_param_scheduler: bool = True, - meta: dict = None, - by_epoch: bool = True): + def save_checkpoint( + self, + out_dir: str, + filename: str, + file_client_args: Optional[dict] = None, + save_optimizer: bool = True, + save_param_scheduler: bool = True, + meta: dict = None, + by_epoch: bool = True, + backend_args: Optional[dict] = None, + ): """Save checkpoints. ``CheckpointHook`` invokes this method to save checkpoints @@ -2008,7 +2011,9 @@ def save_checkpoint(self, out_dir (str): The directory that checkpoints are saved. filename (str): The checkpoint filename. file_client_args (dict, optional): Arguments to instantiate a - FileClient. Default: None. + FileClient. See :class:`mmengine.fileio.FileClient` for + details. Defaults to None. It will be deprecated in future. + Please use `backend_args` instead. save_optimizer (bool): Whether to save the optimizer to the checkpoint. Defaults to True. save_param_scheduler (bool): Whether to save the param_scheduler @@ -2017,6 +2022,9 @@ def save_checkpoint(self, checkpoint. Defaults to None. by_epoch (bool): Whether the scheduled momentum is updated by epochs. Defaults to True. + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + New in v0.2.0. """ if meta is None: meta = {} @@ -2033,8 +2041,18 @@ def save_checkpoint(self, else: meta.update(epoch=self.epoch, iter=self.iter + 1) - file_client = FileClient.infer_client(file_client_args, out_dir) - filepath = file_client.join_path(out_dir, filename) + if file_client_args is not None: + warnings.warn( + '"file_client_args" will be deprecated in future. ' + 'Please use "backend_args" instead', DeprecationWarning) + if backend_args is not None: + raise ValueError( + '"file_client_args and "backend_args" cannot be both set.') + + file_client = FileClient.infer_client(file_client_args, out_dir) + filepath = file_client.join_path(out_dir, filename) + else: + filepath = join_path(out_dir, filename, backend_args=backend_args) meta.update( cfg=self.cfg.pretty_text, From 64451ad1a0d0f99034d1a8404080f4c57f4c5ea6 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Thu, 15 Sep 2022 17:44:11 +0800 Subject: [PATCH 03/18] fix format --- mmengine/hooks/checkpoint_hook.py | 2 +- mmengine/hooks/logger_hook.py | 6 +++--- mmengine/runner/checkpoint.py | 3 ++- mmengine/runner/runner.py | 3 ++- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index d237e4ed06..186c21ba33 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -133,7 +133,7 @@ def __init__(self, if file_client_args is not None: warnings.warn( - '"file_client_args" will be deprecated in future. ', + '"file_client_args" will be deprecated in future. ' 'Please use "backend_args" instead', DeprecationWarning) if backend_args is not None: raise ValueError( diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py index c1cf669770..df1530f327 100644 --- a/mmengine/hooks/logger_hook.py +++ b/mmengine/hooks/logger_hook.py @@ -91,7 +91,7 @@ def __init__(self, if file_client_args is not None: warnings.warn( - '"file_client_args" will be deprecated in future. ', + '"file_client_args" will be deprecated in future. ' 'Please use "backend_args" instead', DeprecationWarning) if backend_args is not None: raise ValueError( @@ -131,8 +131,8 @@ def before_run(self, runner) -> None: basename = osp.basename(runner.work_dir.rstrip(osp.sep)) self.out_dir = self.file_backend.join_path(self.out_dir, basename) runner.logger.info( - f'Text logs will be saved to {self.out_dir} after the training process.' - ) + f'Text logs will be saved to {self.out_dir} after the ' + 'training process.') self.json_log_path = f'{runner.timestamp}.json' diff --git a/mmengine/runner/checkpoint.py b/mmengine/runner/checkpoint.py index 1acc7c5c04..2d89572cbf 100644 --- a/mmengine/runner/checkpoint.py +++ b/mmengine/runner/checkpoint.py @@ -344,7 +344,8 @@ def load_from_ceph(filename, map_location=None, backend='petrel'): Args: filename (str): checkpoint file path with s3 prefix map_location (str, optional): Same as :func:`torch.load`. - backend (str, optional): The storage backend type. Defaults to 'petrel'. + backend (str, optional): The storage backend type. + Defaults to 'petrel'. Returns: dict or OrderedDict: The loaded checkpoint. diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 24a2c6eba2..8e70786fcb 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -2052,7 +2052,8 @@ def save_checkpoint( file_client = FileClient.infer_client(file_client_args, out_dir) filepath = file_client.join_path(out_dir, filename) else: - filepath = join_path(out_dir, filename, backend_args=backend_args) + filepath = join_path( # type: ignore + out_dir, filename, backend_args=backend_args) meta.update( cfg=self.cfg.pretty_text, From 675dfc9f1dac3b782cd73bfd28e0b19cf23ba094 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Thu, 15 Sep 2022 22:55:04 +0800 Subject: [PATCH 04/18] modify io functions --- mmengine/fileio/io.py | 64 +++++++++++++++++++++++++++----- mmengine/fileio/parse.py | 59 ++++++++++++++++++++++++----- tests/test_fileio/test_fileio.py | 4 +- 3 files changed, 105 insertions(+), 22 deletions(-) diff --git a/mmengine/fileio/io.py b/mmengine/fileio/io.py index ee71cc5df1..c75a74be9d 100644 --- a/mmengine/fileio/io.py +++ b/mmengine/fileio/io.py @@ -31,6 +31,7 @@ >>> fileio.get_bytes('s3://path/of/your/file') """ import json +import warnings from contextlib import contextmanager from io import BytesIO, StringIO from pathlib import Path @@ -775,7 +776,11 @@ def generate_presigned_url( return backend.generate_presigned_url(url, client_method, expires_in) -def load(file, file_format=None, file_client_args=None, **kwargs): +def load(file, + file_format=None, + file_client_args=None, + backend_args=None, + **kwargs): """Load data from json/yaml/pickle files. This method provides a unified api for loading data from serialized files. @@ -792,7 +797,11 @@ def load(file, file_format=None, file_client_args=None, **kwargs): "pickle/pkl". file_client_args (dict, optional): Arguments to instantiate a FileClient. See :class:`mmengine.fileio.FileClient` for details. - Default: None. + Defaults to None. It will be deprecated in future. Please use + ``backend_args`` instead. + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + New in v0.2.0. Examples: >>> load('/path/of/your/file') # file is storaged in disk @@ -809,14 +818,27 @@ def load(file, file_format=None, file_client_args=None, **kwargs): if file_format not in file_handlers: raise TypeError(f'Unsupported format: {file_format}') + if file_client_args is not None: + warnings.warn( + '"file_client_args" will be deprecated in future. ' + 'Please use "backend_args" instead', DeprecationWarning) + if backend_args is not None: + raise ValueError( + '"file_client_args and "backend_args" cannot be both set.') + handler = file_handlers[file_format] if is_str(file): - file_client = FileClient.infer_client(file_client_args, file) + if file_client_args is not None: + file_client = FileClient.infer_client(file_client_args, file) + file_backend = file_client + else: + file_backend = get_file_backend(file, backend_args=backend_args) + if handler.str_like: - with StringIO(file_client.get_text(file)) as f: + with StringIO(file_backend.get_text(file)) as f: obj = handler.load_from_fileobj(f, **kwargs) else: - with BytesIO(file_client.get(file)) as f: + with BytesIO(file_backend.get_bytes(file)) as f: obj = handler.load_from_fileobj(f, **kwargs) elif hasattr(file, 'read'): obj = handler.load_from_fileobj(file, **kwargs) @@ -825,7 +847,12 @@ def load(file, file_format=None, file_client_args=None, **kwargs): return obj -def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs): +def dump(obj, + file=None, + file_format=None, + file_client_args=None, + backend_args=None, + **kwargs): """Dump data to json/yaml/pickle strings or files. This method provides a unified api for dumping data as strings or to files, @@ -842,7 +869,11 @@ def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs): file_format (str, optional): Same as :func:`load`. file_client_args (dict, optional): Arguments to instantiate a FileClient. See :class:`mmengine.fileio.FileClient` for details. - Default: None. + Defaults to None. It will be deprecated in future. Please use + ``backend_args`` instead. + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + New in v0.2.0. Examples: >>> dump('hello world', '/path/of/your/file') # disk @@ -862,19 +893,32 @@ def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs): if file_format not in file_handlers: raise TypeError(f'Unsupported format: {file_format}') + if file_client_args is not None: + warnings.warn( + '"file_client_args" will be deprecated in future. ' + 'Please use "backend_args" instead', DeprecationWarning) + if backend_args is not None: + raise ValueError( + '"file_client_args and "backend_args" cannot be both set.') + handler = file_handlers[file_format] if file is None: return handler.dump_to_str(obj, **kwargs) elif is_str(file): - file_client = FileClient.infer_client(file_client_args, file) + if file_client_args is not None: + file_client = FileClient.infer_client(file_client_args, file) + file_backend = file_client + else: + file_backend = get_file_backend(file, backend_args=backend_args) + if handler.str_like: with StringIO() as f: handler.dump_to_fileobj(obj, f, **kwargs) - file_client.put_text(f.getvalue(), file) + file_backend.put_text(f.getvalue(), file) else: with BytesIO() as f: handler.dump_to_fileobj(obj, f, **kwargs) - file_client.put(f.getvalue(), file) + file_backend.put_bytes(f.getvalue(), file) elif hasattr(file, 'write'): handler.dump_to_fileobj(obj, file, **kwargs) else: diff --git a/mmengine/fileio/parse.py b/mmengine/fileio/parse.py index 8353b62297..139481aa49 100644 --- a/mmengine/fileio/parse.py +++ b/mmengine/fileio/parse.py @@ -1,7 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. +import warnings from io import StringIO from .file_client import FileClient +from .io import get_text def list_from_file(filename, @@ -9,7 +11,8 @@ def list_from_file(filename, offset=0, max_num=0, encoding='utf-8', - file_client_args=None): + file_client_args=None, + backend_args=None): """Load a text file and parse the content as a list of strings. ``list_from_file`` supports loading a text file which can be storaged in @@ -21,10 +24,14 @@ def list_from_file(filename, offset (int): The offset of lines. max_num (int): The maximum number of lines to be read, zeros and negatives mean no limitation. - encoding (str): Encoding used to open the file. Default utf-8. + encoding (str): Encoding used to open the file. Defaults to utf-8. file_client_args (dict, optional): Arguments to instantiate a FileClient. See :class:`mmengine.fileio.FileClient` for details. - Default: None. + Defaults to None. It will be deprecated in future. Please use + ``backend_args`` instead. + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + New in v0.2.0. Examples: >>> list_from_file('/path/of/your/file') # disk @@ -35,10 +42,23 @@ def list_from_file(filename, Returns: list[str]: A list of strings. """ + if file_client_args is not None: + warnings.warn( + '"file_client_args" will be deprecated in future. ' + 'Please use "backend_args" instead', DeprecationWarning) + if backend_args is not None: + raise ValueError( + '"file_client_args and "backend_args" cannot be both set.') cnt = 0 item_list = [] - file_client = FileClient.infer_client(file_client_args, filename) - with StringIO(file_client.get_text(filename, encoding)) as f: + + if file_client_args is not None: + file_client = FileClient.infer_client(file_client_args, filename) + text = file_client.get_text(filename, encoding) + else: + text = get_text(filename, encoding, backend_args=backend_args) + + with StringIO(text) as f: for _ in range(offset): f.readline() for line in f: @@ -52,7 +72,8 @@ def list_from_file(filename, def dict_from_file(filename, key_type=str, encoding='utf-8', - file_client_args=None): + file_client_args=None, + backend_args=None): """Load a text file and parse the content as a dict. Each line of the text file will be two or more columns split by @@ -66,10 +87,14 @@ def dict_from_file(filename, filename(str): Filename. key_type(type): Type of the dict keys. str is user by default and type conversion will be performed if specified. - encoding (str): Encoding used to open the file. Default utf-8. + encoding (str): Encoding used to open the file. Defaults to utf-8. file_client_args (dict, optional): Arguments to instantiate a FileClient. See :class:`mmengine.fileio.FileClient` for details. - Default: None. + Defaults to None. It will be deprecated in future. Please use + ``backend_args`` instead. + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + New in v0.2.0. Examples: >>> dict_from_file('/path/of/your/file') # disk @@ -80,9 +105,23 @@ def dict_from_file(filename, Returns: dict: The parsed contents. """ + if file_client_args is not None: + warnings.warn( + '"file_client_args" will be deprecated in future. ' + 'Please use "backend_args" instead', DeprecationWarning) + if backend_args is not None: + raise ValueError( + '"file_client_args and "backend_args" cannot be both set.') + mapping = {} - file_client = FileClient.infer_client(file_client_args, filename) - with StringIO(file_client.get_text(filename, encoding)) as f: + + if file_client_args is not None: + file_client = FileClient.infer_client(file_client_args, filename) + text = file_client.get_text(filename, encoding) + else: + text = get_text(filename, encoding, backend_args=backend_args) + + with StringIO(text) as f: for line in f: items = line.rstrip('\n').split() assert len(items) >= 2 diff --git a/tests/test_fileio/test_fileio.py b/tests/test_fileio/test_fileio.py index 3077a948f4..fe56a5a272 100644 --- a/tests/test_fileio/test_fileio.py +++ b/tests/test_fileio/test_fileio.py @@ -8,7 +8,7 @@ import pytest import mmengine -from mmengine.fileio.file_client import HTTPBackend, PetrelBackend +from mmengine.fileio.backends import HTTPBackend, PetrelBackend sys.modules['petrel_client'] = MagicMock() sys.modules['petrel_client.client'] = MagicMock() @@ -30,7 +30,7 @@ def _test_handler(file_format, test_obj, str_checker, mode='r+'): os.remove(tmp_filename) # load/dump with filename from petrel - method = 'put' if 'b' in mode else 'put_text' + method = 'put_bytes' if 'b' in mode else 'put_text' with patch.object(PetrelBackend, method, return_value=None) as mock_method: filename = 's3://path/of/your/file' mmengine.dump(test_obj, filename, file_format=file_format) From 835ada0710572de558f4b2a6cbda7a96aa9c5de2 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Thu, 15 Sep 2022 23:18:09 +0800 Subject: [PATCH 05/18] fix ut --- tests/test_fileio/test_fileio.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_fileio/test_fileio.py b/tests/test_fileio/test_fileio.py index fe56a5a272..1b6d3a61fc 100644 --- a/tests/test_fileio/test_fileio.py +++ b/tests/test_fileio/test_fileio.py @@ -8,7 +8,8 @@ import pytest import mmengine -from mmengine.fileio.backends import HTTPBackend, PetrelBackend +from mmengine.fileio.backends import PetrelBackend as _PetrelBackend +from mmengine.fileio.file_client import HTTPBackend, PetrelBackend sys.modules['petrel_client'] = MagicMock() sys.modules['petrel_client.client'] = MagicMock() @@ -31,7 +32,8 @@ def _test_handler(file_format, test_obj, str_checker, mode='r+'): # load/dump with filename from petrel method = 'put_bytes' if 'b' in mode else 'put_text' - with patch.object(PetrelBackend, method, return_value=None) as mock_method: + with patch.object( + _PetrelBackend, method, return_value=None) as mock_method: filename = 's3://path/of/your/file' mmengine.dump(test_obj, filename, file_format=file_format) mock_method.assert_called() From 65633d0b2a59983f258c3d87509fff09d9ff3e7e Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Thu, 15 Sep 2022 23:30:26 +0800 Subject: [PATCH 06/18] fix ut --- tests/test_fileio/test_fileio.py | 35 ++++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/tests/test_fileio/test_fileio.py b/tests/test_fileio/test_fileio.py index 1b6d3a61fc..e586fca741 100644 --- a/tests/test_fileio/test_fileio.py +++ b/tests/test_fileio/test_fileio.py @@ -8,6 +8,7 @@ import pytest import mmengine +from mmengine.fileio.backends import HTTPBackend as _HTTPBackend from mmengine.fileio.backends import PetrelBackend as _PetrelBackend from mmengine.fileio.file_client import HTTPBackend, PetrelBackend @@ -153,30 +154,42 @@ def test_list_from_file(): assert filelist == ['4.jpg', '5.jpg'] # get list from http + filename = 'http://path/of/your/file' with patch.object( HTTPBackend, 'get_text', return_value='1.jpg\n2.jpg\n3.jpg'): - filename = 'http://path/of/your/file' filelist = mmengine.list_from_file( filename, file_client_args={'backend': 'http'}) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] filelist = mmengine.list_from_file( filename, file_client_args={'prefix': 'http'}) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] + + with patch.object( + _HTTPBackend, 'get_text', return_value='1.jpg\n2.jpg\n3.jpg'): filelist = mmengine.list_from_file(filename) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] + filelist = mmengine.list_from_file( + filename, backend_args={'backend': 'http'}) + assert filelist == ['1.jpg', '2.jpg', '3.jpg'] # get list from petrel + filename = 's3://path/of/your/file' with patch.object( PetrelBackend, 'get_text', return_value='1.jpg\n2.jpg\n3.jpg'): - filename = 's3://path/of/your/file' filelist = mmengine.list_from_file( filename, file_client_args={'backend': 'petrel'}) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] filelist = mmengine.list_from_file( filename, file_client_args={'prefix': 's3'}) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] + + with patch.object( + _PetrelBackend, 'get_text', return_value='1.jpg\n2.jpg\n3.jpg'): filelist = mmengine.list_from_file(filename) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] + filelist = mmengine.list_from_file( + filename, backend_args={'backend': 'petrel'}) + assert filelist == ['1.jpg', '2.jpg', '3.jpg'] def test_dict_from_file(): @@ -188,28 +201,42 @@ def test_dict_from_file(): assert mapping == {1: 'cat', 2: ['dog', 'cow'], 3: 'panda'} # get dict from http + filename = 'http://path/of/your/file' with patch.object( HTTPBackend, 'get_text', return_value='1 cat\n2 dog cow\n3 panda'): - filename = 'http://path/of/your/file' mapping = mmengine.dict_from_file( filename, file_client_args={'backend': 'http'}) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} mapping = mmengine.dict_from_file( filename, file_client_args={'prefix': 'http'}) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} + + with patch.object( + _HTTPBackend, 'get_text', + return_value='1 cat\n2 dog cow\n3 panda'): mapping = mmengine.dict_from_file(filename) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} + mapping = mmengine.dict_from_file( + filename, backend_args={'backend': 'http'}) + assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} # get dict from petrel + filename = 's3://path/of/your/file' with patch.object( PetrelBackend, 'get_text', return_value='1 cat\n2 dog cow\n3 panda'): - filename = 's3://path/of/your/file' mapping = mmengine.dict_from_file( filename, file_client_args={'backend': 'petrel'}) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} mapping = mmengine.dict_from_file( filename, file_client_args={'prefix': 's3'}) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} + + with patch.object( + _PetrelBackend, 'get_text', + return_value='1 cat\n2 dog cow\n3 panda'): mapping = mmengine.dict_from_file(filename) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} + mapping = mmengine.dict_from_file( + filename, backend_args={'backend': 'petrel'}) + assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} From 619fd4081eac833ca7bc4676cf9dd8032f11ba5c Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Fri, 16 Sep 2022 19:23:33 +0800 Subject: [PATCH 07/18] rename method names --- mmengine/fileio/__init__.py | 22 +- mmengine/fileio/backends/base.py | 17 +- mmengine/fileio/backends/http_backend.py | 6 +- mmengine/fileio/backends/lmdb_backend.py | 4 +- mmengine/fileio/backends/local_backend.py | 16 +- mmengine/fileio/backends/memcached_backend.py | 2 +- mmengine/fileio/backends/petrel_backend.py | 34 +- mmengine/fileio/file_client.py | 701 +----------------- mmengine/fileio/handlers/__init__.py | 6 +- mmengine/fileio/handlers/registry_utils.py | 42 ++ mmengine/fileio/io.py | 70 +- mmengine/hooks/checkpoint_hook.py | 2 +- mmengine/runner/checkpoint.py | 4 +- .../test_backends/test_backend_utils.py | 10 +- .../test_base_storage_backend.py | 4 +- .../test_backends/test_http_backend.py | 4 +- .../test_backends/test_lmdb_backend.py | 4 +- .../test_backends/test_local_backend.py | 26 +- .../test_backends/test_memcached_backend.py | 4 +- .../test_backends/test_petrel_backend.py | 55 +- tests/test_fileio/test_fileclient.py | 10 +- tests/test_fileio/test_fileio.py | 20 +- tests/test_fileio/test_io.py | 24 +- 23 files changed, 206 insertions(+), 881 deletions(-) create mode 100644 mmengine/fileio/handlers/registry_utils.py diff --git a/mmengine/fileio/__init__.py b/mmengine/fileio/__init__.py index 45e68b3e5e..d8ebbc9047 100644 --- a/mmengine/fileio/__init__.py +++ b/mmengine/fileio/__init__.py @@ -3,13 +3,14 @@ from .file_client import (BaseStorageBackend, FileClient, HardDiskBackend, HTTPBackend, LmdbBackend, MemcachedBackend, PetrelBackend) -from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler +from .handlers import (BaseFileHandler, JsonHandler, PickleHandler, + YamlHandler, register_handler) from .io import (copy_if_symlink_fails, copyfile, copyfile_from_local, copyfile_to_local, copytree, copytree_from_local, - copytree_to_local, dump, exists, generate_presigned_url, - get_bytes, get_file_backend, get_local_path, get_text, isdir, - isfile, join_path, list_dir_or_file, load, put_bytes, - put_text, register_handler, rmfile, rmtree) + copytree_to_local, dump, exists, generate_presigned_url, get, + get_file_backend, get_local_path, get_text, isdir, isfile, + join_path, list_dir_or_file, load, put, put_text, remove, + rmtree) from .parse import dict_from_file, list_from_file __all__ = [ @@ -17,10 +18,9 @@ 'LmdbBackend', 'HardDiskBackend', 'HTTPBackend', 'copy_if_symlink_fails', 'copyfile', 'copyfile_from_local', 'copyfile_to_local', 'copytree', 'copytree_from_local', 'copytree_to_local', 'exists', - 'generate_presigned_url', 'get_bytes', 'get_file_backend', - 'get_local_path', 'get_text', 'isdir', 'isfile', 'join_path', - 'list_dir_or_file', 'put_bytes', 'put_text', 'rmfile', 'rmtree', 'load', - 'dump', 'register_handler', 'BaseFileHandler', 'JsonHandler', - 'PickleHandler', 'YamlHandler', 'list_from_file', 'dict_from_file', - 'register_backend' + 'generate_presigned_url', 'get', 'get_file_backend', 'get_local_path', + 'get_text', 'isdir', 'isfile', 'join_path', 'list_dir_or_file', 'put', + 'put_text', 'remove', 'rmtree', 'load', 'dump', 'register_handler', + 'BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler', + 'list_from_file', 'dict_from_file', 'register_backend' ] diff --git a/mmengine/fileio/backends/base.py b/mmengine/fileio/backends/base.py index 6846f0f2d1..4ebabd3bb6 100644 --- a/mmengine/fileio/backends/base.py +++ b/mmengine/fileio/backends/base.py @@ -1,23 +1,34 @@ # Copyright (c) OpenMMLab. All rights reserved. +import warnings from abc import ABCMeta, abstractmethod class BaseStorageBackend(metaclass=ABCMeta): """Abstract class of storage backends. - All backends need to implement two apis: :meth:`get_bytes()` and + All backends need to implement two apis: :meth:`get()` and :meth:`get_text()`. - - :meth:`get_bytes()` reads the file as a byte stream. + - :meth:`get()` reads the file as a byte stream. - :meth:`get_text()` reads the file as texts. """ + # a flag to indicate whether the backend can create a symlink for a file + # This attribute will be deprecated in future. + _allow_symlink = False + + @property + def allow_symlink(self): + warnings.warn('allow_symlink will be deprecated in future', + DeprecationWarning) + return self._allow_symlink + @property def name(self): return self.__class__.__name__ @abstractmethod - def get_bytes(self, filepath): + def get(self, filepath): pass @abstractmethod diff --git a/mmengine/fileio/backends/http_backend.py b/mmengine/fileio/backends/http_backend.py index 2f4fa3dc1f..8929af6339 100644 --- a/mmengine/fileio/backends/http_backend.py +++ b/mmengine/fileio/backends/http_backend.py @@ -12,7 +12,7 @@ class HTTPBackend(BaseStorageBackend): """HTTP and HTTPS storage bachend.""" - def get_bytes(self, filepath: str) -> bytes: + def get(self, filepath: str) -> bytes: """ead bytes from a given ``filepath``. Args: @@ -23,7 +23,7 @@ def get_bytes(self, filepath: str) -> bytes: Examples: >>> backend = HTTPBackend() - >>> backend.get_bytes('http://path/of/file') + >>> backend.get('http://path/of/file') b'hello world' """ return urlopen(filepath).read() @@ -67,7 +67,7 @@ def get_local_path( """ try: f = tempfile.NamedTemporaryFile(delete=False) - f.write(self.get_bytes(filepath)) + f.write(self.get(filepath)) f.close() yield f.name finally: diff --git a/mmengine/fileio/backends/lmdb_backend.py b/mmengine/fileio/backends/lmdb_backend.py index 49368920ff..e4888a0091 100644 --- a/mmengine/fileio/backends/lmdb_backend.py +++ b/mmengine/fileio/backends/lmdb_backend.py @@ -41,7 +41,7 @@ def __init__(self, self.kwargs = kwargs self._client = None - def get_bytes(self, filepath: Union[str, Path]) -> bytes: + def get(self, filepath: Union[str, Path]) -> bytes: """Get values according to the filepath. Args: @@ -52,7 +52,7 @@ def get_bytes(self, filepath: Union[str, Path]) -> bytes: Examples: >>> backend = LmdbBackend('path/to/lmdb') - >>> backend.get_bytes('key') + >>> backend.get('key') b'hello world' """ if self._client is None: diff --git a/mmengine/fileio/backends/local_backend.py b/mmengine/fileio/backends/local_backend.py index 16716275d0..09dec4b263 100644 --- a/mmengine/fileio/backends/local_backend.py +++ b/mmengine/fileio/backends/local_backend.py @@ -13,7 +13,9 @@ class LocalBackend(BaseStorageBackend): """Raw hard disks storage backend.""" - def get_bytes(self, filepath: Union[str, Path]) -> bytes: + _allow_symlink = True + + def get(self, filepath: Union[str, Path]) -> bytes: """Read bytes from a given ``filepath`` with 'rb' mode. Args: @@ -25,7 +27,7 @@ def get_bytes(self, filepath: Union[str, Path]) -> bytes: Examples: >>> backend = LocalBackend() >>> filepath = '/path/of/file' - >>> backend.get_bytes(filepath) + >>> backend.get(filepath) b'hello world' """ with open(filepath, 'rb') as f: @@ -55,11 +57,11 @@ def get_text(self, text = f.read() return text - def put_bytes(self, obj: bytes, filepath: Union[str, Path]) -> None: + def put(self, obj: bytes, filepath: Union[str, Path]) -> None: """Write bytes to a given ``filepath`` with 'wb' mode. Note: - ``put_bytes`` will create a directory if the directory of + ``put`` will create a directory if the directory of ``filepath`` does not exist. Args: @@ -69,7 +71,7 @@ def put_bytes(self, obj: bytes, filepath: Union[str, Path]) -> None: Examples: >>> backend = LocalBackend() >>> filepath = '/path/of/file' - >>> backend.put_bytes(b'hello world', filepath) + >>> backend.put(b'hello world', filepath) """ mmengine.mkdir_or_exist(osp.dirname(filepath)) with open(filepath, 'wb') as f: @@ -393,7 +395,7 @@ def copytree_to_local( """ return self.copytree(src, dst) - def rmfile(self, filepath: Union[str, Path]) -> None: + def remove(self, filepath: Union[str, Path]) -> None: """Remove a file. Args: @@ -408,7 +410,7 @@ def rmfile(self, filepath: Union[str, Path]) -> None: Examples: >>> backend = LocalBackend() >>> filepath = '/path/of/file' - >>> backend.rmfile(filepath) + >>> backend.remove(filepath) """ os.remove(filepath) diff --git a/mmengine/fileio/backends/memcached_backend.py b/mmengine/fileio/backends/memcached_backend.py index 6e672468f5..3ef92b04e7 100644 --- a/mmengine/fileio/backends/memcached_backend.py +++ b/mmengine/fileio/backends/memcached_backend.py @@ -32,7 +32,7 @@ def __init__(self, server_list_cfg, client_cfg, sys_path=None): # mc.pyvector servers as a point which points to a memory cache self._mc_buffer = mc.pyvector() - def get_bytes(self, filepath: Union[str, Path]): + def get(self, filepath: Union[str, Path]): filepath = str(filepath) import mc self._client.Get(filepath, self._mc_buffer) diff --git a/mmengine/fileio/backends/petrel_backend.py b/mmengine/fileio/backends/petrel_backend.py index c708771bf0..0328ad8b11 100644 --- a/mmengine/fileio/backends/petrel_backend.py +++ b/mmengine/fileio/backends/petrel_backend.py @@ -32,8 +32,8 @@ class PetrelBackend(BaseStorageBackend): >>> backend = PetrelBackend() >>> filepath1 = 'petrel://path/of/file' >>> filepath2 = 'cluster-name:petrel://path/of/file' - >>> backend.get_bytes(filepath1) # get data from default cluster - >>> client.get_bytes(filepath2) # get data from 'cluster-name' cluster + >>> backend.get(filepath1) # get data from default cluster + >>> client.get(filepath2) # get data from 'cluster-name' cluster """ def __init__(self, @@ -79,7 +79,7 @@ def _replace_prefix(self, filepath: Union[str, Path]) -> str: filepath = str(filepath) return filepath.replace('petrel://', 's3://') - def get_bytes(self, filepath: Union[str, Path]) -> bytes: + def get(self, filepath: Union[str, Path]) -> bytes: """Read bytes from a given ``filepath`` with 'rb' mode. Args: @@ -91,7 +91,7 @@ def get_bytes(self, filepath: Union[str, Path]) -> bytes: Examples: >>> backend = PetrelBackend() >>> filepath = 'petrel://path/of/file' - >>> backend.get_bytes(filepath) + >>> backend.get(filepath) b'hello world' """ filepath = self._map_path(filepath) @@ -121,9 +121,9 @@ def get_text( >>> backend.get_text(filepath) 'hello world' """ - return str(self.get_bytes(filepath), encoding=encoding) + return str(self.get(filepath), encoding=encoding) - def put_bytes(self, obj: bytes, filepath: Union[str, Path]) -> None: + def put(self, obj: bytes, filepath: Union[str, Path]) -> None: """Write bytes to a given ``filepath``. Args: @@ -133,7 +133,7 @@ def put_bytes(self, obj: bytes, filepath: Union[str, Path]) -> None: Examples: >>> backend = PetrelBackend() >>> filepath = 'petrel://path/of/file' - >>> backend.put_bytes(b'hello world', filepath) + >>> backend.put(b'hello world', filepath) """ filepath = self._map_path(filepath) filepath = self._format_path(filepath) @@ -159,7 +159,7 @@ def put_text( >>> filepath = 'petrel://path/of/file' >>> backend.put_text('hello world', filepath) """ - self.put_bytes(bytes(obj, encoding=encoding), filepath) + self.put(bytes(obj, encoding=encoding), filepath) def exists(self, filepath: Union[str, Path]) -> bool: """Check whether a file path exists. @@ -305,7 +305,7 @@ def get_local_path( assert self.isfile(filepath) try: f = tempfile.NamedTemporaryFile(delete=False) - f.write(self.get_bytes(filepath)) + f.write(self.get(filepath)) f.close() yield f.name finally: @@ -354,7 +354,7 @@ def copyfile( if src == dst: raise SameFileError('src and dst should not be same') - self.put_bytes(self.get_bytes(src), dst) + self.put(self.get(src), dst) return dst def copytree( @@ -396,7 +396,7 @@ def copytree( for path in self.list_dir_or_file(src, list_dir=False, recursive=True): src_path = self.join_path(src, path) dst_path = self.join_path(dst, path) - self.put_bytes(self.get_bytes(src_path), dst_path) + self.put(self.get(src_path), dst_path) return dst @@ -435,7 +435,7 @@ def copyfile_from_local( dst = self.join_path(dst, osp.basename(src)) with open(src, 'rb') as f: - self.put_bytes(f.read(), dst) + self.put(f.read(), dst) return dst @@ -520,7 +520,7 @@ def copyfile_to_local( dst = dst / basename with open(dst, 'wb') as f: - f.write(self.get_bytes(src)) + f.write(self.get(src)) return dst @@ -552,11 +552,11 @@ def copytree_to_local( dst_path = osp.join(dst, path) mmengine.mkdir_or_exist(osp.dirname(dst_path)) with open(dst_path, 'wb') as f: - f.write(self.get_bytes(self.join_path(src, path))) + f.write(self.get(self.join_path(src, path))) return dst - def rmfile(self, filepath: Union[str, Path]) -> None: + def remove(self, filepath: Union[str, Path]) -> None: """Remove a file. Args: @@ -571,7 +571,7 @@ def rmfile(self, filepath: Union[str, Path]) -> None: Examples: >>> backend = PetrelBackend() >>> filepath = 'petrel://path/of/file' - >>> backend.rmfile(filepath) + >>> backend.remove(filepath) """ if not has_method(self._client, 'delete'): raise NotImplementedError( @@ -604,7 +604,7 @@ def rmtree(self, dir_path: Union[str, Path]) -> None: for path in self.list_dir_or_file( dir_path, list_dir=False, recursive=True): filepath = self.join_path(dir_path, path) - self.rmfile(filepath) + self.remove(filepath) def copy_if_symlink_fails( self, diff --git a/mmengine/fileio/file_client.py b/mmengine/fileio/file_client.py index 96822ba234..ad759e5023 100644 --- a/mmengine/fileio/file_client.py +++ b/mmengine/fileio/file_client.py @@ -1,709 +1,20 @@ # Copyright (c) OpenMMLab. All rights reserved. import inspect -import os -import os.path as osp -import re -import tempfile -from abc import ABCMeta, abstractmethod from contextlib import contextmanager from pathlib import Path from typing import Any, Generator, Iterator, Optional, Tuple, Union -from urllib.request import urlopen -from mmengine.utils import has_method, is_filepath, mkdir_or_exist +from mmengine.utils import is_filepath +from .backends import (BaseStorageBackend, HTTPBackend, LmdbBackend, + LocalBackend, MemcachedBackend, PetrelBackend) -class BaseStorageBackend(metaclass=ABCMeta): - """Abstract class of storage backends. - - All backends need to implement two apis: ``get()`` and ``get_text()``. - ``get()`` reads the file as a byte stream and ``get_text()`` reads the file - as texts. - """ - - # a flag to indicate whether the backend can create a symlink for a file - _allow_symlink = False +class HardDiskBackend(LocalBackend): @property def name(self): return self.__class__.__name__ - @property - def allow_symlink(self): - return self._allow_symlink - - @abstractmethod - def get(self, filepath): - pass - - @abstractmethod - def get_text(self, filepath): - pass - - -class PetrelBackend(BaseStorageBackend): - """Petrel storage backend (for internal usage). - - PetrelBackend supports reading and writing data to multiple clusters. - If the file path contains the cluster name, PetrelBackend will read data - from specified cluster or write data to it. Otherwise, PetrelBackend will - access the default cluster. - - Args: - path_mapping (dict, optional): Path mapping dict from local path to - Petrel path. When ``path_mapping={'src': 'dst'}``, ``src`` in - ``filepath`` will be replaced by ``dst``. Default: None. - enable_mc (bool, optional): Whether to enable memcached support. - Default: True. - - Examples: - >>> filepath1 = 's3://path/of/file' - >>> filepath2 = 'cluster-name:s3://path/of/file' - >>> client = PetrelBackend() - >>> client.get(filepath1) # get data from default cluster - >>> client.get(filepath2) # get data from 'cluster-name' cluster - """ - - def __init__(self, - path_mapping: Optional[dict] = None, - enable_mc: bool = True): - try: - from petrel_client import client - except ImportError: - raise ImportError('Please install petrel_client to enable ' - 'PetrelBackend.') - - self._client = client.Client(enable_mc=enable_mc) - assert isinstance(path_mapping, dict) or path_mapping is None - self.path_mapping = path_mapping - - def _map_path(self, filepath: Union[str, Path]) -> str: - """Map ``filepath`` to a string path whose prefix will be replaced by - :attr:`self.path_mapping`. - - Args: - filepath (str): Path to be mapped. - """ - filepath = str(filepath) - if self.path_mapping is not None: - for k, v in self.path_mapping.items(): - filepath = filepath.replace(k, v, 1) - return filepath - - def _format_path(self, filepath: str) -> str: - """Convert a ``filepath`` to standard format of petrel oss. - - If the ``filepath`` is concatenated by ``os.path.join``, in a Windows - environment, the ``filepath`` will be the format of - 's3://bucket_name\\image.jpg'. By invoking :meth:`_format_path`, the - above ``filepath`` will be converted to 's3://bucket_name/image.jpg'. - - Args: - filepath (str): Path to be formatted. - """ - return re.sub(r'\\+', '/', filepath) - - def get(self, filepath: Union[str, Path]) -> memoryview: - """Read data from a given ``filepath`` with 'rb' mode. - - Args: - filepath (str or Path): Path to read data. - - Returns: - memoryview: A memory view of expected bytes object to avoid - copying. The memoryview object can be converted to bytes by - ``value_buf.tobytes()``. - """ - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - value = self._client.Get(filepath) - value_buf = memoryview(value) - return value_buf - - def get_text(self, - filepath: Union[str, Path], - encoding: str = 'utf-8') -> str: - """Read data from a given ``filepath`` with 'r' mode. - - Args: - filepath (str or Path): Path to read data. - encoding (str): The encoding format used to open the ``filepath``. - Default: 'utf-8'. - - Returns: - str: Expected text reading from ``filepath``. - """ - return str(self.get(filepath), encoding=encoding) - - def put(self, obj: bytes, filepath: Union[str, Path]) -> None: - """Save data to a given ``filepath``. - - Args: - obj (bytes): Data to be saved. - filepath (str or Path): Path to write data. - """ - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - self._client.put(filepath, obj) - - def put_text(self, - obj: str, - filepath: Union[str, Path], - encoding: str = 'utf-8') -> None: - """Save data to a given ``filepath``. - - Args: - obj (str): Data to be written. - filepath (str or Path): Path to write data. - encoding (str): The encoding format used to encode the ``obj``. - Default: 'utf-8'. - """ - self.put(bytes(obj, encoding=encoding), filepath) - - def remove(self, filepath: Union[str, Path]) -> None: - """Remove a file. - - Args: - filepath (str or Path): Path to be removed. - """ - if not has_method(self._client, 'delete'): - raise NotImplementedError( - 'Current version of Petrel Python SDK has not supported ' - 'the `delete` method, please use a higher version or dev' - ' branch instead.') - - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - self._client.delete(filepath) - - def exists(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path exists. - - Args: - filepath (str or Path): Path to be checked whether exists. - - Returns: - bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. - """ - if not (has_method(self._client, 'contains') - and has_method(self._client, 'isdir')): - raise NotImplementedError( - 'Current version of Petrel Python SDK has not supported ' - 'the `contains` and `isdir` methods, please use a higher' - 'version or dev branch instead.') - - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - return self._client.contains(filepath) or self._client.isdir(filepath) - - def isdir(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path is a directory. - - Args: - filepath (str or Path): Path to be checked whether it is a - directory. - - Returns: - bool: Return ``True`` if ``filepath`` points to a directory, - ``False`` otherwise. - """ - if not has_method(self._client, 'isdir'): - raise NotImplementedError( - 'Current version of Petrel Python SDK has not supported ' - 'the `isdir` method, please use a higher version or dev' - ' branch instead.') - - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - return self._client.isdir(filepath) - - def isfile(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path is a file. - - Args: - filepath (str or Path): Path to be checked whether it is a file. - - Returns: - bool: Return ``True`` if ``filepath`` points to a file, ``False`` - otherwise. - """ - if not has_method(self._client, 'contains'): - raise NotImplementedError( - 'Current version of Petrel Python SDK has not supported ' - 'the `contains` method, please use a higher version or ' - 'dev branch instead.') - - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - return self._client.contains(filepath) - - def join_path(self, filepath: Union[str, Path], - *filepaths: Union[str, Path]) -> str: - """Concatenate all file paths. - - Args: - filepath (str or Path): Path to be concatenated. - - Returns: - str: The result after concatenation. - """ - filepath = self._format_path(self._map_path(filepath)) - if filepath.endswith('/'): - filepath = filepath[:-1] - formatted_paths = [filepath] - for path in filepaths: - formatted_paths.append(self._format_path(self._map_path(path))) - return '/'.join(formatted_paths) - - @contextmanager - def get_local_path( - self, - filepath: Union[str, - Path]) -> Generator[Union[str, Path], None, None]: - """Download a file from ``filepath`` and return a temporary path. - - ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It - can be called with ``with`` statement, and when exists from the - ``with`` statement, the temporary path will be released. - - Args: - filepath (str | Path): Download a file from ``filepath``. - - Examples: - >>> client = PetrelBackend() - >>> # After existing from the ``with`` clause, - >>> # the path will be removed - >>> with client.get_local_path('s3://path/of/your/file') as path: - ... # do something here - - Yields: - Iterable[str]: Only yield one temporary path. - """ - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - assert self.isfile(filepath) - try: - f = tempfile.NamedTemporaryFile(delete=False) - f.write(self.get(filepath)) - f.close() - yield f.name - finally: - os.remove(f.name) - - def list_dir_or_file(self, - dir_path: Union[str, Path], - list_dir: bool = True, - list_file: bool = True, - suffix: Optional[Union[str, Tuple[str]]] = None, - recursive: bool = False) -> Iterator[str]: - """Scan a directory to find the interested directories or files in - arbitrary order. - - Note: - Petrel has no concept of directories but it simulates the directory - hierarchy in the filesystem through public prefixes. In addition, - if the returned path ends with '/', it means the path is a public - prefix which is a logical directory. - - Note: - :meth:`list_dir_or_file` returns the path relative to ``dir_path``. - In addition, the returned path of directory will not contains the - suffix '/' which is consistent with other backends. - - Args: - dir_path (str | Path): Path of the directory. - list_dir (bool): List the directories. Default: True. - list_file (bool): List the path of files. Default: True. - suffix (str or tuple[str], optional): File suffix - that we are interested in. Default: None. - recursive (bool): If set to True, recursively scan the - directory. Default: False. - - Yields: - Iterable[str]: A relative path to ``dir_path``. - """ - if not has_method(self._client, 'list'): - raise NotImplementedError( - 'Current version of Petrel Python SDK has not supported ' - 'the `list` method, please use a higher version or dev' - ' branch instead.') - - dir_path = self._map_path(dir_path) - dir_path = self._format_path(dir_path) - if list_dir and suffix is not None: - raise TypeError( - '`list_dir` should be False when `suffix` is not None') - - if (suffix is not None) and not isinstance(suffix, (str, tuple)): - raise TypeError('`suffix` must be a string or tuple of strings') - - # Petrel's simulated directory hierarchy assumes that directory paths - # should end with `/` - if not dir_path.endswith('/'): - dir_path += '/' - - root = dir_path - - def _list_dir_or_file(dir_path, list_dir, list_file, suffix, - recursive): - for path in self._client.list(dir_path): - # the `self.isdir` is not used here to determine whether path - # is a directory, because `self.isdir` relies on - # `self._client.list` - if path.endswith('/'): # a directory path - next_dir_path = self.join_path(dir_path, path) - if list_dir: - # get the relative path and exclude the last - # character '/' - rel_dir = next_dir_path[len(root):-1] - yield rel_dir - if recursive: - yield from _list_dir_or_file(next_dir_path, list_dir, - list_file, suffix, - recursive) - else: # a file path - absolute_path = self.join_path(dir_path, path) - rel_path = absolute_path[len(root):] - if (suffix is None - or rel_path.endswith(suffix)) and list_file: - yield rel_path - - return _list_dir_or_file(dir_path, list_dir, list_file, suffix, - recursive) - - -class MemcachedBackend(BaseStorageBackend): - """Memcached storage backend. - - Attributes: - server_list_cfg (str): Config file for memcached server list. - client_cfg (str): Config file for memcached client. - sys_path (str | None): Additional path to be appended to `sys.path`. - Default: None. - """ - - def __init__(self, server_list_cfg, client_cfg, sys_path=None): - if sys_path is not None: - import sys - sys.path.append(sys_path) - try: - import mc - except ImportError: - raise ImportError( - 'Please install memcached to enable MemcachedBackend.') - - self.server_list_cfg = server_list_cfg - self.client_cfg = client_cfg - self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, - self.client_cfg) - # mc.pyvector servers as a point which points to a memory cache - self._mc_buffer = mc.pyvector() - - def get(self, filepath): - filepath = str(filepath) - import mc - self._client.Get(filepath, self._mc_buffer) - value_buf = mc.ConvertBuffer(self._mc_buffer) - return value_buf - - def get_text(self, filepath, encoding=None): - raise NotImplementedError - - -class LmdbBackend(BaseStorageBackend): - """Lmdb storage backend. - - Args: - db_path (str): Lmdb database path. - readonly (bool, optional): Lmdb environment parameter. If True, - disallow any write operations. Default: True. - lock (bool, optional): Lmdb environment parameter. If False, when - concurrent access occurs, do not lock the database. Default: False. - readahead (bool, optional): Lmdb environment parameter. If False, - disable the OS filesystem readahead mechanism, which may improve - random read performance when a database is larger than RAM. - Default: False. - - Attributes: - db_path (str): Lmdb database path. - """ - - def __init__(self, - db_path, - readonly=True, - lock=False, - readahead=False, - **kwargs): - try: - import lmdb # NOQA - except ImportError: - raise ImportError('Please install lmdb to enable LmdbBackend.') - - self.db_path = str(db_path) - self.readonly = readonly - self.lock = lock - self.readahead = readahead - self.kwargs = kwargs - self._client = None - - def get(self, filepath): - """Get values according to the filepath. - - Args: - filepath (str | obj:`Path`): Here, filepath is the lmdb key. - """ - if self._client is None: - self._client = self._get_client() - - with self._client.begin(write=False) as txn: - value_buf = txn.get(str(filepath).encode('utf-8')) - return value_buf - - def get_text(self, filepath, encoding=None): - raise NotImplementedError - - def _get_client(self): - import lmdb - - return lmdb.open( - self.db_path, - readonly=self.readonly, - lock=self.lock, - readahead=self.readahead, - **self.kwargs) - - def __del__(self): - self._client.close() - - -class HardDiskBackend(BaseStorageBackend): - """Raw hard disks storage backend.""" - - _allow_symlink = True - - def get(self, filepath: Union[str, Path]) -> bytes: - """Read data from a given ``filepath`` with 'rb' mode. - - Args: - filepath (str or Path): Path to read data. - - Returns: - bytes: Expected bytes object. - """ - with open(filepath, 'rb') as f: - value_buf = f.read() - return value_buf - - def get_text(self, - filepath: Union[str, Path], - encoding: str = 'utf-8') -> str: - """Read data from a given ``filepath`` with 'r' mode. - - Args: - filepath (str or Path): Path to read data. - encoding (str): The encoding format used to open the ``filepath``. - Default: 'utf-8'. - - Returns: - str: Expected text reading from ``filepath``. - """ - with open(filepath, encoding=encoding) as f: - value_buf = f.read() - return value_buf - - def put(self, obj: bytes, filepath: Union[str, Path]) -> None: - """Write data to a given ``filepath`` with 'wb' mode. - - Note: - ``put`` will create a directory if the directory of ``filepath`` - does not exist. - - Args: - obj (bytes): Data to be written. - filepath (str or Path): Path to write data. - """ - mkdir_or_exist(osp.dirname(filepath)) - with open(filepath, 'wb') as f: - f.write(obj) - - def put_text(self, - obj: str, - filepath: Union[str, Path], - encoding: str = 'utf-8') -> None: - """Write data to a given ``filepath`` with 'w' mode. - - Note: - ``put_text`` will create a directory if the directory of - ``filepath`` does not exist. - - Args: - obj (str): Data to be written. - filepath (str or Path): Path to write data. - encoding (str): The encoding format used to open the ``filepath``. - Default: 'utf-8'. - """ - mkdir_or_exist(osp.dirname(filepath)) - with open(filepath, 'w', encoding=encoding) as f: - f.write(obj) - - def remove(self, filepath: Union[str, Path]) -> None: - """Remove a file. - - Args: - filepath (str or Path): Path to be removed. - """ - os.remove(filepath) - - def exists(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path exists. - - Args: - filepath (str or Path): Path to be checked whether exists. - - Returns: - bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. - """ - return osp.exists(filepath) - - def isdir(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path is a directory. - - Args: - filepath (str or Path): Path to be checked whether it is a - directory. - - Returns: - bool: Return ``True`` if ``filepath`` points to a directory, - ``False`` otherwise. - """ - return osp.isdir(filepath) - - def isfile(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path is a file. - - Args: - filepath (str or Path): Path to be checked whether it is a file. - - Returns: - bool: Return ``True`` if ``filepath`` points to a file, ``False`` - otherwise. - """ - return osp.isfile(filepath) - - def join_path(self, filepath: Union[str, Path], - *filepaths: Union[str, Path]) -> str: - """Concatenate all file paths. - - Join one or more filepath components intelligently. The return value - is the concatenation of filepath and any members of *filepaths. - - Args: - filepath (str or Path): Path to be concatenated. - - Returns: - str: The result of concatenation. - """ - return osp.join(filepath, *filepaths) - - @contextmanager - def get_local_path( - self, - filepath: Union[str, - Path]) -> Generator[Union[str, Path], None, None]: - """Only for unified API and do nothing.""" - yield filepath - - def list_dir_or_file(self, - dir_path: Union[str, Path], - list_dir: bool = True, - list_file: bool = True, - suffix: Optional[Union[str, Tuple[str]]] = None, - recursive: bool = False) -> Iterator[str]: - """Scan a directory to find the interested directories or files in - arbitrary order. - - Note: - :meth:`list_dir_or_file` returns the path relative to ``dir_path``. - - Args: - dir_path (str | Path): Path of the directory. - list_dir (bool): List the directories. Default: True. - list_file (bool): List the path of files. Default: True. - suffix (str or tuple[str], optional): File suffix - that we are interested in. Default: None. - recursive (bool): If set to True, recursively scan the - directory. Default: False. - - Yields: - Iterable[str]: A relative path to ``dir_path``. - """ - if list_dir and suffix is not None: - raise TypeError('`suffix` should be None when `list_dir` is True') - - if (suffix is not None) and not isinstance(suffix, (str, tuple)): - raise TypeError('`suffix` must be a string or tuple of strings') - - root = dir_path - - def _list_dir_or_file(dir_path, list_dir, list_file, suffix, - recursive): - for entry in os.scandir(dir_path): - if not entry.name.startswith('.') and entry.is_file(): - rel_path = osp.relpath(entry.path, root) - if (suffix is None - or rel_path.endswith(suffix)) and list_file: - yield rel_path - elif osp.isdir(entry.path): - if list_dir: - rel_dir = osp.relpath(entry.path, root) - yield rel_dir - if recursive: - yield from _list_dir_or_file(entry.path, list_dir, - list_file, suffix, - recursive) - - return _list_dir_or_file(dir_path, list_dir, list_file, suffix, - recursive) - - -class HTTPBackend(BaseStorageBackend): - """HTTP and HTTPS storage bachend.""" - - def get(self, filepath): - value_buf = urlopen(filepath).read() - return value_buf - - def get_text(self, filepath, encoding='utf-8'): - value_buf = urlopen(filepath).read() - return value_buf.decode(encoding) - - @contextmanager - def get_local_path( - self, filepath: str) -> Generator[Union[str, Path], None, None]: - """Download a file from ``filepath``. - - ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It - can be called with ``with`` statement, and when exists from the - ``with`` statement, the temporary path will be released. - - Args: - filepath (str): Download a file from ``filepath``. - - Examples: - >>> client = HTTPBackend() - >>> # After existing from the ``with`` clause, - >>> # the path will be removed - >>> with client.get_local_path('http://path/of/your/file') as path: - ... # do something here - """ - try: - f = tempfile.NamedTemporaryFile(delete=False) - f.write(self.get(filepath)) - f.close() - yield f.name - finally: - os.remove(f.name) - class FileClient: """A general file client to access files in different backends. @@ -1017,10 +328,6 @@ def remove(self, filepath: Union[str, Path]) -> None: """ self.client.remove(filepath) - get_bytes = get - put_bytes = put - rmfile = remove - def exists(self, filepath: Union[str, Path]) -> bool: """Check whether a file path exists. diff --git a/mmengine/fileio/handlers/__init__.py b/mmengine/fileio/handlers/__init__.py index aa24d91972..391a60c36b 100644 --- a/mmengine/fileio/handlers/__init__.py +++ b/mmengine/fileio/handlers/__init__.py @@ -2,6 +2,10 @@ from .base import BaseFileHandler from .json_handler import JsonHandler from .pickle_handler import PickleHandler +from .registry_utils import file_handlers, register_handler from .yaml_handler import YamlHandler -__all__ = ['BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler'] +__all__ = [ + 'BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler', + 'register_handler', 'file_handlers' +] diff --git a/mmengine/fileio/handlers/registry_utils.py b/mmengine/fileio/handlers/registry_utils.py new file mode 100644 index 0000000000..106fc881f2 --- /dev/null +++ b/mmengine/fileio/handlers/registry_utils.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.utils import is_list_of +from .base import BaseFileHandler +from .json_handler import JsonHandler +from .pickle_handler import PickleHandler +from .yaml_handler import YamlHandler + +file_handlers = { + 'json': JsonHandler(), + 'yaml': YamlHandler(), + 'yml': YamlHandler(), + 'pickle': PickleHandler(), + 'pkl': PickleHandler(), +} + + +def _register_handler(handler, file_formats): + """Register a handler for some file extensions. + + Args: + handler (:obj:`BaseFileHandler`): Handler to be registered. + file_formats (str or list[str]): File formats to be handled by this + handler. + """ + if not isinstance(handler, BaseFileHandler): + raise TypeError( + f'handler must be a child of BaseFileHandler, not {type(handler)}') + if isinstance(file_formats, str): + file_formats = [file_formats] + if not is_list_of(file_formats, str): + raise TypeError('file_formats must be a str or a list of str') + for ext in file_formats: + file_handlers[ext] = handler + + +def register_handler(file_formats, **kwargs): + + def wrap(cls): + _register_handler(cls(**kwargs), file_formats) + return cls + + return wrap diff --git a/mmengine/fileio/io.py b/mmengine/fileio/io.py index c75a74be9d..5c51a995ce 100644 --- a/mmengine/fileio/io.py +++ b/mmengine/fileio/io.py @@ -25,10 +25,10 @@ >>> # Initialize a file backend and call its methods >>> import mmengine.fileio as fileio >>> backend = fileio.get_file_backend(backend_args={'backend': 'petrel'}) - >>> backend.get_bytes('s3://path/of/your/file') + >>> backend.get('s3://path/of/your/file') >>> # Directory call unified I/O functions - >>> fileio.get_bytes('s3://path/of/your/file') + >>> fileio.get('s3://path/of/your/file') """ import json import warnings @@ -37,19 +37,15 @@ from pathlib import Path from typing import Generator, Iterator, Optional, Tuple, Union -from mmengine.utils import is_filepath, is_list_of, is_str +from mmengine.utils import is_filepath, is_str from .backends import backends, prefix_to_backends from .file_client import FileClient -from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler +# file_handlers and register_handler had been moved to +# mmengine/fileio/handlers/registry_utis. Import them +# in this file to keep backward compatibility. +from .handlers import file_handlers, register_handler # noqa: F401 backend_instances: dict = {} -file_handlers = { - 'json': JsonHandler(), - 'yaml': YamlHandler(), - 'yml': YamlHandler(), - 'pickle': PickleHandler(), - 'pkl': PickleHandler() -} def _parse_uri_prefix(uri: Union[str, Path]) -> str: @@ -159,7 +155,7 @@ def get_file_backend( return backend -def get_bytes( +def get( filepath: Union[str, Path], backend_args: Optional[dict] = None, ) -> bytes: @@ -175,12 +171,12 @@ def get_bytes( Examples: >>> filepath = '/path/of/file' - >>> get_bytes(filepath) + >>> get(filepath) b'hello world' """ backend = get_file_backend( filepath, backend_args=backend_args, enable_singleton=True) - return backend.get_bytes(filepath) + return backend.get(filepath) def get_text( @@ -210,7 +206,7 @@ def get_text( return backend.get_text(filepath, encoding) -def put_bytes( +def put( obj: bytes, filepath: Union[str, Path], backend_args: Optional[dict] = None, @@ -218,7 +214,7 @@ def put_bytes( """Write bytes to a given ``filepath`` with 'wb' mode. Note: - ``put_bytes`` should create a directory if the directory of + ``put`` should create a directory if the directory of ``filepath`` does not exist. Args: @@ -229,11 +225,11 @@ def put_bytes( Examples: >>> filepath = '/path/of/file' - >>> put_bytes(b'hello world', filepath) + >>> put(b'hello world', filepath) """ backend = get_file_backend( filepath, backend_args=backend_args, enable_singleton=True) - backend.put_bytes(obj, filepath) + backend.put(obj, filepath) def put_text( @@ -628,7 +624,7 @@ def copytree_to_local( return backend.copytree_to_local(src, dst) -def rmfile( +def remove( filepath: Union[str, Path], backend_args: Optional[dict] = None, ) -> None: @@ -647,11 +643,11 @@ def rmfile( Examples: >>> filepath = '/path/of/file' - >>> rmfile(filepath) + >>> remove(filepath) """ backend = get_file_backend( filepath, backend_args=backend_args, enable_singleton=True) - backend.rmfile(filepath) + backend.remove(filepath) def rmtree( @@ -838,7 +834,7 @@ def load(file, with StringIO(file_backend.get_text(file)) as f: obj = handler.load_from_fileobj(f, **kwargs) else: - with BytesIO(file_backend.get_bytes(file)) as f: + with BytesIO(file_backend.get(file)) as f: obj = handler.load_from_fileobj(f, **kwargs) elif hasattr(file, 'read'): obj = handler.load_from_fileobj(file, **kwargs) @@ -918,36 +914,8 @@ def dump(obj, else: with BytesIO() as f: handler.dump_to_fileobj(obj, f, **kwargs) - file_backend.put_bytes(f.getvalue(), file) + file_backend.put(f.getvalue(), file) elif hasattr(file, 'write'): handler.dump_to_fileobj(obj, file, **kwargs) else: raise TypeError('"file" must be a filename str or a file-object') - - -def _register_handler(handler, file_formats): - """Register a handler for some file extensions. - - Args: - handler (:obj:`BaseFileHandler`): Handler to be registered. - file_formats (str or list[str]): File formats to be handled by this - handler. - """ - if not isinstance(handler, BaseFileHandler): - raise TypeError( - f'handler must be a child of BaseFileHandler, not {type(handler)}') - if isinstance(file_formats, str): - file_formats = [file_formats] - if not is_list_of(file_formats, str): - raise TypeError('file_formats must be a str or a list of str') - for ext in file_formats: - file_handlers[ext] = handler - - -def register_handler(file_formats, **kwargs): - - def wrap(cls): - _register_handler(cls(**kwargs), file_formats) - return cls - - return wrap diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index 186c21ba33..ec07297acb 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -337,7 +337,7 @@ def _save_checkpoint(self, runner) -> None: ckpt_path = self.file_backend.join_path( self.out_dir, filename_tmpl.format(_step)) if self.file_backend.isfile(ckpt_path): - self.file_backend.rmfile(ckpt_path) + self.file_backend.remove(ckpt_path) else: break diff --git a/mmengine/runner/checkpoint.py b/mmengine/runner/checkpoint.py index 2d89572cbf..4080415e46 100644 --- a/mmengine/runner/checkpoint.py +++ b/mmengine/runner/checkpoint.py @@ -352,7 +352,7 @@ def load_from_ceph(filename, map_location=None, backend='petrel'): """ file_backend = get_file_backend( filename, backend_args={'backend': backend}) - with io.BytesIO(file_backend.get_bytes(filename)) as buffer: + with io.BytesIO(file_backend.get(filename)) as buffer: checkpoint = torch.load(buffer, map_location=map_location) return checkpoint @@ -697,7 +697,7 @@ def save_checkpoint(checkpoint, with io.BytesIO() as f: torch.save(checkpoint, f) - file_backend.put_bytes(f.getvalue(), filename) + file_backend.put(f.getvalue(), filename) def find_latest_checkpoint(path: str) -> Optional[str]: diff --git a/tests/test_fileio/test_backends/test_backend_utils.py b/tests/test_fileio/test_backends/test_backend_utils.py index 310ecf7d47..7903f5574e 100644 --- a/tests/test_fileio/test_backends/test_backend_utils.py +++ b/tests/test_fileio/test_backends/test_backend_utils.py @@ -11,7 +11,7 @@ def test_register_backend(): @register_backend('example') class ExampleBackend(BaseStorageBackend): - def get_bytes(self, filepath): + def get(self, filepath): return filepath def get_text(self, filepath): @@ -22,7 +22,7 @@ def get_text(self, filepath): # 1.2 use it as a normal function class ExampleBackend1(BaseStorageBackend): - def get_bytes(self, filepath): + def get(self, filepath): return filepath def get_text(self, filepath): @@ -51,7 +51,7 @@ def test_backend(): class ExampleBackend2: - def get_bytes(self, filepath): + def get(self, filepath): return filepath def get_text(self, filepath): @@ -73,7 +73,7 @@ def get_text(self, filepath): # 5. test `prefixes` parameter class ExampleBackend3(BaseStorageBackend): - def get_bytes(self, filepath): + def get(self, filepath): return filepath def get_text(self, filepath): @@ -102,7 +102,7 @@ def get_text(self, filepath): class ExampleBackend4(BaseStorageBackend): - def get_bytes(self, filepath): + def get(self, filepath): return filepath def get_text(self, filepath): diff --git a/tests/test_fileio/test_backends/test_base_storage_backend.py b/tests/test_fileio/test_backends/test_base_storage_backend.py index 518c1da5d5..6aa608851d 100644 --- a/tests/test_fileio/test_backends/test_base_storage_backend.py +++ b/tests/test_fileio/test_backends/test_base_storage_backend.py @@ -16,12 +16,12 @@ class ExampleBackend(BaseStorageBackend): class ExampleBackend(BaseStorageBackend): - def get_bytes(self, filepath): + def get(self, filepath): return filepath def get_text(self, filepath): return filepath backend = ExampleBackend() - assert backend.get_bytes('test') == 'test' + assert backend.get('test') == 'test' assert backend.get_text('test') == 'test' diff --git a/tests/test_fileio/test_backends/test_http_backend.py b/tests/test_fileio/test_backends/test_http_backend.py index a92cc26245..c69394d147 100644 --- a/tests/test_fileio/test_backends/test_http_backend.py +++ b/tests/test_fileio/test_backends/test_http_backend.py @@ -33,9 +33,9 @@ def setUpClass(cls): cls.test_data_dir = Path(__file__).parent.parent.parent / 'data' cls.text_path = cls.test_data_dir / 'filelist.txt' - def test_get_bytes(self): + def test_get(self): backend = HTTPBackend() - img_bytes = backend.get_bytes(self.img_url) + img_bytes = backend.get(self.img_url) img = imfrombytes(img_bytes) self.assertEqual(img.shape, self.img_shape) diff --git a/tests/test_fileio/test_backends/test_lmdb_backend.py b/tests/test_fileio/test_backends/test_lmdb_backend.py index 6847945874..dc2c7ded2b 100644 --- a/tests/test_fileio/test_backends/test_lmdb_backend.py +++ b/tests/test_fileio/test_backends/test_lmdb_backend.py @@ -23,9 +23,9 @@ def setUpClass(cls): cls.lmdb_path = cls.test_data_dir / 'demo.lmdb' @parameterized.expand([[Path], [str]]) - def test_get_bytes(self, path_type): + def test_get(self, path_type): backend = LmdbBackend(path_type(self.lmdb_path)) - img_bytes = backend.get_bytes('baboon') + img_bytes = backend.get('baboon') img = imfrombytes(img_bytes) self.assertEqual(img.shape, (120, 125, 3)) diff --git a/tests/test_fileio/test_backends/test_local_backend.py b/tests/test_fileio/test_backends/test_local_backend.py index 41e37fe014..427ebf789a 100644 --- a/tests/test_fileio/test_backends/test_local_backend.py +++ b/tests/test_fileio/test_backends/test_local_backend.py @@ -71,9 +71,9 @@ def test_name(self): self.assertEqual(backend.name, 'LocalBackend') @parameterized.expand([[Path], [str]]) - def test_get_bytes(self, path_type): + def test_get(self, path_type): backend = LocalBackend() - img_bytes = backend.get_bytes(path_type(self.img_path)) + img_bytes = backend.get(path_type(self.img_path)) self.assertEqual(self.img_path.open('rb').read(), img_bytes) img = imfrombytes(img_bytes) self.assertEqual(img.shape, self.img_shape) @@ -85,19 +85,19 @@ def test_get_text(self, path_type): self.assertEqual(self.text_path.open('r').read(), text) @parameterized.expand([[Path], [str]]) - def test_put_bytes(self, path_type): + def test_put(self, path_type): backend = LocalBackend() with tempfile.TemporaryDirectory() as tmp_dir: filepath = Path(tmp_dir) / 'test.jpg' - backend.put_bytes(b'disk', path_type(filepath)) - self.assertEqual(backend.get_bytes(filepath), b'disk') + backend.put(b'disk', path_type(filepath)) + self.assertEqual(backend.get(filepath), b'disk') - # If the directory does not exist, put_bytes will create a + # If the directory does not exist, put will create a # directory first filepath = Path(tmp_dir) / 'not_existed_dir' / 'test.jpg' - backend.put_bytes(b'disk', path_type(filepath)) - self.assertEqual(backend.get_bytes(filepath), b'disk') + backend.put(b'disk', path_type(filepath)) + self.assertEqual(backend.get(filepath), b'disk') @parameterized.expand([[Path], [str]]) def test_put_text(self, path_type): @@ -123,7 +123,7 @@ def test_exists(self, path_type): self.assertFalse(backend.exists(path_type(filepath))) backend.put_text('disk', filepath) self.assertTrue(backend.exists(path_type(filepath))) - backend.rmfile(filepath) + backend.remove(filepath) @parameterized.expand([[Path], [str]]) def test_isdir(self, path_type): @@ -294,26 +294,26 @@ def test_copytree_to_local(self, path_type): path_type(src), path_type(Path(tmp_dir) / 'dir2')) @parameterized.expand([[Path], [str]]) - def test_rmfile(self, path_type): + def test_remove(self, path_type): backend = LocalBackend() with tempfile.TemporaryDirectory() as tmp_dir: # filepath is a Path object filepath = Path(tmp_dir) / 'test.txt' backend.put_text('disk', filepath) self.assertTrue(backend.exists(filepath)) - backend.rmfile(path_type(filepath)) + backend.remove(path_type(filepath)) self.assertFalse(backend.exists(filepath)) # raise error if file does not exist with self.assertRaises(FileNotFoundError): filepath = Path(tmp_dir) / 'test1.txt' - backend.rmfile(path_type(filepath)) + backend.remove(path_type(filepath)) # can not remove directory filepath = Path(tmp_dir) / 'dir' filepath.mkdir() with self.assertRaises(IsADirectoryError): - backend.rmfile(path_type(filepath)) + backend.remove(path_type(filepath)) @parameterized.expand([[Path], [str]]) def test_rmtree(self, path_type): diff --git a/tests/test_fileio/test_backends/test_memcached_backend.py b/tests/test_fileio/test_backends/test_memcached_backend.py index 30cf9793ec..d320fcb16b 100644 --- a/tests/test_fileio/test_backends/test_memcached_backend.py +++ b/tests/test_fileio/test_backends/test_memcached_backend.py @@ -43,9 +43,9 @@ def setUpClass(cls): @patch('mc.MemcachedClient.GetInstance', MockMemcachedClient) @patch('mc.pyvector', MagicMock) @patch('mc.ConvertBuffer', lambda x: x.content) - def test_get_bytes(self, path_type): + def test_get(self, path_type): backend = MemcachedBackend(**self.mc_cfg) - img_bytes = backend.get_bytes(path_type(self.img_path)) + img_bytes = backend.get(path_type(self.img_path)) self.assertEqual(self.img_path.open('rb').read(), img_bytes) img = imfrombytes(img_bytes) self.assertEqual(img.shape, self.img_shape) diff --git a/tests/test_fileio/test_backends/test_petrel_backend.py b/tests/test_fileio/test_backends/test_petrel_backend.py index 197dd45086..e260af6f2a 100644 --- a/tests/test_fileio/test_backends/test_petrel_backend.py +++ b/tests/test_fileio/test_backends/test_petrel_backend.py @@ -154,13 +154,12 @@ def test_join_path(self): backend.join_path(self.petrel_dir, 'dir', 'file'), f'{self.petrel_dir}/dir/file') - def test_get_bytes(self): + def test_get(self): backend = PetrelBackend() with patch.object( backend._client, 'Get', return_value=b'petrel') as patched_get: - self.assertEqual( - backend.get_bytes(self.petrel_path), b'petrel') + self.assertEqual(backend.get(self.petrel_path), b'petrel') patched_get.assert_called_once_with(self.expected_path) def test_get_text(self): @@ -171,10 +170,10 @@ def test_get_text(self): self.assertEqual(backend.get_text(self.petrel_path), 'petrel') patched_get.assert_called_once_with(self.expected_path) - def test_put_bytes(self): + def test_put(self): backend = PetrelBackend() with patch.object(backend._client, 'put') as patched_put: - backend.put_bytes(b'petrel', self.petrel_path) + backend.put(b'petrel', self.petrel_path) patched_put.assert_called_once_with(self.expected_path, b'petrel') @@ -287,24 +286,24 @@ def test_copyfile(self): def test_copytree(self): backend = PetrelBackend() - put_bytes_inputs = [] - get_bytes_inputs = [] + put_inputs = [] + get_inputs = [] - def put_bytes(obj, filepath): - put_bytes_inputs.append((obj, filepath)) + def put(obj, filepath): + put_inputs.append((obj, filepath)) - def get_bytes(filepath): - get_bytes_inputs.append(filepath) + def get(filepath): + get_inputs.append(filepath) with build_temporary_directory() as tmp_dir, \ - patch.object(backend, 'put_bytes', side_effect=put_bytes),\ - patch.object(backend, 'get_bytes', side_effect=get_bytes),\ + patch.object(backend, 'put', side_effect=put),\ + patch.object(backend, 'get', side_effect=get),\ patch.object(backend, 'exists', return_value=False): dst = f'{tmp_dir}/dir' self.assertEqual(backend.copytree(tmp_dir, dst), dst) - self.assertEqual(len(put_bytes_inputs), 5) - self.assertEqual(len(get_bytes_inputs), 5) + self.assertEqual(len(put_inputs), 5) + self.assertEqual(len(get_inputs), 5) # dst should not exist with patch.object(backend, 'exists', return_value=True): @@ -386,32 +385,32 @@ def test_copytree_to_local(self): backend = PetrelBackend() inputs = [] - def get_bytes(filepath): + def get(filepath): inputs.append(filepath) return b'petrel' with build_temporary_directory() as tmp_dir, \ - patch.object(backend, 'get_bytes', side_effect=get_bytes): + patch.object(backend, 'get', side_effect=get): dst = f'{tmp_dir}/dir' backend.copytree_to_local(tmp_dir, dst) self.assertEqual(len(inputs), 5) - def test_rmfile(self): + def test_remove(self): backend = PetrelBackend() self.assertTrue(has_method(backend._client, 'delete')) # raise Exception if `delete` is not implemented with delete_and_reset_method(backend._client, 'delete'): self.assertFalse(has_method(backend._client, 'delete')) with self.assertRaises(NotImplementedError): - backend.rmfile(self.petrel_path) + backend.remove(self.petrel_path) with patch.object(backend._client, 'delete') as patched_delete, \ patch.object(backend._client, 'isdir', return_value=False) \ as patched_isdir, \ patch.object(backend._client, 'contains', return_value=True) \ as patched_contains: - backend.rmfile(self.petrel_path) + backend.remove(self.petrel_path) patched_delete.assert_called_once_with(self.expected_path) patched_isdir.assert_called_once_with(self.expected_path) patched_contains.assert_called_once_with(self.expected_path) @@ -420,11 +419,11 @@ def test_rmtree(self): backend = PetrelBackend() inputs = [] - def rmfile(filepath): + def remove(filepath): inputs.append(filepath) with build_temporary_directory() as tmp_dir,\ - patch.object(backend, 'rmfile', side_effect=rmfile): + patch.object(backend, 'remove', side_effect=remove): backend.rmtree(tmp_dir) self.assertEqual(len(inputs), 5) @@ -586,20 +585,20 @@ def setUp(self): self.assertTrue(backend.isfile(text4_path)) self.assertTrue(backend.isfile(img_path)) - def test_get_bytes(self): + def test_get(self): backend = PetrelBackend() img_path = f'{self.petrel_dir}/dir2/img.jpg' - self.assertEqual(backend.get_bytes(img_path), b'img') + self.assertEqual(backend.get(img_path), b'img') def test_get_text(self): backend = PetrelBackend() text_path = f'{self.petrel_dir}/text1.txt' self.assertEqual(backend.get_text(text_path), 'text1') - def test_put_bytes(self): + def test_put(self): backend = PetrelBackend() img_path = f'{self.petrel_dir}/img.jpg' - backend.put_bytes(b'img', img_path) + backend.put(b'img', img_path) def test_put_text(self): backend = PetrelBackend() @@ -728,11 +727,11 @@ def test_copytree_to_local(self): self.assertTrue(osp.exists(Path(tmp_dir) / 'text1.txt')) self.assertTrue(osp.exists(Path(tmp_dir) / 'dir2' / 'img.jpg')) - def test_rmfile(self): + def test_remove(self): backend = PetrelBackend() img_path = f'{self.petrel_dir}/dir2/img.jpg' self.assertTrue(backend.isfile(img_path)) - backend.rmfile(img_path) + backend.remove(img_path) self.assertFalse(backend.exists(img_path)) def test_rmtree(self): diff --git a/tests/test_fileio/test_fileclient.py b/tests/test_fileio/test_fileclient.py index 629bd7f622..3620ddb015 100644 --- a/tests/test_fileio/test_fileclient.py +++ b/tests/test_fileio/test_fileclient.py @@ -12,7 +12,7 @@ import numpy as np import pytest -from mmengine import BaseStorageBackend, FileClient +from mmengine.fileio import BaseStorageBackend, FileClient from mmengine.utils import has_method sys.modules['ceph'] = MagicMock() @@ -354,9 +354,15 @@ def test_petrel_backend(self, backend, prefix): petrel_backend.remove(petrel_path) with patch.object(petrel_backend.client._client, - 'delete') as mock_delete: + 'delete') as mock_delete, \ + patch.object(petrel_backend.client._client, + 'isdir', return_value=False) as mock_isdir, \ + patch.object(petrel_backend.client._client, + 'contains', return_value=True) as mock_contains: petrel_backend.remove(petrel_path) mock_delete.assert_called_once_with(petrel_path) + mock_isdir.assert_called_once_with(petrel_path) + mock_contains.assert_called_once_with(petrel_path) # test `exists` assert has_method(petrel_backend.client._client, 'contains') diff --git a/tests/test_fileio/test_fileio.py b/tests/test_fileio/test_fileio.py index e586fca741..33a0956fed 100644 --- a/tests/test_fileio/test_fileio.py +++ b/tests/test_fileio/test_fileio.py @@ -8,9 +8,7 @@ import pytest import mmengine -from mmengine.fileio.backends import HTTPBackend as _HTTPBackend -from mmengine.fileio.backends import PetrelBackend as _PetrelBackend -from mmengine.fileio.file_client import HTTPBackend, PetrelBackend +from mmengine.fileio import HTTPBackend, PetrelBackend sys.modules['petrel_client'] = MagicMock() sys.modules['petrel_client.client'] = MagicMock() @@ -32,9 +30,8 @@ def _test_handler(file_format, test_obj, str_checker, mode='r+'): os.remove(tmp_filename) # load/dump with filename from petrel - method = 'put_bytes' if 'b' in mode else 'put_text' - with patch.object( - _PetrelBackend, method, return_value=None) as mock_method: + method = 'put' if 'b' in mode else 'put_text' + with patch.object(PetrelBackend, method, return_value=None) as mock_method: filename = 's3://path/of/your/file' mmengine.dump(test_obj, filename, file_format=file_format) mock_method.assert_called() @@ -164,8 +161,6 @@ def test_list_from_file(): filename, file_client_args={'prefix': 'http'}) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] - with patch.object( - _HTTPBackend, 'get_text', return_value='1.jpg\n2.jpg\n3.jpg'): filelist = mmengine.list_from_file(filename) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] filelist = mmengine.list_from_file( @@ -182,9 +177,6 @@ def test_list_from_file(): filelist = mmengine.list_from_file( filename, file_client_args={'prefix': 's3'}) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] - - with patch.object( - _PetrelBackend, 'get_text', return_value='1.jpg\n2.jpg\n3.jpg'): filelist = mmengine.list_from_file(filename) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] filelist = mmengine.list_from_file( @@ -211,9 +203,6 @@ def test_dict_from_file(): filename, file_client_args={'prefix': 'http'}) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} - with patch.object( - _HTTPBackend, 'get_text', - return_value='1 cat\n2 dog cow\n3 panda'): mapping = mmengine.dict_from_file(filename) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} mapping = mmengine.dict_from_file( @@ -232,9 +221,6 @@ def test_dict_from_file(): filename, file_client_args={'prefix': 's3'}) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} - with patch.object( - _PetrelBackend, 'get_text', - return_value='1 cat\n2 dog cow\n3 panda'): mapping = mmengine.dict_from_file(filename) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} mapping = mmengine.dict_from_file( diff --git a/tests/test_fileio/test_io.py b/tests/test_fileio/test_io.py index 88f193fd5f..67883d87b3 100644 --- a/tests/test_fileio/test_io.py +++ b/tests/test_fileio/test_io.py @@ -169,10 +169,10 @@ def test_get_file_backend(): assert backend7 is not backend6 -def test_get_bytes(): +def test_get(): # test HardDiskBackend filepath = Path(img_path) - img_bytes = fileio.get_bytes(filepath) + img_bytes = fileio.get(filepath) assert filepath.open('rb').read() == img_bytes @@ -183,18 +183,18 @@ def test_get_text(): assert filepath.open('r').read() == text -def test_put_bytes(): +def test_put(): # test HardDiskBackend with tempfile.TemporaryDirectory() as tmp_dir: filepath = Path(tmp_dir) / 'img.png' - fileio.put_bytes(b'disk', filepath) - assert fileio.get_bytes(filepath) == b'disk' + fileio.put(b'disk', filepath) + assert fileio.get(filepath) == b'disk' - # If the directory does not exist, put_bytes will create a + # If the directory does not exist, put will create a # directory first filepath = Path(tmp_dir) / 'not_existed_dir' / 'test.jpg' - fileio.put_bytes(b'disk', filepath) - assert fileio.get_bytes(filepath) == b'disk' + fileio.put(b'disk', filepath) + assert fileio.get(filepath) == b'disk' def test_put_text(): @@ -362,26 +362,26 @@ def test_copytree_to_local(): fileio.copytree(src, Path(tmp_dir) / 'dir2') -def test_rmfile(): +def test_remove(): # test HardDiskBackend with tempfile.TemporaryDirectory() as tmp_dir: # filepath is a Path object filepath = Path(tmp_dir) / 'test.txt' fileio.put_text('disk', filepath) assert fileio.exists(filepath) - fileio.rmfile(filepath) + fileio.remove(filepath) assert not fileio.exists(filepath) # raise error if file does not exist with pytest.raises(FileNotFoundError): filepath = Path(tmp_dir) / 'test1.txt' - fileio.rmfile(filepath) + fileio.remove(filepath) # can not remove directory filepath = Path(tmp_dir) / 'dir' filepath.mkdir() with pytest.raises(IsADirectoryError): - fileio.rmfile(filepath) + fileio.remove(filepath) def test_rmtree(): From 26d3bdc2d09637b8550af3a0b45cbefdc132ade4 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Sat, 17 Sep 2022 00:23:10 +0800 Subject: [PATCH 08/18] refine --- docs/en/api/fileio.rst | 37 +++++++++++++++++++++-- docs/zh_cn/api/fileio.rst | 37 +++++++++++++++++++++-- mmengine/fileio/__init__.py | 25 +++++++-------- mmengine/fileio/backends/http_backend.py | 2 +- mmengine/fileio/backends/local_backend.py | 2 +- mmengine/fileio/file_client.py | 18 ++++++++++- tests/test_hooks/test_logger_hook.py | 16 ++++++++++ 7 files changed, 118 insertions(+), 19 deletions(-) diff --git a/docs/en/api/fileio.rst b/docs/en/api/fileio.rst index bf27fc0e62..1b8c14b42a 100644 --- a/docs/en/api/fileio.rst +++ b/docs/en/api/fileio.rst @@ -11,7 +11,7 @@ mmengine.fileio .. currentmodule:: mmengine.fileio -File Client +File Backend ---------------- .. autosummary:: @@ -22,11 +22,18 @@ File Client BaseStorageBackend FileClient HardDiskBackend + LocalBackend HTTPBackend LmdbBackend MemcachedBackend PetrelBackend +.. autosummary:: + :toctree: generated + :nosignatures: + + register_backend + File Handler ---------------- @@ -40,6 +47,12 @@ File Handler PickleHandler YamlHandler +.. autosummary:: + :toctree: generated + :nosignatures: + + register_handler + File IO ---------------- @@ -49,7 +62,27 @@ File IO dump load - register_handler + copy_if_symlink_fails + copyfile + copyfile_from_local + copyfile_to_local + copytree + copytree_from_local + copytree_to_local + exists + generate_presigned_url + get + get_file_backend + get_local_path + get_text + isdir + isfile + join_path + list_dir_or_file + put + put_text + remove + rmtree Parse File ---------------- diff --git a/docs/zh_cn/api/fileio.rst b/docs/zh_cn/api/fileio.rst index bf27fc0e62..1b8c14b42a 100644 --- a/docs/zh_cn/api/fileio.rst +++ b/docs/zh_cn/api/fileio.rst @@ -11,7 +11,7 @@ mmengine.fileio .. currentmodule:: mmengine.fileio -File Client +File Backend ---------------- .. autosummary:: @@ -22,11 +22,18 @@ File Client BaseStorageBackend FileClient HardDiskBackend + LocalBackend HTTPBackend LmdbBackend MemcachedBackend PetrelBackend +.. autosummary:: + :toctree: generated + :nosignatures: + + register_backend + File Handler ---------------- @@ -40,6 +47,12 @@ File Handler PickleHandler YamlHandler +.. autosummary:: + :toctree: generated + :nosignatures: + + register_handler + File IO ---------------- @@ -49,7 +62,27 @@ File IO dump load - register_handler + copy_if_symlink_fails + copyfile + copyfile_from_local + copyfile_to_local + copytree + copytree_from_local + copytree_to_local + exists + generate_presigned_url + get + get_file_backend + get_local_path + get_text + isdir + isfile + join_path + list_dir_or_file + put + put_text + remove + rmtree Parse File ---------------- diff --git a/mmengine/fileio/__init__.py b/mmengine/fileio/__init__.py index d8ebbc9047..81adcd4c02 100644 --- a/mmengine/fileio/__init__.py +++ b/mmengine/fileio/__init__.py @@ -1,8 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .backends import register_backend -from .file_client import (BaseStorageBackend, FileClient, HardDiskBackend, - HTTPBackend, LmdbBackend, MemcachedBackend, - PetrelBackend) +from .backends import (BaseStorageBackend, HTTPBackend, LmdbBackend, + LocalBackend, MemcachedBackend, PetrelBackend, + register_backend) +from .file_client import FileClient, HardDiskBackend from .handlers import (BaseFileHandler, JsonHandler, PickleHandler, YamlHandler, register_handler) from .io import (copy_if_symlink_fails, copyfile, copyfile_from_local, @@ -15,12 +15,13 @@ __all__ = [ 'BaseStorageBackend', 'FileClient', 'PetrelBackend', 'MemcachedBackend', - 'LmdbBackend', 'HardDiskBackend', 'HTTPBackend', 'copy_if_symlink_fails', - 'copyfile', 'copyfile_from_local', 'copyfile_to_local', 'copytree', - 'copytree_from_local', 'copytree_to_local', 'exists', - 'generate_presigned_url', 'get', 'get_file_backend', 'get_local_path', - 'get_text', 'isdir', 'isfile', 'join_path', 'list_dir_or_file', 'put', - 'put_text', 'remove', 'rmtree', 'load', 'dump', 'register_handler', - 'BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler', - 'list_from_file', 'dict_from_file', 'register_backend' + 'LmdbBackend', 'HardDiskBackend', 'LocalBackend', 'HTTPBackend', + 'copy_if_symlink_fails', 'copyfile', 'copyfile_from_local', + 'copyfile_to_local', 'copytree', 'copytree_from_local', + 'copytree_to_local', 'exists', 'generate_presigned_url', 'get', + 'get_file_backend', 'get_local_path', 'get_text', 'isdir', 'isfile', + 'join_path', 'list_dir_or_file', 'put', 'put_text', 'remove', 'rmtree', + 'load', 'dump', 'register_handler', 'BaseFileHandler', 'JsonHandler', + 'PickleHandler', 'YamlHandler', 'list_from_file', 'dict_from_file', + 'register_backend' ] diff --git a/mmengine/fileio/backends/http_backend.py b/mmengine/fileio/backends/http_backend.py index 8929af6339..393e66f206 100644 --- a/mmengine/fileio/backends/http_backend.py +++ b/mmengine/fileio/backends/http_backend.py @@ -13,7 +13,7 @@ class HTTPBackend(BaseStorageBackend): """HTTP and HTTPS storage bachend.""" def get(self, filepath: str) -> bytes: - """ead bytes from a given ``filepath``. + """Read bytes from a given ``filepath``. Args: filepath (str): Path to read data. diff --git a/mmengine/fileio/backends/local_backend.py b/mmengine/fileio/backends/local_backend.py index 09dec4b263..8b67e55a66 100644 --- a/mmengine/fileio/backends/local_backend.py +++ b/mmengine/fileio/backends/local_backend.py @@ -11,7 +11,7 @@ class LocalBackend(BaseStorageBackend): - """Raw hard disks storage backend.""" + """Raw local storage backend.""" _allow_symlink = True diff --git a/mmengine/fileio/file_client.py b/mmengine/fileio/file_client.py index ad759e5023..4bcda26c9f 100644 --- a/mmengine/fileio/file_client.py +++ b/mmengine/fileio/file_client.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import inspect +import warnings from contextlib import contextmanager from pathlib import Path from typing import Any, Generator, Iterator, Optional, Tuple, Union @@ -10,6 +11,12 @@ class HardDiskBackend(LocalBackend): + """Raw hard disks storage backend.""" + + def __init__(self) -> None: + warnings.warn( + '"HardDiskBackend" is the alias of "LocalBackend" ' + 'and the former will be deprecated in future.', DeprecationWarning) @property def name(self): @@ -30,9 +37,13 @@ class FileClient: avoid repeated object creation. If the arguments are the same, the same object will be returned. + Warning: + `FileClient` will be deprecated in future. Please use io functions + in https://mmengine.readthedocs.io/en/latest/api/fileio.html#file-io + Args: backend (str, optional): The storage backend type. Options are "disk", - "ceph", "memcached", "lmdb", "http" and "petrel". Default: None. + "memcached", "lmdb", "http" and "petrel". Default: None. prefix (str, optional): The prefix of the registered storage backend. Options are "s3", "http", "https". Default: None. @@ -71,6 +82,11 @@ class FileClient: client: Any def __new__(cls, backend=None, prefix=None, **kwargs): + warnings.warn( + '"FileClient" will be deprecated in future. Please use io ' + 'functions in ' + 'https://mmengine.readthedocs.io/en/latest/api/fileio.html#file-io' + ) if backend is None and prefix is None: backend = 'disk' if backend is not None and backend not in cls._backends: diff --git a/tests/test_hooks/test_logger_hook.py b/tests/test_hooks/test_logger_hook.py index 230355ccb0..9c531d0702 100644 --- a/tests/test_hooks/test_logger_hook.py +++ b/tests/test_hooks/test_logger_hook.py @@ -26,6 +26,22 @@ def test_init(self): with pytest.raises(ValueError): LoggerHook(file_client_args=dict(enable_mc=True)) + # test `file_client_args` and `backend_args` + with pytest.warns( + DeprecationWarning, + match='"file_client_args" will be deprecated in future'): + logger_hook = LoggerHook( + out_dir='tmp.txt', file_client_args={'backend': 'disk'}) + + with pytest.raises( + ValueError, + match='"file_client_args and "backend_args" cannot be both set' + ): + logger_hook = LoggerHook( + out_dir='tmp.txt', + file_client_args={'backend': 'disk'}, + backend_args={'backend': 'local'}) + def test_before_run(self): runner = MagicMock() runner.iter = 10 From 0972fb84b6a0e9efa5d11c2ecb13722587ac03fd Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Sun, 18 Sep 2022 21:04:27 +0800 Subject: [PATCH 09/18] refine docstring --- mmengine/fileio/backends/http_backend.py | 6 +++++- mmengine/fileio/backends/petrel_backend.py | 3 ++- mmengine/fileio/file_client.py | 4 ++-- mmengine/fileio/io.py | 12 ++++++------ .../test_fileio/test_backends/test_petrel_backend.py | 3 --- tests/test_fileio/test_io.py | 4 ---- 6 files changed, 15 insertions(+), 17 deletions(-) diff --git a/mmengine/fileio/backends/http_backend.py b/mmengine/fileio/backends/http_backend.py index 393e66f206..b3e65bbdbb 100644 --- a/mmengine/fileio/backends/http_backend.py +++ b/mmengine/fileio/backends/http_backend.py @@ -49,7 +49,8 @@ def get_text(self, filepath, encoding='utf-8') -> str: @contextmanager def get_local_path( self, filepath: str) -> Generator[Union[str, Path], None, None]: - """Download a file from ``filepath``. + """Download a file from ``filepath`` to a local temporary directory, + and return the temporary path. ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It can be called with ``with`` statement, and when exists from the @@ -58,6 +59,9 @@ def get_local_path( Args: filepath (str): Download a file from ``filepath``. + Yields: + Iterable[str]: Only yield one temporary path. + Examples: >>> backend = HTTPBackend() >>> # After existing from the ``with`` clause, diff --git a/mmengine/fileio/backends/petrel_backend.py b/mmengine/fileio/backends/petrel_backend.py index 0328ad8b11..dacd76783d 100644 --- a/mmengine/fileio/backends/petrel_backend.py +++ b/mmengine/fileio/backends/petrel_backend.py @@ -282,7 +282,8 @@ def get_local_path( self, filepath: Union[str, Path], ) -> Generator[Union[str, Path], None, None]: - """Download a file from ``filepath`` and return a temporary path. + """Download a file from ``filepath`` to a local temporary directory, + and return the temporary path. ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It can be called with ``with`` statement, and when exists from the diff --git a/mmengine/fileio/file_client.py b/mmengine/fileio/file_client.py index 4bcda26c9f..51cbde299a 100644 --- a/mmengine/fileio/file_client.py +++ b/mmengine/fileio/file_client.py @@ -85,8 +85,8 @@ def __new__(cls, backend=None, prefix=None, **kwargs): warnings.warn( '"FileClient" will be deprecated in future. Please use io ' 'functions in ' - 'https://mmengine.readthedocs.io/en/latest/api/fileio.html#file-io' - ) + 'https://mmengine.readthedocs.io/en/latest/api/fileio.html#file-io', # noqa: E501 + DeprecationWarning) if backend is None and prefix is None: backend = 'disk' if backend is not None and backend not in cls._backends: diff --git a/mmengine/fileio/io.py b/mmengine/fileio/io.py index 5c51a995ce..f821b0135d 100644 --- a/mmengine/fileio/io.py +++ b/mmengine/fileio/io.py @@ -5,7 +5,7 @@ MMEngine currently supports five file backends: -- HardDiskBackend +- LocalBackend - PetrelBackend - HTTPBackend - LmdbBackend @@ -69,7 +69,7 @@ def _parse_uri_prefix(uri: Union[str, Path]) -> str: assert is_filepath(uri) uri = str(uri) # if uri does not contains '://', the uri will be handled by - # HardDiskBackend by default + # LocalBackend by default if '://' not in uri: return '' else: @@ -484,7 +484,7 @@ def copyfile_from_local( """Copy a local file src to dst and return the destination file. Note: - If the backend is the instance of HardDiskBackend, it does the same + If the backend is the instance of LocalBackend, it does the same thing with :func:`copyfile`. Args: @@ -525,7 +525,7 @@ def copytree_from_local( named dst and return the destination directory. Note: - If the backend is the instance of HardDiskBackend, it does the same + If the backend is the instance of LocalBackend, it does the same thing with :func:`copytree`. Args: @@ -560,7 +560,7 @@ def copyfile_to_local( exists, it will be replaced. Note: - If the backend is the instance of HardDiskBackend, it does the same + If the backend is the instance of LocalBackend, it does the same thing with :func:`copyfile`. Args: @@ -601,7 +601,7 @@ def copytree_to_local( directory named dst and return the destination directory. Note: - If the backend is the instance of HardDiskBackend, it does the same + If the backend is the instance of LocalBackend, it does the same thing with :func:`copytree`. Args: diff --git a/tests/test_fileio/test_backends/test_petrel_backend.py b/tests/test_fileio/test_backends/test_petrel_backend.py index e260af6f2a..8628e82845 100644 --- a/tests/test_fileio/test_backends/test_petrel_backend.py +++ b/tests/test_fileio/test_backends/test_petrel_backend.py @@ -855,6 +855,3 @@ def test_list_dir_or_file(self): ('dir2', 'dir3', 'text4.txt')), '/'.join( ('dir2', 'img.jpg')), 'text1.txt', 'text2.txt' }) - - def test_generate_presigned_url(self): - pass diff --git a/tests/test_fileio/test_io.py b/tests/test_fileio/test_io.py index 67883d87b3..e6145cf0cc 100644 --- a/tests/test_fileio/test_io.py +++ b/tests/test_fileio/test_io.py @@ -530,7 +530,3 @@ def test_list_dir_or_file(): osp.join('dir2', 'dir3', 'text4.txt'), osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' } - - -def test_generate_presigned_url(): - pass From c6d2234cef3509ccf3a24d858af30d0e048eafc9 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Sun, 18 Sep 2022 21:29:04 +0800 Subject: [PATCH 10/18] fix ut in windows --- mmengine/fileio/backends/local_backend.py | 6 +++ .../test_backends/test_petrel_backend.py | 1 + tests/test_fileio/test_io.py | 38 +++++++++---------- 3 files changed, 26 insertions(+), 19 deletions(-) diff --git a/mmengine/fileio/backends/local_backend.py b/mmengine/fileio/backends/local_backend.py index 8b67e55a66..ad15e4849b 100644 --- a/mmengine/fileio/backends/local_backend.py +++ b/mmengine/fileio/backends/local_backend.py @@ -412,6 +412,12 @@ def remove(self, filepath: Union[str, Path]) -> None: >>> filepath = '/path/of/file' >>> backend.remove(filepath) """ + if not self.exists(filepath): + raise FileNotFoundError(f'filepath {filepath} does not exist') + + if self.isdir(filepath): + raise IsADirectoryError('filepath should be a file') + os.remove(filepath) def rmtree(self, dir_path: Union[str, Path]) -> None: diff --git a/tests/test_fileio/test_backends/test_petrel_backend.py b/tests/test_fileio/test_backends/test_petrel_backend.py index 8628e82845..63c9284a92 100644 --- a/tests/test_fileio/test_backends/test_petrel_backend.py +++ b/tests/test_fileio/test_backends/test_petrel_backend.py @@ -299,6 +299,7 @@ def get(filepath): patch.object(backend, 'put', side_effect=put),\ patch.object(backend, 'get', side_effect=get),\ patch.object(backend, 'exists', return_value=False): + tmp_dir = tmp_dir.replace('\\', '/') dst = f'{tmp_dir}/dir' self.assertEqual(backend.copytree(tmp_dir, dst), dst) diff --git a/tests/test_fileio/test_io.py b/tests/test_fileio/test_io.py index e6145cf0cc..1e8698cc68 100644 --- a/tests/test_fileio/test_io.py +++ b/tests/test_fileio/test_io.py @@ -170,21 +170,21 @@ def test_get_file_backend(): def test_get(): - # test HardDiskBackend + # test LocalBackend filepath = Path(img_path) img_bytes = fileio.get(filepath) assert filepath.open('rb').read() == img_bytes def test_get_text(): - # test HardDiskBackend + # test LocalBackend filepath = Path(text_path) text = fileio.get_text(filepath) assert filepath.open('r').read() == text def test_put(): - # test HardDiskBackend + # test LocalBackend with tempfile.TemporaryDirectory() as tmp_dir: filepath = Path(tmp_dir) / 'img.png' fileio.put(b'disk', filepath) @@ -198,7 +198,7 @@ def test_put(): def test_put_text(): - # test HardDiskBackend + # test LocalBackend with tempfile.TemporaryDirectory() as tmp_dir: filepath = Path(tmp_dir) / 'text.txt' fileio.put_text('text', filepath) @@ -212,7 +212,7 @@ def test_put_text(): def test_exists(): - # test HardDiskBackend + # test LocalBackend with tempfile.TemporaryDirectory() as tmp_dir: assert fileio.exists(tmp_dir) filepath = Path(tmp_dir) / 'test.txt' @@ -222,7 +222,7 @@ def test_exists(): def test_isdir(): - # test HardDiskBackend + # test LocalBackend with tempfile.TemporaryDirectory() as tmp_dir: assert fileio.isdir(tmp_dir) filepath = Path(tmp_dir) / 'test.txt' @@ -231,7 +231,7 @@ def test_isdir(): def test_isfile(): - # test HardDiskBackend + # test LocalBackend with tempfile.TemporaryDirectory() as tmp_dir: assert not fileio.isfile(tmp_dir) filepath = Path(tmp_dir) / 'test.txt' @@ -240,7 +240,7 @@ def test_isfile(): def test_join_path(): - # test HardDiskBackend + # test LocalBackend filepath = fileio.join_path(test_data_dir, 'file') expected = osp.join(test_data_dir, 'file') assert filepath == expected @@ -251,13 +251,13 @@ def test_join_path(): def test_get_local_path(): - # test HardDiskBackend + # test LocalBackend with fileio.get_local_path(text_path) as filepath: assert str(text_path) == filepath def test_copyfile(): - # test HardDiskBackend + # test LocalBackend with tempfile.TemporaryDirectory() as tmp_dir: src = Path(tmp_dir) / 'test.txt' fileio.put_text('disk', src) @@ -277,7 +277,7 @@ def test_copyfile(): def test_copytree(): - # test HardDiskBackend + # test LocalBackend with build_temporary_directory() as tmp_dir: # src and dst are Path objects src = Path(tmp_dir) / 'dir1' @@ -293,7 +293,7 @@ def test_copytree(): def test_copyfile_from_local(): - # test HardDiskBackend + # test LocalBackend with tempfile.TemporaryDirectory() as tmp_dir: src = Path(tmp_dir) / 'test.txt' fileio.put_text('disk', src) @@ -312,7 +312,7 @@ def test_copyfile_from_local(): def test_copytree_from_local(): - # test HardDiskBackend + # test LocalBackend with build_temporary_directory() as tmp_dir: # src and dst are Path objects src = Path(tmp_dir) / 'dir1' @@ -328,7 +328,7 @@ def test_copytree_from_local(): def test_copyfile_to_local(): - # test HardDiskBackend + # test LocalBackend with tempfile.TemporaryDirectory() as tmp_dir: src = Path(tmp_dir) / 'test.txt' fileio.put_text('disk', src) @@ -347,7 +347,7 @@ def test_copyfile_to_local(): def test_copytree_to_local(): - # test HardDiskBackend + # test LocalBackend with build_temporary_directory() as tmp_dir: # src and dst are Path objects src = Path(tmp_dir) / 'dir1' @@ -363,7 +363,7 @@ def test_copytree_to_local(): def test_remove(): - # test HardDiskBackend + # test LocalBackend with tempfile.TemporaryDirectory() as tmp_dir: # filepath is a Path object filepath = Path(tmp_dir) / 'test.txt' @@ -385,7 +385,7 @@ def test_remove(): def test_rmtree(): - # test HardDiskBackend + # test LocalBackend with build_temporary_directory() as tmp_dir: # src and dst are Path objects dir_path = Path(tmp_dir) / 'dir1' @@ -400,7 +400,7 @@ def test_rmtree(): def test_copy_if_symlink_fails(): - # test HardDiskBackend + # test LocalBackend with tempfile.TemporaryDirectory() as tmp_dir: # create a symlink for a file src = Path(tmp_dir) / 'test.txt' @@ -445,7 +445,7 @@ def symlink(src, dst): def test_list_dir_or_file(): - # test HardDiskBackend + # test LocalBackend with build_temporary_directory() as tmp_dir: # list directories and files assert set(fileio.list_dir_or_file(tmp_dir)) == { From 2964c3c2ecb826f8cb42d9a382f7f55bbbd0b7e4 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Sun, 18 Sep 2022 22:32:00 +0800 Subject: [PATCH 11/18] update ut --- tests/test_fileio/test_backends/utils.py | 9 --- tests/test_hooks/test_checkpoint_hook.py | 72 ++++++++++++++++-------- 2 files changed, 50 insertions(+), 31 deletions(-) delete mode 100644 tests/test_fileio/test_backends/utils.py diff --git a/tests/test_fileio/test_backends/utils.py b/tests/test_fileio/test_backends/utils.py deleted file mode 100644 index 5f4d8f458a..0000000000 --- a/tests/test_fileio/test_backends/utils.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import cv2 -import numpy as np - - -def imfrombytes(content): - img_np = np.frombuffer(content, np.uint8) - img = cv2.imdecode(img_np, cv2.IMREAD_COLOR) - return img diff --git a/tests/test_hooks/test_checkpoint_hook.py b/tests/test_hooks/test_checkpoint_hook.py index d6bf79d4a5..8b62739d1e 100644 --- a/tests/test_hooks/test_checkpoint_hook.py +++ b/tests/test_hooks/test_checkpoint_hook.py @@ -1,42 +1,50 @@ # Copyright (c) OpenMMLab. All rights reserved. import os import os.path as osp -from unittest.mock import Mock, patch +from unittest.mock import Mock import pytest +from mmengine.fileio import FileClient, LocalBackend from mmengine.hooks import CheckpointHook from mmengine.logging import MessageHub -class MockPetrel: - - _allow_symlink = False - - def __init__(self): - pass - - @property - def name(self): - return self.__class__.__name__ - - @property - def allow_symlink(self): - return self._allow_symlink - - -prefix_to_backends = {'s3': MockPetrel} - - class TestCheckpointHook: - @patch('mmengine.fileio.file_client.FileClient._prefix_to_backends', - prefix_to_backends) + def test_init(self, tmp_path): + # Test file_client_args and backend_args + with pytest.warns( + DeprecationWarning, + match='"file_client_args" will be deprecated in future'): + CheckpointHook(file_client_args={'backend': 'disk'}) + + with pytest.raises( + ValueError, + match='"file_client_args and "backend_args" cannot be both set' + ): + CheckpointHook( + file_client_args={'backend': 'disk'}, + backend_args={'backend': 'local'}) + def test_before_train(self, tmp_path): runner = Mock() work_dir = str(tmp_path) runner.work_dir = work_dir + # file_client_args is None + checkpoint_hook = CheckpointHook() + checkpoint_hook.before_train(runner) + assert isinstance(checkpoint_hook.file_client, FileClient) + assert isinstance(checkpoint_hook.file_backend, LocalBackend) + + # file_client_args is not None + checkpoint_hook = CheckpointHook(file_client_args={'backend': 'disk'}) + checkpoint_hook.before_train(runner) + assert isinstance(checkpoint_hook.file_client, FileClient) + # file_backend is the alias of file_client + assert checkpoint_hook.file_backend is checkpoint_hook.file_client + # the out_dir of the checkpoint hook is None checkpoint_hook = CheckpointHook(interval=1, by_epoch=True) checkpoint_hook.before_train(runner) @@ -330,6 +338,26 @@ def test_after_train_epoch(self, tmp_path): assert (runner.epoch + 1) % 2 == 0 assert not os.path.exists(f'{work_dir}/epoch_8.pth') + # save_checkpoint of runner should be called with expected arguments + runner = Mock() + work_dir = str(tmp_path) + runner.work_dir = tmp_path + runner.epoch = 1 + runner.message_hub = MessageHub.get_instance('test_after_train_epoch2') + + checkpoint_hook = CheckpointHook(interval=2, by_epoch=True) + checkpoint_hook.before_train(runner) + checkpoint_hook.after_train_epoch(runner) + + runner.save_checkpoint.assert_called_once_with( + runner.work_dir, + 'epoch_2.pth', + None, + backend_args=None, + by_epoch=True, + save_optimizer=True, + save_param_scheduler=True) + def test_after_train_iter(self, tmp_path): work_dir = str(tmp_path) runner = Mock() From 3ba11025c6b32f780c04fb6e7a6e7c4fb1cd62fe Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Mon, 19 Sep 2022 00:47:02 +0800 Subject: [PATCH 12/18] minor fix --- mmengine/fileio/file_client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mmengine/fileio/file_client.py b/mmengine/fileio/file_client.py index 51cbde299a..7f4a67160d 100644 --- a/mmengine/fileio/file_client.py +++ b/mmengine/fileio/file_client.py @@ -73,6 +73,7 @@ class FileClient: _prefix_to_backends: dict = { 's3': PetrelBackend, + 'petrel': PetrelBackend, 'http': HTTPBackend, 'https': HTTPBackend, } From 7d2c2c94f9deb1e083ec7113f2606516eea0171a Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Wed, 21 Sep 2022 10:49:38 +0800 Subject: [PATCH 13/18] ensure client is not None when closing it --- mmengine/fileio/backends/lmdb_backend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mmengine/fileio/backends/lmdb_backend.py b/mmengine/fileio/backends/lmdb_backend.py index e4888a0091..b8357e02c3 100644 --- a/mmengine/fileio/backends/lmdb_backend.py +++ b/mmengine/fileio/backends/lmdb_backend.py @@ -77,4 +77,5 @@ def _get_client(self): **self.kwargs) def __del__(self): - self._client.close() + if self._client is not None: + self._client.close() From 1b5eeeb7d1cfd6aaf8eb5bb1092765fd483c7997 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Wed, 21 Sep 2022 11:06:22 +0800 Subject: [PATCH 14/18] add more examples for list_dir_or_file interface --- mmengine/fileio/backends/local_backend.py | 15 ++++++++++++++- mmengine/fileio/backends/petrel_backend.py | 19 ++++++++++++++++--- mmengine/fileio/io.py | 15 +++++++++++++++ 3 files changed, 45 insertions(+), 4 deletions(-) diff --git a/mmengine/fileio/backends/local_backend.py b/mmengine/fileio/backends/local_backend.py index ad15e4849b..e054148721 100644 --- a/mmengine/fileio/backends/local_backend.py +++ b/mmengine/fileio/backends/local_backend.py @@ -498,9 +498,22 @@ def list_dir_or_file(self, Examples: >>> backend = LocalBackend() >>> dir_path = '/path/of/dir' + >>> # list those files and directories in current directory >>> for file_path in backend.list_dir_or_file(dir_path): ... print(file_path) - """ + >>> # only list files + >>> for file_path in backend.list_dir_or_file(dir_path, list_dir=False): + ... print(file_path) + >>> # only list directories + >>> for file_path in backend.list_dir_or_file(dir_path, list_file=False): + ... print(file_path) + >>> # only list files ending with specified suffixes + >>> for file_path in backend.list_dir_or_file(dir_path, suffix='.txt'): + ... print(file_path) + >>> # list all files and directory recursively + >>> for file_path in backend.list_dir_or_file(dir_path, recursive): + ... print(file_path) + """ # noqa: E501 if list_dir and suffix is not None: raise TypeError('`suffix` should be None when `list_dir` is True') diff --git a/mmengine/fileio/backends/petrel_backend.py b/mmengine/fileio/backends/petrel_backend.py index dacd76783d..de2ce8c954 100644 --- a/mmengine/fileio/backends/petrel_backend.py +++ b/mmengine/fileio/backends/petrel_backend.py @@ -679,9 +679,22 @@ def list_dir_or_file(self, Examples: >>> backend = PetrelBackend() >>> dir_path = 'petrel://path/of/dir' - >>> for path in backend.list_dir_or_file(dir_path): - ... print(path) - """ + >>> # list those files and directories in current directory + >>> for file_path in backend.list_dir_or_file(dir_path): + ... print(file_path) + >>> # only list files + >>> for file_path in backend.list_dir_or_file(dir_path, list_dir=False): + ... print(file_path) + >>> # only list directories + >>> for file_path in backend.list_dir_or_file(dir_path, list_file=False): + ... print(file_path) + >>> # only list files ending with specified suffixes + >>> for file_path in backend.list_dir_or_file(dir_path, suffix='.txt'): + ... print(file_path) + >>> # list all files and directory recursively + >>> for file_path in backend.list_dir_or_file(dir_path, recursive): + ... print(file_path) + """ # noqa: E501 if not has_method(self._client, 'list'): raise NotImplementedError( 'Current version of Petrel Python SDK has not supported ' diff --git a/mmengine/fileio/io.py b/mmengine/fileio/io.py index f821b0135d..5fe45ddf88 100644 --- a/mmengine/fileio/io.py +++ b/mmengine/fileio/io.py @@ -737,6 +737,21 @@ def list_dir_or_file( >>> dir_path = '/path/of/dir' >>> for file_path in list_dir_or_file(dir_path): ... print(file_path) + >>> # list those files and directories in current directory + >>> for file_path in list_dir_or_file(dir_path): + ... print(file_path) + >>> # only list files + >>> for file_path in list_dir_or_file(dir_path, list_dir=False): + ... print(file_path) + >>> # only list directories + >>> for file_path in list_dir_or_file(dir_path, list_file=False): + ... print(file_path) + >>> # only list files ending with specified suffixes + >>> for file_path in list_dir_or_file(dir_path, suffix='.txt'): + ... print(file_path) + >>> # list all files and directory recursively + >>> for file_path in list_dir_or_file(dir_path, recursive): + ... print(file_path) """ backend = get_file_backend( dir_path, backend_args=backend_args, enable_singleton=True) From 30809b57bc0d286f6c6b94f0e459e557acbf9194 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Thu, 22 Sep 2022 15:45:42 +0800 Subject: [PATCH 15/18] refine docstring --- mmengine/fileio/backends/local_backend.py | 2 +- mmengine/fileio/backends/memcached_backend.py | 15 +++++++++++++++ mmengine/fileio/backends/petrel_backend.py | 2 +- mmengine/fileio/io.py | 2 +- 4 files changed, 18 insertions(+), 3 deletions(-) diff --git a/mmengine/fileio/backends/local_backend.py b/mmengine/fileio/backends/local_backend.py index e054148721..0c6c7774fa 100644 --- a/mmengine/fileio/backends/local_backend.py +++ b/mmengine/fileio/backends/local_backend.py @@ -511,7 +511,7 @@ def list_dir_or_file(self, >>> for file_path in backend.list_dir_or_file(dir_path, suffix='.txt'): ... print(file_path) >>> # list all files and directory recursively - >>> for file_path in backend.list_dir_or_file(dir_path, recursive): + >>> for file_path in backend.list_dir_or_file(dir_path, recursive=True): ... print(file_path) """ # noqa: E501 if list_dir and suffix is not None: diff --git a/mmengine/fileio/backends/memcached_backend.py b/mmengine/fileio/backends/memcached_backend.py index 3ef92b04e7..2458e7c6ec 100644 --- a/mmengine/fileio/backends/memcached_backend.py +++ b/mmengine/fileio/backends/memcached_backend.py @@ -33,6 +33,21 @@ def __init__(self, server_list_cfg, client_cfg, sys_path=None): self._mc_buffer = mc.pyvector() def get(self, filepath: Union[str, Path]): + """Get values according to the filepath. + + Args: + filepath (str or Path): Path to read data. + + Returns: + bytes: Expected bytes object. + + Examples: + >>> server_list_cfg = '/path/of/server_list.conf' + >>> client_cfg = '/path/of/mc.conf' + >>> backend = MemcachedBackend(server_list_cfg, client_cfg) + >>> backend.get('/path/of/file') + b'hello world' + """ filepath = str(filepath) import mc self._client.Get(filepath, self._mc_buffer) diff --git a/mmengine/fileio/backends/petrel_backend.py b/mmengine/fileio/backends/petrel_backend.py index de2ce8c954..bfb23bd586 100644 --- a/mmengine/fileio/backends/petrel_backend.py +++ b/mmengine/fileio/backends/petrel_backend.py @@ -692,7 +692,7 @@ def list_dir_or_file(self, >>> for file_path in backend.list_dir_or_file(dir_path, suffix='.txt'): ... print(file_path) >>> # list all files and directory recursively - >>> for file_path in backend.list_dir_or_file(dir_path, recursive): + >>> for file_path in backend.list_dir_or_file(dir_path, recursive=True): ... print(file_path) """ # noqa: E501 if not has_method(self._client, 'list'): diff --git a/mmengine/fileio/io.py b/mmengine/fileio/io.py index 5fe45ddf88..99182ecfe4 100644 --- a/mmengine/fileio/io.py +++ b/mmengine/fileio/io.py @@ -750,7 +750,7 @@ def list_dir_or_file( >>> for file_path in list_dir_or_file(dir_path, suffix='.txt'): ... print(file_path) >>> # list all files and directory recursively - >>> for file_path in list_dir_or_file(dir_path, recursive): + >>> for file_path in list_dir_or_file(dir_path, recursive=True): ... print(file_path) """ backend = get_file_backend( From 064fb1e720abd6fb6e91780cc8ec69983acec7e4 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Sat, 24 Sep 2022 00:45:28 +0800 Subject: [PATCH 16/18] refine deprecated info --- mmengine/fileio/io.py | 6 ++++-- mmengine/fileio/parse.py | 6 ++++-- mmengine/hooks/checkpoint_hook.py | 3 ++- mmengine/hooks/logger_hook.py | 3 ++- mmengine/runner/checkpoint.py | 3 ++- mmengine/runner/runner.py | 3 ++- 6 files changed, 16 insertions(+), 8 deletions(-) diff --git a/mmengine/fileio/io.py b/mmengine/fileio/io.py index 99182ecfe4..62f9a4ef7a 100644 --- a/mmengine/fileio/io.py +++ b/mmengine/fileio/io.py @@ -835,7 +835,8 @@ def load(file, 'Please use "backend_args" instead', DeprecationWarning) if backend_args is not None: raise ValueError( - '"file_client_args and "backend_args" cannot be both set.') + '"file_client_args and "backend_args" cannot be set at the ' + 'same time.') handler = file_handlers[file_format] if is_str(file): @@ -910,7 +911,8 @@ def dump(obj, 'Please use "backend_args" instead', DeprecationWarning) if backend_args is not None: raise ValueError( - '"file_client_args and "backend_args" cannot be both set.') + '"file_client_args" and "backend_args" cannot be set at the ' + 'same time.') handler = file_handlers[file_format] if file is None: diff --git a/mmengine/fileio/parse.py b/mmengine/fileio/parse.py index 139481aa49..080ae023c2 100644 --- a/mmengine/fileio/parse.py +++ b/mmengine/fileio/parse.py @@ -48,7 +48,8 @@ def list_from_file(filename, 'Please use "backend_args" instead', DeprecationWarning) if backend_args is not None: raise ValueError( - '"file_client_args and "backend_args" cannot be both set.') + '"file_client_args" and "backend_args" cannot be set at the ' + 'same time.') cnt = 0 item_list = [] @@ -111,7 +112,8 @@ def dict_from_file(filename, 'Please use "backend_args" instead', DeprecationWarning) if backend_args is not None: raise ValueError( - '"file_client_args and "backend_args" cannot be both set.') + '"file_client_args" and "backend_args" cannot be set at the ' + 'same time.') mapping = {} diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index b432eca823..cbd13124d3 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -144,7 +144,8 @@ def __init__(self, 'Please use "backend_args" instead', DeprecationWarning) if backend_args is not None: raise ValueError( - '"file_client_args and "backend_args" cannot be both set.') + '"file_client_args" and "backend_args" cannot be set ' + 'at the same time.') self.file_client_args = file_client_args self.backend_args = backend_args diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py index df1530f327..71752be142 100644 --- a/mmengine/hooks/logger_hook.py +++ b/mmengine/hooks/logger_hook.py @@ -95,7 +95,8 @@ def __init__(self, 'Please use "backend_args" instead', DeprecationWarning) if backend_args is not None: raise ValueError( - '"file_client_args and "backend_args" cannot be both set.') + '"file_client_args" and "backend_args" cannot be set ' + 'at the same time.') if not (out_dir is None or isinstance(out_dir, str) or is_tuple_of(out_dir, str)): diff --git a/mmengine/runner/checkpoint.py b/mmengine/runner/checkpoint.py index eba2116afd..9fe42b6dec 100644 --- a/mmengine/runner/checkpoint.py +++ b/mmengine/runner/checkpoint.py @@ -661,7 +661,8 @@ def save_checkpoint(checkpoint, 'Please use "backend_args" instead', DeprecationWarning) if backend_args is not None: raise ValueError( - '"file_client_args and "backend_args" cannot be both set.') + '"file_client_args" and "backend_args" cannot be set ' + 'at the same time.') if filename.startswith('pavi://'): if file_client_args is not None or backend_args is not None: diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 1dc086afb6..5d905ae2bd 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -2061,7 +2061,8 @@ def save_checkpoint( 'Please use "backend_args" instead', DeprecationWarning) if backend_args is not None: raise ValueError( - '"file_client_args and "backend_args" cannot be both set.') + '"file_client_args" and "backend_args" cannot be set at ' + 'the same time.') file_client = FileClient.infer_client(file_client_args, out_dir) filepath = file_client.join_path(out_dir, filename) From c131ad54e9a9145f493df63f7f9e5005e4a728fa Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Sun, 25 Sep 2022 21:39:43 +0800 Subject: [PATCH 17/18] fix ut --- tests/test_hooks/test_checkpoint_hook.py | 4 ++-- tests/test_hooks/test_logger_hook.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_hooks/test_checkpoint_hook.py b/tests/test_hooks/test_checkpoint_hook.py index 9fc894109d..8fbb1a56db 100644 --- a/tests/test_hooks/test_checkpoint_hook.py +++ b/tests/test_hooks/test_checkpoint_hook.py @@ -83,8 +83,8 @@ def test_init(self, tmp_path): with pytest.raises( ValueError, - match='"file_client_args and "backend_args" cannot be both set' - ): + match='"file_client_args" and "backend_args" cannot be set ' + 'at the same time'): CheckpointHook( file_client_args={'backend': 'disk'}, backend_args={'backend': 'local'}) diff --git a/tests/test_hooks/test_logger_hook.py b/tests/test_hooks/test_logger_hook.py index 9c531d0702..3a3ddb37e8 100644 --- a/tests/test_hooks/test_logger_hook.py +++ b/tests/test_hooks/test_logger_hook.py @@ -35,8 +35,8 @@ def test_init(self): with pytest.raises( ValueError, - match='"file_client_args and "backend_args" cannot be both set' - ): + match='"file_client_args" and "backend_args" cannot be ' + 'set at the same time'): logger_hook = LoggerHook( out_dir='tmp.txt', file_client_args={'backend': 'disk'}, From 7237eaf02219d4c369655e0271590ea1530bdefd Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Sun, 25 Sep 2022 23:20:13 +0800 Subject: [PATCH 18/18] add a description for lmdb docstring --- mmengine/fileio/backends/lmdb_backend.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mmengine/fileio/backends/lmdb_backend.py b/mmengine/fileio/backends/lmdb_backend.py index b8357e02c3..eb47923e56 100644 --- a/mmengine/fileio/backends/lmdb_backend.py +++ b/mmengine/fileio/backends/lmdb_backend.py @@ -17,6 +17,7 @@ class LmdbBackend(BaseStorageBackend): readahead (bool): Lmdb environment parameter. If False, disable the OS filesystem readahead mechanism, which may improve random read performance when a database is larger than RAM. Defaults to False. + **kwargs: Keyword arguments passed to `lmdb.open`. Attributes: db_path (str): Lmdb database path.