From fe1fbcdae904649d54ee8b33e77b9f521572df48 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Thu, 12 Oct 2023 18:23:53 -0500 Subject: [PATCH] fix mp bug (#58037) --- .../distributed/fleet/meta_parallel/tensor_parallel.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py b/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py index e541e76e634fc6..ff9ff2ee2a9c3e 100755 --- a/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py @@ -48,8 +48,10 @@ def _prepare_for_model(self): logger.info("mp's parameters is ready") def _pre_forward(self, *inputs, **kwargs): - mp_configs = self._strategy.hybrid_configs["mp_configs"] - need_broadcast_data = mp_configs.need_broadcast_data + need_broadcast_data = True + if self._strategy is not None: + mp_configs = self._strategy.hybrid_configs["mp_configs"] + need_broadcast_data = mp_configs.need_broadcast_data if need_broadcast_data: logger.debug("mp start broadcast input data") return broadcast_input_data(self._hcg, *inputs, **kwargs)