diff --git a/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py b/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py index e541e76e634fc6..ff9ff2ee2a9c3e 100755 --- a/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py @@ -48,8 +48,10 @@ def _prepare_for_model(self): logger.info("mp's parameters is ready") def _pre_forward(self, *inputs, **kwargs): - mp_configs = self._strategy.hybrid_configs["mp_configs"] - need_broadcast_data = mp_configs.need_broadcast_data + need_broadcast_data = True + if self._strategy is not None: + mp_configs = self._strategy.hybrid_configs["mp_configs"] + need_broadcast_data = mp_configs.need_broadcast_data if need_broadcast_data: logger.debug("mp start broadcast input data") return broadcast_input_data(self._hcg, *inputs, **kwargs)