Skip to content

Conversation

@zpcore
Copy link
Contributor

@zpcore zpcore commented Sep 5, 2025

This PR combine with pytorch/pytorch#162294 will support normal model (non-AP) to load the checkpoint saved from the AP model.

Next step is to support AP model to load checkpoint from non-AP model. This requires us to redistribute DTensor certain models weights to fit into AP's requested device order.

Loss curve tests

Tested with torchtitan's autoparallel branch:

echo "Phase 1: Training with checkpoint saving..."
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"
NGPU=${NGPU:-"8"}
export LOG_RANK=${LOG_RANK:-0}
./run_train.sh \
  --model.name llama3_auto_parallel \
  --parallelism.tensor_parallel_degree 2 \
  --training.seed=42 \
  --training.steps 10 \
  --checkpoint.interval 10 \
  --checkpoint.enable \
  --metrics.enable_tensorboard

# Test 2: Resume training from checkpoint to test state dict loading
echo "Phase 2: Resuming training from checkpoint..."
./run_train.sh \
  --model.name llama3 \
  --parallelism.tensor_parallel_degree 2 \
  --training.seed=42 \
  --training.steps 20 \
  --checkpoint.enable 

The solver for debug_model will trigger the reversed order for weight output.weight, which is a great example to test the DCP correctness :)

Loss curve results

Test on torchtitan

Below is the training lossing info under the impact of device ordering (checkpoint saved at step 10):


[baseline and save checkpoint] AP run 20 steps with device order reversed in some model parameters [rank0]:[titan] 2025-09-03 17:45:59,645 - root - INFO - step: 1 loss: 8.0585 grad_norm: 5.8274 memory: 0.50GiB(0.63%) tps: 76 tflops: 0.01 mfu: 0.00%

[rank0]:[titan] 2025-09-03 17:45:59,754 - root - INFO - step: 2 loss: 7.7667 grad_norm: 6.0692 memory: 0.51GiB(0.64%) tps: 75,292 tflops: 5.38 mfu: 1.73%
[rank0]:[titan] 2025-09-03 17:45:59,841 - root - INFO - step: 3 loss: 7.0808 grad_norm: 7.5739 memory: 0.51GiB(0.64%) tps: 95,272 tflops: 6.81 mfu: 2.18%
[rank0]:[titan] 2025-09-03 17:45:59,924 - root - INFO - step: 4 loss: 6.3458 grad_norm: 8.9201 memory: 0.51GiB(0.64%) tps: 98,904 tflops: 7.07 mfu: 2.27%
[rank0]:[titan] 2025-09-03 17:46:00,016 - root - INFO - step: 5 loss: 5.4611 grad_norm: 10.1237 memory: 0.51GiB(0.64%) tps: 89,369 tflops: 6.39 mfu: 2.05%
[rank0]:[titan] 2025-09-03 17:46:00,108 - root - INFO - step: 6 loss: 4.8803 grad_norm: 9.9377 memory: 0.51GiB(0.64%) tps: 89,653 tflops: 6.41 mfu: 2.05%
[rank0]:[titan] 2025-09-03 17:46:00,191 - root - INFO - step: 7 loss: 4.4837 grad_norm: 9.4927 memory: 0.51GiB(0.64%) tps: 98,606 tflops: 7.05 mfu: 2.26%
[rank0]:[titan] 2025-09-03 17:46:00,277 - root - INFO - step: 8 loss: 4.2165 grad_norm: 8.8526 memory: 0.51GiB(0.64%) tps: 95,909 tflops: 6.86 mfu: 2.20%
[rank0]:[titan] 2025-09-03 17:46:00,365 - root - INFO - step: 9 loss: 4.1620 grad_norm: 7.4629 memory: 0.51GiB(0.64%) tps: 93,990 tflops: 6.72 mfu: 2.15%
[rank0]:[titan] 2025-09-03 17:46:00,457 - root - INFO - step: 10 loss: 3.8862 grad_norm: 6.9355 memory: 0.51GiB(0.64%) tps: 89,328 tflops: 6.39 mfu: 2.05%
[rank0]:[titan] 2025-09-03 17:46:00,457 - root - INFO - Saving the checkpoint (or staging if async is enabled).
[rank0]:[titan] 2025-09-03 17:46:02,631 - root - INFO - [GC] GC collection invoked by checkpointer. 1.58 seconds
[rank0]:[titan] 2025-09-03 17:46:02,636 - root - INFO - Finished saving the checkpoint (or staging if async is enabled)in 2.18 seconds.
[rank0]:[titan] 2025-09-03 17:46:03,624 - root - INFO - step: 11 loss: 3.7420 grad_norm: 5.9874 memory: 0.51GiB(0.64%) tps: 2,587 tflops: 0.18 mfu: 0.06%
[rank0]:[titan] 2025-09-03 17:46:03,714 - root - INFO - step: 12 loss: 3.6704 grad_norm: 5.2465 memory: 0.51GiB(0.64%) tps: 91,359 tflops: 6.53 mfu: 2.09%
[rank0]:[titan] 2025-09-03 17:46:03,810 - root - INFO - step: 13 loss: 3.6030 grad_norm: 4.8465 memory: 0.51GiB(0.64%) tps: 85,948 tflops: 6.15 mfu: 1.97%
[rank0]:[titan] 2025-09-03 17:46:03,908 - root - INFO - step: 14 loss: 3.4915 grad_norm: 4.6389 memory: 0.51GiB(0.64%) tps: 83,760 tflops: 5.99 mfu: 1.92%
[rank0]:[titan] 2025-09-03 17:46:04,008 - root - INFO - step: 15 loss: 3.4761 grad_norm: 4.0641 memory: 0.51GiB(0.64%) tps: 82,286 tflops: 5.88 mfu: 1.89%
[rank0]:[titan] 2025-09-03 17:46:04,119 - root - INFO - step: 16 loss: 3.4595 grad_norm: 3.6987 memory: 0.51GiB(0.64%) tps: 74,104 tflops: 5.30 mfu: 1.70%
[rank0]:[titan] 2025-09-03 17:46:04,222 - root - INFO - step: 17 loss: 3.3147 grad_norm: 3.7014 memory: 0.51GiB(0.64%) tps: 80,567 tflops: 5.76 mfu: 1.85%
[rank0]:[titan] 2025-09-03 17:46:04,316 - root - INFO - step: 18 loss: 3.2904 grad_norm: 3.6725 memory: 0.51GiB(0.64%) tps: 87,181 tflops: 6.23 mfu: 2.00%
[rank0]:[titan] 2025-09-03 17:46:04,417 - root - INFO - step: 19 loss: 3.3941 grad_norm: 3.1091 memory: 0.51GiB(0.64%) tps: 81,892 tflops: 5.86 mfu: 1.88%
[rank0]:[titan] 2025-09-03 17:46:04,509 - root - INFO - step: 20 loss: 3.3503 grad_norm: 3.1719 memory: 0.51GiB(0.64%) tps: 89,046 tflops: 6.37 mfu: 2.04%


[mismatched loss] w/o changes to handle device order in DCP, the loss of non-ap model run from checkpoint step 10 [rank0]:[titan] 2025-09-04 11:21:24,805 - root - INFO - Training starts at step 11

[rank0]:[titan] 2025-09-04 11:21:27,777 - root - INFO - step: 11 loss: 3.8366 grad_norm: 1.4309 memory: 0.66GiB(0.83%) tps: 1,627 tflops: 0.12 mfu: 0.04%
[rank0]:[titan] 2025-09-04 11:21:27,777 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-09-04 11:21:28,036 - root - INFO - step: 12 loss: 3.7414 grad_norm: 2.1150 memory: 0.66GiB(0.84%) tps: 31,662 tflops: 2.26 mfu: 0.73%
[rank0]:[titan] 2025-09-04 11:21:28,262 - root - INFO - step: 13 loss: 3.5467 grad_norm: 2.5674 memory: 0.66GiB(0.84%) tps: 36,418 tflops: 2.60 mfu: 0.83%
[rank0]:[titan] 2025-09-04 11:21:28,493 - root - INFO - step: 14 loss: 3.5658 grad_norm: 1.3706 memory: 0.66GiB(0.84%) tps: 35,440 tflops: 2.53 mfu: 0.81%
[rank0]:[titan] 2025-09-04 11:21:28,718 - root - INFO - step: 15 loss: 3.3821 grad_norm: 1.4201 memory: 0.66GiB(0.84%) tps: 36,624 tflops: 2.62 mfu: 0.84%
[rank0]:[titan] 2025-09-04 11:21:28,958 - root - INFO - step: 16 loss: 3.2567 grad_norm: 0.9934 memory: 0.66GiB(0.84%) tps: 34,191 tflops: 2.45 mfu: 0.78%
[rank0]:[titan] 2025-09-04 11:21:29,180 - root - INFO - step: 17 loss: 3.1408 grad_norm: 1.1183 memory: 0.66GiB(0.84%) tps: 37,020 tflops: 2.65 mfu: 0.85%
[rank0]:[titan] 2025-09-04 11:21:29,405 - root - INFO - step: 18 loss: 3.0618 grad_norm: 1.0499 memory: 0.66GiB(0.84%) tps: 36,534 tflops: 2.61 mfu: 0.84%
[rank0]:[titan] 2025-09-04 11:21:29,623 - root - INFO - step: 19 loss: 3.1537 grad_norm: 0.5060 memory: 0.66GiB(0.84%) tps: 37,660 tflops: 2.69 mfu: 0.86%
[rank0]:[titan] 2025-09-04 11:21:29,854 - root - INFO - step: 20 loss: 2.9996 grad_norm: 0.8434 memory: 0.66GiB(0.84%) tps: 35,553 tflops: 2.54 mfu: 0.81%


[expected loss] with changes to handle device order in DCP, the loss of non-ap model run from checkpoint step 10 [rank0]:[titan] 2025-09-05 11:29:33,805 - root - INFO - [GC] GC collection for checkpoint loading. 0.01 seconds

[rank0]:[titan] 2025-09-05 11:29:33,805 - root - INFO - Finished loading the checkpoint in 2.82 seconds.
[rank0]:[titan] 2025-09-05 11:29:33,805 - root - INFO - Training starts at step 11
[rank0]:[titan] 2025-09-05 11:29:36,013 - root - INFO - step: 11 loss: 3.7421 grad_norm: 1.4967 memory: 0.66GiB(0.84%) tps: 1,495 tflops: 0.11 mfu: 0.03%
[rank0]:[titan] 2025-09-05 11:29:36,236 - root - INFO - step: 12 loss: 3.6703 grad_norm: 1.3112 memory: 0.66GiB(0.84%) tps: 36,758 tflops: 2.63 mfu: 0.84%
[rank0]:[titan] 2025-09-05 11:29:36,458 - root - INFO - step: 13 loss: 3.6030 grad_norm: 1.2107 memory: 0.66GiB(0.84%) tps: 37,046 tflops: 2.65 mfu: 0.85%
[rank0]:[titan] 2025-09-05 11:29:36,679 - root - INFO - step: 14 loss: 3.4914 grad_norm: 1.1602 memory: 0.66GiB(0.84%) tps: 37,225 tflops: 2.66 mfu: 0.85%
[rank0]:[titan] 2025-09-05 11:29:36,929 - root - INFO - step: 15 loss: 3.4761 grad_norm: 1.0157 memory: 0.66GiB(0.84%) tps: 32,771 tflops: 2.34 mfu: 0.75%
[rank0]:[titan] 2025-09-05 11:29:37,156 - root - INFO - step: 16 loss: 3.4595 grad_norm: 0.9246 memory: 0.66GiB(0.84%) tps: 36,257 tflops: 2.59 mfu: 0.83%
[rank0]:[titan] 2025-09-05 11:29:37,389 - root - INFO - step: 17 loss: 3.3148 grad_norm: 0.9257 memory: 0.66GiB(0.84%) tps: 35,129 tflops: 2.51 mfu: 0.81%
[rank0]:[titan] 2025-09-05 11:29:37,603 - root - INFO - step: 18 loss: 3.2908 grad_norm: 0.9189 memory: 0.66GiB(0.84%) tps: 38,427 tflops: 2.75 mfu: 0.88%
[rank0]:[titan] 2025-09-05 11:29:37,829 - root - INFO - step: 19 loss: 3.3948 grad_norm: 0.7784 memory: 0.66GiB(0.84%) tps: 36,436 tflops: 2.61 mfu: 0.84%
[rank0]:[titan] 2025-09-05 11:29:38,052 - root - INFO - step: 20 loss: 3.3514 grad_norm: 0.7942 memory: 0.66GiB(0.84%) tps: 36,773 tflops: 2.63 mfu: 0.84%

There is still little bit mismatch between the pure AP's 20 steps run. I think this is due to the mixed precision policy of dtype cast in the module forward in AP. The error should be negligible.

Test on example test

Run:

python examples/example_dcp.py

The test will run on the AP model and save the checkpoint. Then a non-AP model w/o any sharding placement will continue run on the checkpoint. We need to make sure the loss are consistent.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 5, 2025
@zpcore zpcore requested review from XilunWu, ezyang and fmassa September 5, 2025 20:34
@zpcore
Copy link
Contributor Author

zpcore commented Sep 5, 2025

I am unable to make DCP work without the changes in PyTorch upstream. Since this PR is a must to support DCP, we can merge this one first while waiting for the upstream. We can either wait for ghstack pytorch/pytorch#161775 (still need some changes but seems easier to get approved) or wait for pytorch/pytorch#162294 (may take longer).

@zpcore
Copy link
Contributor Author

zpcore commented Sep 5, 2025

Note that we need to update the aten.empty_like strategy to copy the device_order attribute as in pytorch/pytorch#162294. Without the change, if we move the AP generated model to CUDA device as parallel_mod.to_empty(device="cuda"), the device order information will get lost thus result in incorrect initial state.

Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

I made a few comments which I think would simplify a bit the PR, let me know if I understood things right.

Also, I think it would be useful if we write a small test for this functionality.

My original thought would be that we don't need to call into the whole AutoParallel constructor for the test, but we can instead just try calling the apply_sharding_to_model directly with a graph and the placements.

But tests could be added later as we don't have any specific testing for this part yet

@zpcore
Copy link
Contributor Author

zpcore commented Sep 11, 2025

I added the test test_distributed_checkpoint.py. The test will use fakepg to generate the AP model with sharding_placement. The sharding_placement will be used for multiprocess on CUDA devices. In this way we can get different sharding_placement using arbitrary fake world size.

The test will run on the AP model and save the checkpoint. Then a non-AP model w/o any sharding placement will continue run on the checkpoint. We need to make sure the loss are consistent.

Currently I disabled the assertion check. We need to upstream pytorch/pytorch#162294 first to support DCP with device order. Locally I have testing with the PR and the test can pass.

include:
- name: 4xlargegpu
runs-on: linux.g5.4xlarge.nvidia.gpu
runs-on: linux.g5.12xlarge.nvidia.gpu
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need 4 GPU to run the test to compare the loss. Looking forward to your single device simulation :)



def main():
fake_world_size = 256
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is 256 because otherwise you don't get the nontrivial plan? And @fmassa's plan was to let you specifically force a sharding order or something right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, this will trigger the reversed device order placement. I don't have a good solution to force a sharding order, because there must be a specific graph pattern (not just by fixing a single node placement) in order to trigger the redistribute function call to adjust the device ordering.


if not async_op and isinstance(new_local_tensor, funcol.AsyncCollectiveTensor):
new_local_tensor = new_local_tensor.wait()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the bugfix that is upstreamed ig

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sadly, I still see loss.backward() for DTensor module triggers the missing wait() warning, though it is not in this PR. I need to make a follow up fix.

Anyway, we will remove redistribute_tensor.py from AutoParallel soon!

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving to unblock

@zpcore zpcore merged commit 19e5306 into main Sep 12, 2025
6 checks passed
@zpcore zpcore deleted the piz/dcp_dev_order branch September 12, 2025 15:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants