From b2a22574154e1387997a438491f3cb24e591be5c Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Thu, 9 Sep 2021 21:58:17 +0800 Subject: [PATCH 01/46] [Feature] Choose storage backend by the prefix of filepath --- mmcv/fileio/file_client.py | 56 ++++++++++++++++++++++++++++++++------ 1 file changed, 47 insertions(+), 9 deletions(-) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index e8a6cbdb08..b6d9c630e5 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -212,9 +212,13 @@ class FileClient: and return it as a binary file. it can also register other backend accessor with a given name and backend class. - Attributes: + Args: backend (str): The storage backend type. Options are "disk", "ceph", "memcached", "lmdb" and "http". + prefixes (str or list[str] or tuple[str]): The prefix of the + registered storage backend. + + Attributes: client (:obj:`BaseStorageBackend`): The backend object. """ @@ -226,17 +230,35 @@ class FileClient: 'petrel': PetrelBackend, 'http': HTTPBackend, } + _prefix_to_backends = { + 's3://': PetrelBackend, + 'http://': HTTPBackend, + 'https://': HTTPBackend, + } - def __init__(self, backend='disk', **kwargs): - if backend not in self._backends: + def __init__(self, backend=None, prefix=None, **kwargs): + if backend is None and prefix is None: + backend = 'disk' + if backend is not None and prefix is not None: + raise ValueError( + 'backend and prefix should not be `None` at the same time') + if backend is not None and backend not in self._backends: raise ValueError( f'Backend {backend} is not supported. Currently supported ones' f' are {list(self._backends.keys())}') - self.backend = backend - self.client = self._backends[backend](**kwargs) + if prefix is not None and prefix not in self._prefix_to_backends: + raise ValueError( + f'prefix {prefix} is not supported. Currently supported ones' + f' are {list(self._prefix_to_backends.keys())}') + + if backend is not None: + self.client = self._backends[backend](**kwargs) + else: + _backend = self._prefix_to_backends[prefix] + self.client = self._backends[_backend](**kwargs) @classmethod - def _register_backend(cls, name, backend, force=False): + def _register_backend(cls, name, backend, force=False, prefixes=None): if not isinstance(name, str): raise TypeError('the backend name should be a string, ' f'but got {type(name)}') @@ -252,9 +274,21 @@ def _register_backend(cls, name, backend, force=False): 'add "force=True" if you want to override it') cls._backends[name] = backend + if prefixes is not None: + if isinstance(prefixes, str): + prefixes = [prefixes] + else: + assert isinstance(prefixes, (list, tuple)) + for prefix in prefixes: + if (prefix not in cls._prefix_to_backends) or force: + cls._prefix_to_backends[prefix] = backend + else: + raise KeyError( + f'{prefix} is already registered as a storage backend,' + ' add "force=True" if you want to override it') @classmethod - def register_backend(cls, name, backend=None, force=False): + def register_backend(cls, name, backend=None, force=False, prefixes=None): """Register a backend to FileClient. This method can be used as a normal class method or a decorator. @@ -292,13 +326,17 @@ def get_text(self, filepath): Defaults to None. force (bool, optional): Whether to override the backend if the name has already been registered. Defaults to False. + prefixes (str or list[str] or tuple[str]): The prefix of the + registered storage backend. """ if backend is not None: - cls._register_backend(name, backend, force=force) + cls._register_backend( + name, backend, force=force, prefixes=prefixes) return def _register(backend_cls): - cls._register_backend(name, backend_cls, force=force) + cls._register_backend( + name, backend_cls, force=force, prefixes=prefixes) return backend_cls return _register From 073f73ef38603c148f37f17ddb82e73ffff751c5 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Sat, 11 Sep 2021 00:01:37 +0800 Subject: [PATCH 02/46] refactor FileClient and add unittest --- mmcv/fileio/file_client.py | 58 ++++++++++++++++------- tests/test_fileclient.py | 96 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 133 insertions(+), 21 deletions(-) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index b6d9c630e5..9f942976ac 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -214,9 +214,16 @@ class FileClient: Args: backend (str): The storage backend type. Options are "disk", "ceph", - "memcached", "lmdb" and "http". - prefixes (str or list[str] or tuple[str]): The prefix of the - registered storage backend. + "memcached", "lmdb", "http" and "petrel". Default: None. + prefixes (str or list[str] or tuple[str]): The prefixes of the + registered storage backend. Both backend and prefixes can be used + to choose a storage backend, but backend has a higher priority that + is if they are all set, the storage backend will be chosen by the + backend rather than prefixes. If backend and prefixes are all + `None`. The dist backend is be chosen. Default: None. + + .. versionadd:: 1.3.14 + The *prefixes* parameter. Attributes: client (:obj:`BaseStorageBackend`): The backend object. @@ -231,31 +238,45 @@ class FileClient: 'http': HTTPBackend, } _prefix_to_backends = { - 's3://': PetrelBackend, - 'http://': HTTPBackend, - 'https://': HTTPBackend, + 's3': PetrelBackend, + 'http': HTTPBackend, + 'https': HTTPBackend, } - def __init__(self, backend=None, prefix=None, **kwargs): - if backend is None and prefix is None: + def __init__(self, backend=None, prefixes=None, **kwargs): + if backend is None and prefixes is None: backend = 'disk' - if backend is not None and prefix is not None: - raise ValueError( - 'backend and prefix should not be `None` at the same time') if backend is not None and backend not in self._backends: raise ValueError( f'Backend {backend} is not supported. Currently supported ones' f' are {list(self._backends.keys())}') - if prefix is not None and prefix not in self._prefix_to_backends: - raise ValueError( - f'prefix {prefix} is not supported. Currently supported ones' - f' are {list(self._prefix_to_backends.keys())}') + if prefixes is not None: + if isinstance(prefixes, str): + prefixes = [prefixes] + else: + assert isinstance(prefixes, (list, tuple)) + + if not set(prefixes).issubset(self._prefix_to_backends.keys()): + raise ValueError( + f'prefixes {prefixes} is not supported. Currently ' + 'supported ones are ' + f'{list(self._prefix_to_backends.keys())}') if backend is not None: self.client = self._backends[backend](**kwargs) else: - _backend = self._prefix_to_backends[prefix] - self.client = self._backends[_backend](**kwargs) + for prefix in prefixes: + self.client = self._prefix_to_backends[prefix](**kwargs) + break + + @staticmethod + def parse_uri_prefix(uri): + uri = str(uri) + if '://' not in uri: + return None + else: + prefix, _ = uri.split('://') + return prefix @classmethod def _register_backend(cls, name, backend, force=False, prefixes=None): @@ -328,6 +349,9 @@ def get_text(self, filepath): has already been registered. Defaults to False. prefixes (str or list[str] or tuple[str]): The prefix of the registered storage backend. + + .. versionadd:: 1.3.14 + The *prefixes* parameter. """ if backend is not None: cls._register_backend( diff --git a/tests/test_fileclient.py b/tests/test_fileclient.py index 80357cf31d..ef241413f4 100644 --- a/tests/test_fileclient.py +++ b/tests/test_fileclient.py @@ -104,8 +104,10 @@ def test_ceph_backend(self): str(self.img_path).replace(str(self.test_data_dir), ceph_path)) @patch('petrel_client.client.Client', MockS3Client) - def test_petrel_backend(self): - petrel_backend = FileClient('petrel') + @pytest.mark.parametrize('backend,prefixes', [('petrel', None), + (None, 's3')]) + def test_petrel_backend(self, backend, prefixes): + petrel_backend = FileClient(backend=backend, prefixes=prefixes) # input path is Path object with pytest.raises(NotImplementedError): @@ -182,8 +184,10 @@ def test_lmdb_backend(self): img = mmcv.imfrombytes(img_bytes) assert img.shape == (120, 125, 3) - def test_http_backend(self): - http_backend = FileClient('http') + @pytest.mark.parametrize('backend,prefixes', [('http', None), + (None, 'http')]) + def test_http_backend(self, backend, prefixes): + http_backend = FileClient(backend=backend, prefixes=prefixes) img_url = 'https://raw.githubusercontent.com/open-mmlab/mmcv/' \ 'master/tests/data/color.jpg' text_url = 'https://raw.githubusercontent.com/open-mmlab/mmcv/' \ @@ -208,6 +212,21 @@ def test_http_backend(self): value_buf = http_backend.get_text(text_url) assert self.text_path.open('r').read() == value_buf + def test_parse_uri_prefix(self): + # input path is Path object + assert FileClient.parse_uri_prefix(self.img_path) is None + # input path is str + assert FileClient.parse_uri_prefix(str(self.img_path)) is None + + # input path starts with https + img_url = 'https://raw.githubusercontent.com/open-mmlab/mmcv/' \ + 'master/tests/data/color.jpg' + assert FileClient.parse_uri_prefix(img_url) == 'https' + + # input path starts with s3 + img_url = 's3://your_bucket/img.png' + assert FileClient.parse_uri_prefix(img_url) == 's3' + def test_register_backend(self): # name must be a string @@ -299,3 +318,72 @@ def get_text(self, filepath): example_backend = FileClient('example3') assert example_backend.get(self.img_path) == 'bytes5' assert example_backend.get_text(self.text_path) == 'text5' + + # prefixes is a str + class Example6Backend(BaseStorageBackend): + + def get(self, filepath): + return 'bytes6' + + def get_text(self, filepath): + return 'text6' + + FileClient.register_backend( + 'example4', + Example6Backend, + force=True, + prefixes='example4_prefix') + example_backend = FileClient('example4') + assert example_backend.get(self.img_path) == 'bytes6' + assert example_backend.get_text(self.text_path) == 'text6' + example_backend = FileClient(prefixes='example4_prefix') + assert example_backend.get(self.img_path) == 'bytes6' + assert example_backend.get_text(self.text_path) == 'text6' + example_backend = FileClient('example4', prefixes='example4_prefix') + assert example_backend.get(self.img_path) == 'bytes6' + assert example_backend.get_text(self.text_path) == 'text6' + + # prefixes is a list of str + class Example7Backend(BaseStorageBackend): + + def get(self, filepath): + return 'bytes7' + + def get_text(self, filepath): + return 'text7' + + FileClient.register_backend( + 'example5', + Example7Backend, + force=True, + prefixes=['example5_prefix1', 'example5_prefix2']) + example_backend = FileClient('example5') + assert example_backend.get(self.img_path) == 'bytes7' + assert example_backend.get_text(self.text_path) == 'text7' + example_backend = FileClient(prefixes='example5_prefix1') + assert example_backend.get(self.img_path) == 'bytes7' + assert example_backend.get_text(self.text_path) == 'text7' + example_backend = FileClient(prefixes='example5_prefix2') + assert example_backend.get(self.img_path) == 'bytes7' + assert example_backend.get_text(self.text_path) == 'text7' + + # backend has a higher priority than prefixes + class Example8Backend(BaseStorageBackend): + + def get(self, filepath): + return 'bytes8' + + def get_text(self, filepath): + return 'text8' + + FileClient.register_backend( + 'example6', + Example8Backend, + force=True, + prefixes='example6_prefix') + example_backend = FileClient('example6') + assert example_backend.get(self.img_path) == 'bytes8' + assert example_backend.get_text(self.text_path) == 'text8' + example_backend = FileClient('example6', prefixes='example4_prefix') + assert example_backend.get(self.img_path) == 'bytes8' + assert example_backend.get_text(self.text_path) == 'text8' From dfb9fc4a5d141b71e0ba3b613b486053977c321f Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Sat, 11 Sep 2021 22:08:38 +0800 Subject: [PATCH 03/46] support loading from different backends --- mmcv/fileio/file_client.py | 70 ++++++++++++++++++++------ mmcv/fileio/handlers/base.py | 1 + mmcv/fileio/handlers/pickle_handler.py | 2 + mmcv/fileio/io.py | 39 ++++++++++++-- mmcv/fileio/parse.py | 39 ++++++++++++-- tests/test_fileclient.py | 41 +++++++++------ tests/test_fileio.py | 70 +++++++++++++++++++++++++- 7 files changed, 222 insertions(+), 40 deletions(-) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index 9f942976ac..4ca0a5f968 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. import inspect from abc import ABCMeta, abstractmethod +from pathlib import Path +from typing import Optional, Union from urllib.request import urlopen @@ -49,7 +51,7 @@ def get(self, filepath): value_buf = memoryview(value) return value_buf - def get_text(self, filepath): + def get_text(self, filepath, encoding=None): raise NotImplementedError @@ -61,20 +63,26 @@ class PetrelBackend(BaseStorageBackend): path. When `path_mapping={'src': 'dst'}`, `src` in `filepath` will be replaced by `dst`. Default: None. enable_mc (bool): whether to enable memcached support. Default: True. + enable_multi_cluster (bool): Whether to enable multi clusters. + Default: False. """ - def __init__(self, path_mapping=None, enable_mc=True): + def __init__(self, + path_mapping: Optional[dict] = None, + enable_mc: bool = True, + enable_multi_cluster: bool = False): try: from petrel_client import client except ImportError: raise ImportError('Please install petrel_client to enable ' 'PetrelBackend.') - self._client = client.Client(enable_mc=enable_mc) + self._client = client.Client( + enable_mc=enable_mc, enable_multi_cluster=enable_multi_cluster) assert isinstance(path_mapping, dict) or path_mapping is None self.path_mapping = path_mapping - def get(self, filepath): + def get(self, filepath: Union[str, Path]) -> memoryview: filepath = str(filepath) if self.path_mapping is not None: for k, v in self.path_mapping.items(): @@ -83,8 +91,23 @@ def get(self, filepath): value_buf = memoryview(value) return value_buf - def get_text(self, filepath): - raise NotImplementedError + def get_text(self, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> str: + return str(self.get(filepath), encoding=encoding) + + def put(self, obj: bytes, filepath: Union[str, Path]) -> None: + filepath = str(filepath) + if self.path_mapping is not None: + for k, v in self.path_mapping.items(): + filepath = filepath.replace(k, v) + self._client.put(filepath, obj) + + def put_text(self, + obj: str, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> None: + self.put(bytes(obj, encoding=encoding), filepath) class MemcachedBackend(BaseStorageBackend): @@ -121,7 +144,7 @@ def get(self, filepath): value_buf = mc.ConvertBuffer(self._mc_buffer) return value_buf - def get_text(self, filepath): + def get_text(self, filepath, encoding=None): raise NotImplementedError @@ -173,7 +196,7 @@ def get(self, filepath): value_buf = txn.get(filepath.encode('ascii')) return value_buf - def get_text(self, filepath): + def get_text(self, filepath, encoding=None): raise NotImplementedError @@ -186,12 +209,22 @@ def get(self, filepath): value_buf = f.read() return value_buf - def get_text(self, filepath): + def get_text(self, filepath, encoding='utf-8'): filepath = str(filepath) - with open(filepath, 'r') as f: + with open(filepath, 'r', encoding=encoding) as f: value_buf = f.read() return value_buf + def put(self, obj, filepath): + filepath = str(filepath) + with open(filepath, 'wb') as f: + f.write(obj) + + def put_text(self, obj, filepath, encoding='utf-8'): + filepath = str(filepath) + with open(filepath, 'w', encoding=encoding) as f: + f.write(obj) + class HTTPBackend(BaseStorageBackend): """HTTP and HTTPS storage bachend.""" @@ -200,9 +233,9 @@ def get(self, filepath): value_buf = urlopen(filepath).read() return value_buf - def get_text(self, filepath): + def get_text(self, filepath, encoding='utf-8'): value_buf = urlopen(filepath).read() - return value_buf.decode('utf-8') + return value_buf.decode(encoding) class FileClient: @@ -276,6 +309,9 @@ def parse_uri_prefix(uri): return None else: prefix, _ = uri.split('://') + # clusterName:s3:// + if ':' in prefix: + _, prefix = prefix.split(':') return prefix @classmethod @@ -368,5 +404,11 @@ def _register(backend_cls): def get(self, filepath): return self.client.get(filepath) - def get_text(self, filepath): - return self.client.get_text(filepath) + def get_text(self, filepath, encoding='utf-8'): + return self.client.get_text(filepath, encoding) + + def put(self, obj, filepath): + self.client.put(obj, filepath) + + def put_text(self, obj, filepath): + self.client.put_text(obj, filepath) diff --git a/mmcv/fileio/handlers/base.py b/mmcv/fileio/handlers/base.py index 235727557c..0e398301c6 100644 --- a/mmcv/fileio/handlers/base.py +++ b/mmcv/fileio/handlers/base.py @@ -3,6 +3,7 @@ class BaseFileHandler(metaclass=ABCMeta): + str_like_obj = True @abstractmethod def load_from_fileobj(self, file, **kwargs): diff --git a/mmcv/fileio/handlers/pickle_handler.py b/mmcv/fileio/handlers/pickle_handler.py index 0250459957..0809995f51 100644 --- a/mmcv/fileio/handlers/pickle_handler.py +++ b/mmcv/fileio/handlers/pickle_handler.py @@ -6,6 +6,8 @@ class PickleHandler(BaseFileHandler): + str_like_obj = False + def load_from_fileobj(self, file, **kwargs): return pickle.load(file, **kwargs) diff --git a/mmcv/fileio/io.py b/mmcv/fileio/io.py index 015d36e808..5ec6b3db90 100644 --- a/mmcv/fileio/io.py +++ b/mmcv/fileio/io.py @@ -1,7 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. +from io import BytesIO, StringIO from pathlib import Path from ..utils import is_list_of, is_str +from .file_client import FileClient from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler file_handlers = { @@ -13,7 +15,7 @@ } -def load(file, file_format=None, **kwargs): +def load(file, file_format=None, file_client_args=None, **kwargs): """Load data from json/yaml/pickle files. This method provides a unified api for loading data from serialized files. @@ -25,6 +27,8 @@ def load(file, file_format=None, **kwargs): inferred from the file extension, otherwise use the specified one. Currently supported formats include "json", "yaml/yml" and "pickle/pkl". + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. Default: None. Returns: The content from the file. @@ -36,9 +40,20 @@ def load(file, file_format=None, **kwargs): if file_format not in file_handlers: raise TypeError(f'Unsupported format: {file_format}') + if file_client_args is None: + file_prefix = FileClient.parse_uri_prefix(file) + client = FileClient(prefixes=file_prefix) + else: + client = FileClient(**file_client_args) + handler = file_handlers[file_format] if is_str(file): - obj = handler.load_from_path(file, **kwargs) + if handler.str_like_obj: + with StringIO(client.get_text(file)) as f: + obj = handler.load_from_fileobj(f, **kwargs) + else: + with BytesIO(client.get(file)) as f: + obj = handler.load_from_fileobj(f, **kwargs) elif hasattr(file, 'read'): obj = handler.load_from_fileobj(file, **kwargs) else: @@ -46,7 +61,7 @@ def load(file, file_format=None, **kwargs): return obj -def dump(obj, file=None, file_format=None, **kwargs): +def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs): """Dump data to json/yaml/pickle strings or files. This method provides a unified api for dumping data as strings or to files, @@ -58,7 +73,8 @@ def dump(obj, file=None, file_format=None, **kwargs): specified, then the object is dump to a str, otherwise to a file specified by the filename or file-like object. file_format (str, optional): Same as :func:`load`. - + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. Default: None. Returns: bool: True for success, False otherwise. """ @@ -73,11 +89,24 @@ def dump(obj, file=None, file_format=None, **kwargs): if file_format not in file_handlers: raise TypeError(f'Unsupported format: {file_format}') + if file_client_args is None: + file_prefix = FileClient.parse_uri_prefix(file) + client = FileClient(prefixes=file_prefix) + else: + client = FileClient(**file_client_args) + handler = file_handlers[file_format] if file is None: return handler.dump_to_str(obj, **kwargs) elif is_str(file): - handler.dump_to_path(obj, file, **kwargs) + if handler.str_like_obj: + f = StringIO() + handler.dump_to_fileobj(obj, f, **kwargs) + client.put_text(f.getvalue(), file) + else: + f = BytesIO() + handler.dump_to_fileobj(obj, f, **kwargs) + client.put(f.getvalue(), file) elif hasattr(file, 'write'): handler.dump_to_fileobj(obj, file, **kwargs) else: diff --git a/mmcv/fileio/parse.py b/mmcv/fileio/parse.py index ffe86d3de9..feff0ce5a0 100644 --- a/mmcv/fileio/parse.py +++ b/mmcv/fileio/parse.py @@ -1,5 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. -def list_from_file(filename, prefix='', offset=0, max_num=0, encoding='utf-8'): + +from io import StringIO + +from .file_client import FileClient + + +def list_from_file(filename, + prefix='', + offset=0, + max_num=0, + encoding='utf-8', + file_client_args=None): """Load a text file and parse the content as a list of strings. Args: @@ -9,13 +20,21 @@ def list_from_file(filename, prefix='', offset=0, max_num=0, encoding='utf-8'): max_num (int): The maximum number of lines to be read, zeros and negatives mean no limitation. encoding (str): Encoding used to open the file. Default utf-8. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. Default: None. Returns: list[str]: A list of strings. """ cnt = 0 item_list = [] - with open(filename, 'r', encoding=encoding) as f: + if file_client_args is None: + file_prefix = FileClient.parse_uri_prefix(filename) + client = FileClient(prefixes=file_prefix) + else: + client = FileClient(**file_client_args) + + with StringIO(client.get_text(filename, encoding)) as f: for _ in range(offset): f.readline() for line in f: @@ -26,7 +45,10 @@ def list_from_file(filename, prefix='', offset=0, max_num=0, encoding='utf-8'): return item_list -def dict_from_file(filename, key_type=str): +def dict_from_file(filename, + key_type=str, + encoding='utf-8', + file_client_args=None): """Load a text file and parse the content as a dict. Each line of the text file will be two or more columns split by @@ -37,12 +59,21 @@ def dict_from_file(filename, key_type=str): filename(str): Filename. key_type(type): Type of the dict keys. str is user by default and type conversion will be performed if specified. + encoding (str): Encoding used to open the file. Default utf-8. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. Default: None. Returns: dict: The parsed contents. """ mapping = {} - with open(filename, 'r') as f: + if file_client_args is None: + file_prefix = FileClient.parse_uri_prefix(filename) + client = FileClient(prefixes=file_prefix) + else: + client = FileClient(**file_client_args) + + with StringIO(client.get_text(filename, encoding)) as f: for line in f: items = line.rstrip('\n').split() assert len(items) >= 2 diff --git a/tests/test_fileclient.py b/tests/test_fileclient.py index ef241413f4..446e33b719 100644 --- a/tests/test_fileclient.py +++ b/tests/test_fileclient.py @@ -24,6 +24,18 @@ def Get(self, filepath): return content +class MockPetrelClient: + + def __init__(self, enable_mc=True, enable_multi_cluster=False): + self.enable_mc = enable_mc + self.enable_multi_cluster = enable_multi_cluster + + def Get(self, filepath): + with open(filepath, 'rb') as f: + content = f.read() + return content + + class MockMemcachedClient: def __init__(self, server_list_cfg, client_cfg): @@ -103,19 +115,12 @@ def test_ceph_backend(self): ceph_backend.client._client.Get.assert_called_with( str(self.img_path).replace(str(self.test_data_dir), ceph_path)) - @patch('petrel_client.client.Client', MockS3Client) + @patch('petrel_client.client.Client', MockPetrelClient) @pytest.mark.parametrize('backend,prefixes', [('petrel', None), (None, 's3')]) def test_petrel_backend(self, backend, prefixes): petrel_backend = FileClient(backend=backend, prefixes=prefixes) - # input path is Path object - with pytest.raises(NotImplementedError): - petrel_backend.get_text(self.text_path) - # input path is str - with pytest.raises(NotImplementedError): - petrel_backend.get_text(str(self.text_path)) - # input path is Path object img_bytes = petrel_backend.get(self.img_path) img = mmcv.imfrombytes(img_bytes) @@ -227,6 +232,10 @@ def test_parse_uri_prefix(self): img_url = 's3://your_bucket/img.png' assert FileClient.parse_uri_prefix(img_url) == 's3' + # input path starts with clusterName:s3 + img_url = 'clusterName:s3://your_bucket/img.png' + assert FileClient.parse_uri_prefix(img_url) == 's3' + def test_register_backend(self): # name must be a string @@ -254,7 +263,7 @@ class ExampleBackend(BaseStorageBackend): def get(self, filepath): return filepath - def get_text(self, filepath): + def get_text(self, filepath, encoding='utf-8'): return filepath FileClient.register_backend('example', ExampleBackend) @@ -268,7 +277,7 @@ class Example2Backend(BaseStorageBackend): def get(self, filepath): return 'bytes2' - def get_text(self, filepath): + def get_text(self, filepath, encoding='utf-8'): return 'text2' # force=False @@ -286,7 +295,7 @@ class Example3Backend(BaseStorageBackend): def get(self, filepath): return 'bytes3' - def get_text(self, filepath): + def get_text(self, filepath, encoding='utf-8'): return 'text3' example_backend = FileClient('example3') @@ -303,7 +312,7 @@ class Example4Backend(BaseStorageBackend): def get(self, filepath): return 'bytes4' - def get_text(self, filepath): + def get_text(self, filepath, encoding='utf-8'): return 'text4' @FileClient.register_backend(name='example3', force=True) @@ -312,7 +321,7 @@ class Example5Backend(BaseStorageBackend): def get(self, filepath): return 'bytes5' - def get_text(self, filepath): + def get_text(self, filepath, encoding='utf-8'): return 'text5' example_backend = FileClient('example3') @@ -325,7 +334,7 @@ class Example6Backend(BaseStorageBackend): def get(self, filepath): return 'bytes6' - def get_text(self, filepath): + def get_text(self, filepath, encoding='utf-8'): return 'text6' FileClient.register_backend( @@ -349,7 +358,7 @@ class Example7Backend(BaseStorageBackend): def get(self, filepath): return 'bytes7' - def get_text(self, filepath): + def get_text(self, filepath, encoding='utf-8'): return 'text7' FileClient.register_backend( @@ -373,7 +382,7 @@ class Example8Backend(BaseStorageBackend): def get(self, filepath): return 'bytes8' - def get_text(self, filepath): + def get_text(self, filepath, encoding='utf-8'): return 'text8' FileClient.register_backend( diff --git a/tests/test_fileio.py b/tests/test_fileio.py index a9d70f515a..5b701d6da9 100644 --- a/tests/test_fileio.py +++ b/tests/test_fileio.py @@ -1,11 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. import os import os.path as osp +import sys import tempfile +from unittest.mock import MagicMock, patch import pytest import mmcv +from mmcv.fileio.file_client import HTTPBackend, PetrelBackend + +sys.modules['petrel_client'] = MagicMock() +sys.modules['petrel_client.client'] = MagicMock() def _test_handler(file_format, test_obj, str_checker, mode='r+'): @@ -13,7 +19,7 @@ def _test_handler(file_format, test_obj, str_checker, mode='r+'): dump_str = mmcv.dump(test_obj, file_format=file_format) str_checker(dump_str) - # load/dump with filenames + # load/dump with filenames from disk tmp_filename = osp.join(tempfile.gettempdir(), 'mmcv_test_dump') mmcv.dump(test_obj, tmp_filename, file_format=file_format) assert osp.isfile(tmp_filename) @@ -21,6 +27,13 @@ def _test_handler(file_format, test_obj, str_checker, mode='r+'): assert load_obj == test_obj os.remove(tmp_filename) + # load/dump with filename from petrel + method = 'put' if 'b' in mode else 'put_text' + with patch.object(PetrelBackend, method, return_value=None) as mock_method: + filename = 's3://path/of/your/file' + mmcv.dump(test_obj, filename, file_format=file_format) + mock_method.assert_called() + # json load/dump with a file-like object with tempfile.NamedTemporaryFile(mode, delete=False) as f: tmp_filename = f.name @@ -122,6 +135,7 @@ def dump_to_str(self, obj, **kwargs): def test_list_from_file(): + # get list from disk filename = osp.join(osp.dirname(__file__), 'data/filelist.txt') filelist = mmcv.list_from_file(filename) assert filelist == ['1.jpg', '2.jpg', '3.jpg', '4.jpg', '5.jpg'] @@ -134,10 +148,64 @@ def test_list_from_file(): filelist = mmcv.list_from_file(filename, offset=3, max_num=3) assert filelist == ['4.jpg', '5.jpg'] + # get list from http + with patch.object( + HTTPBackend, 'get_text', return_value='1.jpg\n2.jpg\n3.jpg'): + filename = 'http://path/of/your/file' + filelist = mmcv.list_from_file( + filename, file_client_args={'backend': 'http'}) + assert filelist == ['1.jpg', '2.jpg', '3.jpg'] + filelist = mmcv.list_from_file( + filename, file_client_args={'prefixes': 'http'}) + assert filelist == ['1.jpg', '2.jpg', '3.jpg'] + filelist = mmcv.list_from_file(filename) + assert filelist == ['1.jpg', '2.jpg', '3.jpg'] + + # get list from petrel + with patch.object( + PetrelBackend, 'get_text', return_value='1.jpg\n2.jpg\n3.jpg'): + filename = 's3://path/of/your/file' + filelist = mmcv.list_from_file( + filename, file_client_args={'backend': 'petrel'}) + assert filelist == ['1.jpg', '2.jpg', '3.jpg'] + filelist = mmcv.list_from_file( + filename, file_client_args={'prefixes': 's3'}) + assert filelist == ['1.jpg', '2.jpg', '3.jpg'] + filelist = mmcv.list_from_file(filename) + assert filelist == ['1.jpg', '2.jpg', '3.jpg'] + def test_dict_from_file(): + # get dict from disk filename = osp.join(osp.dirname(__file__), 'data/mapping.txt') mapping = mmcv.dict_from_file(filename) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} mapping = mmcv.dict_from_file(filename, key_type=int) assert mapping == {1: 'cat', 2: ['dog', 'cow'], 3: 'panda'} + + # get dict from http + with patch.object( + HTTPBackend, 'get_text', return_value='1 cat\n2 dog cow\n3 panda'): + filename = 'http://path/of/your/file' + mapping = mmcv.dict_from_file( + filename, file_client_args={'backend': 'http'}) + assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} + mapping = mmcv.dict_from_file( + filename, file_client_args={'prefixes': 'http'}) + assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} + mapping = mmcv.dict_from_file(filename) + assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} + + # get dict from petrel + with patch.object( + PetrelBackend, 'get_text', + return_value='1 cat\n2 dog cow\n3 panda'): + filename = 's3://path/of/your/file' + mapping = mmcv.dict_from_file( + filename, file_client_args={'backend': 'petrel'}) + assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} + mapping = mmcv.dict_from_file( + filename, file_client_args={'prefixes': 's3'}) + assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} + mapping = mmcv.dict_from_file(filename) + assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} From 48cfdad7833d84ad00c9932a62df1ff288de1808 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Tue, 21 Sep 2021 10:15:38 +0800 Subject: [PATCH 04/46] polish docstring --- mmcv/fileio/file_client.py | 29 ++++++++++++++++++++--------- mmcv/fileio/io.py | 12 ++++++------ 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index 4ca0a5f968..3c8eb1641e 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -63,7 +63,7 @@ class PetrelBackend(BaseStorageBackend): path. When `path_mapping={'src': 'dst'}`, `src` in `filepath` will be replaced by `dst`. Default: None. enable_mc (bool): whether to enable memcached support. Default: True. - enable_multi_cluster (bool): Whether to enable multi clusters. + enable_multi_cluster (bool): Whether to enable multiple clusters. Default: False. """ @@ -242,22 +242,32 @@ class FileClient: """A general file client to access files in different backend. The client loads a file or text in a specified backend from its path - and return it as a binary file. it can also register other backend - accessor with a given name and backend class. + and return it as a binary or text file. There are two ways to choose a + backend, the name of backend and the prefixes of path. Although both of + them can be used to choose a storage backend, backend has a higher priority + that is if they are all set, the storage backend will be chosen by the + backend argument. If they are all `None`, the dist backend will be chosen. + Note that It can also register other backend accessor with a given name, + prefixes, and backend class. Args: backend (str): The storage backend type. Options are "disk", "ceph", "memcached", "lmdb", "http" and "petrel". Default: None. prefixes (str or list[str] or tuple[str]): The prefixes of the - registered storage backend. Both backend and prefixes can be used - to choose a storage backend, but backend has a higher priority that - is if they are all set, the storage backend will be chosen by the - backend rather than prefixes. If backend and prefixes are all - `None`. The dist backend is be chosen. Default: None. + registered storage backend. Options are "s3", "http", "https". + Default: None. .. versionadd:: 1.3.14 The *prefixes* parameter. + Example: + >>> # only set backend + >>> file_client = FileClient(backend='ceph') + >>> # only set prefixes + >>> file_client = FileClient(prefixes='s3') + >>> # set both backend and prefixes but use backend to choose client + >>> file_client = FileClient(backend='ceph', prefixes='s3') + Attributes: client (:obj:`BaseStorageBackend`): The backend object. """ @@ -309,7 +319,8 @@ def parse_uri_prefix(uri): return None else: prefix, _ = uri.split('://') - # clusterName:s3:// + # In the case of ceph, the prefix may contains the cluster name + # like clusterName:s3 if ':' in prefix: _, prefix = prefix.split(':') return prefix diff --git a/mmcv/fileio/io.py b/mmcv/fileio/io.py index 5ec6b3db90..22146fbb47 100644 --- a/mmcv/fileio/io.py +++ b/mmcv/fileio/io.py @@ -100,13 +100,13 @@ def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs): return handler.dump_to_str(obj, **kwargs) elif is_str(file): if handler.str_like_obj: - f = StringIO() - handler.dump_to_fileobj(obj, f, **kwargs) - client.put_text(f.getvalue(), file) + with StringIO as f: + handler.dump_to_fileobj(obj, f, **kwargs) + client.put_text(f.getvalue(), file) else: - f = BytesIO() - handler.dump_to_fileobj(obj, f, **kwargs) - client.put(f.getvalue(), file) + with BytesIO() as f: + handler.dump_to_fileobj(obj, f, **kwargs) + client.put(f.getvalue(), file) elif hasattr(file, 'write'): handler.dump_to_fileobj(obj, file, **kwargs) else: From c2c9fc07ed41b203a1e25789ded61e109c8fa783 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Tue, 21 Sep 2021 12:08:16 +0800 Subject: [PATCH 05/46] fix unittet --- mmcv/fileio/file_client.py | 22 ++++++++++++++++------ mmcv/fileio/io.py | 2 +- tests/test_fileclient.py | 5 +++++ 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index 3c8eb1641e..e6a3188600 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -107,7 +107,14 @@ def put_text(self, obj: str, filepath: Union[str, Path], encoding: str = 'utf-8') -> None: - self.put(bytes(obj, encoding=encoding), filepath) + self.put(bytes(obj, encoding=encoding), str(filepath)) + + def remove(self, filepath: Union[str, Path]) -> None: + filepath = str(filepath) + if self.path_mapping is not None: + for k, v in self.path_mapping.items(): + filepath = filepath.replace(k, v) + self._client.delete(filepath) class MemcachedBackend(BaseStorageBackend): @@ -244,11 +251,11 @@ class FileClient: The client loads a file or text in a specified backend from its path and return it as a binary or text file. There are two ways to choose a backend, the name of backend and the prefixes of path. Although both of - them can be used to choose a storage backend, backend has a higher priority - that is if they are all set, the storage backend will be chosen by the - backend argument. If they are all `None`, the dist backend will be chosen. - Note that It can also register other backend accessor with a given name, - prefixes, and backend class. + them can be used to choose a storage backend, ``backend`` has a higher + priority that is if they are all set, the storage backend will be chosen by + the backend argument. If they are all `None`, the disk backend will be + chosen. Note that It can also register other backend accessor with a given + name, prefixes, and backend class. Args: backend (str): The storage backend type. Options are "disk", "ceph", @@ -423,3 +430,6 @@ def put(self, obj, filepath): def put_text(self, obj, filepath): self.client.put_text(obj, filepath) + + def remove(self, filepath): + self.client.remove(filepath) diff --git a/mmcv/fileio/io.py b/mmcv/fileio/io.py index 22146fbb47..e60c05ecc9 100644 --- a/mmcv/fileio/io.py +++ b/mmcv/fileio/io.py @@ -100,7 +100,7 @@ def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs): return handler.dump_to_str(obj, **kwargs) elif is_str(file): if handler.str_like_obj: - with StringIO as f: + with StringIO() as f: handler.dump_to_fileobj(obj, f, **kwargs) client.put_text(f.getvalue(), file) else: diff --git a/tests/test_fileclient.py b/tests/test_fileclient.py index 446e33b719..146c4940b2 100644 --- a/tests/test_fileclient.py +++ b/tests/test_fileclient.py @@ -144,6 +144,11 @@ def test_petrel_backend(self, backend, prefixes): assert img.shape == self.img_shape petrel_backend.client._client.Get.assert_called_with( str(self.img_path).replace(str(self.test_data_dir), petrel_path)) + # test remove + petrel_backend.client._client.delete = MagicMock() + petrel_backend.remove(self.img_path) + petrel_backend.client._client.delete.assert_called_with( + str(self.img_path).replace(str(self.test_data_dir), petrel_path)) @patch('mc.MemcachedClient.GetInstance', MockMemcachedClient) @patch('mc.pyvector', MagicMock) From d641a8c74fa19a39beef7bc6510c430090abe261 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Wed, 22 Sep 2021 14:24:38 +0800 Subject: [PATCH 06/46] rename attribute str_like_obj to is_str_like_obj --- mmcv/fileio/handlers/base.py | 2 +- mmcv/fileio/handlers/pickle_handler.py | 2 +- mmcv/fileio/io.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mmcv/fileio/handlers/base.py b/mmcv/fileio/handlers/base.py index 0e398301c6..c184cf3c10 100644 --- a/mmcv/fileio/handlers/base.py +++ b/mmcv/fileio/handlers/base.py @@ -3,7 +3,7 @@ class BaseFileHandler(metaclass=ABCMeta): - str_like_obj = True + is_str_like_obj = True @abstractmethod def load_from_fileobj(self, file, **kwargs): diff --git a/mmcv/fileio/handlers/pickle_handler.py b/mmcv/fileio/handlers/pickle_handler.py index 0809995f51..648bf22b9c 100644 --- a/mmcv/fileio/handlers/pickle_handler.py +++ b/mmcv/fileio/handlers/pickle_handler.py @@ -6,7 +6,7 @@ class PickleHandler(BaseFileHandler): - str_like_obj = False + is_str_like_obj = False def load_from_fileobj(self, file, **kwargs): return pickle.load(file, **kwargs) diff --git a/mmcv/fileio/io.py b/mmcv/fileio/io.py index e60c05ecc9..9ede89d847 100644 --- a/mmcv/fileio/io.py +++ b/mmcv/fileio/io.py @@ -48,7 +48,7 @@ def load(file, file_format=None, file_client_args=None, **kwargs): handler = file_handlers[file_format] if is_str(file): - if handler.str_like_obj: + if handler.is_str_like_obj: with StringIO(client.get_text(file)) as f: obj = handler.load_from_fileobj(f, **kwargs) else: @@ -99,7 +99,7 @@ def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs): if file is None: return handler.dump_to_str(obj, **kwargs) elif is_str(file): - if handler.str_like_obj: + if handler.is_str_like_obj: with StringIO() as f: handler.dump_to_fileobj(obj, f, **kwargs) client.put_text(f.getvalue(), file) From 68f0ab65d02db89418c1c572c6d2f392f8f5684a Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Thu, 23 Sep 2021 14:26:21 +0800 Subject: [PATCH 07/46] add infer_client method --- mmcv/fileio/file_client.py | 38 +++++++++++++++++++++++++++++++++++++- mmcv/fileio/io.py | 14 ++------------ tests/test_fileclient.py | 25 ++++++++++++++++++++++++- 3 files changed, 63 insertions(+), 14 deletions(-) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index e6a3188600..aa7fa05cef 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import inspect +import os from abc import ABCMeta, abstractmethod from pathlib import Path from typing import Optional, Union @@ -232,6 +233,11 @@ def put_text(self, obj, filepath, encoding='utf-8'): with open(filepath, 'w', encoding=encoding) as f: f.write(obj) + def remove(self, filepath: Union[str, Path]) -> None: + """Remove a file.""" + filepath = str(filepath) + os.remove(filepath) + class HTTPBackend(BaseStorageBackend): """HTTP and HTTPS storage bachend.""" @@ -319,8 +325,19 @@ def __init__(self, backend=None, prefixes=None, **kwargs): self.client = self._prefix_to_backends[prefix](**kwargs) break + for backend_name, backend_cls in self._backends.items(): + if isinstance(self.client, backend_cls): + self.backend_name = backend_name + break + @staticmethod - def parse_uri_prefix(uri): + def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]: + """Parse the prefix of a uri. + + Args: + uri (str | Path): Uri to be parsed its prefix. + """ + assert isinstance(uri, str) or isinstance(uri, Path) uri = str(uri) if '://' not in uri: return None @@ -332,6 +349,25 @@ def parse_uri_prefix(uri): _, prefix = prefix.split(':') return prefix + @classmethod + def infer_client(cls, + file_client_args: Optional[dict] = None, + uri: Optional[Union[str, Path]] = None) -> 'FileClient': + """Infer a file client. + + Args: + file_client_args (dict): Arguments to instantiate a FileClient. + Default: None. + uri (str | Path, optional): Uri to be parsed its prefix. + Default: None. + """ + assert file_client_args is not None or uri is not None + if file_client_args is None: + file_prefix = cls.parse_uri_prefix(uri) # type: ignore + return cls(prefixes=file_prefix) + else: + return cls(**file_client_args) + @classmethod def _register_backend(cls, name, backend, force=False, prefixes=None): if not isinstance(name, str): diff --git a/mmcv/fileio/io.py b/mmcv/fileio/io.py index 9ede89d847..5c93acf2b2 100644 --- a/mmcv/fileio/io.py +++ b/mmcv/fileio/io.py @@ -40,14 +40,9 @@ def load(file, file_format=None, file_client_args=None, **kwargs): if file_format not in file_handlers: raise TypeError(f'Unsupported format: {file_format}') - if file_client_args is None: - file_prefix = FileClient.parse_uri_prefix(file) - client = FileClient(prefixes=file_prefix) - else: - client = FileClient(**file_client_args) - handler = file_handlers[file_format] if is_str(file): + client = FileClient.infer_client(file_client_args, file) if handler.is_str_like_obj: with StringIO(client.get_text(file)) as f: obj = handler.load_from_fileobj(f, **kwargs) @@ -89,16 +84,11 @@ def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs): if file_format not in file_handlers: raise TypeError(f'Unsupported format: {file_format}') - if file_client_args is None: - file_prefix = FileClient.parse_uri_prefix(file) - client = FileClient(prefixes=file_prefix) - else: - client = FileClient(**file_client_args) - handler = file_handlers[file_format] if file is None: return handler.dump_to_str(obj, **kwargs) elif is_str(file): + client = FileClient.infer_client(file_client_args, file) if handler.is_str_like_obj: with StringIO() as f: handler.dump_to_fileobj(obj, f, **kwargs) diff --git a/tests/test_fileclient.py b/tests/test_fileclient.py index 146c4940b2..223ec14720 100644 --- a/tests/test_fileclient.py +++ b/tests/test_fileclient.py @@ -144,7 +144,7 @@ def test_petrel_backend(self, backend, prefixes): assert img.shape == self.img_shape petrel_backend.client._client.Get.assert_called_with( str(self.img_path).replace(str(self.test_data_dir), petrel_path)) - # test remove + # remove a file petrel_backend.client._client.delete = MagicMock() petrel_backend.remove(self.img_path) petrel_backend.client._client.delete.assert_called_with( @@ -223,6 +223,13 @@ def test_http_backend(self, backend, prefixes): assert self.text_path.open('r').read() == value_buf def test_parse_uri_prefix(self): + # input path is None + with pytest.raises(AssertionError): + FileClient.parse_uri_prefix(None) + # input path is list + with pytest.raises(AssertionError): + FileClient.parse_uri_prefix([]) + # input path is Path object assert FileClient.parse_uri_prefix(self.img_path) is None # input path is str @@ -241,6 +248,22 @@ def test_parse_uri_prefix(self): img_url = 'clusterName:s3://your_bucket/img.png' assert FileClient.parse_uri_prefix(img_url) == 's3' + def test_infer_client(self): + # HardDiskBackend + file_client_args = {'backend': 'disk'} + client = FileClient.infer_client(file_client_args) + assert client.backend_name == 'disk' + client = FileClient.infer_client(uri=self.img_path) + assert client.backend_name == 'disk' + + # PetrelBackend + file_client_args = {'backend': 'petrel'} + client = FileClient.infer_client(file_client_args) + assert client.backend_name == 'petrel' + uri = 's3://user_data' + client = FileClient.infer_client(uri=uri) + assert client.backend_name == 'petrel' + def test_register_backend(self): # name must be a string From 31caf8e307a3b7c77b9025539ab5b6d1c99bab0f Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Thu, 23 Sep 2021 22:16:42 +0800 Subject: [PATCH 08/46] add check_exist method --- mmcv/fileio/file_client.py | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index aa7fa05cef..ca8626a173 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -1,11 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. import inspect import os +import os.path as osp from abc import ABCMeta, abstractmethod from pathlib import Path from typing import Optional, Union from urllib.request import urlopen +from mmcv.utils.path import is_filepath + class BaseStorageBackend(metaclass=ABCMeta): """Abstract class of storage backends. @@ -83,11 +86,14 @@ def __init__(self, assert isinstance(path_mapping, dict) or path_mapping is None self.path_mapping = path_mapping - def get(self, filepath: Union[str, Path]) -> memoryview: - filepath = str(filepath) + def _path_mapping(self, filepath: str) -> str: if self.path_mapping is not None: for k, v in self.path_mapping.items(): filepath = filepath.replace(k, v) + return filepath + + def get(self, filepath: Union[str, Path]) -> memoryview: + filepath = self._path_mapping(str(filepath)) value = self._client.Get(filepath) value_buf = memoryview(value) return value_buf @@ -98,10 +104,7 @@ def get_text(self, return str(self.get(filepath), encoding=encoding) def put(self, obj: bytes, filepath: Union[str, Path]) -> None: - filepath = str(filepath) - if self.path_mapping is not None: - for k, v in self.path_mapping.items(): - filepath = filepath.replace(k, v) + filepath = self._path_mapping(str(filepath)) self._client.put(filepath, obj) def put_text(self, @@ -111,12 +114,13 @@ def put_text(self, self.put(bytes(obj, encoding=encoding), str(filepath)) def remove(self, filepath: Union[str, Path]) -> None: - filepath = str(filepath) - if self.path_mapping is not None: - for k, v in self.path_mapping.items(): - filepath = filepath.replace(k, v) + filepath = self._path_mapping(str(filepath)) self._client.delete(filepath) + def check_exist(self, filepath: Union[str, Path]) -> bool: + # TODO, need other team to support the feature + return True + class MemcachedBackend(BaseStorageBackend): """Memcached storage backend. @@ -238,6 +242,9 @@ def remove(self, filepath: Union[str, Path]) -> None: filepath = str(filepath) os.remove(filepath) + def check_exist(self, filepath: Union[str, Path]) -> bool: + return osp.exists(str(filepath)) + class HTTPBackend(BaseStorageBackend): """HTTP and HTTPS storage bachend.""" @@ -337,7 +344,7 @@ def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]: Args: uri (str | Path): Uri to be parsed its prefix. """ - assert isinstance(uri, str) or isinstance(uri, Path) + assert is_filepath(uri) uri = str(uri) if '://' not in uri: return None @@ -469,3 +476,6 @@ def put_text(self, obj, filepath): def remove(self, filepath): self.client.remove(filepath) + + def check_exist(self, filepath): + return self.client.check_exist(filepath) From 7e7a80ff3600d7278926de2c03f7a20148b2d787 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Fri, 24 Sep 2021 14:34:17 +0800 Subject: [PATCH 09/46] rename var client to file_client --- mmcv/fileio/io.py | 12 ++++++------ mmcv/fileio/parse.py | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/mmcv/fileio/io.py b/mmcv/fileio/io.py index 5c93acf2b2..1ba9947319 100644 --- a/mmcv/fileio/io.py +++ b/mmcv/fileio/io.py @@ -42,12 +42,12 @@ def load(file, file_format=None, file_client_args=None, **kwargs): handler = file_handlers[file_format] if is_str(file): - client = FileClient.infer_client(file_client_args, file) + file_client = FileClient.infer_client(file_client_args, file) if handler.is_str_like_obj: - with StringIO(client.get_text(file)) as f: + with StringIO(file_client.get_text(file)) as f: obj = handler.load_from_fileobj(f, **kwargs) else: - with BytesIO(client.get(file)) as f: + with BytesIO(file_client.get(file)) as f: obj = handler.load_from_fileobj(f, **kwargs) elif hasattr(file, 'read'): obj = handler.load_from_fileobj(file, **kwargs) @@ -88,15 +88,15 @@ def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs): if file is None: return handler.dump_to_str(obj, **kwargs) elif is_str(file): - client = FileClient.infer_client(file_client_args, file) + file_client = FileClient.infer_client(file_client_args, file) if handler.is_str_like_obj: with StringIO() as f: handler.dump_to_fileobj(obj, f, **kwargs) - client.put_text(f.getvalue(), file) + file_client.put_text(f.getvalue(), file) else: with BytesIO() as f: handler.dump_to_fileobj(obj, f, **kwargs) - client.put(f.getvalue(), file) + file_client.put(f.getvalue(), file) elif hasattr(file, 'write'): handler.dump_to_fileobj(obj, file, **kwargs) else: diff --git a/mmcv/fileio/parse.py b/mmcv/fileio/parse.py index feff0ce5a0..2bdf0ecc4b 100644 --- a/mmcv/fileio/parse.py +++ b/mmcv/fileio/parse.py @@ -30,11 +30,11 @@ def list_from_file(filename, item_list = [] if file_client_args is None: file_prefix = FileClient.parse_uri_prefix(filename) - client = FileClient(prefixes=file_prefix) + file_client = FileClient(prefixes=file_prefix) else: - client = FileClient(**file_client_args) + file_client = FileClient(**file_client_args) - with StringIO(client.get_text(filename, encoding)) as f: + with StringIO(file_client.get_text(filename, encoding)) as f: for _ in range(offset): f.readline() for line in f: @@ -69,11 +69,11 @@ def dict_from_file(filename, mapping = {} if file_client_args is None: file_prefix = FileClient.parse_uri_prefix(filename) - client = FileClient(prefixes=file_prefix) + file_client = FileClient(prefixes=file_prefix) else: - client = FileClient(**file_client_args) + file_client = FileClient(**file_client_args) - with StringIO(client.get_text(filename, encoding)) as f: + with StringIO(file_client.get_text(filename, encoding)) as f: for line in f: items = line.rstrip('\n').split() assert len(items) >= 2 From aa8274ba9a59df8d245e8f9f2e49d148632b4497 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Sun, 26 Sep 2021 15:33:26 +0800 Subject: [PATCH 10/46] polish docstring --- mmcv/fileio/file_client.py | 8 ++++---- mmcv/fileio/handlers/base.py | 6 ++++++ mmcv/fileio/io.py | 1 + mmcv/fileio/parse.py | 14 ++------------ 4 files changed, 13 insertions(+), 16 deletions(-) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index ca8626a173..16847c1743 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -342,7 +342,7 @@ def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]: """Parse the prefix of a uri. Args: - uri (str | Path): Uri to be parsed its prefix. + uri (str | Path): Uri to be parsed that contains the file prefix. """ assert is_filepath(uri) uri = str(uri) @@ -360,13 +360,13 @@ def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]: def infer_client(cls, file_client_args: Optional[dict] = None, uri: Optional[Union[str, Path]] = None) -> 'FileClient': - """Infer a file client. + """Infer a suitable file client based on the URI and arguments. Args: file_client_args (dict): Arguments to instantiate a FileClient. Default: None. - uri (str | Path, optional): Uri to be parsed its prefix. - Default: None. + uri (str | Path, optional): Uri to be parsed that contains the file + prefix. Default: None. """ assert file_client_args is not None or uri is not None if file_client_args is None: diff --git a/mmcv/fileio/handlers/base.py b/mmcv/fileio/handlers/base.py index c184cf3c10..22d66d5b1b 100644 --- a/mmcv/fileio/handlers/base.py +++ b/mmcv/fileio/handlers/base.py @@ -3,6 +3,12 @@ class BaseFileHandler(metaclass=ABCMeta): + # is_str_like_obj is a flag to mark which type of file object is processed, + # bytes-like object or str-like object. For example, pickle only process + # the bytes-like object and json only process the str-like object. The flag + # will be used to check which type of buffer is used. If str-like object, + # StringIO will be used. If bytes-like object, BytesIO will be used. The + # usage of the flag can be found in `mmcv.load` or `mmcv.dump` is_str_like_obj = True @abstractmethod diff --git a/mmcv/fileio/io.py b/mmcv/fileio/io.py index 1ba9947319..c5206f208e 100644 --- a/mmcv/fileio/io.py +++ b/mmcv/fileio/io.py @@ -70,6 +70,7 @@ def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs): file_format (str, optional): Same as :func:`load`. file_client_args (dict): Arguments to instantiate a FileClient. See :class:`mmcv.fileio.FileClient` for details. Default: None. + Returns: bool: True for success, False otherwise. """ diff --git a/mmcv/fileio/parse.py b/mmcv/fileio/parse.py index 2bdf0ecc4b..d7242c0501 100644 --- a/mmcv/fileio/parse.py +++ b/mmcv/fileio/parse.py @@ -28,12 +28,7 @@ def list_from_file(filename, """ cnt = 0 item_list = [] - if file_client_args is None: - file_prefix = FileClient.parse_uri_prefix(filename) - file_client = FileClient(prefixes=file_prefix) - else: - file_client = FileClient(**file_client_args) - + file_client = FileClient.infer_client(file_client_args, filename) with StringIO(file_client.get_text(filename, encoding)) as f: for _ in range(offset): f.readline() @@ -67,12 +62,7 @@ def dict_from_file(filename, dict: The parsed contents. """ mapping = {} - if file_client_args is None: - file_prefix = FileClient.parse_uri_prefix(filename) - file_client = FileClient(prefixes=file_prefix) - else: - file_client = FileClient(**file_client_args) - + file_client = FileClient.infer_client(file_client_args, filename) with StringIO(file_client.get_text(filename, encoding)) as f: for line in f: items = line.rstrip('\n').split() From bb4712d7574323c20dd85a3699bb2a6f135a6962 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Mon, 27 Sep 2021 18:42:11 +0800 Subject: [PATCH 11/46] add join_paths method --- mmcv/fileio/file_client.py | 9 +++++++++ tests/test_fileclient.py | 18 ++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index 16847c1743..fa225c1fb1 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -121,6 +121,9 @@ def check_exist(self, filepath: Union[str, Path]) -> bool: # TODO, need other team to support the feature return True + def join_paths(self, path, *paths) -> str: + return f'{path}/{"/".join(paths)}' + class MemcachedBackend(BaseStorageBackend): """Memcached storage backend. @@ -245,6 +248,9 @@ def remove(self, filepath: Union[str, Path]) -> None: def check_exist(self, filepath: Union[str, Path]) -> bool: return osp.exists(str(filepath)) + def join_paths(self, path, *paths) -> str: + return osp.join(path, *paths) + class HTTPBackend(BaseStorageBackend): """HTTP and HTTPS storage bachend.""" @@ -479,3 +485,6 @@ def remove(self, filepath): def check_exist(self, filepath): return self.client.check_exist(filepath) + + def join_paths(self, path, *paths) -> str: + return self.client.join_paths(path, *paths) diff --git a/tests/test_fileclient.py b/tests/test_fileclient.py index 223ec14720..5f239b2cfa 100644 --- a/tests/test_fileclient.py +++ b/tests/test_fileclient.py @@ -1,3 +1,4 @@ +import platform import sys from pathlib import Path from unittest.mock import MagicMock, patch @@ -264,6 +265,23 @@ def test_infer_client(self): client = FileClient.infer_client(uri=uri) assert client.backend_name == 'petrel' + def test_join_paths(self): + # HardDiskBackend + file_client_args = {'backend': 'disk'} + client = FileClient(**file_client_args) + dir1 = '/path/of/your/directory' + dir2 = 'c:\\windows\\path\\of\\your\\directory' + filename1 = 'filename' + if platform.system() == 'Windows': + assert client.join_paths(dir2, filename1) == f'{dir2}\\{filename1}' + else: + assert client.join_paths(dir1, filename1) == f'{dir1}/{filename1}' + + # PetrelBackend + file_client_args = {'backend': 'petrel'} + client = FileClient(**file_client_args) + assert client.join_paths(dir1, filename1) == f'{dir1}/{filename1}' + def test_register_backend(self): # name must be a string From d4b6d962a85b23b917749424cad12f1cab22486d Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Tue, 28 Sep 2021 15:42:29 +0800 Subject: [PATCH 12/46] remove join_paths and add _format_path --- mmcv/fileio/file_client.py | 25 +++++++++++++++---------- tests/test_fileclient.py | 18 ------------------ 2 files changed, 15 insertions(+), 28 deletions(-) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index fa225c1fb1..dcec9d03ee 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -2,6 +2,7 @@ import inspect import os import os.path as osp +import re from abc import ABCMeta, abstractmethod from pathlib import Path from typing import Optional, Union @@ -92,8 +93,19 @@ def _path_mapping(self, filepath: str) -> str: filepath = filepath.replace(k, v) return filepath + def _format_path(self, filepath: str) -> str: + """Convert filepath to standard format of petrel oss. + + Since the filepath is concatenated by `os.path.join`, in a windows + environment, the filepath will be the format of + 's3://bucket_name\\image.jpg'. By invoking `_format_path`, the above + filepath will be converted to 's3://bucket_name/image.jpn'. + """ + return re.sub(r'\\+', '/', filepath) + def get(self, filepath: Union[str, Path]) -> memoryview: filepath = self._path_mapping(str(filepath)) + filepath = self._format_path(filepath) value = self._client.Get(filepath) value_buf = memoryview(value) return value_buf @@ -105,25 +117,24 @@ def get_text(self, def put(self, obj: bytes, filepath: Union[str, Path]) -> None: filepath = self._path_mapping(str(filepath)) + filepath = self._format_path(filepath) self._client.put(filepath, obj) def put_text(self, obj: str, filepath: Union[str, Path], encoding: str = 'utf-8') -> None: - self.put(bytes(obj, encoding=encoding), str(filepath)) + self.put(bytes(obj, encoding=encoding), filepath) def remove(self, filepath: Union[str, Path]) -> None: filepath = self._path_mapping(str(filepath)) + filepath = self._format_path(filepath) self._client.delete(filepath) def check_exist(self, filepath: Union[str, Path]) -> bool: # TODO, need other team to support the feature return True - def join_paths(self, path, *paths) -> str: - return f'{path}/{"/".join(paths)}' - class MemcachedBackend(BaseStorageBackend): """Memcached storage backend. @@ -248,9 +259,6 @@ def remove(self, filepath: Union[str, Path]) -> None: def check_exist(self, filepath: Union[str, Path]) -> bool: return osp.exists(str(filepath)) - def join_paths(self, path, *paths) -> str: - return osp.join(path, *paths) - class HTTPBackend(BaseStorageBackend): """HTTP and HTTPS storage bachend.""" @@ -485,6 +493,3 @@ def remove(self, filepath): def check_exist(self, filepath): return self.client.check_exist(filepath) - - def join_paths(self, path, *paths) -> str: - return self.client.join_paths(path, *paths) diff --git a/tests/test_fileclient.py b/tests/test_fileclient.py index 5f239b2cfa..223ec14720 100644 --- a/tests/test_fileclient.py +++ b/tests/test_fileclient.py @@ -1,4 +1,3 @@ -import platform import sys from pathlib import Path from unittest.mock import MagicMock, patch @@ -265,23 +264,6 @@ def test_infer_client(self): client = FileClient.infer_client(uri=uri) assert client.backend_name == 'petrel' - def test_join_paths(self): - # HardDiskBackend - file_client_args = {'backend': 'disk'} - client = FileClient(**file_client_args) - dir1 = '/path/of/your/directory' - dir2 = 'c:\\windows\\path\\of\\your\\directory' - filename1 = 'filename' - if platform.system() == 'Windows': - assert client.join_paths(dir2, filename1) == f'{dir2}\\{filename1}' - else: - assert client.join_paths(dir1, filename1) == f'{dir1}/{filename1}' - - # PetrelBackend - file_client_args = {'backend': 'petrel'} - client = FileClient(**file_client_args) - assert client.join_paths(dir1, filename1) == f'{dir1}/{filename1}' - def test_register_backend(self): # name must be a string From 767f7fb6671a68b077701046904642c384440227 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Sun, 3 Oct 2021 17:04:23 +0800 Subject: [PATCH 13/46] enhance unittest --- mmcv/fileio/file_client.py | 29 +++++++++++++++++++---------- tests/test_fileclient.py | 7 ++++++- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index dcec9d03ee..4526348112 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -29,7 +29,7 @@ def get_text(self, filepath): class CephBackend(BaseStorageBackend): - """Ceph storage backend. + """Ceph storage backend (for internal use). Args: path_mapping (dict|None): path mapping dict from local path to Petrel @@ -63,27 +63,35 @@ def get_text(self, filepath, encoding=None): class PetrelBackend(BaseStorageBackend): """Petrel storage backend (for internal use). + PetrelBackend supports reading or writing data to multiple clusters. If the + filepath contains the cluster name, PetrelBackend will read from the + filepath or write to the filepath. Otherwise, PetrelBackend will access + the default cluster. + Args: path_mapping (dict|None): path mapping dict from local path to Petrel path. When `path_mapping={'src': 'dst'}`, `src` in `filepath` will be replaced by `dst`. Default: None. enable_mc (bool): whether to enable memcached support. Default: True. - enable_multi_cluster (bool): Whether to enable multiple clusters. - Default: False. + + Examples: + >>> filepath1 = 's3://path/of/file' + >>> filepath2 = 'cluster-name:s3://path/of/file' + >>> client = PetrelBackend() + >>> client.get(filepath1) # get from default cluster + >>> client.get(filepath2) # get from cluster-name """ def __init__(self, path_mapping: Optional[dict] = None, - enable_mc: bool = True, - enable_multi_cluster: bool = False): + enable_mc: bool = True): try: from petrel_client import client except ImportError: raise ImportError('Please install petrel_client to enable ' 'PetrelBackend.') - self._client = client.Client( - enable_mc=enable_mc, enable_multi_cluster=enable_multi_cluster) + self._client = client.Client(enable_mc=enable_mc) assert isinstance(path_mapping, dict) or path_mapping is None self.path_mapping = path_mapping @@ -99,7 +107,7 @@ def _format_path(self, filepath: str) -> str: Since the filepath is concatenated by `os.path.join`, in a windows environment, the filepath will be the format of 's3://bucket_name\\image.jpg'. By invoking `_format_path`, the above - filepath will be converted to 's3://bucket_name/image.jpn'. + filepath will be converted to 's3://bucket_name/image.jpg'. """ return re.sub(r'\\+', '/', filepath) @@ -132,8 +140,9 @@ def remove(self, filepath: Union[str, Path]) -> None: self._client.delete(filepath) def check_exist(self, filepath: Union[str, Path]) -> bool: - # TODO, need other team to support the feature - return True + filepath = self._path_mapping(str(filepath)) + filepath = self._format_path(filepath) + return self._client.contains(filepath) class MemcachedBackend(BaseStorageBackend): diff --git a/tests/test_fileclient.py b/tests/test_fileclient.py index 223ec14720..6eb9253690 100644 --- a/tests/test_fileclient.py +++ b/tests/test_fileclient.py @@ -144,11 +144,16 @@ def test_petrel_backend(self, backend, prefixes): assert img.shape == self.img_shape petrel_backend.client._client.Get.assert_called_with( str(self.img_path).replace(str(self.test_data_dir), petrel_path)) - # remove a file + # test `remove` petrel_backend.client._client.delete = MagicMock() petrel_backend.remove(self.img_path) petrel_backend.client._client.delete.assert_called_with( str(self.img_path).replace(str(self.test_data_dir), petrel_path)) + # test `check_exist` + petrel_backend.client._client.contains = MagicMock(return_value=True) + assert petrel_backend.check_exist(self.img_path) + petrel_backend.client._client.contains.assert_called_with( + str(self.img_path).replace(str(self.test_data_dir), petrel_path)) @patch('mc.MemcachedClient.GetInstance', MockMemcachedClient) @patch('mc.pyvector', MagicMock) From b930678e7e616a00241c99d3e77d55885d894159 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Sun, 3 Oct 2021 21:07:16 +0800 Subject: [PATCH 14/46] refactor unittest --- mmcv/fileio/file_client.py | 14 +++++++++++++ tests/test_fileclient.py | 43 +++++++++++++++++++++++++------------- 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index 4526348112..9c37554330 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -144,6 +144,14 @@ def check_exist(self, filepath: Union[str, Path]) -> bool: filepath = self._format_path(filepath) return self._client.contains(filepath) + def isfile(self, filepath: Union[str, Path]) -> bool: + filepath = self._path_mapping(str(filepath)) + filepath = self._format_path(filepath) + # petrel checks a filepath whether it is a file by its ending char + if filepath.endswith('/'): + return False + return self.check_exist(filepath) + class MemcachedBackend(BaseStorageBackend): """Memcached storage backend. @@ -268,6 +276,9 @@ def remove(self, filepath: Union[str, Path]) -> None: def check_exist(self, filepath: Union[str, Path]) -> bool: return osp.exists(str(filepath)) + def isfile(self, filepath: Union[str, Path]) -> bool: + return osp.isfile(str(filepath)) + class HTTPBackend(BaseStorageBackend): """HTTP and HTTPS storage bachend.""" @@ -502,3 +513,6 @@ def remove(self, filepath): def check_exist(self, filepath): return self.client.check_exist(filepath) + + def isfile(self, filepath): + return self.client.isfile(filepath) diff --git a/tests/test_fileclient.py b/tests/test_fileclient.py index 6eb9253690..0781c07d24 100644 --- a/tests/test_fileclient.py +++ b/tests/test_fileclient.py @@ -133,27 +133,42 @@ def test_petrel_backend(self, backend, prefixes): # `path_mapping` is either None or dict with pytest.raises(AssertionError): FileClient('petrel', path_mapping=1) - # test `path_mapping` + + # test `_path_mapping` petrel_path = 's3://user/data' petrel_backend = FileClient( 'petrel', path_mapping={str(self.test_data_dir): petrel_path}) - petrel_backend.client._client.Get = MagicMock( - return_value=petrel_backend.client._client.Get(self.img_path)) - img_bytes = petrel_backend.get(self.img_path) - img = mmcv.imfrombytes(img_bytes) - assert img.shape == self.img_shape - petrel_backend.client._client.Get.assert_called_with( - str(self.img_path).replace(str(self.test_data_dir), petrel_path)) + assert petrel_backend.client._path_mapping(str(self.img_path)) == \ + str(self.img_path).replace(str(self.test_data_dir), petrel_path) + + petrel_path = 's3://user/data/test.jpg' + petrel_backend = FileClient('petrel') + + # test `_format_path` + assert petrel_backend.client._format_path('s3://user\\data\\test.jpg')\ + == petrel_path + + # test `get` + petrel_backend.client._client.Get = MagicMock(return_value=b'petrel') + petrel_backend.get(petrel_path) + petrel_backend.client._client.Get.assert_called_with(petrel_path) + # test `remove` petrel_backend.client._client.delete = MagicMock() - petrel_backend.remove(self.img_path) - petrel_backend.client._client.delete.assert_called_with( - str(self.img_path).replace(str(self.test_data_dir), petrel_path)) + petrel_backend.remove(petrel_path) + petrel_backend.client._client.delete.assert_called_with(petrel_path) + # test `check_exist` petrel_backend.client._client.contains = MagicMock(return_value=True) - assert petrel_backend.check_exist(self.img_path) - petrel_backend.client._client.contains.assert_called_with( - str(self.img_path).replace(str(self.test_data_dir), petrel_path)) + assert petrel_backend.check_exist(petrel_path) + petrel_backend.client._client.contains.assert_called_with(petrel_path) + + # test `isfile` + petrel_backend.client._client.contains = MagicMock(return_value=True) + assert petrel_backend.isfile(petrel_path) + petrel_backend.client._client.contains.assert_called_with(petrel_path) + # if ending with '/', it is not a file + assert not petrel_backend.isfile(f'{petrel_path}/') @patch('mc.MemcachedClient.GetInstance', MockMemcachedClient) @patch('mc.pyvector', MagicMock) From 1752698aee5bda7fa05ba81f5406cc361791ae22 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Mon, 4 Oct 2021 21:49:37 +0800 Subject: [PATCH 15/46] singleton pattern --- mmcv/fileio/file_client.py | 107 +++++++++++++++++++++---------------- tests/test_fileclient.py | 92 ++++++++++++++++++++----------- 2 files changed, 123 insertions(+), 76 deletions(-) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index 9c37554330..01b12b37a5 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -69,10 +69,10 @@ class PetrelBackend(BaseStorageBackend): the default cluster. Args: - path_mapping (dict|None): path mapping dict from local path to Petrel - path. When `path_mapping={'src': 'dst'}`, `src` in `filepath` will - be replaced by `dst`. Default: None. - enable_mc (bool): whether to enable memcached support. Default: True. + path_mapping (dict, optional): Path mapping dict from local path to + Petrel path. When `path_mapping={'src': 'dst'}`, `src` in + `filepath` will be replaced by `dst`. Default: None. + enable_mc (bool): Whether to enable memcached support. Default: True. Examples: >>> filepath1 = 's3://path/of/file' @@ -297,30 +297,28 @@ class FileClient: The client loads a file or text in a specified backend from its path and return it as a binary or text file. There are two ways to choose a - backend, the name of backend and the prefixes of path. Although both of - them can be used to choose a storage backend, ``backend`` has a higher - priority that is if they are all set, the storage backend will be chosen by - the backend argument. If they are all `None`, the disk backend will be - chosen. Note that It can also register other backend accessor with a given - name, prefixes, and backend class. + backend, the name of backend and the prefix of path. Although both of them + can be used to choose a storage backend, ``backend`` has a higher priority + that is if they are all set, the storage backend will be chosen by the + backend argument. If they are all `None`, the disk backend will be chosen. + Note that It can also register other backend accessor with a given name, + prefixes, and backend class. In addition, We use the singleton pattern to + avoid repeated object creationIf the arguments are the same, the same + object is returned. Args: - backend (str): The storage backend type. Options are "disk", "ceph", - "memcached", "lmdb", "http" and "petrel". Default: None. - prefixes (str or list[str] or tuple[str]): The prefixes of the - registered storage backend. Options are "s3", "http", "https". - Default: None. - - .. versionadd:: 1.3.14 - The *prefixes* parameter. + backend (str, optional): The storage backend type. Options are "disk", + "ceph", "memcached", "lmdb", "http" and "petrel". Default: None. + prefix (str, optional): The prefix of the registered storage backend. + Options are "s3", "http", "https". Default: None. Example: >>> # only set backend >>> file_client = FileClient(backend='ceph') >>> # only set prefixes - >>> file_client = FileClient(prefixes='s3') + >>> file_client = FileClient(prefix='s3') >>> # set both backend and prefixes but use backend to choose client - >>> file_client = FileClient(backend='ceph', prefixes='s3') + >>> file_client = FileClient(backend='ceph', prefix='s3') Attributes: client (:obj:`BaseStorageBackend`): The backend object. @@ -334,42 +332,55 @@ class FileClient: 'petrel': PetrelBackend, 'http': HTTPBackend, } + # This collection is used to record the overridden backend, and when a + # backend appears in the collection, the singleton pattern is disabled for + # that backend, because if the singleton pattern is used, then the object + # returned will be the backend before the override + _overridden_backends = set() _prefix_to_backends = { 's3': PetrelBackend, 'http': HTTPBackend, 'https': HTTPBackend, } + _overridden_prefixes = set() - def __init__(self, backend=None, prefixes=None, **kwargs): - if backend is None and prefixes is None: + _instances = {} + + def __new__(cls, backend=None, prefix=None, **kwargs): + if backend is None and prefix is None: backend = 'disk' - if backend is not None and backend not in self._backends: + if backend is not None and backend not in cls._backends: raise ValueError( f'Backend {backend} is not supported. Currently supported ones' - f' are {list(self._backends.keys())}') - if prefixes is not None: - if isinstance(prefixes, str): - prefixes = [prefixes] - else: - assert isinstance(prefixes, (list, tuple)) + f' are {list(cls._backends.keys())}') + if prefix is not None and prefix not in cls._prefix_to_backends: + raise ValueError( + f'prefix {prefix} is not supported. Currently supported ones ' + f'are {list(cls._prefix_to_backends.keys())}') - if not set(prefixes).issubset(self._prefix_to_backends.keys()): - raise ValueError( - f'prefixes {prefixes} is not supported. Currently ' - 'supported ones are ' - f'{list(self._prefix_to_backends.keys())}') + arg_key = f'{backend}:{prefix}' + for key, value in kwargs.items(): + arg_key += f':{key}:{value}' - if backend is not None: - self.client = self._backends[backend](**kwargs) + if (arg_key in cls._instances + and backend not in cls._overridden_backends + and prefix not in cls._overridden_prefixes): + _instance = cls._instances[arg_key] else: - for prefix in prefixes: - self.client = self._prefix_to_backends[prefix](**kwargs) - break + _instance = super().__new__(cls) + if backend is not None: + _instance.client = cls._backends[backend](**kwargs) + _instance.backend_name = backend + else: + _instance.client = cls._prefix_to_backends[prefix](**kwargs) + # infer the backend name according to prefix + for backend_name, backend_cls in cls._backends.items(): + if isinstance(_instance.client, backend_cls): + _instance.backend_name = backend_name + break + cls._instances[arg_key] = _instance - for backend_name, backend_cls in self._backends.items(): - if isinstance(self.client, backend_cls): - self.backend_name = backend_name - break + return _instance @staticmethod def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]: @@ -405,7 +416,7 @@ def infer_client(cls, assert file_client_args is not None or uri is not None if file_client_args is None: file_prefix = cls.parse_uri_prefix(uri) # type: ignore - return cls(prefixes=file_prefix) + return cls(prefix=file_prefix) else: return cls(**file_client_args) @@ -425,14 +436,20 @@ def _register_backend(cls, name, backend, force=False, prefixes=None): f'{name} is already registered as a storage backend, ' 'add "force=True" if you want to override it') + if name in cls._backends and force: + cls._overridden_backends.add(name) cls._backends[name] = backend + if prefixes is not None: if isinstance(prefixes, str): prefixes = [prefixes] else: assert isinstance(prefixes, (list, tuple)) for prefix in prefixes: - if (prefix not in cls._prefix_to_backends) or force: + if prefix not in cls._prefix_to_backends: + cls._prefix_to_backends[prefix] = backend + elif (prefix in cls._prefix_to_backends) and force: + cls._overridden_prefixes.add(prefix) cls._prefix_to_backends[prefix] = backend else: raise KeyError( diff --git a/tests/test_fileclient.py b/tests/test_fileclient.py index 0781c07d24..c20e50eddd 100644 --- a/tests/test_fileclient.py +++ b/tests/test_fileclient.py @@ -116,10 +116,10 @@ def test_ceph_backend(self): str(self.img_path).replace(str(self.test_data_dir), ceph_path)) @patch('petrel_client.client.Client', MockPetrelClient) - @pytest.mark.parametrize('backend,prefixes', [('petrel', None), - (None, 's3')]) - def test_petrel_backend(self, backend, prefixes): - petrel_backend = FileClient(backend=backend, prefixes=prefixes) + @pytest.mark.parametrize('backend,prefix', [('petrel', None), + (None, 's3')]) + def test_petrel_backend(self, backend, prefix): + petrel_backend = FileClient(backend=backend, prefix=prefix) # input path is Path object img_bytes = petrel_backend.get(self.img_path) @@ -214,10 +214,10 @@ def test_lmdb_backend(self): img = mmcv.imfrombytes(img_bytes) assert img.shape == (120, 125, 3) - @pytest.mark.parametrize('backend,prefixes', [('http', None), - (None, 'http')]) - def test_http_backend(self, backend, prefixes): - http_backend = FileClient(backend=backend, prefixes=prefixes) + @pytest.mark.parametrize('backend,prefix', [('http', None), + (None, 'http')]) + def test_http_backend(self, backend, prefix): + http_backend = FileClient(backend=backend, prefix=prefix) img_url = 'https://raw.githubusercontent.com/open-mmlab/mmcv/' \ 'master/tests/data/color.jpg' text_url = 'https://raw.githubusercontent.com/open-mmlab/mmcv/' \ @@ -242,6 +242,36 @@ def test_http_backend(self, backend, prefixes): value_buf = http_backend.get_text(text_url) assert self.text_path.open('r').read() == value_buf + def test_new_magic_method(self): + + class DummyBackend1(BaseStorageBackend): + + def get(self, filepath): + return filepath + + def get_text(self, filepath, encoding='utf-8'): + return filepath + + FileClient.register_backend('dummy_backend', DummyBackend1) + client1 = FileClient(backend='dummy_backend') + client2 = FileClient(backend='dummy_backend') + assert client1 is client2 + + # if a backend is overwrote, it will disable the singleton pattern for + # the backend + class DummyBackend2(BaseStorageBackend): + + def get(self, filepath): + pass + + def get_text(self, filepath): + pass + + FileClient.register_backend('dummy_backend', DummyBackend2, force=True) + client3 = FileClient(backend='dummy_backend') + client4 = FileClient(backend='dummy_backend') + assert client3 is not client4 + def test_parse_uri_prefix(self): # input path is None with pytest.raises(AssertionError): @@ -323,7 +353,7 @@ def get_text(self, filepath, encoding='utf-8'): class Example2Backend(BaseStorageBackend): def get(self, filepath): - return 'bytes2' + return b'bytes2' def get_text(self, filepath, encoding='utf-8'): return 'text2' @@ -334,20 +364,20 @@ def get_text(self, filepath, encoding='utf-8'): FileClient.register_backend('example', Example2Backend, force=True) example_backend = FileClient('example') - assert example_backend.get(self.img_path) == 'bytes2' + assert example_backend.get(self.img_path) == b'bytes2' assert example_backend.get_text(self.text_path) == 'text2' @FileClient.register_backend(name='example3') class Example3Backend(BaseStorageBackend): def get(self, filepath): - return 'bytes3' + return b'bytes3' def get_text(self, filepath, encoding='utf-8'): return 'text3' example_backend = FileClient('example3') - assert example_backend.get(self.img_path) == 'bytes3' + assert example_backend.get(self.img_path) == b'bytes3' assert example_backend.get_text(self.text_path) == 'text3' assert 'example3' in FileClient._backends @@ -358,7 +388,7 @@ def get_text(self, filepath, encoding='utf-8'): class Example4Backend(BaseStorageBackend): def get(self, filepath): - return 'bytes4' + return b'bytes4' def get_text(self, filepath, encoding='utf-8'): return 'text4' @@ -367,20 +397,20 @@ def get_text(self, filepath, encoding='utf-8'): class Example5Backend(BaseStorageBackend): def get(self, filepath): - return 'bytes5' + return b'bytes5' def get_text(self, filepath, encoding='utf-8'): return 'text5' example_backend = FileClient('example3') - assert example_backend.get(self.img_path) == 'bytes5' + assert example_backend.get(self.img_path) == b'bytes5' assert example_backend.get_text(self.text_path) == 'text5' # prefixes is a str class Example6Backend(BaseStorageBackend): def get(self, filepath): - return 'bytes6' + return b'bytes6' def get_text(self, filepath, encoding='utf-8'): return 'text6' @@ -391,20 +421,20 @@ def get_text(self, filepath, encoding='utf-8'): force=True, prefixes='example4_prefix') example_backend = FileClient('example4') - assert example_backend.get(self.img_path) == 'bytes6' + assert example_backend.get(self.img_path) == b'bytes6' assert example_backend.get_text(self.text_path) == 'text6' - example_backend = FileClient(prefixes='example4_prefix') - assert example_backend.get(self.img_path) == 'bytes6' + example_backend = FileClient(prefix='example4_prefix') + assert example_backend.get(self.img_path) == b'bytes6' assert example_backend.get_text(self.text_path) == 'text6' - example_backend = FileClient('example4', prefixes='example4_prefix') - assert example_backend.get(self.img_path) == 'bytes6' + example_backend = FileClient('example4', prefix='example4_prefix') + assert example_backend.get(self.img_path) == b'bytes6' assert example_backend.get_text(self.text_path) == 'text6' # prefixes is a list of str class Example7Backend(BaseStorageBackend): def get(self, filepath): - return 'bytes7' + return b'bytes7' def get_text(self, filepath, encoding='utf-8'): return 'text7' @@ -415,20 +445,20 @@ def get_text(self, filepath, encoding='utf-8'): force=True, prefixes=['example5_prefix1', 'example5_prefix2']) example_backend = FileClient('example5') - assert example_backend.get(self.img_path) == 'bytes7' + assert example_backend.get(self.img_path) == b'bytes7' assert example_backend.get_text(self.text_path) == 'text7' - example_backend = FileClient(prefixes='example5_prefix1') - assert example_backend.get(self.img_path) == 'bytes7' + example_backend = FileClient(prefix='example5_prefix1') + assert example_backend.get(self.img_path) == b'bytes7' assert example_backend.get_text(self.text_path) == 'text7' - example_backend = FileClient(prefixes='example5_prefix2') - assert example_backend.get(self.img_path) == 'bytes7' + example_backend = FileClient(prefix='example5_prefix2') + assert example_backend.get(self.img_path) == b'bytes7' assert example_backend.get_text(self.text_path) == 'text7' # backend has a higher priority than prefixes class Example8Backend(BaseStorageBackend): def get(self, filepath): - return 'bytes8' + return b'bytes8' def get_text(self, filepath, encoding='utf-8'): return 'text8' @@ -439,8 +469,8 @@ def get_text(self, filepath, encoding='utf-8'): force=True, prefixes='example6_prefix') example_backend = FileClient('example6') - assert example_backend.get(self.img_path) == 'bytes8' + assert example_backend.get(self.img_path) == b'bytes8' assert example_backend.get_text(self.text_path) == 'text8' - example_backend = FileClient('example6', prefixes='example4_prefix') - assert example_backend.get(self.img_path) == 'bytes8' + example_backend = FileClient('example6', prefix='example4_prefix') + assert example_backend.get(self.img_path) == b'bytes8' assert example_backend.get_text(self.text_path) == 'text8' From fb9567c6847cc147a30edb0326ec57dd6743fd7f Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Mon, 4 Oct 2021 22:59:02 +0800 Subject: [PATCH 16/46] fix test_clientio.py --- tests/test_fileio.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_fileio.py b/tests/test_fileio.py index 5b701d6da9..556a44a133 100644 --- a/tests/test_fileio.py +++ b/tests/test_fileio.py @@ -156,7 +156,7 @@ def test_list_from_file(): filename, file_client_args={'backend': 'http'}) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] filelist = mmcv.list_from_file( - filename, file_client_args={'prefixes': 'http'}) + filename, file_client_args={'prefix': 'http'}) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] filelist = mmcv.list_from_file(filename) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] @@ -169,7 +169,7 @@ def test_list_from_file(): filename, file_client_args={'backend': 'petrel'}) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] filelist = mmcv.list_from_file( - filename, file_client_args={'prefixes': 's3'}) + filename, file_client_args={'prefix': 's3'}) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] filelist = mmcv.list_from_file(filename) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] @@ -191,7 +191,7 @@ def test_dict_from_file(): filename, file_client_args={'backend': 'http'}) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} mapping = mmcv.dict_from_file( - filename, file_client_args={'prefixes': 'http'}) + filename, file_client_args={'prefix': 'http'}) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} mapping = mmcv.dict_from_file(filename) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} @@ -205,7 +205,7 @@ def test_dict_from_file(): filename, file_client_args={'backend': 'petrel'}) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} mapping = mmcv.dict_from_file( - filename, file_client_args={'prefixes': 's3'}) + filename, file_client_args={'prefix': 's3'}) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} mapping = mmcv.dict_from_file(filename) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} From 00505f8bef7008de9bd439c2e8ed9b2417f99933 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Mon, 4 Oct 2021 23:42:32 +0800 Subject: [PATCH 17/46] deprecate CephBackend --- mmcv/fileio/file_client.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index 01b12b37a5..42639a69b0 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -3,6 +3,7 @@ import os import os.path as osp import re +import warnings from abc import ABCMeta, abstractmethod from pathlib import Path from typing import Optional, Union @@ -35,6 +36,9 @@ class CephBackend(BaseStorageBackend): path_mapping (dict|None): path mapping dict from local path to Petrel path. When ``path_mapping={'src': 'dst'}``, ``src`` in ``filepath`` will be replaced by ``dst``. Default: None. + + ..warning:: + :class:`CephBackend` is deprecated using :class:`PetrelBackend` instead """ def __init__(self, path_mapping=None): @@ -43,6 +47,7 @@ def __init__(self, path_mapping=None): except ImportError: raise ImportError('Please install ceph to enable CephBackend.') + warnings.warn('CephBackend is deprecated using PetrelBackend instead') self._client = ceph.S3Client() assert isinstance(path_mapping, dict) or path_mapping is None self.path_mapping = path_mapping @@ -293,7 +298,7 @@ def get_text(self, filepath, encoding='utf-8'): class FileClient: - """A general file client to access files in different backend. + """A general file client to access files in different backends. The client loads a file or text in a specified backend from its path and return it as a binary or text file. There are two ways to choose a @@ -303,7 +308,7 @@ class FileClient: backend argument. If they are all `None`, the disk backend will be chosen. Note that It can also register other backend accessor with a given name, prefixes, and backend class. In addition, We use the singleton pattern to - avoid repeated object creationIf the arguments are the same, the same + avoid repeated object creation. If the arguments are the same, the same object is returned. Args: From 225d3a63db0bef362eb419a6c7e6cf6793e19649 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Wed, 6 Oct 2021 12:48:53 +0800 Subject: [PATCH 18/46] enhance docstring --- mmcv/fileio/file_client.py | 32 +++++++++++++++++++++++--------- mmcv/fileio/io.py | 8 ++++++++ 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index 42639a69b0..0da6fd894c 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -37,7 +37,7 @@ class CephBackend(BaseStorageBackend): path. When ``path_mapping={'src': 'dst'}``, ``src`` in ``filepath`` will be replaced by ``dst``. Default: None. - ..warning:: + .. warning:: :class:`CephBackend` is deprecated using :class:`PetrelBackend` instead """ @@ -84,7 +84,7 @@ class PetrelBackend(BaseStorageBackend): >>> filepath2 = 'cluster-name:s3://path/of/file' >>> client = PetrelBackend() >>> client.get(filepath1) # get from default cluster - >>> client.get(filepath2) # get from cluster-name + >>> client.get(filepath2) # get from 'cluster-name' cluster """ def __init__(self, @@ -317,13 +317,17 @@ class FileClient: prefix (str, optional): The prefix of the registered storage backend. Options are "s3", "http", "https". Default: None. - Example: + Examples: >>> # only set backend - >>> file_client = FileClient(backend='ceph') - >>> # only set prefixes + >>> file_client = FileClient(backend='petrel') + >>> # only set prefix >>> file_client = FileClient(prefix='s3') - >>> # set both backend and prefixes but use backend to choose client - >>> file_client = FileClient(backend='ceph', prefix='s3') + >>> # set both backend and prefix but use backend to choose client + >>> file_client = FileClient(backend='petrel', prefix='s3') + >>> # if the arguments are the same, the same object is returned + >>> file_client1 = FileClient(backend='petrel') + >>> file_client1 is file_client + True Attributes: client (:obj:`BaseStorageBackend`): The backend object. @@ -393,6 +397,10 @@ def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]: Args: uri (str | Path): Uri to be parsed that contains the file prefix. + + Examples: + >>> FileClient.parse_uri_prefix('s3://path/of/your/file') + 's3' """ assert is_filepath(uri) uri = str(uri) @@ -400,8 +408,8 @@ def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]: return None else: prefix, _ = uri.split('://') - # In the case of ceph, the prefix may contains the cluster name - # like clusterName:s3 + # In the case of PetrelBackend, the prefix may contains the cluster + # name like clusterName:s3 if ':' in prefix: _, prefix = prefix.split(':') return prefix @@ -417,6 +425,12 @@ def infer_client(cls, Default: None. uri (str | Path, optional): Uri to be parsed that contains the file prefix. Default: None. + + Examples: + >>> uri = 's3://path/of/your/file' + >>> file_client = FileClient.infer_client(uri=uri) + >>> file_client_args = {'backend': 'petrel'} + >>> file_client = FileClient.infer_client(file_client_args) """ assert file_client_args is not None or uri is not None if file_client_args is None: diff --git a/mmcv/fileio/io.py b/mmcv/fileio/io.py index c5206f208e..33fd8f38b3 100644 --- a/mmcv/fileio/io.py +++ b/mmcv/fileio/io.py @@ -20,6 +20,10 @@ def load(file, file_format=None, file_client_args=None, **kwargs): This method provides a unified api for loading data from serialized files. + Note: + In v1.3.15 and later, :function:`load` supports loading data from + serialized files those can be storaged in different backends. + Args: file (str or :obj:`Path` or file-like object): Filename or a file-like object. @@ -62,6 +66,10 @@ def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs): This method provides a unified api for dumping data as strings or to files, and also supports custom arguments for each file format. + Note: + In v1.3.15 and later, :function:`dump` supports dumping data as strings + or to different backends. + Args: obj (any): The python object to be dumped. file (str or :obj:`Path` or file-like object, optional): If not From 22644dadcde556258953b636f5110327b2f92725 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Wed, 6 Oct 2021 13:45:16 +0800 Subject: [PATCH 19/46] refactor unittest for petrel --- mmcv/fileio/io.py | 29 ++++++++++++++++++++--------- mmcv/fileio/parse.py | 32 ++++++++++++++++++++++++++++---- mmcv/image/photometric.py | 2 +- tests/test_fileclient.py | 19 ++++++++++++++++++- 4 files changed, 67 insertions(+), 15 deletions(-) diff --git a/mmcv/fileio/io.py b/mmcv/fileio/io.py index 33fd8f38b3..430711c2cd 100644 --- a/mmcv/fileio/io.py +++ b/mmcv/fileio/io.py @@ -21,8 +21,8 @@ def load(file, file_format=None, file_client_args=None, **kwargs): This method provides a unified api for loading data from serialized files. Note: - In v1.3.15 and later, :function:`load` supports loading data from - serialized files those can be storaged in different backends. + In v1.3.15 and later, `load` supports loading data from serialized + files those can be storaged in different backends. Args: file (str or :obj:`Path` or file-like object): Filename or a file-like @@ -31,8 +31,14 @@ def load(file, file_format=None, file_client_args=None, **kwargs): inferred from the file extension, otherwise use the specified one. Currently supported formats include "json", "yaml/yml" and "pickle/pkl". - file_client_args (dict): Arguments to instantiate a FileClient. - See :class:`mmcv.fileio.FileClient` for details. Default: None. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmcv.fileio.FileClient` for details. + Default: None. + + Examples: + >>> load('/path/of/your/file') # file is storaged in disk + >>> load('https://path/of/your/file') # file is storage in Internet + >>> load('s3://path/of/your/file') # file is storage in ceph or petrel Returns: The content from the file. @@ -67,17 +73,22 @@ def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs): and also supports custom arguments for each file format. Note: - In v1.3.15 and later, :function:`dump` supports dumping data as strings - or to different backends. + In v1.3.15 and later, `dump` supports dumping data as strings or to + files which is saved to different backends. Args: obj (any): The python object to be dumped. file (str or :obj:`Path` or file-like object, optional): If not - specified, then the object is dump to a str, otherwise to a file + specified, then the object is dumped to a str, otherwise to a file specified by the filename or file-like object. file_format (str, optional): Same as :func:`load`. - file_client_args (dict): Arguments to instantiate a FileClient. - See :class:`mmcv.fileio.FileClient` for details. Default: None. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmcv.fileio.FileClient` for details. + Default: None. + + Examples: + >>> dump('hello world', '/path/of/your/file') # disk + >>> dump('hello world', 's3://path/of/your/file') # ceph or petrel Returns: bool: True for success, False otherwise. diff --git a/mmcv/fileio/parse.py b/mmcv/fileio/parse.py index d7242c0501..3337a289c3 100644 --- a/mmcv/fileio/parse.py +++ b/mmcv/fileio/parse.py @@ -13,6 +13,11 @@ def list_from_file(filename, file_client_args=None): """Load a text file and parse the content as a list of strings. + Note: + In v1.3.15 and later, `list_from_file` supports loading a text file + which can be storaged in different backends and parsing the content as + a list for strings. + Args: filename (str): Filename. prefix (str): The prefix to be inserted to the begining of each item. @@ -20,8 +25,15 @@ def list_from_file(filename, max_num (int): The maximum number of lines to be read, zeros and negatives mean no limitation. encoding (str): Encoding used to open the file. Default utf-8. - file_client_args (dict): Arguments to instantiate a FileClient. - See :class:`mmcv.fileio.FileClient` for details. Default: None. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmcv.fileio.FileClient` for details. + Default: None. + + Examples: + >>> list_from_file('/path/of/your/file') # disk + ['hello', 'world'] + >>> list_from_file('s3://path/of/your/file') # ceph or petrel + ['hello', 'world'] Returns: list[str]: A list of strings. @@ -50,13 +62,25 @@ def dict_from_file(filename, whitespaces or tabs. The first column will be parsed as dict keys, and the following columns will be parsed as dict values. + Note: + In v1.3.15 and later, `dict_from_file` supports loading a text file + which can be storaged in different backends and parsing the content as + a dict. + Args: filename(str): Filename. key_type(type): Type of the dict keys. str is user by default and type conversion will be performed if specified. encoding (str): Encoding used to open the file. Default utf-8. - file_client_args (dict): Arguments to instantiate a FileClient. - See :class:`mmcv.fileio.FileClient` for details. Default: None. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmcv.fileio.FileClient` for details. + Default: None. + + Examples: + >>> dict_from_file('/path/of/your/file') # disk + {'key1': 'value1', 'key2': 'value2'} + >>> dict_from_file('s3://path/of/your/file') # ceph or petrel + {'key1': 'value1', 'key2': 'value2'} Returns: dict: The parsed contents. diff --git a/mmcv/image/photometric.py b/mmcv/image/photometric.py index 3c1f68f1f5..e234394084 100644 --- a/mmcv/image/photometric.py +++ b/mmcv/image/photometric.py @@ -309,7 +309,7 @@ def adjust_sharpness(img, factor=1., kernel=None): kernel (np.ndarray, optional): Filter kernel to be applied on the img to obtain the degenerated img. Defaults to None. - Notes:: + Note: No value sanity check is enforced on the kernel set by users. So with an inappropriate kernel, the `adjust_sharpness` may fail to perform the function its name indicates but end up performing whatever diff --git a/tests/test_fileclient.py b/tests/test_fileclient.py index c20e50eddd..da6f119ae1 100644 --- a/tests/test_fileclient.py +++ b/tests/test_fileclient.py @@ -150,9 +150,26 @@ def test_petrel_backend(self, backend, prefix): # test `get` petrel_backend.client._client.Get = MagicMock(return_value=b'petrel') - petrel_backend.get(petrel_path) + assert petrel_backend.get(petrel_path) == b'petrel' petrel_backend.client._client.Get.assert_called_with(petrel_path) + # test `get_text` + petrel_backend.client._client.Get = MagicMock(return_value=b'petrel') + assert petrel_backend.get_text(petrel_path) == 'petrel' + petrel_backend.client._client.Get.assert_called_with(petrel_path) + + # test `put` + petrel_backend.client._client.put = MagicMock() + petrel_backend.put(b'petrel', petrel_path) + petrel_backend.client._client.put.assert_called_with( + petrel_path, b'petrel') + + # test `put_text` + petrel_backend.client._client.put = MagicMock() + petrel_backend.put_text('petrel', petrel_path) + petrel_backend.client._client.put.assert_called_with( + petrel_path, b'petrel') + # test `remove` petrel_backend.client._client.delete = MagicMock() petrel_backend.remove(petrel_path) From 058b7e89609f8b9609378f1f0ede7c334ff7416d Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Wed, 6 Oct 2021 17:07:19 +0800 Subject: [PATCH 20/46] refactor unittest for disk backend --- tests/test_fileclient.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/test_fileclient.py b/tests/test_fileclient.py index da6f119ae1..116ef42022 100644 --- a/tests/test_fileclient.py +++ b/tests/test_fileclient.py @@ -1,4 +1,5 @@ import sys +import tempfile from pathlib import Path from unittest.mock import MagicMock, patch @@ -62,6 +63,7 @@ def test_error(self): def test_disk_backend(self): disk_backend = FileClient('disk') + # test `get` # input path is Path object img_bytes = disk_backend.get(self.img_path) img = mmcv.imfrombytes(img_bytes) @@ -73,6 +75,7 @@ def test_disk_backend(self): assert self.img_path.open('rb').read() == img_bytes assert img.shape == self.img_shape + # test `get_text` # input path is Path object value_buf = disk_backend.get_text(self.text_path) assert self.text_path.open('r').read() == value_buf @@ -80,6 +83,26 @@ def test_disk_backend(self): value_buf = disk_backend.get_text(str(self.text_path)) assert self.text_path.open('r').read() == value_buf + with tempfile.TemporaryDirectory() as tmp_dir: + # test `put` + filepath1 = Path(tmp_dir) / 'test_put' + disk_backend.put(b'hello world', filepath1) + assert filepath1.open('rb').read() == b'hello world' + + # test `put_text` + filepath2 = Path(tmp_dir) / 'test_put_text' + disk_backend.put_text('hello world', filepath2) + assert filepath2.open('r').read() == 'hello world' + + # test `isfile` + assert disk_backend.isfile(filepath2) + + # test `remove` + disk_backend.remove(filepath2) + + # test `check_exist` + assert not disk_backend.check_exist(filepath2) + @patch('ceph.S3Client', MockS3Client) def test_ceph_backend(self): ceph_backend = FileClient('ceph') From 16926788239dc4ddf3304c9eabaf0386d448bd05 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Wed, 6 Oct 2021 22:49:01 +0800 Subject: [PATCH 21/46] update io.md --- docs/understand_mmcv/io.md | 50 ++++++++++++++++++++++++++++++-- docs_zh_CN/understand_mmcv/io.md | 48 +++++++++++++++++++++++++++++- 2 files changed, 95 insertions(+), 3 deletions(-) diff --git a/docs/understand_mmcv/io.md b/docs/understand_mmcv/io.md index 50314d13d0..521d4dd773 100644 --- a/docs/understand_mmcv/io.md +++ b/docs/understand_mmcv/io.md @@ -2,11 +2,17 @@ This module provides two universal API to load and dump files of different formats. +```{note} +In v1.3.15 and later, `File IO` also supports loading data from different backends and dumping data to different backends. More datails at https://github.com/open-mmlab/mmcv/pull/1330. +``` + ### Load and dump data `mmcv` provides a universal api for loading and dumping data, currently supported formats are json, yaml and pickle. ++ Load from disk or dump to disk + ```python import mmcv @@ -29,6 +35,20 @@ with open('test.yaml', 'w') as f: data = mmcv.dump(data, f, file_format='yaml') ``` ++ Load from other backends or dump to other backends + +```python +import mmcv + +# load data from a file +data = mmcv.load('s3://bucket-name/test.json') +data = mmcv.load('s3://bucket-name/test.yaml') +data = mmcv.load('s3://bucket-name/test.pkl') + +# dump data to a file with a filename (infer format from file extension) +mmcv.dump(data, 's3://bucket-name/out.pkl') +``` + It is also very convenient to extend the api to support more file formats. All you need to do is to write a file handler inherited from `BaseFileHandler` and register it with one or several file formats. @@ -92,7 +112,9 @@ d e ``` -Then use `list_from_file` to load the list from a.txt. ++ Load from disk + +Use `list_from_file` to load the list from a.txt. ```python >>> mmcv.list_from_file('a.txt') @@ -113,7 +135,7 @@ For example `b.txt` is a text file with 3 lines. 3 panda ``` -Then use `dict_from_file` to load the dict from `b.txt` . +Then use `dict_from_file` to load the dict from `b.txt`. ```python >>> mmcv.dict_from_file('b.txt') @@ -121,3 +143,27 @@ Then use `dict_from_file` to load the dict from `b.txt` . >>> mmcv.dict_from_file('b.txt', key_type=int) {1: 'cat', 2: ['dog', 'cow'], 3: 'panda'} ``` + ++ load from other backends + +Use `list_from_file` to load the list from `s3://bucket-name/a.txt`. + +```python +>>> mmcv.list_from_file('s3://bucket-name/a.txt') +['a', 'b', 'c', 'd', 'e'] +>>> mmcv.list_from_file('s3://bucket-name/a.txt', offset=2) +['c', 'd', 'e'] +>>> mmcv.list_from_file('s3://bucket-name/a.txt', max_num=2) +['a', 'b'] +>>> mmcv.list_from_file('s3://bucket-name/a.txt', prefix='/mnt/') +['/mnt/a', '/mnt/b', '/mnt/c', '/mnt/d', '/mnt/e'] +``` + +Use `dict_from_file` to load the dict from `s3://bucket-name/b.txt`. + +```python +>>> mmcv.dict_from_file('s3://bucket-name/b.txt') +{'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} +>>> mmcv.dict_from_file('s3://bucket-name/b.txt', key_type=int) +{1: 'cat', 2: ['dog', 'cow'], 3: 'panda'} +``` diff --git a/docs_zh_CN/understand_mmcv/io.md b/docs_zh_CN/understand_mmcv/io.md index 8d3844f77c..13e7732da8 100644 --- a/docs_zh_CN/understand_mmcv/io.md +++ b/docs_zh_CN/understand_mmcv/io.md @@ -2,10 +2,16 @@ 文件输入输出模块提供了两个通用的 API 接口用于读取和保存不同格式的文件。 +```{note} +在 v1.3.15 及之后的版本中,`File IO` 支持从不同后端读取数据和将数据保存至不同后端。更多细节请访问 https://github.com/open-mmlab/mmcv/pull/1330。 +``` + ### 读取和保存数据 `mmcv` 提供了一个通用的 api 用于读取和保存数据,目前支持的格式有 json、yaml 和 pickle。 ++ 从硬盘读取数据或者将数据保存至硬盘 + ```python import mmcv @@ -28,6 +34,20 @@ with open('test.yaml', 'w') as f: data = mmcv.dump(data, f, file_format='yaml') ``` ++ 从其他后端加载或者保存至其他后端 + +```python +import mmcv + +# 从 s3 文件读取数据 +data = mmcv.load('s3://bucket-name/test.json') +data = mmcv.load('s3://bucket-name/test.yaml') +data = mmcv.load('s3://bucket-name/test.pkl') + +# 将数据保存至 s3 文件 (根据文件名后缀反推文件类型) +mmcv.dump(data, 's3://bucket-name/out.pkl') +``` + 我们提供了易于拓展的方式以支持更多的文件格式。我们只需要创建一个继承自 `BaseFileHandler` 的 文件句柄类并将其注册到 `mmcv` 中即可。句柄类至少需要重写三个方法。 @@ -88,6 +108,8 @@ d e ``` ++ 从硬盘读取 + 使用 `list_from_file` 读取 `a.txt` 。 ```python @@ -109,7 +131,7 @@ e 3 panda ``` -使用 `dict_from_file` 读取 `b.txt` 。 +使用 `dict_from_file` 读取 `b.txt`。 ```python >>> mmcv.dict_from_file('b.txt') @@ -117,3 +139,27 @@ e >>> mmcv.dict_from_file('b.txt', key_type=int) {1: 'cat', 2: ['dog', 'cow'], 3: 'panda'} ``` + ++ 从其他后端读取 + +使用 `list_from_file` 读取 `s3://bucket-name/a.txt` 。 + +```python +>>> mmcv.list_from_file('s3://bucket-name/a.txt') +['a', 'b', 'c', 'd', 'e'] +>>> mmcv.list_from_file('s3://bucket-name/a.txt', offset=2) +['c', 'd', 'e'] +>>> mmcv.list_from_file('s3://bucket-name/a.txt', max_num=2) +['a', 'b'] +>>> mmcv.list_from_file('s3://bucket-name/a.txt', prefix='/mnt/') +['/mnt/a', '/mnt/b', '/mnt/c', '/mnt/d', '/mnt/e'] +``` + +使用 `dict_from_file` 读取 `b.txt`。 + +```python +>>> mmcv.dict_from_file('s3://bucket-name/b.txt') +{'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} +>>> mmcv.dict_from_file('s3://bucket-name/b.txt', key_type=int) +{1: 'cat', 2: ['dog', 'cow'], 3: 'panda'} +``` From 01b9807febeb364076d42e69c56b77573775bbc9 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Wed, 6 Oct 2021 23:16:08 +0800 Subject: [PATCH 22/46] add concat_paths method --- mmcv/fileio/file_client.py | 14 +++++++++++++- tests/test_fileclient.py | 14 ++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index 0da6fd894c..234cc83e3a 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -157,6 +157,12 @@ def isfile(self, filepath: Union[str, Path]) -> bool: return False return self.check_exist(filepath) + def concat_paths(self, path, *paths) -> str: + formatted_paths = [self._format_path(self._path_mapping(path))] + for path in paths: + formatted_paths.append(self._format_path(self._path_mapping(path))) + return '/'.join(formatted_paths) + class MemcachedBackend(BaseStorageBackend): """Memcached storage backend. @@ -284,6 +290,9 @@ def check_exist(self, filepath: Union[str, Path]) -> bool: def isfile(self, filepath: Union[str, Path]) -> bool: return osp.isfile(str(filepath)) + def concat_paths(self, path, *paths): + return osp.join(path, *paths) + class HTTPBackend(BaseStorageBackend): """HTTP and HTTPS storage bachend.""" @@ -341,7 +350,7 @@ class FileClient: 'petrel': PetrelBackend, 'http': HTTPBackend, } - # This collection is used to record the overridden backend, and when a + # This collection is used to record the overridden backends, and when a # backend appears in the collection, the singleton pattern is disabled for # that backend, because if the singleton pattern is used, then the object # returned will be the backend before the override @@ -552,3 +561,6 @@ def check_exist(self, filepath): def isfile(self, filepath): return self.client.isfile(filepath) + + def concat_paths(self, path, *paths): + return self.client.concat_paths(path, *paths) diff --git a/tests/test_fileclient.py b/tests/test_fileclient.py index 116ef42022..58965544b0 100644 --- a/tests/test_fileclient.py +++ b/tests/test_fileclient.py @@ -1,3 +1,4 @@ +import os.path as osp import sys import tempfile from pathlib import Path @@ -103,6 +104,12 @@ def test_disk_backend(self): # test `check_exist` assert not disk_backend.check_exist(filepath2) + disk_dir = '/path/of/your/directory' + assert disk_backend.concat_paths(disk_dir, 'file') == \ + osp.join(disk_dir, 'file') + assert disk_backend.concat_paths(disk_dir, 'dir', 'file') == \ + osp.join(disk_dir, 'dir', 'file') + @patch('ceph.S3Client', MockS3Client) def test_ceph_backend(self): ceph_backend = FileClient('ceph') @@ -210,6 +217,13 @@ def test_petrel_backend(self, backend, prefix): # if ending with '/', it is not a file assert not petrel_backend.isfile(f'{petrel_path}/') + # test `concat_paths` + petrel_dir = 's3://path/of/your/directory' + assert petrel_backend.concat_paths(petrel_dir, 'file') == \ + f'{petrel_dir}/file' + assert petrel_backend.concat_paths(petrel_dir, 'dir', 'file') == \ + f'{petrel_dir}/dir/file' + @patch('mc.MemcachedClient.GetInstance', MockMemcachedClient) @patch('mc.pyvector', MagicMock) @patch('mc.ConvertBuffer', lambda x: x.content) From fed5a3937116e6bdaa71ba4d0bae5923caa37cd6 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Fri, 8 Oct 2021 17:55:28 +0800 Subject: [PATCH 23/46] improve docstring --- mmcv/fileio/file_client.py | 229 ++++++++++++++++++++++++++++++++----- 1 file changed, 198 insertions(+), 31 deletions(-) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index 234cc83e3a..a786dc66a7 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -38,7 +38,8 @@ class CephBackend(BaseStorageBackend): will be replaced by ``dst``. Default: None. .. warning:: - :class:`CephBackend` is deprecated using :class:`PetrelBackend` instead + :class:`CephBackend` will be deprecated, please use + :class:`PetrelBackend` instead """ def __init__(self, path_mapping=None): @@ -47,7 +48,8 @@ def __init__(self, path_mapping=None): except ImportError: raise ImportError('Please install ceph to enable CephBackend.') - warnings.warn('CephBackend is deprecated using PetrelBackend instead') + warnings.warn( + 'CephBackend will be deprecated, please use PetrelBackend instead') self._client = ceph.S3Client() assert isinstance(path_mapping, dict) or path_mapping is None self.path_mapping = path_mapping @@ -101,6 +103,11 @@ def __init__(self, self.path_mapping = path_mapping def _path_mapping(self, filepath: str) -> str: + """Replace the prefix of filepath with path_mapping. + + Args: + filepath (str): Path to be mapped. + """ if self.path_mapping is not None: for k, v in self.path_mapping.items(): filepath = filepath.replace(k, v) @@ -113,10 +120,18 @@ def _format_path(self, filepath: str) -> str: environment, the filepath will be the format of 's3://bucket_name\\image.jpg'. By invoking `_format_path`, the above filepath will be converted to 's3://bucket_name/image.jpg'. + + Args: + filepath (str): Path to be formatted. """ return re.sub(r'\\+', '/', filepath) def get(self, filepath: Union[str, Path]) -> memoryview: + """Read data from a given filepath with 'rb' mode. + + Args: + filepath (str or Path): Path to read data. + """ filepath = self._path_mapping(str(filepath)) filepath = self._format_path(filepath) value = self._client.Get(filepath) @@ -126,9 +141,22 @@ def get(self, filepath: Union[str, Path]) -> memoryview: def get_text(self, filepath: Union[str, Path], encoding: str = 'utf-8') -> str: + """Read data from a given filepath with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str, optional): The encoding format used to open the + `filepath`. Default: 'utf-8'. + """ return str(self.get(filepath), encoding=encoding) def put(self, obj: bytes, filepath: Union[str, Path]) -> None: + """Save data to a given filepath. + + Args: + obj (bytes): Data to be saved. + filepath (str or Path): Path to write data. + """ filepath = self._path_mapping(str(filepath)) filepath = self._format_path(filepath) self._client.put(filepath, obj) @@ -137,19 +165,42 @@ def put_text(self, obj: str, filepath: Union[str, Path], encoding: str = 'utf-8') -> None: + """Save data to a given filepath. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str, optional): The encoding format used to encode the + `obj`. Default: 'utf-8'. + """ self.put(bytes(obj, encoding=encoding), filepath) def remove(self, filepath: Union[str, Path]) -> None: + """Remove a file. + + Args: + filepath (str or Path): Path to be removed. + """ filepath = self._path_mapping(str(filepath)) filepath = self._format_path(filepath) self._client.delete(filepath) def check_exist(self, filepath: Union[str, Path]) -> bool: + """Check a filepath whether exists. + + Args: + filepath (str or Path): Path to be checked whether exists. + """ filepath = self._path_mapping(str(filepath)) filepath = self._format_path(filepath) return self._client.contains(filepath) def isfile(self, filepath: Union[str, Path]) -> bool: + """Check a filepath whether it is a file. + + Args: + filepath (str or Path): Path to be checked whether it is a file. + """ filepath = self._path_mapping(str(filepath)) filepath = self._format_path(filepath) # petrel checks a filepath whether it is a file by its ending char @@ -157,10 +208,19 @@ def isfile(self, filepath: Union[str, Path]) -> bool: return False return self.check_exist(filepath) - def concat_paths(self, path, *paths) -> str: - formatted_paths = [self._format_path(self._path_mapping(path))] - for path in paths: - formatted_paths.append(self._format_path(self._path_mapping(path))) + def concat_paths(self, filepath: Union[str, Path], + *filepaths: Union[str, Path]) -> str: + """Concatenate all filepaths. + + Args: + filepath (str or Path): Path to be concatenated. + """ + formatted_paths = [ + self._format_path(self._path_mapping(str(filepath))) + ] + for path in filepaths: + formatted_paths.append( + self._format_path(self._path_mapping(str(path)))) return '/'.join(formatted_paths) @@ -257,41 +317,97 @@ def get_text(self, filepath, encoding=None): class HardDiskBackend(BaseStorageBackend): """Raw hard disks storage backend.""" - def get(self, filepath): + def get(self, filepath: Union[str, Path]) -> bytes: + """Read data from a given filepath with 'rb' mode. + + Args: + filepath (str or Path): Path to read data. + """ filepath = str(filepath) with open(filepath, 'rb') as f: value_buf = f.read() return value_buf - def get_text(self, filepath, encoding='utf-8'): + def get_text(self, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> str: + """Read data from a given filepath with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str, optional): The encoding format used to open the + `filepath`. Default: 'utf-8'. + """ filepath = str(filepath) with open(filepath, 'r', encoding=encoding) as f: value_buf = f.read() return value_buf - def put(self, obj, filepath): + def put(self, obj: bytes, filepath: Union[str, Path]) -> None: + """Write data to a given filepath with 'wb' mode. + + Args: + obj (bytes): Data to be written. + filepath (str or Path): Path to write data. + """ filepath = str(filepath) with open(filepath, 'wb') as f: f.write(obj) - def put_text(self, obj, filepath, encoding='utf-8'): + def put_text(self, + obj: str, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> None: + """Write data to a given filepath with 'w' mode. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str, optional): The encoding format used to open the + `filepath`. Default: 'utf-8'. + """ filepath = str(filepath) with open(filepath, 'w', encoding=encoding) as f: f.write(obj) def remove(self, filepath: Union[str, Path]) -> None: - """Remove a file.""" + """Remove a file. + + Args: + filepath (str or Path): Path to be removed. + """ filepath = str(filepath) os.remove(filepath) def check_exist(self, filepath: Union[str, Path]) -> bool: + """Check a filepath whether exists. + + Args: + filepath (str or Path): Path to be checked whether exists. + """ return osp.exists(str(filepath)) def isfile(self, filepath: Union[str, Path]) -> bool: + """Check a filepath whether it is a file. + + Args: + filepath (str or Path): Path to be checked whether it is a file. + """ return osp.isfile(str(filepath)) - def concat_paths(self, path, *paths): - return osp.join(path, *paths) + def concat_paths(self, filepath: Union[str, Path], + *filepaths: Union[str, Path]) -> str: + """Concatenate all filepaths. + + Join one or more filepath components intelligently. The return value + is the concatenation of filepath and any members of *filepaths. + + Args: + filepath (str or Path): Path to be concatenated. + """ + filepath = str(filepath) + filepaths = [str(path) for path in filepaths] + return osp.join(filepath, *filepaths) class HTTPBackend(BaseStorageBackend): @@ -310,7 +426,7 @@ class FileClient: """A general file client to access files in different backends. The client loads a file or text in a specified backend from its path - and return it as a binary or text file. There are two ways to choose a + and returns it as a binary or text file. There are two ways to choose a backend, the name of backend and the prefix of path. Although both of them can be used to choose a storage backend, ``backend`` has a higher priority that is if they are all set, the storage backend will be chosen by the @@ -391,7 +507,7 @@ def __new__(cls, backend=None, prefix=None, **kwargs): _instance.backend_name = backend else: _instance.client = cls._prefix_to_backends[prefix](**kwargs) - # infer the backend name according to prefix + # infer the backend name according to the prefix for backend_name, backend_cls in cls._backends.items(): if isinstance(_instance.client, backend_cls): _instance.backend_name = backend_name @@ -407,6 +523,9 @@ def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]: Args: uri (str | Path): Uri to be parsed that contains the file prefix. + Returns: + return the prefix of uri if it contains "://" else None. + Examples: >>> FileClient.parse_uri_prefix('s3://path/of/your/file') 's3' @@ -430,8 +549,8 @@ def infer_client(cls, """Infer a suitable file client based on the URI and arguments. Args: - file_client_args (dict): Arguments to instantiate a FileClient. - Default: None. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. Default: None. uri (str | Path, optional): Uri to be parsed that contains the file prefix. Default: None. @@ -523,11 +642,9 @@ def get_text(self, filepath): Defaults to None. force (bool, optional): Whether to override the backend if the name has already been registered. Defaults to False. - prefixes (str or list[str] or tuple[str]): The prefix of the - registered storage backend. - - .. versionadd:: 1.3.14 - The *prefixes* parameter. + prefixes (str or list[str] or tuple[str], optional): The prefixes + of the registered storage backend. Default: None. + `New in version 1.3.15.` """ if backend is not None: cls._register_backend( @@ -541,26 +658,76 @@ def _register(backend_cls): return _register - def get(self, filepath): + def get(self, filepath: Union[str, Path]) -> Union[bytes, memoryview]: + """Read data from a given filepath with 'rb' mode. + + Args: + filepath (str or Path): Path to read data. + """ return self.client.get(filepath) - def get_text(self, filepath, encoding='utf-8'): + def get_text(self, filepath: Union[str, Path], encoding='utf-8') -> str: + """Read data from a given filepath with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str, optional): The encoding format used to open the + `filepath`. Default: 'utf-8'. + """ return self.client.get_text(filepath, encoding) - def put(self, obj, filepath): + def put(self, obj: bytes, filepath: Union[str, Path]) -> None: + """Write data to a given filepath with 'wb' mode. + + Args: + obj (bytes): Data to be written. + filepath (str or Path): Path to write data. + """ self.client.put(obj, filepath) - def put_text(self, obj, filepath): + def put_text(self, obj: str, filepath: Union[str, Path]) -> None: + """Write data to a given filepath with 'w' mode. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str, optional): The encoding format used to open the + `filepath`. Default: 'utf-8'. + """ self.client.put_text(obj, filepath) - def remove(self, filepath): + def remove(self, filepath: Union[str, Path]) -> None: + """Remove a file. + + Args: + filepath (str, Path): Path to be removed. + """ self.client.remove(filepath) - def check_exist(self, filepath): + def check_exist(self, filepath: Union[str, Path]) -> bool: + """Check a filepath whether exists. + + Args: + filepath (str or Path): Path to be checked whether exists. + """ return self.client.check_exist(filepath) - def isfile(self, filepath): + def isfile(self, filepath: Union[str, Path]) -> bool: + """Check a filepath whether it is a file. + + Args: + filepath (str or Path): Path to be checked whether it is a file. + """ return self.client.isfile(filepath) - def concat_paths(self, path, *paths): - return self.client.concat_paths(path, *paths) + def concat_paths(self, filepath: Union[str, Path], + *filepaths: Union[str, Path]) -> str: + """Concatenate all filepaths. + + Join one or more filepath components intelligently. The return value + is the concatenation of filepath and any members of *filepaths. + + Args: + filepath (str or Path): Path to be concatenated. + """ + return self.client.concat_paths(filepath, *filepaths) From 495968760ef2f931ae660dacc3ff73cf6a5eb845 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Fri, 8 Oct 2021 20:08:34 +0800 Subject: [PATCH 24/46] improve docstring --- docs/understand_mmcv/io.md | 10 +++++----- docs_zh_CN/understand_mmcv/io.md | 13 ++++++------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/docs/understand_mmcv/io.md b/docs/understand_mmcv/io.md index 521d4dd773..c83ba9bdc1 100644 --- a/docs/understand_mmcv/io.md +++ b/docs/understand_mmcv/io.md @@ -3,7 +3,7 @@ This module provides two universal API to load and dump files of different formats. ```{note} -In v1.3.15 and later, `File IO` also supports loading data from different backends and dumping data to different backends. More datails at https://github.com/open-mmlab/mmcv/pull/1330. +Since v1.3.15, the IO modules support loading and dumping data from and to different backends, respectively. More details are in PR [#1330](https://github.com/open-mmlab/mmcv/pull/1330). ``` ### Load and dump data @@ -11,7 +11,7 @@ In v1.3.15 and later, `File IO` also supports loading data from different backen `mmcv` provides a universal api for loading and dumping data, currently supported formats are json, yaml and pickle. -+ Load from disk or dump to disk +#### Load from disk or dump to disk ```python import mmcv @@ -35,7 +35,7 @@ with open('test.yaml', 'w') as f: data = mmcv.dump(data, f, file_format='yaml') ``` -+ Load from other backends or dump to other backends +#### Load from other backends or dump to other backends ```python import mmcv @@ -112,7 +112,7 @@ d e ``` -+ Load from disk +#### Load from disk Use `list_from_file` to load the list from a.txt. @@ -144,7 +144,7 @@ Then use `dict_from_file` to load the dict from `b.txt`. {1: 'cat', 2: ['dog', 'cow'], 3: 'panda'} ``` -+ load from other backends +#### Load from other backends Use `list_from_file` to load the list from `s3://bucket-name/a.txt`. diff --git a/docs_zh_CN/understand_mmcv/io.md b/docs_zh_CN/understand_mmcv/io.md index 13e7732da8..b6c2a8b224 100644 --- a/docs_zh_CN/understand_mmcv/io.md +++ b/docs_zh_CN/understand_mmcv/io.md @@ -3,14 +3,14 @@ 文件输入输出模块提供了两个通用的 API 接口用于读取和保存不同格式的文件。 ```{note} -在 v1.3.15 及之后的版本中,`File IO` 支持从不同后端读取数据和将数据保存至不同后端。更多细节请访问 https://github.com/open-mmlab/mmcv/pull/1330。 +在 v1.3.15 及之后的版本中,IO 模块支持从不同后端读取数据和将数据保存至不同后端。更多细节请访问 PR [#1330](https://github.com/open-mmlab/mmcv/pull/1330)。 ``` ### 读取和保存数据 `mmcv` 提供了一个通用的 api 用于读取和保存数据,目前支持的格式有 json、yaml 和 pickle。 -+ 从硬盘读取数据或者将数据保存至硬盘 +#### 从硬盘读取数据或者将数据保存至硬盘 ```python import mmcv @@ -34,7 +34,7 @@ with open('test.yaml', 'w') as f: data = mmcv.dump(data, f, file_format='yaml') ``` -+ 从其他后端加载或者保存至其他后端 +#### 从其他后端加载或者保存至其他后端 ```python import mmcv @@ -69,7 +69,7 @@ class TxtHandler1(mmcv.BaseFileHandler): return str(obj) ``` -举 `PickleHandler` 为例。 +举 `PickleHandler` 为例 ```python import pickle @@ -107,8 +107,7 @@ c d e ``` - -+ 从硬盘读取 +#### 从硬盘读取 使用 `list_from_file` 读取 `a.txt` 。 @@ -140,7 +139,7 @@ e {1: 'cat', 2: ['dog', 'cow'], 3: 'panda'} ``` -+ 从其他后端读取 +#### 从其他后端读取 使用 `list_from_file` 读取 `s3://bucket-name/a.txt` 。 From aea920a64246c96227b3c396256b87e5c387cd9a Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Sun, 10 Oct 2021 16:00:13 +0800 Subject: [PATCH 25/46] add isdir and copyfile for file backend --- mmcv/fileio/file_client.py | 97 ++++++++++++++++++++++++++++++++++++++ tests/test_fileclient.py | 81 ++++++++++++++++++++++++++----- 2 files changed, 167 insertions(+), 11 deletions(-) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index a786dc66a7..aea1d0c417 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -3,6 +3,7 @@ import os import os.path as osp import re +import shutil import warnings from abc import ABCMeta, abstractmethod from pathlib import Path @@ -195,6 +196,20 @@ def check_exist(self, filepath: Union[str, Path]) -> bool: filepath = self._format_path(filepath) return self._client.contains(filepath) + def isdir(self, filepath: Union[str, Path]) -> bool: + """Check a filepath whether it is a directory. + + Args: + filepath (str or Path): Path to be checked whether it is a + directory. + """ + filepath = self._path_mapping(str(filepath)) + filepath = self._format_path(filepath) + # petrel checks a filepath whether it is a file by its ending char + if not filepath.endswith('/'): + return False + return self.check_exist(filepath) + def isfile(self, filepath: Union[str, Path]) -> bool: """Check a filepath whether it is a file. @@ -223,6 +238,32 @@ def concat_paths(self, filepath: Union[str, Path], self._format_path(self._path_mapping(str(path)))) return '/'.join(formatted_paths) + def copyfile(self, + src_path: Union[str, Path], + dst_path: Union[str, Path], + force: bool = False) -> None: + """Copy a file from src_path starting with 's3' to dst_path which is a + disk path. + + Args: + src_path (str or Path): Download a file from ``src_path``. + dst_path (str or Path): Save a file to ``dst_path``. + force (bool): Whether to overwrite the file when ``dst_path`` + exists. Default: False. + """ + src_path = str(src_path) + dst_path = str(dst_path) + + if self.isdir(src_path) or osp.isdir(dst_path): + raise ValueError('src_path and dst_path should both be a path of ' + 'file rather than a directory') + if not self.isfile(src_path): + raise FileNotFoundError(f'src_path {src_path} cannot be found') + + if force or not osp.isfile(dst_path): + with open(dst_path, 'wb') as f: + f.write(self.get(src_path)) + class MemcachedBackend(BaseStorageBackend): """Memcached storage backend. @@ -387,6 +428,15 @@ def check_exist(self, filepath: Union[str, Path]) -> bool: """ return osp.exists(str(filepath)) + def isdir(self, filepath: Union[str, Path]) -> bool: + """Check a filepath whether it is a directory. + + Args: + filepath (str or Path): Path to be checked whether it is a + directory. + """ + return osp.isdir(str(filepath)) + def isfile(self, filepath: Union[str, Path]) -> bool: """Check a filepath whether it is a file. @@ -409,6 +459,30 @@ def concat_paths(self, filepath: Union[str, Path], filepaths = [str(path) for path in filepaths] return osp.join(filepath, *filepaths) + def copyfile(self, + src_path: Union[str, Path], + dst_path: Union[str, Path], + force: bool = False) -> None: + """Copy a file from src_path to dst_path. + + Args: + src_path (str or Path): Copy a file from ``src_path``. + dst_path (str or Path): Save a file to ``dst_path``. + force (bool): Whether to overwrite the file when ``dst_path`` + exists. Default: False. + """ + src_path = str(src_path) + dst_path = str(dst_path) + + if self.isdir(src_path) or self.isdir(dst_path): + raise ValueError('src_path and dst_path should both be a path of ' + 'file rather than a directory') + if not self.isfile(src_path): + raise FileNotFoundError(f'src_path {src_path} cannot be found') + + if force or not self.isfile(dst_path): + shutil.copyfile(src_path, dst_path) + class HTTPBackend(BaseStorageBackend): """HTTP and HTTPS storage bachend.""" @@ -712,6 +786,15 @@ def check_exist(self, filepath: Union[str, Path]) -> bool: """ return self.client.check_exist(filepath) + def isdir(self, filepath: Union[str, Path]) -> bool: + """Check a filepath whether it is a directory. + + Args: + filepath (str or Path): Path to be checked whether it is a + directory. + """ + return self.client.isdir(filepath) + def isfile(self, filepath: Union[str, Path]) -> bool: """Check a filepath whether it is a file. @@ -731,3 +814,17 @@ def concat_paths(self, filepath: Union[str, Path], filepath (str or Path): Path to be concatenated. """ return self.client.concat_paths(filepath, *filepaths) + + def copyfile(self, + src_path: Union[str, Path], + dst_path: Union[str, Path], + force: bool = False) -> None: + """Copy a file from src_path to dst_path. + + Args: + src_path (str or Path): Copy a file from ``src_path``. + dst_path (str or Path): Save a file to ``dst_path``. + force (bool): Whether to overwrite the file when ``dst_path`` + exists. Default: False. + """ + self.client.copyfile(src_path, dst_path, force) diff --git a/tests/test_fileclient.py b/tests/test_fileclient.py index 58965544b0..e7e2706c9d 100644 --- a/tests/test_fileclient.py +++ b/tests/test_fileclient.py @@ -86,17 +86,21 @@ def test_disk_backend(self): with tempfile.TemporaryDirectory() as tmp_dir: # test `put` - filepath1 = Path(tmp_dir) / 'test_put' - disk_backend.put(b'hello world', filepath1) - assert filepath1.open('rb').read() == b'hello world' + filepath1 = Path(tmp_dir) / 'test.jpg' + disk_backend.put(b'disk', filepath1) + assert filepath1.open('rb').read() == b'disk' # test `put_text` - filepath2 = Path(tmp_dir) / 'test_put_text' - disk_backend.put_text('hello world', filepath2) - assert filepath2.open('r').read() == 'hello world' + filepath2 = Path(tmp_dir) / 'test.txt' + disk_backend.put_text('disk', filepath2) + assert filepath2.open('r').read() == 'disk' + + # test `isdir` + assert disk_backend.isdir(tmp_dir) # test `isfile` assert disk_backend.isfile(filepath2) + assert not disk_backend.isfile(Path(tmp_dir) / 'not/existed/path') # test `remove` disk_backend.remove(filepath2) @@ -104,6 +108,27 @@ def test_disk_backend(self): # test `check_exist` assert not disk_backend.check_exist(filepath2) + # test `copyfile` + filepath3 = Path(tmp_dir) / 'test1.jpg' + with pytest.raises(ValueError): + # src_path and dst_path should both be a path of file rather + # than a directory + disk_backend.copyfile(src_path=tmp_dir, dst_path=tmp_dir) + with pytest.raises(FileNotFoundError): + # src_path should exist + disk_backend.copyfile( + src_path='/not/existed/path', dst_path=filepath3) + disk_backend.copyfile(filepath1, filepath3) + assert disk_backend.isfile(filepath3) + # force = False + disk_backend.put(b'overwrite disk', filepath3) + # filepath2 exists + disk_backend.copyfile(filepath1, filepath3, force=False) + assert filepath3.open('rb').read() == b'overwrite disk' + # force = True + disk_backend.copyfile(filepath1, filepath3, force=True) + assert filepath3.open('rb').read() == b'disk' + disk_dir = '/path/of/your/directory' assert disk_backend.concat_paths(disk_dir, 'file') == \ osp.join(disk_dir, 'file') @@ -165,13 +190,13 @@ def test_petrel_backend(self, backend, prefix): FileClient('petrel', path_mapping=1) # test `_path_mapping` - petrel_path = 's3://user/data' + petrel_dir = 's3://user/data' petrel_backend = FileClient( - 'petrel', path_mapping={str(self.test_data_dir): petrel_path}) + 'petrel', path_mapping={str(self.test_data_dir): petrel_dir}) assert petrel_backend.client._path_mapping(str(self.img_path)) == \ - str(self.img_path).replace(str(self.test_data_dir), petrel_path) + str(self.img_path).replace(str(self.test_data_dir), petrel_dir) - petrel_path = 's3://user/data/test.jpg' + petrel_path = f'{petrel_dir}/test.jpg' petrel_backend = FileClient('petrel') # test `_format_path` @@ -210,20 +235,54 @@ def test_petrel_backend(self, backend, prefix): assert petrel_backend.check_exist(petrel_path) petrel_backend.client._client.contains.assert_called_with(petrel_path) + # test `isdir` + assert petrel_backend.isdir(f'{petrel_dir}/') + # directory should end with '/' + assert not petrel_backend.isdir(petrel_dir) + # test `isfile` petrel_backend.client._client.contains = MagicMock(return_value=True) + petrel_backend.client._client.contains = MagicMock(return_value=True) assert petrel_backend.isfile(petrel_path) petrel_backend.client._client.contains.assert_called_with(petrel_path) # if ending with '/', it is not a file assert not petrel_backend.isfile(f'{petrel_path}/') # test `concat_paths` - petrel_dir = 's3://path/of/your/directory' assert petrel_backend.concat_paths(petrel_dir, 'file') == \ f'{petrel_dir}/file' assert petrel_backend.concat_paths(petrel_dir, 'dir', 'file') == \ f'{petrel_dir}/dir/file' + # test `copyfile` + with tempfile.TemporaryDirectory() as tmp_dir: + disk_path = Path(tmp_dir) / 'test.jpg' + with pytest.raises(ValueError): + # src_path and dst_path should both be a path of file rather + # than a directory + petrel_backend.copyfile(src_path=petrel_dir, dst_path=tmp_dir) + with pytest.raises(FileNotFoundError): + # src_path should exist + petrel_backend.client._client.contains = MagicMock( + return_value=False) + petrel_backend.copyfile( + src_path=petrel_path, dst_path=disk_path) + petrel_backend.client._client.contains = MagicMock( + return_value=True) + petrel_backend.copyfile(petrel_path, disk_path) + assert osp.isfile(disk_path) + assert disk_path.open('rb').read() == b'petrel' + + # force = False + # filepath2 exists + petrel_backend.client._client.Get = MagicMock( + return_value=b'new petrel') + petrel_backend.copyfile(petrel_path, disk_path, force=False) + assert disk_path.open('rb').read() == b'petrel' + # force = True + petrel_backend.copyfile(petrel_path, disk_path, force=True) + assert disk_path.open('rb').read() == b'new petrel' + @patch('mc.MemcachedClient.GetInstance', MockMemcachedClient) @patch('mc.pyvector', MagicMock) @patch('mc.ConvertBuffer', lambda x: x.content) From 641210352db1fb4b45654a54f6025d990a8afa2b Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Mon, 11 Oct 2021 14:57:30 +0800 Subject: [PATCH 26/46] delete copyfile and add get_local_path --- mmcv/fileio/file_client.py | 117 ++++++++++++++++++++----------------- tests/test_fileclient.py | 66 ++++++--------------- 2 files changed, 81 insertions(+), 102 deletions(-) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index aea1d0c417..e306ed7be5 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -3,9 +3,10 @@ import os import os.path as osp import re -import shutil +import tempfile import warnings from abc import ABCMeta, abstractmethod +from contextlib import contextmanager from pathlib import Path from typing import Optional, Union from urllib.request import urlopen @@ -238,31 +239,27 @@ def concat_paths(self, filepath: Union[str, Path], self._format_path(self._path_mapping(str(path)))) return '/'.join(formatted_paths) - def copyfile(self, - src_path: Union[str, Path], - dst_path: Union[str, Path], - force: bool = False) -> None: - """Copy a file from src_path starting with 's3' to dst_path which is a - disk path. + def _release_resource(self, filepath: str) -> None: + """Release the resource generated by _get_local_path. Args: - src_path (str or Path): Download a file from ``src_path``. - dst_path (str or Path): Save a file to ``dst_path``. - force (bool): Whether to overwrite the file when ``dst_path`` - exists. Default: False. + filepath (str): Path to be released. """ - src_path = str(src_path) - dst_path = str(dst_path) + os.remove(filepath) + + def _get_local_path(self, filepath: str) -> str: + """Download a file from filepath. - if self.isdir(src_path) or osp.isdir(dst_path): - raise ValueError('src_path and dst_path should both be a path of ' - 'file rather than a directory') - if not self.isfile(src_path): - raise FileNotFoundError(f'src_path {src_path} cannot be found') + Args: + filepath (str): Download a file from ``filepath``. + """ + assert self.isfile(filepath) - if force or not osp.isfile(dst_path): - with open(dst_path, 'wb') as f: - f.write(self.get(src_path)) + # the file will be removed when calling _release_resource() + f = tempfile.NamedTemporaryFile(delete=False) + f.write(self.get(filepath)) + f.close() + return f.name class MemcachedBackend(BaseStorageBackend): @@ -459,29 +456,13 @@ def concat_paths(self, filepath: Union[str, Path], filepaths = [str(path) for path in filepaths] return osp.join(filepath, *filepaths) - def copyfile(self, - src_path: Union[str, Path], - dst_path: Union[str, Path], - force: bool = False) -> None: - """Copy a file from src_path to dst_path. - - Args: - src_path (str or Path): Copy a file from ``src_path``. - dst_path (str or Path): Save a file to ``dst_path``. - force (bool): Whether to overwrite the file when ``dst_path`` - exists. Default: False. - """ - src_path = str(src_path) - dst_path = str(dst_path) - - if self.isdir(src_path) or self.isdir(dst_path): - raise ValueError('src_path and dst_path should both be a path of ' - 'file rather than a directory') - if not self.isfile(src_path): - raise FileNotFoundError(f'src_path {src_path} cannot be found') + def _release_resource(self, filepath: str) -> None: + """Do nothing in order to unify API.""" + pass - if force or not self.isfile(dst_path): - shutil.copyfile(src_path, dst_path) + def _get_local_path(self, filepath: str) -> str: + """Do nothing in order to unify API.""" + return filepath class HTTPBackend(BaseStorageBackend): @@ -495,6 +476,26 @@ def get_text(self, filepath, encoding='utf-8'): value_buf = urlopen(filepath).read() return value_buf.decode(encoding) + def _release_resource(self, filepath: str) -> None: + """Release the resource generated by _get_local_path. + + Args: + filepath (str): Path to be released. + """ + os.remove(filepath) + + def _get_local_path(self, filepath: str) -> str: + """Download a file from filepath. + + Args: + filepath (str): Download a file from ``filepath``. + """ + # the file will be removed when calling _release_resource() + f = tempfile.NamedTemporaryFile(delete=False) + f.write(self.get(filepath)) + f.close() + return f.name + class FileClient: """A general file client to access files in different backends. @@ -815,16 +816,26 @@ def concat_paths(self, filepath: Union[str, Path], """ return self.client.concat_paths(filepath, *filepaths) - def copyfile(self, - src_path: Union[str, Path], - dst_path: Union[str, Path], - force: bool = False) -> None: - """Copy a file from src_path to dst_path. + @contextmanager + def get_local_path(self, filepath: Union[str, Path]): + """Download data from given filepath and write the data to local path. + + If the ``filepath`` is a local path, just return the ``filepath``. + + Note: + ``get_local_path`` is an experimental interface that may change in + the future. Args: - src_path (str or Path): Copy a file from ``src_path``. - dst_path (str or Path): Save a file to ``dst_path``. - force (bool): Whether to overwrite the file when ``dst_path`` - exists. Default: False. + filepath (str or Path): Path to be read data. + + Examples: + >>> file_client = FileClient(prefix='s3') + >>> with file_client.get_local_path('s3://bucket/abc.jpg') as path: + ... # do something here """ - self.client.copyfile(src_path, dst_path, force) + path = self.client._get_local_path(str(filepath)) + try: + yield path + finally: + self.client._release_resource(path) diff --git a/tests/test_fileclient.py b/tests/test_fileclient.py index e7e2706c9d..2bc4f9aeb0 100644 --- a/tests/test_fileclient.py +++ b/tests/test_fileclient.py @@ -108,26 +108,11 @@ def test_disk_backend(self): # test `check_exist` assert not disk_backend.check_exist(filepath2) - # test `copyfile` - filepath3 = Path(tmp_dir) / 'test1.jpg' - with pytest.raises(ValueError): - # src_path and dst_path should both be a path of file rather - # than a directory - disk_backend.copyfile(src_path=tmp_dir, dst_path=tmp_dir) - with pytest.raises(FileNotFoundError): - # src_path should exist - disk_backend.copyfile( - src_path='/not/existed/path', dst_path=filepath3) - disk_backend.copyfile(filepath1, filepath3) - assert disk_backend.isfile(filepath3) - # force = False - disk_backend.put(b'overwrite disk', filepath3) - # filepath2 exists - disk_backend.copyfile(filepath1, filepath3, force=False) - assert filepath3.open('rb').read() == b'overwrite disk' - # force = True - disk_backend.copyfile(filepath1, filepath3, force=True) - assert filepath3.open('rb').read() == b'disk' + # test `_get_local_path` + # if the backend is disk, `get_local_path` just return the input + with disk_backend.get_local_path(filepath1) as path: + assert str(filepath1) == path + assert osp.isfile(filepath1) disk_dir = '/path/of/your/directory' assert disk_backend.concat_paths(disk_dir, 'file') == \ @@ -242,7 +227,6 @@ def test_petrel_backend(self, backend, prefix): # test `isfile` petrel_backend.client._client.contains = MagicMock(return_value=True) - petrel_backend.client._client.contains = MagicMock(return_value=True) assert petrel_backend.isfile(petrel_path) petrel_backend.client._client.contains.assert_called_with(petrel_path) # if ending with '/', it is not a file @@ -254,34 +238,12 @@ def test_petrel_backend(self, backend, prefix): assert petrel_backend.concat_paths(petrel_dir, 'dir', 'file') == \ f'{petrel_dir}/dir/file' - # test `copyfile` - with tempfile.TemporaryDirectory() as tmp_dir: - disk_path = Path(tmp_dir) / 'test.jpg' - with pytest.raises(ValueError): - # src_path and dst_path should both be a path of file rather - # than a directory - petrel_backend.copyfile(src_path=petrel_dir, dst_path=tmp_dir) - with pytest.raises(FileNotFoundError): - # src_path should exist - petrel_backend.client._client.contains = MagicMock( - return_value=False) - petrel_backend.copyfile( - src_path=petrel_path, dst_path=disk_path) - petrel_backend.client._client.contains = MagicMock( - return_value=True) - petrel_backend.copyfile(petrel_path, disk_path) - assert osp.isfile(disk_path) - assert disk_path.open('rb').read() == b'petrel' - - # force = False - # filepath2 exists - petrel_backend.client._client.Get = MagicMock( - return_value=b'new petrel') - petrel_backend.copyfile(petrel_path, disk_path, force=False) - assert disk_path.open('rb').read() == b'petrel' - # force = True - petrel_backend.copyfile(petrel_path, disk_path, force=True) - assert disk_path.open('rb').read() == b'new petrel' + # test `_get_local_path` + # exist the with block and path will be released + petrel_backend.client._client.contains = MagicMock(return_value=True) + with petrel_backend.get_local_path(petrel_path) as path: + assert Path(path).open('rb').read() == b'petrel' + assert not osp.isfile(path) @patch('mc.MemcachedClient.GetInstance', MockMemcachedClient) @patch('mc.pyvector', MagicMock) @@ -355,6 +317,12 @@ def test_http_backend(self, backend, prefix): value_buf = http_backend.get_text(text_url) assert self.text_path.open('r').read() == value_buf + # test `_get_local_path` + # exist the with block and path will be released + with http_backend.get_local_path(img_url) as path: + assert mmcv.imread(path).shape == self.img_shape + assert not osp.isfile(path) + def test_new_magic_method(self): class DummyBackend1(BaseStorageBackend): From eeda74cbf146fe21461315072cc6733f942134d4 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Tue, 12 Oct 2021 11:09:05 +0800 Subject: [PATCH 27/46] remove isdir method of petrel --- mmcv/fileio/file_client.py | 29 +++++------------------------ tests/test_fileclient.py | 18 ++++-------------- 2 files changed, 9 insertions(+), 38 deletions(-) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index e306ed7be5..4208de9fb7 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -187,7 +187,7 @@ def remove(self, filepath: Union[str, Path]) -> None: filepath = self._format_path(filepath) self._client.delete(filepath) - def check_exist(self, filepath: Union[str, Path]) -> bool: + def exists(self, filepath: Union[str, Path]) -> bool: """Check a filepath whether exists. Args: @@ -197,32 +197,13 @@ def check_exist(self, filepath: Union[str, Path]) -> bool: filepath = self._format_path(filepath) return self._client.contains(filepath) - def isdir(self, filepath: Union[str, Path]) -> bool: - """Check a filepath whether it is a directory. - - Args: - filepath (str or Path): Path to be checked whether it is a - directory. - """ - filepath = self._path_mapping(str(filepath)) - filepath = self._format_path(filepath) - # petrel checks a filepath whether it is a file by its ending char - if not filepath.endswith('/'): - return False - return self.check_exist(filepath) - def isfile(self, filepath: Union[str, Path]) -> bool: """Check a filepath whether it is a file. Args: filepath (str or Path): Path to be checked whether it is a file. """ - filepath = self._path_mapping(str(filepath)) - filepath = self._format_path(filepath) - # petrel checks a filepath whether it is a file by its ending char - if filepath.endswith('/'): - return False - return self.check_exist(filepath) + return self.exists(filepath) def concat_paths(self, filepath: Union[str, Path], *filepaths: Union[str, Path]) -> str: @@ -417,7 +398,7 @@ def remove(self, filepath: Union[str, Path]) -> None: filepath = str(filepath) os.remove(filepath) - def check_exist(self, filepath: Union[str, Path]) -> bool: + def exists(self, filepath: Union[str, Path]) -> bool: """Check a filepath whether exists. Args: @@ -779,13 +760,13 @@ def remove(self, filepath: Union[str, Path]) -> None: """ self.client.remove(filepath) - def check_exist(self, filepath: Union[str, Path]) -> bool: + def exists(self, filepath: Union[str, Path]) -> bool: """Check a filepath whether exists. Args: filepath (str or Path): Path to be checked whether exists. """ - return self.client.check_exist(filepath) + return self.client.exists(filepath) def isdir(self, filepath: Union[str, Path]) -> bool: """Check a filepath whether it is a directory. diff --git a/tests/test_fileclient.py b/tests/test_fileclient.py index 2bc4f9aeb0..e5868b1e87 100644 --- a/tests/test_fileclient.py +++ b/tests/test_fileclient.py @@ -95,9 +95,6 @@ def test_disk_backend(self): disk_backend.put_text('disk', filepath2) assert filepath2.open('r').read() == 'disk' - # test `isdir` - assert disk_backend.isdir(tmp_dir) - # test `isfile` assert disk_backend.isfile(filepath2) assert not disk_backend.isfile(Path(tmp_dir) / 'not/existed/path') @@ -105,8 +102,8 @@ def test_disk_backend(self): # test `remove` disk_backend.remove(filepath2) - # test `check_exist` - assert not disk_backend.check_exist(filepath2) + # test `exists` + assert not disk_backend.exists(filepath2) # test `_get_local_path` # if the backend is disk, `get_local_path` just return the input @@ -215,22 +212,15 @@ def test_petrel_backend(self, backend, prefix): petrel_backend.remove(petrel_path) petrel_backend.client._client.delete.assert_called_with(petrel_path) - # test `check_exist` + # test `exists` petrel_backend.client._client.contains = MagicMock(return_value=True) - assert petrel_backend.check_exist(petrel_path) + assert petrel_backend.exists(petrel_path) petrel_backend.client._client.contains.assert_called_with(petrel_path) - # test `isdir` - assert petrel_backend.isdir(f'{petrel_dir}/') - # directory should end with '/' - assert not petrel_backend.isdir(petrel_dir) - # test `isfile` petrel_backend.client._client.contains = MagicMock(return_value=True) assert petrel_backend.isfile(petrel_path) petrel_backend.client._client.contains.assert_called_with(petrel_path) - # if ending with '/', it is not a file - assert not petrel_backend.isfile(f'{petrel_path}/') # test `concat_paths` assert petrel_backend.concat_paths(petrel_dir, 'file') == \ From ad5242804121d6c79c76f1106d17bd2652a6cf6d Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Tue, 12 Oct 2021 11:50:44 +0800 Subject: [PATCH 28/46] fix typo --- mmcv/fileio/io.py | 8 ++++---- mmcv/fileio/parse.py | 4 ++-- mmcv/image/photometric.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mmcv/fileio/io.py b/mmcv/fileio/io.py index 430711c2cd..1225b75a29 100644 --- a/mmcv/fileio/io.py +++ b/mmcv/fileio/io.py @@ -21,7 +21,7 @@ def load(file, file_format=None, file_client_args=None, **kwargs): This method provides a unified api for loading data from serialized files. Note: - In v1.3.15 and later, `load` supports loading data from serialized + In v1.3.15 and later, ``load`` supports loading data from serialized files those can be storaged in different backends. Args: @@ -37,8 +37,8 @@ def load(file, file_format=None, file_client_args=None, **kwargs): Examples: >>> load('/path/of/your/file') # file is storaged in disk - >>> load('https://path/of/your/file') # file is storage in Internet - >>> load('s3://path/of/your/file') # file is storage in ceph or petrel + >>> load('https://path/of/your/file') # file is storaged in Internet + >>> load('s3://path/of/your/file') # file is storaged in petrel Returns: The content from the file. @@ -73,7 +73,7 @@ def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs): and also supports custom arguments for each file format. Note: - In v1.3.15 and later, `dump` supports dumping data as strings or to + In v1.3.15 and later, ``dump`` supports dumping data as strings or to files which is saved to different backends. Args: diff --git a/mmcv/fileio/parse.py b/mmcv/fileio/parse.py index 8ed84bfbe4..0f368ab9cb 100644 --- a/mmcv/fileio/parse.py +++ b/mmcv/fileio/parse.py @@ -14,7 +14,7 @@ def list_from_file(filename, """Load a text file and parse the content as a list of strings. Note: - In v1.3.15 and later, `list_from_file` supports loading a text file + In v1.3.15 and later, ``list_from_file`` supports loading a text file which can be storaged in different backends and parsing the content as a list for strings. @@ -63,7 +63,7 @@ def dict_from_file(filename, the following columns will be parsed as dict values. Note: - In v1.3.15 and later, `dict_from_file` supports loading a text file + In v1.3.15 and later, ``dict_from_file`` supports loading a text file which can be storaged in different backends and parsing the content as a dict. diff --git a/mmcv/image/photometric.py b/mmcv/image/photometric.py index e234394084..5085d01201 100644 --- a/mmcv/image/photometric.py +++ b/mmcv/image/photometric.py @@ -311,7 +311,7 @@ def adjust_sharpness(img, factor=1., kernel=None): Note: No value sanity check is enforced on the kernel set by users. So with - an inappropriate kernel, the `adjust_sharpness` may fail to perform + an inappropriate kernel, the ``adjust_sharpness`` may fail to perform the function its name indicates but end up performing whatever transform determined by the kernel. From 941a8846290e31118b74f4521a9ca891e399855d Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Wed, 13 Oct 2021 10:58:30 +0800 Subject: [PATCH 29/46] add comment and polish docstring --- mmcv/fileio/file_client.py | 93 ++++++++++++++++++++------------------ 1 file changed, 49 insertions(+), 44 deletions(-) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index 4208de9fb7..514857f22d 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -72,23 +72,24 @@ def get_text(self, filepath, encoding=None): class PetrelBackend(BaseStorageBackend): """Petrel storage backend (for internal use). - PetrelBackend supports reading or writing data to multiple clusters. If the - filepath contains the cluster name, PetrelBackend will read from the - filepath or write to the filepath. Otherwise, PetrelBackend will access - the default cluster. + PetrelBackend supports reading and writing data to multiple clusters. + If the file path contains the cluster name, PetrelBackend will read data + from specified cluster or write data to it. Otherwise, PetrelBackend will + access the default cluster. Args: path_mapping (dict, optional): Path mapping dict from local path to - Petrel path. When `path_mapping={'src': 'dst'}`, `src` in - `filepath` will be replaced by `dst`. Default: None. - enable_mc (bool): Whether to enable memcached support. Default: True. + Petrel path. When ``path_mapping={'src': 'dst'}``, ``src`` in + ``filepath`` will be replaced by ``dst``. Default: None. + enable_mc (bool, optional): Whether to enable memcached support. + Default: True. Examples: >>> filepath1 = 's3://path/of/file' >>> filepath2 = 'cluster-name:s3://path/of/file' >>> client = PetrelBackend() - >>> client.get(filepath1) # get from default cluster - >>> client.get(filepath2) # get from 'cluster-name' cluster + >>> client.get(filepath1) # get data from default cluster + >>> client.get(filepath2) # get data from 'cluster-name' cluster """ def __init__(self, @@ -105,7 +106,7 @@ def __init__(self, self.path_mapping = path_mapping def _path_mapping(self, filepath: str) -> str: - """Replace the prefix of filepath with path_mapping. + """Replace the prefix of ``filepath`` with :attr:`path_mapping`. Args: filepath (str): Path to be mapped. @@ -116,12 +117,12 @@ def _path_mapping(self, filepath: str) -> str: return filepath def _format_path(self, filepath: str) -> str: - """Convert filepath to standard format of petrel oss. + """Convert a ``filepath`` to standard format of petrel oss. - Since the filepath is concatenated by `os.path.join`, in a windows - environment, the filepath will be the format of - 's3://bucket_name\\image.jpg'. By invoking `_format_path`, the above - filepath will be converted to 's3://bucket_name/image.jpg'. + If the ``filepath`` is concatenated by ``os.path.join``, in a windows + environment, the ``filepath`` will be the format of + 's3://bucket_name\\image.jpg'. By invoking :meth:`_format_path`, the + above ``filepath`` will be converted to 's3://bucket_name/image.jpg'. Args: filepath (str): Path to be formatted. @@ -129,7 +130,7 @@ def _format_path(self, filepath: str) -> str: return re.sub(r'\\+', '/', filepath) def get(self, filepath: Union[str, Path]) -> memoryview: - """Read data from a given filepath with 'rb' mode. + """Read data from a given ``filepath`` with 'rb' mode. Args: filepath (str or Path): Path to read data. @@ -143,17 +144,17 @@ def get(self, filepath: Union[str, Path]) -> memoryview: def get_text(self, filepath: Union[str, Path], encoding: str = 'utf-8') -> str: - """Read data from a given filepath with 'r' mode. + """Read data from a given ``filepath`` with 'r' mode. Args: filepath (str or Path): Path to read data. encoding (str, optional): The encoding format used to open the - `filepath`. Default: 'utf-8'. + ``filepath``. Default: 'utf-8'. """ return str(self.get(filepath), encoding=encoding) def put(self, obj: bytes, filepath: Union[str, Path]) -> None: - """Save data to a given filepath. + """Save data to a given ``filepath``. Args: obj (bytes): Data to be saved. @@ -167,13 +168,13 @@ def put_text(self, obj: str, filepath: Union[str, Path], encoding: str = 'utf-8') -> None: - """Save data to a given filepath. + """Save data to a given ``filepath``. Args: obj (str): Data to be written. filepath (str or Path): Path to write data. encoding (str, optional): The encoding format used to encode the - `obj`. Default: 'utf-8'. + ``obj``. Default: 'utf-8'. """ self.put(bytes(obj, encoding=encoding), filepath) @@ -188,7 +189,7 @@ def remove(self, filepath: Union[str, Path]) -> None: self._client.delete(filepath) def exists(self, filepath: Union[str, Path]) -> bool: - """Check a filepath whether exists. + """Check a ``filepath`` whether exists. Args: filepath (str or Path): Path to be checked whether exists. @@ -198,7 +199,7 @@ def exists(self, filepath: Union[str, Path]) -> bool: return self._client.contains(filepath) def isfile(self, filepath: Union[str, Path]) -> bool: - """Check a filepath whether it is a file. + """Check a ``filepath`` whether it is a file. Args: filepath (str or Path): Path to be checked whether it is a file. @@ -221,7 +222,7 @@ def concat_paths(self, filepath: Union[str, Path], return '/'.join(formatted_paths) def _release_resource(self, filepath: str) -> None: - """Release the resource generated by _get_local_path. + """Release the resource generated by :meth:`_get_local_path`. Args: filepath (str): Path to be released. @@ -229,7 +230,7 @@ def _release_resource(self, filepath: str) -> None: os.remove(filepath) def _get_local_path(self, filepath: str) -> str: - """Download a file from filepath. + """Download a file from ``filepath``. Args: filepath (str): Download a file from ``filepath``. @@ -337,7 +338,7 @@ class HardDiskBackend(BaseStorageBackend): """Raw hard disks storage backend.""" def get(self, filepath: Union[str, Path]) -> bytes: - """Read data from a given filepath with 'rb' mode. + """Read data from a given ``filepath`` with 'rb' mode. Args: filepath (str or Path): Path to read data. @@ -350,12 +351,12 @@ def get(self, filepath: Union[str, Path]) -> bytes: def get_text(self, filepath: Union[str, Path], encoding: str = 'utf-8') -> str: - """Read data from a given filepath with 'r' mode. + """Read data from a given ``filepath`` with 'r' mode. Args: filepath (str or Path): Path to read data. encoding (str, optional): The encoding format used to open the - `filepath`. Default: 'utf-8'. + ``filepath``. Default: 'utf-8'. """ filepath = str(filepath) with open(filepath, 'r', encoding=encoding) as f: @@ -363,7 +364,7 @@ def get_text(self, return value_buf def put(self, obj: bytes, filepath: Union[str, Path]) -> None: - """Write data to a given filepath with 'wb' mode. + """Write data to a given ``filepath`` with 'wb' mode. Args: obj (bytes): Data to be written. @@ -377,7 +378,7 @@ def put_text(self, obj: str, filepath: Union[str, Path], encoding: str = 'utf-8') -> None: - """Write data to a given filepath with 'w' mode. + """Write data to a given ``filepath`` with 'w' mode. Args: obj (str): Data to be written. @@ -399,7 +400,7 @@ def remove(self, filepath: Union[str, Path]) -> None: os.remove(filepath) def exists(self, filepath: Union[str, Path]) -> bool: - """Check a filepath whether exists. + """Check a ``filepath`` whether exists. Args: filepath (str or Path): Path to be checked whether exists. @@ -407,7 +408,7 @@ def exists(self, filepath: Union[str, Path]) -> bool: return osp.exists(str(filepath)) def isdir(self, filepath: Union[str, Path]) -> bool: - """Check a filepath whether it is a directory. + """Check a ``filepath`` whether it is a directory. Args: filepath (str or Path): Path to be checked whether it is a @@ -416,7 +417,7 @@ def isdir(self, filepath: Union[str, Path]) -> bool: return osp.isdir(str(filepath)) def isfile(self, filepath: Union[str, Path]) -> bool: - """Check a filepath whether it is a file. + """Check a ``filepath`` whether it is a file. Args: filepath (str or Path): Path to be checked whether it is a file. @@ -458,7 +459,7 @@ def get_text(self, filepath, encoding='utf-8'): return value_buf.decode(encoding) def _release_resource(self, filepath: str) -> None: - """Release the resource generated by _get_local_path. + """Release the resource generated by :meth:`_get_local_path`. Args: filepath (str): Path to be released. @@ -525,7 +526,7 @@ class FileClient: # This collection is used to record the overridden backends, and when a # backend appears in the collection, the singleton pattern is disabled for # that backend, because if the singleton pattern is used, then the object - # returned will be the backend before the override + # returned will be the backend before overwriting _overridden_backends = set() _prefix_to_backends = { 's3': PetrelBackend, @@ -548,15 +549,19 @@ def __new__(cls, backend=None, prefix=None, **kwargs): f'prefix {prefix} is not supported. Currently supported ones ' f'are {list(cls._prefix_to_backends.keys())}') + # concatenate the arguments to a unique key for determining whether + # objects with the same arguments were created arg_key = f'{backend}:{prefix}' for key, value in kwargs.items(): arg_key += f':{key}:{value}' + # if a backend was overridden, it will create a new object if (arg_key in cls._instances and backend not in cls._overridden_backends and prefix not in cls._overridden_prefixes): _instance = cls._instances[arg_key] else: + # create a new object and put it to _instance _instance = super().__new__(cls) if backend is not None: _instance.client = cls._backends[backend](**kwargs) @@ -715,7 +720,7 @@ def _register(backend_cls): return _register def get(self, filepath: Union[str, Path]) -> Union[bytes, memoryview]: - """Read data from a given filepath with 'rb' mode. + """Read data from a given ``filepath`` with 'rb' mode. Args: filepath (str or Path): Path to read data. @@ -723,7 +728,7 @@ def get(self, filepath: Union[str, Path]) -> Union[bytes, memoryview]: return self.client.get(filepath) def get_text(self, filepath: Union[str, Path], encoding='utf-8') -> str: - """Read data from a given filepath with 'r' mode. + """Read data from a given ``filepath`` with 'r' mode. Args: filepath (str or Path): Path to read data. @@ -733,7 +738,7 @@ def get_text(self, filepath: Union[str, Path], encoding='utf-8') -> str: return self.client.get_text(filepath, encoding) def put(self, obj: bytes, filepath: Union[str, Path]) -> None: - """Write data to a given filepath with 'wb' mode. + """Write data to a given ``filepath`` with 'wb' mode. Args: obj (bytes): Data to be written. @@ -742,7 +747,7 @@ def put(self, obj: bytes, filepath: Union[str, Path]) -> None: self.client.put(obj, filepath) def put_text(self, obj: str, filepath: Union[str, Path]) -> None: - """Write data to a given filepath with 'w' mode. + """Write data to a given ``filepath`` with 'w' mode. Args: obj (str): Data to be written. @@ -761,7 +766,7 @@ def remove(self, filepath: Union[str, Path]) -> None: self.client.remove(filepath) def exists(self, filepath: Union[str, Path]) -> bool: - """Check a filepath whether exists. + """Check a ``filepath`` whether exists. Args: filepath (str or Path): Path to be checked whether exists. @@ -769,7 +774,7 @@ def exists(self, filepath: Union[str, Path]) -> bool: return self.client.exists(filepath) def isdir(self, filepath: Union[str, Path]) -> bool: - """Check a filepath whether it is a directory. + """Check a ``filepath`` whether it is a directory. Args: filepath (str or Path): Path to be checked whether it is a @@ -778,7 +783,7 @@ def isdir(self, filepath: Union[str, Path]) -> bool: return self.client.isdir(filepath) def isfile(self, filepath: Union[str, Path]) -> bool: - """Check a filepath whether it is a file. + """Check a ``filepath`` whether it is a file. Args: filepath (str or Path): Path to be checked whether it is a file. @@ -799,9 +804,9 @@ def concat_paths(self, filepath: Union[str, Path], @contextmanager def get_local_path(self, filepath: Union[str, Path]): - """Download data from given filepath and write the data to local path. + """Download data from ``filepath`` and write the data to local path. - If the ``filepath`` is a local path, just return the ``filepath``. + If the ``filepath`` is a local path, just return the itself. Note: ``get_local_path`` is an experimental interface that may change in From 198a465e48e791b843927ffc9b4857d98929fe28 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Thu, 14 Oct 2021 21:30:30 +0800 Subject: [PATCH 30/46] polish docstring --- mmcv/fileio/file_client.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index 514857f22d..28df487e2f 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -189,7 +189,7 @@ def remove(self, filepath: Union[str, Path]) -> None: self._client.delete(filepath) def exists(self, filepath: Union[str, Path]) -> bool: - """Check a ``filepath`` whether exists. + """Check whether a file path exists. Args: filepath (str or Path): Path to be checked whether exists. @@ -199,7 +199,7 @@ def exists(self, filepath: Union[str, Path]) -> bool: return self._client.contains(filepath) def isfile(self, filepath: Union[str, Path]) -> bool: - """Check a ``filepath`` whether it is a file. + """Check whether a file path is a file. Args: filepath (str or Path): Path to be checked whether it is a file. @@ -208,7 +208,7 @@ def isfile(self, filepath: Union[str, Path]) -> bool: def concat_paths(self, filepath: Union[str, Path], *filepaths: Union[str, Path]) -> str: - """Concatenate all filepaths. + """Concatenate all file paths. Args: filepath (str or Path): Path to be concatenated. @@ -400,7 +400,7 @@ def remove(self, filepath: Union[str, Path]) -> None: os.remove(filepath) def exists(self, filepath: Union[str, Path]) -> bool: - """Check a ``filepath`` whether exists. + """Check whether a file path exists. Args: filepath (str or Path): Path to be checked whether exists. @@ -408,7 +408,7 @@ def exists(self, filepath: Union[str, Path]) -> bool: return osp.exists(str(filepath)) def isdir(self, filepath: Union[str, Path]) -> bool: - """Check a ``filepath`` whether it is a directory. + """Check whether a file path is a directory. Args: filepath (str or Path): Path to be checked whether it is a @@ -426,7 +426,7 @@ def isfile(self, filepath: Union[str, Path]) -> bool: def concat_paths(self, filepath: Union[str, Path], *filepaths: Union[str, Path]) -> str: - """Concatenate all filepaths. + """Concatenate all file paths. Join one or more filepath components intelligently. The return value is the concatenation of filepath and any members of *filepaths. @@ -766,7 +766,7 @@ def remove(self, filepath: Union[str, Path]) -> None: self.client.remove(filepath) def exists(self, filepath: Union[str, Path]) -> bool: - """Check a ``filepath`` whether exists. + """Check whether a file path exists. Args: filepath (str or Path): Path to be checked whether exists. @@ -774,7 +774,7 @@ def exists(self, filepath: Union[str, Path]) -> bool: return self.client.exists(filepath) def isdir(self, filepath: Union[str, Path]) -> bool: - """Check a ``filepath`` whether it is a directory. + """Check whether a file path is a directory. Args: filepath (str or Path): Path to be checked whether it is a @@ -783,7 +783,7 @@ def isdir(self, filepath: Union[str, Path]) -> bool: return self.client.isdir(filepath) def isfile(self, filepath: Union[str, Path]) -> bool: - """Check a ``filepath`` whether it is a file. + """Check whether a file path is a file. Args: filepath (str or Path): Path to be checked whether it is a file. @@ -792,7 +792,7 @@ def isfile(self, filepath: Union[str, Path]) -> bool: def concat_paths(self, filepath: Union[str, Path], *filepaths: Union[str, Path]) -> str: - """Concatenate all filepaths. + """Concatenate all file paths. Join one or more filepath components intelligently. The return value is the concatenation of filepath and any members of *filepaths. From e0d6a839c5a59ba996acb0dff902a92c94424035 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Fri, 15 Oct 2021 15:58:38 +0800 Subject: [PATCH 31/46] rename _path_mapping to _map_path --- mmcv/fileio/file_client.py | 18 ++++++++---------- tests/test_fileclient.py | 4 ++-- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index 28df487e2f..14604fbade 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -105,7 +105,7 @@ def __init__(self, assert isinstance(path_mapping, dict) or path_mapping is None self.path_mapping = path_mapping - def _path_mapping(self, filepath: str) -> str: + def _map_path(self, filepath: str) -> str: """Replace the prefix of ``filepath`` with :attr:`path_mapping`. Args: @@ -135,7 +135,7 @@ def get(self, filepath: Union[str, Path]) -> memoryview: Args: filepath (str or Path): Path to read data. """ - filepath = self._path_mapping(str(filepath)) + filepath = self._map_path(str(filepath)) filepath = self._format_path(filepath) value = self._client.Get(filepath) value_buf = memoryview(value) @@ -160,7 +160,7 @@ def put(self, obj: bytes, filepath: Union[str, Path]) -> None: obj (bytes): Data to be saved. filepath (str or Path): Path to write data. """ - filepath = self._path_mapping(str(filepath)) + filepath = self._map_path(str(filepath)) filepath = self._format_path(filepath) self._client.put(filepath, obj) @@ -184,7 +184,7 @@ def remove(self, filepath: Union[str, Path]) -> None: Args: filepath (str or Path): Path to be removed. """ - filepath = self._path_mapping(str(filepath)) + filepath = self._map_path(str(filepath)) filepath = self._format_path(filepath) self._client.delete(filepath) @@ -194,7 +194,7 @@ def exists(self, filepath: Union[str, Path]) -> bool: Args: filepath (str or Path): Path to be checked whether exists. """ - filepath = self._path_mapping(str(filepath)) + filepath = self._map_path(str(filepath)) filepath = self._format_path(filepath) return self._client.contains(filepath) @@ -213,12 +213,10 @@ def concat_paths(self, filepath: Union[str, Path], Args: filepath (str or Path): Path to be concatenated. """ - formatted_paths = [ - self._format_path(self._path_mapping(str(filepath))) - ] + formatted_paths = [self._format_path(self._map_path(str(filepath)))] for path in filepaths: formatted_paths.append( - self._format_path(self._path_mapping(str(path)))) + self._format_path(self._map_path(str(path)))) return '/'.join(formatted_paths) def _release_resource(self, filepath: str) -> None: @@ -384,7 +382,7 @@ def put_text(self, obj (str): Data to be written. filepath (str or Path): Path to write data. encoding (str, optional): The encoding format used to open the - `filepath`. Default: 'utf-8'. + ``filepath``. Default: 'utf-8'. """ filepath = str(filepath) with open(filepath, 'w', encoding=encoding) as f: diff --git a/tests/test_fileclient.py b/tests/test_fileclient.py index e5868b1e87..935206a392 100644 --- a/tests/test_fileclient.py +++ b/tests/test_fileclient.py @@ -171,11 +171,11 @@ def test_petrel_backend(self, backend, prefix): with pytest.raises(AssertionError): FileClient('petrel', path_mapping=1) - # test `_path_mapping` + # test `_map_path` petrel_dir = 's3://user/data' petrel_backend = FileClient( 'petrel', path_mapping={str(self.test_data_dir): petrel_dir}) - assert petrel_backend.client._path_mapping(str(self.img_path)) == \ + assert petrel_backend.client._map_path(str(self.img_path)) == \ str(self.img_path).replace(str(self.test_data_dir), petrel_dir) petrel_path = f'{petrel_dir}/test.jpg' From ae0cdd3d597811b7f1b759c6990d82a153b775c7 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Fri, 15 Oct 2021 18:28:07 +0800 Subject: [PATCH 32/46] polish docstring and fix typo --- docs/understand_mmcv/io.md | 2 +- mmcv/fileio/file_client.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/understand_mmcv/io.md b/docs/understand_mmcv/io.md index c83ba9bdc1..f09a15a245 100644 --- a/docs/understand_mmcv/io.md +++ b/docs/understand_mmcv/io.md @@ -3,7 +3,7 @@ This module provides two universal API to load and dump files of different formats. ```{note} -Since v1.3.15, the IO modules support loading and dumping data from and to different backends, respectively. More details are in PR [#1330](https://github.com/open-mmlab/mmcv/pull/1330). +Since v1.3.15, the IO modules support loading (dumping) data from (to) different backends, respectively. More details are in PR [#1330](https://github.com/open-mmlab/mmcv/pull/1330). ``` ### Load and dump data diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index 14604fbade..6ec92d831f 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -119,7 +119,7 @@ def _map_path(self, filepath: str) -> str: def _format_path(self, filepath: str) -> str: """Convert a ``filepath`` to standard format of petrel oss. - If the ``filepath`` is concatenated by ``os.path.join``, in a windows + If the ``filepath`` is concatenated by ``os.path.join``, in a Windows environment, the ``filepath`` will be the format of 's3://bucket_name\\image.jpg'. By invoking :meth:`_format_path`, the above ``filepath`` will be converted to 's3://bucket_name/image.jpg'. From a2e0162859f7862a09800397328e2b303746d9b1 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Sat, 16 Oct 2021 17:24:50 +0800 Subject: [PATCH 33/46] refactor get_local_path --- mmcv/fileio/file_client.py | 89 ++++++++++++++++++-------------------- tests/test_fileclient.py | 4 +- 2 files changed, 45 insertions(+), 48 deletions(-) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index 6ec92d831f..d3b32a79a9 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -8,7 +8,7 @@ from abc import ABCMeta, abstractmethod from contextlib import contextmanager from pathlib import Path -from typing import Optional, Union +from typing import Iterable, Optional, Union from urllib.request import urlopen from mmcv.utils.path import is_filepath @@ -219,27 +219,28 @@ def concat_paths(self, filepath: Union[str, Path], self._format_path(self._map_path(str(path)))) return '/'.join(formatted_paths) - def _release_resource(self, filepath: str) -> None: - """Release the resource generated by :meth:`_get_local_path`. - - Args: - filepath (str): Path to be released. - """ - os.remove(filepath) - - def _get_local_path(self, filepath: str) -> str: + @contextmanager + def get_local_path(self, filepath: str) -> Iterable[str]: """Download a file from ``filepath``. Args: filepath (str): Download a file from ``filepath``. + + Examples: + >>> client = PetrelBackend() + >>> # After existing from the ``with`` clause, + >>> # the path will be removed + >>> with client.get_local_path('s3://path/of/your/file') as path: + ... # do something here """ assert self.isfile(filepath) - - # the file will be removed when calling _release_resource() - f = tempfile.NamedTemporaryFile(delete=False) - f.write(self.get(filepath)) - f.close() - return f.name + try: + f = tempfile.NamedTemporaryFile(delete=False) + f.write(self.get(filepath)) + f.close() + yield f.name + finally: + os.remove(f.name) class MemcachedBackend(BaseStorageBackend): @@ -436,13 +437,10 @@ def concat_paths(self, filepath: Union[str, Path], filepaths = [str(path) for path in filepaths] return osp.join(filepath, *filepaths) - def _release_resource(self, filepath: str) -> None: - """Do nothing in order to unify API.""" - pass - - def _get_local_path(self, filepath: str) -> str: - """Do nothing in order to unify API.""" - return filepath + @contextmanager + def get_local_path(self, filepath: str) -> Iterable[str]: + """Only for unified API and do nothing.""" + yield filepath class HTTPBackend(BaseStorageBackend): @@ -456,25 +454,27 @@ def get_text(self, filepath, encoding='utf-8'): value_buf = urlopen(filepath).read() return value_buf.decode(encoding) - def _release_resource(self, filepath: str) -> None: - """Release the resource generated by :meth:`_get_local_path`. - - Args: - filepath (str): Path to be released. - """ - os.remove(filepath) - - def _get_local_path(self, filepath: str) -> str: - """Download a file from filepath. + @contextmanager + def get_local_path(self, filepath: str) -> Iterable[str]: + """Download a file from ``filepath``. Args: filepath (str): Download a file from ``filepath``. + + Examples: + >>> client = HTTPBackend() + >>> # After existing from the ``with`` clause, + >>> # the path will be removed + >>> with client.get_local_path('http://path/of/your/file') as path: + ... # do something here """ - # the file will be removed when calling _release_resource() - f = tempfile.NamedTemporaryFile(delete=False) - f.write(self.get(filepath)) - f.close() - return f.name + try: + f = tempfile.NamedTemporaryFile(delete=False) + f.write(self.get(filepath)) + f.close() + yield f.name + finally: + os.remove(f.name) class FileClient: @@ -801,12 +801,12 @@ def concat_paths(self, filepath: Union[str, Path], return self.client.concat_paths(filepath, *filepaths) @contextmanager - def get_local_path(self, filepath: Union[str, Path]): + def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]: """Download data from ``filepath`` and write the data to local path. - If the ``filepath`` is a local path, just return the itself. + If the ``filepath`` is a local path, just return itself. - Note: + .. warning:: ``get_local_path`` is an experimental interface that may change in the future. @@ -818,8 +818,5 @@ def get_local_path(self, filepath: Union[str, Path]): >>> with file_client.get_local_path('s3://bucket/abc.jpg') as path: ... # do something here """ - path = self.client._get_local_path(str(filepath)) - try: - yield path - finally: - self.client._release_resource(path) + with self.client.get_local_path(str(filepath)) as local_path: + yield local_path diff --git a/tests/test_fileclient.py b/tests/test_fileclient.py index 935206a392..820524fadf 100644 --- a/tests/test_fileclient.py +++ b/tests/test_fileclient.py @@ -105,7 +105,7 @@ def test_disk_backend(self): # test `exists` assert not disk_backend.exists(filepath2) - # test `_get_local_path` + # test `get_local_path` # if the backend is disk, `get_local_path` just return the input with disk_backend.get_local_path(filepath1) as path: assert str(filepath1) == path @@ -228,7 +228,7 @@ def test_petrel_backend(self, backend, prefix): assert petrel_backend.concat_paths(petrel_dir, 'dir', 'file') == \ f'{petrel_dir}/dir/file' - # test `_get_local_path` + # test `get_local_path` # exist the with block and path will be released petrel_backend.client._client.contains = MagicMock(return_value=True) with petrel_backend.get_local_path(petrel_path) as path: From 50ba26f5f920ed4440b2d4e275b3c8cdbfed22dd Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Sun, 17 Oct 2021 10:47:02 +0800 Subject: [PATCH 34/46] add list_dir_or_file for FileClient --- mmcv/fileio/file_client.py | 70 +++++++++++++++++++++++++++- tests/test_fileclient.py | 95 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 164 insertions(+), 1 deletion(-) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index d3b32a79a9..983ab0a25c 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -8,7 +8,7 @@ from abc import ABCMeta, abstractmethod from contextlib import contextmanager from pathlib import Path -from typing import Iterable, Optional, Union +from typing import Iterable, Iterator, Optional, Tuple, Union from urllib.request import urlopen from mmcv.utils.path import is_filepath @@ -442,6 +442,53 @@ def get_local_path(self, filepath: str) -> Iterable[str]: """Only for unified API and do nothing.""" yield filepath + def list_dir_or_file(self, + dir_path: Union[str, Path], + list_dir: bool = True, + list_file: bool = True, + suffix: Optional[Union[str, Tuple[str]]] = None, + recursive: bool = False) -> Iterator[str]: + """Scan a directory to find the interested directories or files in + arbitrary order. + + Args: + dir_path (str | obj:`Path`): Path of the directory. + list_dir (bool): List the directories. Default: True. + list_file (bool): List the path of files. Default: True. + suffix (str or tuple[str], optional): File suffix + that we are interested in. Default: None. + recursive (bool): If set to True, recursively scan the + directory. Default: False. + """ + dir_path = str(dir_path) + if list_dir and suffix is not None: + raise TypeError('`suffix` should be None when `list_dir` is True') + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('`suffix` must be a string or tuple of strings') + + root = dir_path + + def _list_dir_or_file(dir_path, list_dir, list_file, suffix, + recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + rel_path = osp.relpath(entry.path, root) + if (suffix is None + or rel_path.endswith(suffix)) and list_file: + yield rel_path + elif osp.isdir(entry.path): + if list_dir: + rel_dir = osp.relpath(entry.path, root) + yield rel_dir + if recursive: + yield from _list_dir_or_file(entry.path, list_dir, + list_file, suffix, + recursive) + + return _list_dir_or_file(dir_path, list_dir, list_file, suffix, + recursive) + class HTTPBackend(BaseStorageBackend): """HTTP and HTTPS storage bachend.""" @@ -820,3 +867,24 @@ def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]: """ with self.client.get_local_path(str(filepath)) as local_path: yield local_path + + def list_dir_or_file(self, + dir_path: Union[str, Path], + list_dir: bool = True, + list_file: bool = True, + suffix: Optional[Union[str, Tuple[str]]] = None, + recursive: bool = False) -> Iterator[str]: + """Scan a directory to find the interested directories or files in + arbitrary order. + + Args: + dir_path (str | obj:`Path`): Path of the directory. + list_dir (bool): List the directories. Default: True. + list_file (bool): List the path of files. Default: True. + suffix (str or tuple[str], optional): File suffix + that we are interested in. Default: None. + recursive (bool): If set to True, recursively scan the + directory. Default: False. + """ + yield from self.client.list_dir_or_file(dir_path, list_dir, list_file, + suffix, recursive) diff --git a/tests/test_fileclient.py b/tests/test_fileclient.py index 820524fadf..6d9a29b3b1 100644 --- a/tests/test_fileclient.py +++ b/tests/test_fileclient.py @@ -111,12 +111,107 @@ def test_disk_backend(self): assert str(filepath1) == path assert osp.isfile(filepath1) + # test `concat_paths` disk_dir = '/path/of/your/directory' assert disk_backend.concat_paths(disk_dir, 'file') == \ osp.join(disk_dir, 'file') assert disk_backend.concat_paths(disk_dir, 'dir', 'file') == \ osp.join(disk_dir, 'dir', 'file') + # test `list_dir_or_file` + with tempfile.TemporaryDirectory() as tmp_dir: + text1 = Path(tmp_dir) / 'text1.txt' + text1.open('w').write('text1') + text2 = Path(tmp_dir) / 'text2.txt' + text2.open('w').write('text2') + dir1 = Path(tmp_dir) / 'dir1' + dir1.mkdir() + text3 = dir1 / 'text3.txt' + text3.open('w').write('text3') + dir2 = Path(tmp_dir) / 'dir2' + dir2.mkdir() + jpg1 = dir2 / 'img.jpg' + jpg1.open('wb').write(b'img') + dir3 = dir2 / 'dir3' + dir3.mkdir() + text4 = dir3 / 'text4.txt' + text4.open('w').write('text4') + # 1. list directories and files + assert set(disk_backend.list_dir_or_file(tmp_dir)) == set( + ['dir1', 'dir2', 'text1.txt', 'text2.txt']) + # 2. list directories and files recursively + assert set(disk_backend.list_dir_or_file( + tmp_dir, recursive=True)) == set([ + 'dir1', + osp.join('dir1', 'text3.txt'), 'dir2', + osp.join('dir2', 'dir3'), + osp.join('dir2', 'dir3', 'text4.txt'), + osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' + ]) + # 3. only list directories + assert set( + disk_backend.list_dir_or_file( + tmp_dir, list_file=False)) == set(['dir1', 'dir2']) + with pytest.raises( + TypeError, + match='`suffix` should be None when `list_dir` is True'): + # Exception is raised among the `list_dir_or_file` of client, + # so we need to invode the client to trigger the exception + disk_backend.client.list_dir_or_file( + tmp_dir, list_file=False, suffix='.txt') + # 4. only list directories recursively + assert set( + disk_backend.list_dir_or_file( + tmp_dir, list_file=False, recursive=True)) == set( + ['dir1', 'dir2', + osp.join('dir2', 'dir3')]) + # 5. only list files + assert set(disk_backend.list_dir_or_file( + tmp_dir, list_dir=False)) == set(['text1.txt', 'text2.txt']) + # 6. only list files recursively + assert set( + disk_backend.list_dir_or_file( + tmp_dir, list_dir=False, recursive=True)) == set([ + osp.join('dir1', 'text3.txt'), + osp.join('dir2', 'dir3', 'text4.txt'), + osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' + ]) + # 7. only list files ending with suffix + assert set( + disk_backend.list_dir_or_file( + tmp_dir, list_dir=False, + suffix='.txt')) == set(['text1.txt', 'text2.txt']) + assert set( + disk_backend.list_dir_or_file( + tmp_dir, list_dir=False, + suffix=('.txt', + '.jpg'))) == set(['text1.txt', 'text2.txt']) + with pytest.raises( + TypeError, + match='`suffix` must be a string or tuple of strings'): + disk_backend.client.list_dir_or_file( + tmp_dir, list_dir=False, suffix=['.txt', '.jpg']) + # 8. only list files ending with suffix recursively + assert set( + disk_backend.list_dir_or_file( + tmp_dir, list_dir=False, suffix='.txt', + recursive=True)) == set([ + osp.join('dir1', 'text3.txt'), + osp.join('dir2', 'dir3', 'text4.txt'), 'text1.txt', + 'text2.txt' + ]) + # 7. only list files ending with suffix + assert set( + disk_backend.list_dir_or_file( + tmp_dir, + list_dir=False, + suffix=('.txt', '.jpg'), + recursive=True)) == set([ + osp.join('dir1', 'text3.txt'), + osp.join('dir2', 'dir3', 'text4.txt'), + osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' + ]) + @patch('ceph.S3Client', MockS3Client) def test_ceph_backend(self): ceph_backend = FileClient('ceph') From 4ad3bf5e0cd30c7f2fddc17673ed04cc114f7008 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Mon, 18 Oct 2021 22:26:30 +0800 Subject: [PATCH 35/46] add list_dir_or_file for PetrelBackend --- mmcv/fileio/file_client.py | 91 ++++++++++++++- tests/test_fileclient.py | 230 +++++++++++++++++++++++++++++-------- 2 files changed, 274 insertions(+), 47 deletions(-) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index 983ab0a25c..6664ef8d67 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -196,7 +196,18 @@ def exists(self, filepath: Union[str, Path]) -> bool: """ filepath = self._map_path(str(filepath)) filepath = self._format_path(filepath) - return self._client.contains(filepath) + return self._client.contains(filepath) or self._client.isdir(filepath) + + def isdir(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a directory. + + Args: + filepath (str or Path): Path to be checked whether it is a + directory. + """ + filepath = self._map_path(str(filepath)) + filepath = self._format_path(filepath) + return self._client.isdir(filepath) def isfile(self, filepath: Union[str, Path]) -> bool: """Check whether a file path is a file. @@ -204,7 +215,9 @@ def isfile(self, filepath: Union[str, Path]) -> bool: Args: filepath (str or Path): Path to be checked whether it is a file. """ - return self.exists(filepath) + filepath = self._map_path(str(filepath)) + filepath = self._format_path(filepath) + return self._client.contains(filepath) def concat_paths(self, filepath: Union[str, Path], *filepaths: Union[str, Path]) -> str: @@ -233,6 +246,8 @@ def get_local_path(self, filepath: str) -> Iterable[str]: >>> with client.get_local_path('s3://path/of/your/file') as path: ... # do something here """ + filepath = self._map_path(str(filepath)) + filepath = self._format_path(filepath) assert self.isfile(filepath) try: f = tempfile.NamedTemporaryFile(delete=False) @@ -242,6 +257,72 @@ def get_local_path(self, filepath: str) -> Iterable[str]: finally: os.remove(f.name) + def list_dir_or_file(self, + dir_path: Union[str, Path], + list_dir: bool = True, + list_file: bool = True, + suffix: Optional[Union[str, Tuple[str]]] = None, + recursive: bool = False) -> Iterator[str]: + """Scan a directory to find the interested directories or files in + arbitrary order. + + Note: + Petrel has no concept of directories but it simulates the directory + hierarchy in the filesystem through public prefixes. In addition, + if the returned path ends with '/', it means the path is a public + prefix which is a logical directory. + + Note: + :meth:`list_dir_or_file` returns the path relative to ``dir_path``. + In addition, the returned path of directory will not contains the + suffix '/' which is consistent with other backends. + + Args: + dir_path (str | obj:`Path`): Path of the directory. + list_dir (bool): List the directories. Default: True. + list_file (bool): List the path of files. Default: True. + suffix (str or tuple[str], optional): File suffix + that we are interested in. Default: None. + recursive (bool): If set to True, recursively scan the + directory. Default: False. + """ + dir_path = self._map_path(str(dir_path)) + dir_path = self._format_path(dir_path) + if list_dir and suffix is not None: + raise TypeError('`suffix` should be None when `list_dir` is True') + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('`suffix` must be a string or tuple of strings') + + # Petrel's simulated directory hierarchy assumes that directory paths + # should end with `/` + if not dir_path.endswith('/'): + dir_path += '/' + + root = dir_path + + def _list_dir_or_file(dir_path, list_dir, list_file, suffix, + recursive): + for path in self._client.list(dir_path): + # the `self.isdir` is not used here to determine if path is a + # directory, because `self.isdir` relies on `self._client.list` + if path.endswith('/'): # a directory path + if list_dir: + # exclude the last character '/' + rel_dir = path[len(root):-1] + yield rel_dir + if recursive: + yield from _list_dir_or_file(path, list_dir, list_file, + suffix, recursive) + else: # a file path + rel_path = path[len(root):] + if (suffix is None + or rel_path.endswith(suffix)) and list_file: + yield rel_path + + return _list_dir_or_file(dir_path, list_dir, list_file, suffix, + recursive) + class MemcachedBackend(BaseStorageBackend): """Memcached storage backend. @@ -451,6 +532,9 @@ def list_dir_or_file(self, """Scan a directory to find the interested directories or files in arbitrary order. + Note: + :meth:`list_dir_or_file` returns the path relative to ``dir_path``. + Args: dir_path (str | obj:`Path`): Path of the directory. list_dir (bool): List the directories. Default: True. @@ -877,6 +961,9 @@ def list_dir_or_file(self, """Scan a directory to find the interested directories or files in arbitrary order. + Note: + :meth:`list_dir_or_file` returns the path relative to ``dir_path``. + Args: dir_path (str | obj:`Path`): Path of the directory. list_dir (bool): List the directories. Default: True. diff --git a/tests/test_fileclient.py b/tests/test_fileclient.py index 6d9a29b3b1..d79e7706ce 100644 --- a/tests/test_fileclient.py +++ b/tests/test_fileclient.py @@ -1,6 +1,8 @@ +import os import os.path as osp import sys import tempfile +from contextlib import contextmanager from pathlib import Path from unittest.mock import MagicMock, patch @@ -15,6 +17,41 @@ sys.modules['mc'] = MagicMock() +@contextmanager +def build_temporary_directory(): + """Build a temporary directory containing many files to test + ``FileClient.list_dir_or_file``. + + . \n + | -- dir1 \n + | -- | -- text3.txt \n + | -- dir2 \n + | -- | -- dir3 \n + | -- | -- | -- text4.txt \n + | -- | -- img.jpg \n + | -- text1.txt \n + | -- text2.txt \n + """ + with tempfile.TemporaryDirectory() as tmp_dir: + text1 = Path(tmp_dir) / 'text1.txt' + text1.open('w').write('text1') + text2 = Path(tmp_dir) / 'text2.txt' + text2.open('w').write('text2') + dir1 = Path(tmp_dir) / 'dir1' + dir1.mkdir() + text3 = dir1 / 'text3.txt' + text3.open('w').write('text3') + dir2 = Path(tmp_dir) / 'dir2' + dir2.mkdir() + jpg1 = dir2 / 'img.jpg' + jpg1.open('wb').write(b'img') + dir3 = dir2 / 'dir3' + dir3.mkdir() + text4 = dir3 / 'text4.txt' + text4.open('w').write('text4') + yield tmp_dir + + class MockS3Client: def __init__(self, enable_mc=True): @@ -37,6 +74,25 @@ def Get(self, filepath): content = f.read() return content + def put(self): + pass + + def delete(self): + pass + + def contains(self): + pass + + def isdir(self): + pass + + def list(self, dir_path): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + yield entry.path + elif osp.isdir(entry.path): + yield entry.path + '/' + class MockMemcachedClient: @@ -119,23 +175,7 @@ def test_disk_backend(self): osp.join(disk_dir, 'dir', 'file') # test `list_dir_or_file` - with tempfile.TemporaryDirectory() as tmp_dir: - text1 = Path(tmp_dir) / 'text1.txt' - text1.open('w').write('text1') - text2 = Path(tmp_dir) / 'text2.txt' - text2.open('w').write('text2') - dir1 = Path(tmp_dir) / 'dir1' - dir1.mkdir() - text3 = dir1 / 'text3.txt' - text3.open('w').write('text3') - dir2 = Path(tmp_dir) / 'dir2' - dir2.mkdir() - jpg1 = dir2 / 'img.jpg' - jpg1.open('wb').write(b'img') - dir3 = dir2 / 'dir3' - dir3.mkdir() - text4 = dir3 / 'text4.txt' - text4.open('w').write('text4') + with build_temporary_directory() as tmp_dir: # 1. list directories and files assert set(disk_backend.list_dir_or_file(tmp_dir)) == set( ['dir1', 'dir2', 'text1.txt', 'text2.txt']) @@ -281,41 +321,55 @@ def test_petrel_backend(self, backend, prefix): == petrel_path # test `get` - petrel_backend.client._client.Get = MagicMock(return_value=b'petrel') - assert petrel_backend.get(petrel_path) == b'petrel' - petrel_backend.client._client.Get.assert_called_with(petrel_path) + with patch.object( + petrel_backend.client._client, 'Get', + return_value=b'petrel') as mock_get: + assert petrel_backend.get(petrel_path) == b'petrel' + mock_get.assert_called_once_with(petrel_path) # test `get_text` - petrel_backend.client._client.Get = MagicMock(return_value=b'petrel') - assert petrel_backend.get_text(petrel_path) == 'petrel' - petrel_backend.client._client.Get.assert_called_with(petrel_path) + with patch.object( + petrel_backend.client._client, 'Get', + return_value=b'petrel') as mock_get: + assert petrel_backend.get_text(petrel_path) == 'petrel' + mock_get.assert_called_once_with(petrel_path) # test `put` - petrel_backend.client._client.put = MagicMock() - petrel_backend.put(b'petrel', petrel_path) - petrel_backend.client._client.put.assert_called_with( - petrel_path, b'petrel') + with patch.object(petrel_backend.client._client, 'put') as mock_put: + petrel_backend.put(b'petrel', petrel_path) + mock_put.assert_called_once_with(petrel_path, b'petrel') # test `put_text` - petrel_backend.client._client.put = MagicMock() - petrel_backend.put_text('petrel', petrel_path) - petrel_backend.client._client.put.assert_called_with( - petrel_path, b'petrel') + with patch.object(petrel_backend.client._client, 'put') as mock_put: + petrel_backend.put_text('petrel', petrel_path) + mock_put.assert_called_once_with(petrel_path, b'petrel') # test `remove` - petrel_backend.client._client.delete = MagicMock() - petrel_backend.remove(petrel_path) - petrel_backend.client._client.delete.assert_called_with(petrel_path) + with patch.object(petrel_backend.client._client, + 'delete') as mock_delete: + petrel_backend.remove(petrel_path) + mock_delete.assert_called_once_with(petrel_path) # test `exists` - petrel_backend.client._client.contains = MagicMock(return_value=True) - assert petrel_backend.exists(petrel_path) - petrel_backend.client._client.contains.assert_called_with(petrel_path) + with patch.object( + petrel_backend.client._client, 'contains', + return_value=True) as mock_contains: + assert petrel_backend.exists(petrel_path) + mock_contains.assert_called_once_with(petrel_path) + + # test `isdir` + with patch.object( + petrel_backend.client._client, 'isdir', + return_value=True) as mock_isdir: + assert petrel_backend.isdir(petrel_dir) + mock_isdir.assert_called_once_with(petrel_dir) # test `isfile` - petrel_backend.client._client.contains = MagicMock(return_value=True) - assert petrel_backend.isfile(petrel_path) - petrel_backend.client._client.contains.assert_called_with(petrel_path) + with patch.object( + petrel_backend.client._client, 'contains', + return_value=True) as mock_contains: + assert petrel_backend.isfile(petrel_path) + mock_contains.assert_called_once_with(petrel_path) # test `concat_paths` assert petrel_backend.concat_paths(petrel_dir, 'file') == \ @@ -324,11 +378,97 @@ def test_petrel_backend(self, backend, prefix): f'{petrel_dir}/dir/file' # test `get_local_path` - # exist the with block and path will be released - petrel_backend.client._client.contains = MagicMock(return_value=True) - with petrel_backend.get_local_path(petrel_path) as path: - assert Path(path).open('rb').read() == b'petrel' - assert not osp.isfile(path) + with patch.object(petrel_backend.client._client, 'Get', + return_value=b'petrel') as mock_get, \ + patch.object(petrel_backend.client._client, 'contains', + return_value=True) as mock_contains: + with petrel_backend.get_local_path(petrel_path) as path: + assert Path(path).open('rb').read() == b'petrel' + # exist the with block and path will be released + assert not osp.isfile(path) + mock_get.assert_called_once_with(petrel_path) + mock_contains.assert_called_once_with(petrel_path) + + # test `list_dir_or_file` + with build_temporary_directory() as tmp_dir: + # 1. list directories and files + assert set(petrel_backend.list_dir_or_file(tmp_dir)) == set( + ['dir1', 'dir2', 'text1.txt', 'text2.txt']) + # 2. list directories and files recursively + assert set( + petrel_backend.list_dir_or_file( + tmp_dir, recursive=True)) == set([ + 'dir1', + osp.join('dir1', 'text3.txt'), 'dir2', + osp.join('dir2', 'dir3'), + osp.join('dir2', 'dir3', 'text4.txt'), + osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' + ]) + # 3. only list directories + assert set( + petrel_backend.list_dir_or_file( + tmp_dir, list_file=False)) == set(['dir1', 'dir2']) + with pytest.raises( + TypeError, + match='`suffix` should be None when `list_dir` is True'): + # Exception is raised among the `list_dir_or_file` of client, + # so we need to invode the client to trigger the exception + petrel_backend.client.list_dir_or_file( + tmp_dir, list_file=False, suffix='.txt') + # 4. only list directories recursively + assert set( + petrel_backend.list_dir_or_file( + tmp_dir, list_file=False, recursive=True)) == set( + ['dir1', 'dir2', + osp.join('dir2', 'dir3')]) + # 5. only list files + assert set( + petrel_backend.list_dir_or_file(tmp_dir, + list_dir=False)) == set( + ['text1.txt', 'text2.txt']) + # 6. only list files recursively + assert set( + petrel_backend.list_dir_or_file( + tmp_dir, list_dir=False, recursive=True)) == set([ + osp.join('dir1', 'text3.txt'), + osp.join('dir2', 'dir3', 'text4.txt'), + osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' + ]) + # 7. only list files ending with suffix + assert set( + petrel_backend.list_dir_or_file( + tmp_dir, list_dir=False, + suffix='.txt')) == set(['text1.txt', 'text2.txt']) + assert set( + petrel_backend.list_dir_or_file( + tmp_dir, list_dir=False, + suffix=('.txt', + '.jpg'))) == set(['text1.txt', 'text2.txt']) + with pytest.raises( + TypeError, + match='`suffix` must be a string or tuple of strings'): + petrel_backend.client.list_dir_or_file( + tmp_dir, list_dir=False, suffix=['.txt', '.jpg']) + # 8. only list files ending with suffix recursively + assert set( + petrel_backend.list_dir_or_file( + tmp_dir, list_dir=False, suffix='.txt', + recursive=True)) == set([ + osp.join('dir1', 'text3.txt'), + osp.join('dir2', 'dir3', 'text4.txt'), 'text1.txt', + 'text2.txt' + ]) + # 7. only list files ending with suffix + assert set( + petrel_backend.list_dir_or_file( + tmp_dir, + list_dir=False, + suffix=('.txt', '.jpg'), + recursive=True)) == set([ + osp.join('dir1', 'text3.txt'), + osp.join('dir2', 'dir3', 'text4.txt'), + osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' + ]) @patch('mc.MemcachedClient.GetInstance', MockMemcachedClient) @patch('mc.pyvector', MagicMock) From df207d1adaa56c3392010bc9ff89b82ac008e0d7 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Mon, 18 Oct 2021 23:39:05 +0800 Subject: [PATCH 36/46] fix windows ci --- tests/test_fileclient.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/tests/test_fileclient.py b/tests/test_fileclient.py index d79e7706ce..c5c159ccb1 100644 --- a/tests/test_fileclient.py +++ b/tests/test_fileclient.py @@ -89,9 +89,11 @@ def isdir(self): def list(self, dir_path): for entry in os.scandir(dir_path): if not entry.name.startswith('.') and entry.is_file(): - yield entry.path + path = entry.path.replace(os.sep, '/') + yield path elif osp.isdir(entry.path): - yield entry.path + '/' + path = entry.path.replace(os.sep, '/') + yield path + '/' class MockMemcachedClient: @@ -398,11 +400,10 @@ def test_petrel_backend(self, backend, prefix): assert set( petrel_backend.list_dir_or_file( tmp_dir, recursive=True)) == set([ - 'dir1', - osp.join('dir1', 'text3.txt'), 'dir2', - osp.join('dir2', 'dir3'), - osp.join('dir2', 'dir3', 'text4.txt'), - osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' + 'dir1', '/'.join(('dir1', 'text3.txt')), 'dir2', + '/'.join(('dir2', 'dir3')), '/'.join( + ('dir2', 'dir3', 'text4.txt')), '/'.join( + ('dir2', 'img.jpg')), 'text1.txt', 'text2.txt' ]) # 3. only list directories assert set( @@ -419,8 +420,7 @@ def test_petrel_backend(self, backend, prefix): assert set( petrel_backend.list_dir_or_file( tmp_dir, list_file=False, recursive=True)) == set( - ['dir1', 'dir2', - osp.join('dir2', 'dir3')]) + ['dir1', 'dir2', '/'.join(('dir2', 'dir3'))]) # 5. only list files assert set( petrel_backend.list_dir_or_file(tmp_dir, @@ -430,9 +430,9 @@ def test_petrel_backend(self, backend, prefix): assert set( petrel_backend.list_dir_or_file( tmp_dir, list_dir=False, recursive=True)) == set([ - osp.join('dir1', 'text3.txt'), - osp.join('dir2', 'dir3', 'text4.txt'), - osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' + '/'.join(('dir1', 'text3.txt')), '/'.join( + ('dir2', 'dir3', 'text4.txt')), '/'.join( + ('dir2', 'img.jpg')), 'text1.txt', 'text2.txt' ]) # 7. only list files ending with suffix assert set( @@ -454,8 +454,8 @@ def test_petrel_backend(self, backend, prefix): petrel_backend.list_dir_or_file( tmp_dir, list_dir=False, suffix='.txt', recursive=True)) == set([ - osp.join('dir1', 'text3.txt'), - osp.join('dir2', 'dir3', 'text4.txt'), 'text1.txt', + '/'.join(('dir1', 'text3.txt')), '/'.join( + ('dir2', 'dir3', 'text4.txt')), 'text1.txt', 'text2.txt' ]) # 7. only list files ending with suffix @@ -465,9 +465,9 @@ def test_petrel_backend(self, backend, prefix): list_dir=False, suffix=('.txt', '.jpg'), recursive=True)) == set([ - osp.join('dir1', 'text3.txt'), - osp.join('dir2', 'dir3', 'text4.txt'), - osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' + '/'.join(('dir1', 'text3.txt')), '/'.join( + ('dir2', 'dir3', 'text4.txt')), '/'.join( + ('dir2', 'img.jpg')), 'text1.txt', 'text2.txt' ]) @patch('mc.MemcachedClient.GetInstance', MockMemcachedClient) From d29a88da2d244681ec3532f0c929e52a03369cb0 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Tue, 19 Oct 2021 15:48:41 +0800 Subject: [PATCH 37/46] Add return docstring --- mmcv/fileio/file_client.py | 125 ++++++++++++++++++++++++++++++++----- mmcv/fileio/io.py | 4 +- mmcv/fileio/parse.py | 4 +- 3 files changed, 114 insertions(+), 19 deletions(-) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index 6664ef8d67..3287618030 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -134,6 +134,11 @@ def get(self, filepath: Union[str, Path]) -> memoryview: Args: filepath (str or Path): Path to read data. + + Returns: + memoryview: A memory view of expected bytes object to avoid + copying. The memoryview object can be converted to bytes by + ``value_buf.tobytes()``. """ filepath = self._map_path(str(filepath)) filepath = self._format_path(filepath) @@ -148,8 +153,11 @@ def get_text(self, Args: filepath (str or Path): Path to read data. - encoding (str, optional): The encoding format used to open the - ``filepath``. Default: 'utf-8'. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. """ return str(self.get(filepath), encoding=encoding) @@ -173,8 +181,8 @@ def put_text(self, Args: obj (str): Data to be written. filepath (str or Path): Path to write data. - encoding (str, optional): The encoding format used to encode the - ``obj``. Default: 'utf-8'. + encoding (str): The encoding format used to encode the ``obj``. + Default: 'utf-8'. """ self.put(bytes(obj, encoding=encoding), filepath) @@ -193,6 +201,9 @@ def exists(self, filepath: Union[str, Path]) -> bool: Args: filepath (str or Path): Path to be checked whether exists. + + Returns: + bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. """ filepath = self._map_path(str(filepath)) filepath = self._format_path(filepath) @@ -204,6 +215,10 @@ def isdir(self, filepath: Union[str, Path]) -> bool: Args: filepath (str or Path): Path to be checked whether it is a directory. + + Returns: + bool: Return ``True`` if ``filepath`` points to a directory, + ``False`` otherwise. """ filepath = self._map_path(str(filepath)) filepath = self._format_path(filepath) @@ -214,6 +229,10 @@ def isfile(self, filepath: Union[str, Path]) -> bool: Args: filepath (str or Path): Path to be checked whether it is a file. + + Returns: + bool: Return ``True`` if ``filepath`` points to a file, ``False`` + otherwise. """ filepath = self._map_path(str(filepath)) filepath = self._format_path(filepath) @@ -225,6 +244,9 @@ def concat_paths(self, filepath: Union[str, Path], Args: filepath (str or Path): Path to be concatenated. + + Returns: + str: The result of concatenation. """ formatted_paths = [self._format_path(self._map_path(str(filepath)))] for path in filepaths: @@ -234,7 +256,11 @@ def concat_paths(self, filepath: Union[str, Path], @contextmanager def get_local_path(self, filepath: str) -> Iterable[str]: - """Download a file from ``filepath``. + """Download a file from ``filepath`` and return a temporary path. + + ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. Args: filepath (str): Download a file from ``filepath``. @@ -245,6 +271,9 @@ def get_local_path(self, filepath: str) -> Iterable[str]: >>> # the path will be removed >>> with client.get_local_path('s3://path/of/your/file') as path: ... # do something here + + Yields: + Iterable[str]: Only yield one temporary path. """ filepath = self._map_path(str(filepath)) filepath = self._format_path(filepath) @@ -285,6 +314,9 @@ def list_dir_or_file(self, that we are interested in. Default: None. recursive (bool): If set to True, recursively scan the directory. Default: False. + + Yields: + Iterable[str]: A relative path to ``dir_path``. """ dir_path = self._map_path(str(dir_path)) dir_path = self._format_path(dir_path) @@ -422,6 +454,9 @@ def get(self, filepath: Union[str, Path]) -> bytes: Args: filepath (str or Path): Path to read data. + + Returns: + bytes: Expected bytes object. """ filepath = str(filepath) with open(filepath, 'rb') as f: @@ -435,8 +470,11 @@ def get_text(self, Args: filepath (str or Path): Path to read data. - encoding (str, optional): The encoding format used to open the - ``filepath``. Default: 'utf-8'. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. """ filepath = str(filepath) with open(filepath, 'r', encoding=encoding) as f: @@ -463,8 +501,8 @@ def put_text(self, Args: obj (str): Data to be written. filepath (str or Path): Path to write data. - encoding (str, optional): The encoding format used to open the - ``filepath``. Default: 'utf-8'. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. """ filepath = str(filepath) with open(filepath, 'w', encoding=encoding) as f: @@ -484,6 +522,9 @@ def exists(self, filepath: Union[str, Path]) -> bool: Args: filepath (str or Path): Path to be checked whether exists. + + Returns: + bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. """ return osp.exists(str(filepath)) @@ -493,6 +534,10 @@ def isdir(self, filepath: Union[str, Path]) -> bool: Args: filepath (str or Path): Path to be checked whether it is a directory. + + Returns: + bool: Return ``True`` if ``filepath`` points to a directory, + ``False`` otherwise. """ return osp.isdir(str(filepath)) @@ -501,6 +546,10 @@ def isfile(self, filepath: Union[str, Path]) -> bool: Args: filepath (str or Path): Path to be checked whether it is a file. + + Returns: + bool: Return ``True`` if ``filepath`` points to a file, ``False`` + otherwise. """ return osp.isfile(str(filepath)) @@ -513,6 +562,9 @@ def concat_paths(self, filepath: Union[str, Path], Args: filepath (str or Path): Path to be concatenated. + + Returns: + str: The result of concatenation. """ filepath = str(filepath) filepaths = [str(path) for path in filepaths] @@ -543,6 +595,9 @@ def list_dir_or_file(self, that we are interested in. Default: None. recursive (bool): If set to True, recursively scan the directory. Default: False. + + Yields: + Iterable[str]: A relative path to ``dir_path``. """ dir_path = str(dir_path) if list_dir and suffix is not None: @@ -589,6 +644,10 @@ def get_text(self, filepath, encoding='utf-8'): def get_local_path(self, filepath: str) -> Iterable[str]: """Download a file from ``filepath``. + ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + Args: filepath (str): Download a file from ``filepath``. @@ -713,12 +772,13 @@ def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]: Args: uri (str | Path): Uri to be parsed that contains the file prefix. - Returns: - return the prefix of uri if it contains "://" else None. - Examples: >>> FileClient.parse_uri_prefix('s3://path/of/your/file') 's3' + + Returns: + str | None: Return the prefix of uri if the uri contains '://' + else ``None``. """ assert is_filepath(uri) uri = str(uri) @@ -749,6 +809,9 @@ def infer_client(cls, >>> file_client = FileClient.infer_client(uri=uri) >>> file_client_args = {'backend': 'petrel'} >>> file_client = FileClient.infer_client(file_client_args) + + Returns: + FileClient: Instantiated FileClient object. """ assert file_client_args is not None or uri is not None if file_client_args is None: @@ -853,6 +916,10 @@ def get(self, filepath: Union[str, Path]) -> Union[bytes, memoryview]: Args: filepath (str or Path): Path to read data. + + Returns: + bytes | memoryview: Expected bytes object or a memory view of the + bytes object. """ return self.client.get(filepath) @@ -861,8 +928,11 @@ def get_text(self, filepath: Union[str, Path], encoding='utf-8') -> str: Args: filepath (str or Path): Path to read data. - encoding (str, optional): The encoding format used to open the - `filepath`. Default: 'utf-8'. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. """ return self.client.get_text(filepath, encoding) @@ -899,6 +969,9 @@ def exists(self, filepath: Union[str, Path]) -> bool: Args: filepath (str or Path): Path to be checked whether exists. + + Returns: + bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. """ return self.client.exists(filepath) @@ -908,6 +981,10 @@ def isdir(self, filepath: Union[str, Path]) -> bool: Args: filepath (str or Path): Path to be checked whether it is a directory. + + Returns: + bool: Return ``True`` if ``filepath`` points to a directory, + ``False`` otherwise. """ return self.client.isdir(filepath) @@ -916,6 +993,10 @@ def isfile(self, filepath: Union[str, Path]) -> bool: Args: filepath (str or Path): Path to be checked whether it is a file. + + Returns: + bool: Return ``True`` if ``filepath`` points to a file, ``False`` + otherwise. """ return self.client.isfile(filepath) @@ -928,6 +1009,9 @@ def concat_paths(self, filepath: Union[str, Path], Args: filepath (str or Path): Path to be concatenated. + + Returns: + str: The result of concatenation. """ return self.client.concat_paths(filepath, *filepaths) @@ -935,7 +1019,12 @@ def concat_paths(self, filepath: Union[str, Path], def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]: """Download data from ``filepath`` and write the data to local path. - If the ``filepath`` is a local path, just return itself. + ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Note: + If the ``filepath`` is a local path, just return itself. .. warning:: ``get_local_path`` is an experimental interface that may change in @@ -948,6 +1037,9 @@ def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]: >>> file_client = FileClient(prefix='s3') >>> with file_client.get_local_path('s3://bucket/abc.jpg') as path: ... # do something here + + Yields: + Iterable[str]: Only yield one temporary path. """ with self.client.get_local_path(str(filepath)) as local_path: yield local_path @@ -972,6 +1064,9 @@ def list_dir_or_file(self, that we are interested in. Default: None. recursive (bool): If set to True, recursively scan the directory. Default: False. + + Yields: + Iterable[str]: A relative path to ``dir_path``. """ yield from self.client.list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive) diff --git a/mmcv/fileio/io.py b/mmcv/fileio/io.py index 1225b75a29..ce0a62ca78 100644 --- a/mmcv/fileio/io.py +++ b/mmcv/fileio/io.py @@ -21,7 +21,7 @@ def load(file, file_format=None, file_client_args=None, **kwargs): This method provides a unified api for loading data from serialized files. Note: - In v1.3.15 and later, ``load`` supports loading data from serialized + In v1.3.16 and later, ``load`` supports loading data from serialized files those can be storaged in different backends. Args: @@ -73,7 +73,7 @@ def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs): and also supports custom arguments for each file format. Note: - In v1.3.15 and later, ``dump`` supports dumping data as strings or to + In v1.3.16 and later, ``dump`` supports dumping data as strings or to files which is saved to different backends. Args: diff --git a/mmcv/fileio/parse.py b/mmcv/fileio/parse.py index 0f368ab9cb..f60f0d611b 100644 --- a/mmcv/fileio/parse.py +++ b/mmcv/fileio/parse.py @@ -14,7 +14,7 @@ def list_from_file(filename, """Load a text file and parse the content as a list of strings. Note: - In v1.3.15 and later, ``list_from_file`` supports loading a text file + In v1.3.16 and later, ``list_from_file`` supports loading a text file which can be storaged in different backends and parsing the content as a list for strings. @@ -63,7 +63,7 @@ def dict_from_file(filename, the following columns will be parsed as dict values. Note: - In v1.3.15 and later, ``dict_from_file`` supports loading a text file + In v1.3.16 and later, ``dict_from_file`` supports loading a text file which can be storaged in different backends and parsing the content as a dict. From f18a779489395bde17c0536b07e7f0e6c8f3a3fa Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Tue, 19 Oct 2021 16:12:18 +0800 Subject: [PATCH 38/46] polish docstring --- mmcv/fileio/file_client.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index 3287618030..575b89ee8b 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -914,6 +914,12 @@ def _register(backend_cls): def get(self, filepath: Union[str, Path]) -> Union[bytes, memoryview]: """Read data from a given ``filepath`` with 'rb' mode. + Note: + There are two types of return values for ``get``, one is ``bytes`` + and the other is ``memoryview``. The advantage of using memoryview + is that you can avoid copying, and if you want to convert it to + ``bytes``, you can use ``.tobytes()``. + Args: filepath (str or Path): Path to read data. @@ -1039,7 +1045,7 @@ def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]: ... # do something here Yields: - Iterable[str]: Only yield one temporary path. + Iterable[str]: Only yield one path. """ with self.client.get_local_path(str(filepath)) as local_path: yield local_path From b6eb5d16f6358e8657cd21cd15f982de09df0ca4 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Tue, 19 Oct 2021 16:19:50 +0800 Subject: [PATCH 39/46] fix typo --- docs/understand_mmcv/io.md | 2 +- docs_zh_CN/understand_mmcv/io.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/understand_mmcv/io.md b/docs/understand_mmcv/io.md index f09a15a245..898bce5946 100644 --- a/docs/understand_mmcv/io.md +++ b/docs/understand_mmcv/io.md @@ -3,7 +3,7 @@ This module provides two universal API to load and dump files of different formats. ```{note} -Since v1.3.15, the IO modules support loading (dumping) data from (to) different backends, respectively. More details are in PR [#1330](https://github.com/open-mmlab/mmcv/pull/1330). +Since v1.3.16, the IO modules support loading (dumping) data from (to) different backends, respectively. More details are in PR [#1330](https://github.com/open-mmlab/mmcv/pull/1330). ``` ### Load and dump data diff --git a/docs_zh_CN/understand_mmcv/io.md b/docs_zh_CN/understand_mmcv/io.md index b6c2a8b224..161114c836 100644 --- a/docs_zh_CN/understand_mmcv/io.md +++ b/docs_zh_CN/understand_mmcv/io.md @@ -3,7 +3,7 @@ 文件输入输出模块提供了两个通用的 API 接口用于读取和保存不同格式的文件。 ```{note} -在 v1.3.15 及之后的版本中,IO 模块支持从不同后端读取数据和将数据保存至不同后端。更多细节请访问 PR [#1330](https://github.com/open-mmlab/mmcv/pull/1330)。 +在 v1.3.16 及之后的版本中,IO 模块支持从不同后端读取数据并支持将数据至不同后端。更多细节请访问 PR [#1330](https://github.com/open-mmlab/mmcv/pull/1330)。 ``` ### 读取和保存数据 From 150d504fac35bab354ec4ea88bb820f20a7ad163 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Tue, 19 Oct 2021 16:21:44 +0800 Subject: [PATCH 40/46] fix typo --- docs_zh_CN/understand_mmcv/io.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs_zh_CN/understand_mmcv/io.md b/docs_zh_CN/understand_mmcv/io.md index 161114c836..95d111ccf8 100644 --- a/docs_zh_CN/understand_mmcv/io.md +++ b/docs_zh_CN/understand_mmcv/io.md @@ -69,7 +69,7 @@ class TxtHandler1(mmcv.BaseFileHandler): return str(obj) ``` -举 `PickleHandler` 为例 +以 `PickleHandler` 为例 ```python import pickle @@ -109,7 +109,7 @@ e ``` #### 从硬盘读取 -使用 `list_from_file` 读取 `a.txt` 。 +使用 `list_from_file` 读取 `a.txt` ```python >>> mmcv.list_from_file('a.txt') @@ -130,7 +130,7 @@ e 3 panda ``` -使用 `dict_from_file` 读取 `b.txt`。 +使用 `dict_from_file` 读取 `b.txt` ```python >>> mmcv.dict_from_file('b.txt') @@ -154,7 +154,7 @@ e ['/mnt/a', '/mnt/b', '/mnt/c', '/mnt/d', '/mnt/e'] ``` -使用 `dict_from_file` 读取 `b.txt`。 +使用 `dict_from_file` 读取 `b.txt` ```python >>> mmcv.dict_from_file('s3://bucket-name/b.txt') From 208ff82032962facbfbb7223a9eae7ef7a78baa8 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Wed, 20 Oct 2021 22:07:04 +0800 Subject: [PATCH 41/46] deprecate the conversion from Path to str --- mmcv/fileio/file_client.py | 14 +++----------- mmcv/fileio/handlers/base.py | 13 ++++++------- mmcv/fileio/handlers/pickle_handler.py | 2 +- mmcv/fileio/io.py | 4 ++-- 4 files changed, 12 insertions(+), 21 deletions(-) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index 575b89ee8b..c4a90bd797 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -458,7 +458,6 @@ def get(self, filepath: Union[str, Path]) -> bytes: Returns: bytes: Expected bytes object. """ - filepath = str(filepath) with open(filepath, 'rb') as f: value_buf = f.read() return value_buf @@ -476,7 +475,6 @@ def get_text(self, Returns: str: Expected text reading from ``filepath``. """ - filepath = str(filepath) with open(filepath, 'r', encoding=encoding) as f: value_buf = f.read() return value_buf @@ -488,7 +486,6 @@ def put(self, obj: bytes, filepath: Union[str, Path]) -> None: obj (bytes): Data to be written. filepath (str or Path): Path to write data. """ - filepath = str(filepath) with open(filepath, 'wb') as f: f.write(obj) @@ -504,7 +501,6 @@ def put_text(self, encoding (str): The encoding format used to open the ``filepath``. Default: 'utf-8'. """ - filepath = str(filepath) with open(filepath, 'w', encoding=encoding) as f: f.write(obj) @@ -514,7 +510,6 @@ def remove(self, filepath: Union[str, Path]) -> None: Args: filepath (str or Path): Path to be removed. """ - filepath = str(filepath) os.remove(filepath) def exists(self, filepath: Union[str, Path]) -> bool: @@ -526,7 +521,7 @@ def exists(self, filepath: Union[str, Path]) -> bool: Returns: bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. """ - return osp.exists(str(filepath)) + return osp.exists(filepath) def isdir(self, filepath: Union[str, Path]) -> bool: """Check whether a file path is a directory. @@ -539,7 +534,7 @@ def isdir(self, filepath: Union[str, Path]) -> bool: bool: Return ``True`` if ``filepath`` points to a directory, ``False`` otherwise. """ - return osp.isdir(str(filepath)) + return osp.isdir(filepath) def isfile(self, filepath: Union[str, Path]) -> bool: """Check a ``filepath`` whether it is a file. @@ -551,7 +546,7 @@ def isfile(self, filepath: Union[str, Path]) -> bool: bool: Return ``True`` if ``filepath`` points to a file, ``False`` otherwise. """ - return osp.isfile(str(filepath)) + return osp.isfile(filepath) def concat_paths(self, filepath: Union[str, Path], *filepaths: Union[str, Path]) -> str: @@ -566,8 +561,6 @@ def concat_paths(self, filepath: Union[str, Path], Returns: str: The result of concatenation. """ - filepath = str(filepath) - filepaths = [str(path) for path in filepaths] return osp.join(filepath, *filepaths) @contextmanager @@ -599,7 +592,6 @@ def list_dir_or_file(self, Yields: Iterable[str]: A relative path to ``dir_path``. """ - dir_path = str(dir_path) if list_dir and suffix is not None: raise TypeError('`suffix` should be None when `list_dir` is True') diff --git a/mmcv/fileio/handlers/base.py b/mmcv/fileio/handlers/base.py index 22d66d5b1b..5f28b0acc6 100644 --- a/mmcv/fileio/handlers/base.py +++ b/mmcv/fileio/handlers/base.py @@ -3,13 +3,12 @@ class BaseFileHandler(metaclass=ABCMeta): - # is_str_like_obj is a flag to mark which type of file object is processed, - # bytes-like object or str-like object. For example, pickle only process - # the bytes-like object and json only process the str-like object. The flag - # will be used to check which type of buffer is used. If str-like object, - # StringIO will be used. If bytes-like object, BytesIO will be used. The - # usage of the flag can be found in `mmcv.load` or `mmcv.dump` - is_str_like_obj = True + # `str_like` is a flag to indicate whether the type of file object is + # str-like object or bytes-like object. Pickle only processes bytes-like + # objects but json only processes str-like object. If it is str-like + # object, `StringIO` will be used to process the buffer. + + str_like = True @abstractmethod def load_from_fileobj(self, file, **kwargs): diff --git a/mmcv/fileio/handlers/pickle_handler.py b/mmcv/fileio/handlers/pickle_handler.py index 648bf22b9c..b37c79bed4 100644 --- a/mmcv/fileio/handlers/pickle_handler.py +++ b/mmcv/fileio/handlers/pickle_handler.py @@ -6,7 +6,7 @@ class PickleHandler(BaseFileHandler): - is_str_like_obj = False + str_like = False def load_from_fileobj(self, file, **kwargs): return pickle.load(file, **kwargs) diff --git a/mmcv/fileio/io.py b/mmcv/fileio/io.py index ce0a62ca78..aaefde58aa 100644 --- a/mmcv/fileio/io.py +++ b/mmcv/fileio/io.py @@ -53,7 +53,7 @@ def load(file, file_format=None, file_client_args=None, **kwargs): handler = file_handlers[file_format] if is_str(file): file_client = FileClient.infer_client(file_client_args, file) - if handler.is_str_like_obj: + if handler.str_like: with StringIO(file_client.get_text(file)) as f: obj = handler.load_from_fileobj(f, **kwargs) else: @@ -109,7 +109,7 @@ def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs): return handler.dump_to_str(obj, **kwargs) elif is_str(file): file_client = FileClient.infer_client(file_client_args, file) - if handler.is_str_like_obj: + if handler.str_like: with StringIO() as f: handler.dump_to_fileobj(obj, f, **kwargs) file_client.put_text(f.getvalue(), file) From 9ecfc1270b2c0d56ef4da81a7517186ab980f2b7 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Fri, 22 Oct 2021 18:28:19 +0800 Subject: [PATCH 42/46] add docs for loading checkpoints with FileClient --- docs/understand_mmcv/io.md | 78 ++++++++++++++++++++++++++++++++ docs_zh_CN/understand_mmcv/io.md | 78 +++++++++++++++++++++++++++++++- mmcv/fileio/file_client.py | 27 +++++++++++ 3 files changed, 181 insertions(+), 2 deletions(-) diff --git a/docs/understand_mmcv/io.md b/docs/understand_mmcv/io.md index 898bce5946..f6c28dd425 100644 --- a/docs/understand_mmcv/io.md +++ b/docs/understand_mmcv/io.md @@ -167,3 +167,81 @@ Use `dict_from_file` to load the dict from `s3://bucket-name/b.txt`. >>> mmcv.dict_from_file('s3://bucket-name/b.txt', key_type=int) {1: 'cat', 2: ['dog', 'cow'], 3: 'panda'} ``` + +### Load and dump checkpoints + +#### Load checkpoints from disk or save to disk + +We can read the checkpoints from disk or save to disk in the following way. + +```python +import torch + +filepath1 = '/path/of/your/checkpoint1.pth' +filepath2 = '/path/of/your/checkpoint2.pth' +# read from filepath1 +checkpoint = torch.load(filepath1) +# save to filepath2 +torch.save(checkpoint, filepath2) +``` + +MMCV provides many backends. `HardDiskBackend` is one of them and we can use it to read or save checkpoints. + +```python +import io +from mmcv.fileio.file_client import HardDiskBackend + +disk_backend = HardDiskBackend() +with io.BytesIO(disk_backend.get(filepath1)) as buffer: + checkpoint = torch.load(buffer) +with io.BytesIO() as buffer: + torch.save(checkpoint, f) + disk_backend.put(f.getvalue(), filepath2) +``` + +If we want to implement an interface which automatically select the corresponding +backend based on the file path, we can use the `FileClient`. +For example, we want to implement two methods for reading checkpoints as well as saving checkpoints, +which need to support different types of file paths, either disk paths, network paths or other paths. + +```python +from mmcv.fileio.file_client import FileClient + +def load_checkpoint(path): + file_client = FileClient.infer(uri=path) + with io.BytesIO(file_client.get(path)) as buffer: + checkpoint = torch.load(buffer) + return checkpoint + +def save_checkpoint(checkpoint, path): + with io.BytesIO() as buffer: + torch.save(checkpoint, buffer) + file_client.put(buffer.getvalue(), path) + +file_client = FileClient.infer_client(uri=filepath1) +checkpoint = load_checkpoint(filepath1) +save_checkpoint(checkpoint, filepath2) +``` + +#### Load checkpoints from the Internet + +```{note} +Currently, it only supports reading checkpoints from the Internet, and does not support saving checkpoints to the Internet. +``` + +```python +import io +import torch +from mmcv.fileio.file_client import HTTPBackend, FileClient + +filepath = 'http://path/of/your/checkpoint.pth' +checkpoint = torch.utils.model_zoo.load_url(filepath) + +http_backend = HTTPBackend() +with io.BytesIO(http_backend.get(filepath)) as buffer: + checkpoint = torch.load(buffer) + +file_client = FileClient.infer_client(uri=filepath) +with io.BytesIO(file_client.get(filepath)) as buffer: + checkpoint = torch.load(buffer) +``` diff --git a/docs_zh_CN/understand_mmcv/io.md b/docs_zh_CN/understand_mmcv/io.md index 95d111ccf8..8c708045b8 100644 --- a/docs_zh_CN/understand_mmcv/io.md +++ b/docs_zh_CN/understand_mmcv/io.md @@ -122,7 +122,7 @@ e ['/mnt/a', '/mnt/b', '/mnt/c', '/mnt/d', '/mnt/e'] ``` -同样, `b.txt` 也是文本文件,一共有3行内容。 +同样, `b.txt` 也是文本文件,一共有3行内容 ``` 1 cat @@ -141,7 +141,7 @@ e #### 从其他后端读取 -使用 `list_from_file` 读取 `s3://bucket-name/a.txt` 。 +使用 `list_from_file` 读取 `s3://bucket-name/a.txt` ```python >>> mmcv.list_from_file('s3://bucket-name/a.txt') @@ -162,3 +162,77 @@ e >>> mmcv.dict_from_file('s3://bucket-name/b.txt', key_type=int) {1: 'cat', 2: ['dog', 'cow'], 3: 'panda'} ``` + +### 读取和保存权重文件 +#### 从硬盘读取权重文件或者将权重文件保存至硬盘 + +我们可以通过下面的方式从磁盘读取权重文件或者将权重文件保存至磁盘 + +```python +import torch + +filepath1 = '/path/of/your/checkpoint1.pth' +filepath2 = '/path/of/your/checkpoint2.pth' +# 从 filepath1 读取权重文件 +checkpoint = torch.load(filepath1) +# 将权重文件保存至 filepath2 +torch.save(checkpoint, filepath2) +``` + +MMCV 提供了很多后端,`HardDiskBackend` 是其中一个,我们可以通过它来读取或者保存权重文件。 +```python +import io +from mmcv.fileio.file_client import HardDiskBackend + +disk_backend = HardDiskBackend() +with io.BytesIO(disk_backend.get(filepath1)) as buffer: + checkpoint = torch.load(buffer) +with io.BytesIO() as buffer: + torch.save(checkpoint, f) + disk_backend.put(f.getvalue(), filepath2) +``` + +如果我们想在接口中实现根据文件路径自动选择对应的后端,我们可以使用 `FileClient`。 +例如,我们想实现两个方法,分别是读取权重以及保存权重,它们需支持不同类型的文件路径,可以是磁盘路径,也可以是网络路径或者其他路径。 + +```python +from mmcv.fileio.file_client import FileClient + +def load_checkpoint(path): + file_client = FileClient.infer(uri=path) + with io.BytesIO(file_client.get(path)) as buffer: + checkpoint = torch.load(buffer) + return checkpoint + +def save_checkpoint(checkpoint, path): + with io.BytesIO() as buffer: + torch.save(checkpoint, buffer) + file_client.put(buffer.getvalue(), path) + +file_client = FileClient.infer_client(uri=filepath1) +checkpoint = load_checkpoint(filepath1) +save_checkpoint(checkpoint, filepath2) +``` + +#### 从网络远端读取权重文件 + +```{note} +目前只支持从网络远端读取权重文件,暂不支持将权重文件写入网络远端 +``` + +```python +import io +import torch +from mmcv.fileio.file_client import HTTPBackend, FileClient + +filepath = 'http://path/of/your/checkpoint.pth' +checkpoint = torch.utils.model_zoo.load_url(filepath) + +http_backend = HTTPBackend() +with io.BytesIO(http_backend.get(filepath)) as buffer: + checkpoint = torch.load(buffer) + +file_client = FileClient.infer_client(uri=filepath) +with io.BytesIO(file_client.get(filepath)) as buffer: + checkpoint = torch.load(buffer) +``` diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index c4a90bd797..18bb9eb9a3 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -192,6 +192,11 @@ def remove(self, filepath: Union[str, Path]) -> None: Args: filepath (str or Path): Path to be removed. """ + if not hasattr(self._client, 'delete'): + NotImplementedError( + ('Current version of Petrel has not supported the `delete` ' + 'method, please use a higher version or dev branch instead.')) + filepath = self._map_path(str(filepath)) filepath = self._format_path(filepath) self._client.delete(filepath) @@ -205,6 +210,13 @@ def exists(self, filepath: Union[str, Path]) -> bool: Returns: bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. """ + if not (hasattr(self._client, 'delete') + and hasattr(self._client, 'isdir')): + NotImplementedError( + ('Current version of Petrel has not supported the `contains` ' + '`isdir` method, please use a higher version or dev branch ' + 'instead.')) + filepath = self._map_path(str(filepath)) filepath = self._format_path(filepath) return self._client.contains(filepath) or self._client.isdir(filepath) @@ -220,6 +232,11 @@ def isdir(self, filepath: Union[str, Path]) -> bool: bool: Return ``True`` if ``filepath`` points to a directory, ``False`` otherwise. """ + if not hasattr(self._client, 'isdir'): + NotImplementedError( + ('Current version of Petrel has not supported the `isdir` ' + 'method, please use a higher version or dev branch instead.')) + filepath = self._map_path(str(filepath)) filepath = self._format_path(filepath) return self._client.isdir(filepath) @@ -234,6 +251,11 @@ def isfile(self, filepath: Union[str, Path]) -> bool: bool: Return ``True`` if ``filepath`` points to a file, ``False`` otherwise. """ + if not hasattr(self._client, 'contains'): + NotImplementedError( + ('Current version of Petrel has not supported the `contains` ' + 'method, please use a higher version or dev branch instead.')) + filepath = self._map_path(str(filepath)) filepath = self._format_path(filepath) return self._client.contains(filepath) @@ -318,6 +340,11 @@ def list_dir_or_file(self, Yields: Iterable[str]: A relative path to ``dir_path``. """ + if not hasattr(self._client, 'list'): + NotImplementedError( + ('Current version of Petrel has not supported the `list` ' + 'method, please use a higher version or dev branch instead.')) + dir_path = self._map_path(str(dir_path)) dir_path = self._format_path(dir_path) if list_dir and suffix is not None: From 38559f19eb4278648c0bb10009649ec1cf44fe0a Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Fri, 22 Oct 2021 22:10:07 +0800 Subject: [PATCH 43/46] refactor map_path --- docs_zh_CN/understand_mmcv/io.md | 2 ++ mmcv/fileio/file_client.py | 42 +++++++++++++++++--------------- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/docs_zh_CN/understand_mmcv/io.md b/docs_zh_CN/understand_mmcv/io.md index 8c708045b8..0e5002f828 100644 --- a/docs_zh_CN/understand_mmcv/io.md +++ b/docs_zh_CN/understand_mmcv/io.md @@ -164,6 +164,7 @@ e ``` ### 读取和保存权重文件 + #### 从硬盘读取权重文件或者将权重文件保存至硬盘 我们可以通过下面的方式从磁盘读取权重文件或者将权重文件保存至磁盘 @@ -180,6 +181,7 @@ torch.save(checkpoint, filepath2) ``` MMCV 提供了很多后端,`HardDiskBackend` 是其中一个,我们可以通过它来读取或者保存权重文件。 + ```python import io from mmcv.fileio.file_client import HardDiskBackend diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index 18bb9eb9a3..0dc2d873b9 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -105,12 +105,14 @@ def __init__(self, assert isinstance(path_mapping, dict) or path_mapping is None self.path_mapping = path_mapping - def _map_path(self, filepath: str) -> str: - """Replace the prefix of ``filepath`` with :attr:`path_mapping`. + def _map_path(self, filepath: Union[str, Path]) -> str: + """Map ``filepath`` to a string path whose prefix will be replaced by + :attr:`self.path_mapping`. Args: filepath (str): Path to be mapped. """ + filepath = str(filepath) if self.path_mapping is not None: for k, v in self.path_mapping.items(): filepath = filepath.replace(k, v) @@ -140,7 +142,7 @@ def get(self, filepath: Union[str, Path]) -> memoryview: copying. The memoryview object can be converted to bytes by ``value_buf.tobytes()``. """ - filepath = self._map_path(str(filepath)) + filepath = self._map_path(filepath) filepath = self._format_path(filepath) value = self._client.Get(filepath) value_buf = memoryview(value) @@ -168,7 +170,7 @@ def put(self, obj: bytes, filepath: Union[str, Path]) -> None: obj (bytes): Data to be saved. filepath (str or Path): Path to write data. """ - filepath = self._map_path(str(filepath)) + filepath = self._map_path(filepath) filepath = self._format_path(filepath) self._client.put(filepath, obj) @@ -197,7 +199,7 @@ def remove(self, filepath: Union[str, Path]) -> None: ('Current version of Petrel has not supported the `delete` ' 'method, please use a higher version or dev branch instead.')) - filepath = self._map_path(str(filepath)) + filepath = self._map_path(filepath) filepath = self._format_path(filepath) self._client.delete(filepath) @@ -217,7 +219,7 @@ def exists(self, filepath: Union[str, Path]) -> bool: '`isdir` method, please use a higher version or dev branch ' 'instead.')) - filepath = self._map_path(str(filepath)) + filepath = self._map_path(filepath) filepath = self._format_path(filepath) return self._client.contains(filepath) or self._client.isdir(filepath) @@ -237,7 +239,7 @@ def isdir(self, filepath: Union[str, Path]) -> bool: ('Current version of Petrel has not supported the `isdir` ' 'method, please use a higher version or dev branch instead.')) - filepath = self._map_path(str(filepath)) + filepath = self._map_path(filepath) filepath = self._format_path(filepath) return self._client.isdir(filepath) @@ -256,7 +258,7 @@ def isfile(self, filepath: Union[str, Path]) -> bool: ('Current version of Petrel has not supported the `contains` ' 'method, please use a higher version or dev branch instead.')) - filepath = self._map_path(str(filepath)) + filepath = self._map_path(filepath) filepath = self._format_path(filepath) return self._client.contains(filepath) @@ -268,16 +270,15 @@ def concat_paths(self, filepath: Union[str, Path], filepath (str or Path): Path to be concatenated. Returns: - str: The result of concatenation. + str: The result after concatenation. """ - formatted_paths = [self._format_path(self._map_path(str(filepath)))] + formatted_paths = [self._format_path(self._map_path(filepath))] for path in filepaths: - formatted_paths.append( - self._format_path(self._map_path(str(path)))) + formatted_paths.append(self._format_path(self._map_path(path))) return '/'.join(formatted_paths) @contextmanager - def get_local_path(self, filepath: str) -> Iterable[str]: + def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]: """Download a file from ``filepath`` and return a temporary path. ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It @@ -285,7 +286,7 @@ def get_local_path(self, filepath: str) -> Iterable[str]: ``with`` statement, the temporary path will be released. Args: - filepath (str): Download a file from ``filepath``. + filepath (str | Path): Download a file from ``filepath``. Examples: >>> client = PetrelBackend() @@ -297,7 +298,7 @@ def get_local_path(self, filepath: str) -> Iterable[str]: Yields: Iterable[str]: Only yield one temporary path. """ - filepath = self._map_path(str(filepath)) + filepath = self._map_path(filepath) filepath = self._format_path(filepath) assert self.isfile(filepath) try: @@ -329,7 +330,7 @@ def list_dir_or_file(self, suffix '/' which is consistent with other backends. Args: - dir_path (str | obj:`Path`): Path of the directory. + dir_path (str | Path): Path of the directory. list_dir (bool): List the directories. Default: True. list_file (bool): List the path of files. Default: True. suffix (str or tuple[str], optional): File suffix @@ -345,7 +346,7 @@ def list_dir_or_file(self, ('Current version of Petrel has not supported the `list` ' 'method, please use a higher version or dev branch instead.')) - dir_path = self._map_path(str(dir_path)) + dir_path = self._map_path(dir_path) dir_path = self._format_path(dir_path) if list_dir and suffix is not None: raise TypeError('`suffix` should be None when `list_dir` is True') @@ -591,7 +592,8 @@ def concat_paths(self, filepath: Union[str, Path], return osp.join(filepath, *filepaths) @contextmanager - def get_local_path(self, filepath: str) -> Iterable[str]: + def get_local_path( + self, filepath: Union[str, Path]) -> Iterable[Union[str, Path]]: """Only for unified API and do nothing.""" yield filepath @@ -608,7 +610,7 @@ def list_dir_or_file(self, :meth:`list_dir_or_file` returns the path relative to ``dir_path``. Args: - dir_path (str | obj:`Path`): Path of the directory. + dir_path (str | Path): Path of the directory. list_dir (bool): List the directories. Default: True. list_file (bool): List the path of files. Default: True. suffix (str or tuple[str], optional): File suffix @@ -1082,7 +1084,7 @@ def list_dir_or_file(self, :meth:`list_dir_or_file` returns the path relative to ``dir_path``. Args: - dir_path (str | obj:`Path`): Path of the directory. + dir_path (str | Path): Path of the directory. list_dir (bool): List the directories. Default: True. list_file (bool): List the path of files. Default: True. suffix (str or tuple[str], optional): File suffix From ea3238878007f47a3530d1e3b02642c1c1e7ed9e Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Sat, 23 Oct 2021 00:29:48 +0800 Subject: [PATCH 44/46] add _ensure_methods to ensure methods have been implemented --- mmcv/fileio/file_client.py | 49 +++++++++++++++++++++----------------- tests/test_fileclient.py | 10 ++++++++ 2 files changed, 37 insertions(+), 22 deletions(-) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index 0dc2d873b9..640fd67b32 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -6,11 +6,13 @@ import tempfile import warnings from abc import ABCMeta, abstractmethod +from collections.abc import Sequence from contextlib import contextmanager from pathlib import Path from typing import Iterable, Iterator, Optional, Tuple, Union from urllib.request import urlopen +from mmcv.utils.misc import is_seq_of from mmcv.utils.path import is_filepath @@ -188,16 +190,33 @@ def put_text(self, """ self.put(bytes(obj, encoding=encoding), filepath) + def _ensure_methods(self, method_names: Union[str, Sequence]): + """Ensure that methods have been implemented before called. + + Args: + method_names (str | Sequence): The name of method or the list of + name of method. + """ + if isinstance(method_names, str): + method_names = [method_names] + else: + assert is_seq_of(method_names, str) + + for method_name in method_names: + if not (hasattr(self._client, method_name) + and callable(getattr(self._client, method_name))): + raise NotImplementedError( + ('Current version of Petrel Python SDK has not supported ' + f'the `{method_name}` method, please use a higher version' + ' or dev branch instead.')) + def remove(self, filepath: Union[str, Path]) -> None: """Remove a file. Args: filepath (str or Path): Path to be removed. """ - if not hasattr(self._client, 'delete'): - NotImplementedError( - ('Current version of Petrel has not supported the `delete` ' - 'method, please use a higher version or dev branch instead.')) + self._ensure_methods('delete') filepath = self._map_path(filepath) filepath = self._format_path(filepath) @@ -212,12 +231,7 @@ def exists(self, filepath: Union[str, Path]) -> bool: Returns: bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. """ - if not (hasattr(self._client, 'delete') - and hasattr(self._client, 'isdir')): - NotImplementedError( - ('Current version of Petrel has not supported the `contains` ' - '`isdir` method, please use a higher version or dev branch ' - 'instead.')) + self._ensure_methods(['contains', 'isdir']) filepath = self._map_path(filepath) filepath = self._format_path(filepath) @@ -234,10 +248,7 @@ def isdir(self, filepath: Union[str, Path]) -> bool: bool: Return ``True`` if ``filepath`` points to a directory, ``False`` otherwise. """ - if not hasattr(self._client, 'isdir'): - NotImplementedError( - ('Current version of Petrel has not supported the `isdir` ' - 'method, please use a higher version or dev branch instead.')) + self._ensure_methods('isdir') filepath = self._map_path(filepath) filepath = self._format_path(filepath) @@ -253,10 +264,7 @@ def isfile(self, filepath: Union[str, Path]) -> bool: bool: Return ``True`` if ``filepath`` points to a file, ``False`` otherwise. """ - if not hasattr(self._client, 'contains'): - NotImplementedError( - ('Current version of Petrel has not supported the `contains` ' - 'method, please use a higher version or dev branch instead.')) + self._ensure_methods('contains') filepath = self._map_path(filepath) filepath = self._format_path(filepath) @@ -341,10 +349,7 @@ def list_dir_or_file(self, Yields: Iterable[str]: A relative path to ``dir_path``. """ - if not hasattr(self._client, 'list'): - NotImplementedError( - ('Current version of Petrel has not supported the `list` ' - 'method, please use a higher version or dev branch instead.')) + self._ensure_methods('list') dir_path = self._map_path(dir_path) dir_path = self._format_path(dir_path) diff --git a/tests/test_fileclient.py b/tests/test_fileclient.py index c5c159ccb1..b6bdd66b59 100644 --- a/tests/test_fileclient.py +++ b/tests/test_fileclient.py @@ -346,6 +346,16 @@ def test_petrel_backend(self, backend, prefix): petrel_backend.put_text('petrel', petrel_path) mock_put.assert_called_once_with(petrel_path, b'petrel') + # test `_ensure_methods` + with pytest.raises(NotImplementedError): + petrel_backend.client._ensure_methods('unimplemented_method') + with pytest.raises(NotImplementedError): + # `contains` is implemented but `unimplemented_method` not + petrel_backend.client._ensure_methods( + ['contains', 'unimplemented_method']) + petrel_backend.client._ensure_methods('contains') + petrel_backend.client._ensure_methods(['contains', 'delete']) + # test `remove` with patch.object(petrel_backend.client._client, 'delete') as mock_delete: From a8cc11d58256909124ec7565d35e3abfe0cce7f8 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Sat, 23 Oct 2021 02:22:54 +0800 Subject: [PATCH 45/46] fix list_dir_or_file --- mmcv/fileio/file_client.py | 39 +++++++++++++++++++++++--------------- tests/test_fileclient.py | 23 ++++++++++++---------- 2 files changed, 37 insertions(+), 25 deletions(-) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index 640fd67b32..0848fb82f2 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -190,7 +190,7 @@ def put_text(self, """ self.put(bytes(obj, encoding=encoding), filepath) - def _ensure_methods(self, method_names: Union[str, Sequence]): + def _ensure_method_implemented(self, method_names: Union[str, Sequence]): """Ensure that methods have been implemented before called. Args: @@ -216,7 +216,7 @@ def remove(self, filepath: Union[str, Path]) -> None: Args: filepath (str or Path): Path to be removed. """ - self._ensure_methods('delete') + self._ensure_method_implemented('delete') filepath = self._map_path(filepath) filepath = self._format_path(filepath) @@ -231,7 +231,7 @@ def exists(self, filepath: Union[str, Path]) -> bool: Returns: bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. """ - self._ensure_methods(['contains', 'isdir']) + self._ensure_method_implemented(['contains', 'isdir']) filepath = self._map_path(filepath) filepath = self._format_path(filepath) @@ -248,7 +248,7 @@ def isdir(self, filepath: Union[str, Path]) -> bool: bool: Return ``True`` if ``filepath`` points to a directory, ``False`` otherwise. """ - self._ensure_methods('isdir') + self._ensure_method_implemented('isdir') filepath = self._map_path(filepath) filepath = self._format_path(filepath) @@ -264,7 +264,7 @@ def isfile(self, filepath: Union[str, Path]) -> bool: bool: Return ``True`` if ``filepath`` points to a file, ``False`` otherwise. """ - self._ensure_methods('contains') + self._ensure_method_implemented('contains') filepath = self._map_path(filepath) filepath = self._format_path(filepath) @@ -280,7 +280,10 @@ def concat_paths(self, filepath: Union[str, Path], Returns: str: The result after concatenation. """ - formatted_paths = [self._format_path(self._map_path(filepath))] + filepath = self._format_path(self._map_path(filepath)) + if filepath.endswith('/'): + filepath = filepath[:-1] + formatted_paths = [filepath] for path in filepaths: formatted_paths.append(self._format_path(self._map_path(path))) return '/'.join(formatted_paths) @@ -349,12 +352,13 @@ def list_dir_or_file(self, Yields: Iterable[str]: A relative path to ``dir_path``. """ - self._ensure_methods('list') + self._ensure_method_implemented('list') dir_path = self._map_path(dir_path) dir_path = self._format_path(dir_path) if list_dir and suffix is not None: - raise TypeError('`suffix` should be None when `list_dir` is True') + raise TypeError( + '`list_dir` should be False when `suffix` is not None') if (suffix is not None) and not isinstance(suffix, (str, tuple)): raise TypeError('`suffix` must be a string or tuple of strings') @@ -369,18 +373,23 @@ def list_dir_or_file(self, def _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive): for path in self._client.list(dir_path): - # the `self.isdir` is not used here to determine if path is a - # directory, because `self.isdir` relies on `self._client.list` + # the `self.isdir` is not used here to determine whether path + # is a directory, because `self.isdir` relies on + # `self._client.list` if path.endswith('/'): # a directory path + next_dir_path = self.concat_paths(dir_path, path) if list_dir: - # exclude the last character '/' - rel_dir = path[len(root):-1] + # get the relative path and exclude the last + # character '/' + rel_dir = next_dir_path[len(root):-1] yield rel_dir if recursive: - yield from _list_dir_or_file(path, list_dir, list_file, - suffix, recursive) + yield from _list_dir_or_file(next_dir_path, list_dir, + list_file, suffix, + recursive) else: # a file path - rel_path = path[len(root):] + absolute_path = self.concat_paths(dir_path, path) + rel_path = absolute_path[len(root):] if (suffix is None or rel_path.endswith(suffix)) and list_file: yield rel_path diff --git a/tests/test_fileclient.py b/tests/test_fileclient.py index b6bdd66b59..e2cd26f187 100644 --- a/tests/test_fileclient.py +++ b/tests/test_fileclient.py @@ -89,11 +89,9 @@ def isdir(self): def list(self, dir_path): for entry in os.scandir(dir_path): if not entry.name.startswith('.') and entry.is_file(): - path = entry.path.replace(os.sep, '/') - yield path + yield entry.name elif osp.isdir(entry.path): - path = entry.path.replace(os.sep, '/') - yield path + '/' + yield entry.name + '/' class MockMemcachedClient: @@ -346,15 +344,17 @@ def test_petrel_backend(self, backend, prefix): petrel_backend.put_text('petrel', petrel_path) mock_put.assert_called_once_with(petrel_path, b'petrel') - # test `_ensure_methods` + # test `_ensure_method_implemented` with pytest.raises(NotImplementedError): - petrel_backend.client._ensure_methods('unimplemented_method') + petrel_backend.client._ensure_method_implemented( + 'unimplemented_method') with pytest.raises(NotImplementedError): # `contains` is implemented but `unimplemented_method` not - petrel_backend.client._ensure_methods( + petrel_backend.client._ensure_method_implemented( ['contains', 'unimplemented_method']) - petrel_backend.client._ensure_methods('contains') - petrel_backend.client._ensure_methods(['contains', 'delete']) + petrel_backend.client._ensure_method_implemented('contains') + petrel_backend.client._ensure_method_implemented( + ['contains', 'delete']) # test `remove` with patch.object(petrel_backend.client._client, @@ -386,6 +386,8 @@ def test_petrel_backend(self, backend, prefix): # test `concat_paths` assert petrel_backend.concat_paths(petrel_dir, 'file') == \ f'{petrel_dir}/file' + assert petrel_backend.concat_paths(f'{petrel_dir}/', 'file') == \ + f'{petrel_dir}/file' assert petrel_backend.concat_paths(petrel_dir, 'dir', 'file') == \ f'{petrel_dir}/dir/file' @@ -421,7 +423,8 @@ def test_petrel_backend(self, backend, prefix): tmp_dir, list_file=False)) == set(['dir1', 'dir2']) with pytest.raises( TypeError, - match='`suffix` should be None when `list_dir` is True'): + match=('`list_dir` should be False when `suffix` is not ' + 'None')): # Exception is raised among the `list_dir_or_file` of client, # so we need to invode the client to trigger the exception petrel_backend.client.list_dir_or_file( From e66fe614ebd982df1a223ef21016053fba356e5c Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Sat, 23 Oct 2021 16:32:39 +0800 Subject: [PATCH 46/46] rename _ensure_method_implemented to has_method --- mmcv/fileio/file_client.py | 54 +++++++++++++++--------------- mmcv/utils/__init__.py | 6 ++-- mmcv/utils/misc.py | 13 ++++++++ tests/test_fileclient.py | 62 ++++++++++++++++++++++++++++------- tests/test_utils/test_misc.py | 16 +++++++++ 5 files changed, 109 insertions(+), 42 deletions(-) diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index 0848fb82f2..a6c0f8b89e 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -6,13 +6,12 @@ import tempfile import warnings from abc import ABCMeta, abstractmethod -from collections.abc import Sequence from contextlib import contextmanager from pathlib import Path from typing import Iterable, Iterator, Optional, Tuple, Union from urllib.request import urlopen -from mmcv.utils.misc import is_seq_of +from mmcv.utils.misc import has_method from mmcv.utils.path import is_filepath @@ -190,33 +189,17 @@ def put_text(self, """ self.put(bytes(obj, encoding=encoding), filepath) - def _ensure_method_implemented(self, method_names: Union[str, Sequence]): - """Ensure that methods have been implemented before called. - - Args: - method_names (str | Sequence): The name of method or the list of - name of method. - """ - if isinstance(method_names, str): - method_names = [method_names] - else: - assert is_seq_of(method_names, str) - - for method_name in method_names: - if not (hasattr(self._client, method_name) - and callable(getattr(self._client, method_name))): - raise NotImplementedError( - ('Current version of Petrel Python SDK has not supported ' - f'the `{method_name}` method, please use a higher version' - ' or dev branch instead.')) - def remove(self, filepath: Union[str, Path]) -> None: """Remove a file. Args: filepath (str or Path): Path to be removed. """ - self._ensure_method_implemented('delete') + if not has_method(self._client, 'delete'): + raise NotImplementedError( + ('Current version of Petrel Python SDK has not supported ' + 'the `delete` method, please use a higher version or dev' + ' branch instead.')) filepath = self._map_path(filepath) filepath = self._format_path(filepath) @@ -231,7 +214,12 @@ def exists(self, filepath: Union[str, Path]) -> bool: Returns: bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. """ - self._ensure_method_implemented(['contains', 'isdir']) + if not (has_method(self._client, 'contains') + and has_method(self._client, 'isdir')): + raise NotImplementedError( + ('Current version of Petrel Python SDK has not supported ' + 'the `contains` and `isdir` methods, please use a higher' + 'version or dev branch instead.')) filepath = self._map_path(filepath) filepath = self._format_path(filepath) @@ -248,7 +236,11 @@ def isdir(self, filepath: Union[str, Path]) -> bool: bool: Return ``True`` if ``filepath`` points to a directory, ``False`` otherwise. """ - self._ensure_method_implemented('isdir') + if not has_method(self._client, 'isdir'): + raise NotImplementedError( + ('Current version of Petrel Python SDK has not supported ' + 'the `isdir` method, please use a higher version or dev' + ' branch instead.')) filepath = self._map_path(filepath) filepath = self._format_path(filepath) @@ -264,7 +256,11 @@ def isfile(self, filepath: Union[str, Path]) -> bool: bool: Return ``True`` if ``filepath`` points to a file, ``False`` otherwise. """ - self._ensure_method_implemented('contains') + if not has_method(self._client, 'contains'): + raise NotImplementedError( + ('Current version of Petrel Python SDK has not supported ' + 'the `contains` method, please use a higher version or ' + 'dev branch instead.')) filepath = self._map_path(filepath) filepath = self._format_path(filepath) @@ -352,7 +348,11 @@ def list_dir_or_file(self, Yields: Iterable[str]: A relative path to ``dir_path``. """ - self._ensure_method_implemented('list') + if not has_method(self._client, 'list'): + raise NotImplementedError( + ('Current version of Petrel Python SDK has not supported ' + 'the `list` method, please use a higher version or dev' + ' branch instead.')) dir_path = self._map_path(dir_path) dir_path = self._format_path(dir_path) diff --git a/mmcv/utils/__init__.py b/mmcv/utils/__init__.py index baf8109f05..378a006843 100644 --- a/mmcv/utils/__init__.py +++ b/mmcv/utils/__init__.py @@ -2,7 +2,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .config import Config, ConfigDict, DictAction from .misc import (check_prerequisites, concat_list, deprecated_api_warning, - import_modules_from_strings, is_list_of, + has_method, import_modules_from_strings, is_list_of, is_method_overridden, is_seq_of, is_str, is_tuple_of, iter_cast, list_cast, requires_executable, requires_package, slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple, @@ -33,7 +33,7 @@ 'assert_dict_contains_subset', 'assert_attrs_equal', 'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script', 'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple', - 'is_method_overridden' + 'is_method_overridden', 'has_method' ] else: from .env import collect_env @@ -65,5 +65,5 @@ 'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer', 'assert_params_all_zeros', 'check_python_script', 'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch', - '_get_cuda_home' + '_get_cuda_home', 'has_method' ] diff --git a/mmcv/utils/misc.py b/mmcv/utils/misc.py index b9fddfb2b5..aa7fd8393d 100644 --- a/mmcv/utils/misc.py +++ b/mmcv/utils/misc.py @@ -352,3 +352,16 @@ def is_method_overridden(method, base_class, derived_class): base_method = getattr(base_class, method) derived_method = getattr(derived_class, method) return derived_method != base_method + + +def has_method(obj: object, method: str) -> bool: + """Check whether the object has a method. + + Args: + method (str): The method name to check. + obj (object): The object to check. + + Returns: + bool: True if the object has the method else False. + """ + return hasattr(obj, method) and callable(getattr(obj, method)) diff --git a/tests/test_fileclient.py b/tests/test_fileclient.py index e2cd26f187..d15483c94c 100644 --- a/tests/test_fileclient.py +++ b/tests/test_fileclient.py @@ -3,6 +3,7 @@ import sys import tempfile from contextlib import contextmanager +from copy import deepcopy from pathlib import Path from unittest.mock import MagicMock, patch @@ -10,6 +11,7 @@ import mmcv from mmcv import BaseStorageBackend, FileClient +from mmcv.utils import has_method sys.modules['ceph'] = MagicMock() sys.modules['petrel_client'] = MagicMock() @@ -52,6 +54,16 @@ def build_temporary_directory(): yield tmp_dir +@contextmanager +def delete_and_reset_method(obj, method): + method_obj = deepcopy(getattr(type(obj), method)) + try: + delattr(type(obj), method) + yield + finally: + setattr(type(obj), method, method_obj) + + class MockS3Client: def __init__(self, enable_mc=True): @@ -344,25 +356,32 @@ def test_petrel_backend(self, backend, prefix): petrel_backend.put_text('petrel', petrel_path) mock_put.assert_called_once_with(petrel_path, b'petrel') - # test `_ensure_method_implemented` - with pytest.raises(NotImplementedError): - petrel_backend.client._ensure_method_implemented( - 'unimplemented_method') - with pytest.raises(NotImplementedError): - # `contains` is implemented but `unimplemented_method` not - petrel_backend.client._ensure_method_implemented( - ['contains', 'unimplemented_method']) - petrel_backend.client._ensure_method_implemented('contains') - petrel_backend.client._ensure_method_implemented( - ['contains', 'delete']) - # test `remove` + assert has_method(petrel_backend.client._client, 'delete') + # raise Exception if `delete` is not implemented + with delete_and_reset_method(petrel_backend.client._client, 'delete'): + assert not has_method(petrel_backend.client._client, 'delete') + with pytest.raises(NotImplementedError): + petrel_backend.remove(petrel_path) + with patch.object(petrel_backend.client._client, 'delete') as mock_delete: petrel_backend.remove(petrel_path) mock_delete.assert_called_once_with(petrel_path) # test `exists` + assert has_method(petrel_backend.client._client, 'contains') + assert has_method(petrel_backend.client._client, 'isdir') + # raise Exception if `delete` is not implemented + with delete_and_reset_method(petrel_backend.client._client, + 'contains'), delete_and_reset_method( + petrel_backend.client._client, + 'isdir'): + assert not has_method(petrel_backend.client._client, 'contains') + assert not has_method(petrel_backend.client._client, 'isdir') + with pytest.raises(NotImplementedError): + petrel_backend.exists(petrel_path) + with patch.object( petrel_backend.client._client, 'contains', return_value=True) as mock_contains: @@ -370,6 +389,12 @@ def test_petrel_backend(self, backend, prefix): mock_contains.assert_called_once_with(petrel_path) # test `isdir` + assert has_method(petrel_backend.client._client, 'isdir') + with delete_and_reset_method(petrel_backend.client._client, 'isdir'): + assert not has_method(petrel_backend.client._client, 'isdir') + with pytest.raises(NotImplementedError): + petrel_backend.isdir(petrel_path) + with patch.object( petrel_backend.client._client, 'isdir', return_value=True) as mock_isdir: @@ -377,6 +402,13 @@ def test_petrel_backend(self, backend, prefix): mock_isdir.assert_called_once_with(petrel_dir) # test `isfile` + assert has_method(petrel_backend.client._client, 'contains') + with delete_and_reset_method(petrel_backend.client._client, + 'contains'): + assert not has_method(petrel_backend.client._client, 'contains') + with pytest.raises(NotImplementedError): + petrel_backend.isfile(petrel_path) + with patch.object( petrel_backend.client._client, 'contains', return_value=True) as mock_contains: @@ -404,6 +436,12 @@ def test_petrel_backend(self, backend, prefix): mock_contains.assert_called_once_with(petrel_path) # test `list_dir_or_file` + assert has_method(petrel_backend.client._client, 'list') + with delete_and_reset_method(petrel_backend.client._client, 'list'): + assert not has_method(petrel_backend.client._client, 'list') + with pytest.raises(NotImplementedError): + list(petrel_backend.list_dir_or_file(petrel_dir)) + with build_temporary_directory() as tmp_dir: # 1. list directories and files assert set(petrel_backend.list_dir_or_file(tmp_dir)) == set( diff --git a/tests/test_utils/test_misc.py b/tests/test_utils/test_misc.py index 67f80044b8..42a2227385 100644 --- a/tests/test_utils/test_misc.py +++ b/tests/test_utils/test_misc.py @@ -2,6 +2,7 @@ import pytest import mmcv +from mmcv.utils.misc import has_method def test_to_ntuple(): @@ -190,3 +191,18 @@ def foo1(): base_instance = Base() with pytest.raises(AssertionError): mmcv.is_method_overridden('foo1', base_instance, sub_instance) + + +def test_has_method(): + + class Foo: + + def __init__(self, name): + self.name = name + + def print_name(self): + print(self.name) + + foo = Foo('foo') + assert not has_method(foo, 'name') + assert has_method(foo, 'print_name')