From 5bf634f3852b05256d2f004c24312331fe17bba8 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Mon, 18 Sep 2023 15:38:34 +0800 Subject: [PATCH] [Fix] fix infer.py _init_collate incompatibility problem --- mmengine/infer/infer.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/mmengine/infer/infer.py b/mmengine/infer/infer.py index 322d885224..39cb3df040 100644 --- a/mmengine/infer/infer.py +++ b/mmengine/infer/infer.py @@ -529,7 +529,22 @@ def preprocess(self, inputs, batch_size, **kwargs): """ try: with FUNCTIONS.switch_scope_and_registry(self.scope) as registry: - collate_fn = registry.get(cfg.test_dataloader.collate_fn) + collate_fn_cfg = cfg.test_dataloader.collate_fn + if isinstance(collate_fn_cfg, dict): + collate_fn_type = collate_fn_cfg.pop('type') + if isinstance(collate_fn_type, str): + collate_fn = registry.get(collate_fn_type) + else: + collate_fn = collate_fn_type + from functools import partial + collate_fn = partial(collate_fn, + **collate_fn_cfg) # type: ignore + elif callable(collate_fn_cfg): + collate_fn = collate_fn_cfg + else: + raise TypeError( + 'collate_fn should be a dict or callable object, ' + f'but got {collate_fn_cfg}') except AttributeError: collate_fn = pseudo_collate return collate_fn # type: ignore