diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index e8a6cbdb08a..16847c1743f 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -1,8 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. import inspect +import os +import os.path as osp from abc import ABCMeta, abstractmethod +from pathlib import Path +from typing import Optional, Union from urllib.request import urlopen +from mmcv.utils.path import is_filepath + class BaseStorageBackend(metaclass=ABCMeta): """Abstract class of storage backends. @@ -49,7 +55,7 @@ def get(self, filepath): value_buf = memoryview(value) return value_buf - def get_text(self, filepath): + def get_text(self, filepath, encoding=None): raise NotImplementedError @@ -61,30 +67,59 @@ class PetrelBackend(BaseStorageBackend): path. When `path_mapping={'src': 'dst'}`, `src` in `filepath` will be replaced by `dst`. Default: None. enable_mc (bool): whether to enable memcached support. Default: True. + enable_multi_cluster (bool): Whether to enable multiple clusters. + Default: False. """ - def __init__(self, path_mapping=None, enable_mc=True): + def __init__(self, + path_mapping: Optional[dict] = None, + enable_mc: bool = True, + enable_multi_cluster: bool = False): try: from petrel_client import client except ImportError: raise ImportError('Please install petrel_client to enable ' 'PetrelBackend.') - self._client = client.Client(enable_mc=enable_mc) + self._client = client.Client( + enable_mc=enable_mc, enable_multi_cluster=enable_multi_cluster) assert isinstance(path_mapping, dict) or path_mapping is None self.path_mapping = path_mapping - def get(self, filepath): - filepath = str(filepath) + def _path_mapping(self, filepath: str) -> str: if self.path_mapping is not None: for k, v in self.path_mapping.items(): filepath = filepath.replace(k, v) + return filepath + + def get(self, filepath: Union[str, Path]) -> memoryview: + filepath = self._path_mapping(str(filepath)) value = self._client.Get(filepath) value_buf = memoryview(value) return value_buf - def get_text(self, filepath): - raise NotImplementedError + def get_text(self, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> str: + return str(self.get(filepath), encoding=encoding) + + def put(self, obj: bytes, filepath: Union[str, Path]) -> None: + filepath = self._path_mapping(str(filepath)) + self._client.put(filepath, obj) + + def put_text(self, + obj: str, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> None: + self.put(bytes(obj, encoding=encoding), str(filepath)) + + def remove(self, filepath: Union[str, Path]) -> None: + filepath = self._path_mapping(str(filepath)) + self._client.delete(filepath) + + def check_exist(self, filepath: Union[str, Path]) -> bool: + # TODO, need other team to support the feature + return True class MemcachedBackend(BaseStorageBackend): @@ -121,7 +156,7 @@ def get(self, filepath): value_buf = mc.ConvertBuffer(self._mc_buffer) return value_buf - def get_text(self, filepath): + def get_text(self, filepath, encoding=None): raise NotImplementedError @@ -173,7 +208,7 @@ def get(self, filepath): value_buf = txn.get(filepath.encode('ascii')) return value_buf - def get_text(self, filepath): + def get_text(self, filepath, encoding=None): raise NotImplementedError @@ -186,12 +221,30 @@ def get(self, filepath): value_buf = f.read() return value_buf - def get_text(self, filepath): + def get_text(self, filepath, encoding='utf-8'): filepath = str(filepath) - with open(filepath, 'r') as f: + with open(filepath, 'r', encoding=encoding) as f: value_buf = f.read() return value_buf + def put(self, obj, filepath): + filepath = str(filepath) + with open(filepath, 'wb') as f: + f.write(obj) + + def put_text(self, obj, filepath, encoding='utf-8'): + filepath = str(filepath) + with open(filepath, 'w', encoding=encoding) as f: + f.write(obj) + + def remove(self, filepath: Union[str, Path]) -> None: + """Remove a file.""" + filepath = str(filepath) + os.remove(filepath) + + def check_exist(self, filepath: Union[str, Path]) -> bool: + return osp.exists(str(filepath)) + class HTTPBackend(BaseStorageBackend): """HTTP and HTTPS storage bachend.""" @@ -200,21 +253,42 @@ def get(self, filepath): value_buf = urlopen(filepath).read() return value_buf - def get_text(self, filepath): + def get_text(self, filepath, encoding='utf-8'): value_buf = urlopen(filepath).read() - return value_buf.decode('utf-8') + return value_buf.decode(encoding) class FileClient: """A general file client to access files in different backend. The client loads a file or text in a specified backend from its path - and return it as a binary file. it can also register other backend - accessor with a given name and backend class. + and return it as a binary or text file. There are two ways to choose a + backend, the name of backend and the prefixes of path. Although both of + them can be used to choose a storage backend, ``backend`` has a higher + priority that is if they are all set, the storage backend will be chosen by + the backend argument. If they are all `None`, the disk backend will be + chosen. Note that It can also register other backend accessor with a given + name, prefixes, and backend class. - Attributes: + Args: backend (str): The storage backend type. Options are "disk", "ceph", - "memcached", "lmdb" and "http". + "memcached", "lmdb", "http" and "petrel". Default: None. + prefixes (str or list[str] or tuple[str]): The prefixes of the + registered storage backend. Options are "s3", "http", "https". + Default: None. + + .. versionadd:: 1.3.14 + The *prefixes* parameter. + + Example: + >>> # only set backend + >>> file_client = FileClient(backend='ceph') + >>> # only set prefixes + >>> file_client = FileClient(prefixes='s3') + >>> # set both backend and prefixes but use backend to choose client + >>> file_client = FileClient(backend='ceph', prefixes='s3') + + Attributes: client (:obj:`BaseStorageBackend`): The backend object. """ @@ -226,17 +300,83 @@ class FileClient: 'petrel': PetrelBackend, 'http': HTTPBackend, } + _prefix_to_backends = { + 's3': PetrelBackend, + 'http': HTTPBackend, + 'https': HTTPBackend, + } - def __init__(self, backend='disk', **kwargs): - if backend not in self._backends: + def __init__(self, backend=None, prefixes=None, **kwargs): + if backend is None and prefixes is None: + backend = 'disk' + if backend is not None and backend not in self._backends: raise ValueError( f'Backend {backend} is not supported. Currently supported ones' f' are {list(self._backends.keys())}') - self.backend = backend - self.client = self._backends[backend](**kwargs) + if prefixes is not None: + if isinstance(prefixes, str): + prefixes = [prefixes] + else: + assert isinstance(prefixes, (list, tuple)) + + if not set(prefixes).issubset(self._prefix_to_backends.keys()): + raise ValueError( + f'prefixes {prefixes} is not supported. Currently ' + 'supported ones are ' + f'{list(self._prefix_to_backends.keys())}') + + if backend is not None: + self.client = self._backends[backend](**kwargs) + else: + for prefix in prefixes: + self.client = self._prefix_to_backends[prefix](**kwargs) + break + + for backend_name, backend_cls in self._backends.items(): + if isinstance(self.client, backend_cls): + self.backend_name = backend_name + break + + @staticmethod + def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]: + """Parse the prefix of a uri. + + Args: + uri (str | Path): Uri to be parsed that contains the file prefix. + """ + assert is_filepath(uri) + uri = str(uri) + if '://' not in uri: + return None + else: + prefix, _ = uri.split('://') + # In the case of ceph, the prefix may contains the cluster name + # like clusterName:s3 + if ':' in prefix: + _, prefix = prefix.split(':') + return prefix + + @classmethod + def infer_client(cls, + file_client_args: Optional[dict] = None, + uri: Optional[Union[str, Path]] = None) -> 'FileClient': + """Infer a suitable file client based on the URI and arguments. + + Args: + file_client_args (dict): Arguments to instantiate a FileClient. + Default: None. + uri (str | Path, optional): Uri to be parsed that contains the file + prefix. Default: None. + """ + assert file_client_args is not None or uri is not None + if file_client_args is None: + file_prefix = cls.parse_uri_prefix(uri) # type: ignore + return cls(prefixes=file_prefix) + else: + return cls(**file_client_args) @classmethod - def _register_backend(cls, name, backend, force=False): + def _register_backend(cls, name, backend, force=False, prefixes=None): if not isinstance(name, str): raise TypeError('the backend name should be a string, ' f'but got {type(name)}') @@ -252,9 +392,21 @@ def _register_backend(cls, name, backend, force=False): 'add "force=True" if you want to override it') cls._backends[name] = backend + if prefixes is not None: + if isinstance(prefixes, str): + prefixes = [prefixes] + else: + assert isinstance(prefixes, (list, tuple)) + for prefix in prefixes: + if (prefix not in cls._prefix_to_backends) or force: + cls._prefix_to_backends[prefix] = backend + else: + raise KeyError( + f'{prefix} is already registered as a storage backend,' + ' add "force=True" if you want to override it') @classmethod - def register_backend(cls, name, backend=None, force=False): + def register_backend(cls, name, backend=None, force=False, prefixes=None): """Register a backend to FileClient. This method can be used as a normal class method or a decorator. @@ -292,13 +444,20 @@ def get_text(self, filepath): Defaults to None. force (bool, optional): Whether to override the backend if the name has already been registered. Defaults to False. + prefixes (str or list[str] or tuple[str]): The prefix of the + registered storage backend. + + .. versionadd:: 1.3.14 + The *prefixes* parameter. """ if backend is not None: - cls._register_backend(name, backend, force=force) + cls._register_backend( + name, backend, force=force, prefixes=prefixes) return def _register(backend_cls): - cls._register_backend(name, backend_cls, force=force) + cls._register_backend( + name, backend_cls, force=force, prefixes=prefixes) return backend_cls return _register @@ -306,5 +465,17 @@ def _register(backend_cls): def get(self, filepath): return self.client.get(filepath) - def get_text(self, filepath): - return self.client.get_text(filepath) + def get_text(self, filepath, encoding='utf-8'): + return self.client.get_text(filepath, encoding) + + def put(self, obj, filepath): + self.client.put(obj, filepath) + + def put_text(self, obj, filepath): + self.client.put_text(obj, filepath) + + def remove(self, filepath): + self.client.remove(filepath) + + def check_exist(self, filepath): + return self.client.check_exist(filepath) diff --git a/mmcv/fileio/handlers/base.py b/mmcv/fileio/handlers/base.py index 235727557ca..22d66d5b1b1 100644 --- a/mmcv/fileio/handlers/base.py +++ b/mmcv/fileio/handlers/base.py @@ -3,6 +3,13 @@ class BaseFileHandler(metaclass=ABCMeta): + # is_str_like_obj is a flag to mark which type of file object is processed, + # bytes-like object or str-like object. For example, pickle only process + # the bytes-like object and json only process the str-like object. The flag + # will be used to check which type of buffer is used. If str-like object, + # StringIO will be used. If bytes-like object, BytesIO will be used. The + # usage of the flag can be found in `mmcv.load` or `mmcv.dump` + is_str_like_obj = True @abstractmethod def load_from_fileobj(self, file, **kwargs): diff --git a/mmcv/fileio/handlers/pickle_handler.py b/mmcv/fileio/handlers/pickle_handler.py index 02504599577..648bf22b9c1 100644 --- a/mmcv/fileio/handlers/pickle_handler.py +++ b/mmcv/fileio/handlers/pickle_handler.py @@ -6,6 +6,8 @@ class PickleHandler(BaseFileHandler): + is_str_like_obj = False + def load_from_fileobj(self, file, **kwargs): return pickle.load(file, **kwargs) diff --git a/mmcv/fileio/io.py b/mmcv/fileio/io.py index 015d36e8089..c5206f208e4 100644 --- a/mmcv/fileio/io.py +++ b/mmcv/fileio/io.py @@ -1,7 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. +from io import BytesIO, StringIO from pathlib import Path from ..utils import is_list_of, is_str +from .file_client import FileClient from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler file_handlers = { @@ -13,7 +15,7 @@ } -def load(file, file_format=None, **kwargs): +def load(file, file_format=None, file_client_args=None, **kwargs): """Load data from json/yaml/pickle files. This method provides a unified api for loading data from serialized files. @@ -25,6 +27,8 @@ def load(file, file_format=None, **kwargs): inferred from the file extension, otherwise use the specified one. Currently supported formats include "json", "yaml/yml" and "pickle/pkl". + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. Default: None. Returns: The content from the file. @@ -38,7 +42,13 @@ def load(file, file_format=None, **kwargs): handler = file_handlers[file_format] if is_str(file): - obj = handler.load_from_path(file, **kwargs) + file_client = FileClient.infer_client(file_client_args, file) + if handler.is_str_like_obj: + with StringIO(file_client.get_text(file)) as f: + obj = handler.load_from_fileobj(f, **kwargs) + else: + with BytesIO(file_client.get(file)) as f: + obj = handler.load_from_fileobj(f, **kwargs) elif hasattr(file, 'read'): obj = handler.load_from_fileobj(file, **kwargs) else: @@ -46,7 +56,7 @@ def load(file, file_format=None, **kwargs): return obj -def dump(obj, file=None, file_format=None, **kwargs): +def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs): """Dump data to json/yaml/pickle strings or files. This method provides a unified api for dumping data as strings or to files, @@ -58,6 +68,8 @@ def dump(obj, file=None, file_format=None, **kwargs): specified, then the object is dump to a str, otherwise to a file specified by the filename or file-like object. file_format (str, optional): Same as :func:`load`. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. Default: None. Returns: bool: True for success, False otherwise. @@ -77,7 +89,15 @@ def dump(obj, file=None, file_format=None, **kwargs): if file is None: return handler.dump_to_str(obj, **kwargs) elif is_str(file): - handler.dump_to_path(obj, file, **kwargs) + file_client = FileClient.infer_client(file_client_args, file) + if handler.is_str_like_obj: + with StringIO() as f: + handler.dump_to_fileobj(obj, f, **kwargs) + file_client.put_text(f.getvalue(), file) + else: + with BytesIO() as f: + handler.dump_to_fileobj(obj, f, **kwargs) + file_client.put(f.getvalue(), file) elif hasattr(file, 'write'): handler.dump_to_fileobj(obj, file, **kwargs) else: diff --git a/mmcv/fileio/parse.py b/mmcv/fileio/parse.py index ffe86d3de9e..d7242c05010 100644 --- a/mmcv/fileio/parse.py +++ b/mmcv/fileio/parse.py @@ -1,5 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. -def list_from_file(filename, prefix='', offset=0, max_num=0, encoding='utf-8'): + +from io import StringIO + +from .file_client import FileClient + + +def list_from_file(filename, + prefix='', + offset=0, + max_num=0, + encoding='utf-8', + file_client_args=None): """Load a text file and parse the content as a list of strings. Args: @@ -9,13 +20,16 @@ def list_from_file(filename, prefix='', offset=0, max_num=0, encoding='utf-8'): max_num (int): The maximum number of lines to be read, zeros and negatives mean no limitation. encoding (str): Encoding used to open the file. Default utf-8. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. Default: None. Returns: list[str]: A list of strings. """ cnt = 0 item_list = [] - with open(filename, 'r', encoding=encoding) as f: + file_client = FileClient.infer_client(file_client_args, filename) + with StringIO(file_client.get_text(filename, encoding)) as f: for _ in range(offset): f.readline() for line in f: @@ -26,7 +40,10 @@ def list_from_file(filename, prefix='', offset=0, max_num=0, encoding='utf-8'): return item_list -def dict_from_file(filename, key_type=str): +def dict_from_file(filename, + key_type=str, + encoding='utf-8', + file_client_args=None): """Load a text file and parse the content as a dict. Each line of the text file will be two or more columns split by @@ -37,12 +54,16 @@ def dict_from_file(filename, key_type=str): filename(str): Filename. key_type(type): Type of the dict keys. str is user by default and type conversion will be performed if specified. + encoding (str): Encoding used to open the file. Default utf-8. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. Default: None. Returns: dict: The parsed contents. """ mapping = {} - with open(filename, 'r') as f: + file_client = FileClient.infer_client(file_client_args, filename) + with StringIO(file_client.get_text(filename, encoding)) as f: for line in f: items = line.rstrip('\n').split() assert len(items) >= 2 diff --git a/mmcv/runner/checkpoint.py b/mmcv/runner/checkpoint.py index e266608c3ac..0c9436c46b7 100644 --- a/mmcv/runner/checkpoint.py +++ b/mmcv/runner/checkpoint.py @@ -323,7 +323,7 @@ def load_from_pavi(filename, map_location=None): @CheckpointLoader.register_scheme(prefixes='s3://') -def load_from_ceph(filename, map_location=None, backend='ceph'): +def load_from_ceph(filename, map_location=None, backend='petrel'): """load checkpoint through the file path prefixed with s3. In distributed setting, this function download ckpt at all ranks to different temporary directories. @@ -331,20 +331,27 @@ def load_from_ceph(filename, map_location=None, backend='ceph'): Args: filename (str): checkpoint file path with s3 prefix map_location (str, optional): Same as :func:`torch.load`. - backend (str): The storage backend type. Options are "disk", "ceph", - "memcached" and "lmdb". Default: 'ceph' + backend (str): The storage backend type. Options are 'ceph', 'petrel'. + Default: 'petrel'. Returns: dict or OrderedDict: The loaded checkpoint. """ - - allowed_backends = ['ceph'] + allowed_backends = ['ceph', 'petrel'] if backend not in allowed_backends: raise ValueError(f'Load from Backend {backend} is not supported.') - fileclient = FileClient(backend=backend) - buffer = io.BytesIO(fileclient.get(filename)) - checkpoint = torch.load(buffer, map_location=map_location) + # CephClient and PetrelBackend has the same prefix 's3://' and the latter + # will be chosen default. If PetrelBackend can not be instantiated + # successfully, the CephClient will be chosen. + try: + file_client = FileClient(backend=backend) + except ImportError: + allowed_backends.remove(backend) + file_client = FileClient(backend=allowed_backends[0]) + + with io.BytesIO(file_client.get(filename)) as buffer: + checkpoint = torch.load(buffer, map_location=map_location) return checkpoint @@ -506,7 +513,6 @@ def load_checkpoint(model, pair of the regular expression operations. Default: strip the prefix 'module.' by [(r'^module\\.', '')]. - Returns: dict or OrderedDict: The loaded checkpoint. """ @@ -616,7 +622,11 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False): return destination -def save_checkpoint(model, filename, optimizer=None, meta=None): +def save_checkpoint(model, + filename, + optimizer=None, + meta=None, + file_client_args=None): """Save checkpoint to file. The checkpoint will have 3 fields: ``meta``, ``state_dict`` and @@ -627,6 +637,8 @@ def save_checkpoint(model, filename, optimizer=None, meta=None): filename (str): Checkpoint filename. optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. meta (dict, optional): Metadata to be saved in checkpoint. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. Default: None. """ if meta is None: meta = {} @@ -654,6 +666,10 @@ def save_checkpoint(model, filename, optimizer=None, meta=None): checkpoint['optimizer'][name] = optim.state_dict() if filename.startswith('pavi://'): + if file_client_args is not None: + raise ValueError( + 'file_client_args should be "None" if filename starts with' + '"pavi://"') try: from pavi import modelcloud from pavi import exception @@ -674,8 +690,10 @@ def save_checkpoint(model, filename, optimizer=None, meta=None): f.flush() model.create_file(checkpoint_file, name=model_name) else: - mmcv.mkdir_or_exist(osp.dirname(filename)) - # immediately flush buffer - with open(filename, 'wb') as f: + file_client = FileClient.infer_client(file_client_args, filename) + if file_client.backend_name == 'disk': + mmcv.mkdir_or_exist(osp.dirname(filename)) + + with io.BytesIO() as f: torch.save(checkpoint, f) - f.flush() + file_client.put(f.getvalue(), filename) diff --git a/mmcv/runner/hooks/checkpoint.py b/mmcv/runner/hooks/checkpoint.py index d99dcb3e623..5c94c9ebb8b 100644 --- a/mmcv/runner/hooks/checkpoint.py +++ b/mmcv/runner/hooks/checkpoint.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. import os +import os.path as osp +from ...fileio import FileClient from ..dist_utils import allreduce_params, master_only from .hook import HOOKS, Hook @@ -18,8 +20,11 @@ class CheckpointHook(Hook): save_optimizer (bool): Whether to save optimizer state_dict in the checkpoint. It is usually used for resuming experiments. Default: True. - out_dir (str, optional): The directory to save checkpoints. If not - specified, ``runner.work_dir`` will be used by default. + out_dir (str, optional): The root directory to save checkpoints. If not + specified, `runner.work_dir` will be used by default. If specified, + the `out_dir` will be the concatenation of `out_dir` and the last + level directory of `runner.work_dir`. + `Changed in version 1.3.15.` max_keep_ckpts (int, optional): The maximum checkpoints to keep. In some cases we want only the latest few checkpoints and would like to delete old ones to save the disk space. @@ -28,6 +33,18 @@ class CheckpointHook(Hook): regardless of interval. sync_buffer (bool): Whether to synchronize buffers in different gpus. Default: False. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. Default: None. + `New in version 1.3.15.` + + .. warning:: + Before v1.3.15, the `out_dir` argument indicates the path where the + checkpoint is stored. However, in v1.3.15 and later, `out_dir` + indicates the root directory and the final path to save checkpoint is + the concatenation of out_dir and the last level directory of + `runner.work_dir`. Suppose the value of `out_dir` is "/path/of/A" and + the value of `runner.work_dir` is "/path/of/B", then the final path + will be "/path/of/A/B". """ def __init__(self, @@ -38,6 +55,7 @@ def __init__(self, max_keep_ckpts=-1, save_last=True, sync_buffer=False, + file_client_args=None, **kwargs): self.interval = interval self.by_epoch = by_epoch @@ -47,11 +65,25 @@ def __init__(self, self.save_last = save_last self.args = kwargs self.sync_buffer = sync_buffer + self.file_client_args = file_client_args def before_run(self, runner): if not self.out_dir: self.out_dir = runner.work_dir + if self.out_dir != runner.work_dir: + # The final `self.out_dir` is the concatenation of `self.out_dir` + # and the last level directory of `runner.work_dir` + basename = osp.basename(runner.work_dir.rstrip(osp.sep)) + self.out_dir = osp.join(self.out_dir, basename) + + self.file_client = FileClient.infer_client(self.file_client_args, + self.out_dir) + # disable the create_symlink option when the backend is not + # HardDiskBackend + if self.file_client.backend_name != 'disk': + self.args['create_symlink'] = False + def after_train_epoch(self, runner): if not self.by_epoch: return @@ -98,8 +130,8 @@ def _save_checkpoint(self, runner): for _step in redundant_ckpts: ckpt_path = os.path.join(self.out_dir, filename_tmpl.format(_step)) - if os.path.exists(ckpt_path): - os.remove(ckpt_path) + if self.file_client.check_exist(ckpt_path): + self.file_client.remove(ckpt_path) else: break diff --git a/mmcv/runner/hooks/logger/text.py b/mmcv/runner/hooks/logger/text.py index 40a619e5ef1..d6d7839af9c 100644 --- a/mmcv/runner/hooks/logger/text.py +++ b/mmcv/runner/hooks/logger/text.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import datetime +import os import os.path as osp from collections import OrderedDict @@ -7,6 +8,8 @@ import torch.distributed as dist import mmcv +from mmcv.fileio.file_client import FileClient +from mmcv.utils import scandir from ..hook import HOOKS from .base import LoggerHook @@ -27,6 +30,20 @@ class TextLoggerHook(LoggerHook): interval_exp_name (int): Logging interval for experiment name. This feature is to help users conveniently get the experiment information from screen or log file. Default: 1000. + out_dir (str, optional): Logs are saved in `runner.work_dir` default. + If `out_dir` is specified, logs will be copied to a new directory + which is the concatenation of `out_dir` and the last level + directory of `runner.work_dir`. Default: None. + `New in version 1.3.15.` + out_suffix (str, list[str]): Those filenames ending with `out_suffix` + will be copied to `out_dir`. Default: ['.log.json', '.log', '.py']. + `New in version 1.3.15.` + keep_log (bool): Whether to keep local log when out_dir is specified. + If False, the local log will be removed. Default: True. + `New in version 1.3.15.` + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. Default: None. + `New in version 1.3.15.` """ def __init__(self, @@ -34,15 +51,37 @@ def __init__(self, interval=10, ignore_last=True, reset_flag=False, - interval_exp_name=1000): + interval_exp_name=1000, + out_dir=None, + out_suffix=['.log.json', '.log', '.py'], + keep_log=True, + file_client_args=None): super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag, by_epoch) self.by_epoch = by_epoch self.time_sec_tot = 0 self.interval_exp_name = interval_exp_name + if out_dir is None and file_client_args is not None: + raise ValueError( + 'file_client_args should be "None" when `out_dir` is not' + 'specified.') + self.out_dir = out_dir + self.out_suffix = out_suffix + self.keep_log = keep_log + if self.out_dir is not None: + self.file_client = FileClient.infer_client(file_client_args, + self.out_dir) + def before_run(self, runner): super(TextLoggerHook, self).before_run(runner) + + if self.out_dir is not None: + # The final `self.out_dir` is the concatenation of `self.out_dir` + # and the last level directory of `runner.work_dir` + basename = osp.basename(runner.work_dir.rstrip(osp.sep)) + self.out_dir = osp.join(self.out_dir, basename) + self.start_iter = runner.iter self.json_log_path = osp.join(runner.work_dir, f'{runner.timestamp}.log.json') @@ -177,3 +216,15 @@ def log(self, runner): self._log_info(log_dict, runner) self._dump_log(log_dict, runner) return log_dict + + def after_run(self, runner): + # copy or upload logs to self.out_dir + if self.out_dir is not None: + for filename in scandir(runner.work_dir, self.out_suffix, True): + local_filepath = osp.join(runner.work_dir, filename) + out_filepath = osp.join(self.out_dir, filename) + with open(local_filepath, 'r') as f: + self.file_client.put_text(f.read(), out_filepath) + + if not self.keep_log: + os.remove(local_filepath) diff --git a/tests/test_fileclient.py b/tests/test_fileclient.py index 80357cf31d3..223ec14720e 100644 --- a/tests/test_fileclient.py +++ b/tests/test_fileclient.py @@ -24,6 +24,18 @@ def Get(self, filepath): return content +class MockPetrelClient: + + def __init__(self, enable_mc=True, enable_multi_cluster=False): + self.enable_mc = enable_mc + self.enable_multi_cluster = enable_multi_cluster + + def Get(self, filepath): + with open(filepath, 'rb') as f: + content = f.read() + return content + + class MockMemcachedClient: def __init__(self, server_list_cfg, client_cfg): @@ -103,16 +115,11 @@ def test_ceph_backend(self): ceph_backend.client._client.Get.assert_called_with( str(self.img_path).replace(str(self.test_data_dir), ceph_path)) - @patch('petrel_client.client.Client', MockS3Client) - def test_petrel_backend(self): - petrel_backend = FileClient('petrel') - - # input path is Path object - with pytest.raises(NotImplementedError): - petrel_backend.get_text(self.text_path) - # input path is str - with pytest.raises(NotImplementedError): - petrel_backend.get_text(str(self.text_path)) + @patch('petrel_client.client.Client', MockPetrelClient) + @pytest.mark.parametrize('backend,prefixes', [('petrel', None), + (None, 's3')]) + def test_petrel_backend(self, backend, prefixes): + petrel_backend = FileClient(backend=backend, prefixes=prefixes) # input path is Path object img_bytes = petrel_backend.get(self.img_path) @@ -137,6 +144,11 @@ def test_petrel_backend(self): assert img.shape == self.img_shape petrel_backend.client._client.Get.assert_called_with( str(self.img_path).replace(str(self.test_data_dir), petrel_path)) + # remove a file + petrel_backend.client._client.delete = MagicMock() + petrel_backend.remove(self.img_path) + petrel_backend.client._client.delete.assert_called_with( + str(self.img_path).replace(str(self.test_data_dir), petrel_path)) @patch('mc.MemcachedClient.GetInstance', MockMemcachedClient) @patch('mc.pyvector', MagicMock) @@ -182,8 +194,10 @@ def test_lmdb_backend(self): img = mmcv.imfrombytes(img_bytes) assert img.shape == (120, 125, 3) - def test_http_backend(self): - http_backend = FileClient('http') + @pytest.mark.parametrize('backend,prefixes', [('http', None), + (None, 'http')]) + def test_http_backend(self, backend, prefixes): + http_backend = FileClient(backend=backend, prefixes=prefixes) img_url = 'https://raw.githubusercontent.com/open-mmlab/mmcv/' \ 'master/tests/data/color.jpg' text_url = 'https://raw.githubusercontent.com/open-mmlab/mmcv/' \ @@ -208,6 +222,48 @@ def test_http_backend(self): value_buf = http_backend.get_text(text_url) assert self.text_path.open('r').read() == value_buf + def test_parse_uri_prefix(self): + # input path is None + with pytest.raises(AssertionError): + FileClient.parse_uri_prefix(None) + # input path is list + with pytest.raises(AssertionError): + FileClient.parse_uri_prefix([]) + + # input path is Path object + assert FileClient.parse_uri_prefix(self.img_path) is None + # input path is str + assert FileClient.parse_uri_prefix(str(self.img_path)) is None + + # input path starts with https + img_url = 'https://raw.githubusercontent.com/open-mmlab/mmcv/' \ + 'master/tests/data/color.jpg' + assert FileClient.parse_uri_prefix(img_url) == 'https' + + # input path starts with s3 + img_url = 's3://your_bucket/img.png' + assert FileClient.parse_uri_prefix(img_url) == 's3' + + # input path starts with clusterName:s3 + img_url = 'clusterName:s3://your_bucket/img.png' + assert FileClient.parse_uri_prefix(img_url) == 's3' + + def test_infer_client(self): + # HardDiskBackend + file_client_args = {'backend': 'disk'} + client = FileClient.infer_client(file_client_args) + assert client.backend_name == 'disk' + client = FileClient.infer_client(uri=self.img_path) + assert client.backend_name == 'disk' + + # PetrelBackend + file_client_args = {'backend': 'petrel'} + client = FileClient.infer_client(file_client_args) + assert client.backend_name == 'petrel' + uri = 's3://user_data' + client = FileClient.infer_client(uri=uri) + assert client.backend_name == 'petrel' + def test_register_backend(self): # name must be a string @@ -235,7 +291,7 @@ class ExampleBackend(BaseStorageBackend): def get(self, filepath): return filepath - def get_text(self, filepath): + def get_text(self, filepath, encoding='utf-8'): return filepath FileClient.register_backend('example', ExampleBackend) @@ -249,7 +305,7 @@ class Example2Backend(BaseStorageBackend): def get(self, filepath): return 'bytes2' - def get_text(self, filepath): + def get_text(self, filepath, encoding='utf-8'): return 'text2' # force=False @@ -267,7 +323,7 @@ class Example3Backend(BaseStorageBackend): def get(self, filepath): return 'bytes3' - def get_text(self, filepath): + def get_text(self, filepath, encoding='utf-8'): return 'text3' example_backend = FileClient('example3') @@ -284,7 +340,7 @@ class Example4Backend(BaseStorageBackend): def get(self, filepath): return 'bytes4' - def get_text(self, filepath): + def get_text(self, filepath, encoding='utf-8'): return 'text4' @FileClient.register_backend(name='example3', force=True) @@ -293,9 +349,78 @@ class Example5Backend(BaseStorageBackend): def get(self, filepath): return 'bytes5' - def get_text(self, filepath): + def get_text(self, filepath, encoding='utf-8'): return 'text5' example_backend = FileClient('example3') assert example_backend.get(self.img_path) == 'bytes5' assert example_backend.get_text(self.text_path) == 'text5' + + # prefixes is a str + class Example6Backend(BaseStorageBackend): + + def get(self, filepath): + return 'bytes6' + + def get_text(self, filepath, encoding='utf-8'): + return 'text6' + + FileClient.register_backend( + 'example4', + Example6Backend, + force=True, + prefixes='example4_prefix') + example_backend = FileClient('example4') + assert example_backend.get(self.img_path) == 'bytes6' + assert example_backend.get_text(self.text_path) == 'text6' + example_backend = FileClient(prefixes='example4_prefix') + assert example_backend.get(self.img_path) == 'bytes6' + assert example_backend.get_text(self.text_path) == 'text6' + example_backend = FileClient('example4', prefixes='example4_prefix') + assert example_backend.get(self.img_path) == 'bytes6' + assert example_backend.get_text(self.text_path) == 'text6' + + # prefixes is a list of str + class Example7Backend(BaseStorageBackend): + + def get(self, filepath): + return 'bytes7' + + def get_text(self, filepath, encoding='utf-8'): + return 'text7' + + FileClient.register_backend( + 'example5', + Example7Backend, + force=True, + prefixes=['example5_prefix1', 'example5_prefix2']) + example_backend = FileClient('example5') + assert example_backend.get(self.img_path) == 'bytes7' + assert example_backend.get_text(self.text_path) == 'text7' + example_backend = FileClient(prefixes='example5_prefix1') + assert example_backend.get(self.img_path) == 'bytes7' + assert example_backend.get_text(self.text_path) == 'text7' + example_backend = FileClient(prefixes='example5_prefix2') + assert example_backend.get(self.img_path) == 'bytes7' + assert example_backend.get_text(self.text_path) == 'text7' + + # backend has a higher priority than prefixes + class Example8Backend(BaseStorageBackend): + + def get(self, filepath): + return 'bytes8' + + def get_text(self, filepath, encoding='utf-8'): + return 'text8' + + FileClient.register_backend( + 'example6', + Example8Backend, + force=True, + prefixes='example6_prefix') + example_backend = FileClient('example6') + assert example_backend.get(self.img_path) == 'bytes8' + assert example_backend.get_text(self.text_path) == 'text8' + example_backend = FileClient('example6', prefixes='example4_prefix') + assert example_backend.get(self.img_path) == 'bytes8' + assert example_backend.get_text(self.text_path) == 'text8' diff --git a/tests/test_fileio.py b/tests/test_fileio.py index a9d70f515a2..5b701d6da9d 100644 --- a/tests/test_fileio.py +++ b/tests/test_fileio.py @@ -1,11 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. import os import os.path as osp +import sys import tempfile +from unittest.mock import MagicMock, patch import pytest import mmcv +from mmcv.fileio.file_client import HTTPBackend, PetrelBackend + +sys.modules['petrel_client'] = MagicMock() +sys.modules['petrel_client.client'] = MagicMock() def _test_handler(file_format, test_obj, str_checker, mode='r+'): @@ -13,7 +19,7 @@ def _test_handler(file_format, test_obj, str_checker, mode='r+'): dump_str = mmcv.dump(test_obj, file_format=file_format) str_checker(dump_str) - # load/dump with filenames + # load/dump with filenames from disk tmp_filename = osp.join(tempfile.gettempdir(), 'mmcv_test_dump') mmcv.dump(test_obj, tmp_filename, file_format=file_format) assert osp.isfile(tmp_filename) @@ -21,6 +27,13 @@ def _test_handler(file_format, test_obj, str_checker, mode='r+'): assert load_obj == test_obj os.remove(tmp_filename) + # load/dump with filename from petrel + method = 'put' if 'b' in mode else 'put_text' + with patch.object(PetrelBackend, method, return_value=None) as mock_method: + filename = 's3://path/of/your/file' + mmcv.dump(test_obj, filename, file_format=file_format) + mock_method.assert_called() + # json load/dump with a file-like object with tempfile.NamedTemporaryFile(mode, delete=False) as f: tmp_filename = f.name @@ -122,6 +135,7 @@ def dump_to_str(self, obj, **kwargs): def test_list_from_file(): + # get list from disk filename = osp.join(osp.dirname(__file__), 'data/filelist.txt') filelist = mmcv.list_from_file(filename) assert filelist == ['1.jpg', '2.jpg', '3.jpg', '4.jpg', '5.jpg'] @@ -134,10 +148,64 @@ def test_list_from_file(): filelist = mmcv.list_from_file(filename, offset=3, max_num=3) assert filelist == ['4.jpg', '5.jpg'] + # get list from http + with patch.object( + HTTPBackend, 'get_text', return_value='1.jpg\n2.jpg\n3.jpg'): + filename = 'http://path/of/your/file' + filelist = mmcv.list_from_file( + filename, file_client_args={'backend': 'http'}) + assert filelist == ['1.jpg', '2.jpg', '3.jpg'] + filelist = mmcv.list_from_file( + filename, file_client_args={'prefixes': 'http'}) + assert filelist == ['1.jpg', '2.jpg', '3.jpg'] + filelist = mmcv.list_from_file(filename) + assert filelist == ['1.jpg', '2.jpg', '3.jpg'] + + # get list from petrel + with patch.object( + PetrelBackend, 'get_text', return_value='1.jpg\n2.jpg\n3.jpg'): + filename = 's3://path/of/your/file' + filelist = mmcv.list_from_file( + filename, file_client_args={'backend': 'petrel'}) + assert filelist == ['1.jpg', '2.jpg', '3.jpg'] + filelist = mmcv.list_from_file( + filename, file_client_args={'prefixes': 's3'}) + assert filelist == ['1.jpg', '2.jpg', '3.jpg'] + filelist = mmcv.list_from_file(filename) + assert filelist == ['1.jpg', '2.jpg', '3.jpg'] + def test_dict_from_file(): + # get dict from disk filename = osp.join(osp.dirname(__file__), 'data/mapping.txt') mapping = mmcv.dict_from_file(filename) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} mapping = mmcv.dict_from_file(filename, key_type=int) assert mapping == {1: 'cat', 2: ['dog', 'cow'], 3: 'panda'} + + # get dict from http + with patch.object( + HTTPBackend, 'get_text', return_value='1 cat\n2 dog cow\n3 panda'): + filename = 'http://path/of/your/file' + mapping = mmcv.dict_from_file( + filename, file_client_args={'backend': 'http'}) + assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} + mapping = mmcv.dict_from_file( + filename, file_client_args={'prefixes': 'http'}) + assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} + mapping = mmcv.dict_from_file(filename) + assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} + + # get dict from petrel + with patch.object( + PetrelBackend, 'get_text', + return_value='1 cat\n2 dog cow\n3 panda'): + filename = 's3://path/of/your/file' + mapping = mmcv.dict_from_file( + filename, file_client_args={'backend': 'petrel'}) + assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} + mapping = mmcv.dict_from_file( + filename, file_client_args={'prefixes': 's3'}) + assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} + mapping = mmcv.dict_from_file(filename) + assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} diff --git a/tests/test_runner/test_checkpoint.py b/tests/test_runner/test_checkpoint.py index 75aa9ddd75b..9856724318c 100644 --- a/tests/test_runner/test_checkpoint.py +++ b/tests/test_runner/test_checkpoint.py @@ -1,17 +1,22 @@ import sys from collections import OrderedDict from tempfile import TemporaryDirectory -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest import torch import torch.nn as nn +import torch.optim as optim from torch.nn.parallel import DataParallel +from mmcv.fileio.file_client import PetrelBackend from mmcv.parallel.registry import MODULE_WRAPPERS from mmcv.runner.checkpoint import (_load_checkpoint_with_prefix, get_state_dict, load_checkpoint, - load_from_pavi) + load_from_pavi, save_checkpoint) + +sys.modules['petrel_client'] = MagicMock() +sys.modules['petrel_client.client'] = MagicMock() @MODULE_WRAPPERS.register_module() @@ -392,3 +397,36 @@ def load_from_abc(filename, map_location): filename = 'a/b/c/d' loader = CheckpointLoader._get_checkpoint_loader(filename) assert loader.__name__ == 'load_from_abc' + + +def test_save_checkpoint(tmp_path): + model = Model() + optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + # meta is not a dict + with pytest.raises(TypeError): + save_checkpoint(model, '/path/of/your/filename', meta='invalid type') + + # 1. save to disk + filename = str(tmp_path / 'checkpoint1.pth') + save_checkpoint(model, filename) + + filename = str(tmp_path / 'checkpoint2.pth') + save_checkpoint(model, filename, optimizer) + + filename = str(tmp_path / 'checkpoint3.pth') + save_checkpoint(model, filename, meta={'test': 'test'}) + + filename = str(tmp_path / 'checkpoint4.pth') + save_checkpoint(model, filename, file_client_args={'backend': 'disk'}) + + # 2. save to petrel oss + with patch.object(PetrelBackend, 'put') as mock_method: + filename = 's3://path/of/your/checkpoint1.pth' + save_checkpoint(model, filename) + mock_method.assert_called() + + with patch.object(PetrelBackend, 'put') as mock_method: + filename = 's3://path//of/your/checkpoint2.pth' + save_checkpoint( + model, filename, file_client_args={'backend': 'petrel'}) + mock_method.assert_called() diff --git a/tests/test_runner/test_hooks.py b/tests/test_runner/test_hooks.py index bb0e7585049..f4a97c887c1 100644 --- a/tests/test_runner/test_hooks.py +++ b/tests/test_runner/test_hooks.py @@ -12,7 +12,7 @@ import shutil import sys import tempfile -from unittest.mock import MagicMock, call +from unittest.mock import MagicMock, call, patch import pytest import torch @@ -20,6 +20,7 @@ from torch.nn.init import constant_ from torch.utils.data import DataLoader +from mmcv.fileio.file_client import PetrelBackend from mmcv.runner import (CheckpointHook, DvcliveLoggerHook, EMAHook, Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook, @@ -34,8 +35,11 @@ OneCycleLrUpdaterHook, StepLrUpdaterHook) +sys.modules['petrel_client'] = MagicMock() +sys.modules['petrel_client.client'] = MagicMock() -def test_checkpoint_hook(): + +def test_checkpoint_hook(tmp_path): """xdoctest -m tests/test_runner/test_hooks.py test_checkpoint_hook.""" # test epoch based runner @@ -49,6 +53,23 @@ def test_checkpoint_hook(): runner.work_dir, 'epoch_1.pth') shutil.rmtree(runner.work_dir) + # test petrel oss when type of runner is EpochBasedRunner + runner = _build_demo_runner('EpochBasedRunner', max_epochs=4) + runner.meta = dict() + out_dir = 's3://user/data' + with patch.object(PetrelBackend, 'put') as mock_put, \ + patch.object(PetrelBackend, 'remove') as mock_remove: + checkpointhook = CheckpointHook( + interval=1, out_dir=out_dir, by_epoch=True, max_keep_ckpts=2) + runner.register_hook(checkpointhook) + runner.run([loader], [('train', 1)]) + basename = osp.basename(runner.work_dir.rstrip(osp.sep)) + assert runner.meta['hook_msgs']['last_ckpt'] == osp.join( + out_dir, basename, 'epoch_4.pth') + mock_put.assert_called() + mock_remove.assert_called() + shutil.rmtree(runner.work_dir) + # test iter based runner runner = _build_demo_runner( 'IterBasedRunner', max_iters=1, max_epochs=None) @@ -60,6 +81,24 @@ def test_checkpoint_hook(): runner.work_dir, 'iter_1.pth') shutil.rmtree(runner.work_dir) + # test petrel oss when type of runner is IterBasedRunner + runner = _build_demo_runner( + 'IterBasedRunner', max_iters=4, max_epochs=None) + runner.meta = dict() + out_dir = 's3://user/data' + with patch.object(PetrelBackend, 'put') as mock_put, \ + patch.object(PetrelBackend, 'remove') as mock_remove: + checkpointhook = CheckpointHook( + interval=1, out_dir=out_dir, by_epoch=False, max_keep_ckpts=2) + runner.register_hook(checkpointhook) + runner.run([loader], [('train', 1)]) + basename = osp.basename(runner.work_dir.rstrip(osp.sep)) + assert runner.meta['hook_msgs']['last_ckpt'] == osp.join( + out_dir, basename, 'iter_4.pth') + mock_put.assert_called() + mock_remove.assert_called() + shutil.rmtree(runner.work_dir) + def test_ema_hook(): """xdoctest -m tests/test_hooks.py test_ema_hook."""