Skip to content

[Public release 26/04] Introducing EPv2: faster EP, and Engram/PP/CP supports#605

Open
LyricZhao wants to merge 4 commits intomainfrom
epv2-release
Open

[Public release 26/04] Introducing EPv2: faster EP, and Engram/PP/CP supports#605
LyricZhao wants to merge 4 commits intomainfrom
epv2-release

Conversation

@LyricZhao
Copy link
Copy Markdown
Collaborator

With the evolution of hardware, networking, and model architectures, the previous DeepEP V1 had accumulated too much legacy baggage and performance issues. Today, we are excited to introduce DeepEP V2, which includes a complete refactoring of Expert Parallelism — achieving extreme performance with several times fewer SM resources compared to V1, while supporting significantly larger scale-up and scale-out domains — as well as experimental 0 SM Engram, 0 SM Pipeline Parallelism, and 0 SM Context Parallelism all-gather.

We are also happy to announce that we have switched from the NVSHMEM backend to the more lightweight NCCL Gin backend.


New Features

  • Fully JIT (Just-In-Time compilation)
  • NCCL Gin backend
    • Header-only & lightweight
    • Able to reuse existing NCCL communicators
  • EPv2
    • High-throughput and low-latency APIs unified into a single interface, with a new GEMM layout
    • Larger scale-up & scale-out domain support (up to EP2048)
    • Analytical SM & QP count calculation — no more auto-tuning needed
    • Both hybrid & direct modes remain supported
    • For V3-like legacy training, SM usage reduced from 24 to 4 - 6 while maintaining equivalent or better performance
  • 0 SM Engram (with RDMA)
  • 0 SM PP (with RDMA)
  • 0 SM CP (with Copy Engine)

Notes

  • Buffer size consumption is larger than V1
  • 0 SM RDMA low-latency EP is no longer supported
  • Engram, PP, and CP are experimental features

Still On-going Features

  • Elastic GPU & CPU buffers: A contiguous virtual address space that maps to a hybrid of GPU and CPU physical memory under the hood, enabling fully automatic and transparent Engram or imbalanced EP
  • Reducing intermediate buffer sizes by leveraging EP replay to handle load imbalance
  • All-gather updates and reduce-scatter implementations for DP & TP

Performance

Following V3's configuration, we tested with 8K tokens per batch, 7168 hidden dimensions, top 8 experts, FP8 dispatching, and BF16 combining, and obtained the following results:

Arch NIC type Topo Dispatch Bottleneck Bandwidth Combine Bottleneck Bandwidth #SMs
SM90 CX7 EP 8 x 2 90 GB/s (RDMA) 81 GB/s (RDMA) 12
SM90 CX7 EP 8 x 4 61 GB/s (RDMA) 61 GB/s (RDMA) 6
SM100 CX7 EP 8 x 2 90 GB/s (RDMA) 91 GB/s (RDMA) 12
SM100 N/A EP 8 726 GB/s (NVLink) 740 GB/s (NVLink) 64 (Max perf)
SM100 N/A EP 8 643 GB/s (NVLink) 675 GB/s (NVLink) 24 (Min #SM)

Notes, the results are logical bandwidth. For example, under the EP 8 x 2 case, 90 GB/s actually contains local rank traffic.

Comparing with V1, V2 achieves up to 1.3x peak performance, while saving up to 4x SM count.

We omit results for larger EP configurations for the time being, but encourage interested users to benchmark them directly. Based on our internal experience, we expect the kernel to continue saturating hardware bandwidth at scale.


Contributors

@LyricZhao LyricZhao requested a review from sphish April 23, 2026 05:54
@alpha-baby
Copy link
Copy Markdown
Contributor

build failed on cuda 12.8

dependency:

nvidia-nccl-cu12                         2.30.4
nvidia-nvshmem-cu12                      3.5.19
export PATH=/usr/local/cuda/bin:$PATH
export EP_NVSHMEM_ROOT_DIR=/opt/conda/lib/python3.10/site-packages/nvidia/nvshmem
export EP_NCCL_ROOT_DIR=/opt/conda/lib/python3.10/site-packages/nvidia/nccl
python setup.py bdist_wheel
/opt/conda/lib/python3.10/site-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
Build summary:
 > Sources: ['csrc/python_api.cpp', 'csrc/kernels/legacy/layout.cu', 'csrc/kernels/legacy/intranode.cu', 'csrc/kernels/legacy/internode.cu', 'csrc/kernels/legacy/internode_ll.cu', 'csrc/kernels/backend/nvshmem.cu', 'csrc/kernels/backend/nccl.cu', 'csrc/kernels/backend/cuda_driver.cu']
 > Includes: ['/root/AntDeepEP-v2-bak/deep_ep/include', '/root/AntDeepEP-v2-bak/third-party/fmt/include', '/usr/local/cuda/include/cccl', '/opt/conda/lib/python3.10/site-packages/nvidia/nvshmem/include', '/opt/conda/lib/python3.10/site-packages/nvidia/nccl/include']
 > Libraries: ['/opt/conda/lib/python3.10/site-packages/nvidia/nvshmem/lib']
 > Compilation flags: {'cxx': ['-O3', '-Wno-deprecated-declarations', '-Wno-unused-variable', '-Wno-sign-compare', '-Wno-reorder', '-Wno-attributes', '-DDISABLE_AGGRESSIVE_PTX_INSTRS'], 'nvcc': ['-O3', '-Xcompiler', '-O3', '--extended-lambda', '--diag-suppress=128,2417', '-rdc=true', '--ptxas-options=--register-usage-level=10', '-DDISABLE_AGGRESSIVE_PTX_INSTRS'], 'nvcc_dlink': ['-dlink', '-L/opt/conda/lib/python3.10/site-packages/nvidia/nvshmem/lib', '-lnvshmem_device']}
 > Link flags: ['-lcuda', '-l:libnvshmem_host.so', '-l:libnvshmem_device.a', '-Wl,-rpath,/opt/conda/lib/python3.10/site-packages/nvidia/nvshmem/lib', '-l:libnccl.so', '-Wl,-rpath,/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib']
 > Arch list: 9.0
 > NVSHMEM path: /opt/conda/lib/python3.10/site-packages/nvidia/nvshmem
 > NCCL path: /opt/conda/lib/python3.10/site-packages/nvidia/nccl
 > Persistent envs:
   > EP_NCCL_ROOT_DIR: /opt/conda/lib/python3.10/site-packages/nvidia/nccl

running bdist_wheel
running build
running build_py
copying deep_ep/__init__.py -> build/lib.linux-x86_64-cpython-310/deep_ep
creating build/lib.linux-x86_64-cpython-310/deep_ep/utils
copying deep_ep/utils/envs.py -> build/lib.linux-x86_64-cpython-310/deep_ep/utils
copying deep_ep/utils/refs.py -> build/lib.linux-x86_64-cpython-310/deep_ep/utils
copying deep_ep/utils/find_pkgs.py -> build/lib.linux-x86_64-cpython-310/deep_ep/utils
copying deep_ep/utils/gate.py -> build/lib.linux-x86_64-cpython-310/deep_ep/utils
copying deep_ep/utils/semantic.py -> build/lib.linux-x86_64-cpython-310/deep_ep/utils
copying deep_ep/utils/event.py -> build/lib.linux-x86_64-cpython-310/deep_ep/utils
copying deep_ep/utils/testing.py -> build/lib.linux-x86_64-cpython-310/deep_ep/utils
copying deep_ep/utils/__init__.py -> build/lib.linux-x86_64-cpython-310/deep_ep/utils
copying deep_ep/utils/math.py -> build/lib.linux-x86_64-cpython-310/deep_ep/utils
copying deep_ep/utils/comm.py -> build/lib.linux-x86_64-cpython-310/deep_ep/utils
creating build/lib.linux-x86_64-cpython-310/deep_ep/buffers
copying deep_ep/buffers/elastic.py -> build/lib.linux-x86_64-cpython-310/deep_ep/buffers
copying deep_ep/buffers/__init__.py -> build/lib.linux-x86_64-cpython-310/deep_ep/buffers
copying deep_ep/buffers/legacy.py -> build/lib.linux-x86_64-cpython-310/deep_ep/buffers
creating build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/common
copying deep_ep/include/deep_ep/common/math.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/common
copying deep_ep/include/deep_ep/common/ptx.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/common
copying deep_ep/include/deep_ep/common/comm.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/common
copying deep_ep/include/deep_ep/common/compiled.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/common
copying deep_ep/include/deep_ep/common/handle.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/common
copying deep_ep/include/deep_ep/common/exception.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/common
copying deep_ep/include/deep_ep/common/layout.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/common
creating build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/impls
copying deep_ep/include/deep_ep/impls/hybrid_dispatch.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/impls
copying deep_ep/include/deep_ep/impls/combine_utils.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/impls
copying deep_ep/include/deep_ep/impls/combine.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/impls
copying deep_ep/include/deep_ep/impls/barrier.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/impls
copying deep_ep/include/deep_ep/impls/dispatch.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/impls
copying deep_ep/include/deep_ep/impls/hybrid_combine.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/impls
copying deep_ep/include/deep_ep/impls/dispatch_copy_epilogue.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/impls
copying deep_ep/include/deep_ep/impls/dispatch_deterministic_prologue.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/impls
copying deep_ep/include/deep_ep/impls/pp_send_recv.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/impls
copying deep_ep/include/deep_ep/impls/combine_reduce_epilogue.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/impls
copying deep_ep/include/deep_ep/impls/engram_fetch.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/impls
running build_ext
W0423 20:44:59.975000 77927 site-packages/torch/utils/cpp_extension.py:531] There are no g++ version bounds defined for CUDA version 12.8
building 'deep_ep._C' extension
creating /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/kernels/backend
creating /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/kernels/legacy
[1/9] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/kernels/legacy/layout.o.d -I/root/AntDeepEP-v2-bak/deep_ep/include -I/root/AntDeepEP-v2-bak/third-party/fmt/include -I/usr/local/cuda/include/cccl -I/opt/conda/lib/python3.10/site-packages/nvidia/nvshmem/include -I/opt/conda/lib/python3.10/site-packages/nvidia/nccl/include -I/opt/conda/lib/python3.10/site-packages/torch/include -I/opt/conda/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda/include -I/opt/conda/include/python3.10 -c -c /root/AntDeepEP-v2-bak/csrc/kernels/legacy/layout.cu -o /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/kernels/legacy/layout.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -Xcompiler -O3 --extended-lambda --diag-suppress=128,2417 -rdc=true --ptxas-options=--register-usage-level=10 -DDISABLE_AGGRESSIVE_PTX_INSTRS -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_C -gencode=arch=compute_90,code=sm_90 -std=c++17
[2/9] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/kernels/legacy/intranode.o.d -I/root/AntDeepEP-v2-bak/deep_ep/include -I/root/AntDeepEP-v2-bak/third-party/fmt/include -I/usr/local/cuda/include/cccl -I/opt/conda/lib/python3.10/site-packages/nvidia/nvshmem/include -I/opt/conda/lib/python3.10/site-packages/nvidia/nccl/include -I/opt/conda/lib/python3.10/site-packages/torch/include -I/opt/conda/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda/include -I/opt/conda/include/python3.10 -c -c /root/AntDeepEP-v2-bak/csrc/kernels/legacy/intranode.cu -o /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/kernels/legacy/intranode.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -Xcompiler -O3 --extended-lambda --diag-suppress=128,2417 -rdc=true --ptxas-options=--register-usage-level=10 -DDISABLE_AGGRESSIVE_PTX_INSTRS -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_C -gencode=arch=compute_90,code=sm_90 -std=c++17
[3/9] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/kernels/backend/nvshmem.o.d -I/root/AntDeepEP-v2-bak/deep_ep/include -I/root/AntDeepEP-v2-bak/third-party/fmt/include -I/usr/local/cuda/include/cccl -I/opt/conda/lib/python3.10/site-packages/nvidia/nvshmem/include -I/opt/conda/lib/python3.10/site-packages/nvidia/nccl/include -I/opt/conda/lib/python3.10/site-packages/torch/include -I/opt/conda/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda/include -I/opt/conda/include/python3.10 -c -c /root/AntDeepEP-v2-bak/csrc/kernels/backend/nvshmem.cu -o /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/kernels/backend/nvshmem.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -Xcompiler -O3 --extended-lambda --diag-suppress=128,2417 -rdc=true --ptxas-options=--register-usage-level=10 -DDISABLE_AGGRESSIVE_PTX_INSTRS -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_C -gencode=arch=compute_90,code=sm_90 -std=c++17
[4/9] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/kernels/backend/cuda_driver.o.d -I/root/AntDeepEP-v2-bak/deep_ep/include -I/root/AntDeepEP-v2-bak/third-party/fmt/include -I/usr/local/cuda/include/cccl -I/opt/conda/lib/python3.10/site-packages/nvidia/nvshmem/include -I/opt/conda/lib/python3.10/site-packages/nvidia/nccl/include -I/opt/conda/lib/python3.10/site-packages/torch/include -I/opt/conda/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda/include -I/opt/conda/include/python3.10 -c -c /root/AntDeepEP-v2-bak/csrc/kernels/backend/cuda_driver.cu -o /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/kernels/backend/cuda_driver.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -Xcompiler -O3 --extended-lambda --diag-suppress=128,2417 -rdc=true --ptxas-options=--register-usage-level=10 -DDISABLE_AGGRESSIVE_PTX_INSTRS -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_C -gencode=arch=compute_90,code=sm_90 -std=c++17
[5/9] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/kernels/backend/nccl.o.d -I/root/AntDeepEP-v2-bak/deep_ep/include -I/root/AntDeepEP-v2-bak/third-party/fmt/include -I/usr/local/cuda/include/cccl -I/opt/conda/lib/python3.10/site-packages/nvidia/nvshmem/include -I/opt/conda/lib/python3.10/site-packages/nvidia/nccl/include -I/opt/conda/lib/python3.10/site-packages/torch/include -I/opt/conda/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda/include -I/opt/conda/include/python3.10 -c -c /root/AntDeepEP-v2-bak/csrc/kernels/backend/nccl.cu -o /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/kernels/backend/nccl.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -Xcompiler -O3 --extended-lambda --diag-suppress=128,2417 -rdc=true --ptxas-options=--register-usage-level=10 -DDISABLE_AGGRESSIVE_PTX_INSTRS -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_C -gencode=arch=compute_90,code=sm_90 -std=c++17
[6/9] c++ -MMD -MF /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/python_api.o.d -pthread -B /opt/conda/compiler_compat -Wno-unused-result -Wsign-compare -DNDEBUG -fwrapv -O2 -Wall -fPIC -O2 -isystem /opt/conda/include -fPIC -O2 -isystem /opt/conda/include -fPIC -I/root/AntDeepEP-v2-bak/deep_ep/include -I/root/AntDeepEP-v2-bak/third-party/fmt/include -I/usr/local/cuda/include/cccl -I/opt/conda/lib/python3.10/site-packages/nvidia/nvshmem/include -I/opt/conda/lib/python3.10/site-packages/nvidia/nccl/include -I/opt/conda/lib/python3.10/site-packages/torch/include -I/opt/conda/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda/include -I/opt/conda/include/python3.10 -c -c /root/AntDeepEP-v2-bak/csrc/python_api.cpp -o /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/python_api.o -O3 -Wno-deprecated-declarations -Wno-unused-variable -Wno-sign-compare -Wno-reorder -Wno-attributes -DDISABLE_AGGRESSIVE_PTX_INSTRS -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_C -std=c++17
FAILED: /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/python_api.o 
c++ -MMD -MF /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/python_api.o.d -pthread -B /opt/conda/compiler_compat -Wno-unused-result -Wsign-compare -DNDEBUG -fwrapv -O2 -Wall -fPIC -O2 -isystem /opt/conda/include -fPIC -O2 -isystem /opt/conda/include -fPIC -I/root/AntDeepEP-v2-bak/deep_ep/include -I/root/AntDeepEP-v2-bak/third-party/fmt/include -I/usr/local/cuda/include/cccl -I/opt/conda/lib/python3.10/site-packages/nvidia/nvshmem/include -I/opt/conda/lib/python3.10/site-packages/nvidia/nccl/include -I/opt/conda/lib/python3.10/site-packages/torch/include -I/opt/conda/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda/include -I/opt/conda/include/python3.10 -c -c /root/AntDeepEP-v2-bak/csrc/python_api.cpp -o /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/python_api.o -O3 -Wno-deprecated-declarations -Wno-unused-variable -Wno-sign-compare -Wno-reorder -Wno-attributes -DDISABLE_AGGRESSIVE_PTX_INSTRS -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_C -std=c++17
In file included from /usr/local/cuda/include/cuda/__ptx/instructions/barrier_cluster.h:26,
                 from /usr/local/cuda/include/cuda/ptx:72,
                 from /usr/local/cuda/include/cuda/barrier:24,
                 from /root/AntDeepEP-v2-bak/deep_ep/include/deep_ep/common/ptx.cuh:3,
                 from /root/AntDeepEP-v2-bak/deep_ep/include/deep_ep/common/layout.cuh:6,
                 from /root/AntDeepEP-v2-bak/csrc/elastic/buffer.hpp:8,
                 from /root/AntDeepEP-v2-bak/csrc/python_api.cpp:7:
/usr/local/cuda/include/cuda/__ptx/ptx_helper_functions.h: In function ‘uint32_t cuda::ptx::__4::__as_ptr_smem(const void*)’:
/usr/local/cuda/include/cuda/__ptx/ptx_helper_functions.h:42:44: error: ‘__cvta_generic_to_shared’ was not declared in this scope
   42 |   return static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(__ptr));
      |                                            ^~~~~~~~~~~~~~~~~~~~~~~~
/usr/local/cuda/include/cuda/__ptx/ptx_helper_functions.h: In function ‘uint64_t cuda::ptx::__4::__as_ptr_gmem(const void*)’:
/usr/local/cuda/include/cuda/__ptx/ptx_helper_functions.h:62:44: error: ‘__cvta_generic_to_global’ was not declared in this scope
   62 |   return static_cast<_CUDA_VSTD::uint64_t>(__cvta_generic_to_global(__ptr));
      |                                            ^~~~~~~~~~~~~~~~~~~~~~~~
/usr/local/cuda/include/cuda/__ptx/ptx_helper_functions.h: In function ‘_Tp* cuda::ptx::__4::__from_ptr_smem(size_t)’:
/usr/local/cuda/include/cuda/__ptx/ptx_helper_functions.h:74:33: error: there are no arguments to ‘__cvta_shared_to_generic’ that depend on a template parameter, so a declaration of ‘__cvta_shared_to_generic’ must be available [-fpermissive]
   74 |   return reinterpret_cast<_Tp*>(__cvta_shared_to_generic(__ptr));
      |                                 ^~~~~~~~~~~~~~~~~~~~~~~~
/usr/local/cuda/include/cuda/__ptx/ptx_helper_functions.h:74:33: note: (if you use ‘-fpermissive’, G++ will accept your code, but allowing the use of an undeclared name is deprecated)
/usr/local/cuda/include/cuda/__ptx/ptx_helper_functions.h: In function ‘_Tp* cuda::ptx::__4::__from_ptr_gmem(size_t)’:
/usr/local/cuda/include/cuda/__ptx/ptx_helper_functions.h:95:33: error: there are no arguments to ‘__cvta_global_to_generic’ that depend on a template parameter, so a declaration of ‘__cvta_global_to_generic’ must be available [-fpermissive]
   95 |   return reinterpret_cast<_Tp*>(__cvta_global_to_generic(__ptr));
      |                                 ^~~~~~~~~~~~~~~~~~~~~~~~
In file included from /root/AntDeepEP-v2-bak/csrc/python_api.cpp:7:
/root/AntDeepEP-v2-bak/csrc/elastic/buffer.hpp: In member function ‘std::function<at::Tensor()> deep_ep::elastic::ElasticBuffer::engram_fetch(const at::Tensor&, int) const’:
/root/AntDeepEP-v2-bak/csrc/elastic/buffer.hpp:255:20: warning: explicit by-copy capture of ‘this’ redundant with by-copy capture default
  255 |         return [=, this]() {
      |                    ^~~~
In file included from /root/AntDeepEP-v2-bak/csrc/python_api.cpp:7:
/root/AntDeepEP-v2-bak/csrc/elastic/buffer.hpp: In member function ‘std::pair<std::vector<at::Tensor>, std::function<void()> > deep_ep::elastic::ElasticBuffer::all_gather(const std::vector<at::Tensor>&)’:
/root/AntDeepEP-v2-bak/csrc/elastic/buffer.hpp:461:27: warning: explicit by-copy capture of ‘this’ redundant with by-copy capture default
  461 |         auto handle = [=, this]() {
      |                           ^~~~
In file included from /root/AntDeepEP-v2-bak/csrc/python_api.cpp:8:
/root/AntDeepEP-v2-bak/csrc/legacy/buffer.hpp: In member function ‘void deep_ep::legacy::Buffer::clean_low_latency_buffer(int, int, int)’:
/root/AntDeepEP-v2-bak/csrc/legacy/buffer.hpp:1438:35: warning: explicit by-copy capture of ‘this’ redundant with by-copy capture default
 1438 |         auto check_boundary = [=, this](void* ptr, size_t num_bytes) {
      |                                   ^~~~
/root/AntDeepEP-v2-bak/csrc/legacy/buffer.hpp: In member function ‘std::tuple<at::Tensor, std::optional<at::Tensor>, at::Tensor, at::Tensor, at::Tensor, std::optional<deep_ep::EventHandle>, std::optional<std::function<void()> > > deep_ep::legacy::Buffer::low_latency_dispatch(const at::Tensor&, const at::Tensor&, const std::optional<at::Tensor>&, const std::optional<at::Tensor>&, int, int, bool, bool, bool, bool, bool)’:
/root/AntDeepEP-v2-bak/csrc/legacy/buffer.hpp:1545:29: warning: explicit by-copy capture of ‘this’ redundant with by-copy capture default
 1545 |         auto launcher = [=, this](int phases) {
      |                             ^~~~
/root/AntDeepEP-v2-bak/csrc/legacy/buffer.hpp: In member function ‘std::tuple<at::Tensor, std::optional<deep_ep::EventHandle>, std::optional<std::function<void()> > > deep_ep::legacy::Buffer::low_latency_combine(const at::Tensor&, const at::Tensor&, const at::Tensor&, const at::Tensor&, const at::Tensor&, const std::optional<at::Tensor>&, int, int, bool, bool, bool, bool, const std::optional<at::Tensor>&)’:
/root/AntDeepEP-v2-bak/csrc/legacy/buffer.hpp:1668:29: warning: explicit by-copy capture of ‘this’ redundant with by-copy capture default
 1668 |         auto launcher = [=, this](int phases) {
      |                             ^~~~
[7/9] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/kernels/legacy/internode_ll.o.d -I/root/AntDeepEP-v2-bak/deep_ep/include -I/root/AntDeepEP-v2-bak/third-party/fmt/include -I/usr/local/cuda/include/cccl -I/opt/conda/lib/python3.10/site-packages/nvidia/nvshmem/include -I/opt/conda/lib/python3.10/site-packages/nvidia/nccl/include -I/opt/conda/lib/python3.10/site-packages/torch/include -I/opt/conda/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda/include -I/opt/conda/include/python3.10 -c -c /root/AntDeepEP-v2-bak/csrc/kernels/legacy/internode_ll.cu -o /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/kernels/legacy/internode_ll.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -Xcompiler -O3 --extended-lambda --diag-suppress=128,2417 -rdc=true --ptxas-options=--register-usage-level=10 -DDISABLE_AGGRESSIVE_PTX_INSTRS -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_C -gencode=arch=compute_90,code=sm_90 -std=c++17
[8/9] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/kernels/legacy/internode.o.d -I/root/AntDeepEP-v2-bak/deep_ep/include -I/root/AntDeepEP-v2-bak/third-party/fmt/include -I/usr/local/cuda/include/cccl -I/opt/conda/lib/python3.10/site-packages/nvidia/nvshmem/include -I/opt/conda/lib/python3.10/site-packages/nvidia/nccl/include -I/opt/conda/lib/python3.10/site-packages/torch/include -I/opt/conda/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda/include -I/opt/conda/include/python3.10 -c -c /root/AntDeepEP-v2-bak/csrc/kernels/legacy/internode.cu -o /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/kernels/legacy/internode.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -Xcompiler -O3 --extended-lambda --diag-suppress=128,2417 -rdc=true --ptxas-options=--register-usage-level=10 -DDISABLE_AGGRESSIVE_PTX_INSTRS -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_C -gencode=arch=compute_90,code=sm_90 -std=c++17
ninja: build stopped: subcommand failed.
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/cpp_extension.py", line 2597, in _run_ninja_build
    subprocess.run(
  File "/opt/conda/lib/python3.10/subprocess.py", line 526, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/root/AntDeepEP-v2-bak/setup.py", line 170, in <module>
    setuptools.setup(
  File "/opt/conda/lib/python3.10/site-packages/setuptools/__init__.py", line 115, in setup
    return distutils.core.setup(**attrs)
  File "/opt/conda/lib/python3.10/site-packages/setuptools/_distutils/core.py", line 186, in setup
    return run_commands(dist)
  File "/opt/conda/lib/python3.10/site-packages/setuptools/_distutils/core.py", line 202, in run_commands
    dist.run_commands()
  File "/opt/conda/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 1002, in run_commands
    self.run_command(cmd)
  File "/opt/conda/lib/python3.10/site-packages/setuptools/dist.py", line 1102, in run_command
    super().run_command(command)
  File "/opt/conda/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 1021, in run_command
    cmd_obj.run()
  File "/opt/conda/lib/python3.10/site-packages/setuptools/command/bdist_wheel.py", line 370, in run
    self.run_command("build")
  File "/opt/conda/lib/python3.10/site-packages/setuptools/_distutils/cmd.py", line 357, in run_command
    self.distribution.run_command(command)
  File "/opt/conda/lib/python3.10/site-packages/setuptools/dist.py", line 1102, in run_command
    super().run_command(command)
  File "/opt/conda/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 1021, in run_command
    cmd_obj.run()
  File "/opt/conda/lib/python3.10/site-packages/setuptools/_distutils/command/build.py", line 135, in run
    self.run_command(cmd_name)
  File "/opt/conda/lib/python3.10/site-packages/setuptools/_distutils/cmd.py", line 357, in run_command
    self.distribution.run_command(command)
  File "/opt/conda/lib/python3.10/site-packages/setuptools/dist.py", line 1102, in run_command
    super().run_command(command)
  File "/opt/conda/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 1021, in run_command
    cmd_obj.run()
  File "/opt/conda/lib/python3.10/site-packages/setuptools/command/build_ext.py", line 96, in run
    _build_ext.run(self)
  File "/opt/conda/lib/python3.10/site-packages/setuptools/_distutils/command/build_ext.py", line 368, in run
    self.build_extensions()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/cpp_extension.py", line 1082, in build_extensions
    build_ext.build_extensions(self)
  File "/opt/conda/lib/python3.10/site-packages/setuptools/_distutils/command/build_ext.py", line 484, in build_extensions
    self._build_extensions_serial()
  File "/opt/conda/lib/python3.10/site-packages/setuptools/_distutils/command/build_ext.py", line 510, in _build_extensions_serial
    self.build_extension(ext)
  File "/opt/conda/lib/python3.10/site-packages/setuptools/command/build_ext.py", line 261, in build_extension
    _build_ext.build_extension(self, ext)
  File "/opt/conda/lib/python3.10/site-packages/Cython/Distutils/build_ext.py", line 135, in build_extension
    super(build_ext, self).build_extension(ext)
  File "/opt/conda/lib/python3.10/site-packages/setuptools/_distutils/command/build_ext.py", line 565, in build_extension
    objects = self.compiler.compile(
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/cpp_extension.py", line 866, in unix_wrap_ninja_compile
    _write_ninja_file_and_compile_objects(
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/cpp_extension.py", line 2223, in _write_ninja_file_and_compile_objects
    _run_ninja_build(
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/cpp_extension.py", line 2614, in _run_ninja_build
    raise RuntimeError(message) from e
RuntimeError: Error compiling objects for extension

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants