diff --git a/swift/megatron/train/sft.py b/swift/megatron/train/sft.py index b369261482..5d9cb7a423 100644 --- a/swift/megatron/train/sft.py +++ b/swift/megatron/train/sft.py @@ -4,6 +4,7 @@ from typing import List, Optional, Union import torch +import torch.distributed as dist from swift.llm import TEMPLATE_MAPPING from swift.llm.train import SwiftSft @@ -60,6 +61,7 @@ def run(self): try: self.trainer.train(train_dataset, val_dataset, data_collator) + dist.barrier() # Ensure all weights are saved completely finally: # Visualization if is_last_rank():