Skip to content

Commit

Permalink
Merge 5940864 into 8cac7c2
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouzaida committed Sep 27, 2021
2 parents 8cac7c2 + 5940864 commit 81520af
Show file tree
Hide file tree
Showing 12 changed files with 668 additions and 76 deletions.
225 changes: 198 additions & 27 deletions mmcv/fileio/file_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
import os
import os.path as osp
from abc import ABCMeta, abstractmethod
from pathlib import Path
from typing import Optional, Union
from urllib.request import urlopen

from mmcv.utils.path import is_filepath


class BaseStorageBackend(metaclass=ABCMeta):
"""Abstract class of storage backends.
Expand Down Expand Up @@ -49,7 +55,7 @@ def get(self, filepath):
value_buf = memoryview(value)
return value_buf

def get_text(self, filepath):
def get_text(self, filepath, encoding=None):
raise NotImplementedError


Expand All @@ -61,30 +67,59 @@ class PetrelBackend(BaseStorageBackend):
path. When `path_mapping={'src': 'dst'}`, `src` in `filepath` will
be replaced by `dst`. Default: None.
enable_mc (bool): whether to enable memcached support. Default: True.
enable_multi_cluster (bool): Whether to enable multiple clusters.
Default: False.
"""

def __init__(self, path_mapping=None, enable_mc=True):
def __init__(self,
path_mapping: Optional[dict] = None,
enable_mc: bool = True,
enable_multi_cluster: bool = False):
try:
from petrel_client import client
except ImportError:
raise ImportError('Please install petrel_client to enable '
'PetrelBackend.')

self._client = client.Client(enable_mc=enable_mc)
self._client = client.Client(
enable_mc=enable_mc, enable_multi_cluster=enable_multi_cluster)
assert isinstance(path_mapping, dict) or path_mapping is None
self.path_mapping = path_mapping

def get(self, filepath):
filepath = str(filepath)
def _path_mapping(self, filepath: str) -> str:
if self.path_mapping is not None:
for k, v in self.path_mapping.items():
filepath = filepath.replace(k, v)
return filepath

def get(self, filepath: Union[str, Path]) -> memoryview:
filepath = self._path_mapping(str(filepath))
value = self._client.Get(filepath)
value_buf = memoryview(value)
return value_buf

def get_text(self, filepath):
raise NotImplementedError
def get_text(self,
filepath: Union[str, Path],
encoding: str = 'utf-8') -> str:
return str(self.get(filepath), encoding=encoding)

def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
filepath = self._path_mapping(str(filepath))
self._client.put(filepath, obj)

def put_text(self,
obj: str,
filepath: Union[str, Path],
encoding: str = 'utf-8') -> None:
self.put(bytes(obj, encoding=encoding), str(filepath))

def remove(self, filepath: Union[str, Path]) -> None:
filepath = self._path_mapping(str(filepath))
self._client.delete(filepath)

def check_exist(self, filepath: Union[str, Path]) -> bool:
# TODO, need other team to support the feature
return True


class MemcachedBackend(BaseStorageBackend):
Expand Down Expand Up @@ -121,7 +156,7 @@ def get(self, filepath):
value_buf = mc.ConvertBuffer(self._mc_buffer)
return value_buf

def get_text(self, filepath):
def get_text(self, filepath, encoding=None):
raise NotImplementedError


Expand Down Expand Up @@ -173,7 +208,7 @@ def get(self, filepath):
value_buf = txn.get(filepath.encode('ascii'))
return value_buf

def get_text(self, filepath):
def get_text(self, filepath, encoding=None):
raise NotImplementedError


Expand All @@ -186,12 +221,30 @@ def get(self, filepath):
value_buf = f.read()
return value_buf

def get_text(self, filepath):
def get_text(self, filepath, encoding='utf-8'):
filepath = str(filepath)
with open(filepath, 'r') as f:
with open(filepath, 'r', encoding=encoding) as f:
value_buf = f.read()
return value_buf

def put(self, obj, filepath):
filepath = str(filepath)
with open(filepath, 'wb') as f:
f.write(obj)

def put_text(self, obj, filepath, encoding='utf-8'):
filepath = str(filepath)
with open(filepath, 'w', encoding=encoding) as f:
f.write(obj)

def remove(self, filepath: Union[str, Path]) -> None:
"""Remove a file."""
filepath = str(filepath)
os.remove(filepath)

def check_exist(self, filepath: Union[str, Path]) -> bool:
return osp.exists(str(filepath))


class HTTPBackend(BaseStorageBackend):
"""HTTP and HTTPS storage bachend."""
Expand All @@ -200,21 +253,42 @@ def get(self, filepath):
value_buf = urlopen(filepath).read()
return value_buf

def get_text(self, filepath):
def get_text(self, filepath, encoding='utf-8'):
value_buf = urlopen(filepath).read()
return value_buf.decode('utf-8')
return value_buf.decode(encoding)


class FileClient:
"""A general file client to access files in different backend.
The client loads a file or text in a specified backend from its path
and return it as a binary file. it can also register other backend
accessor with a given name and backend class.
and return it as a binary or text file. There are two ways to choose a
backend, the name of backend and the prefixes of path. Although both of
them can be used to choose a storage backend, ``backend`` has a higher
priority that is if they are all set, the storage backend will be chosen by
the backend argument. If they are all `None`, the disk backend will be
chosen. Note that It can also register other backend accessor with a given
name, prefixes, and backend class.
Attributes:
Args:
backend (str): The storage backend type. Options are "disk", "ceph",
"memcached", "lmdb" and "http".
"memcached", "lmdb", "http" and "petrel". Default: None.
prefixes (str or list[str] or tuple[str]): The prefixes of the
registered storage backend. Options are "s3", "http", "https".
Default: None.
.. versionadd:: 1.3.14
The *prefixes* parameter.
Example:
>>> # only set backend
>>> file_client = FileClient(backend='ceph')
>>> # only set prefixes
>>> file_client = FileClient(prefixes='s3')
>>> # set both backend and prefixes but use backend to choose client
>>> file_client = FileClient(backend='ceph', prefixes='s3')
Attributes:
client (:obj:`BaseStorageBackend`): The backend object.
"""

Expand All @@ -226,17 +300,83 @@ class FileClient:
'petrel': PetrelBackend,
'http': HTTPBackend,
}
_prefix_to_backends = {
's3': PetrelBackend,
'http': HTTPBackend,
'https': HTTPBackend,
}

def __init__(self, backend='disk', **kwargs):
if backend not in self._backends:
def __init__(self, backend=None, prefixes=None, **kwargs):
if backend is None and prefixes is None:
backend = 'disk'
if backend is not None and backend not in self._backends:
raise ValueError(
f'Backend {backend} is not supported. Currently supported ones'
f' are {list(self._backends.keys())}')
self.backend = backend
self.client = self._backends[backend](**kwargs)
if prefixes is not None:
if isinstance(prefixes, str):
prefixes = [prefixes]
else:
assert isinstance(prefixes, (list, tuple))

if not set(prefixes).issubset(self._prefix_to_backends.keys()):
raise ValueError(
f'prefixes {prefixes} is not supported. Currently '
'supported ones are '
f'{list(self._prefix_to_backends.keys())}')

if backend is not None:
self.client = self._backends[backend](**kwargs)
else:
for prefix in prefixes:
self.client = self._prefix_to_backends[prefix](**kwargs)
break

for backend_name, backend_cls in self._backends.items():
if isinstance(self.client, backend_cls):
self.backend_name = backend_name
break

@staticmethod
def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]:
"""Parse the prefix of a uri.
Args:
uri (str | Path): Uri to be parsed that contains the file prefix.
"""
assert is_filepath(uri)
uri = str(uri)
if '://' not in uri:
return None
else:
prefix, _ = uri.split('://')
# In the case of ceph, the prefix may contains the cluster name
# like clusterName:s3
if ':' in prefix:
_, prefix = prefix.split(':')
return prefix

@classmethod
def infer_client(cls,
file_client_args: Optional[dict] = None,
uri: Optional[Union[str, Path]] = None) -> 'FileClient':
"""Infer a suitable file client based on the URI and arguments.
Args:
file_client_args (dict): Arguments to instantiate a FileClient.
Default: None.
uri (str | Path, optional): Uri to be parsed that contains the file
prefix. Default: None.
"""
assert file_client_args is not None or uri is not None
if file_client_args is None:
file_prefix = cls.parse_uri_prefix(uri) # type: ignore
return cls(prefixes=file_prefix)
else:
return cls(**file_client_args)

@classmethod
def _register_backend(cls, name, backend, force=False):
def _register_backend(cls, name, backend, force=False, prefixes=None):
if not isinstance(name, str):
raise TypeError('the backend name should be a string, '
f'but got {type(name)}')
Expand All @@ -252,9 +392,21 @@ def _register_backend(cls, name, backend, force=False):
'add "force=True" if you want to override it')

cls._backends[name] = backend
if prefixes is not None:
if isinstance(prefixes, str):
prefixes = [prefixes]
else:
assert isinstance(prefixes, (list, tuple))
for prefix in prefixes:
if (prefix not in cls._prefix_to_backends) or force:
cls._prefix_to_backends[prefix] = backend
else:
raise KeyError(
f'{prefix} is already registered as a storage backend,'
' add "force=True" if you want to override it')

@classmethod
def register_backend(cls, name, backend=None, force=False):
def register_backend(cls, name, backend=None, force=False, prefixes=None):
"""Register a backend to FileClient.
This method can be used as a normal class method or a decorator.
Expand Down Expand Up @@ -292,19 +444,38 @@ def get_text(self, filepath):
Defaults to None.
force (bool, optional): Whether to override the backend if the name
has already been registered. Defaults to False.
prefixes (str or list[str] or tuple[str]): The prefix of the
registered storage backend.
.. versionadd:: 1.3.14
The *prefixes* parameter.
"""
if backend is not None:
cls._register_backend(name, backend, force=force)
cls._register_backend(
name, backend, force=force, prefixes=prefixes)
return

def _register(backend_cls):
cls._register_backend(name, backend_cls, force=force)
cls._register_backend(
name, backend_cls, force=force, prefixes=prefixes)
return backend_cls

return _register

def get(self, filepath):
return self.client.get(filepath)

def get_text(self, filepath):
return self.client.get_text(filepath)
def get_text(self, filepath, encoding='utf-8'):
return self.client.get_text(filepath, encoding)

def put(self, obj, filepath):
self.client.put(obj, filepath)

def put_text(self, obj, filepath):
self.client.put_text(obj, filepath)

def remove(self, filepath):
self.client.remove(filepath)

def check_exist(self, filepath):
return self.client.check_exist(filepath)
7 changes: 7 additions & 0 deletions mmcv/fileio/handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@


class BaseFileHandler(metaclass=ABCMeta):
# is_str_like_obj is a flag to mark which type of file object is processed,
# bytes-like object or str-like object. For example, pickle only process
# the bytes-like object and json only process the str-like object. The flag
# will be used to check which type of buffer is used. If str-like object,
# StringIO will be used. If bytes-like object, BytesIO will be used. The
# usage of the flag can be found in `mmcv.load` or `mmcv.dump`
is_str_like_obj = True

@abstractmethod
def load_from_fileobj(self, file, **kwargs):
Expand Down
2 changes: 2 additions & 0 deletions mmcv/fileio/handlers/pickle_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

class PickleHandler(BaseFileHandler):

is_str_like_obj = False

def load_from_fileobj(self, file, **kwargs):
return pickle.load(file, **kwargs)

Expand Down
Loading

0 comments on commit 81520af

Please sign in to comment.