Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Loading objects from different backends and dumping objects to different backends #1330

Merged
merged 50 commits into from
Oct 23, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
b2a2257
[Feature] Choose storage backend by the prefix of filepath
zhouzaida Sep 9, 2021
073f73e
refactor FileClient and add unittest
zhouzaida Sep 10, 2021
dfb9fc4
support loading from different backends
zhouzaida Sep 11, 2021
48cfdad
polish docstring
zhouzaida Sep 21, 2021
c2c9fc0
fix unittet
zhouzaida Sep 21, 2021
d641a8c
rename attribute str_like_obj to is_str_like_obj
zhouzaida Sep 22, 2021
68f0ab6
add infer_client method
zhouzaida Sep 23, 2021
31caf8e
add check_exist method
zhouzaida Sep 23, 2021
7e7a80f
rename var client to file_client
zhouzaida Sep 24, 2021
aa8274b
polish docstring
zhouzaida Sep 26, 2021
bb4712d
add join_paths method
zhouzaida Sep 27, 2021
2409531
Merge branch 'master' of https://github.com/open-mmlab/mmcv into load…
zhouzaida Sep 27, 2021
d4b6d96
remove join_paths and add _format_path
zhouzaida Sep 28, 2021
824cff3
Merge branch 'master' of https://github.com/open-mmlab/mmcv into load…
zhouzaida Oct 3, 2021
767f7fb
enhance unittest
zhouzaida Oct 3, 2021
b930678
refactor unittest
zhouzaida Oct 3, 2021
1752698
singleton pattern
zhouzaida Oct 4, 2021
fb9567c
fix test_clientio.py
zhouzaida Oct 4, 2021
00505f8
deprecate CephBackend
zhouzaida Oct 4, 2021
225d3a6
enhance docstring
zhouzaida Oct 6, 2021
22644da
refactor unittest for petrel
zhouzaida Oct 6, 2021
058b7e8
refactor unittest for disk backend
zhouzaida Oct 6, 2021
1692678
update io.md
zhouzaida Oct 6, 2021
01b9807
add concat_paths method
zhouzaida Oct 6, 2021
fed5a39
improve docstring
zhouzaida Oct 8, 2021
4959687
improve docstring
zhouzaida Oct 8, 2021
aea920a
add isdir and copyfile for file backend
zhouzaida Oct 10, 2021
6412103
delete copyfile and add get_local_path
zhouzaida Oct 11, 2021
c557ca3
Merge branch 'master' of https://github.com/open-mmlab/mmcv into load…
zhouzaida Oct 12, 2021
eeda74c
remove isdir method of petrel
zhouzaida Oct 12, 2021
ad52428
fix typo
zhouzaida Oct 12, 2021
941a884
add comment and polish docstring
zhouzaida Oct 13, 2021
198a465
polish docstring
zhouzaida Oct 14, 2021
e0d6a83
rename _path_mapping to _map_path
zhouzaida Oct 15, 2021
ae0cdd3
polish docstring and fix typo
zhouzaida Oct 15, 2021
a2e0162
refactor get_local_path
zhouzaida Oct 16, 2021
50ba26f
add list_dir_or_file for FileClient
zhouzaida Oct 17, 2021
4ad3bf5
add list_dir_or_file for PetrelBackend
zhouzaida Oct 18, 2021
df207d1
fix windows ci
zhouzaida Oct 18, 2021
d29a88d
Add return docstring
zhouzaida Oct 19, 2021
f18a779
polish docstring
zhouzaida Oct 19, 2021
b6eb5d1
fix typo
zhouzaida Oct 19, 2021
150d504
fix typo
zhouzaida Oct 19, 2021
208ff82
deprecate the conversion from Path to str
zhouzaida Oct 20, 2021
9ecfc12
add docs for loading checkpoints with FileClient
zhouzaida Oct 22, 2021
38559f1
refactor map_path
zhouzaida Oct 22, 2021
ea32388
add _ensure_methods to ensure methods have been implemented
zhouzaida Oct 22, 2021
a8cc11d
fix list_dir_or_file
zhouzaida Oct 22, 2021
e66fe61
rename _ensure_method_implemented to has_method
zhouzaida Oct 23, 2021
6987038
fix conflict
zhouzaida Oct 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 128 additions & 24 deletions mmcv/fileio/file_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
from abc import ABCMeta, abstractmethod
from pathlib import Path
from typing import Optional, Union
from urllib.request import urlopen


Expand Down Expand Up @@ -49,7 +51,7 @@ def get(self, filepath):
value_buf = memoryview(value)
return value_buf

def get_text(self, filepath):
def get_text(self, filepath, encoding=None):
raise NotImplementedError


Expand All @@ -61,20 +63,26 @@ class PetrelBackend(BaseStorageBackend):
path. When `path_mapping={'src': 'dst'}`, `src` in `filepath` will
be replaced by `dst`. Default: None.
enable_mc (bool): whether to enable memcached support. Default: True.
enable_multi_cluster (bool): Whether to enable multi clusters.
ZwwWayne marked this conversation as resolved.
Show resolved Hide resolved
Default: False.
"""

def __init__(self, path_mapping=None, enable_mc=True):
def __init__(self,
path_mapping: Optional[dict] = None,
enable_mc: bool = True,
enable_multi_cluster: bool = False):
try:
from petrel_client import client
except ImportError:
raise ImportError('Please install petrel_client to enable '
'PetrelBackend.')

self._client = client.Client(enable_mc=enable_mc)
self._client = client.Client(
enable_mc=enable_mc, enable_multi_cluster=enable_multi_cluster)
assert isinstance(path_mapping, dict) or path_mapping is None
self.path_mapping = path_mapping

def get(self, filepath):
def get(self, filepath: Union[str, Path]) -> memoryview:
filepath = str(filepath)
if self.path_mapping is not None:
for k, v in self.path_mapping.items():
Expand All @@ -83,8 +91,23 @@ def get(self, filepath):
value_buf = memoryview(value)
return value_buf

def get_text(self, filepath):
raise NotImplementedError
def get_text(self,
filepath: Union[str, Path],
encoding: str = 'utf-8') -> str:
return str(self.get(filepath), encoding=encoding)
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved

def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
filepath = str(filepath)
if self.path_mapping is not None:
for k, v in self.path_mapping.items():
filepath = filepath.replace(k, v)
self._client.put(filepath, obj)

def put_text(self,
obj: str,
filepath: Union[str, Path],
encoding: str = 'utf-8') -> None:
self.put(bytes(obj, encoding=encoding), filepath)


class MemcachedBackend(BaseStorageBackend):
Expand Down Expand Up @@ -121,7 +144,7 @@ def get(self, filepath):
value_buf = mc.ConvertBuffer(self._mc_buffer)
return value_buf

def get_text(self, filepath):
def get_text(self, filepath, encoding=None):
raise NotImplementedError


Expand Down Expand Up @@ -173,7 +196,7 @@ def get(self, filepath):
value_buf = txn.get(filepath.encode('ascii'))
return value_buf

def get_text(self, filepath):
def get_text(self, filepath, encoding=None):
raise NotImplementedError


Expand All @@ -186,12 +209,22 @@ def get(self, filepath):
value_buf = f.read()
return value_buf

def get_text(self, filepath):
def get_text(self, filepath, encoding='utf-8'):
filepath = str(filepath)
with open(filepath, 'r') as f:
with open(filepath, 'r', encoding=encoding) as f:
value_buf = f.read()
return value_buf

def put(self, obj, filepath):
filepath = str(filepath)
with open(filepath, 'wb') as f:
f.write(obj)

def put_text(self, obj, filepath, encoding='utf-8'):
filepath = str(filepath)
with open(filepath, 'w', encoding=encoding) as f:
f.write(obj)


class HTTPBackend(BaseStorageBackend):
"""HTTP and HTTPS storage bachend."""
Expand All @@ -200,9 +233,9 @@ def get(self, filepath):
value_buf = urlopen(filepath).read()
return value_buf

def get_text(self, filepath):
def get_text(self, filepath, encoding='utf-8'):
value_buf = urlopen(filepath).read()
return value_buf.decode('utf-8')
return value_buf.decode(encoding)


class FileClient:
Expand All @@ -212,9 +245,20 @@ class FileClient:
and return it as a binary file. it can also register other backend
accessor with a given name and backend class.

Attributes:
Args:
backend (str): The storage backend type. Options are "disk", "ceph",
"memcached", "lmdb" and "http".
"memcached", "lmdb", "http" and "petrel". Default: None.
prefixes (str or list[str] or tuple[str]): The prefixes of the
registered storage backend. Both backend and prefixes can be used
to choose a storage backend, but backend has a higher priority that
is if they are all set, the storage backend will be chosen by the
backend rather than prefixes. If backend and prefixes are all
`None`. The dist backend is be chosen. Default: None.
ZwwWayne marked this conversation as resolved.
Show resolved Hide resolved

.. versionadd:: 1.3.14
The *prefixes* parameter.

Attributes:
client (:obj:`BaseStorageBackend`): The backend object.
"""

Expand All @@ -226,17 +270,52 @@ class FileClient:
'petrel': PetrelBackend,
'http': HTTPBackend,
}
_prefix_to_backends = {
's3': PetrelBackend,
'http': HTTPBackend,
'https': HTTPBackend,
}

def __init__(self, backend='disk', **kwargs):
if backend not in self._backends:
def __init__(self, backend=None, prefixes=None, **kwargs):
if backend is None and prefixes is None:
backend = 'disk'
if backend is not None and backend not in self._backends:
raise ValueError(
f'Backend {backend} is not supported. Currently supported ones'
f' are {list(self._backends.keys())}')
self.backend = backend
self.client = self._backends[backend](**kwargs)
if prefixes is not None:
if isinstance(prefixes, str):
prefixes = [prefixes]
else:
assert isinstance(prefixes, (list, tuple))

if not set(prefixes).issubset(self._prefix_to_backends.keys()):
raise ValueError(
f'prefixes {prefixes} is not supported. Currently '
'supported ones are '
f'{list(self._prefix_to_backends.keys())}')

if backend is not None:
self.client = self._backends[backend](**kwargs)
else:
for prefix in prefixes:
self.client = self._prefix_to_backends[prefix](**kwargs)
break

@staticmethod
def parse_uri_prefix(uri):
uri = str(uri)
if '://' not in uri:
return None
else:
prefix, _ = uri.split('://')
# clusterName:s3://
if ':' in prefix:
_, prefix = prefix.split(':')
return prefix

@classmethod
def _register_backend(cls, name, backend, force=False):
def _register_backend(cls, name, backend, force=False, prefixes=None):
if not isinstance(name, str):
raise TypeError('the backend name should be a string, '
f'but got {type(name)}')
Expand All @@ -252,9 +331,21 @@ def _register_backend(cls, name, backend, force=False):
'add "force=True" if you want to override it')

cls._backends[name] = backend
if prefixes is not None:
if isinstance(prefixes, str):
prefixes = [prefixes]
else:
assert isinstance(prefixes, (list, tuple))
for prefix in prefixes:
if (prefix not in cls._prefix_to_backends) or force:
cls._prefix_to_backends[prefix] = backend
else:
raise KeyError(
f'{prefix} is already registered as a storage backend,'
' add "force=True" if you want to override it')

@classmethod
def register_backend(cls, name, backend=None, force=False):
def register_backend(cls, name, backend=None, force=False, prefixes=None):
"""Register a backend to FileClient.

This method can be used as a normal class method or a decorator.
Expand Down Expand Up @@ -292,19 +383,32 @@ def get_text(self, filepath):
Defaults to None.
force (bool, optional): Whether to override the backend if the name
has already been registered. Defaults to False.
prefixes (str or list[str] or tuple[str]): The prefix of the
registered storage backend.

.. versionadd:: 1.3.14
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
The *prefixes* parameter.
"""
if backend is not None:
cls._register_backend(name, backend, force=force)
cls._register_backend(
name, backend, force=force, prefixes=prefixes)
return

def _register(backend_cls):
cls._register_backend(name, backend_cls, force=force)
cls._register_backend(
name, backend_cls, force=force, prefixes=prefixes)
return backend_cls

return _register

def get(self, filepath):
return self.client.get(filepath)

def get_text(self, filepath):
return self.client.get_text(filepath)
def get_text(self, filepath, encoding='utf-8'):
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
return self.client.get_text(filepath, encoding)

def put(self, obj, filepath):
self.client.put(obj, filepath)

def put_text(self, obj, filepath):
self.client.put_text(obj, filepath)
1 change: 1 addition & 0 deletions mmcv/fileio/handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@


class BaseFileHandler(metaclass=ABCMeta):
str_like_obj = True

@abstractmethod
def load_from_fileobj(self, file, **kwargs):
Expand Down
2 changes: 2 additions & 0 deletions mmcv/fileio/handlers/pickle_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

class PickleHandler(BaseFileHandler):

str_like_obj = False

def load_from_fileobj(self, file, **kwargs):
return pickle.load(file, **kwargs)

Expand Down
39 changes: 34 additions & 5 deletions mmcv/fileio/io.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from io import BytesIO, StringIO
from pathlib import Path

from ..utils import is_list_of, is_str
from .file_client import FileClient
from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler

file_handlers = {
Expand All @@ -13,7 +15,7 @@
}


def load(file, file_format=None, **kwargs):
def load(file, file_format=None, file_client_args=None, **kwargs):
"""Load data from json/yaml/pickle files.

This method provides a unified api for loading data from serialized files.
Expand All @@ -25,6 +27,8 @@ def load(file, file_format=None, **kwargs):
inferred from the file extension, otherwise use the specified one.
Currently supported formats include "json", "yaml/yml" and
"pickle/pkl".
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details. Default: None.

Returns:
The content from the file.
Expand All @@ -36,17 +40,28 @@ def load(file, file_format=None, **kwargs):
if file_format not in file_handlers:
raise TypeError(f'Unsupported format: {file_format}')

if file_client_args is None:
file_prefix = FileClient.parse_uri_prefix(file)
client = FileClient(prefixes=file_prefix)
else:
client = FileClient(**file_client_args)

handler = file_handlers[file_format]
if is_str(file):
obj = handler.load_from_path(file, **kwargs)
if handler.str_like_obj:
with StringIO(client.get_text(file)) as f:
obj = handler.load_from_fileobj(f, **kwargs)
else:
with BytesIO(client.get(file)) as f:
obj = handler.load_from_fileobj(f, **kwargs)
elif hasattr(file, 'read'):
obj = handler.load_from_fileobj(file, **kwargs)
else:
raise TypeError('"file" must be a filepath str or a file-object')
return obj


def dump(obj, file=None, file_format=None, **kwargs):
def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs):
"""Dump data to json/yaml/pickle strings or files.

This method provides a unified api for dumping data as strings or to files,
Expand All @@ -58,7 +73,8 @@ def dump(obj, file=None, file_format=None, **kwargs):
specified, then the object is dump to a str, otherwise to a file
specified by the filename or file-like object.
file_format (str, optional): Same as :func:`load`.

file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details. Default: None.
Returns:
bool: True for success, False otherwise.
"""
Expand All @@ -73,11 +89,24 @@ def dump(obj, file=None, file_format=None, **kwargs):
if file_format not in file_handlers:
raise TypeError(f'Unsupported format: {file_format}')

if file_client_args is None:
file_prefix = FileClient.parse_uri_prefix(file)
client = FileClient(prefixes=file_prefix)
else:
client = FileClient(**file_client_args)

handler = file_handlers[file_format]
if file is None:
return handler.dump_to_str(obj, **kwargs)
elif is_str(file):
handler.dump_to_path(obj, file, **kwargs)
if handler.str_like_obj:
f = StringIO()
handler.dump_to_fileobj(obj, f, **kwargs)
client.put_text(f.getvalue(), file)
else:
f = BytesIO()
handler.dump_to_fileobj(obj, f, **kwargs)
client.put(f.getvalue(), file)
elif hasattr(file, 'write'):
handler.dump_to_fileobj(obj, file, **kwargs)
else:
Expand Down
Loading