Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Upload checkpoints and logs to ceph #1375

Merged
merged 103 commits into from
Oct 24, 2021
Merged
Show file tree
Hide file tree
Changes from 100 commits
Commits
Show all changes
103 commits
Select commit Hold shift + click to select a range
b2a2257
[Feature] Choose storage backend by the prefix of filepath
zhouzaida Sep 9, 2021
073f73e
refactor FileClient and add unittest
zhouzaida Sep 10, 2021
dfb9fc4
support loading from different backends
zhouzaida Sep 11, 2021
48cfdad
polish docstring
zhouzaida Sep 21, 2021
c2c9fc0
fix unittet
zhouzaida Sep 21, 2021
d641a8c
rename attribute str_like_obj to is_str_like_obj
zhouzaida Sep 22, 2021
bb45ee5
[Docs] Upload checkpoint to petrel oss
zhouzaida Sep 22, 2021
68f0ab6
add infer_client method
zhouzaida Sep 23, 2021
2f56b3c
merge load-from-backend
zhouzaida Sep 23, 2021
3fe48b8
Support uploading checkpoint to petrel oss
zhouzaida Sep 23, 2021
31caf8e
add check_exist method
zhouzaida Sep 23, 2021
d202465
Merge branch 'load-from-backend' into upload-ckpt-to-ceph
zhouzaida Sep 23, 2021
f75511e
refactor CheckpointHook
zhouzaida Sep 23, 2021
8461a37
support uploading logs to ceph
zhouzaida Sep 24, 2021
7e7a80f
rename var client to file_client
zhouzaida Sep 24, 2021
44829ca
Merge branch 'load-from-backend' into upload-ckpt-to-ceph
zhouzaida Sep 24, 2021
aa8274b
polish docstring
zhouzaida Sep 26, 2021
fd70556
Merge branch 'load-from-backend' into upload-ckpt-to-ceph
zhouzaida Sep 26, 2021
6af56b1
enhance load_from_ceph
zhouzaida Sep 26, 2021
9600d18
refactor load_from_ceph
zhouzaida Sep 26, 2021
6983002
refactor TextLoggerHook
zhouzaida Sep 26, 2021
e91c93e
change the meaning of out_dir argument
zhouzaida Sep 27, 2021
5940864
fix test_checkpoint_hook.py
zhouzaida Sep 27, 2021
bb4712d
add join_paths method
zhouzaida Sep 27, 2021
2409531
Merge branch 'master' of https://github.com/open-mmlab/mmcv into load…
zhouzaida Sep 27, 2021
c1ae3ef
Merge branch 'load-from-backend' into upload-ckpt-to-ceph
zhouzaida Sep 27, 2021
d4b6d96
remove join_paths and add _format_path
zhouzaida Sep 28, 2021
824cff3
Merge branch 'master' of https://github.com/open-mmlab/mmcv into load…
zhouzaida Oct 3, 2021
767f7fb
enhance unittest
zhouzaida Oct 3, 2021
a929e90
Merge branch 'load-from-backend' into upload-ckpt-to-ceph
zhouzaida Oct 3, 2021
b930678
refactor unittest
zhouzaida Oct 3, 2021
6704a15
Merge branch 'load-from-backend' into upload-ckpt-to-ceph
zhouzaida Oct 3, 2021
8020802
add a unittest for EvalHook when file backend is petrel
zhouzaida Oct 4, 2021
1752698
singleton pattern
zhouzaida Oct 4, 2021
fb9567c
fix test_clientio.py
zhouzaida Oct 4, 2021
00505f8
deprecate CephBackend
zhouzaida Oct 4, 2021
5fdcedc
Merge branch 'load-from-backend' into upload-ckpt-to-ceph
zhouzaida Oct 4, 2021
9d0b6ec
add warning in load_from_ceph
zhouzaida Oct 4, 2021
6980bf3
fix type of out_suffix
zhouzaida Oct 5, 2021
225d3a6
enhance docstring
zhouzaida Oct 6, 2021
22644da
refactor unittest for petrel
zhouzaida Oct 6, 2021
058b7e8
refactor unittest for disk backend
zhouzaida Oct 6, 2021
1692678
update io.md
zhouzaida Oct 6, 2021
01b9807
add concat_paths method
zhouzaida Oct 6, 2021
9491bac
Merge branch 'load-from-backend' into upload-ckpt-to-ceph
zhouzaida Oct 6, 2021
b44bfc6
fix CI
zhouzaida Oct 6, 2021
016c879
mock check_exist
zhouzaida Oct 7, 2021
fed5a39
improve docstring
zhouzaida Oct 8, 2021
4959687
improve docstring
zhouzaida Oct 8, 2021
fb0e21f
Merge branch 'load-from-backend' into upload-ckpt-to-ceph
zhouzaida Oct 8, 2021
9a0c535
improve docstring
zhouzaida Oct 8, 2021
e0dcad9
improve docstring
zhouzaida Oct 8, 2021
aea920a
add isdir and copyfile for file backend
zhouzaida Oct 10, 2021
4eda86f
Merge branch 'load-from-backend' into upload-ckpt-to-ceph
zhouzaida Oct 10, 2021
6412103
delete copyfile and add get_local_path
zhouzaida Oct 11, 2021
c557ca3
Merge branch 'master' of https://github.com/open-mmlab/mmcv into load…
zhouzaida Oct 12, 2021
eeda74c
remove isdir method of petrel
zhouzaida Oct 12, 2021
ad52428
fix typo
zhouzaida Oct 12, 2021
b846cdb
Merge branch 'load-from-backend' into upload-ckpt-to-ceph
zhouzaida Oct 12, 2021
54597f4
rename check_exists to exists
zhouzaida Oct 12, 2021
097bae5
refactor code and polish docstring
zhouzaida Oct 12, 2021
9f78448
fix windows ci
zhouzaida Oct 12, 2021
941a884
add comment and polish docstring
zhouzaida Oct 13, 2021
3105687
Merge branch 'load-from-backend' into upload-ckpt-to-ceph
zhouzaida Oct 13, 2021
198a465
polish docstring
zhouzaida Oct 14, 2021
cda02b7
Merge branch 'load-from-backend' into upload-ckpt-to-ceph
zhouzaida Oct 14, 2021
7ad8ab7
polish docstring
zhouzaida Oct 14, 2021
e0d6a83
rename _path_mapping to _map_path
zhouzaida Oct 15, 2021
59aa354
Merge branch 'load-from-backend' into upload-ckpt-to-ceph
zhouzaida Oct 15, 2021
ae0cdd3
polish docstring and fix typo
zhouzaida Oct 15, 2021
bd3b322
Merge branch 'load-from-backend' into upload-ckpt-to-ceph
zhouzaida Oct 15, 2021
a2e0162
refactor get_local_path
zhouzaida Oct 16, 2021
bfc5ecc
fix conflict
zhouzaida Oct 16, 2021
50ba26f
add list_dir_or_file for FileClient
zhouzaida Oct 17, 2021
4ad3bf5
add list_dir_or_file for PetrelBackend
zhouzaida Oct 18, 2021
df207d1
fix windows ci
zhouzaida Oct 18, 2021
d29a88d
Add return docstring
zhouzaida Oct 19, 2021
f18a779
polish docstring
zhouzaida Oct 19, 2021
7b7a380
Merge branch 'load-from-backend' into upload-ckpt-to-ceph
zhouzaida Oct 19, 2021
669cfee
fix typo
zhouzaida Oct 19, 2021
b6eb5d1
fix typo
zhouzaida Oct 19, 2021
0c8fbc3
Merge branch 'load-from-backend' into upload-ckpt-to-ceph
zhouzaida Oct 19, 2021
150d504
fix typo
zhouzaida Oct 19, 2021
b371c83
Merge branch 'load-from-backend' into upload-ckpt-to-ceph
zhouzaida Oct 19, 2021
23ba993
fix error when mocking PetrelBackend
zhouzaida Oct 20, 2021
208ff82
deprecate the conversion from Path to str
zhouzaida Oct 20, 2021
9ecfc12
add docs for loading checkpoints with FileClient
zhouzaida Oct 22, 2021
947d549
rename keep_log to keep_local
zhouzaida Oct 22, 2021
e921770
Merge branch 'load-from-backend' into upload-ckpt-to-ceph
zhouzaida Oct 22, 2021
38559f1
refactor map_path
zhouzaida Oct 22, 2021
e75685c
Merge branch 'load-from-backend' into upload-ckpt-to-ceph
zhouzaida Oct 22, 2021
ea32388
add _ensure_methods to ensure methods have been implemented
zhouzaida Oct 22, 2021
f413480
Merge branch 'load-from-backend' into upload-ckpt-to-ceph
zhouzaida Oct 22, 2021
a8cc11d
fix list_dir_or_file
zhouzaida Oct 22, 2021
7ef1652
Merge branch 'load-from-backend' into upload-ckpt-to-ceph
zhouzaida Oct 22, 2021
e66fe61
rename _ensure_method_implemented to has_method
zhouzaida Oct 23, 2021
196a56d
Merge branch 'load-from-backend' into upload-ckpt-to-ceph
zhouzaida Oct 23, 2021
6987038
fix conflict
zhouzaida Oct 23, 2021
aa5611c
Merge branch 'load-from-backend' into upload-ckpt-to-ceph
zhouzaida Oct 23, 2021
100bf55
fix conflict
zhouzaida Oct 23, 2021
a13d078
refactor
zhouzaida Oct 24, 2021
26c8127
polish information
zhouzaida Oct 24, 2021
39a58a3
format information
zhouzaida Oct 24, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions mmcv/fileio/file_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ class CephBackend(BaseStorageBackend):
will be replaced by ``dst``. Default: None.

.. warning::
:class:`CephBackend` will be deprecated, please use
:class:`PetrelBackend` instead
:class:`mmcv.fileio.file_client.CephBackend` will be deprecated,
please use :class:`mmcv.fileio.file_client.PetrelBackend` instead.
"""

def __init__(self, path_mapping=None):
Expand Down Expand Up @@ -579,7 +579,7 @@ def isdir(self, filepath: Union[str, Path]) -> bool:
return osp.isdir(filepath)

def isfile(self, filepath: Union[str, Path]) -> bool:
"""Check a ``filepath`` whether it is a file.
"""Check whether a file path is a file.

Args:
filepath (str or Path): Path to be checked whether it is a file.
Expand Down Expand Up @@ -714,7 +714,7 @@ class FileClient:
Note that It can also register other backend accessor with a given name,
prefixes, and backend class. In addition, We use the singleton pattern to
avoid repeated object creation. If the arguments are the same, the same
object is returned.
object will be returned.

Args:
backend (str, optional): The storage backend type. Options are "disk",
Expand Down
56 changes: 42 additions & 14 deletions mmcv/runner/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,28 +323,43 @@ def load_from_pavi(filename, map_location=None):


@CheckpointLoader.register_scheme(prefixes='s3://')
def load_from_ceph(filename, map_location=None, backend='ceph'):
def load_from_ceph(filename, map_location=None, backend='petrel'):
"""load checkpoint through the file path prefixed with s3. In distributed
setting, this function download ckpt at all ranks to different temporary
directories.

Args:
filename (str): checkpoint file path with s3 prefix
map_location (str, optional): Same as :func:`torch.load`.
backend (str): The storage backend type. Options are "disk", "ceph",
"memcached" and "lmdb". Default: 'ceph'
backend (str, optional): The storage backend type. Options are 'ceph',
'petrel'. Default: 'petrel'.

.. warning::
:class:`mmcv.fileio.file_client.CephBackend` will be deprecated,
please use :class:`mmcv.fileio.file_client.PetrelBackend` instead.

Returns:
dict or OrderedDict: The loaded checkpoint.
"""

allowed_backends = ['ceph']
allowed_backends = ['ceph', 'petrel']
if backend not in allowed_backends:
raise ValueError(f'Load from Backend {backend} is not supported.')

fileclient = FileClient(backend=backend)
buffer = io.BytesIO(fileclient.get(filename))
checkpoint = torch.load(buffer, map_location=map_location)
if backend == 'ceph':
warnings.warn(
'CephBackend will be deprecated, please use PetrelBackend instead')

# CephClient and PetrelBackend have the same prefix 's3://' and the latter
# will be chosen as default. If PetrelBackend can not be instantiated
# successfully, the CephClient will be chosen.
try:
file_client = FileClient(backend=backend)
except ImportError:
allowed_backends.remove(backend)
file_client = FileClient(backend=allowed_backends[0])

with io.BytesIO(file_client.get(filename)) as buffer:
checkpoint = torch.load(buffer, map_location=map_location)
return checkpoint


Expand Down Expand Up @@ -506,7 +521,6 @@ def load_checkpoint(model,
pair of the regular expression operations. Default: strip
the prefix 'module.' by [(r'^module\\.', '')].


Returns:
dict or OrderedDict: The loaded checkpoint.
"""
Expand Down Expand Up @@ -616,7 +630,11 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False):
return destination


def save_checkpoint(model, filename, optimizer=None, meta=None):
def save_checkpoint(model,
filename,
optimizer=None,
meta=None,
file_client_args=None):
"""Save checkpoint to file.

The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
Expand All @@ -627,6 +645,10 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
filename (str): Checkpoint filename.
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
meta (dict, optional): Metadata to be saved in checkpoint.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None.
`New in version 1.3.16.`
"""
if meta is None:
meta = {}
Expand Down Expand Up @@ -654,6 +676,10 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
checkpoint['optimizer'][name] = optim.state_dict()

if filename.startswith('pavi://'):
if file_client_args is not None:
raise ValueError(
'file_client_args should be "None" if filename starts with'
f'"pavi://", but got {file_client_args}')
try:
from pavi import modelcloud
from pavi import exception
Expand All @@ -674,8 +700,10 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
f.flush()
model.create_file(checkpoint_file, name=model_name)
else:
mmcv.mkdir_or_exist(osp.dirname(filename))
# immediately flush buffer
with open(filename, 'wb') as f:
file_client = FileClient.infer_client(file_client_args, filename)
if file_client.backend_name == 'disk':
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
mmcv.mkdir_or_exist(osp.dirname(filename))

with io.BytesIO() as f:
torch.save(checkpoint, f)
f.flush()
file_client.put(f.getvalue(), filename)
65 changes: 52 additions & 13 deletions mmcv/runner/hooks/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp

from mmcv.fileio import FileClient
from ..dist_utils import allreduce_params, master_only
from .hook import HOOKS, Hook

Expand All @@ -18,16 +19,32 @@ class CheckpointHook(Hook):
save_optimizer (bool): Whether to save optimizer state_dict in the
checkpoint. It is usually used for resuming experiments.
Default: True.
out_dir (str, optional): The directory to save checkpoints. If not
specified, ``runner.work_dir`` will be used by default.
out_dir (str, optional): The root directory to save checkpoints. If not
specified, ``runner.work_dir`` will be used by default. If
specified, the ``out_dir`` will be the concatenation of ``out_dir``
and the last level directory of ``runner.work_dir``.
`Changed in version 1.3.16.`
max_keep_ckpts (int, optional): The maximum checkpoints to keep.
In some cases we want only the latest few checkpoints and would
like to delete old ones to save the disk space.
Default: -1, which means unlimited.
save_last (bool): Whether to force the last checkpoint to be saved
regardless of interval.
sync_buffer (bool): Whether to synchronize buffers in different
gpus. Default: False.
save_last (bool, optional): Whether to force the last checkpoint to be
saved regardless of interval. Default: True.
sync_buffer (bool, optional): Whether to synchronize buffers in
different gpus. Default: False.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None.
`New in version 1.3.16.`

.. warning::
Before v1.3.16, the ``out_dir`` argument indicates the path where the
checkpoint is stored. However, since v1.3.16, ``out_dir`` indicates the
root directory and the final path to save checkpoint is the
concatenation of ``out_dir`` and the last level directory of
``runner.work_dir``. Suppose the value of ``out_dir`` is "/path/of/A"
and the value of ``runner.work_dir`` is "/path/of/B", then the final
path will be "/path/of/A/B".
"""

def __init__(self,
Expand All @@ -38,6 +55,7 @@ def __init__(self,
max_keep_ckpts=-1,
save_last=True,
sync_buffer=False,
file_client_args=None,
**kwargs):
self.interval = interval
self.by_epoch = by_epoch
Expand All @@ -47,11 +65,31 @@ def __init__(self,
self.save_last = save_last
self.args = kwargs
self.sync_buffer = sync_buffer
self.file_client_args = file_client_args

def before_run(self, runner):
if not self.out_dir:
self.out_dir = runner.work_dir

self.file_client = FileClient.infer_client(self.file_client_args,
self.out_dir)

# if `self.out_dir` is not equal to `runner.work_dir`, it means that
# `self.out_dir` is set so the final `self.out_dir` is the
# concatenation of `self.out_dir` and the last level directory of
# `runner.work_dir`
if self.out_dir != runner.work_dir:
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
self.out_dir = self.file_client.concat_paths(
self.out_dir, basename)

runner.logger.info(f'checkpoints will be saved to {self.out_dir}')

# disable the create_symlink option when the backend is not
# HardDiskBackend
if self.file_client.backend_name != 'disk':
self.args['create_symlink'] = False
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved

def after_train_epoch(self, runner):
if not self.by_epoch:
return
Expand Down Expand Up @@ -81,8 +119,9 @@ def _save_checkpoint(self, runner):
cur_ckpt_filename = self.args.get(
'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
runner.meta.setdefault('hook_msgs', dict())
runner.meta['hook_msgs']['last_ckpt'] = os.path.join(
self.out_dir, cur_ckpt_filename)
runner.meta['hook_msgs'][
'last_ckpt'] = self.file_client.concat_paths(
self.out_dir, cur_ckpt_filename)
# remove other checkpoints
if self.max_keep_ckpts > 0:
if self.by_epoch:
Expand All @@ -96,10 +135,10 @@ def _save_checkpoint(self, runner):
-self.interval)
filename_tmpl = self.args.get('filename_tmpl', name)
for _step in redundant_ckpts:
ckpt_path = os.path.join(self.out_dir,
filename_tmpl.format(_step))
if os.path.exists(ckpt_path):
os.remove(ckpt_path)
ckpt_path = self.file_client.concat_paths(
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
self.out_dir, filename_tmpl.format(_step))
if self.file_client.isfile(ckpt_path):
self.file_client.remove(ckpt_path)
else:
break

Expand Down
50 changes: 45 additions & 5 deletions mmcv/runner/hooks/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
import warnings
from math import inf
Expand All @@ -8,6 +7,7 @@
from torch.nn.modules.batchnorm import _BatchNorm
from torch.utils.data import DataLoader

from mmcv.fileio import FileClient
from mmcv.utils import is_seq_of
from .hook import Hook
from .logger import LoggerHook
Expand Down Expand Up @@ -54,6 +54,14 @@ class EvalHook(Hook):
less_keys (List[str] | None, optional): Metric keys that will be
inferred by 'less' comparison rule. If ``None``, _default_less_keys
will be used. (default: ``None``)
out_dir (str, optional): The root directory to save checkpoints. If not
specified, `runner.work_dir` will be used by default. If specified,
the `out_dir` will be the concatenation of `out_dir` and the last
level directory of `runner.work_dir`.
`New in version 1.3.16.`
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details. Default: None.
`New in version 1.3.16.`
**eval_kwargs: Evaluation arguments fed into the evaluate function of
the dataset.

Expand Down Expand Up @@ -84,6 +92,8 @@ def __init__(self,
test_fn=None,
greater_keys=None,
less_keys=None,
out_dir=None,
file_client_args=None,
**eval_kwargs):
if not isinstance(dataloader, DataLoader):
raise TypeError(f'dataloader must be a pytorch DataLoader, '
Expand Down Expand Up @@ -137,6 +147,9 @@ def __init__(self,
self.best_ckpt_path = None
self._init_rule(rule, self.save_best)

self.out_dir = out_dir
self.file_client_args = file_client_args

def _init_rule(self, rule, key_indicator):
"""Initialize rule, key_indicator, comparison_func, and best score.

Expand Down Expand Up @@ -187,6 +200,21 @@ def _init_rule(self, rule, key_indicator):
self.compare_func = self.rule_map[self.rule]

def before_run(self, runner):
if not self.out_dir:
self.out_dir = runner.work_dir

self.file_client = FileClient.infer_client(self.file_client_args,
self.out_dir)

# if `self.out_dir` is not equal to `runner.work_dir`, it means that
# `self.out_dir` is set so the final `self.out_dir` is the
# concatenation of `self.out_dir` and the last level directory of
# `runner.work_dir`
if self.out_dir != runner.work_dir:
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
self.out_dir = self.file_client.concat_paths(
self.out_dir, basename)

if self.save_best is not None:
if runner.meta is None:
warnings.warn('runner.meta is None. Creating an empty one.')
Expand Down Expand Up @@ -299,15 +327,17 @@ def _save_ckpt(self, runner, key_score):
best_score = key_score
runner.meta['hook_msgs']['best_score'] = best_score

if self.best_ckpt_path and osp.isfile(self.best_ckpt_path):
os.remove(self.best_ckpt_path)
if self.best_ckpt_path and self.file_client.isfile(
self.best_ckpt_path):
self.file_client.remove(self.best_ckpt_path)
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved

best_ckpt_name = f'best_{self.key_indicator}_{current}.pth'
self.best_ckpt_path = osp.join(runner.work_dir, best_ckpt_name)
self.best_ckpt_path = self.file_client.concat_paths(
self.out_dir, best_ckpt_name)
runner.meta['hook_msgs']['best_ckpt'] = self.best_ckpt_path

runner.save_checkpoint(
runner.work_dir, best_ckpt_name, create_symlink=False)
self.out_dir, best_ckpt_name, create_symlink=False)
runner.logger.info(
f'Now best checkpoint is saved as {best_ckpt_name}.')
runner.logger.info(
Expand Down Expand Up @@ -378,6 +408,12 @@ class DistEvalHook(EvalHook):
broadcast_bn_buffer (bool): Whether to broadcast the
buffer(running_mean and running_var) of rank 0 to other rank
before evaluation. Default: True.
out_dir (str, optional): The root directory to save checkpoints. If not
specified, `runner.work_dir` will be used by default. If specified,
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
the `out_dir` will be the concatenation of `out_dir` and the last
level directory of `runner.work_dir`.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details. Default: None.
**eval_kwargs: Evaluation arguments fed into the evaluate function of
the dataset.
"""
Expand All @@ -395,6 +431,8 @@ def __init__(self,
broadcast_bn_buffer=True,
tmpdir=None,
gpu_collect=False,
out_dir=None,
file_client_args=None,
**eval_kwargs):

if test_fn is None:
Expand All @@ -411,6 +449,8 @@ def __init__(self,
test_fn=test_fn,
greater_keys=greater_keys,
less_keys=less_keys,
out_dir=out_dir,
file_client_args=file_client_args,
**eval_kwargs)

self.broadcast_bn_buffer = broadcast_bn_buffer
Expand Down
Loading