diff --git a/examples/pytorch/llm/src/utils/model.py b/examples/pytorch/llm/src/utils/model.py index 2459232129..2ab075c35d 100644 --- a/examples/pytorch/llm/src/utils/model.py +++ b/examples/pytorch/llm/src/utils/model.py @@ -10,7 +10,7 @@ from torch import dtype as Dtype from swift import get_logger -from .utils import is_local_master +from .utils import is_dist, is_local_master logger = get_logger() @@ -314,14 +314,14 @@ def get_model_tokenizer(model_type: str, model_dir = kwargs.pop('model_dir', None) if model_dir is None: - if not is_local_master(): + if is_dist() and not is_local_master(): dist.barrier() model_dir = model_id if not os.path.exists(model_id): revision = data.get('revision', 'master') model_dir = snapshot_download( model_id, revision, ignore_file_pattern=ignore_file_pattern) - if is_local_master(): + if is_dist() and is_local_master(): dist.barrier() model, tokenizer = get_function(model_dir, torch_dtype, load_model,