diff --git a/swift/megatron/train/utils.py b/swift/megatron/train/utils.py index cd92218ce4..991c8a671c 100644 --- a/swift/megatron/train/utils.py +++ b/swift/megatron/train/utils.py @@ -205,7 +205,7 @@ def get_batch(data_iterator): # TODO: this is pretty hacky, find a better way if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()): - return None, None, None, None, None + return {key: None for key in ['input_ids', 'attention_mask', 'position_ids']} # get batches based on the TP rank you are on batch = get_batch_on_this_tp_rank(data_iterator)