Skip to content

tensor size error when using wan_14b_text_to_video_tensor_parallel in more than 2 gpus #495

@meareabc

Description

@meareabc

When I set more than 2 gpus (4 or 6), I will get a tensor size error, but when I set it to 2 it works will. is there some solution to solve this problem?

this is my data setting:

dataloader = torch.utils.data.DataLoader(
ToyDataset([
{
"prompt":"....",
"negative_prompt":"....",
"num_inference_steps": 500,
"seed": 0,
"tiled": False,
"height": 720,
"width": 480,
"output_path": "video_test1.mp4",
},
]),
collate_fn=lambda x: x,
num_workers=64,
pin_memory=True
)

when using "CUDA_VISIBLE_DEVICES="4,5,6,7" "python ./examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py"
occured error:
[rank2]: Traceback (most recent call last):
[rank2]: File "/home/lkh/sd/DiffSynth-Studio/examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py", line 166, in
[rank2]: trainer.test(model, dataloader)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 775, in test
[rank2]: return call._call_and_handle_interrupt(
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 48, in _call_and_handle_interrupt
[rank2]: return trainer_fn(*args, **kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 817, in _test_impl
[rank2]: results = self._run(model, ckpt_path=ckpt_path)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1012, in _run
[rank2]: results = self._run_stage()
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1049, in _run_stage
[rank2]: return self._evaluation_loop.run()
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py", line 179, in _decorator
[rank2]: return loop_run(self, *args, **kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 145, in run
[rank2]: self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 437, in _evaluation_step
[rank2]: output = call._call_strategy_hook(trainer, hook_name, *step_args)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 328, in _call_strategy_hook
[rank2]: output = fn(*args, **kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 425, in test_step
[rank2]: return self.lightning_module.test_step(*args, **kwargs)
[rank2]: File "/home/lkh/sd/DiffSynth-Studio/examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py", line 120, in test_step
[rank2]: video = self.pipe(**data)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank2]: return func(*args, **kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/diffsynth/pipelines/wan_video.py", line 286, in call
[rank2]: noise_pred_posi = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi, **usp_kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/diffsynth/pipelines/wan_video.py", line 407, in model_fn_wan_video
[rank2]: x = block(x, context, t_mod, freqs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
[rank2]: return inner()
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1793, in inner
[rank2]: result = forward_call(*args, **kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/diffsynth/models/wan_video_dit.py", line 216, in forward
[rank2]: x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
[rank2]: return inner()
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1793, in inner
[rank2]: result = forward_call(*args, **kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/diffsynth/models/wan_video_dit.py", line 141, in forward
[rank2]: q = rope_apply(q, freqs, self.num_heads)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/diffsynth/models/wan_video_dit.py", line 93, in rope_apply
[rank2]: x_out = torch.view_as_real(x_out * freqs).flatten(2)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_compile.py", line 32, in inner
[rank2]: return disable_fn(*args, **kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
[rank2]: return fn(*args, **kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/distributed/tensor/_api.py", line 346, in torch_dispatch
[rank2]: return DTensor._op_dispatcher.dispatch(
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/distributed/tensor/_dispatch.py", line 170, in dispatch
[rank2]: self.sharding_propagator.propagate(op_info)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/distributed/tensor/_sharding_prop.py", line 206, in propagate
[rank2]: OutputSharding, self.propagate_op_sharding(op_info.schema)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/distributed/tensor/_sharding_prop.py", line 46, in call
[rank2]: return self.cache(*args, **kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/distributed/tensor/_sharding_prop.py", line 219, in propagate_op_sharding_non_cached
[rank2]: out_tensor_meta = self._propagate_tensor_meta_non_cached(op_schema)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/distributed/tensor/_sharding_prop.py", line 123, in _propagate_tensor_meta_non_cached
[rank2]: fake_out = op_schema.op(*fake_args, **fake_kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_ops.py", line 723, in call
[rank2]: return self._op(*args, **kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/utils/_stats.py", line 21, in wrapper
[rank2]: return fn(*args, **kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1276, in torch_dispatch
[rank2]: return self.dispatch(func, types, args, kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
[rank2]: return self._cached_dispatch_impl(func, types, args, kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1377, in _cached_dispatch_impl
[rank2]: output = self._dispatch_impl(func, types, args, kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2384, in _dispatch_impl
[rank2]: r = func(*args, **kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_ops.py", line 723, in call
[rank2]: return self._op(*args, **kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 291, in _fn
[rank2]: result = fn(*args, **kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 143, in _fn
[rank2]: result = fn(**bound.arguments)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_refs/init.py", line 1095, in _ref
[rank2]: a, b = _maybe_broadcast(a, b)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_refs/init.py", line 437, in _maybe_broadcast
[rank2]: common_shape = _broadcast_shapes(
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_refs/init.py", line 425, in _broadcast_shapes
[rank2]: torch._check(
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/init.py", line 1656, in _check
[rank2]: _check_with(RuntimeError, cond, message)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/init.py", line 1638, in _check_with
[rank2]: raise error_type(message_evaluated)
[rank2]: RuntimeError: Attempting to broadcast a dimension of length 28350 at -3! Mismatching argument at index 1 had torch.Size([28350, 1, 64]); but expected shape should be broadcastable to [1, 28352, 40, 64]
Testing DataLoader 0: 0%| | 0/3 [00:03<?, ?it/s]
[W327 12:46:40.142966460 ProcessGroup.cpp:266] Warning: At the time of process termination, there are still 1 unwaited collective calls. Please review your program to ensure that:

  1. c10d_functional.wait_tensor() is invoked on all tensors returned from c10d_functional collective,
  2. c10d_functional.wait_tensor() is invoked on all output tensors of async_op=True torch.distributed collective called under with allow_inflight_collective_as_graph_input_ctx():,
    before the output tensors of the collective are used. (function ~WorkRegistry)
    [W327 12:46:40.362266633 ProcessGroup.cpp:266] Warning: At the time of process termination, there are still 1 unwaited collective calls. Please review your program to ensure that:
  3. c10d_functional.wait_tensor() is invoked on all tensors returned from c10d_functional collective,
  4. c10d_functional.wait_tensor() is invoked on all output tensors of async_op=True torch.distributed collective called under with allow_inflight_collective_as_graph_input_ctx():,
    before the output tensors of the collective are used. (function ~WorkRegistry)
    [W327 12:46:41.056580355 ProcessGroup.cpp:266] Warning: At the time of process termination, there are still 1 unwaited collective calls. Please review your program to ensure that:
  5. c10d_functional.wait_tensor() is invoked on all tensors returned from c10d_functional collective,
  6. c10d_functional.wait_tensor() is invoked on all output tensors of async_op=True torch.distributed collective called under with allow_inflight_collective_as_graph_input_ctx():,
    before the output tensors of the collective are used. (function ~WorkRegistry)
    [W327 12:46:41.080732705 ProcessGroup.cpp:266] Warning: At the time of process termination, there are still 1 unwaited collective calls. Please review your program to ensure that:
  7. c10d_functional.wait_tensor() is invoked on all tensors returned from c10d_functional collective,
  8. c10d_functional.wait_tensor() is invoked on all output tensors of async_op=True torch.distributed collective called under with allow_inflight_collective_as_graph_input_ctx():,
    before the output tensors of the collective are used. (function ~WorkRegistry)

when using "CUDA_VISIBLE_DEVICES="4,5" "python ./examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py"
works well

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions