From 6b632c7c18540ff4335876775975bfe1403ea2c4 Mon Sep 17 00:00:00 2001 From: Lux Industries Date: Sat, 16 May 2026 12:03:06 -0700 Subject: [PATCH 1/5] cmake: add find_package(lux-gpu-kernels) discovery hook The proprietary high-performance GPU kernels (Metal/CUDA/WGSL) for crypto schemes, AIVM, and FHE have been extracted into the private repository lux-private/gpu-kernels (commercial license required). This commit adds a non-breaking discovery hook to the top-level CMakeLists.txt: - find_package(lux-gpu-kernels CONFIG QUIET) probes for the private install via CMAKE_PREFIX_PATH. - When found, the in-tree gpu/ subdirs are symlinked from the install root, allowing per-scheme CMakeLists to compile the real GPU drivers from the canonical private source. - When not found, CRYPTO_ENABLE_{CUDA,METAL,WGSL} are forced OFF and the build proceeds with the in-tree CPU implementations unchanged. The cevm-genesis-parity test passes byte-equal in this mode (state root + genesis hash match). In-tree gpu/ subdirs remain in this commit to keep historical consumers building unchanged; a follow-up will gate per-scheme add_library(... gpu/...) calls behind CRYPTO_ENABLE_* and delete the in-tree kernels. Commercial license inquiries: licensing@lux.network --- CMakeLists.txt | 48 +++++++++++++++++++++++++++++++++++++++++++++++- README.md | 15 ++++++++++----- 2 files changed, 57 insertions(+), 6 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5b56a97..1310d3c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -58,6 +58,52 @@ include(LuxAlgorithm) # or (currently bn254, modexp). add_subdirectory(deps) +# ----------------------------------------------------------------------------- +# lux-gpu-kernels (proprietary) discovery +# ----------------------------------------------------------------------------- +# Public luxcpp/crypto ships CPU implementations only. The GPU kernel sources +# (Metal/CUDA/WGSL) live in the proprietary `lux-private/gpu-kernels` repo. +# When the private repo is installed (CMAKE_PREFIX_PATH points at it), this +# block symlinks each scheme's gpu/ directory into the source tree so +# the existing per-scheme CMakeLists works unmodified. Without it, all GPU +# drivers are disabled and the build is CPU-only. +find_package(lux-gpu-kernels CONFIG QUIET) +if(lux-gpu-kernels_FOUND) + message(STATUS "lux-gpu-kernels: FOUND (${lux-gpu-kernels_DIR}) — GPU drivers enabled") + # ${lux-gpu-kernels_DIR} == /lib/cmake/lux-gpu-kernels → step up 3 levels to /. + get_filename_component(_LUX_GPU_KERNELS_PREFIX "${lux-gpu-kernels_DIR}/../../.." ABSOLUTE) + set(_LUX_GPU_KERNELS_INCLUDE "${_LUX_GPU_KERNELS_PREFIX}/include/lux-gpu-kernels") + file(GLOB _LUX_GPU_SCHEMES LIST_DIRECTORIES true RELATIVE "${_LUX_GPU_KERNELS_INCLUDE}/crypto" + "${_LUX_GPU_KERNELS_INCLUDE}/crypto/*") + foreach(_alg IN LISTS _LUX_GPU_SCHEMES) + set(_alg_priv "${_LUX_GPU_KERNELS_INCLUDE}/crypto/${_alg}") + if(IS_DIRECTORY "${_alg_priv}" AND EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/${_alg}") + file(MAKE_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/${_alg}/gpu") + foreach(_backend cuda metal wgsl) + if(IS_DIRECTORY "${_alg_priv}/${_backend}" + AND NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/${_alg}/gpu/${_backend}") + file(CREATE_LINK "${_alg_priv}/${_backend}" + "${CMAKE_CURRENT_SOURCE_DIR}/${_alg}/gpu/${_backend}" + SYMBOLIC) + endif() + endforeach() + endif() + endforeach() + # math/ntt/cuda lives at a non-scheme path; treat as a standalone symlink. + if(IS_DIRECTORY "${_LUX_GPU_KERNELS_INCLUDE}/crypto/math/ntt/cuda" + AND NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/math/ntt/cuda") + file(CREATE_LINK + "${_LUX_GPU_KERNELS_INCLUDE}/crypto/math/ntt/cuda" + "${CMAKE_CURRENT_SOURCE_DIR}/math/ntt/cuda" + SYMBOLIC) + endif() +else() + message(STATUS "lux-gpu-kernels: NOT FOUND — CPU-only build (Metal/CUDA/WGSL drivers disabled)") + set(CRYPTO_ENABLE_CUDA OFF CACHE BOOL "Build CUDA drivers" FORCE) + set(CRYPTO_ENABLE_METAL OFF CACHE BOOL "Build Metal drivers" FORCE) + set(CRYPTO_ENABLE_WGSL OFF CACHE BOOL "Build WGSL drivers" FORCE) +endif() + option(CRYPTO_ENABLE_CUDA "Build CUDA drivers" OFF) option(CRYPTO_ENABLE_METAL "Build Metal drivers" OFF) option(CRYPTO_BUILD_TESTS "Build crypto tests" ON) @@ -73,7 +119,7 @@ option(CRYPTO_BUILD_TESTS "Build crypto tests" ON) option(CRYPTO_RINGTAIL_DKG_EXPERIMENTAL "Expose ringtail_dkg_* C-ABI (UNSAFE — see RED-DKG-REVIEW.md)" OFF) -if(APPLE) +if(APPLE AND lux-gpu-kernels_FOUND) set(CRYPTO_ENABLE_METAL ON CACHE BOOL "Build Metal drivers" FORCE) endif() diff --git a/README.md b/README.md index 482093b..86fb96f 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,17 @@ # luxcpp/crypto Canonical native cryptographic primitives for the Lux / Hanzo / Zoo -ecosystem. CPU + GPU implementations live here and only here. - -* CPU: portable C++17 / C++20, no third-party crypto libraries -* GPU: CUDA, Metal, WGSL kernels with byte-equal CPU↔GPU output +ecosystem. Public surface ships CPU implementations only. + +* CPU: portable C++17 / C++20, no third-party crypto libraries (this repo) +* GPU: CUDA, Metal, WGSL kernels — **proprietary**, distributed via + `lux-private/gpu-kernels`. Public CMake build auto-detects via + `find_package(lux-gpu-kernels CONFIG QUIET)` and disables Metal/CUDA/WGSL + drivers when absent (CPU-only fallback). Commercial license: + licensing@lux.network * C ABI: `include/lux/crypto/.h`, callable from Go (cgo) and - Rust (bindgen / lux-crypto-sys) + Rust (bindgen / lux-crypto-sys) — identical surface whether the GPU + kernels are linked or not The Go entry point is `github.com/luxfi/crypto`; the GPU device router is `github.com/luxfi/accel`. From ae575c73fce1d9fa8e0ee59727d9a134146f2b72 Mon Sep 17 00:00:00 2001 From: Hanzo AI Date: Sat, 16 May 2026 15:29:25 -0700 Subject: [PATCH 2/5] cmake: per-scheme CRYPTO_ENABLE_ gating so gpu/ can be deleted MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The find_package(lux-gpu-kernels) hook (6b632c7c) symlinks each scheme's gpu// from lux-private at configure time and force-disables CRYPTO_ENABLE_{CUDA,METAL,WGSL} when the private repo is absent. That made the gates exist as cache vars but every per-scheme CMakeLists still unconditionally referenced gpu//.{cpp,cu,wgsl} via add_library / file(READ) — so deleting the in-tree gpu/ subtrees would break configure + build even with the option flipped off. This commit wraps every gpu-touching add_library + file(READ) + set(...) block in its scheme's per-backend CRYPTO_ENABLE_* guard: aead, blake2b, bls (stages 1-4 + combined-miller + cuda-stub + wgsl-driver), cggmp21, frost, gpukit (cuda driver + multi_pippenger cuda source), lamport, math (lattice_ring_cuda), ntt (large-N cuda/metal/wgsl), ripemd160, secp256k1 (batch_inv + ecrecover cuda + wgsl), sha256, plus fhe (host CUDA shims + stub for lattice_ring_cuda_* C ABI when lux-private absent). Top-level CMakeLists tests gated parallel to the lib gates: sha256_cuda_test, sha256_wgpu_test, frost_presign_test, cggmp21_presign_test, ripemd160_cuda_test, ripemd160_wgpu_test, bn254_gpu_determinism_test + bn254_pairing_consts codegen, modexp_karatsuba_gpu_test, kzg_gpu_determinism_test, ringtail_lattice_ring_bench, ringtail_lattice_ring_sweep_bench (needed METAL+WGSL because source statically references both), pedersen_tree_{metal,cuda,wgpu}_determinism_test, batch_inv_{cuda,wgsl}_test, banderwagon_cuda_determinism_test, banderwagon_wgsl_determinism_test, ntt_large_test, gpukit tests (all-backends-required because the source statically resolves gpukit_*_{metal,cuda,wgsl}). Drive-by fixes to the umbrella crypto target loop: - skip "math" (it's a substrate with math_codec/math_modarith/... — no umbrella math target ever existed, link line was emitting -lmath) - if(TARGET ${_alg}) so the loop tolerates lazy/absent targets New file fhe/cpp/backends/cuda/lattice_ring_cuda_stub.cpp provides CPU-only stub bodies for the six extern C lattice_ring_cuda_* symbols the FHE host dispatcher references unconditionally. Available() returns 0 so the dispatcher routes to the CPU oracle; every NTT/MUL entry point returns -1 NOTIMPL so any accidental call surfaces immediately. CPU-only configure + build: PASS (crypto static + all linkable tests). --- CMakeLists.txt | 582 +++++++++--------- aead/CMakeLists.txt | 124 ++-- blake2b/CMakeLists.txt | 31 +- bls/CMakeLists.txt | 62 +- cggmp21/CMakeLists.txt | 24 +- fhe/CMakeLists.txt | 33 +- .../backends/cuda/lattice_ring_cuda_stub.cpp | 58 ++ frost/CMakeLists.txt | 15 +- gpukit/CMakeLists.txt | 39 +- lamport/CMakeLists.txt | 22 +- math/CMakeLists.txt | 43 +- ntt/CMakeLists.txt | 46 +- ripemd160/CMakeLists.txt | 62 +- secp256k1/CMakeLists.txt | 74 +-- sha256/CMakeLists.txt | 62 +- 15 files changed, 651 insertions(+), 626 deletions(-) create mode 100644 fhe/cpp/backends/cuda/lattice_ring_cuda_stub.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 1310d3c..2c7d905 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -268,7 +268,15 @@ target_include_directories(crypto target_compile_features(crypto PUBLIC cxx_std_20) set_target_properties(crypto PROPERTIES POSITION_INDEPENDENT_CODE ON) foreach(_alg ${CRYPTO_ALGS}) - target_link_libraries(crypto PUBLIC ${_alg}) + # `math` is a substrate that exposes math_codec / math_modarith / math_poly + # / math_ntt / math_sample — there is no umbrella `math` target. Skip it + # here; consumers link the specific math_* lib(s) they need. + if(_alg STREQUAL "math") + continue() + endif() + if(TARGET ${_alg}) + target_link_libraries(crypto PUBLIC ${_alg}) + endif() endforeach() add_library(lux::crypto ALIAS crypto) @@ -335,39 +343,34 @@ if(CRYPTO_BUILD_TESTS) add_test(NAME secp256k1_ecrecover_pipeline_test COMMAND secp256k1_ecrecover_pipeline_test) - # secp256k1 batch_inv CPU<->CUDA determinism. On non-CUDA hosts the - # driver returns NOTIMPL and the test self-skips via CRYPTO_HAS_CUDA. + # secp256k1 batch_inv CPU<->CUDA determinism. Gated on CRYPTO_ENABLE_CUDA + # — the driver lives in lux-private/gpu-kernels. if(CRYPTO_ENABLE_CUDA) add_executable(batch_inv_cuda_test secp256k1/test/batch_inv_cuda_test.cu) set_target_properties(batch_inv_cuda_test PROPERTIES CUDA_ARCHITECTURES "75" ) - else() - set_source_files_properties(secp256k1/test/batch_inv_cuda_test.cu - PROPERTIES LANGUAGE CXX) - add_executable(batch_inv_cuda_test - secp256k1/test/batch_inv_cuda_test.cu) - target_compile_features(batch_inv_cuda_test PRIVATE cxx_std_20) + target_link_libraries(batch_inv_cuda_test PRIVATE + secp256k1_cpu secp256k1_batch_inv_cuda) + target_include_directories(batch_inv_cuda_test PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/secp256k1/cpp + ) + add_test(NAME batch_inv_cuda_test COMMAND batch_inv_cuda_test) + endif() + + # secp256k1 batch_inv CPU<->WGSL determinism. Gated on CRYPTO_ENABLE_WGSL. + if(CRYPTO_ENABLE_WGSL) + add_executable(batch_inv_wgsl_test + secp256k1/test/batch_inv_wgsl_test.cpp) + target_compile_features(batch_inv_wgsl_test PRIVATE cxx_std_20) + target_link_libraries(batch_inv_wgsl_test PRIVATE + secp256k1_cpu secp256k1_batch_inv_wgsl) + target_include_directories(batch_inv_wgsl_test PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/secp256k1/cpp + ) + add_test(NAME batch_inv_wgsl_test COMMAND batch_inv_wgsl_test) endif() - target_link_libraries(batch_inv_cuda_test PRIVATE - secp256k1_cpu secp256k1_batch_inv_cuda) - target_include_directories(batch_inv_cuda_test PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/secp256k1/cpp - ) - add_test(NAME batch_inv_cuda_test COMMAND batch_inv_cuda_test) - - # secp256k1 batch_inv CPU<->WGSL determinism. Self-skips when - # CRYPTO_HAS_WGSL / CRYPTO_SECP256K1_BATCH_INV_WGSL unset. - add_executable(batch_inv_wgsl_test - secp256k1/test/batch_inv_wgsl_test.cpp) - target_compile_features(batch_inv_wgsl_test PRIVATE cxx_std_20) - target_link_libraries(batch_inv_wgsl_test PRIVATE - secp256k1_cpu secp256k1_batch_inv_wgsl) - target_include_directories(batch_inv_wgsl_test PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/secp256k1/cpp - ) - add_test(NAME batch_inv_wgsl_test COMMAND batch_inv_wgsl_test) add_executable(attestation_test attestation/test/attestation_test.cpp) target_link_libraries(attestation_test PRIVATE attestation_cpu keccak_cpu) @@ -392,50 +395,52 @@ if(CRYPTO_BUILD_TESTS) add_test(NAME sha256_test COMMAND sha256_test) # SHA-256 CPU↔CUDA byte-equality (100/100 vectors covering all - # padding edges). Stub returns 0 / "not available" on Apple/non-CUDA - # hosts so the test still passes structurally. - add_executable(sha256_cuda_test sha256/test/sha256_cuda_test.cpp) - target_link_libraries(sha256_cuda_test PRIVATE sha256_cpu sha256_batch_cuda) - target_include_directories(sha256_cuda_test PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/sha256/cpp - ${CMAKE_CURRENT_SOURCE_DIR}/sha256/gpu/cuda - ) - add_test(NAME sha256_cuda_test COMMAND sha256_cuda_test) + # padding edges). Gated on CRYPTO_ENABLE_CUDA — driver lives in + # lux-private/gpu-kernels. + if(CRYPTO_ENABLE_CUDA) + add_executable(sha256_cuda_test sha256/test/sha256_cuda_test.cpp) + target_link_libraries(sha256_cuda_test PRIVATE sha256_cpu sha256_batch_cuda) + target_include_directories(sha256_cuda_test PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/sha256/cpp + ${CMAKE_CURRENT_SOURCE_DIR}/sha256/gpu/cuda + ) + add_test(NAME sha256_cuda_test COMMAND sha256_cuda_test) + endif() - # SHA-256 CPU↔WGSL byte-equality. Stub returns 0 / "not available" - # when the wgpu-native runtime is missing. - add_executable(sha256_wgpu_test sha256/test/sha256_wgpu_test.cpp) - target_link_libraries(sha256_wgpu_test PRIVATE sha256_cpu sha256_batch_wgpu) - target_include_directories(sha256_wgpu_test PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/sha256/cpp - ${CMAKE_CURRENT_SOURCE_DIR}/sha256/gpu/wgsl - ) - add_test(NAME sha256_wgpu_test COMMAND sha256_wgpu_test) - - # FROST batched pre-signing kernel: CPU canonical body byte-equal to - # the CUDA host polyfill across 100 random batches and a deterministic - # KAT (zero-seed). Covers M=10 signers x N=64 slots = 640 commitments - # plus N=100 random shapes. - add_executable(frost_presign_test frost/test/frost_presign_test.cpp) - target_link_libraries(frost_presign_test PRIVATE frost_cpu frost_cuda) - target_include_directories(frost_presign_test PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/frost/cpp - ${CMAKE_CURRENT_SOURCE_DIR}/secp256k1/cpp - ) - add_test(NAME frost_presign_test COMMAND frost_presign_test) - - # CGGMP21 batched pre-signing kernel: secp256k1 portion (R = k*G) - # wired and byte-equal to the CUDA host polyfill across 100 random - # batches. Paillier ciphertext + ZK proof are reserved bytes - # (status=0xFF) until the 2048-bit Karatsuba modexp primitive lands; - # the test asserts the wire layout is frozen. - add_executable(cggmp21_presign_test cggmp21/test/cggmp21_presign_test.cpp) - target_link_libraries(cggmp21_presign_test PRIVATE cggmp21_cpu cggmp21_cuda) - target_include_directories(cggmp21_presign_test PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/cggmp21/cpp - ${CMAKE_CURRENT_SOURCE_DIR}/secp256k1/cpp - ) - add_test(NAME cggmp21_presign_test COMMAND cggmp21_presign_test) + # SHA-256 CPU↔WGSL byte-equality. Gated on CRYPTO_ENABLE_WGSL. + if(CRYPTO_ENABLE_WGSL) + add_executable(sha256_wgpu_test sha256/test/sha256_wgpu_test.cpp) + target_link_libraries(sha256_wgpu_test PRIVATE sha256_cpu sha256_batch_wgpu) + target_include_directories(sha256_wgpu_test PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/sha256/cpp + ${CMAKE_CURRENT_SOURCE_DIR}/sha256/gpu/wgsl + ) + add_test(NAME sha256_wgpu_test COMMAND sha256_wgpu_test) + endif() + + # FROST batched pre-signing kernel CPU↔CUDA byte-equality. Gated on + # CRYPTO_ENABLE_CUDA — kernel polyfill lives in lux-private/gpu-kernels. + if(CRYPTO_ENABLE_CUDA) + add_executable(frost_presign_test frost/test/frost_presign_test.cpp) + target_link_libraries(frost_presign_test PRIVATE frost_cpu frost_cuda) + target_include_directories(frost_presign_test PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/frost/cpp + ${CMAKE_CURRENT_SOURCE_DIR}/secp256k1/cpp + ) + add_test(NAME frost_presign_test COMMAND frost_presign_test) + endif() + + # CGGMP21 batched pre-signing kernel CPU↔CUDA byte-equality. Gated on + # CRYPTO_ENABLE_CUDA — kernel polyfill lives in lux-private/gpu-kernels. + if(CRYPTO_ENABLE_CUDA) + add_executable(cggmp21_presign_test cggmp21/test/cggmp21_presign_test.cpp) + target_link_libraries(cggmp21_presign_test PRIVATE cggmp21_cpu cggmp21_cuda) + target_include_directories(cggmp21_presign_test PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/cggmp21/cpp + ${CMAKE_CURRENT_SOURCE_DIR}/secp256k1/cpp + ) + add_test(NAME cggmp21_presign_test COMMAND cggmp21_presign_test) + endif() # Paillier 2048-bit KAT — keygen (1024-bit p,q via MR-40) + 8 encrypt/ # decrypt round-trips + Π^enc accept + Π^enc reject (tampered K) + @@ -458,28 +463,30 @@ if(CRYPTO_BUILD_TESTS) target_link_libraries(ripemd160_test PRIVATE ripemd160) add_test(NAME ripemd160_test COMMAND ripemd160_test) - # RIPEMD-160 CPU↔CUDA byte-equality (100/100 vectors covering all - # padding edges). Stub returns 0 / "not available" on Apple/non-CUDA - # hosts so the test still passes structurally. - add_executable(ripemd160_cuda_test ripemd160/test/ripemd160_cuda_test.cpp) - target_link_libraries(ripemd160_cuda_test PRIVATE - ripemd160_cpu ripemd160_batch_cuda) - target_include_directories(ripemd160_cuda_test PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/ripemd160/cpp - ${CMAKE_CURRENT_SOURCE_DIR}/ripemd160/gpu/cuda - ) - add_test(NAME ripemd160_cuda_test COMMAND ripemd160_cuda_test) - - # RIPEMD-160 CPU↔WGSL byte-equality. Stub returns 0 / "not available" - # when the wgpu-native runtime is missing. - add_executable(ripemd160_wgpu_test ripemd160/test/ripemd160_wgpu_test.cpp) - target_link_libraries(ripemd160_wgpu_test PRIVATE - ripemd160_cpu ripemd160_batch_wgpu) - target_include_directories(ripemd160_wgpu_test PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/ripemd160/cpp - ${CMAKE_CURRENT_SOURCE_DIR}/ripemd160/gpu/wgsl - ) - add_test(NAME ripemd160_wgpu_test COMMAND ripemd160_wgpu_test) + # RIPEMD-160 CPU↔CUDA byte-equality. Gated on CRYPTO_ENABLE_CUDA — + # driver lives in lux-private/gpu-kernels. + if(CRYPTO_ENABLE_CUDA) + add_executable(ripemd160_cuda_test ripemd160/test/ripemd160_cuda_test.cpp) + target_link_libraries(ripemd160_cuda_test PRIVATE + ripemd160_cpu ripemd160_batch_cuda) + target_include_directories(ripemd160_cuda_test PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/ripemd160/cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ripemd160/gpu/cuda + ) + add_test(NAME ripemd160_cuda_test COMMAND ripemd160_cuda_test) + endif() + + # RIPEMD-160 CPU↔WGSL byte-equality. Gated on CRYPTO_ENABLE_WGSL. + if(CRYPTO_ENABLE_WGSL) + add_executable(ripemd160_wgpu_test ripemd160/test/ripemd160_wgpu_test.cpp) + target_link_libraries(ripemd160_wgpu_test PRIVATE + ripemd160_cpu ripemd160_batch_wgpu) + target_include_directories(ripemd160_wgpu_test PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/ripemd160/cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ripemd160/gpu/wgsl + ) + add_test(NAME ripemd160_wgpu_test COMMAND ripemd160_wgpu_test) + endif() # BLAKE2b (RFC 7693): wired CPU implementation, RFC test vectors. add_executable(blake2b_test blake2b/test/blake2b_test.cpp) @@ -493,37 +500,38 @@ if(CRYPTO_BUILD_TESTS) # bn254 GPU determinism harness: dispatches the same vectors through CUDA # and WebGPU drivers and asserts byte-equality with the CPU oracle. - # Without a real device the drivers run the CPU oracle directly so the - # wire format is still exercised end-to-end. Covers Fp_mul, G1Add, G1Mul, - # SVDW, Fp2_mul, Fp12_mul, MillerIter (cyclo-sqr^100), pairing. - add_executable(bn254_gpu_determinism_test - bn254/test/bn254_gpu_determinism_test.cpp - bn254/gpu/cuda/bn254_driver_cuda.cpp - bn254/gpu/wgsl/bn254_driver_wgpu.cpp) - target_link_libraries(bn254_gpu_determinism_test PRIVATE bn254_cpu) - target_include_directories(bn254_gpu_determinism_test PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/bn254/cpp - ${CMAKE_CURRENT_SOURCE_DIR}/bn254 - ${CMAKE_CURRENT_SOURCE_DIR}/bn254/gpu/cuda - ${CMAKE_CURRENT_SOURCE_DIR}/bn254/gpu/wgsl) - add_test(NAME bn254_gpu_determinism_test COMMAND bn254_gpu_determinism_test) - - # Codegen: bn254 GPU pairing constants emitted from CPU body. - # Produces bn254_pairing_consts_{cuda.cuh,wgsl.wgslh} so CPU is the single - # producer of Frobenius coefficients across both backends. - add_executable(bn254_gen_pairing_constants - bn254/test/tools/gen_pairing_constants.cpp) - set(_bn_consts_cuda ${CMAKE_CURRENT_SOURCE_DIR}/bn254/gpu/cuda/bn254_pairing_consts_cuda.cuh) - set(_bn_consts_wgsl ${CMAKE_CURRENT_BINARY_DIR}/bn254_pairing_consts_wgsl.wgslh) - set(_bn_wgsl_kernel ${CMAKE_CURRENT_SOURCE_DIR}/bn254/gpu/wgsl/bn254.wgsl) - add_custom_command( - OUTPUT ${_bn_consts_cuda} ${_bn_consts_wgsl} - COMMAND bn254_gen_pairing_constants ${_bn_consts_cuda} ${_bn_consts_wgsl} ${_bn_wgsl_kernel} - DEPENDS bn254_gen_pairing_constants ${_bn_wgsl_kernel} - COMMENT "Regenerating bn254 pairing GPU constants from CPU body (and verifying inlined WGSL kernel constants)") - add_custom_target(bn254_pairing_consts ALL - DEPENDS ${_bn_consts_cuda} ${_bn_consts_wgsl}) - add_dependencies(bn254_gpu_determinism_test bn254_pairing_consts) + # Gated on CRYPTO_ENABLE_CUDA OR CRYPTO_ENABLE_WGSL — the driver sources + # live in lux-private/gpu-kernels (symlinked into bn254/gpu/{cuda,wgsl}/). + if(CRYPTO_ENABLE_CUDA OR CRYPTO_ENABLE_WGSL) + add_executable(bn254_gpu_determinism_test + bn254/test/bn254_gpu_determinism_test.cpp + bn254/gpu/cuda/bn254_driver_cuda.cpp + bn254/gpu/wgsl/bn254_driver_wgpu.cpp) + target_link_libraries(bn254_gpu_determinism_test PRIVATE bn254_cpu) + target_include_directories(bn254_gpu_determinism_test PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bn254/cpp + ${CMAKE_CURRENT_SOURCE_DIR}/bn254 + ${CMAKE_CURRENT_SOURCE_DIR}/bn254/gpu/cuda + ${CMAKE_CURRENT_SOURCE_DIR}/bn254/gpu/wgsl) + add_test(NAME bn254_gpu_determinism_test COMMAND bn254_gpu_determinism_test) + + # Codegen: bn254 GPU pairing constants emitted from CPU body. + # Produces bn254_pairing_consts_{cuda.cuh,wgsl.wgslh} so CPU is the + # single producer of Frobenius coefficients across both backends. + add_executable(bn254_gen_pairing_constants + bn254/test/tools/gen_pairing_constants.cpp) + set(_bn_consts_cuda ${CMAKE_CURRENT_SOURCE_DIR}/bn254/gpu/cuda/bn254_pairing_consts_cuda.cuh) + set(_bn_consts_wgsl ${CMAKE_CURRENT_BINARY_DIR}/bn254_pairing_consts_wgsl.wgslh) + set(_bn_wgsl_kernel ${CMAKE_CURRENT_SOURCE_DIR}/bn254/gpu/wgsl/bn254.wgsl) + add_custom_command( + OUTPUT ${_bn_consts_cuda} ${_bn_consts_wgsl} + COMMAND bn254_gen_pairing_constants ${_bn_consts_cuda} ${_bn_consts_wgsl} ${_bn_wgsl_kernel} + DEPENDS bn254_gen_pairing_constants ${_bn_wgsl_kernel} + COMMENT "Regenerating bn254 pairing GPU constants from CPU body (and verifying inlined WGSL kernel constants)") + add_custom_target(bn254_pairing_consts ALL + DEPENDS ${_bn_consts_cuda} ${_bn_consts_wgsl}) + add_dependencies(bn254_gpu_determinism_test bn254_pairing_consts) + endif() # modexp (EIP-198) + evm256_{add,mul}mod KAT. add_executable(modexp_kat_test modexp/test/modexp_kat_test.cpp) @@ -548,42 +556,38 @@ if(CRYPTO_BUILD_TESTS) target_include_directories(modexp_karatsuba_bench PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/modexp/cpp) - # modexp Karatsuba GPU determinism test. Compiles the CUDA kernel .cu as - # plain C++ (host polyfill) and verifies byte-equality against the CPU - # oracle at 16/32/64 limbs (1024/2048/4096 bit). When CRYPTO_ENABLE_CUDA - # is ON the same .cu compiles as a real CUDA TU; the test runs against - # the host-side polyfill body so the wire format is exercised end-to-end - # without requiring an NVIDIA device in CI. - add_executable(modexp_karatsuba_gpu_test - modexp/test/modexp_karatsuba_gpu_test.cpp - modexp/gpu/cuda/modexp_karatsuba.cu) - set_source_files_properties(modexp/gpu/cuda/modexp_karatsuba.cu - PROPERTIES LANGUAGE CXX) - target_link_libraries(modexp_karatsuba_gpu_test PRIVATE modexp_cpu) - target_include_directories(modexp_karatsuba_gpu_test PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/modexp/cpp) - target_compile_features(modexp_karatsuba_gpu_test PRIVATE cxx_std_20) - add_test(NAME modexp_karatsuba_gpu_test - COMMAND modexp_karatsuba_gpu_test) - - # kzg (EIP-4844) GPU determinism: 100 deterministic blobs per op - # (blob_to_kzg_commitment, compute_blob_kzg_proof, verify_kzg_proof) per - # backend (cuda, wgpu). On hosts without a real device the driver runs - # the CPU oracle path so the round-trip is exercised end-to-end and the - # harness reports 100/100. On CI with hardware the same vectors flow - # through the actual kernel and byte-equality is asserted. Reuses BLS12- - # 381 G1+Fp arithmetic patterns from bls/gpu/{cuda,wgsl}/. Optional - # EIP-4844 KAT block runs when CRYPTO_KZG_KAT_DIR is set. - add_executable(kzg_gpu_determinism_test - kzg/test/kzg_gpu_determinism_test.cpp - kzg/gpu/cuda/kzg_driver_cuda.cpp - kzg/gpu/wgsl/kzg_driver_wgpu.cpp) - target_include_directories(kzg_gpu_determinism_test PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/kzg/cpp - ${CMAKE_CURRENT_SOURCE_DIR}/kzg - ${CMAKE_CURRENT_SOURCE_DIR}/bls/gpu/cuda - ${CMAKE_CURRENT_SOURCE_DIR}/bls/gpu/wgsl) - add_test(NAME kzg_gpu_determinism_test COMMAND kzg_gpu_determinism_test) + # modexp Karatsuba GPU determinism test. Gated on CRYPTO_ENABLE_CUDA — + # the .cu source lives in lux-private/gpu-kernels (symlinked into + # modexp/gpu/cuda/). + if(CRYPTO_ENABLE_CUDA) + add_executable(modexp_karatsuba_gpu_test + modexp/test/modexp_karatsuba_gpu_test.cpp + modexp/gpu/cuda/modexp_karatsuba.cu) + set_source_files_properties(modexp/gpu/cuda/modexp_karatsuba.cu + PROPERTIES LANGUAGE CXX) + target_link_libraries(modexp_karatsuba_gpu_test PRIVATE modexp_cpu) + target_include_directories(modexp_karatsuba_gpu_test PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/modexp/cpp) + target_compile_features(modexp_karatsuba_gpu_test PRIVATE cxx_std_20) + add_test(NAME modexp_karatsuba_gpu_test + COMMAND modexp_karatsuba_gpu_test) + endif() + + # kzg (EIP-4844) GPU determinism: 100 deterministic blobs per op. + # Gated on CRYPTO_ENABLE_CUDA OR CRYPTO_ENABLE_WGSL — driver sources + # live in lux-private/gpu-kernels (symlinked into kzg/gpu/{cuda,wgsl}/). + if(CRYPTO_ENABLE_CUDA OR CRYPTO_ENABLE_WGSL) + add_executable(kzg_gpu_determinism_test + kzg/test/kzg_gpu_determinism_test.cpp + kzg/gpu/cuda/kzg_driver_cuda.cpp + kzg/gpu/wgsl/kzg_driver_wgpu.cpp) + target_include_directories(kzg_gpu_determinism_test PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/kzg/cpp + ${CMAKE_CURRENT_SOURCE_DIR}/kzg + ${CMAKE_CURRENT_SOURCE_DIR}/bls/gpu/cuda + ${CMAKE_CURRENT_SOURCE_DIR}/bls/gpu/wgsl) + add_test(NAME kzg_gpu_determinism_test COMMAND kzg_gpu_determinism_test) + endif() # IPA (Bulletproofs over Banderwagon): 5 valid prover/verifier KAT # roundtrips at N=256 (in-domain + out-of-domain) byte-equal vs @@ -1074,17 +1078,25 @@ if(CRYPTO_BUILD_TESTS) target_compile_options(ringtail_blake_bench PRIVATE -O3) # CPU vs GPU NTT throughput benchmark for canonical Ringtail M1 - # (Q=0x1000000004A01, N=256). Compares lattice_ring scalar CPU against - # Metal + WGSL backends on the same montgomery_ntt.json constants. - # Informational; not a ctest assertion target. - add_executable(ringtail_lattice_ring_bench ringtail/test/lattice_ring_bench.cpp) - target_link_libraries(ringtail_lattice_ring_bench PRIVATE - lattice_ring_cpu) - target_include_directories(ringtail_lattice_ring_bench PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/ringtail/cpp - ${CMAKE_CURRENT_SOURCE_DIR}/ringtail/gpu/metal - ${CMAKE_CURRENT_SOURCE_DIR}/ringtail/gpu/wgsl) - target_compile_options(ringtail_lattice_ring_bench PRIVATE -O3) + # (Q=0x1000000004A01, N=256). The bench source unconditionally calls + # the Metal + WGSL entry points, so it only builds when both backends + # are enabled. Informational; not a ctest assertion target. + if(CRYPTO_ENABLE_METAL AND CRYPTO_ENABLE_WGSL) + add_executable(ringtail_lattice_ring_bench ringtail/test/lattice_ring_bench.cpp) + target_link_libraries(ringtail_lattice_ring_bench PRIVATE + lattice_ring_cpu) + target_include_directories(ringtail_lattice_ring_bench PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/ringtail/cpp + ${CMAKE_CURRENT_SOURCE_DIR}/math/ntt/metal + ${CMAKE_CURRENT_SOURCE_DIR}/math/ntt/wgsl) + target_compile_options(ringtail_lattice_ring_bench PRIVATE -O3) + if(TARGET lattice_ring_metal) + target_link_libraries(ringtail_lattice_ring_bench PRIVATE lattice_ring_metal) + endif() + if(TARGET lattice_ring_wgpu) + target_link_libraries(ringtail_lattice_ring_bench PRIVATE lattice_ring_wgpu) + endif() + endif() # End-to-end Sign+Verify throughput bench across CPU / Metal / WGSL. # Backend selected via CRYPTO_BACKEND env var; defaults to CPU. @@ -1096,35 +1108,26 @@ if(CRYPTO_BUILD_TESTS) target_include_directories(ringtail_sign_throughput_bench PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/ringtail/cpp) target_compile_options(ringtail_sign_throughput_bench PRIVATE -O3) - if(APPLE AND CRYPTO_ENABLE_METAL AND TARGET lattice_ring_metal) - target_link_libraries(ringtail_lattice_ring_bench PRIVATE - lattice_ring_metal) - endif() - if(CRYPTO_ENABLE_WGSL AND TARGET lattice_ring_wgpu) - target_link_libraries(ringtail_lattice_ring_bench PRIVATE - lattice_ring_wgpu) - endif() # Extended sweep bench: 10 batch sizes from 1 to 65536, TSV output for # the Blockchain Scaling Laws paper. Same backends as - # ringtail_lattice_ring_bench but spans a wider range to find the - # crossover and the killer batch. - add_executable(ringtail_lattice_ring_sweep_bench - ringtail/test/lattice_ring_sweep_bench.cpp) - target_link_libraries(ringtail_lattice_ring_sweep_bench PRIVATE - lattice_ring_cpu) - target_include_directories(ringtail_lattice_ring_sweep_bench PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/ringtail/cpp - ${CMAKE_CURRENT_SOURCE_DIR}/ringtail/gpu/metal - ${CMAKE_CURRENT_SOURCE_DIR}/ringtail/gpu/wgsl) - target_compile_options(ringtail_lattice_ring_sweep_bench PRIVATE -O3) - if(APPLE AND CRYPTO_ENABLE_METAL AND TARGET lattice_ring_metal) + # ringtail_lattice_ring_bench. Built only when both GPU backends are on. + if(CRYPTO_ENABLE_METAL AND CRYPTO_ENABLE_WGSL) + add_executable(ringtail_lattice_ring_sweep_bench + ringtail/test/lattice_ring_sweep_bench.cpp) target_link_libraries(ringtail_lattice_ring_sweep_bench PRIVATE - lattice_ring_metal) - endif() - if(CRYPTO_ENABLE_WGSL AND TARGET lattice_ring_wgpu) - target_link_libraries(ringtail_lattice_ring_sweep_bench PRIVATE - lattice_ring_wgpu) + lattice_ring_cpu) + target_include_directories(ringtail_lattice_ring_sweep_bench PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/ringtail/cpp + ${CMAKE_CURRENT_SOURCE_DIR}/math/ntt/metal + ${CMAKE_CURRENT_SOURCE_DIR}/math/ntt/wgsl) + target_compile_options(ringtail_lattice_ring_sweep_bench PRIVATE -O3) + if(TARGET lattice_ring_metal) + target_link_libraries(ringtail_lattice_ring_sweep_bench PRIVATE lattice_ring_metal) + endif() + if(TARGET lattice_ring_wgpu) + target_link_libraries(ringtail_lattice_ring_sweep_bench PRIVATE lattice_ring_wgpu) + endif() endif() # banderwagon MSM doc-presence regression guard: asserts the variable-time @@ -1153,7 +1156,7 @@ if(CRYPTO_BUILD_TESTS) # Per-backend determinism tests for the tree-reduce kernel. Each backend # registers unconditionally; the tests skip with success if their runtime # isn't available so CTest still surfaces the lane on every host. - if(APPLE) + if(APPLE AND CRYPTO_ENABLE_METAL AND TARGET pedersen_metal) add_executable(pedersen_tree_metal_determinism_test pedersen/test/pedersen_tree_metal_determinism_test.mm) target_link_libraries(pedersen_tree_metal_determinism_test PRIVATE @@ -1172,35 +1175,39 @@ if(CRYPTO_BUILD_TESTS) TIMEOUT 1800) endif() - add_executable(pedersen_tree_cuda_determinism_test - pedersen/test/pedersen_tree_cuda_determinism_test.cpp) - target_link_libraries(pedersen_tree_cuda_determinism_test PRIVATE - pedersen_cpu pedersen pedersen_cuda) - target_include_directories(pedersen_tree_cuda_determinism_test PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/pedersen/cpp - ${CMAKE_CURRENT_SOURCE_DIR}/bn254/cpp - ${CMAKE_CURRENT_SOURCE_DIR}/pedersen/gpu/cuda - ${CMAKE_CURRENT_SOURCE_DIR}/c-abi - ) - add_test(NAME pedersen_tree_cuda_determinism_test - COMMAND pedersen_tree_cuda_determinism_test) - set_tests_properties(pedersen_tree_cuda_determinism_test PROPERTIES - TIMEOUT 1800) - - add_executable(pedersen_tree_wgpu_determinism_test - pedersen/test/pedersen_tree_wgpu_determinism_test.cpp) - target_link_libraries(pedersen_tree_wgpu_determinism_test PRIVATE - pedersen_cpu pedersen pedersen_wgpu) - target_include_directories(pedersen_tree_wgpu_determinism_test PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/pedersen/cpp - ${CMAKE_CURRENT_SOURCE_DIR}/bn254/cpp - ${CMAKE_CURRENT_SOURCE_DIR}/pedersen/gpu/wgsl - ${CMAKE_CURRENT_SOURCE_DIR}/c-abi - ) - add_test(NAME pedersen_tree_wgpu_determinism_test - COMMAND pedersen_tree_wgpu_determinism_test) - set_tests_properties(pedersen_tree_wgpu_determinism_test PROPERTIES - TIMEOUT 1800) + if(CRYPTO_ENABLE_CUDA AND TARGET pedersen_cuda) + add_executable(pedersen_tree_cuda_determinism_test + pedersen/test/pedersen_tree_cuda_determinism_test.cpp) + target_link_libraries(pedersen_tree_cuda_determinism_test PRIVATE + pedersen_cpu pedersen pedersen_cuda) + target_include_directories(pedersen_tree_cuda_determinism_test PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/pedersen/cpp + ${CMAKE_CURRENT_SOURCE_DIR}/bn254/cpp + ${CMAKE_CURRENT_SOURCE_DIR}/pedersen/gpu/cuda + ${CMAKE_CURRENT_SOURCE_DIR}/c-abi + ) + add_test(NAME pedersen_tree_cuda_determinism_test + COMMAND pedersen_tree_cuda_determinism_test) + set_tests_properties(pedersen_tree_cuda_determinism_test PROPERTIES + TIMEOUT 1800) + endif() + + if(CRYPTO_ENABLE_WGSL AND TARGET pedersen_wgpu) + add_executable(pedersen_tree_wgpu_determinism_test + pedersen/test/pedersen_tree_wgpu_determinism_test.cpp) + target_link_libraries(pedersen_tree_wgpu_determinism_test PRIVATE + pedersen_cpu pedersen pedersen_wgpu) + target_include_directories(pedersen_tree_wgpu_determinism_test PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/pedersen/cpp + ${CMAKE_CURRENT_SOURCE_DIR}/bn254/cpp + ${CMAKE_CURRENT_SOURCE_DIR}/pedersen/gpu/wgsl + ${CMAKE_CURRENT_SOURCE_DIR}/c-abi + ) + add_test(NAME pedersen_tree_wgpu_determinism_test + COMMAND pedersen_tree_wgpu_determinism_test) + set_tests_properties(pedersen_tree_wgpu_determinism_test PROPERTIES + TIMEOUT 1800) + endif() if(CRYPTO_ENABLE_METAL AND APPLE) find_program(XCRUN xcrun) @@ -1517,10 +1524,9 @@ if(CRYPTO_BUILD_TESTS) DEPENDS ${_bw_const_metalh} ${_bw_const_cuh} ${_bw_const_wgslh}) # ========================================================================= - # Banderwagon CUDA backend. nvcc-compiled when CRYPTO_ENABLE_CUDA=ON; - # otherwise the same .cu file compiles as plain C++ via the host polyfill. - # Either path runs the identical kernel body and yields byte-equal output - # to the CPU oracle by construction. + # Banderwagon CUDA backend. Built only when CRYPTO_ENABLE_CUDA=ON + # (lux-gpu-kernels found + NVCC). The .cu source lives in + # lux-private/gpu-kernels and is symlinked into banderwagon/gpu/cuda/. # ========================================================================= if(CRYPTO_ENABLE_CUDA) add_library(banderwagon_cuda STATIC @@ -1530,76 +1536,72 @@ if(CRYPTO_BUILD_TESTS) set_target_properties(banderwagon_cuda PROPERTIES POSITION_INDEPENDENT_CODE ON CUDA_SEPARABLE_COMPILATION ON) - else() - add_library(banderwagon_cuda STATIC - banderwagon/gpu/cuda/banderwagon.cu) - set_source_files_properties(banderwagon/gpu/cuda/banderwagon.cu - PROPERTIES LANGUAGE CXX) - target_compile_features(banderwagon_cuda PUBLIC cxx_std_20) - set_target_properties(banderwagon_cuda PROPERTIES - POSITION_INDEPENDENT_CODE ON) + target_include_directories(banderwagon_cuda + PUBLIC + $ + PRIVATE + $) + add_dependencies(banderwagon_cuda banderwagon_gpu_constants) + add_library(lux::banderwagon_cuda ALIAS banderwagon_cuda) + + add_executable(banderwagon_cuda_determinism_test + banderwagon/test/banderwagon_cuda_determinism_test.cpp) + target_link_libraries(banderwagon_cuda_determinism_test + PRIVATE banderwagon_cpu banderwagon_cuda) + target_include_directories(banderwagon_cuda_determinism_test PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/banderwagon/cpp + ${CMAKE_CURRENT_SOURCE_DIR}/banderwagon/gpu/cuda) + add_test(NAME banderwagon_cuda_determinism_test + COMMAND banderwagon_cuda_determinism_test) endif() - target_include_directories(banderwagon_cuda - PUBLIC - $ - PRIVATE - $) - add_dependencies(banderwagon_cuda banderwagon_gpu_constants) - add_library(lux::banderwagon_cuda ALIAS banderwagon_cuda) - - add_executable(banderwagon_cuda_determinism_test - banderwagon/test/banderwagon_cuda_determinism_test.cpp) - target_link_libraries(banderwagon_cuda_determinism_test - PRIVATE banderwagon_cpu banderwagon_cuda) - target_include_directories(banderwagon_cuda_determinism_test PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/banderwagon/cpp - ${CMAKE_CURRENT_SOURCE_DIR}/banderwagon/gpu/cuda) - add_test(NAME banderwagon_cuda_determinism_test - COMMAND banderwagon_cuda_determinism_test) # ========================================================================= - # Banderwagon WGSL backend. Host polyfill mirrors the .wgsl kernel byte- - # for-byte using only u32 ops (WGSL has no native u64). Byte-equal to the - # CPU oracle by construction; no wgpu runtime needed in CI. + # Banderwagon WGSL backend. Built only when CRYPTO_ENABLE_WGSL=ON + # (lux-gpu-kernels found). The driver source lives in + # lux-private/gpu-kernels and is symlinked into banderwagon/gpu/wgsl/. # ========================================================================= - add_library(banderwagon_wgsl STATIC - banderwagon/gpu/wgsl/banderwagon_driver.cpp) - target_include_directories(banderwagon_wgsl - PUBLIC - $ - PRIVATE - $) - target_compile_features(banderwagon_wgsl PUBLIC cxx_std_20) - set_target_properties(banderwagon_wgsl PROPERTIES POSITION_INDEPENDENT_CODE ON) - add_dependencies(banderwagon_wgsl banderwagon_gpu_constants) - add_library(lux::banderwagon_wgsl ALIAS banderwagon_wgsl) - - add_executable(banderwagon_wgsl_determinism_test - banderwagon/test/banderwagon_wgsl_determinism_test.cpp) - target_link_libraries(banderwagon_wgsl_determinism_test - PRIVATE banderwagon_cpu banderwagon_wgsl) - target_include_directories(banderwagon_wgsl_determinism_test PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/banderwagon/cpp - ${CMAKE_CURRENT_SOURCE_DIR}/banderwagon/gpu/wgsl) - add_test(NAME banderwagon_wgsl_determinism_test - COMMAND banderwagon_wgsl_determinism_test) + if(CRYPTO_ENABLE_WGSL) + add_library(banderwagon_wgsl STATIC + banderwagon/gpu/wgsl/banderwagon_driver.cpp) + target_include_directories(banderwagon_wgsl + PUBLIC + $ + PRIVATE + $) + target_compile_features(banderwagon_wgsl PUBLIC cxx_std_20) + set_target_properties(banderwagon_wgsl PROPERTIES POSITION_INDEPENDENT_CODE ON) + add_dependencies(banderwagon_wgsl banderwagon_gpu_constants) + add_library(lux::banderwagon_wgsl ALIAS banderwagon_wgsl) + + add_executable(banderwagon_wgsl_determinism_test + banderwagon/test/banderwagon_wgsl_determinism_test.cpp) + target_link_libraries(banderwagon_wgsl_determinism_test + PRIVATE banderwagon_cpu banderwagon_wgsl) + target_include_directories(banderwagon_wgsl_determinism_test PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/banderwagon/cpp + ${CMAKE_CURRENT_SOURCE_DIR}/banderwagon/gpu/wgsl) + add_test(NAME banderwagon_wgsl_determinism_test + COMMAND banderwagon_wgsl_determinism_test) + endif() # ========================================================================= # NTT large-N (six-step) round-trip + cross-backend byte-equality. # Covers N in {2^17, 2^18, 2^19, 2^20} for both Cyclone-FFT prime - # (Q = 119*2^23 + 1) and TFHE q = 2^64. CPU oracle plus CUDA/Metal/WGSL - # drivers (currently CPU-fallback on hosts without devices). + # (Q = 119*2^23 + 1) and TFHE q = 2^64. Gated on all three GPU backends + # being enabled — the host drivers live in lux-private/gpu-kernels. # ========================================================================= - add_executable(ntt_large_test ntt/test/ntt_large_test.cpp) - target_link_libraries(ntt_large_test PRIVATE - ntt_cpu - ntt_large_gpu_cuda - ntt_large_gpu_metal - ntt_large_gpu_wgsl) - target_include_directories(ntt_large_test PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/ntt/cpp) - add_test(NAME ntt_large_test COMMAND ntt_large_test) - set_tests_properties(ntt_large_test PROPERTIES TIMEOUT 600) + if(CRYPTO_ENABLE_CUDA AND CRYPTO_ENABLE_METAL AND CRYPTO_ENABLE_WGSL) + add_executable(ntt_large_test ntt/test/ntt_large_test.cpp) + target_link_libraries(ntt_large_test PRIVATE + ntt_cpu + ntt_large_gpu_cuda + ntt_large_gpu_metal + ntt_large_gpu_wgsl) + target_include_directories(ntt_large_test PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/ntt/cpp) + add_test(NAME ntt_large_test COMMAND ntt_large_test) + set_tests_properties(ntt_large_test PROPERTIES TIMEOUT 600) + endif() # ========================================================================= # Per-algorithm KAT tests (LP-178 audit a90c7a8a52ef12c96). diff --git a/aead/CMakeLists.txt b/aead/CMakeLists.txt index 01bf1c6..c09ac9e 100644 --- a/aead/CMakeLists.txt +++ b/aead/CMakeLists.txt @@ -39,28 +39,21 @@ endif() # ============================================================================= # CUDA drivers + determinism test. # -# The .cu files compile only when CRYPTO_ENABLE_CUDA is ON (NVCC required). -# The host driver compiles on every host: with LUX_AEAD_HAVE_CUDA defined it -# launches the kernels; without, it returns -1 from every entry point so the -# determinism test prints a "build-only" banner on Apple. +# Built only when CRYPTO_ENABLE_CUDA=ON (lux-gpu-kernels found + NVCC). The +# host driver and .cu kernels live in lux-private/gpu-kernels; consumer tests +# at the top-level CMakeLists are gated on the same flag. # ============================================================================= - -# Always build the host driver (stub or real). On Apple this gives us the -# build-only path (compiles, links, returns -1). On Linux+CUDA CI we set -# CRYPTO_ENABLE_CUDA=ON which flips the driver into real-dispatch mode -# and additionally compiles the .cu kernels. - -add_library(aead_cuda_driver STATIC gpu/cuda/aead_driver_cuda.cpp) -target_include_directories(aead_cuda_driver - PUBLIC - $ - $ -) -target_compile_features(aead_cuda_driver PUBLIC cxx_std_20) -set_target_properties(aead_cuda_driver PROPERTIES POSITION_INDEPENDENT_CODE ON) - if(CRYPTO_ENABLE_CUDA) enable_language(CUDA) + add_library(aead_cuda_driver STATIC gpu/cuda/aead_driver_cuda.cpp) + target_include_directories(aead_cuda_driver + PUBLIC + $ + $ + ) + target_compile_features(aead_cuda_driver PUBLIC cxx_std_20) + set_target_properties(aead_cuda_driver PROPERTIES POSITION_INDEPENDENT_CODE ON) + add_library(aead_cuda_kernels STATIC gpu/cuda/chacha20_poly1305.cu gpu/cuda/aes_gcm.cu) @@ -76,60 +69,53 @@ if(CRYPTO_ENABLE_CUDA) target_link_libraries(aead_cuda_driver PUBLIC CUDA::cudart) endif() -# Test executables are added at the top-level CMakeLists.txt where -# enable_testing() has been called before us. The library targets above -# are sufficient for the parent to wire test binaries. - # ============================================================================= # WGSL driver (compiles WGSL via wgpu-native or Dawn; outputs bound through -# C ABI). On Apple, byte-equality is enforced if wgpu-native is on the system; -# otherwise the test self-skips. On Linux/CI, same behavior with Dawn. +# C ABI). Built only when CRYPTO_ENABLE_WGSL=ON. # ============================================================================= +if(CRYPTO_ENABLE_WGSL) + # Concatenate WGSL sources into a header that the host driver embeds. + set(AEAD_WGSL_DIR "${CMAKE_CURRENT_SOURCE_DIR}/gpu/wgsl") + set(AEAD_WGSL_HEADER "${CMAKE_CURRENT_BINARY_DIR}/aead_wgsl_sources.h") + file(WRITE "${AEAD_WGSL_HEADER}" "// Auto-generated. Do not edit.\n#pragma once\n\n") + foreach(_SRC chacha20_poly1305.wgsl aes_gcm.wgsl) + get_filename_component(_NAME ${_SRC} NAME_WE) + string(REGEX REPLACE "^chacha20_poly1305$" "ChaCha20Poly1305" _SHORT "${_NAME}") + string(REGEX REPLACE "^aes_gcm$" "AesGcm" _SHORT "${_SHORT}") + file(READ "${AEAD_WGSL_DIR}/${_SRC}" _CONTENT) + file(APPEND "${AEAD_WGSL_HEADER}" + "// ---- ${_SRC} ----\n" + "static constexpr char kAEAD_WGSL_${_SHORT}[] = R\"AEADWGSL(\n${_CONTENT}\n)AEADWGSL\";\n\n") + endforeach() -# Concatenate WGSL sources into a header that the host driver embeds. -set(AEAD_WGSL_DIR "${CMAKE_CURRENT_SOURCE_DIR}/gpu/wgsl") -set(AEAD_WGSL_HEADER "${CMAKE_CURRENT_BINARY_DIR}/aead_wgsl_sources.h") -file(WRITE "${AEAD_WGSL_HEADER}" "// Auto-generated. Do not edit.\n#pragma once\n\n") -foreach(_SRC chacha20_poly1305.wgsl aes_gcm.wgsl) - get_filename_component(_NAME ${_SRC} NAME_WE) - string(REGEX REPLACE "^chacha20_poly1305$" "ChaCha20Poly1305" _SHORT "${_NAME}") - string(REGEX REPLACE "^aes_gcm$" "AesGcm" _SHORT "${_SHORT}") - file(READ "${AEAD_WGSL_DIR}/${_SRC}" _CONTENT) - file(APPEND "${AEAD_WGSL_HEADER}" - "// ---- ${_SRC} ----\n" - "static constexpr char kAEAD_WGSL_${_SHORT}[] = R\"AEADWGSL(\n${_CONTENT}\n)AEADWGSL\";\n\n") -endforeach() - -# Locate wgpu-native (Homebrew) or Dawn. -set(_AEAD_WGPU_FOUND FALSE) -find_path(_AEAD_WGPU_INCLUDE webgpu.h - HINTS /opt/homebrew/include /usr/local/include /usr/include) -find_library(_AEAD_WGPU_LIB NAMES wgpu_native wgpu - HINTS /opt/homebrew/lib /usr/local/lib /usr/lib) -if(_AEAD_WGPU_INCLUDE AND _AEAD_WGPU_LIB) - set(_AEAD_WGPU_FOUND TRUE) -endif() + # Locate wgpu-native (Homebrew) or Dawn. + set(_AEAD_WGPU_FOUND FALSE) + find_path(_AEAD_WGPU_INCLUDE webgpu.h + HINTS /opt/homebrew/include /usr/local/include /usr/include) + find_library(_AEAD_WGPU_LIB NAMES wgpu_native wgpu + HINTS /opt/homebrew/lib /usr/local/lib /usr/lib) + if(_AEAD_WGPU_INCLUDE AND _AEAD_WGPU_LIB) + set(_AEAD_WGPU_FOUND TRUE) + endif() -add_library(aead_wgpu_driver STATIC gpu/wgsl/aead_driver_wgpu.cpp) -target_include_directories(aead_wgpu_driver - PRIVATE - "${CMAKE_CURRENT_BINARY_DIR}" - "${CMAKE_CURRENT_SOURCE_DIR}/gpu/wgsl" - "${CMAKE_CURRENT_SOURCE_DIR}/cpp" -) -target_compile_features(aead_wgpu_driver PUBLIC cxx_std_20) -set_target_properties(aead_wgpu_driver PROPERTIES POSITION_INDEPENDENT_CODE ON) + add_library(aead_wgpu_driver STATIC gpu/wgsl/aead_driver_wgpu.cpp) + target_include_directories(aead_wgpu_driver + PRIVATE + "${CMAKE_CURRENT_BINARY_DIR}" + "${CMAKE_CURRENT_SOURCE_DIR}/gpu/wgsl" + "${CMAKE_CURRENT_SOURCE_DIR}/cpp" + ) + target_compile_features(aead_wgpu_driver PUBLIC cxx_std_20) + set_target_properties(aead_wgpu_driver PROPERTIES POSITION_INDEPENDENT_CODE ON) -if(_AEAD_WGPU_FOUND) - target_include_directories(aead_wgpu_driver PRIVATE ${_AEAD_WGPU_INCLUDE}) - target_link_libraries(aead_wgpu_driver PUBLIC ${_AEAD_WGPU_LIB}) - target_compile_definitions(aead_wgpu_driver PRIVATE - LUX_AEAD_HAS_WEBGPU=1 - LUX_AEAD_HAS_WGPU_NATIVE=1) - message(STATUS "[aead-wgsl] wgpu-native: ${_AEAD_WGPU_LIB}") -else() - message(STATUS "[aead-wgsl] wgpu-native: NOT FOUND (driver compiles to stub)") + if(_AEAD_WGPU_FOUND) + target_include_directories(aead_wgpu_driver PRIVATE ${_AEAD_WGPU_INCLUDE}) + target_link_libraries(aead_wgpu_driver PUBLIC ${_AEAD_WGPU_LIB}) + target_compile_definitions(aead_wgpu_driver PRIVATE + LUX_AEAD_HAS_WEBGPU=1 + LUX_AEAD_HAS_WGPU_NATIVE=1) + message(STATUS "[aead-wgsl] wgpu-native: ${_AEAD_WGPU_LIB}") + else() + message(STATUS "[aead-wgsl] wgpu-native: NOT FOUND (driver compiles to stub)") + endif() endif() - -# See note above: aead_wgsl_determinism_test is added at the top-level -# CMakeLists.txt after enable_testing() has run. diff --git a/blake2b/CMakeLists.txt b/blake2b/CMakeLists.txt index 0d1c200..010ef10 100644 --- a/blake2b/CMakeLists.txt +++ b/blake2b/CMakeLists.txt @@ -24,30 +24,21 @@ if(APPLE AND CRYPTO_ENABLE_METAL) ) endif() -# Batched BLAKE2b CUDA kernel + host-emulation entry point. -# When CRYPTO_ENABLE_CUDA=ON the .cu file is compiled by nvcc with the -# real __global__ kernel exposed as blake2b_jobs. When CUDA is off (default -# CPU-only build) the same .cu file compiles as host C++ via the -# __CUDA_ARCH__ shim and exposes blake2b_batch_cuda_host for the -# determinism test, which is byte-equal to the device kernel by construction. +# Batched BLAKE2b CUDA kernel + host-emulation entry point. Built only when +# CRYPTO_ENABLE_CUDA=ON (lux-gpu-kernels found + NVCC). The .cu source lives +# in lux-private/gpu-kernels and is symlinked into gpu/cuda/ at configure time. if(CRYPTO_ENABLE_CUDA) add_library(blake2b_batch_cuda STATIC gpu/cuda/blake2b.cu) set_target_properties(blake2b_batch_cuda PROPERTIES POSITION_INDEPENDENT_CODE ON CUDA_SEPARABLE_COMPILATION ON) -else() - add_library(blake2b_batch_cuda STATIC gpu/cuda/blake2b.cu) - set_source_files_properties(gpu/cuda/blake2b.cu PROPERTIES LANGUAGE CXX) - target_compile_features(blake2b_batch_cuda PUBLIC cxx_std_20) - set_target_properties(blake2b_batch_cuda PROPERTIES - POSITION_INDEPENDENT_CODE ON) endif() -# Batched BLAKE2b WGSL host-emulation entry point. The .wgsl kernel is -# shipped alongside; the host emulator mirrors it byte-for-byte (same -# vec2-emulated u64 math) so determinism testing succeeds on systems -# without a wgpu runtime. -add_library(blake2b_batch_wgsl STATIC gpu/wgsl/blake2b_wgsl_host.cpp) -target_compile_features(blake2b_batch_wgsl PUBLIC cxx_std_20) -set_target_properties(blake2b_batch_wgsl PROPERTIES - POSITION_INDEPENDENT_CODE ON) +# Batched BLAKE2b WGSL host-emulation entry point. Built only when +# CRYPTO_ENABLE_WGSL=ON (lux-gpu-kernels found + wgpu-native present). +if(CRYPTO_ENABLE_WGSL) + add_library(blake2b_batch_wgsl STATIC gpu/wgsl/blake2b_wgsl_host.cpp) + target_compile_features(blake2b_batch_wgsl PUBLIC cxx_std_20) + set_target_properties(blake2b_batch_wgsl PROPERTIES + POSITION_INDEPENDENT_CODE ON) +endif() diff --git a/bls/CMakeLists.txt b/bls/CMakeLists.txt index 70c4c5b..b86ecfb 100644 --- a/bls/CMakeLists.txt +++ b/bls/CMakeLists.txt @@ -68,7 +68,10 @@ if(APPLE) COMMENT "Generating Fp-tower test vectors via blst") add_custom_target(bls_fp_tower_vectors DEPENDS "${BLS_VEC_HEADER}") - # 2) Metal library compilation + # 2) Metal library compilation. Gated on CRYPTO_ENABLE_METAL — + # the .metal sources live in lux-private/gpu-kernels and are + # symlinked into gpu/metal/ at configure time. + if(CRYPTO_ENABLE_METAL) set(BLS_METAL_DIR "${CMAKE_CURRENT_SOURCE_DIR}/gpu/metal") set(BLS_METAL_FILES "${BLS_METAL_DIR}/bls_fp2.metal" @@ -356,6 +359,7 @@ if(APPLE) message(STATUS "[bls-stage3] enabled (Metal)") message(STATUS "[bls-stage3] final_exp metallib: ${BLS_FE_METALLIB}") message(STATUS "[bls-stage3] pairing metallib: ${BLS_PAIR_METALLIB}") + endif() # CRYPTO_ENABLE_METAL (stages 1-3 Metal scaffolding) # ========================================================= # Stage 4 — CUDA + WGSL parity ports of Stages 1-3. @@ -364,7 +368,7 @@ if(APPLE) # * subgroup rejection vectors via blst predicate. # ========================================================= - # ---- Subgroup rejection vector oracle + test ---- + # ---- Subgroup rejection vector oracle + test (CPU + blst) ---- add_executable(bls_subgroup_oracle test/bls_subgroup_oracle.cpp) target_link_libraries(bls_subgroup_oracle PRIVATE bls_stage1_blst) target_compile_options(bls_subgroup_oracle PRIVATE -O3) @@ -386,7 +390,10 @@ if(APPLE) add_test(NAME bls-subgroup-test COMMAND bls_subgroup_test) add_custom_target(bls-subgroup-test DEPENDS bls_subgroup_test bls_subgroup_vectors) - # ---- WGSL Fp-tower test (byte-equal to CPU oracle on Apple) ---- + # ---- WGSL Fp-tower test (byte-equal to CPU oracle on Apple). + # Gated on CRYPTO_ENABLE_WGSL — the .wgsl sources live in + # lux-private/gpu-kernels and are symlinked into gpu/wgsl/. ---- + if(CRYPTO_ENABLE_WGSL) set(BLS_WGSL_DIR "${CMAKE_CURRENT_SOURCE_DIR}/gpu/wgsl") set(BLS_WGSL_HEADER "${CMAKE_CURRENT_BINARY_DIR}/bls_wgsl_sources.h") file(WRITE ${BLS_WGSL_HEADER} "// Auto-generated. Do not edit.\n#pragma once\n\n") @@ -442,22 +449,21 @@ if(APPLE) add_test(NAME bls-fp-tower-wgsl-test COMMAND bls_fp_tower_wgsl_test) add_custom_target(bls-fp-tower-wgsl-test DEPENDS bls_fp_tower_wgsl_test bls_fp_tower_vectors) + endif() # CRYPTO_ENABLE_WGSL (stage 4 WGSL) - # ---- CUDA stub driver (build-only on Apple; gated for Linux+CUDA CI) ---- - # nvcc not available on Apple; we compile only the host driver in stub - # mode and link it to a smoke test that confirms the stub paths return - # -1. The Linux+CUDA CI runner sets BLS_HAVE_CUDA and recompiles - # the .cu files via nvcc. + # ---- CUDA stub driver. Gated on CRYPTO_ENABLE_CUDA — the .cpp + # host shim lives in lux-private/gpu-kernels and is symlinked into + # gpu/cuda/. The Linux+CUDA CI lane sets BLS_HAVE_CUDA and + # recompiles the .cu files via nvcc. ---- + if(CRYPTO_ENABLE_CUDA) add_library(bls_cuda_stub STATIC gpu/cuda/bls_driver_cuda.cpp) target_include_directories(bls_cuda_stub PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/gpu/cuda") target_compile_options(bls_cuda_stub PRIVATE -O3) target_compile_features(bls_cuda_stub PRIVATE cxx_std_17) + endif() # CRYPTO_ENABLE_CUDA (stage 4 CUDA stub) - message(STATUS "[bls-stage4] enabled") - message(STATUS "[bls-stage4] subgroup vectors: ${BLS_SUBG_VEC}") - message(STATUS "[bls-stage4] wgsl sources concatenated: ${BLS_WGSL_HEADER}") - message(STATUS "[bls-stage4] cuda driver: build-only on Apple (CI: nvcc-compiled)") + message(STATUS "[bls-stage4] subgroup vectors: ${BLS_SUBG_VEC}") # ========================================================= # Stage 6 — IRTF BLS signature primitives (keygen / sk_to_pk @@ -547,10 +553,10 @@ if(APPLE) # BLS_PAIR_METAL_FILES below). # ========================================================= - # ---- Add the combined-miller .metal source to the pairing metallib ---- - # The pairing metallib already aggregates miller + final_exp + pairing; - # we add bls_combined_miller.metal so the runtime loader finds the - # k_combined_miller_reduce kernel alongside the per-bit Miller kernels. + # ---- Add the combined-miller .metal source to the pairing metallib. + # Gated on CRYPTO_ENABLE_METAL — the .metal sources live in + # lux-private/gpu-kernels and are symlinked into gpu/metal/. ---- + if(CRYPTO_ENABLE_METAL) set(BLS_CMIL_AIR "${CMAKE_CURRENT_BINARY_DIR}/bls_combined_miller_pair.air") add_custom_command( OUTPUT "${BLS_CMIL_AIR}" @@ -631,7 +637,14 @@ if(APPLE) DEPENDS bls_combined_miller_metal_test bls_combined_miller_metallib) - # ---- CUDA stub test (runs in stub mode on Apple) ---- + message(STATUS "[bls-combined-miller] enabled — Metal") + message(STATUS "[bls-combined-miller] metallib: ${BLS_COMBINED_METALLIB}") + endif() # CRYPTO_ENABLE_METAL (combined-miller Metal) + + # ---- CUDA stub test. Gated on CRYPTO_ENABLE_CUDA — the .cpp host + # shim lives in lux-private/gpu-kernels and is symlinked into + # gpu/cuda/. ---- + if(CRYPTO_ENABLE_CUDA) add_library(bls_combined_miller_cuda_stub STATIC gpu/cuda/bls_combined_miller_driver.cpp) target_compile_options(bls_combined_miller_cuda_stub PRIVATE -O3) @@ -651,8 +664,10 @@ if(APPLE) COMMAND bls_combined_miller_cuda_test) add_custom_target(bls-combined-miller-cuda-test DEPENDS bls_combined_miller_cuda_test) + endif() # CRYPTO_ENABLE_CUDA (combined-miller CUDA stub) - # ---- WGSL test (SKIP unless WGSL miller driver linked) ---- + # ---- WGSL test. Gated on CRYPTO_ENABLE_WGSL. ---- + if(CRYPTO_ENABLE_WGSL) add_executable(bls_combined_miller_wgsl_test test/bls_combined_miller_wgsl_test.cpp) target_compile_options(bls_combined_miller_wgsl_test PRIVATE -O3) @@ -663,16 +678,7 @@ if(APPLE) COMMAND bls_combined_miller_wgsl_test) add_custom_target(bls-combined-miller-wgsl-test DEPENDS bls_combined_miller_wgsl_test) - - # ---- Aggregate target name as referenced by build commands ---- - add_custom_target(bls_combined_miller_all - DEPENDS bls_combined_miller_metal_test - bls_combined_miller_cuda_test - bls_combined_miller_wgsl_test - bls_combined_miller_metallib) - - message(STATUS "[bls-combined-miller] enabled — Metal/CUDA/WGSL") - message(STATUS "[bls-combined-miller] metallib: ${BLS_COMBINED_METALLIB}") + endif() # CRYPTO_ENABLE_WGSL (combined-miller WGSL) endif() else() message(STATUS "[bls-stage1] skipped — blst not found at ${BLST_DIR}") diff --git a/cggmp21/CMakeLists.txt b/cggmp21/CMakeLists.txt index c9d59ce..543ac9e 100644 --- a/cggmp21/CMakeLists.txt +++ b/cggmp21/CMakeLists.txt @@ -24,15 +24,17 @@ target_include_directories(cggmp21 PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../secp256k1/cpp ${CMAKE_CURRENT_SOURCE_DIR}/../modexp/cpp) -# CUDA backend (host polyfill — same .cu compiles as plain C++ when CUDA is -# disabled, exposing cggmp21_presign_cuda_host as the byte-equal oracle that -# forwards to the CPU canonical body). -add_library(cggmp21_cuda STATIC gpu/cuda/cggmp21_presign.cu) -set_source_files_properties(gpu/cuda/cggmp21_presign.cu PROPERTIES LANGUAGE CXX) -target_compile_features(cggmp21_cuda PUBLIC cxx_std_20) -set_target_properties(cggmp21_cuda PROPERTIES POSITION_INDEPENDENT_CODE ON) -target_link_libraries(cggmp21_cuda PUBLIC cggmp21_cpu) -target_include_directories(cggmp21_cuda PUBLIC - ${CMAKE_CURRENT_SOURCE_DIR}/cpp - ${CMAKE_CURRENT_SOURCE_DIR}/../secp256k1/cpp) +# CUDA backend (host polyfill). Built only when CRYPTO_ENABLE_CUDA=ON +# (lux-gpu-kernels found). The .cu source lives in lux-private/gpu-kernels +# and is symlinked into gpu/cuda/ at configure time. +if(CRYPTO_ENABLE_CUDA) + add_library(cggmp21_cuda STATIC gpu/cuda/cggmp21_presign.cu) + set_source_files_properties(gpu/cuda/cggmp21_presign.cu PROPERTIES LANGUAGE CXX) + target_compile_features(cggmp21_cuda PUBLIC cxx_std_20) + set_target_properties(cggmp21_cuda PROPERTIES POSITION_INDEPENDENT_CODE ON) + target_link_libraries(cggmp21_cuda PUBLIC cggmp21_cpu) + target_include_directories(cggmp21_cuda PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../secp256k1/cpp) +endif() diff --git a/fhe/CMakeLists.txt b/fhe/CMakeLists.txt index 81d08cb..a422a15 100644 --- a/fhe/CMakeLists.txt +++ b/fhe/CMakeLists.txt @@ -47,9 +47,10 @@ set(FHE_CPP_SOURCES cpp/backends/cpu/ntt_cpu.cpp cpp/backends/cpu/keyswitch_cpu.cpp cpp/backends/cpu/bootstrap_cpu.cpp - # CUDA host drivers compile on every host (CRYPTO_HAS_CUDA gates the - # actual GPU dispatch). They link the CPU oracle for the no-device - # path so the FHE umbrella library always provides the full API. + # CUDA host drivers compile on every host. They internally reference the + # lattice_ring_cuda_* C symbols which are provided by either the real + # lattice_ring_cuda target (when CRYPTO_ENABLE_CUDA=ON) or by a CPU-only + # stub (added below) when lux-gpu-kernels is not present. cpp/backends/cuda/cuda_driver.cpp cpp/backends/cuda/cuda_ntt_kernel.cpp cpp/backends/cuda/cuda_keyswitch_kernel.cpp @@ -120,9 +121,29 @@ target_include_directories(fhe PUBLIC # LP-107: FHE consumes the unified Lux Montgomery + NTT body. The # cpu/ntt_cpu wrapper delegates to lattice_ring_cpu's NTTStandard / # INTTStandard. The cuda/cuda_ntt_kernel.cpp host dispatcher delegates -# to lattice_ring_cuda when a device is present. -target_link_libraries(fhe_cpu PUBLIC lattice_ring_cpu lattice_ring_cuda) -target_link_libraries(fhe PUBLIC lattice_ring_cpu lattice_ring_cuda) +# to lattice_ring_cuda when a device is present (target only exists when +# CRYPTO_ENABLE_CUDA=ON, i.e. lux-gpu-kernels is installed). +target_link_libraries(fhe_cpu PUBLIC lattice_ring_cpu) +target_link_libraries(fhe PUBLIC lattice_ring_cpu) + +if(TARGET lattice_ring_cuda) + target_link_libraries(fhe_cpu PUBLIC lattice_ring_cuda) + target_link_libraries(fhe PUBLIC lattice_ring_cuda) +else() + # CPU-only build (lux-gpu-kernels absent): fhe_cpu's cuda_ntt_kernel.cpp + # references extern "C" lattice_ring_cuda_{ntt,intt,available,is_device_present} + # symbols. Provide a minimal stub TU so libfhe_cpu.a is linkable. Stubs + # return "not available" (0) so the host dispatcher routes to the CPU + # oracle exactly like it does on a runtime without a GPU. + add_library(fhe_lattice_ring_cuda_stub STATIC + cpp/backends/cuda/lattice_ring_cuda_stub.cpp + ) + target_compile_features(fhe_lattice_ring_cuda_stub PUBLIC cxx_std_20) + set_target_properties(fhe_lattice_ring_cuda_stub PROPERTIES + POSITION_INDEPENDENT_CODE ON) + target_link_libraries(fhe_cpu PUBLIC fhe_lattice_ring_cuda_stub) + target_link_libraries(fhe PUBLIC fhe_lattice_ring_cuda_stub) +endif() # CUDA device-kernel static archive. Linked into `fhe_cpu` so callers # of the umbrella library transparently pull the GPU symbols. diff --git a/fhe/cpp/backends/cuda/lattice_ring_cuda_stub.cpp b/fhe/cpp/backends/cuda/lattice_ring_cuda_stub.cpp new file mode 100644 index 0000000..0f73931 --- /dev/null +++ b/fhe/cpp/backends/cuda/lattice_ring_cuda_stub.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: BSD-3-Clause-Eco +// CPU-only stub for lattice_ring_cuda_* extern C symbols. +// +// The real implementation lives in lux-private/gpu-kernels at +// math/ntt/cuda/lattice_ring.cu and is linked when CRYPTO_ENABLE_CUDA=ON +// brings the lattice_ring_cuda target into the build. Without lux-private +// installed the symbols are missing at link time, which breaks fhe_cpu's +// cuda_ntt_kernel.cpp host dispatcher even though it never runs on this +// host (the dispatcher gates dispatch behind lattice_ring_cuda_available()==1). +// +// This stub returns 0 from _available so the dispatcher always routes to +// the CPU oracle, and emits NOTIMPL (-1) from every NTT/IMForm/MUL entry +// point so any accidental call surfaces immediately rather than silently +// computing wrong values. + +#include + +extern "C" { + +int lattice_ring_cuda_available(void) { return 0; } + +int lattice_ring_cuda_mform(const uint64_t*, uint64_t*, + uint32_t, uint64_t, uint64_t, uint64_t) { + return -1; +} + +int lattice_ring_cuda_imform(const uint64_t*, uint64_t*, + uint32_t, uint64_t, uint64_t) { + return -1; +} + +int lattice_ring_cuda_mul_coeffs_montgomery(const uint64_t*, const uint64_t*, + uint64_t*, uint32_t, + uint64_t, uint64_t) { + return -1; +} + +int lattice_ring_cuda_mul_coeffs_montgomery_then_add(const uint64_t*, + const uint64_t*, + uint64_t*, + uint32_t, + uint64_t, uint64_t) { + return -1; +} + +int lattice_ring_cuda_ntt(const uint64_t*, const uint64_t*, + uint64_t*, uint32_t, + uint64_t, uint64_t, uint64_t) { + return -1; +} + +int lattice_ring_cuda_intt(const uint64_t*, const uint64_t*, + uint64_t*, uint32_t, + uint64_t, uint64_t, uint64_t) { + return -1; +} + +} // extern "C" diff --git a/frost/CMakeLists.txt b/frost/CMakeLists.txt index 4423d6c..8acd472 100644 --- a/frost/CMakeLists.txt +++ b/frost/CMakeLists.txt @@ -28,13 +28,16 @@ target_include_directories(frost PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../secp256k1/cpp) # ============================================================================= -# CUDA backend (host polyfill). Same .cu compiles as plain C++ when CUDA is -# disabled and exposes frost_presign_cuda_host() as the byte-equal oracle. +# CUDA backend (host polyfill). Built only when CRYPTO_ENABLE_CUDA=ON +# (lux-gpu-kernels found). The .cu source lives in lux-private/gpu-kernels +# and is symlinked into gpu/cuda/ at configure time. # ============================================================================= -add_library(frost_cuda STATIC gpu/cuda/frost_presign.cu) -set_source_files_properties(gpu/cuda/frost_presign.cu PROPERTIES LANGUAGE CXX) -target_compile_features(frost_cuda PUBLIC cxx_std_20) -set_target_properties(frost_cuda PROPERTIES POSITION_INDEPENDENT_CODE ON) +if(CRYPTO_ENABLE_CUDA) + add_library(frost_cuda STATIC gpu/cuda/frost_presign.cu) + set_source_files_properties(gpu/cuda/frost_presign.cu PROPERTIES LANGUAGE CXX) + target_compile_features(frost_cuda PUBLIC cxx_std_20) + set_target_properties(frost_cuda PROPERTIES POSITION_INDEPENDENT_CODE ON) +endif() # Test wiring lives in the top-level CMakeLists.txt under the # `if(CRYPTO_BUILD_TESTS)` block — that block is the only point where diff --git a/gpukit/CMakeLists.txt b/gpukit/CMakeLists.txt index 784478a..24bf87b 100644 --- a/gpukit/CMakeLists.txt +++ b/gpukit/CMakeLists.txt @@ -56,10 +56,6 @@ if(APPLE AND CRYPTO_ENABLE_METAL) endif() if(CRYPTO_ENABLE_CUDA OR GPUKIT_ENABLE_CUDA) list(APPEND _GPUKIT_MP_SOURCES gpu/cuda/multi_pippenger.cu) -else() - set_source_files_properties(gpu/cuda/multi_pippenger.cu - PROPERTIES LANGUAGE CXX) - list(APPEND _GPUKIT_MP_SOURCES gpu/cuda/multi_pippenger.cu) endif() if(CRYPTO_ENABLE_WGSL OR GPUKIT_ENABLE_WGSL) list(APPEND _GPUKIT_MP_SOURCES gpu/wgsl/multi_pippenger_driver.cpp) @@ -163,7 +159,9 @@ if(APPLE AND CRYPTO_ENABLE_METAL) endif() # ---- CUDA driver ----------------------------------------------------------- - +# Built only when CRYPTO_ENABLE_CUDA=ON (lux-gpu-kernels found + NVCC). The +# .cu sources live in lux-private/gpu-kernels and are symlinked into +# gpu/cuda/ at configure time. if(CRYPTO_ENABLE_CUDA OR GPUKIT_ENABLE_CUDA) set(_GPUKIT_CUDA_SOURCES gpu/cuda/prefix_sum.cu @@ -185,31 +183,6 @@ if(CRYPTO_ENABLE_CUDA OR GPUKIT_ENABLE_CUDA) ) target_link_libraries(gpukit_cuda PUBLIC gpukit_cpu) add_library(lux::gpukit_cuda ALIAS gpukit_cuda) -else() - # On Apple/no-CUDA hosts we still need the symbols. The .cu files are - # written so they emit NOTIMPL stubs when __CUDACC__/GPUKIT_HAS_CUDA is - # not defined; compile them as regular C++ TUs. - set(_GPUKIT_CUDA_STUBS - gpu/cuda/prefix_sum.cu - gpu/cuda/compaction.cu - gpu/cuda/radix_sort.cu - gpu/cuda/batch_inversion.cu - gpu/cuda/merkle_compose.cu - gpu/cuda/transcript_root.cu - gpu/cuda/ntt.cu - ) - foreach(_src ${_GPUKIT_CUDA_STUBS}) - set_source_files_properties(${_src} PROPERTIES LANGUAGE CXX) - endforeach() - add_library(gpukit_cuda STATIC ${_GPUKIT_CUDA_STUBS}) - target_include_directories(gpukit_cuda PUBLIC - $ - $ - ) - target_compile_features(gpukit_cuda PUBLIC cxx_std_20) - set_target_properties(gpukit_cuda PROPERTIES POSITION_INDEPENDENT_CODE ON) - target_link_libraries(gpukit_cuda PUBLIC gpukit_cpu) - add_library(lux::gpukit_cuda ALIAS gpukit_cuda) endif() # ---- WGSL driver ----------------------------------------------------------- @@ -251,8 +224,12 @@ endif() add_library(lux::gpukit ALIAS gpukit) # ---- Tests ----------------------------------------------------------------- +# Test sources directly call gpukit_*_{cuda,metal,wgsl} symbols so they only +# build when ALL backend targets are present (CRYPTO_ENABLE_CUDA + METAL + +# WGSL). On a CPU-only build (lux-gpu-kernels not installed) these tests +# are simply not built. -if(CRYPTO_BUILD_TESTS) +if(CRYPTO_BUILD_TESTS AND TARGET gpukit_cuda AND TARGET gpukit_metal AND TARGET gpukit_wgsl) enable_testing() set(_GPUKIT_TESTS prefix_sum diff --git a/lamport/CMakeLists.txt b/lamport/CMakeLists.txt index 4f3d20a..913d2ba 100644 --- a/lamport/CMakeLists.txt +++ b/lamport/CMakeLists.txt @@ -26,13 +26,17 @@ target_include_directories(lamport PUBLIC $ ) -# Always-built host oracles for the CUDA + WGSL kernels. These exist so the -# determinism tests can run on hosts without an NVIDIA / WebGPU runtime — -# the oracle is line-for-line the same arithmetic as the kernel. -add_library(lamport_cuda_oracle STATIC gpu/cuda/lamport_cuda_oracle.cpp) -target_compile_features(lamport_cuda_oracle PUBLIC cxx_std_20) -set_target_properties(lamport_cuda_oracle PROPERTIES POSITION_INDEPENDENT_CODE ON) +# Host oracles for the CUDA + WGSL kernels. Built only when the corresponding +# backend is enabled (lux-gpu-kernels found). The oracle sources live in +# lux-private/gpu-kernels and are symlinked into gpu/{cuda,wgsl}/ at configure. +if(CRYPTO_ENABLE_CUDA) + add_library(lamport_cuda_oracle STATIC gpu/cuda/lamport_cuda_oracle.cpp) + target_compile_features(lamport_cuda_oracle PUBLIC cxx_std_20) + set_target_properties(lamport_cuda_oracle PROPERTIES POSITION_INDEPENDENT_CODE ON) +endif() -add_library(lamport_wgsl_oracle STATIC gpu/wgsl/lamport_wgsl_oracle.cpp) -target_compile_features(lamport_wgsl_oracle PUBLIC cxx_std_20) -set_target_properties(lamport_wgsl_oracle PROPERTIES POSITION_INDEPENDENT_CODE ON) +if(CRYPTO_ENABLE_WGSL) + add_library(lamport_wgsl_oracle STATIC gpu/wgsl/lamport_wgsl_oracle.cpp) + target_compile_features(lamport_wgsl_oracle PUBLIC cxx_std_20) + set_target_properties(lamport_wgsl_oracle PROPERTIES POSITION_INDEPENDENT_CODE ON) +endif() diff --git a/math/CMakeLists.txt b/math/CMakeLists.txt index 6560aac..0919c70 100644 --- a/math/CMakeLists.txt +++ b/math/CMakeLists.txt @@ -45,27 +45,23 @@ endif() # ============================================================================= # CUDA — NVIDIA NTT. Same six-op surface, same byte-equality contract. -# CRYPTO_ENABLE_CUDA=ON: lattice_ring.cu compiles via nvcc. -# CRYPTO_ENABLE_CUDA=OFF: same .cu compiles as plain C++ host polyfill. +# Built only when CRYPTO_ENABLE_CUDA=ON (lux-gpu-kernels found + NVCC). +# The .cu source lives in lux-private/gpu-kernels and is symlinked into +# ntt/cuda/ at configure time. # ============================================================================= -add_library(lattice_ring_cuda STATIC - ntt/cuda/lattice_ring.cu - ntt/cuda/lattice_ring_cuda.cpp) if(CRYPTO_ENABLE_CUDA) + add_library(lattice_ring_cuda STATIC + ntt/cuda/lattice_ring.cu + ntt/cuda/lattice_ring_cuda.cpp) set_source_files_properties(ntt/cuda/lattice_ring.cu PROPERTIES LANGUAGE CUDA) set_target_properties(lattice_ring_cuda PROPERTIES POSITION_INDEPENDENT_CODE ON CUDA_SEPARABLE_COMPILATION ON) -else() - set_source_files_properties(ntt/cuda/lattice_ring.cu - PROPERTIES LANGUAGE CXX) - set_target_properties(lattice_ring_cuda PROPERTIES - POSITION_INDEPENDENT_CODE ON) + target_compile_features(lattice_ring_cuda PUBLIC cxx_std_20) + target_include_directories(lattice_ring_cuda PUBLIC + $) endif() -target_compile_features(lattice_ring_cuda PUBLIC cxx_std_20) -target_include_directories(lattice_ring_cuda PUBLIC - $) # ============================================================================= # WGSL — WebGPU NTT via wgpu-native. Same byte-equality contract. @@ -159,17 +155,30 @@ target_link_libraries(math_ntt INTERFACE lattice_ring_cpu math_modarith math_par add_library(math_ntt_c_abi STATIC ntt/c-abi/c_math_ntt.cpp) target_compile_features(math_ntt_c_abi PUBLIC cxx_std_20) set_target_properties(math_ntt_c_abi PROPERTIES POSITION_INDEPENDENT_CODE ON) +# Base include path. Backend-specific include dirs are added conditionally +# below — only when lux-private/gpu-kernels has symlinked them into place. target_include_directories(math_ntt_c_abi PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} - ${CMAKE_CURRENT_SOURCE_DIR}/ntt/cuda - ${CMAKE_CURRENT_SOURCE_DIR}/ntt/metal - ${CMAKE_CURRENT_SOURCE_DIR}/ntt/wgsl ) +if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/ntt/cuda) + target_include_directories(math_ntt_c_abi PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/ntt/cuda) +endif() +if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/ntt/metal) + target_include_directories(math_ntt_c_abi PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/ntt/metal) +endif() +if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/ntt/wgsl) + target_include_directories(math_ntt_c_abi PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/ntt/wgsl) +endif() # Link the GPU drivers we forward to. Each is conditionally added by the # ringtail CMakeLists; if a driver target is missing the C-ABI just won't # resolve those symbols at link time, which is fine because the Go-side # build tag for that backend won't be set either. -target_link_libraries(math_ntt_c_abi PUBLIC lattice_ring_cuda) +if(TARGET lattice_ring_cuda) + target_link_libraries(math_ntt_c_abi PUBLIC lattice_ring_cuda) +endif() if(APPLE AND TARGET lattice_ring_metal) target_link_libraries(math_ntt_c_abi PUBLIC lattice_ring_metal) endif() diff --git a/ntt/CMakeLists.txt b/ntt/CMakeLists.txt index bff6392..57836b1 100644 --- a/ntt/CMakeLists.txt +++ b/ntt/CMakeLists.txt @@ -9,26 +9,30 @@ lux_add_algorithm( target_include_directories(ntt_cpu PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/cpp) target_include_directories(ntt PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/cpp) -# GPU host-side drivers for the large-N path. These compile as plain C++ on -# every host (no CUDA/Metal/wgpu runtime required); when CRYPTO_ENABLE_* -# is on the corresponding kernel files in gpu/{cuda,metal,wgsl}/ are added -# to the link line. Today the drivers fall through to the CPU oracle so -# byte-equality is exercised on every CI runner. -add_library(ntt_large_gpu_cuda STATIC gpu/cuda/ntt_large.cu) -set_source_files_properties(gpu/cuda/ntt_large.cu PROPERTIES LANGUAGE CXX) -target_compile_features(ntt_large_gpu_cuda PUBLIC cxx_std_20) -target_include_directories(ntt_large_gpu_cuda PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/cpp) -set_target_properties(ntt_large_gpu_cuda PROPERTIES POSITION_INDEPENDENT_CODE ON) -target_link_libraries(ntt_large_gpu_cuda PUBLIC ntt_cpu) +# GPU host-side drivers for the large-N path. Built only when the corresponding +# CRYPTO_ENABLE_ is ON (lux-gpu-kernels found). The driver sources +# live in lux-private/gpu-kernels and are symlinked into gpu/{cuda,metal,wgsl}/. +if(CRYPTO_ENABLE_CUDA) + add_library(ntt_large_gpu_cuda STATIC gpu/cuda/ntt_large.cu) + set_source_files_properties(gpu/cuda/ntt_large.cu PROPERTIES LANGUAGE CXX) + target_compile_features(ntt_large_gpu_cuda PUBLIC cxx_std_20) + target_include_directories(ntt_large_gpu_cuda PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/cpp) + set_target_properties(ntt_large_gpu_cuda PROPERTIES POSITION_INDEPENDENT_CODE ON) + target_link_libraries(ntt_large_gpu_cuda PUBLIC ntt_cpu) +endif() -add_library(ntt_large_gpu_metal STATIC gpu/metal/ntt_large_driver.cpp) -target_compile_features(ntt_large_gpu_metal PUBLIC cxx_std_20) -target_include_directories(ntt_large_gpu_metal PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/cpp) -set_target_properties(ntt_large_gpu_metal PROPERTIES POSITION_INDEPENDENT_CODE ON) -target_link_libraries(ntt_large_gpu_metal PUBLIC ntt_cpu) +if(CRYPTO_ENABLE_METAL) + add_library(ntt_large_gpu_metal STATIC gpu/metal/ntt_large_driver.cpp) + target_compile_features(ntt_large_gpu_metal PUBLIC cxx_std_20) + target_include_directories(ntt_large_gpu_metal PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/cpp) + set_target_properties(ntt_large_gpu_metal PROPERTIES POSITION_INDEPENDENT_CODE ON) + target_link_libraries(ntt_large_gpu_metal PUBLIC ntt_cpu) +endif() -add_library(ntt_large_gpu_wgsl STATIC gpu/wgsl/ntt_large_driver.cpp) -target_compile_features(ntt_large_gpu_wgsl PUBLIC cxx_std_20) -target_include_directories(ntt_large_gpu_wgsl PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/cpp) -set_target_properties(ntt_large_gpu_wgsl PROPERTIES POSITION_INDEPENDENT_CODE ON) -target_link_libraries(ntt_large_gpu_wgsl PUBLIC ntt_cpu) +if(CRYPTO_ENABLE_WGSL) + add_library(ntt_large_gpu_wgsl STATIC gpu/wgsl/ntt_large_driver.cpp) + target_compile_features(ntt_large_gpu_wgsl PUBLIC cxx_std_20) + target_include_directories(ntt_large_gpu_wgsl PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/cpp) + set_target_properties(ntt_large_gpu_wgsl PROPERTIES POSITION_INDEPENDENT_CODE ON) + target_link_libraries(ntt_large_gpu_wgsl PUBLIC ntt_cpu) +endif() diff --git a/ripemd160/CMakeLists.txt b/ripemd160/CMakeLists.txt index e82e3cd..7e295e1 100644 --- a/ripemd160/CMakeLists.txt +++ b/ripemd160/CMakeLists.txt @@ -25,22 +25,18 @@ if(APPLE AND CRYPTO_ENABLE_METAL) endif() # ============================================================================= -# CUDA driver (host TU + .cu kernel). Build modes: -# * CRYPTO_ENABLE_CUDA=ON -> compile ripemd160.cu via nvcc; link to a -# real driver that talks to the CUDA runtime. -# * CRYPTO_ENABLE_CUDA=OFF -> compile only the host shim in stub mode so -# the test harness can call the C ABI on -# Apple/non-CUDA hosts (returns 0 -# "available?" -> false; skips compare). +# CUDA driver (host TU + .cu kernel). Built only when CRYPTO_ENABLE_CUDA=ON +# (lux-gpu-kernels found). Otherwise the target is absent; consumer tests +# are gated on CRYPTO_ENABLE_CUDA at the top-level CMakeLists. # ============================================================================= -add_library(ripemd160_batch_cuda STATIC - gpu/cuda/ripemd160_driver.cpp) -target_include_directories(ripemd160_batch_cuda PUBLIC - ${CMAKE_CURRENT_SOURCE_DIR}/gpu/cuda) -target_compile_features(ripemd160_batch_cuda PUBLIC cxx_std_20) -set_target_properties(ripemd160_batch_cuda PROPERTIES POSITION_INDEPENDENT_CODE ON) - if(CRYPTO_ENABLE_CUDA) + add_library(ripemd160_batch_cuda STATIC + gpu/cuda/ripemd160_driver.cpp) + target_include_directories(ripemd160_batch_cuda PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/gpu/cuda) + target_compile_features(ripemd160_batch_cuda PUBLIC cxx_std_20) + set_target_properties(ripemd160_batch_cuda PROPERTIES POSITION_INDEPENDENT_CODE ON) + target_compile_definitions(ripemd160_batch_cuda PRIVATE LUX_RIPEMD160_HAVE_CUDA=1) target_sources(ripemd160_batch_cuda PRIVATE gpu/cuda/ripemd160.cu) set_source_files_properties(gpu/cuda/ripemd160.cu PROPERTIES LANGUAGE CUDA) @@ -49,28 +45,28 @@ if(CRYPTO_ENABLE_CUDA) endif() # ============================================================================= -# WGSL driver (WebGPU). The WGSL source is concatenated into a header at -# build time so the driver embeds the kernel source as a string literal. -# Compiles to stub mode unless wgpu-native (or Dawn) is found on the host. +# WGSL driver (WebGPU). Built only when CRYPTO_ENABLE_WGSL=ON (lux-gpu-kernels +# found + wgpu-native present). The WGSL source is concatenated into a header +# at build time so the driver embeds the kernel source as a string literal. # ============================================================================= -set(_RIPEMD160_WGSL_HEADER "${CMAKE_CURRENT_BINARY_DIR}/ripemd160_wgsl_sources.h") -file(WRITE ${_RIPEMD160_WGSL_HEADER} - "// Auto-generated. Do not edit.\n" - "#pragma once\n\n") -file(READ "${CMAKE_CURRENT_SOURCE_DIR}/gpu/wgsl/ripemd160.wgsl" _RIPEMD160_WGSL_SRC) -file(APPEND ${_RIPEMD160_WGSL_HEADER} - "static constexpr char kRIPEMD160_WGSL_Source[] = R\"RMDWGSL(\n" - "${_RIPEMD160_WGSL_SRC}\n" - ")RMDWGSL\";\n") +if(CRYPTO_ENABLE_WGSL) + set(_RIPEMD160_WGSL_HEADER "${CMAKE_CURRENT_BINARY_DIR}/ripemd160_wgsl_sources.h") + file(WRITE ${_RIPEMD160_WGSL_HEADER} + "// Auto-generated. Do not edit.\n" + "#pragma once\n\n") + file(READ "${CMAKE_CURRENT_SOURCE_DIR}/gpu/wgsl/ripemd160.wgsl" _RIPEMD160_WGSL_SRC) + file(APPEND ${_RIPEMD160_WGSL_HEADER} + "static constexpr char kRIPEMD160_WGSL_Source[] = R\"RMDWGSL(\n" + "${_RIPEMD160_WGSL_SRC}\n" + ")RMDWGSL\";\n") -add_library(ripemd160_batch_wgpu STATIC gpu/wgsl/ripemd160_driver_wgpu.cpp) -target_include_directories(ripemd160_batch_wgpu - PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/gpu/wgsl - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) -target_compile_features(ripemd160_batch_wgpu PUBLIC cxx_std_20) -set_target_properties(ripemd160_batch_wgpu PROPERTIES POSITION_INDEPENDENT_CODE ON) + add_library(ripemd160_batch_wgpu STATIC gpu/wgsl/ripemd160_driver_wgpu.cpp) + target_include_directories(ripemd160_batch_wgpu + PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/gpu/wgsl + PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) + target_compile_features(ripemd160_batch_wgpu PUBLIC cxx_std_20) + set_target_properties(ripemd160_batch_wgpu PROPERTIES POSITION_INDEPENDENT_CODE ON) -if(CRYPTO_ENABLE_WGSL) find_path(_RIPEMD160_WGPU_INCLUDE webgpu.h HINTS /opt/homebrew/include /usr/local/include /usr/include) find_library(_RIPEMD160_WGPU_LIB NAMES wgpu_native wgpu diff --git a/secp256k1/CMakeLists.txt b/secp256k1/CMakeLists.txt index 72fcd47..10e8a62 100644 --- a/secp256k1/CMakeLists.txt +++ b/secp256k1/CMakeLists.txt @@ -45,15 +45,14 @@ if(APPLE AND CRYPTO_ENABLE_METAL) endif() # ---- CUDA --------------------------------------------------------------- -# Two-source library: kernel TU + driver TU. When CRYPTO_ENABLE_CUDA is OFF -# the .cu files compile as plain C++ and emit NOTIMPL stubs (sentinel -100) -# so the umbrella library still links on macOS / non-CUDA hosts. CI runs on -# hanzo-build-linux-amd64 with nvcc to build the real device path. -set(_SECP256K1_BATCH_INV_CUDA_SOURCES - gpu/cuda/secp256k1_batch_inv.cu - gpu/cuda/secp256k1_batch_inv_driver.cu -) +# Built only when CRYPTO_ENABLE_CUDA=ON (lux-gpu-kernels found + NVCC). The +# .cu sources live in lux-private/gpu-kernels and are symlinked into +# gpu/cuda/ at configure time. if(CRYPTO_ENABLE_CUDA) + set(_SECP256K1_BATCH_INV_CUDA_SOURCES + gpu/cuda/secp256k1_batch_inv.cu + gpu/cuda/secp256k1_batch_inv_driver.cu + ) add_library(secp256k1_batch_inv_cuda STATIC ${_SECP256K1_BATCH_INV_CUDA_SOURCES}) set_target_properties(secp256k1_batch_inv_cuda PROPERTIES @@ -64,45 +63,29 @@ if(CRYPTO_ENABLE_CUDA) target_include_directories(secp256k1_batch_inv_cuda PUBLIC $ ) -else() - foreach(_src ${_SECP256K1_BATCH_INV_CUDA_SOURCES}) - set_source_files_properties(${_src} PROPERTIES LANGUAGE CXX) - endforeach() - add_library(secp256k1_batch_inv_cuda STATIC - ${_SECP256K1_BATCH_INV_CUDA_SOURCES}) - target_include_directories(secp256k1_batch_inv_cuda PUBLIC +endif() + +# ---- WGSL --------------------------------------------------------------- +# Built only when CRYPTO_ENABLE_WGSL=ON. +if(CRYPTO_ENABLE_WGSL) + add_library(secp256k1_batch_inv_wgsl STATIC + gpu/wgsl/secp256k1_batch_inv_driver.cpp) + target_include_directories(secp256k1_batch_inv_wgsl PUBLIC $ ) - target_compile_features(secp256k1_batch_inv_cuda PUBLIC cxx_std_20) - set_target_properties(secp256k1_batch_inv_cuda PROPERTIES + target_compile_features(secp256k1_batch_inv_wgsl PUBLIC cxx_std_20) + set_target_properties(secp256k1_batch_inv_wgsl PROPERTIES POSITION_INDEPENDENT_CODE ON ) endif() -# ---- WGSL --------------------------------------------------------------- -# Single-source: the host-side dispatch wrapper. Real WebGPU dispatch is -# enabled by CRYPTO_HAS_DAWN at compile-time (set by CI when Dawn / -# wgpu-native is available). Otherwise the entry point returns NOTIMPL. -add_library(secp256k1_batch_inv_wgsl STATIC - gpu/wgsl/secp256k1_batch_inv_driver.cpp) -target_include_directories(secp256k1_batch_inv_wgsl PUBLIC - $ -) -target_compile_features(secp256k1_batch_inv_wgsl PUBLIC cxx_std_20) -set_target_properties(secp256k1_batch_inv_wgsl PROPERTIES - POSITION_INDEPENDENT_CODE ON -) - # ---- secp256k1 ecrecover CUDA driver ------------------------------------ -# Host-side CUDA driver paired with the kernel in gpu/cuda/secp256k1_recover.cu. -# Symbol exposed: secp256k1_ecrecover_address_batch_cuda — looked up via -# dlsym from cpp/ecrecover.cpp when LUX_SECP256K1_BACKEND=cuda. On non-CUDA -# hosts compiles as a CXX NOTIMPL stub so the umbrella library still links. -set(_SECP256K1_RECOVER_CUDA_SOURCES - gpu/cuda/secp256k1_recover.cu - gpu/cuda/secp256k1_first_party_cuda_driver.cu -) +# Built only when CRYPTO_ENABLE_CUDA=ON. if(CRYPTO_ENABLE_CUDA) + set(_SECP256K1_RECOVER_CUDA_SOURCES + gpu/cuda/secp256k1_recover.cu + gpu/cuda/secp256k1_first_party_cuda_driver.cu + ) add_library(secp256k1_recover_cuda STATIC ${_SECP256K1_RECOVER_CUDA_SOURCES}) set_target_properties(secp256k1_recover_cuda PROPERTIES @@ -114,19 +97,6 @@ if(CRYPTO_ENABLE_CUDA) $ ) target_compile_definitions(secp256k1_recover_cuda PRIVATE CRYPTO_HAS_CUDA=1) -else() - foreach(_src ${_SECP256K1_RECOVER_CUDA_SOURCES}) - set_source_files_properties(${_src} PROPERTIES LANGUAGE CXX) - endforeach() - add_library(secp256k1_recover_cuda STATIC - ${_SECP256K1_RECOVER_CUDA_SOURCES}) - target_include_directories(secp256k1_recover_cuda PUBLIC - $ - ) - target_compile_features(secp256k1_recover_cuda PUBLIC cxx_std_20) - set_target_properties(secp256k1_recover_cuda PROPERTIES - POSITION_INDEPENDENT_CODE ON - ) endif() # Note: test executables (batch_inv_cuda_test / batch_inv_wgsl_test) are diff --git a/sha256/CMakeLists.txt b/sha256/CMakeLists.txt index fee012f..f6c42bf 100644 --- a/sha256/CMakeLists.txt +++ b/sha256/CMakeLists.txt @@ -25,22 +25,18 @@ if(APPLE AND CRYPTO_ENABLE_METAL) endif() # ============================================================================= -# CUDA driver (host TU + .cu kernel). Build modes: -# * CRYPTO_ENABLE_CUDA=ON -> compile sha256.cu via nvcc; link to a real -# driver that talks to the CUDA runtime. -# * CRYPTO_ENABLE_CUDA=OFF -> compile only the host shim in stub mode -# so the test harness can call the C ABI -# on Apple/non-CUDA hosts (returns 0 -# "available?" -> false; skips compare). +# CUDA driver (host TU + .cu kernel). Built only when CRYPTO_ENABLE_CUDA=ON +# (lux-gpu-kernels found). Otherwise the target is absent; consumer tests +# are gated on CRYPTO_ENABLE_CUDA at the top-level CMakeLists. # ============================================================================= -add_library(sha256_batch_cuda STATIC - gpu/cuda/sha256_driver.cpp) -target_include_directories(sha256_batch_cuda PUBLIC - ${CMAKE_CURRENT_SOURCE_DIR}/gpu/cuda) -target_compile_features(sha256_batch_cuda PUBLIC cxx_std_20) -set_target_properties(sha256_batch_cuda PROPERTIES POSITION_INDEPENDENT_CODE ON) - if(CRYPTO_ENABLE_CUDA) + add_library(sha256_batch_cuda STATIC + gpu/cuda/sha256_driver.cpp) + target_include_directories(sha256_batch_cuda PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/gpu/cuda) + target_compile_features(sha256_batch_cuda PUBLIC cxx_std_20) + set_target_properties(sha256_batch_cuda PROPERTIES POSITION_INDEPENDENT_CODE ON) + target_compile_definitions(sha256_batch_cuda PRIVATE LUX_SHA256_HAVE_CUDA=1) # Compile the kernel via nvcc and link as a CUDA TU. target_sources(sha256_batch_cuda PRIVATE gpu/cuda/sha256.cu) @@ -50,28 +46,28 @@ if(CRYPTO_ENABLE_CUDA) endif() # ============================================================================= -# WGSL driver (WebGPU). The WGSL source is concatenated into a header at -# build time so the driver embeds the kernel source as a string literal. -# Compiles to stub mode unless wgpu-native (or Dawn) is found on the host. +# WGSL driver (WebGPU). Built only when CRYPTO_ENABLE_WGSL=ON (lux-gpu-kernels +# found + wgpu-native present). The WGSL source is concatenated into a header +# at build time so the driver embeds the kernel source as a string literal. # ============================================================================= -set(_SHA256_WGSL_HEADER "${CMAKE_CURRENT_BINARY_DIR}/sha256_wgsl_sources.h") -file(WRITE ${_SHA256_WGSL_HEADER} - "// Auto-generated. Do not edit.\n" - "#pragma once\n\n") -file(READ "${CMAKE_CURRENT_SOURCE_DIR}/gpu/wgsl/sha256.wgsl" _SHA256_WGSL_SRC) -file(APPEND ${_SHA256_WGSL_HEADER} - "static constexpr char kSHA256_WGSL_Source[] = R\"SHAWGSL(\n" - "${_SHA256_WGSL_SRC}\n" - ")SHAWGSL\";\n") +if(CRYPTO_ENABLE_WGSL) + set(_SHA256_WGSL_HEADER "${CMAKE_CURRENT_BINARY_DIR}/sha256_wgsl_sources.h") + file(WRITE ${_SHA256_WGSL_HEADER} + "// Auto-generated. Do not edit.\n" + "#pragma once\n\n") + file(READ "${CMAKE_CURRENT_SOURCE_DIR}/gpu/wgsl/sha256.wgsl" _SHA256_WGSL_SRC) + file(APPEND ${_SHA256_WGSL_HEADER} + "static constexpr char kSHA256_WGSL_Source[] = R\"SHAWGSL(\n" + "${_SHA256_WGSL_SRC}\n" + ")SHAWGSL\";\n") -add_library(sha256_batch_wgpu STATIC gpu/wgsl/sha256_driver_wgpu.cpp) -target_include_directories(sha256_batch_wgpu - PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/gpu/wgsl - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) -target_compile_features(sha256_batch_wgpu PUBLIC cxx_std_20) -set_target_properties(sha256_batch_wgpu PROPERTIES POSITION_INDEPENDENT_CODE ON) + add_library(sha256_batch_wgpu STATIC gpu/wgsl/sha256_driver_wgpu.cpp) + target_include_directories(sha256_batch_wgpu + PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/gpu/wgsl + PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) + target_compile_features(sha256_batch_wgpu PUBLIC cxx_std_20) + set_target_properties(sha256_batch_wgpu PROPERTIES POSITION_INDEPENDENT_CODE ON) -if(CRYPTO_ENABLE_WGSL) find_path(_SHA256_WGPU_INCLUDE webgpu.h HINTS /opt/homebrew/include /usr/local/include /usr/include) find_library(_SHA256_WGPU_LIB NAMES wgpu_native wgpu From 702f40a5b7845c650dd91eeed9a6ab91504d2c0b Mon Sep 17 00:00:00 2001 From: Hanzo AI Date: Sat, 16 May 2026 15:34:05 -0700 Subject: [PATCH 3/5] crypto: delete in-tree gpu/ subtrees (now sourced from lux-private) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Removes 303 files across 29 schemes' gpu/{cuda,metal,wgsl}/ subtrees and math/ntt/cuda/. The kernels now live in lux-private/gpu-kernels and are symlinked back into /gpu// by the find_package hook in crypto/CMakeLists.txt (PR #1) when lux-gpu-kernels is found at configure. math/ntt/cuda/lattice_ring_driver.h kept as a minimal public C ABI declaration so fhe/cpp/backends/cuda/cuda_ntt_kernel.cpp can resolve the extern "C" lattice_ring_cuda_* signatures it dispatches against — the bodies come from either the real lattice_ring_cuda target (CUDA on) or fhe/cpp/backends/cuda/lattice_ring_cuda_stub.cpp (CUDA off). pedersen/CMakeLists.txt: wrap WGSL block in CRYPTO_ENABLE_WGSL and drop the CUDA else() stub branch (cleaned up alongside the rest, missed in the previous gating commit). CPU-only verification (no CMAKE_PREFIX_PATH, lux-private absent): - configure: clean - build: 446/446 targets - ctest: 70/70 passing (35.25s wall, includes pulsar/fhe/bls/slhdsa KATs) --- aead/gpu/cuda/aead_driver_cuda.cpp | 205 -- aead/gpu/cuda/aead_driver_cuda.h | 61 - aead/gpu/cuda/aes_gcm.cu | 427 ----- aead/gpu/cuda/chacha20_poly1305.cu | 299 --- aead/gpu/metal/aead_batch.metal | 297 --- aead/gpu/metal/aead_batch_driver.mm | 131 -- aead/gpu/metal/aes_gcm.metal | 452 ----- aead/gpu/metal/aes_gcm_driver.mm | 133 -- aead/gpu/wgsl/aead_driver_wgpu.cpp | 348 ---- aead/gpu/wgsl/aead_driver_wgpu.h | 45 - aead/gpu/wgsl/aes_gcm.wgsl | 456 ----- aead/gpu/wgsl/chacha20_poly1305.wgsl | 382 ---- banderwagon/gpu/cuda/banderwagon.cu | 552 ------ banderwagon/gpu/cuda/banderwagon_driver.cpp | 129 -- banderwagon/gpu/cuda/banderwagon_driver.h | 41 - banderwagon/gpu/metal/banderwagon.metal | 487 ----- banderwagon/gpu/metal/banderwagon_driver.h | 77 - banderwagon/gpu/metal/banderwagon_driver.mm | 164 -- banderwagon/gpu/metal/banderwagon_msm.metal | 51 - banderwagon/gpu/wgsl/banderwagon.wgsl | 460 ----- banderwagon/gpu/wgsl/banderwagon_driver.cpp | 450 ----- banderwagon/gpu/wgsl/banderwagon_driver.h | 41 - blake2b/gpu/cuda/blake2b.cu | 197 -- blake2b/gpu/metal/blake2b_batch.metal | 140 -- blake2b/gpu/metal/blake2b_batch_driver.mm | 104 - blake2b/gpu/wgsl/blake2b.wgsl | 236 --- blake2b/gpu/wgsl/blake2b_wgsl_host.cpp | 215 --- blake3/gpu/cuda/blake3.cu | 315 --- blake3/gpu/metal/blake3.metal | 340 ---- blake3/gpu/metal/blake3_authored.metal | 624 ------ blake3/gpu/metal/blake3_batch.metal | 334 ---- blake3/gpu/metal/blake3_batch_driver.mm | 104 - blake3/gpu/metal/blake3_driver.h | 275 --- blake3/gpu/metal/blake3_driver.mm | 1083 ----------- blake3/gpu/wgsl/blake3.wgsl | 155 -- bls/gpu/cuda/bls.cu | 667 ------- bls/gpu/cuda/bls_combined_miller.cu | 46 - bls/gpu/cuda/bls_combined_miller_driver.cpp | 179 -- bls/gpu/cuda/bls_combined_miller_driver.h | 44 - bls/gpu/cuda/bls_driver_cuda.cpp | 278 --- bls/gpu/cuda/bls_driver_cuda.h | 25 - bls/gpu/cuda/bls_final_exp.cu | 47 - bls/gpu/cuda/bls_fp12.cu | 53 - bls/gpu/cuda/bls_fp12.cuh | 191 -- bls/gpu/cuda/bls_fp2.cu | 59 - bls/gpu/cuda/bls_fp2.cuh | 86 - bls/gpu/cuda/bls_fp6.cu | 41 - bls/gpu/cuda/bls_fp6.cuh | 191 -- bls/gpu/cuda/bls_fp_ops.cuh | 206 -- bls/gpu/cuda/bls_g2.cu | 41 - bls/gpu/cuda/bls_g2.cuh | 182 -- bls/gpu/cuda/bls_miller.cu | 96 - bls/gpu/cuda/bls_miller.cuh | 188 -- bls/gpu/cuda/bls_pairing.cu | 42 - bls/gpu/metal/bls.metal | 742 -------- bls/gpu/metal/bls_authored.metal | 647 ------- bls/gpu/metal/bls_combined_miller.metal | 70 - bls/gpu/metal/bls_combined_miller_driver.h | 44 - bls/gpu/metal/bls_combined_miller_driver.mm | 275 --- bls/gpu/metal/bls_driver.h | 246 --- bls/gpu/metal/bls_driver.mm | 688 ------- bls/gpu/metal/bls_final_exp.metal | 112 -- bls/gpu/metal/bls_fp12.metal | 300 --- bls/gpu/metal/bls_fp2.metal | 178 -- bls/gpu/metal/bls_fp6.metal | 278 --- bls/gpu/metal/bls_fp_ops.h.metal | 193 -- bls/gpu/metal/bls_g2.metal | 299 --- bls/gpu/metal/bls_miller.metal | 394 ---- bls/gpu/metal/bls_pairing.metal | 86 - bls/gpu/metal/msm.metal | 665 ------- bls/gpu/wgsl/bls.wgsl | 395 ---- bls/gpu/wgsl/bls_combined_miller.wgsl | 66 - bls/gpu/wgsl/bls_driver_wgpu.cpp | 275 --- bls/gpu/wgsl/bls_driver_wgpu.h | 22 - bls/gpu/wgsl/bls_fp12.wgsl | 718 ------- bls/gpu/wgsl/bls_fp2.wgsl | 172 -- bls/gpu/wgsl/bls_fp6.wgsl | 342 ---- bls/gpu/wgsl/bls_fp_ops.wgsl | 322 ---- bls/gpu/wgsl/bls_fp_tower_kernels.wgsl | 244 --- bn254/gpu/cuda/bn254.cu | 618 ------ bn254/gpu/cuda/bn254_driver_cuda.cpp | 315 --- bn254/gpu/cuda/bn254_driver_cuda.h | 55 - bn254/gpu/cuda/bn254_pairing.cuh | 890 --------- bn254/gpu/cuda/bn254_pairing_consts_cuda.cuh | 36 - bn254/gpu/metal/bn254.metal | 569 ------ bn254/gpu/metal/zk_metal.h | 360 ---- bn254/gpu/metal/zk_metal.mm | 988 ---------- bn254/gpu/wgsl/bn254.wgsl | 1684 ----------------- bn254/gpu/wgsl/bn254_driver_wgpu.cpp | 221 --- bn254/gpu/wgsl/bn254_driver_wgpu.h | 32 - cggmp21/gpu/cuda/cggmp21.cu | 347 ---- cggmp21/gpu/cuda/cggmp21_presign.cu | 96 - cggmp21/gpu/metal/cggmp21.metal | 366 ---- cggmp21/gpu/metal/cggmp21_presign.metal | 53 - cggmp21/gpu/wgsl/cggmp21.wgsl | 132 -- cggmp21/gpu/wgsl/cggmp21_presign.wgsl | 26 - ed25519/gpu/cuda/ed25519.cu | 451 ----- ed25519/gpu/metal/ed25519.metal | 501 ----- ed25519/gpu/metal/ed25519_batch.metal | 437 ----- ed25519/gpu/metal/ed25519_batch_driver.mm | 93 - ed25519/gpu/wgsl/ed25519.wgsl | 188 -- evm256/gpu/cuda/evm256.cu | 454 ----- evm256/gpu/metal/evm256.metal | 416 ---- evm256/gpu/wgsl/evm256.wgsl | 162 -- frost/gpu/cuda/frost.cu | 402 ---- frost/gpu/cuda/frost_presign.cu | 520 ----- frost/gpu/metal/frost.metal | 422 ----- frost/gpu/metal/frost_aggregate.metal | 439 ----- frost/gpu/metal/frost_nonce.metal | 613 ------ frost/gpu/metal/frost_presign.metal | 533 ------ frost/gpu/metal/shamir_interpolate.metal | 620 ------ frost/gpu/wgsl/frost.wgsl | 84 - frost/gpu/wgsl/frost_presign.wgsl | 91 - gpukit/gpu/cuda/batch_inversion.cu | 16 - gpukit/gpu/cuda/compaction.cu | 69 - gpukit/gpu/cuda/merkle_compose.cu | 10 - gpukit/gpu/cuda/multi_pippenger.cu | 59 - gpukit/gpu/cuda/ntt.cu | 13 - gpukit/gpu/cuda/prefix_sum.cu | 95 - gpukit/gpu/cuda/radix_sort.cu | 82 - gpukit/gpu/cuda/transcript_root.cu | 10 - .../curve_traits/banderwagon_traits.h.metal | 23 - .../curve_traits/bls12_381_g1_traits.h.metal | 25 - .../gpu/curve_traits/bn254_g1_traits.h.metal | 23 - .../gpu/curve_traits/secp256k1_traits.h.metal | 23 - gpukit/gpu/metal/batch_inversion.metal | 18 - gpukit/gpu/metal/batch_inversion_driver.mm | 24 - gpukit/gpu/metal/compaction.metal | 38 - gpukit/gpu/metal/compaction_driver.mm | 107 -- gpukit/gpu/metal/merkle_compose.metal | 10 - gpukit/gpu/metal/merkle_compose_driver.mm | 18 - gpukit/gpu/metal/multi_pippenger.metal | 77 - .../metal/multi_pippenger_banderwagon.metal | 28 - .../metal/multi_pippenger_bls12_381_g1.metal | 27 - .../gpu/metal/multi_pippenger_bn254_g1.metal | 27 - gpukit/gpu/metal/multi_pippenger_driver.mm | 30 - .../gpu/metal/multi_pippenger_secp256k1.metal | 35 - gpukit/gpu/metal/ntt.metal | 13 - gpukit/gpu/metal/ntt_driver.mm | 21 - gpukit/gpu/metal/prefix_sum.metal | 102 - gpukit/gpu/metal/prefix_sum_driver.mm | 124 -- gpukit/gpu/metal/radix_sort.metal | 72 - gpukit/gpu/metal/radix_sort_driver.mm | 25 - gpukit/gpu/metal/transcript_root.metal | 10 - gpukit/gpu/metal/transcript_root_driver.mm | 18 - gpukit/gpu/wgsl/batch_inversion.wgsl | 3 - gpukit/gpu/wgsl/batch_inversion_driver.cpp | 9 - gpukit/gpu/wgsl/compaction.wgsl | 25 - gpukit/gpu/wgsl/compaction_driver.cpp | 9 - gpukit/gpu/wgsl/merkle_compose.wgsl | 3 - gpukit/gpu/wgsl/merkle_compose_driver.cpp | 7 - gpukit/gpu/wgsl/multi_pippenger.wgsl | 52 - gpukit/gpu/wgsl/multi_pippenger_driver.cpp | 16 - gpukit/gpu/wgsl/ntt.wgsl | 3 - gpukit/gpu/wgsl/ntt_driver.cpp | 9 - gpukit/gpu/wgsl/prefix_sum.wgsl | 47 - gpukit/gpu/wgsl/prefix_sum_driver.cpp | 20 - gpukit/gpu/wgsl/radix_sort.wgsl | 29 - gpukit/gpu/wgsl/radix_sort_driver.cpp | 8 - gpukit/gpu/wgsl/transcript_root.wgsl | 3 - gpukit/gpu/wgsl/transcript_root_driver.cpp | 7 - ipa/gpu/metal/ipa_driver.h | 43 - ipa/gpu/metal/ipa_driver.mm | 40 - keccak/gpu/cuda/keccak.cu | 116 -- keccak/gpu/metal/keccak.metal | 166 -- keccak/gpu/metal/keccak_batch.metal | 135 -- keccak/gpu/wgsl/keccak.wgsl | 226 --- kzg/gpu/cuda/kzg.cu | 247 --- kzg/gpu/cuda/kzg_driver_cuda.cpp | 157 -- kzg/gpu/cuda/kzg_driver_cuda.h | 31 - kzg/gpu/metal/kzg.metal | 437 ----- kzg/gpu/wgsl/kzg.wgsl | 266 --- kzg/gpu/wgsl/kzg_driver_wgpu.cpp | 69 - kzg/gpu/wgsl/kzg_driver_wgpu.h | 26 - lamport/gpu/cuda/lamport.cu | 90 - lamport/gpu/cuda/lamport_cuda_oracle.cpp | 93 - lamport/gpu/metal/lamport_batch.metal | 92 - lamport/gpu/metal/lamport_batch_driver.mm | 111 -- lamport/gpu/metal/lamport_driver.h | 27 - lamport/gpu/metal/lamport_driver.mm | 78 - lamport/gpu/wgsl/lamport.wgsl | 100 - lamport/gpu/wgsl/lamport_wgsl_oracle.cpp | 91 - math/ntt/cuda/lattice_ring.cu | 782 -------- math/ntt/cuda/lattice_ring_cuda.cpp | 53 - math/ntt/cuda/lattice_ring_cuda.hpp | 58 - math/ntt/cuda/lattice_ring_driver.h | 35 +- mldsa/gpu/cuda/mldsa.cu | 247 --- mldsa/gpu/metal/mldsa_batch.metal | 369 ---- mldsa/gpu/metal/mldsa_batch_driver.mm | 222 --- mldsa/gpu/wgsl/mldsa.wgsl | 124 -- mlkem/gpu/cuda/mlkem.cu | 258 --- mlkem/gpu/metal/mlkem_batch.metal | 355 ---- mlkem/gpu/metal/mlkem_batch_driver.mm | 207 -- mlkem/gpu/wgsl/mlkem.wgsl | 198 -- modexp/gpu/cuda/modexp_karatsuba.cu | 104 - modexp/gpu/metal/modexp_karatsuba.metal | 151 -- modexp/gpu/metal/modular.metal | 327 ---- modexp/gpu/wgsl/modexp_karatsuba.wgsl | 114 -- ntt/gpu/cuda/four_step_ntt.cu | 452 ----- ntt/gpu/cuda/ntt.cu | 242 --- ntt/gpu/cuda/ntt_kernels.cu | 479 ----- ntt/gpu/cuda/ntt_large.cu | 48 - ntt/gpu/cuda/ntt_metal_kernel.cu | 208 -- ntt/gpu/cuda/ntt_unified_memory.cu | 365 ---- ntt/gpu/cuda/twiddle_cache.cu | 288 --- ntt/gpu/metal/four_step_ntt.metal | 963 ---------- ntt/gpu/metal/ntt.metal | 237 --- ntt/gpu/metal/ntt_kernels.metal | 572 ------ ntt/gpu/metal/ntt_large.metal | 121 -- ntt/gpu/metal/ntt_large_driver.cpp | 30 - ntt/gpu/metal/ntt_metal_kernel.metal | 373 ---- ntt/gpu/metal/ntt_unified_memory.metal | 746 -------- ntt/gpu/metal/twiddle_cache.metal | 545 ------ ntt/gpu/wgsl/four_step_ntt.wgsl | 143 -- ntt/gpu/wgsl/ntt.wgsl | 155 -- ntt/gpu/wgsl/ntt_kernels.wgsl | 211 --- ntt/gpu/wgsl/ntt_large.wgsl | 77 - ntt/gpu/wgsl/ntt_large_driver.cpp | 26 - ntt/gpu/wgsl/ntt_metal_kernel.wgsl | 146 -- ntt/gpu/wgsl/ntt_unified_memory.wgsl | 142 -- ntt/gpu/wgsl/twiddle_cache.wgsl | 359 ---- pedersen/CMakeLists.txt | 60 +- pedersen/gpu/cuda/pedersen.cu | 562 ------ pedersen/gpu/cuda/pedersen_driver_cuda.cpp | 110 -- pedersen/gpu/cuda/pedersen_driver_cuda.h | 41 - pedersen/gpu/cuda/pedersen_tree.cu | 491 ----- pedersen/gpu/cuda/pedersen_tree_driver.cpp | 82 - pedersen/gpu/cuda/pedersen_tree_driver.h | 36 - pedersen/gpu/metal/pedersen.metal | 663 ------- pedersen/gpu/metal/pedersen_driver.h | 38 - pedersen/gpu/metal/pedersen_driver.mm | 141 -- pedersen/gpu/metal/pedersen_tree.metal | 478 ----- pedersen/gpu/metal/pedersen_tree_driver.h | 42 - pedersen/gpu/metal/pedersen_tree_driver.mm | 116 -- pedersen/gpu/wgsl/pedersen.wgsl | 560 ------ pedersen/gpu/wgsl/pedersen_driver_wgpu.cpp | 304 --- pedersen/gpu/wgsl/pedersen_driver_wgpu.h | 30 - pedersen/gpu/wgsl/pedersen_tree.wgsl | 535 ------ pedersen/gpu/wgsl/pedersen_tree_driver.cpp | 271 --- pedersen/gpu/wgsl/pedersen_tree_driver.h | 33 - poly_mul/gpu/cuda/poly_mul.cu | 287 --- poly_mul/gpu/metal/poly_mul.metal | 351 ---- poly_mul/gpu/metal/poly_mul_batch.metal | 66 - poly_mul/gpu/metal/poly_mul_batch_driver.mm | 89 - poly_mul/gpu/wgsl/poly_mul.wgsl | 165 -- poseidon/gpu/cuda/poseidon2_bn254.cu | 376 ---- poseidon/gpu/cuda/poseidon2_driver.h | 38 - poseidon/gpu/metal/attestation.metal | 459 ----- poseidon/gpu/metal/goldilocks.metal | 412 ---- poseidon/gpu/metal/poseidon.metal | 426 ----- poseidon/gpu/metal/poseidon2_bn254.metal | 327 ---- poseidon/gpu/metal/poseidon2_driver.h | 41 - poseidon/gpu/metal/poseidon2_driver.mm | 80 - poseidon/gpu/metal/poseidon2_t2_batch.metal | 380 ---- .../gpu/metal/poseidon2_t2_batch_driver.mm | 74 - poseidon/gpu/wgsl/poseidon2_bn254.wgsl | 467 ----- poseidon/gpu/wgsl/poseidon2_driver.cpp | 350 ---- poseidon/gpu/wgsl/poseidon2_driver.h | 37 - ringtail/gpu/cuda/ringtail.cu | 240 --- ringtail/gpu/metal/ringtail.metal | 262 --- ringtail/gpu/metal/ringtail_ops.metal | 378 ---- ringtail/gpu/metal/ringtail_sign.metal | 568 ------ ringtail/gpu/metal/ringtail_verify.metal | 569 ------ ringtail/gpu/wgsl/ringtail.wgsl | 147 -- ripemd160/gpu/cuda/ripemd160.cu | 191 -- ripemd160/gpu/cuda/ripemd160_driver.cpp | 117 -- ripemd160/gpu/cuda/ripemd160_driver.h | 39 - ripemd160/gpu/metal/ripemd160_batch.metal | 172 -- ripemd160/gpu/metal/ripemd160_batch_driver.mm | 104 - ripemd160/gpu/wgsl/ripemd160.wgsl | 199 -- ripemd160/gpu/wgsl/ripemd160_driver_wgpu.cpp | 284 --- ripemd160/gpu/wgsl/ripemd160_driver_wgpu.h | 36 - secp256k1/gpu/cuda/secp256k1.cu | 685 ------- secp256k1/gpu/cuda/secp256k1_batch_inv.cu | 330 ---- .../gpu/cuda/secp256k1_batch_inv_driver.cu | 30 - .../cuda/secp256k1_first_party_cuda_driver.cu | 134 -- secp256k1/gpu/cuda/secp256k1_recover.cu | 685 ------- secp256k1/gpu/metal/secp256k1.metal | 561 ------ secp256k1/gpu/metal/secp256k1_authored.metal | 1397 -------------- secp256k1/gpu/metal/secp256k1_batch_inv.metal | 211 --- .../gpu/metal/secp256k1_batch_inv_driver.mm | 92 - secp256k1/gpu/metal/secp256k1_driver.h | 527 ------ secp256k1/gpu/metal/secp256k1_driver.mm | 1087 ----------- .../gpu/metal/secp256k1_first_party_driver.mm | 121 -- secp256k1/gpu/metal/secp256k1_recover.metal | 852 --------- secp256k1/gpu/wgsl/secp256k1.wgsl | 695 ------- secp256k1/gpu/wgsl/secp256k1_batch_inv.wgsl | 393 ---- .../gpu/wgsl/secp256k1_batch_inv_driver.cpp | 186 -- secp256k1/gpu/wgsl/secp256k1_recover.wgsl | 695 ------- sha256/gpu/cuda/sha256.cu | 148 -- sha256/gpu/cuda/sha256_driver.cpp | 117 -- sha256/gpu/cuda/sha256_driver.h | 40 - sha256/gpu/metal/sha256_batch.metal | 135 -- sha256/gpu/metal/sha256_batch_driver.mm | 111 -- sha256/gpu/wgsl/sha256.wgsl | 196 -- sha256/gpu/wgsl/sha256_driver_wgpu.cpp | 284 --- sha256/gpu/wgsl/sha256_driver_wgpu.h | 36 - slhdsa/gpu/cuda/slhdsa.cu | 376 ---- slhdsa/gpu/metal/slhdsa.metal | 415 ---- slhdsa/gpu/metal/slhdsa_driver.h | 249 --- slhdsa/gpu/metal/slhdsa_driver.mm | 204 -- slhdsa/gpu/wgsl/slhdsa.wgsl | 231 --- sr25519/gpu/cuda/sr25519.cu | 393 ---- sr25519/gpu/metal/sr25519.metal | 404 ---- sr25519/gpu/wgsl/sr25519.wgsl | 95 - 305 files changed, 37 insertions(+), 72556 deletions(-) delete mode 100644 aead/gpu/cuda/aead_driver_cuda.cpp delete mode 100644 aead/gpu/cuda/aead_driver_cuda.h delete mode 100644 aead/gpu/cuda/aes_gcm.cu delete mode 100644 aead/gpu/cuda/chacha20_poly1305.cu delete mode 100644 aead/gpu/metal/aead_batch.metal delete mode 100644 aead/gpu/metal/aead_batch_driver.mm delete mode 100644 aead/gpu/metal/aes_gcm.metal delete mode 100644 aead/gpu/metal/aes_gcm_driver.mm delete mode 100644 aead/gpu/wgsl/aead_driver_wgpu.cpp delete mode 100644 aead/gpu/wgsl/aead_driver_wgpu.h delete mode 100644 aead/gpu/wgsl/aes_gcm.wgsl delete mode 100644 aead/gpu/wgsl/chacha20_poly1305.wgsl delete mode 100644 banderwagon/gpu/cuda/banderwagon.cu delete mode 100644 banderwagon/gpu/cuda/banderwagon_driver.cpp delete mode 100644 banderwagon/gpu/cuda/banderwagon_driver.h delete mode 100644 banderwagon/gpu/metal/banderwagon.metal delete mode 100644 banderwagon/gpu/metal/banderwagon_driver.h delete mode 100644 banderwagon/gpu/metal/banderwagon_driver.mm delete mode 100644 banderwagon/gpu/metal/banderwagon_msm.metal delete mode 100644 banderwagon/gpu/wgsl/banderwagon.wgsl delete mode 100644 banderwagon/gpu/wgsl/banderwagon_driver.cpp delete mode 100644 banderwagon/gpu/wgsl/banderwagon_driver.h delete mode 100644 blake2b/gpu/cuda/blake2b.cu delete mode 100644 blake2b/gpu/metal/blake2b_batch.metal delete mode 100644 blake2b/gpu/metal/blake2b_batch_driver.mm delete mode 100644 blake2b/gpu/wgsl/blake2b.wgsl delete mode 100644 blake2b/gpu/wgsl/blake2b_wgsl_host.cpp delete mode 100644 blake3/gpu/cuda/blake3.cu delete mode 100644 blake3/gpu/metal/blake3.metal delete mode 100644 blake3/gpu/metal/blake3_authored.metal delete mode 100644 blake3/gpu/metal/blake3_batch.metal delete mode 100644 blake3/gpu/metal/blake3_batch_driver.mm delete mode 100644 blake3/gpu/metal/blake3_driver.h delete mode 100644 blake3/gpu/metal/blake3_driver.mm delete mode 100644 blake3/gpu/wgsl/blake3.wgsl delete mode 100644 bls/gpu/cuda/bls.cu delete mode 100644 bls/gpu/cuda/bls_combined_miller.cu delete mode 100644 bls/gpu/cuda/bls_combined_miller_driver.cpp delete mode 100644 bls/gpu/cuda/bls_combined_miller_driver.h delete mode 100644 bls/gpu/cuda/bls_driver_cuda.cpp delete mode 100644 bls/gpu/cuda/bls_driver_cuda.h delete mode 100644 bls/gpu/cuda/bls_final_exp.cu delete mode 100644 bls/gpu/cuda/bls_fp12.cu delete mode 100644 bls/gpu/cuda/bls_fp12.cuh delete mode 100644 bls/gpu/cuda/bls_fp2.cu delete mode 100644 bls/gpu/cuda/bls_fp2.cuh delete mode 100644 bls/gpu/cuda/bls_fp6.cu delete mode 100644 bls/gpu/cuda/bls_fp6.cuh delete mode 100644 bls/gpu/cuda/bls_fp_ops.cuh delete mode 100644 bls/gpu/cuda/bls_g2.cu delete mode 100644 bls/gpu/cuda/bls_g2.cuh delete mode 100644 bls/gpu/cuda/bls_miller.cu delete mode 100644 bls/gpu/cuda/bls_miller.cuh delete mode 100644 bls/gpu/cuda/bls_pairing.cu delete mode 100644 bls/gpu/metal/bls.metal delete mode 100644 bls/gpu/metal/bls_authored.metal delete mode 100644 bls/gpu/metal/bls_combined_miller.metal delete mode 100644 bls/gpu/metal/bls_combined_miller_driver.h delete mode 100644 bls/gpu/metal/bls_combined_miller_driver.mm delete mode 100644 bls/gpu/metal/bls_driver.h delete mode 100644 bls/gpu/metal/bls_driver.mm delete mode 100644 bls/gpu/metal/bls_final_exp.metal delete mode 100644 bls/gpu/metal/bls_fp12.metal delete mode 100644 bls/gpu/metal/bls_fp2.metal delete mode 100644 bls/gpu/metal/bls_fp6.metal delete mode 100644 bls/gpu/metal/bls_fp_ops.h.metal delete mode 100644 bls/gpu/metal/bls_g2.metal delete mode 100644 bls/gpu/metal/bls_miller.metal delete mode 100644 bls/gpu/metal/bls_pairing.metal delete mode 100644 bls/gpu/metal/msm.metal delete mode 100644 bls/gpu/wgsl/bls.wgsl delete mode 100644 bls/gpu/wgsl/bls_combined_miller.wgsl delete mode 100644 bls/gpu/wgsl/bls_driver_wgpu.cpp delete mode 100644 bls/gpu/wgsl/bls_driver_wgpu.h delete mode 100644 bls/gpu/wgsl/bls_fp12.wgsl delete mode 100644 bls/gpu/wgsl/bls_fp2.wgsl delete mode 100644 bls/gpu/wgsl/bls_fp6.wgsl delete mode 100644 bls/gpu/wgsl/bls_fp_ops.wgsl delete mode 100644 bls/gpu/wgsl/bls_fp_tower_kernels.wgsl delete mode 100644 bn254/gpu/cuda/bn254.cu delete mode 100644 bn254/gpu/cuda/bn254_driver_cuda.cpp delete mode 100644 bn254/gpu/cuda/bn254_driver_cuda.h delete mode 100644 bn254/gpu/cuda/bn254_pairing.cuh delete mode 100644 bn254/gpu/cuda/bn254_pairing_consts_cuda.cuh delete mode 100644 bn254/gpu/metal/bn254.metal delete mode 100644 bn254/gpu/metal/zk_metal.h delete mode 100644 bn254/gpu/metal/zk_metal.mm delete mode 100644 bn254/gpu/wgsl/bn254.wgsl delete mode 100644 bn254/gpu/wgsl/bn254_driver_wgpu.cpp delete mode 100644 bn254/gpu/wgsl/bn254_driver_wgpu.h delete mode 100644 cggmp21/gpu/cuda/cggmp21.cu delete mode 100644 cggmp21/gpu/cuda/cggmp21_presign.cu delete mode 100644 cggmp21/gpu/metal/cggmp21.metal delete mode 100644 cggmp21/gpu/metal/cggmp21_presign.metal delete mode 100644 cggmp21/gpu/wgsl/cggmp21.wgsl delete mode 100644 cggmp21/gpu/wgsl/cggmp21_presign.wgsl delete mode 100644 ed25519/gpu/cuda/ed25519.cu delete mode 100644 ed25519/gpu/metal/ed25519.metal delete mode 100644 ed25519/gpu/metal/ed25519_batch.metal delete mode 100644 ed25519/gpu/metal/ed25519_batch_driver.mm delete mode 100644 ed25519/gpu/wgsl/ed25519.wgsl delete mode 100644 evm256/gpu/cuda/evm256.cu delete mode 100644 evm256/gpu/metal/evm256.metal delete mode 100644 evm256/gpu/wgsl/evm256.wgsl delete mode 100644 frost/gpu/cuda/frost.cu delete mode 100644 frost/gpu/cuda/frost_presign.cu delete mode 100644 frost/gpu/metal/frost.metal delete mode 100644 frost/gpu/metal/frost_aggregate.metal delete mode 100644 frost/gpu/metal/frost_nonce.metal delete mode 100644 frost/gpu/metal/frost_presign.metal delete mode 100644 frost/gpu/metal/shamir_interpolate.metal delete mode 100644 frost/gpu/wgsl/frost.wgsl delete mode 100644 frost/gpu/wgsl/frost_presign.wgsl delete mode 100644 gpukit/gpu/cuda/batch_inversion.cu delete mode 100644 gpukit/gpu/cuda/compaction.cu delete mode 100644 gpukit/gpu/cuda/merkle_compose.cu delete mode 100644 gpukit/gpu/cuda/multi_pippenger.cu delete mode 100644 gpukit/gpu/cuda/ntt.cu delete mode 100644 gpukit/gpu/cuda/prefix_sum.cu delete mode 100644 gpukit/gpu/cuda/radix_sort.cu delete mode 100644 gpukit/gpu/cuda/transcript_root.cu delete mode 100644 gpukit/gpu/curve_traits/banderwagon_traits.h.metal delete mode 100644 gpukit/gpu/curve_traits/bls12_381_g1_traits.h.metal delete mode 100644 gpukit/gpu/curve_traits/bn254_g1_traits.h.metal delete mode 100644 gpukit/gpu/curve_traits/secp256k1_traits.h.metal delete mode 100644 gpukit/gpu/metal/batch_inversion.metal delete mode 100644 gpukit/gpu/metal/batch_inversion_driver.mm delete mode 100644 gpukit/gpu/metal/compaction.metal delete mode 100644 gpukit/gpu/metal/compaction_driver.mm delete mode 100644 gpukit/gpu/metal/merkle_compose.metal delete mode 100644 gpukit/gpu/metal/merkle_compose_driver.mm delete mode 100644 gpukit/gpu/metal/multi_pippenger.metal delete mode 100644 gpukit/gpu/metal/multi_pippenger_banderwagon.metal delete mode 100644 gpukit/gpu/metal/multi_pippenger_bls12_381_g1.metal delete mode 100644 gpukit/gpu/metal/multi_pippenger_bn254_g1.metal delete mode 100644 gpukit/gpu/metal/multi_pippenger_driver.mm delete mode 100644 gpukit/gpu/metal/multi_pippenger_secp256k1.metal delete mode 100644 gpukit/gpu/metal/ntt.metal delete mode 100644 gpukit/gpu/metal/ntt_driver.mm delete mode 100644 gpukit/gpu/metal/prefix_sum.metal delete mode 100644 gpukit/gpu/metal/prefix_sum_driver.mm delete mode 100644 gpukit/gpu/metal/radix_sort.metal delete mode 100644 gpukit/gpu/metal/radix_sort_driver.mm delete mode 100644 gpukit/gpu/metal/transcript_root.metal delete mode 100644 gpukit/gpu/metal/transcript_root_driver.mm delete mode 100644 gpukit/gpu/wgsl/batch_inversion.wgsl delete mode 100644 gpukit/gpu/wgsl/batch_inversion_driver.cpp delete mode 100644 gpukit/gpu/wgsl/compaction.wgsl delete mode 100644 gpukit/gpu/wgsl/compaction_driver.cpp delete mode 100644 gpukit/gpu/wgsl/merkle_compose.wgsl delete mode 100644 gpukit/gpu/wgsl/merkle_compose_driver.cpp delete mode 100644 gpukit/gpu/wgsl/multi_pippenger.wgsl delete mode 100644 gpukit/gpu/wgsl/multi_pippenger_driver.cpp delete mode 100644 gpukit/gpu/wgsl/ntt.wgsl delete mode 100644 gpukit/gpu/wgsl/ntt_driver.cpp delete mode 100644 gpukit/gpu/wgsl/prefix_sum.wgsl delete mode 100644 gpukit/gpu/wgsl/prefix_sum_driver.cpp delete mode 100644 gpukit/gpu/wgsl/radix_sort.wgsl delete mode 100644 gpukit/gpu/wgsl/radix_sort_driver.cpp delete mode 100644 gpukit/gpu/wgsl/transcript_root.wgsl delete mode 100644 gpukit/gpu/wgsl/transcript_root_driver.cpp delete mode 100644 ipa/gpu/metal/ipa_driver.h delete mode 100644 ipa/gpu/metal/ipa_driver.mm delete mode 100644 keccak/gpu/cuda/keccak.cu delete mode 100644 keccak/gpu/metal/keccak.metal delete mode 100644 keccak/gpu/metal/keccak_batch.metal delete mode 100644 keccak/gpu/wgsl/keccak.wgsl delete mode 100644 kzg/gpu/cuda/kzg.cu delete mode 100644 kzg/gpu/cuda/kzg_driver_cuda.cpp delete mode 100644 kzg/gpu/cuda/kzg_driver_cuda.h delete mode 100644 kzg/gpu/metal/kzg.metal delete mode 100644 kzg/gpu/wgsl/kzg.wgsl delete mode 100644 kzg/gpu/wgsl/kzg_driver_wgpu.cpp delete mode 100644 kzg/gpu/wgsl/kzg_driver_wgpu.h delete mode 100644 lamport/gpu/cuda/lamport.cu delete mode 100644 lamport/gpu/cuda/lamport_cuda_oracle.cpp delete mode 100644 lamport/gpu/metal/lamport_batch.metal delete mode 100644 lamport/gpu/metal/lamport_batch_driver.mm delete mode 100644 lamport/gpu/metal/lamport_driver.h delete mode 100644 lamport/gpu/metal/lamport_driver.mm delete mode 100644 lamport/gpu/wgsl/lamport.wgsl delete mode 100644 lamport/gpu/wgsl/lamport_wgsl_oracle.cpp delete mode 100644 math/ntt/cuda/lattice_ring.cu delete mode 100644 math/ntt/cuda/lattice_ring_cuda.cpp delete mode 100644 math/ntt/cuda/lattice_ring_cuda.hpp delete mode 100644 mldsa/gpu/cuda/mldsa.cu delete mode 100644 mldsa/gpu/metal/mldsa_batch.metal delete mode 100644 mldsa/gpu/metal/mldsa_batch_driver.mm delete mode 100644 mldsa/gpu/wgsl/mldsa.wgsl delete mode 100644 mlkem/gpu/cuda/mlkem.cu delete mode 100644 mlkem/gpu/metal/mlkem_batch.metal delete mode 100644 mlkem/gpu/metal/mlkem_batch_driver.mm delete mode 100644 mlkem/gpu/wgsl/mlkem.wgsl delete mode 100644 modexp/gpu/cuda/modexp_karatsuba.cu delete mode 100644 modexp/gpu/metal/modexp_karatsuba.metal delete mode 100644 modexp/gpu/metal/modular.metal delete mode 100644 modexp/gpu/wgsl/modexp_karatsuba.wgsl delete mode 100644 ntt/gpu/cuda/four_step_ntt.cu delete mode 100644 ntt/gpu/cuda/ntt.cu delete mode 100644 ntt/gpu/cuda/ntt_kernels.cu delete mode 100644 ntt/gpu/cuda/ntt_large.cu delete mode 100644 ntt/gpu/cuda/ntt_metal_kernel.cu delete mode 100644 ntt/gpu/cuda/ntt_unified_memory.cu delete mode 100644 ntt/gpu/cuda/twiddle_cache.cu delete mode 100644 ntt/gpu/metal/four_step_ntt.metal delete mode 100644 ntt/gpu/metal/ntt.metal delete mode 100644 ntt/gpu/metal/ntt_kernels.metal delete mode 100644 ntt/gpu/metal/ntt_large.metal delete mode 100644 ntt/gpu/metal/ntt_large_driver.cpp delete mode 100644 ntt/gpu/metal/ntt_metal_kernel.metal delete mode 100644 ntt/gpu/metal/ntt_unified_memory.metal delete mode 100644 ntt/gpu/metal/twiddle_cache.metal delete mode 100644 ntt/gpu/wgsl/four_step_ntt.wgsl delete mode 100644 ntt/gpu/wgsl/ntt.wgsl delete mode 100644 ntt/gpu/wgsl/ntt_kernels.wgsl delete mode 100644 ntt/gpu/wgsl/ntt_large.wgsl delete mode 100644 ntt/gpu/wgsl/ntt_large_driver.cpp delete mode 100644 ntt/gpu/wgsl/ntt_metal_kernel.wgsl delete mode 100644 ntt/gpu/wgsl/ntt_unified_memory.wgsl delete mode 100644 ntt/gpu/wgsl/twiddle_cache.wgsl delete mode 100644 pedersen/gpu/cuda/pedersen.cu delete mode 100644 pedersen/gpu/cuda/pedersen_driver_cuda.cpp delete mode 100644 pedersen/gpu/cuda/pedersen_driver_cuda.h delete mode 100644 pedersen/gpu/cuda/pedersen_tree.cu delete mode 100644 pedersen/gpu/cuda/pedersen_tree_driver.cpp delete mode 100644 pedersen/gpu/cuda/pedersen_tree_driver.h delete mode 100644 pedersen/gpu/metal/pedersen.metal delete mode 100644 pedersen/gpu/metal/pedersen_driver.h delete mode 100644 pedersen/gpu/metal/pedersen_driver.mm delete mode 100644 pedersen/gpu/metal/pedersen_tree.metal delete mode 100644 pedersen/gpu/metal/pedersen_tree_driver.h delete mode 100644 pedersen/gpu/metal/pedersen_tree_driver.mm delete mode 100644 pedersen/gpu/wgsl/pedersen.wgsl delete mode 100644 pedersen/gpu/wgsl/pedersen_driver_wgpu.cpp delete mode 100644 pedersen/gpu/wgsl/pedersen_driver_wgpu.h delete mode 100644 pedersen/gpu/wgsl/pedersen_tree.wgsl delete mode 100644 pedersen/gpu/wgsl/pedersen_tree_driver.cpp delete mode 100644 pedersen/gpu/wgsl/pedersen_tree_driver.h delete mode 100644 poly_mul/gpu/cuda/poly_mul.cu delete mode 100644 poly_mul/gpu/metal/poly_mul.metal delete mode 100644 poly_mul/gpu/metal/poly_mul_batch.metal delete mode 100644 poly_mul/gpu/metal/poly_mul_batch_driver.mm delete mode 100644 poly_mul/gpu/wgsl/poly_mul.wgsl delete mode 100644 poseidon/gpu/cuda/poseidon2_bn254.cu delete mode 100644 poseidon/gpu/cuda/poseidon2_driver.h delete mode 100644 poseidon/gpu/metal/attestation.metal delete mode 100644 poseidon/gpu/metal/goldilocks.metal delete mode 100644 poseidon/gpu/metal/poseidon.metal delete mode 100644 poseidon/gpu/metal/poseidon2_bn254.metal delete mode 100644 poseidon/gpu/metal/poseidon2_driver.h delete mode 100644 poseidon/gpu/metal/poseidon2_driver.mm delete mode 100644 poseidon/gpu/metal/poseidon2_t2_batch.metal delete mode 100644 poseidon/gpu/metal/poseidon2_t2_batch_driver.mm delete mode 100644 poseidon/gpu/wgsl/poseidon2_bn254.wgsl delete mode 100644 poseidon/gpu/wgsl/poseidon2_driver.cpp delete mode 100644 poseidon/gpu/wgsl/poseidon2_driver.h delete mode 100644 ringtail/gpu/cuda/ringtail.cu delete mode 100644 ringtail/gpu/metal/ringtail.metal delete mode 100644 ringtail/gpu/metal/ringtail_ops.metal delete mode 100644 ringtail/gpu/metal/ringtail_sign.metal delete mode 100644 ringtail/gpu/metal/ringtail_verify.metal delete mode 100644 ringtail/gpu/wgsl/ringtail.wgsl delete mode 100644 ripemd160/gpu/cuda/ripemd160.cu delete mode 100644 ripemd160/gpu/cuda/ripemd160_driver.cpp delete mode 100644 ripemd160/gpu/cuda/ripemd160_driver.h delete mode 100644 ripemd160/gpu/metal/ripemd160_batch.metal delete mode 100644 ripemd160/gpu/metal/ripemd160_batch_driver.mm delete mode 100644 ripemd160/gpu/wgsl/ripemd160.wgsl delete mode 100644 ripemd160/gpu/wgsl/ripemd160_driver_wgpu.cpp delete mode 100644 ripemd160/gpu/wgsl/ripemd160_driver_wgpu.h delete mode 100644 secp256k1/gpu/cuda/secp256k1.cu delete mode 100644 secp256k1/gpu/cuda/secp256k1_batch_inv.cu delete mode 100644 secp256k1/gpu/cuda/secp256k1_batch_inv_driver.cu delete mode 100644 secp256k1/gpu/cuda/secp256k1_first_party_cuda_driver.cu delete mode 100644 secp256k1/gpu/cuda/secp256k1_recover.cu delete mode 100644 secp256k1/gpu/metal/secp256k1.metal delete mode 100644 secp256k1/gpu/metal/secp256k1_authored.metal delete mode 100644 secp256k1/gpu/metal/secp256k1_batch_inv.metal delete mode 100644 secp256k1/gpu/metal/secp256k1_batch_inv_driver.mm delete mode 100644 secp256k1/gpu/metal/secp256k1_driver.h delete mode 100644 secp256k1/gpu/metal/secp256k1_driver.mm delete mode 100644 secp256k1/gpu/metal/secp256k1_first_party_driver.mm delete mode 100644 secp256k1/gpu/metal/secp256k1_recover.metal delete mode 100644 secp256k1/gpu/wgsl/secp256k1.wgsl delete mode 100644 secp256k1/gpu/wgsl/secp256k1_batch_inv.wgsl delete mode 100644 secp256k1/gpu/wgsl/secp256k1_batch_inv_driver.cpp delete mode 100644 secp256k1/gpu/wgsl/secp256k1_recover.wgsl delete mode 100644 sha256/gpu/cuda/sha256.cu delete mode 100644 sha256/gpu/cuda/sha256_driver.cpp delete mode 100644 sha256/gpu/cuda/sha256_driver.h delete mode 100644 sha256/gpu/metal/sha256_batch.metal delete mode 100644 sha256/gpu/metal/sha256_batch_driver.mm delete mode 100644 sha256/gpu/wgsl/sha256.wgsl delete mode 100644 sha256/gpu/wgsl/sha256_driver_wgpu.cpp delete mode 100644 sha256/gpu/wgsl/sha256_driver_wgpu.h delete mode 100644 slhdsa/gpu/cuda/slhdsa.cu delete mode 100644 slhdsa/gpu/metal/slhdsa.metal delete mode 100644 slhdsa/gpu/metal/slhdsa_driver.h delete mode 100644 slhdsa/gpu/metal/slhdsa_driver.mm delete mode 100644 slhdsa/gpu/wgsl/slhdsa.wgsl delete mode 100644 sr25519/gpu/cuda/sr25519.cu delete mode 100644 sr25519/gpu/metal/sr25519.metal delete mode 100644 sr25519/gpu/wgsl/sr25519.wgsl diff --git a/aead/gpu/cuda/aead_driver_cuda.cpp b/aead/gpu/cuda/aead_driver_cuda.cpp deleted file mode 100644 index b88455b..0000000 --- a/aead/gpu/cuda/aead_driver_cuda.cpp +++ /dev/null @@ -1,205 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CUDA host driver for batched AEAD ciphers (ChaCha20-Poly1305 + -// AES-256-GCM). One dispatch per call; one thread per message. -// -// Build modes: -// 1. With CUDA toolkit (LUX_AEAD_HAVE_CUDA defined): launches the -// first-party kernels in chacha20_poly1305.cu / aes_gcm.cu. Output is -// byte-equal to lux::crypto::aead::chacha20_poly1305::encrypt() and -// lux::crypto::aead::aes_256_gcm::encrypt() in cpp/aead.cpp. -// -// 2. Without CUDA (LUX_AEAD_HAVE_CUDA not defined): stubs return -1. The -// determinism test prints a [skipped on Apple] banner. The CI runner -// with a real CUDA device sets LUX_AEAD_HAVE_CUDA and exercises the -// same byte-equality vectors used by the Metal harness. - -#include "aead_driver_cuda.h" - -#include -#include - -#ifdef LUX_AEAD_HAVE_CUDA -#include - -extern "C" { - -struct AeadJob; - -__global__ void chacha20_poly1305_jobs( - const AeadJob* jobs, - const uint8_t* keys, - const uint8_t* nonces, - const uint8_t* inputs_arena, - uint8_t* outputs_arena, - uint32_t n_jobs); - -__global__ void aes_gcm_jobs( - const AeadJob* jobs, - const uint8_t* keys, - const uint8_t* nonces, - const uint8_t* inputs_arena, - uint8_t* outputs_arena, - uint32_t n_jobs); - -} // extern "C" - -namespace { - -constexpr unsigned kThreadsPerBlock = 64u; - -unsigned grid_for(unsigned n) { - return (n + kThreadsPerBlock - 1u) / kThreadsPerBlock; -} - -// Common dispatch path: marshal host -> device, launch kernel, copy back. -// Returns 0 on success, negative on failure. Used for both ChaCha and AES. -template -int dispatch_aead( - const uint8_t* keys, size_t keys_bytes, // n * 32 - const uint8_t* nonces, size_t nonces_bytes, // n * 12 - const uint8_t* inputs_arena, size_t inputs_bytes, - const void* jobs, size_t jobs_bytes, - uint8_t* outputs, size_t outputs_bytes, - unsigned n, - Launcher launch) { - - // Substitute a single non-empty byte for empty inputs_arena (CUDA - // requires non-NULL for cudaMemcpy; matches Metal driver convention). - static const uint8_t kEmpty = 0; - const uint8_t* in_ptr = inputs_arena ? inputs_arena : &kEmpty; - size_t in_len = inputs_bytes > 0 ? inputs_bytes : 1; - - void *d_keys=nullptr, *d_nonces=nullptr, *d_in=nullptr, *d_jobs=nullptr, - *d_out=nullptr; - - auto cleanup = [&]() { - if (d_keys) cudaFree(d_keys); - if (d_nonces) cudaFree(d_nonces); - if (d_in) cudaFree(d_in); - if (d_jobs) cudaFree(d_jobs); - if (d_out) cudaFree(d_out); - }; - - if (cudaMalloc(&d_keys, keys_bytes) != cudaSuccess) { cleanup(); return -10; } - if (cudaMalloc(&d_nonces, nonces_bytes) != cudaSuccess) { cleanup(); return -11; } - if (cudaMalloc(&d_in, in_len) != cudaSuccess) { cleanup(); return -12; } - if (cudaMalloc(&d_jobs, jobs_bytes) != cudaSuccess) { cleanup(); return -13; } - if (cudaMalloc(&d_out, outputs_bytes)!= cudaSuccess) { cleanup(); return -14; } - - if (cudaMemcpy(d_keys, keys, keys_bytes, cudaMemcpyHostToDevice) != cudaSuccess) { cleanup(); return -20; } - if (cudaMemcpy(d_nonces, nonces, nonces_bytes, cudaMemcpyHostToDevice) != cudaSuccess) { cleanup(); return -21; } - if (cudaMemcpy(d_in, in_ptr, in_len, cudaMemcpyHostToDevice) != cudaSuccess) { cleanup(); return -22; } - if (cudaMemcpy(d_jobs, jobs, jobs_bytes, cudaMemcpyHostToDevice) != cudaSuccess) { cleanup(); return -23; } - if (cudaMemset(d_out, 0, outputs_bytes) != cudaSuccess) { cleanup(); return -24; } - - launch(grid_for(n), kThreadsPerBlock, - d_jobs, d_keys, d_nonces, d_in, d_out, n); - - if (cudaDeviceSynchronize() != cudaSuccess) { cleanup(); return -30; } - if (cudaMemcpy(outputs, d_out, outputs_bytes, cudaMemcpyDeviceToHost) != cudaSuccess) { cleanup(); return -31; } - - cleanup(); - return 0; -} - -} // namespace - -extern "C" { - -int lux_aead_cuda_available(void) { - int count = 0; - cudaError_t e = cudaGetDeviceCount(&count); - return (e == cudaSuccess && count > 0) ? 1 : 0; -} - -int aead_chacha20poly1305_batch_cuda( - const uint8_t* keys, - const uint8_t* nonces, - const uint8_t* inputs_arena, - size_t inputs_arena_len, - const void* jobs, - size_t n, - uint8_t* outputs_arena, - size_t outputs_arena_len) { - if (n == 0) return 0; - if (!keys || !nonces || !jobs || !outputs_arena) return -1; - if (!lux_aead_cuda_available()) return -2; - - return dispatch_aead( - keys, n * 32, - nonces, n * 12, - inputs_arena, inputs_arena_len, - jobs, n * 32, // sizeof(AeadJob) = 8 * uint32_t = 32 bytes - outputs_arena, outputs_arena_len, - (unsigned)n, - [](unsigned grid, unsigned tg, - void* d_jobs, void* d_keys, void* d_nonces, void* d_in, void* d_out, - unsigned n) { - chacha20_poly1305_jobs<<>>( - static_cast(d_jobs), - static_cast(d_keys), - static_cast(d_nonces), - static_cast(d_in), - static_cast(d_out), - n); - }); -} - -int aead_aes_256_gcm_batch_cuda( - const uint8_t* keys, - const uint8_t* ivs, - const uint8_t* inputs_arena, - size_t inputs_arena_len, - const void* jobs, - size_t n, - uint8_t* outputs_arena, - size_t outputs_arena_len) { - if (n == 0) return 0; - if (!keys || !ivs || !jobs || !outputs_arena) return -1; - if (!lux_aead_cuda_available()) return -2; - - return dispatch_aead( - keys, n * 32, - ivs, n * 12, - inputs_arena, inputs_arena_len, - jobs, n * 32, - outputs_arena, outputs_arena_len, - (unsigned)n, - [](unsigned grid, unsigned tg, - void* d_jobs, void* d_keys, void* d_nonces, void* d_in, void* d_out, - unsigned n) { - aes_gcm_jobs<<>>( - static_cast(d_jobs), - static_cast(d_keys), - static_cast(d_nonces), - static_cast(d_in), - static_cast(d_out), - n); - }); -} - -} // extern "C" - -#else // LUX_AEAD_HAVE_CUDA not defined: stub mode - -extern "C" { - -int lux_aead_cuda_available(void) { return 0; } - -int aead_chacha20poly1305_batch_cuda( - const uint8_t*, const uint8_t*, const uint8_t*, size_t, - const void*, size_t, uint8_t*, size_t) { - return -1; -} - -int aead_aes_256_gcm_batch_cuda( - const uint8_t*, const uint8_t*, const uint8_t*, size_t, - const void*, size_t, uint8_t*, size_t) { - return -1; -} - -} // extern "C" - -#endif // LUX_AEAD_HAVE_CUDA diff --git a/aead/gpu/cuda/aead_driver_cuda.h b/aead/gpu/cuda/aead_driver_cuda.h deleted file mode 100644 index f969627..0000000 --- a/aead/gpu/cuda/aead_driver_cuda.h +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Public C ABI for the CUDA AEAD batch driver. The driver compiles in two -// modes: -// * LUX_AEAD_HAVE_CUDA defined -> real CUDA dispatch -// * not defined -> stub mode, every entry returns -1 -// -// All entry points return 0 on success, negative on failure (matching the -// Metal driver convention in gpu/metal/*_driver.mm). - -#pragma once - -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -// Returns 1 if a CUDA-capable device is reachable, 0 otherwise. -int lux_aead_cuda_available(void); - -// Encrypt n ChaCha20-Poly1305 messages in a single GPU dispatch. -// -// keys -- n * 32 bytes (ChaCha20 keys) -// nonces -- n * 12 bytes (RFC 8439 IETF nonces) -// inputs_arena -- packed (aad || plaintext) per message -// inputs_arena_len -- total bytes of inputs_arena (may be 0) -// jobs -- n AeadJob records (32 bytes each, see *.cu structs) -// n -- number of messages -// outputs_arena -- caller-allocated; receives (ciphertext || 16-byte tag) -// outputs_arena_len-- total capacity of outputs_arena -// -// Returns 0 on success. -int aead_chacha20poly1305_batch_cuda( - const uint8_t* keys, - const uint8_t* nonces, - const uint8_t* inputs_arena, - size_t inputs_arena_len, - const void* jobs, - size_t n, - uint8_t* outputs_arena, - size_t outputs_arena_len); - -// Encrypt n AES-256-GCM messages in a single GPU dispatch. -// Layout matches aead_chacha20poly1305_batch_cuda; nonces[] holds 12-byte -// IVs (NIST SP 800-38D 96-bit IV path). -int aead_aes_256_gcm_batch_cuda( - const uint8_t* keys, - const uint8_t* ivs, - const uint8_t* inputs_arena, - size_t inputs_arena_len, - const void* jobs, - size_t n, - uint8_t* outputs_arena, - size_t outputs_arena_len); - -#ifdef __cplusplus -} // extern "C" -#endif diff --git a/aead/gpu/cuda/aes_gcm.cu b/aead/gpu/cuda/aes_gcm.cu deleted file mode 100644 index d4abad4..0000000 --- a/aead/gpu/cuda/aes_gcm.cu +++ /dev/null @@ -1,427 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Batched AES-256-GCM (NIST SP 800-38D, 96-bit IV). One thread per -// (key, iv, aad, plaintext) message; output is byte-equal to -// lux::crypto::aead::aes_256_gcm::encrypt() in cpp/aead.cpp. -// -// Per-message fanout. Constant-time S-box (Boyar-Peralta) -- table-free, -// no data-dependent branches. Constant-time GHASH (128 iterations always). - -#include - -#ifndef __CUDA_ARCH__ -#define __device__ -#define __global__ -#define __host__ -struct dim3 { unsigned x, y, z; }; -static dim3 blockIdx, blockDim, threadIdx; -#endif - -extern "C" { - -struct AeadJob { - uint32_t aad_offset; - uint32_t aad_len; - uint32_t pt_offset; - uint32_t pt_len; - uint32_t ct_offset; - uint32_t tag_offset; - uint32_t key_offset; - uint32_t nonce_offset; -}; - -} // extern "C" - -// AES-256: 14 rounds, expanded round-key schedule = 240 bytes (15 round keys). -#define AES256_ROUNDS 14u -#define AES256_KS_BYTES 240u - -// --------------------------------------------------------------------------- -// AES S-box -- Boyar-Peralta circuit (J. Cryptol. 2010). Byte-for-byte port -// of lux::crypto::aead::aes::aes_sbox in cpp/aead.cpp. -// --------------------------------------------------------------------------- -__device__ static inline uint8_t aes_sbox(uint8_t x) { - uint8_t U0 = (x >> 7) & 1u; - uint8_t U1 = (x >> 6) & 1u; - uint8_t U2 = (x >> 5) & 1u; - uint8_t U3 = (x >> 4) & 1u; - uint8_t U4 = (x >> 3) & 1u; - uint8_t U5 = (x >> 2) & 1u; - uint8_t U6 = (x >> 1) & 1u; - uint8_t U7 = x & 1u; - - uint8_t T1 = U0 ^ U3; - uint8_t T2 = U0 ^ U5; - uint8_t T3 = U0 ^ U6; - uint8_t T4 = U3 ^ U5; - uint8_t T5 = U4 ^ U6; - uint8_t T6 = T1 ^ T5; - uint8_t T7 = U1 ^ U2; - uint8_t T8 = U7 ^ T6; - uint8_t T9 = U7 ^ T7; - uint8_t T10 = T6 ^ T7; - uint8_t T11 = U1 ^ U5; - uint8_t T12 = U2 ^ U5; - uint8_t T13 = T3 ^ T4; - uint8_t T14 = T6 ^ T11; - uint8_t T15 = T5 ^ T11; - uint8_t T16 = T5 ^ T12; - uint8_t T17 = T9 ^ T16; - uint8_t T18 = U3 ^ U7; - uint8_t T19 = T7 ^ T18; - uint8_t T20 = T1 ^ T19; - uint8_t T21 = U6 ^ U7; - uint8_t T22 = T7 ^ T21; - uint8_t T23 = T2 ^ T22; - uint8_t T24 = T2 ^ T10; - uint8_t T25 = T20 ^ T17; - uint8_t T26 = T3 ^ T16; - uint8_t T27 = T1 ^ T12; - - uint8_t M1 = T13 & T6; - uint8_t M2 = T23 & T8; - uint8_t M3 = T14 ^ M1; - uint8_t M4 = T19 & U7; - uint8_t M5 = M4 ^ M1; - uint8_t M6 = T3 & T16; - uint8_t M7 = T22 & T9; - uint8_t M8 = T26 ^ M6; - uint8_t M9 = T20 & T17; - uint8_t M10 = M9 ^ M6; - uint8_t M11 = T1 & T15; - uint8_t M12 = T4 & T27; - uint8_t M13 = M12 ^ M11; - uint8_t M14 = T2 & T10; - uint8_t M15 = M14 ^ M11; - uint8_t M16 = M3 ^ M2; - uint8_t M17 = M5 ^ T24; - uint8_t M18 = M8 ^ M7; - uint8_t M19 = M10 ^ M15; - uint8_t M20 = M16 ^ M13; - uint8_t M21 = M17 ^ M15; - uint8_t M22 = M18 ^ M13; - uint8_t M23 = M19 ^ T25; - uint8_t M24 = M22 ^ M23; - uint8_t M25 = M22 & M20; - uint8_t M26 = M21 ^ M25; - uint8_t M27 = M20 ^ M21; - uint8_t M28 = M23 ^ M25; - uint8_t M29 = M28 & M27; - uint8_t M30 = M26 & M24; - uint8_t M31 = M20 & M23; - uint8_t M32 = M27 & M31; - uint8_t M33 = M27 ^ M25; - uint8_t M34 = M21 & M22; - uint8_t M35 = M24 & M34; - uint8_t M36 = M24 ^ M25; - uint8_t M37 = M21 ^ M29; - uint8_t M38 = M32 ^ M33; - uint8_t M39 = M23 ^ M30; - uint8_t M40 = M35 ^ M36; - uint8_t M41 = M38 ^ M40; - uint8_t M42 = M37 ^ M39; - uint8_t M43 = M37 ^ M38; - uint8_t M44 = M39 ^ M40; - uint8_t M45 = M42 ^ M41; - uint8_t M46 = M44 & T6; - uint8_t M47 = M40 & T8; - uint8_t M48 = M39 & U7; - uint8_t M49 = M43 & T16; - uint8_t M50 = M38 & T9; - uint8_t M51 = M37 & T17; - uint8_t M52 = M42 & T15; - uint8_t M53 = M45 & T27; - uint8_t M54 = M41 & T10; - uint8_t M55 = M44 & T13; - uint8_t M56 = M40 & T23; - uint8_t M57 = M39 & T19; - uint8_t M58 = M43 & T3; - uint8_t M59 = M38 & T22; - uint8_t M60 = M37 & T20; - uint8_t M61 = M42 & T1; - uint8_t M62 = M45 & T4; - uint8_t M63 = M41 & T2; - - uint8_t L0 = M61 ^ M62; - uint8_t L1 = M50 ^ M56; - uint8_t L2 = M46 ^ M48; - uint8_t L3 = M47 ^ M55; - uint8_t L4 = M54 ^ M58; - uint8_t L5 = M49 ^ M61; - uint8_t L6 = M62 ^ L5; - uint8_t L7 = M46 ^ L3; - uint8_t L8 = M51 ^ M59; - uint8_t L9 = M52 ^ M53; - uint8_t L10 = M53 ^ L4; - uint8_t L11 = M60 ^ L2; - uint8_t L12 = M48 ^ M51; - uint8_t L13 = M50 ^ L0; - uint8_t L14 = M52 ^ M61; - uint8_t L15 = M55 ^ L1; - uint8_t L16 = M56 ^ L0; - uint8_t L17 = M57 ^ L1; - uint8_t L18 = M58 ^ L8; - uint8_t L19 = M63 ^ L4; - uint8_t L20 = L0 ^ L1; - uint8_t L21 = L1 ^ L7; - uint8_t L22 = L3 ^ L12; - uint8_t L23 = L18 ^ L2; - uint8_t L24 = L15 ^ L9; - uint8_t L25 = L6 ^ L10; - uint8_t L26 = L7 ^ L9; - uint8_t L27 = L8 ^ L10; - uint8_t L28 = L11 ^ L14; - uint8_t L29 = L11 ^ L17; - - uint8_t S0 = L6 ^ L24; - uint8_t S1 = L16 ^ L26; S1 ^= 1u; - uint8_t S2 = L19 ^ L28; S2 ^= 1u; - uint8_t S3 = L6 ^ L21; - uint8_t S4 = L20 ^ L22; - uint8_t S5 = L25 ^ L29; - uint8_t S6 = L13 ^ L27; S6 ^= 1u; - uint8_t S7 = L6 ^ L23; S7 ^= 1u; - - return (uint8_t)( - ((S0 & 1u) << 7) | - ((S1 & 1u) << 6) | - ((S2 & 1u) << 5) | - ((S3 & 1u) << 4) | - ((S4 & 1u) << 3) | - ((S5 & 1u) << 2) | - ((S6 & 1u) << 1) | - ( S7 & 1u)); -} - -// FIPS 197 round-constants for AES-256 (only Rcon[0..6] needed). -__device__ static const uint8_t RCON[7] = { - 0x01u, 0x02u, 0x04u, 0x08u, 0x10u, 0x20u, 0x40u -}; - -__device__ static inline uint8_t xtime(uint8_t x) { - return (uint8_t)((x << 1) ^ (((x >> 7) & 1u) * 0x1bu)); -} - -// --------------------------------------------------------------------------- -// AES-256 key expansion (FIPS 197 §5.2). Output: 240-byte round-key schedule. -// --------------------------------------------------------------------------- -__device__ static inline void aes256_expand_key(const uint8_t* key, uint8_t* rk) { - for (uint32_t i = 0u; i < 32u; ++i) rk[i] = key[i]; - - for (uint32_t i = 8u; i < 60u; ++i) { - uint8_t t0 = rk[(i - 1u) * 4u + 0u]; - uint8_t t1 = rk[(i - 1u) * 4u + 1u]; - uint8_t t2 = rk[(i - 1u) * 4u + 2u]; - uint8_t t3 = rk[(i - 1u) * 4u + 3u]; - - if ((i & 7u) == 0u) { - uint8_t r0 = t1, r1 = t2, r2 = t3, r3 = t0; - t0 = (uint8_t)(aes_sbox(r0) ^ RCON[(i / 8u) - 1u]); - t1 = aes_sbox(r1); - t2 = aes_sbox(r2); - t3 = aes_sbox(r3); - } else if ((i & 7u) == 4u) { - t0 = aes_sbox(t0); - t1 = aes_sbox(t1); - t2 = aes_sbox(t2); - t3 = aes_sbox(t3); - } - - rk[i * 4u + 0u] = (uint8_t)(rk[(i - 8u) * 4u + 0u] ^ t0); - rk[i * 4u + 1u] = (uint8_t)(rk[(i - 8u) * 4u + 1u] ^ t1); - rk[i * 4u + 2u] = (uint8_t)(rk[(i - 8u) * 4u + 2u] ^ t2); - rk[i * 4u + 3u] = (uint8_t)(rk[(i - 8u) * 4u + 3u] ^ t3); - } -} - -// --------------------------------------------------------------------------- -// AES-256 encrypt one 16-byte block (FIPS 197 §5.1). State layout column- -// major: s[c*4 + r]. Byte-for-byte equivalent to aes::encrypt_block in -// cpp/aead.cpp. -// --------------------------------------------------------------------------- -__device__ static inline void aes256_encrypt_block(const uint8_t* rk, - const uint8_t* in, - uint8_t* out) { - uint8_t s[16]; - for (uint32_t i = 0u; i < 16u; ++i) s[i] = (uint8_t)(in[i] ^ rk[i]); - - for (uint32_t round = 1u; round < AES256_ROUNDS; ++round) { - for (uint32_t i = 0u; i < 16u; ++i) s[i] = aes_sbox(s[i]); - - // ShiftRows (column-major: row r at indices r, r+4, r+8, r+12). - uint8_t t; - t = s[1]; s[1] = s[5]; s[5] = s[9]; s[9] = s[13]; s[13] = t; - t = s[2]; s[2] = s[10]; s[10] = t; - t = s[6]; s[6] = s[14]; s[14] = t; - t = s[15]; s[15] = s[11]; s[11] = s[7]; s[7] = s[3]; s[3] = t; - - // MixColumns. - for (uint32_t c = 0u; c < 4u; ++c) { - uint8_t a0 = s[c*4u + 0u]; - uint8_t a1 = s[c*4u + 1u]; - uint8_t a2 = s[c*4u + 2u]; - uint8_t a3 = s[c*4u + 3u]; - uint8_t x = a0 ^ a1 ^ a2 ^ a3; - uint8_t y0 = a0; - s[c*4u + 0u] = (uint8_t)(a0 ^ x ^ xtime(a0 ^ a1)); - s[c*4u + 1u] = (uint8_t)(a1 ^ x ^ xtime(a1 ^ a2)); - s[c*4u + 2u] = (uint8_t)(a2 ^ x ^ xtime(a2 ^ a3)); - s[c*4u + 3u] = (uint8_t)(a3 ^ x ^ xtime(a3 ^ y0)); - } - - for (uint32_t i = 0u; i < 16u; ++i) s[i] ^= rk[round * 16u + i]; - } - - // Final round (no MixColumns). - for (uint32_t i = 0u; i < 16u; ++i) s[i] = aes_sbox(s[i]); - { - uint8_t t; - t = s[1]; s[1] = s[5]; s[5] = s[9]; s[9] = s[13]; s[13] = t; - t = s[2]; s[2] = s[10]; s[10] = t; - t = s[6]; s[6] = s[14]; s[14] = t; - t = s[15]; s[15] = s[11]; s[11] = s[7]; s[7] = s[3]; s[3] = t; - } - for (uint32_t i = 0u; i < 16u; ++i) { - out[i] = (uint8_t)(s[i] ^ rk[AES256_ROUNDS * 16u + i]); - } -} - -// --------------------------------------------------------------------------- -// GHASH multiplication in GF(2^128) with reduction polynomial -// x^128 + x^7 + x^2 + x + 1 (NIST SP 800-38D §6.3). Bits are GCM-style -// (bit 0 = MSB of byte 0). Constant-time: 128 iterations always. -// --------------------------------------------------------------------------- -__device__ static inline void ghash_mul(uint8_t* z, const uint8_t* h) { - uint8_t v[16]; - for (uint32_t i = 0u; i < 16u; ++i) v[i] = h[i]; - uint8_t r[16]; - for (uint32_t i = 0u; i < 16u; ++i) r[i] = 0u; - - for (uint32_t i = 0u; i < 128u; ++i) { - uint8_t zbit = (uint8_t)((z[i >> 3] >> (7u - (i & 7u))) & 1u); - uint8_t mask = (uint8_t)(0u - (uint32_t)zbit); - for (uint32_t j = 0u; j < 16u; ++j) r[j] ^= (uint8_t)(v[j] & mask); - - uint8_t lsb = (uint8_t)(v[15] & 1u); - for (uint32_t j = 15u; j > 0u; --j) { - v[j] = (uint8_t)((v[j] >> 1) | ((v[j-1u] & 1u) << 7)); - } - v[0] >>= 1; - uint8_t rmask = (uint8_t)(0u - (uint32_t)lsb); - v[0] ^= (uint8_t)(0xe1u & rmask); - } - - for (uint32_t i = 0u; i < 16u; ++i) z[i] = r[i]; -} - -__device__ static inline void inc32(uint8_t* ctr) { - uint32_t c = ((uint32_t)ctr[12] << 24) | ((uint32_t)ctr[13] << 16) - | ((uint32_t)ctr[14] << 8) | (uint32_t)ctr[15]; - c += 1u; - ctr[12] = (uint8_t)(c >> 24); - ctr[13] = (uint8_t)(c >> 16); - ctr[14] = (uint8_t)(c >> 8); - ctr[15] = (uint8_t)(c); -} - -__device__ static inline void ghash_update(uint8_t* y, const uint8_t* h, - const uint8_t* arena, - uint32_t off, uint32_t len) { - uint32_t pos = 0u; - while (len - pos >= 16u) { - for (uint32_t i = 0u; i < 16u; ++i) y[i] ^= arena[off + pos + i]; - ghash_mul(y, h); - pos += 16u; - } - uint32_t rem = len - pos; - if (rem > 0u) { - uint8_t buf[16]; - for (uint32_t i = 0u; i < 16u; ++i) buf[i] = 0u; - for (uint32_t i = 0u; i < rem; ++i) buf[i] = arena[off + pos + i]; - for (uint32_t i = 0u; i < 16u; ++i) y[i] ^= buf[i]; - ghash_mul(y, h); - } -} - -// --------------------------------------------------------------------------- -// Kernel: one thread per AEAD job. AES-256-GCM seal. -// --------------------------------------------------------------------------- -extern "C" __global__ void aes_gcm_jobs( - const AeadJob* __restrict__ jobs, - const uint8_t* __restrict__ keys, - const uint8_t* __restrict__ nonces, - const uint8_t* __restrict__ inputs_arena, - uint8_t* __restrict__ outputs_arena, - uint32_t n_jobs) -{ - uint32_t gid = blockIdx.x * blockDim.x + threadIdx.x; - if (gid >= n_jobs) return; - - AeadJob job = jobs[gid]; - const uint8_t* key = keys + job.key_offset; - const uint8_t* iv = nonces + job.nonce_offset; - - // Key schedule. - uint8_t rk[AES256_KS_BYTES]; - aes256_expand_key(key, rk); - - // H = AES_K(0^128). - uint8_t H[16]; - { - uint8_t zero[16]; - for (uint32_t i = 0u; i < 16u; ++i) zero[i] = 0u; - aes256_encrypt_block(rk, zero, H); - } - - // J0 = IV || 0x00000001 (96-bit IV path). - uint8_t J0[16]; - for (uint32_t i = 0u; i < 12u; ++i) J0[i] = iv[i]; - J0[12] = 0u; J0[13] = 0u; J0[14] = 0u; J0[15] = 1u; - - // Encrypt plaintext under counter starting at inc32(J0). - { - uint8_t ctr[16]; - for (uint32_t i = 0u; i < 16u; ++i) ctr[i] = J0[i]; - inc32(ctr); - - uint32_t pos = 0u; - while (pos < job.pt_len) { - uint8_t ks[16]; - aes256_encrypt_block(rk, ctr, ks); - uint32_t take = job.pt_len - pos; - if (take > 16u) take = 16u; - for (uint32_t i = 0u; i < take; ++i) { - outputs_arena[job.ct_offset + pos + i] = - (uint8_t)(inputs_arena[job.pt_offset + pos + i] ^ ks[i]); - } - inc32(ctr); - pos += take; - } - } - - // GHASH over (aad || pad || ct || pad || lens_in_bits). - uint8_t Y[16]; - for (uint32_t i = 0u; i < 16u; ++i) Y[i] = 0u; - ghash_update(Y, H, inputs_arena, job.aad_offset, job.aad_len); - ghash_update(Y, H, outputs_arena, job.ct_offset, job.pt_len); - { - uint8_t lens[16]; - uint64_t la = (uint64_t)job.aad_len * 8ul; - uint64_t lc = (uint64_t)job.pt_len * 8ul; - // BE encoding. - for (uint32_t i = 0u; i < 8u; ++i) lens[i] = (uint8_t)(la >> (8u * (7u - i))); - for (uint32_t i = 0u; i < 8u; ++i) lens[8u + i] = (uint8_t)(lc >> (8u * (7u - i))); - for (uint32_t i = 0u; i < 16u; ++i) Y[i] ^= lens[i]; - ghash_mul(Y, H); - } - - // Tag = GHASH XOR AES_K(J0). - { - uint8_t s[16]; - aes256_encrypt_block(rk, J0, s); - for (uint32_t i = 0u; i < 16u; ++i) { - outputs_arena[job.tag_offset + i] = (uint8_t)(Y[i] ^ s[i]); - } - } -} diff --git a/aead/gpu/cuda/chacha20_poly1305.cu b/aead/gpu/cuda/chacha20_poly1305.cu deleted file mode 100644 index 7417c50..0000000 --- a/aead/gpu/cuda/chacha20_poly1305.cu +++ /dev/null @@ -1,299 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Batched ChaCha20-Poly1305 (RFC 8439). One thread per (key, nonce, aad, -// plaintext) message; output is byte-equal to lux::crypto::aead:: -// chacha20_poly1305::encrypt() in cpp/aead.cpp. -// -// Per-message fanout: typical AEAD workload of many TLS-record-sized -// messages. Per-block fanout is a future kernel. Layout matches -// gpu/metal/aead_batch.metal exactly. - -#include - -#ifndef __CUDA_ARCH__ -#define __device__ -#define __global__ -#define __host__ -struct dim3 { unsigned x, y, z; }; -static dim3 blockIdx, blockDim, threadIdx; -#endif - -extern "C" { - -struct AeadJob { - uint32_t aad_offset; - uint32_t aad_len; - uint32_t pt_offset; - uint32_t pt_len; - uint32_t ct_offset; - uint32_t tag_offset; - uint32_t key_offset; - uint32_t nonce_offset; -}; - -} // extern "C" - -// --------------------------------------------------------------------------- -// ChaCha20 (RFC 8439 §2.3) -- direct port of metal/aead_batch.metal. -// --------------------------------------------------------------------------- - -__device__ static inline uint32_t rotl32_d(uint32_t x, uint32_t n) { - return (x << n) | (x >> (32u - n)); -} - -__device__ static inline void quarter(uint32_t& a, uint32_t& b, - uint32_t& c, uint32_t& d) { - a += b; d ^= a; d = rotl32_d(d, 16u); - c += d; b ^= c; b = rotl32_d(b, 12u); - a += b; d ^= a; d = rotl32_d(d, 8u); - c += d; b ^= c; b = rotl32_d(b, 7u); -} - -__device__ static inline uint32_t load32_le_g(const uint8_t* p) { - return (uint32_t)p[0] - | ((uint32_t)p[1] << 8) - | ((uint32_t)p[2] << 16) - | ((uint32_t)p[3] << 24); -} - -__device__ static inline void store32_le_t(uint8_t* p, uint32_t v) { - p[0] = (uint8_t)(v ); - p[1] = (uint8_t)(v >> 8); - p[2] = (uint8_t)(v >> 16); - p[3] = (uint8_t)(v >> 24); -} - -__device__ static inline void chacha20_block(const uint8_t* key, - const uint8_t* nonce, - uint32_t counter, - uint8_t* out64) { - uint32_t s[16]; - s[ 0] = 0x61707865u; s[ 1] = 0x3320646eu; - s[ 2] = 0x79622d32u; s[ 3] = 0x6b206574u; - s[ 4] = load32_le_g(key + 0); s[ 5] = load32_le_g(key + 4); - s[ 6] = load32_le_g(key + 8); s[ 7] = load32_le_g(key + 12); - s[ 8] = load32_le_g(key + 16); s[ 9] = load32_le_g(key + 20); - s[10] = load32_le_g(key + 24); s[11] = load32_le_g(key + 28); - s[12] = counter; - s[13] = load32_le_g(nonce + 0); - s[14] = load32_le_g(nonce + 4); - s[15] = load32_le_g(nonce + 8); - - uint32_t v[16]; - for (uint32_t i = 0; i < 16; ++i) v[i] = s[i]; - for (uint32_t i = 0; i < 10; ++i) { - quarter(v[0], v[4], v[ 8], v[12]); - quarter(v[1], v[5], v[ 9], v[13]); - quarter(v[2], v[6], v[10], v[14]); - quarter(v[3], v[7], v[11], v[15]); - quarter(v[0], v[5], v[10], v[15]); - quarter(v[1], v[6], v[11], v[12]); - quarter(v[2], v[7], v[ 8], v[13]); - quarter(v[3], v[4], v[ 9], v[14]); - } - for (uint32_t i = 0; i < 16; ++i) { - store32_le_t(out64 + i * 4, v[i] + s[i]); - } -} - -// --------------------------------------------------------------------------- -// Poly1305 (RFC 8439 §2.5, radix 2^26). -// --------------------------------------------------------------------------- - -struct Poly { - uint32_t r[5]; - uint32_t s[4]; - uint32_t h[5]; -}; - -__device__ static inline uint32_t load32_le_t(const uint8_t* p) { - return (uint32_t)p[0] - | ((uint32_t)p[1] << 8) - | ((uint32_t)p[2] << 16) - | ((uint32_t)p[3] << 24); -} - -__device__ static inline void poly_init(Poly& st, const uint8_t* key) { - uint32_t c0 = load32_le_t(key + 0) & 0x0fffffffu; - uint32_t c1 = load32_le_t(key + 4) & 0x0ffffffcu; - uint32_t c2 = load32_le_t(key + 8) & 0x0ffffffcu; - uint32_t c3 = load32_le_t(key + 12) & 0x0ffffffcu; - st.r[0] = c0 & 0x3ffffffu; - st.r[1] = ((c0 >> 26) | (c1 << 6)) & 0x3ffffffu; - st.r[2] = ((c1 >> 20) | (c2 << 12)) & 0x3ffffffu; - st.r[3] = ((c2 >> 14) | (c3 << 18)) & 0x3ffffffu; - st.r[4] = (c3 >> 8) & 0x3ffffffu; - - st.s[0] = load32_le_t(key + 16); - st.s[1] = load32_le_t(key + 20); - st.s[2] = load32_le_t(key + 24); - st.s[3] = load32_le_t(key + 28); - for (uint32_t i = 0; i < 5; ++i) st.h[i] = 0u; -} - -__device__ static inline void poly_block(Poly& st, const uint8_t* m, - uint32_t hibit) { - uint32_t t0 = load32_le_t(m + 0); - uint32_t t1 = load32_le_t(m + 4); - uint32_t t2 = load32_le_t(m + 8); - uint32_t t3 = load32_le_t(m + 12); - - uint64_t h0 = (uint64_t)st.h[0] + ( t0 & 0x3ffffffu); - uint64_t h1 = (uint64_t)st.h[1] + (((t0 >> 26) | (t1 << 6)) & 0x3ffffffu); - uint64_t h2 = (uint64_t)st.h[2] + (((t1 >> 20) | (t2 << 12)) & 0x3ffffffu); - uint64_t h3 = (uint64_t)st.h[3] + (((t2 >> 14) | (t3 << 18)) & 0x3ffffffu); - uint64_t h4 = (uint64_t)st.h[4] + ( (t3 >> 8) | (uint64_t)hibit); - - uint64_t r0 = st.r[0]; uint64_t r1 = st.r[1]; - uint64_t r2 = st.r[2]; uint64_t r3 = st.r[3]; - uint64_t r4 = st.r[4]; - uint64_t s1 = r1 * 5UL; uint64_t s2 = r2 * 5UL; - uint64_t s3 = r3 * 5UL; uint64_t s4 = r4 * 5UL; - - uint64_t d0 = h0*r0 + h1*s4 + h2*s3 + h3*s2 + h4*s1; - uint64_t d1 = h0*r1 + h1*r0 + h2*s4 + h3*s3 + h4*s2; - uint64_t d2 = h0*r2 + h1*r1 + h2*r0 + h3*s4 + h4*s3; - uint64_t d3 = h0*r3 + h1*r2 + h2*r1 + h3*r0 + h4*s4; - uint64_t d4 = h0*r4 + h1*r3 + h2*r2 + h3*r1 + h4*r0; - - uint64_t c; - c = d0 >> 26; d0 &= 0x3ffffffUL; d1 += c; - c = d1 >> 26; d1 &= 0x3ffffffUL; d2 += c; - c = d2 >> 26; d2 &= 0x3ffffffUL; d3 += c; - c = d3 >> 26; d3 &= 0x3ffffffUL; d4 += c; - c = d4 >> 26; d4 &= 0x3ffffffUL; d0 += c * 5UL; - c = d0 >> 26; d0 &= 0x3ffffffUL; d1 += c; - - st.h[0] = (uint32_t)d0; - st.h[1] = (uint32_t)d1; - st.h[2] = (uint32_t)d2; - st.h[3] = (uint32_t)d3; - st.h[4] = (uint32_t)d4; -} - -__device__ static inline void poly_finalize(Poly& st, uint8_t* tag) { - uint32_t h0 = st.h[0], h1 = st.h[1], h2 = st.h[2], h3 = st.h[3], h4 = st.h[4]; - uint32_t c; - c = h1 >> 26; h1 &= 0x3ffffffu; h2 += c; - c = h2 >> 26; h2 &= 0x3ffffffu; h3 += c; - c = h3 >> 26; h3 &= 0x3ffffffu; h4 += c; - c = h4 >> 26; h4 &= 0x3ffffffu; h0 += c * 5u; - c = h0 >> 26; h0 &= 0x3ffffffu; h1 += c; - - uint32_t g0 = h0 + 5u; c = g0 >> 26; g0 &= 0x3ffffffu; - uint32_t g1 = h1 + c; c = g1 >> 26; g1 &= 0x3ffffffu; - uint32_t g2 = h2 + c; c = g2 >> 26; g2 &= 0x3ffffffu; - uint32_t g3 = h3 + c; c = g3 >> 26; g3 &= 0x3ffffffu; - uint32_t g4 = h4 + c - (1u << 26); - - uint32_t mask = (g4 >> 31) - 1u; - g0 &= mask; g1 &= mask; g2 &= mask; g3 &= mask; g4 &= mask; - uint32_t nm = ~mask; - h0 = (h0 & nm) | g0; - h1 = (h1 & nm) | g1; - h2 = (h2 & nm) | g2; - h3 = (h3 & nm) | g3; - h4 = (h4 & nm) | g4; - - uint32_t f0 = h0 | (h1 << 26); - uint32_t f1 = (h1 >> 6) | (h2 << 20); - uint32_t f2 = (h2 >> 12) | (h3 << 14); - uint32_t f3 = (h3 >> 18) | (h4 << 8); - - uint64_t t = (uint64_t)f0 + (uint64_t)st.s[0]; - store32_le_t(tag + 0, (uint32_t)t); - t = (t >> 32) + (uint64_t)f1 + (uint64_t)st.s[1]; - store32_le_t(tag + 4, (uint32_t)t); - t = (t >> 32) + (uint64_t)f2 + (uint64_t)st.s[2]; - store32_le_t(tag + 8, (uint32_t)t); - t = (t >> 32) + (uint64_t)f3 + (uint64_t)st.s[3]; - store32_le_t(tag + 12, (uint32_t)t); -} - -// Absorb a buffer with RFC 8439 AEAD framing: full 16-byte blocks, final -// partial block zero-padded to 16 bytes (NO 0x01 marker; the AEAD -// construction handles framing). -__device__ static inline void absorb_padded(Poly& st, - const uint8_t* arena, - uint32_t off, uint32_t len) { - uint32_t pos = 0; - while (len - pos >= 16u) { - uint8_t buf[16]; - for (uint32_t i = 0; i < 16; ++i) buf[i] = arena[off + pos + i]; - poly_block(st, buf, 1u << 24); - pos += 16u; - } - uint32_t rem = len - pos; - if (rem > 0u) { - uint8_t buf[16]; - for (uint32_t i = 0; i < 16; ++i) buf[i] = 0; - for (uint32_t i = 0; i < rem; ++i) buf[i] = arena[off + pos + i]; - poly_block(st, buf, 1u << 24); - } -} - -// --------------------------------------------------------------------------- -// Kernel: one thread per ChaCha20-Poly1305 AEAD job. -// --------------------------------------------------------------------------- - -extern "C" __global__ void chacha20_poly1305_jobs( - const AeadJob* __restrict__ jobs, - const uint8_t* __restrict__ keys, - const uint8_t* __restrict__ nonces, - const uint8_t* __restrict__ inputs_arena, - uint8_t* __restrict__ outputs_arena, - uint32_t n_jobs) -{ - uint32_t gid = blockIdx.x * blockDim.x + threadIdx.x; - if (gid >= n_jobs) return; - - AeadJob job = jobs[gid]; - const uint8_t* key = keys + job.key_offset; - const uint8_t* nonce = nonces + job.nonce_offset; - - // Derive Poly1305 one-time key from ChaCha20 block 0. - uint8_t poly_key[32]; - { - uint8_t ks[64]; - chacha20_block(key, nonce, 0u, ks); - for (uint32_t i = 0; i < 32u; ++i) poly_key[i] = ks[i]; - } - - // Encrypt plaintext with counter starting at 1. - { - uint8_t ks[64]; - uint32_t counter = 1u; - uint32_t pos = 0u; - while (pos < job.pt_len) { - chacha20_block(key, nonce, counter, ks); - uint32_t take = job.pt_len - pos; - if (take > 64u) take = 64u; - for (uint32_t i = 0; i < take; ++i) { - outputs_arena[job.ct_offset + pos + i] = - inputs_arena[job.pt_offset + pos + i] ^ ks[i]; - } - pos += take; - ++counter; - } - } - - // MAC over (aad || pad || ct || pad || lens). - Poly st; - poly_init(st, poly_key); - absorb_padded(st, inputs_arena, job.aad_offset, job.aad_len); - absorb_padded(st, outputs_arena, job.ct_offset, job.pt_len); - { - uint8_t lens[16]; - uint64_t la = (uint64_t)job.aad_len; - uint64_t lc = (uint64_t)job.pt_len; - for (uint32_t i = 0; i < 8u; ++i) lens[i] = (uint8_t)(la >> (8u * i)); - for (uint32_t i = 0; i < 8u; ++i) lens[8u + i] = (uint8_t)(lc >> (8u * i)); - poly_block(st, lens, 1u << 24); - } - uint8_t tag[16]; - poly_finalize(st, tag); - for (uint32_t i = 0; i < 16u; ++i) { - outputs_arena[job.tag_offset + i] = tag[i]; - } -} diff --git a/aead/gpu/metal/aead_batch.metal b/aead/gpu/metal/aead_batch.metal deleted file mode 100644 index f533718..0000000 --- a/aead/gpu/metal/aead_batch.metal +++ /dev/null @@ -1,297 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Batched ChaCha20-Poly1305 (RFC 8439). One thread per (key, nonce, aad, -// plaintext) message; output is byte-equal to lux::crypto::aead:: -// chacha20_poly1305::encrypt() in cpp/aead.cpp. -// -// Per-message fanout: useful for the typical AEAD workload of many -// TLS-record-sized messages. For very large per-message payloads we'd want -// per-block fanout instead; that's a future kernel. -// -// Layout: -// * keys[i*32..] -- 32-byte ChaCha20 key for message i -// * nonces[i*12..] -- 12-byte nonce for message i -// * inputs_arena[..] -- packed concatenation of (aad || plaintext) per -// message; offsets/lens given by the per-message -// job table. -// * outputs_arena[..] -- packed concatenation of (ciphertext || tag(16)) -// per message. - -#include -using namespace metal; - -struct AeadJob { - uint32_t aad_offset; // byte offset into inputs_arena - uint32_t aad_len; - uint32_t pt_offset; // byte offset into inputs_arena - uint32_t pt_len; - uint32_t ct_offset; // byte offset into outputs_arena (ciphertext) - uint32_t tag_offset; // byte offset into outputs_arena (16 byte tag) - uint32_t key_offset; // byte offset into keys[] (always i*32) - uint32_t nonce_offset; // byte offset into nonces[] (always i*12) -}; - -// --------------------------------------------------------------------------- -// ChaCha20 -// --------------------------------------------------------------------------- - -inline uint rotl32_d(uint x, uint n) { return (x << n) | (x >> (32u - n)); } - -inline void quarter(thread uint& a, thread uint& b, thread uint& c, thread uint& d) { - a += b; d ^= a; d = rotl32_d(d, 16u); - c += d; b ^= c; b = rotl32_d(b, 12u); - a += b; d ^= a; d = rotl32_d(d, 8u); - c += d; b ^= c; b = rotl32_d(b, 7u); -} - -inline uint load32_le_d(const device uint8_t* p) { - return (uint)p[0] - | ((uint)p[1] << 8) - | ((uint)p[2] << 16) - | ((uint)p[3] << 24); -} - -inline void store32_le_t(thread uint8_t* p, uint v) { - p[0] = (uint8_t)(v ); - p[1] = (uint8_t)(v >> 8); - p[2] = (uint8_t)(v >> 16); - p[3] = (uint8_t)(v >> 24); -} - -inline void chacha20_block(const device uint8_t* key, - const device uint8_t* nonce, - uint counter, - thread uint8_t* out64) { - uint s[16]; - s[ 0] = 0x61707865u; s[ 1] = 0x3320646eu; - s[ 2] = 0x79622d32u; s[ 3] = 0x6b206574u; - s[ 4] = load32_le_d(key + 0); s[ 5] = load32_le_d(key + 4); - s[ 6] = load32_le_d(key + 8); s[ 7] = load32_le_d(key + 12); - s[ 8] = load32_le_d(key + 16); s[ 9] = load32_le_d(key + 20); - s[10] = load32_le_d(key + 24); s[11] = load32_le_d(key + 28); - s[12] = counter; - s[13] = load32_le_d(nonce + 0); - s[14] = load32_le_d(nonce + 4); - s[15] = load32_le_d(nonce + 8); - - uint v[16]; - for (uint i = 0; i < 16; ++i) v[i] = s[i]; - for (uint i = 0; i < 10; ++i) { - quarter(v[0], v[4], v[ 8], v[12]); - quarter(v[1], v[5], v[ 9], v[13]); - quarter(v[2], v[6], v[10], v[14]); - quarter(v[3], v[7], v[11], v[15]); - quarter(v[0], v[5], v[10], v[15]); - quarter(v[1], v[6], v[11], v[12]); - quarter(v[2], v[7], v[ 8], v[13]); - quarter(v[3], v[4], v[ 9], v[14]); - } - for (uint i = 0; i < 16; ++i) { - store32_le_t(out64 + i * 4, v[i] + s[i]); - } -} - -// --------------------------------------------------------------------------- -// Poly1305 (radix 2^26) -// --------------------------------------------------------------------------- -// -// State held entirely in thread-private storage; we operate on full 16-byte -// blocks (caller is responsible for zero-padding partial tails as required -// by RFC 8439 §2.8 for the AEAD construction). - -struct Poly { - uint r[5]; - uint s[4]; - uint h[5]; -}; - -inline uint load32_le_t(thread const uint8_t* p) { - return (uint)p[0] - | ((uint)p[1] << 8) - | ((uint)p[2] << 16) - | ((uint)p[3] << 24); -} - -inline void poly_init(thread Poly& st, thread const uint8_t* key) { - uint c0 = load32_le_t(key + 0) & 0x0fffffffu; - uint c1 = load32_le_t(key + 4) & 0x0ffffffcu; - uint c2 = load32_le_t(key + 8) & 0x0ffffffcu; - uint c3 = load32_le_t(key + 12) & 0x0ffffffcu; - st.r[0] = c0 & 0x3ffffffu; - st.r[1] = ((c0 >> 26) | (c1 << 6)) & 0x3ffffffu; - st.r[2] = ((c1 >> 20) | (c2 << 12)) & 0x3ffffffu; - st.r[3] = ((c2 >> 14) | (c3 << 18)) & 0x3ffffffu; - st.r[4] = (c3 >> 8) & 0x3ffffffu; - - st.s[0] = load32_le_t(key + 16); - st.s[1] = load32_le_t(key + 20); - st.s[2] = load32_le_t(key + 24); - st.s[3] = load32_le_t(key + 28); - for (uint i = 0; i < 5; ++i) st.h[i] = 0u; -} - -inline void poly_block(thread Poly& st, thread const uint8_t* m, uint hibit) { - uint t0 = load32_le_t(m + 0); - uint t1 = load32_le_t(m + 4); - uint t2 = load32_le_t(m + 8); - uint t3 = load32_le_t(m + 12); - - ulong h0 = (ulong)st.h[0] + ( t0 & 0x3ffffffu); - ulong h1 = (ulong)st.h[1] + (((t0 >> 26) | (t1 << 6)) & 0x3ffffffu); - ulong h2 = (ulong)st.h[2] + (((t1 >> 20) | (t2 << 12)) & 0x3ffffffu); - ulong h3 = (ulong)st.h[3] + (((t2 >> 14) | (t3 << 18)) & 0x3ffffffu); - ulong h4 = (ulong)st.h[4] + ( (t3 >> 8) | (ulong)hibit); - - ulong r0 = st.r[0]; ulong r1 = st.r[1]; - ulong r2 = st.r[2]; ulong r3 = st.r[3]; - ulong r4 = st.r[4]; - ulong s1 = r1 * 5UL; ulong s2 = r2 * 5UL; - ulong s3 = r3 * 5UL; ulong s4 = r4 * 5UL; - - ulong d0 = h0*r0 + h1*s4 + h2*s3 + h3*s2 + h4*s1; - ulong d1 = h0*r1 + h1*r0 + h2*s4 + h3*s3 + h4*s2; - ulong d2 = h0*r2 + h1*r1 + h2*r0 + h3*s4 + h4*s3; - ulong d3 = h0*r3 + h1*r2 + h2*r1 + h3*r0 + h4*s4; - ulong d4 = h0*r4 + h1*r3 + h2*r2 + h3*r1 + h4*r0; - - ulong c; - c = d0 >> 26; d0 &= 0x3ffffffUL; d1 += c; - c = d1 >> 26; d1 &= 0x3ffffffUL; d2 += c; - c = d2 >> 26; d2 &= 0x3ffffffUL; d3 += c; - c = d3 >> 26; d3 &= 0x3ffffffUL; d4 += c; - c = d4 >> 26; d4 &= 0x3ffffffUL; d0 += c * 5UL; - c = d0 >> 26; d0 &= 0x3ffffffUL; d1 += c; - - st.h[0] = (uint)d0; - st.h[1] = (uint)d1; - st.h[2] = (uint)d2; - st.h[3] = (uint)d3; - st.h[4] = (uint)d4; -} - -inline void poly_finalize(thread Poly& st, thread uint8_t* tag) { - uint h0 = st.h[0], h1 = st.h[1], h2 = st.h[2], h3 = st.h[3], h4 = st.h[4]; - uint c; - c = h1 >> 26; h1 &= 0x3ffffffu; h2 += c; - c = h2 >> 26; h2 &= 0x3ffffffu; h3 += c; - c = h3 >> 26; h3 &= 0x3ffffffu; h4 += c; - c = h4 >> 26; h4 &= 0x3ffffffu; h0 += c * 5u; - c = h0 >> 26; h0 &= 0x3ffffffu; h1 += c; - - uint g0 = h0 + 5u; c = g0 >> 26; g0 &= 0x3ffffffu; - uint g1 = h1 + c; c = g1 >> 26; g1 &= 0x3ffffffu; - uint g2 = h2 + c; c = g2 >> 26; g2 &= 0x3ffffffu; - uint g3 = h3 + c; c = g3 >> 26; g3 &= 0x3ffffffu; - uint g4 = h4 + c - (1u << 26); - - uint mask = (g4 >> 31) - 1u; - g0 &= mask; g1 &= mask; g2 &= mask; g3 &= mask; g4 &= mask; - uint nm = ~mask; - h0 = (h0 & nm) | g0; - h1 = (h1 & nm) | g1; - h2 = (h2 & nm) | g2; - h3 = (h3 & nm) | g3; - h4 = (h4 & nm) | g4; - - uint f0 = h0 | (h1 << 26); - uint f1 = (h1 >> 6) | (h2 << 20); - uint f2 = (h2 >> 12) | (h3 << 14); - uint f3 = (h3 >> 18) | (h4 << 8); - - ulong t = (ulong)f0 + (ulong)st.s[0]; - store32_le_t(tag + 0, (uint)t); - t = (t >> 32) + (ulong)f1 + (ulong)st.s[1]; - store32_le_t(tag + 4, (uint)t); - t = (t >> 32) + (ulong)f2 + (ulong)st.s[2]; - store32_le_t(tag + 8, (uint)t); - t = (t >> 32) + (ulong)f3 + (ulong)st.s[3]; - store32_le_t(tag + 12, (uint)t); -} - -// Absorb a buffer with RFC 8439 AEAD framing: full 16-byte blocks, final -// partial block zero-padded to 16 bytes (NO 0x01 marker). -inline void absorb_padded(thread Poly& st, - const device uint8_t* arena, - uint off, uint len) { - uint pos = 0; - while (len - pos >= 16) { - // Copy 16 bytes from device memory to thread memory. - thread uint8_t buf[16]; - for (uint i = 0; i < 16; ++i) buf[i] = arena[off + pos + i]; - poly_block(st, buf, 1u << 24); - pos += 16; - } - uint rem = len - pos; - if (rem > 0) { - thread uint8_t buf[16]; - for (uint i = 0; i < 16; ++i) buf[i] = 0; - for (uint i = 0; i < rem; ++i) buf[i] = arena[off + pos + i]; - poly_block(st, buf, 1u << 24); - } -} - -// --------------------------------------------------------------------------- -// Kernel: one thread per AEAD job. -// --------------------------------------------------------------------------- - -kernel void aead_jobs( - device const AeadJob* jobs [[buffer(0)]], - device const uint8_t* keys [[buffer(1)]], - device const uint8_t* nonces [[buffer(2)]], - device const uint8_t* inputs_arena [[buffer(3)]], - device uint8_t* outputs_arena [[buffer(4)]], - device const uint& n_jobs [[buffer(5)]], - uint gid [[thread_position_in_grid]]) -{ - if (gid >= n_jobs) return; - - AeadJob job = jobs[gid]; - const device uint8_t* key = keys + job.key_offset; - const device uint8_t* nonce = nonces + job.nonce_offset; - - // ---- Derive Poly1305 one-time key from ChaCha20 block 0 --------------- - thread uint8_t poly_key[32]; - { - thread uint8_t ks[64]; - chacha20_block(key, nonce, 0u, ks); - for (uint i = 0; i < 32; ++i) poly_key[i] = ks[i]; - } - - // ---- Encrypt plaintext with counter starting at 1 --------------------- - { - thread uint8_t ks[64]; - uint counter = 1u; - uint pos = 0; - while (pos < job.pt_len) { - chacha20_block(key, nonce, counter, ks); - uint take = job.pt_len - pos; - if (take > 64) take = 64; - for (uint i = 0; i < take; ++i) { - outputs_arena[job.ct_offset + pos + i] = - inputs_arena[job.pt_offset + pos + i] ^ ks[i]; - } - pos += take; - ++counter; - } - } - - // ---- Compute MAC over (aad || pad || ct || pad || lens) --------------- - Poly st; - poly_init(st, poly_key); - absorb_padded(st, inputs_arena, job.aad_offset, job.aad_len); - absorb_padded(st, outputs_arena, job.ct_offset, job.pt_len); - { - thread uint8_t lens[16]; - ulong la = (ulong)job.aad_len; - ulong lc = (ulong)job.pt_len; - for (uint i = 0; i < 8; ++i) lens[i] = (uint8_t)(la >> (8u * i)); - for (uint i = 0; i < 8; ++i) lens[8 + i] = (uint8_t)(lc >> (8u * i)); - poly_block(st, lens, 1u << 24); - } - thread uint8_t tag[16]; - poly_finalize(st, tag); - for (uint i = 0; i < 16; ++i) { - outputs_arena[job.tag_offset + i] = tag[i]; - } -} diff --git a/aead/gpu/metal/aead_batch_driver.mm b/aead/gpu/metal/aead_batch_driver.mm deleted file mode 100644 index 3de491d..0000000 --- a/aead/gpu/metal/aead_batch_driver.mm +++ /dev/null @@ -1,131 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Metal driver for batched ChaCha20-Poly1305 (RFC 8439). macOS / iOS only. -// Loads aead_batch.metallib, dispatches `aead_jobs` with one thread per -// message. Output is byte-equal to lux::crypto::aead::chacha20_poly1305:: -// encrypt() in cpp/aead.cpp. - -#if __APPLE__ && __OBJC__ - -#import -#import - -#include -#include -#include -#include - -namespace { - -// Mirror the Metal struct AeadJob in metal/aead_batch.metal. -struct AeadJobGPU { - uint32_t aad_offset; - uint32_t aad_len; - uint32_t pt_offset; - uint32_t pt_len; - uint32_t ct_offset; - uint32_t tag_offset; - uint32_t key_offset; - uint32_t nonce_offset; -}; - -} // namespace - -// Encrypt n messages in a single GPU dispatch. -// -// Inputs: -// keys -- n * 32 bytes (ChaCha20 keys, packed) -// nonces -- n * 12 bytes (nonces, packed) -// inputs_arena -- packed (aad || plaintext) per message; offsets/lens in jobs -// inputs_arena_len -- total bytes in inputs_arena -// jobs -- n AeadJobGPU records -// -// Outputs: -// outputs_arena -- caller-allocated; receives (ciphertext || tag) per message -// at the offsets specified in jobs -// outputs_arena_len -- total capacity of outputs_arena -// -// Returns 0 on success, negative on failure. -extern "C" int aead_chacha20poly1305_batch_metal( - const uint8_t* keys, - const uint8_t* nonces, - const uint8_t* inputs_arena, - size_t inputs_arena_len, - const AeadJobGPU* jobs, - size_t n, - uint8_t* outputs_arena, - size_t outputs_arena_len, - const char* metallib_path) { - - if (n == 0) return 0; - if (!keys || !nonces || !inputs_arena || !jobs || !outputs_arena || - !metallib_path) { - return -1; - } - - @autoreleasepool { - id device = MTLCreateSystemDefaultDevice(); - if (!device) return -2; - - NSError* err = nil; - NSString* path = [NSString stringWithUTF8String:metallib_path]; - NSURL* url = [NSURL fileURLWithPath:path]; - id lib = [device newLibraryWithURL:url error:&err]; - if (!lib) return -3; - - id fn = [lib newFunctionWithName:@"aead_jobs"]; - if (!fn) return -4; - - id pipeline = - [device newComputePipelineStateWithFunction:fn error:&err]; - if (!pipeline) return -5; - - id queue = [device newCommandQueue]; - - id jobs_buf = [device newBufferWithBytes:jobs - length:n * sizeof(AeadJobGPU) - options:MTLResourceStorageModeShared]; - id keys_buf = [device newBufferWithBytes:keys - length:n * 32 - options:MTLResourceStorageModeShared]; - id nonces_buf = [device newBufferWithBytes:nonces - length:n * 12 - options:MTLResourceStorageModeShared]; - id inputs_buf = [device newBufferWithBytes:inputs_arena - length:inputs_arena_len - options:MTLResourceStorageModeShared]; - id outputs_buf = [device newBufferWithLength:outputs_arena_len - options:MTLResourceStorageModeShared]; - // Zero outputs_arena so any unused regions stay deterministic. - std::memset([outputs_buf contents], 0, outputs_arena_len); - uint32_t n_u32 = (uint32_t)n; - id n_buf = [device newBufferWithBytes:&n_u32 - length:sizeof(n_u32) - options:MTLResourceStorageModeShared]; - - id cmd = [queue commandBuffer]; - id enc = [cmd computeCommandEncoder]; - [enc setComputePipelineState:pipeline]; - [enc setBuffer:jobs_buf offset:0 atIndex:0]; - [enc setBuffer:keys_buf offset:0 atIndex:1]; - [enc setBuffer:nonces_buf offset:0 atIndex:2]; - [enc setBuffer:inputs_buf offset:0 atIndex:3]; - [enc setBuffer:outputs_buf offset:0 atIndex:4]; - [enc setBuffer:n_buf offset:0 atIndex:5]; - - NSUInteger tg_max = pipeline.maxTotalThreadsPerThreadgroup; - NSUInteger tg_w = tg_max < 64 ? tg_max : 64; - MTLSize threads_per_grid = MTLSizeMake(n, 1, 1); - MTLSize threads_per_tg = MTLSizeMake(tg_w, 1, 1); - [enc dispatchThreads:threads_per_grid threadsPerThreadgroup:threads_per_tg]; - [enc endEncoding]; - [cmd commit]; - [cmd waitUntilCompleted]; - - std::memcpy(outputs_arena, [outputs_buf contents], outputs_arena_len); - } - return 0; -} - -#endif // __APPLE__ && __OBJC__ diff --git a/aead/gpu/metal/aes_gcm.metal b/aead/gpu/metal/aes_gcm.metal deleted file mode 100644 index 889b06c..0000000 --- a/aead/gpu/metal/aes_gcm.metal +++ /dev/null @@ -1,452 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Batched AES-256-GCM (NIST SP 800-38D, 96-bit IV). One thread per -// (key, iv, aad, plaintext) message; output is byte-equal to -// lux::crypto::aead::aes_256_gcm::encrypt() in cpp/aead.cpp. -// -// Per-message fanout: useful for the typical AEAD workload of many -// TLS-record-sized messages. Per-block fanout is a future kernel. -// -// Layout matches gpu/metal/aead_batch.metal (struct AeadJob; same -// keys/nonces/inputs_arena/outputs_arena packing). -// -// Constant-time S-box: Boyar-Peralta (J. Cryptol. 2010), ported byte-for-byte -// from the CPU body. No table lookups, no data-dependent branches. -// Constant-time GHASH: 128 iterations regardless of input bits. - -#include -using namespace metal; - -struct AeadJob { - uint32_t aad_offset; - uint32_t aad_len; - uint32_t pt_offset; - uint32_t pt_len; - uint32_t ct_offset; - uint32_t tag_offset; - uint32_t key_offset; // i*32 (AES-256 key) - uint32_t nonce_offset; // i*12 (96-bit IV) -}; - -// AES-256: 14 rounds, expanded round-key schedule = 240 bytes (15 round keys). -constant uint AES256_ROUNDS = 14u; -constant uint AES256_KS_BYTES = 240u; - -// --------------------------------------------------------------------------- -// AES S-box -- Boyar-Peralta circuit. Byte-for-byte port of -// lux::crypto::aead::aes::aes_sbox in cpp/aead.cpp. -// --------------------------------------------------------------------------- -inline uint8_t aes_sbox(uint8_t x) { - uint8_t U0 = (x >> 7) & 1u; - uint8_t U1 = (x >> 6) & 1u; - uint8_t U2 = (x >> 5) & 1u; - uint8_t U3 = (x >> 4) & 1u; - uint8_t U4 = (x >> 3) & 1u; - uint8_t U5 = (x >> 2) & 1u; - uint8_t U6 = (x >> 1) & 1u; - uint8_t U7 = x & 1u; - - uint8_t T1 = U0 ^ U3; - uint8_t T2 = U0 ^ U5; - uint8_t T3 = U0 ^ U6; - uint8_t T4 = U3 ^ U5; - uint8_t T5 = U4 ^ U6; - uint8_t T6 = T1 ^ T5; - uint8_t T7 = U1 ^ U2; - uint8_t T8 = U7 ^ T6; - uint8_t T9 = U7 ^ T7; - uint8_t T10 = T6 ^ T7; - uint8_t T11 = U1 ^ U5; - uint8_t T12 = U2 ^ U5; - uint8_t T13 = T3 ^ T4; - uint8_t T14 = T6 ^ T11; - uint8_t T15 = T5 ^ T11; - uint8_t T16 = T5 ^ T12; - uint8_t T17 = T9 ^ T16; - uint8_t T18 = U3 ^ U7; - uint8_t T19 = T7 ^ T18; - uint8_t T20 = T1 ^ T19; - uint8_t T21 = U6 ^ U7; - uint8_t T22 = T7 ^ T21; - uint8_t T23 = T2 ^ T22; - uint8_t T24 = T2 ^ T10; - uint8_t T25 = T20 ^ T17; - uint8_t T26 = T3 ^ T16; - uint8_t T27 = T1 ^ T12; - - uint8_t M1 = T13 & T6; - uint8_t M2 = T23 & T8; - uint8_t M3 = T14 ^ M1; - uint8_t M4 = T19 & U7; - uint8_t M5 = M4 ^ M1; - uint8_t M6 = T3 & T16; - uint8_t M7 = T22 & T9; - uint8_t M8 = T26 ^ M6; - uint8_t M9 = T20 & T17; - uint8_t M10 = M9 ^ M6; - uint8_t M11 = T1 & T15; - uint8_t M12 = T4 & T27; - uint8_t M13 = M12 ^ M11; - uint8_t M14 = T2 & T10; - uint8_t M15 = M14 ^ M11; - uint8_t M16 = M3 ^ M2; - uint8_t M17 = M5 ^ T24; - uint8_t M18 = M8 ^ M7; - uint8_t M19 = M10 ^ M15; - uint8_t M20 = M16 ^ M13; - uint8_t M21 = M17 ^ M15; - uint8_t M22 = M18 ^ M13; - uint8_t M23 = M19 ^ T25; - uint8_t M24 = M22 ^ M23; - uint8_t M25 = M22 & M20; - uint8_t M26 = M21 ^ M25; - uint8_t M27 = M20 ^ M21; - uint8_t M28 = M23 ^ M25; - uint8_t M29 = M28 & M27; - uint8_t M30 = M26 & M24; - uint8_t M31 = M20 & M23; - uint8_t M32 = M27 & M31; - uint8_t M33 = M27 ^ M25; - uint8_t M34 = M21 & M22; - uint8_t M35 = M24 & M34; - uint8_t M36 = M24 ^ M25; - uint8_t M37 = M21 ^ M29; - uint8_t M38 = M32 ^ M33; - uint8_t M39 = M23 ^ M30; - uint8_t M40 = M35 ^ M36; - uint8_t M41 = M38 ^ M40; - uint8_t M42 = M37 ^ M39; - uint8_t M43 = M37 ^ M38; - uint8_t M44 = M39 ^ M40; - uint8_t M45 = M42 ^ M41; - uint8_t M46 = M44 & T6; - uint8_t M47 = M40 & T8; - uint8_t M48 = M39 & U7; - uint8_t M49 = M43 & T16; - uint8_t M50 = M38 & T9; - uint8_t M51 = M37 & T17; - uint8_t M52 = M42 & T15; - uint8_t M53 = M45 & T27; - uint8_t M54 = M41 & T10; - uint8_t M55 = M44 & T13; - uint8_t M56 = M40 & T23; - uint8_t M57 = M39 & T19; - uint8_t M58 = M43 & T3; - uint8_t M59 = M38 & T22; - uint8_t M60 = M37 & T20; - uint8_t M61 = M42 & T1; - uint8_t M62 = M45 & T4; - uint8_t M63 = M41 & T2; - - uint8_t L0 = M61 ^ M62; - uint8_t L1 = M50 ^ M56; - uint8_t L2 = M46 ^ M48; - uint8_t L3 = M47 ^ M55; - uint8_t L4 = M54 ^ M58; - uint8_t L5 = M49 ^ M61; - uint8_t L6 = M62 ^ L5; - uint8_t L7 = M46 ^ L3; - uint8_t L8 = M51 ^ M59; - uint8_t L9 = M52 ^ M53; - uint8_t L10 = M53 ^ L4; - uint8_t L11 = M60 ^ L2; - uint8_t L12 = M48 ^ M51; - uint8_t L13 = M50 ^ L0; - uint8_t L14 = M52 ^ M61; - uint8_t L15 = M55 ^ L1; - uint8_t L16 = M56 ^ L0; - uint8_t L17 = M57 ^ L1; - uint8_t L18 = M58 ^ L8; - uint8_t L19 = M63 ^ L4; - uint8_t L20 = L0 ^ L1; - uint8_t L21 = L1 ^ L7; - uint8_t L22 = L3 ^ L12; - uint8_t L23 = L18 ^ L2; - uint8_t L24 = L15 ^ L9; - uint8_t L25 = L6 ^ L10; - uint8_t L26 = L7 ^ L9; - uint8_t L27 = L8 ^ L10; - uint8_t L28 = L11 ^ L14; - uint8_t L29 = L11 ^ L17; - - uint8_t S0 = L6 ^ L24; - uint8_t S1 = L16 ^ L26; S1 ^= 1u; - uint8_t S2 = L19 ^ L28; S2 ^= 1u; - uint8_t S3 = L6 ^ L21; - uint8_t S4 = L20 ^ L22; - uint8_t S5 = L25 ^ L29; - uint8_t S6 = L13 ^ L27; S6 ^= 1u; - uint8_t S7 = L6 ^ L23; S7 ^= 1u; - - return (uint8_t)( - ((S0 & 1u) << 7) | - ((S1 & 1u) << 6) | - ((S2 & 1u) << 5) | - ((S3 & 1u) << 4) | - ((S4 & 1u) << 3) | - ((S5 & 1u) << 2) | - ((S6 & 1u) << 1) | - ( S7 & 1u)); -} - -// FIPS 197 round-constants for AES-256 (only Rcon[0..6] needed). -constant uint8_t RCON[7] = { 0x01u, 0x02u, 0x04u, 0x08u, 0x10u, 0x20u, 0x40u }; - -inline uint8_t xtime(uint8_t x) { - return (uint8_t)((x << 1) ^ (((x >> 7) & 1u) * 0x1bu)); -} - -// --------------------------------------------------------------------------- -// AES-256 key expansion (FIPS 197 §5.2). Output: 240-byte round-key schedule. -// 60 32-bit words; each word is 4 bytes. -// --------------------------------------------------------------------------- -inline void aes256_expand_key(const device uint8_t* key, thread uint8_t* rk) { - for (uint i = 0u; i < 32u; ++i) rk[i] = key[i]; - - for (uint i = 8u; i < 60u; ++i) { - uint8_t t0 = rk[(i - 1u) * 4u + 0u]; - uint8_t t1 = rk[(i - 1u) * 4u + 1u]; - uint8_t t2 = rk[(i - 1u) * 4u + 2u]; - uint8_t t3 = rk[(i - 1u) * 4u + 3u]; - - if ((i & 7u) == 0u) { - uint8_t r0 = t1, r1 = t2, r2 = t3, r3 = t0; - t0 = (uint8_t)(aes_sbox(r0) ^ RCON[(i / 8u) - 1u]); - t1 = aes_sbox(r1); - t2 = aes_sbox(r2); - t3 = aes_sbox(r3); - } else if ((i & 7u) == 4u) { - t0 = aes_sbox(t0); - t1 = aes_sbox(t1); - t2 = aes_sbox(t2); - t3 = aes_sbox(t3); - } - - rk[i * 4u + 0u] = (uint8_t)(rk[(i - 8u) * 4u + 0u] ^ t0); - rk[i * 4u + 1u] = (uint8_t)(rk[(i - 8u) * 4u + 1u] ^ t1); - rk[i * 4u + 2u] = (uint8_t)(rk[(i - 8u) * 4u + 2u] ^ t2); - rk[i * 4u + 3u] = (uint8_t)(rk[(i - 8u) * 4u + 3u] ^ t3); - } -} - -// --------------------------------------------------------------------------- -// AES-256 encrypt one 16-byte block (FIPS 197 §5.1). State layout is column- -// major: s[c*4 + r]. Byte-for-byte equivalent to aes::encrypt_block in -// cpp/aead.cpp. -// --------------------------------------------------------------------------- -inline void aes256_encrypt_block(thread const uint8_t* rk, - thread const uint8_t* in, - thread uint8_t* out) { - uint8_t s[16]; - for (uint i = 0u; i < 16u; ++i) s[i] = (uint8_t)(in[i] ^ rk[i]); - - for (uint round = 1u; round < AES256_ROUNDS; ++round) { - // SubBytes. - for (uint i = 0u; i < 16u; ++i) s[i] = aes_sbox(s[i]); - - // ShiftRows (column-major: row r at indices r, r+4, r+8, r+12). - uint8_t t; - // Row 1: rotate left by 1. - t = s[1]; s[1] = s[5]; s[5] = s[9]; s[9] = s[13]; s[13] = t; - // Row 2: rotate left by 2. - t = s[2]; s[2] = s[10]; s[10] = t; - t = s[6]; s[6] = s[14]; s[14] = t; - // Row 3: rotate left by 3 == rotate right by 1. - t = s[15]; s[15] = s[11]; s[11] = s[7]; s[7] = s[3]; s[3] = t; - - // MixColumns. - for (uint c = 0u; c < 4u; ++c) { - uint8_t a0 = s[c*4u + 0u]; - uint8_t a1 = s[c*4u + 1u]; - uint8_t a2 = s[c*4u + 2u]; - uint8_t a3 = s[c*4u + 3u]; - uint8_t x = a0 ^ a1 ^ a2 ^ a3; - uint8_t y0 = a0; - s[c*4u + 0u] = (uint8_t)(a0 ^ x ^ xtime(a0 ^ a1)); - s[c*4u + 1u] = (uint8_t)(a1 ^ x ^ xtime(a1 ^ a2)); - s[c*4u + 2u] = (uint8_t)(a2 ^ x ^ xtime(a2 ^ a3)); - s[c*4u + 3u] = (uint8_t)(a3 ^ x ^ xtime(a3 ^ y0)); - } - - // AddRoundKey. - for (uint i = 0u; i < 16u; ++i) s[i] ^= rk[round * 16u + i]; - } - - // Final round: SubBytes + ShiftRows + AddRoundKey (no MixColumns). - for (uint i = 0u; i < 16u; ++i) s[i] = aes_sbox(s[i]); - { - uint8_t t; - t = s[1]; s[1] = s[5]; s[5] = s[9]; s[9] = s[13]; s[13] = t; - t = s[2]; s[2] = s[10]; s[10] = t; - t = s[6]; s[6] = s[14]; s[14] = t; - t = s[15]; s[15] = s[11]; s[11] = s[7]; s[7] = s[3]; s[3] = t; - } - for (uint i = 0u; i < 16u; ++i) { - out[i] = (uint8_t)(s[i] ^ rk[AES256_ROUNDS * 16u + i]); - } -} - -// --------------------------------------------------------------------------- -// GHASH multiplication in GF(2^128) with reduction polynomial -// x^128 + x^7 + x^2 + x + 1 (NIST SP 800-38D §6.3). Bits are GCM-style -// (bit 0 = MSB of byte 0). Constant-time: 128 iterations always. -// --------------------------------------------------------------------------- -inline void ghash_mul(thread uint8_t* z, thread const uint8_t* h) { - uint8_t v[16]; - for (uint i = 0u; i < 16u; ++i) v[i] = h[i]; - uint8_t r[16]; - for (uint i = 0u; i < 16u; ++i) r[i] = 0u; - - for (uint i = 0u; i < 128u; ++i) { - uint8_t zbit = (uint8_t)((z[i >> 3] >> (7u - (i & 7u))) & 1u); - uint8_t mask = (uint8_t)(0u - (uint)zbit); // 0xff if 1, else 0 - for (uint j = 0u; j < 16u; ++j) r[j] ^= (uint8_t)(v[j] & mask); - - uint8_t lsb = (uint8_t)(v[15] & 1u); - for (uint j = 15u; j > 0u; --j) { - v[j] = (uint8_t)((v[j] >> 1) | ((v[j-1u] & 1u) << 7)); - } - v[0] >>= 1; - uint8_t rmask = (uint8_t)(0u - (uint)lsb); - v[0] ^= (uint8_t)(0xe1u & rmask); - } - - for (uint i = 0u; i < 16u; ++i) z[i] = r[i]; -} - -// inc32 on the rightmost 32 bits (big-endian) of a 16-byte counter. -inline void inc32(thread uint8_t* ctr) { - uint c = ((uint)ctr[12] << 24) | ((uint)ctr[13] << 16) - | ((uint)ctr[14] << 8) | (uint)ctr[15]; - c += 1u; - ctr[12] = (uint8_t)(c >> 24); - ctr[13] = (uint8_t)(c >> 16); - ctr[14] = (uint8_t)(c >> 8); - ctr[15] = (uint8_t)(c); -} - -// Absorb a buffer (from device memory) into GHASH, padding the final partial -// block with zeros. -inline void ghash_update(thread uint8_t* y, thread const uint8_t* h, - const device uint8_t* arena, - uint off, uint len) { - uint pos = 0u; - while (len - pos >= 16u) { - for (uint i = 0u; i < 16u; ++i) y[i] ^= arena[off + pos + i]; - ghash_mul(y, h); - pos += 16u; - } - uint rem = len - pos; - if (rem > 0u) { - uint8_t buf[16]; - for (uint i = 0u; i < 16u; ++i) buf[i] = 0u; - for (uint i = 0u; i < rem; ++i) buf[i] = arena[off + pos + i]; - for (uint i = 0u; i < 16u; ++i) y[i] ^= buf[i]; - ghash_mul(y, h); - } -} - -// Same, but absorbs from a writable arena (used after we have written -// ciphertext to outputs_arena). -inline void ghash_update_dev_w(thread uint8_t* y, thread const uint8_t* h, - device const uint8_t* arena, - uint off, uint len) { - uint pos = 0u; - while (len - pos >= 16u) { - for (uint i = 0u; i < 16u; ++i) y[i] ^= arena[off + pos + i]; - ghash_mul(y, h); - pos += 16u; - } - uint rem = len - pos; - if (rem > 0u) { - uint8_t buf[16]; - for (uint i = 0u; i < 16u; ++i) buf[i] = 0u; - for (uint i = 0u; i < rem; ++i) buf[i] = arena[off + pos + i]; - for (uint i = 0u; i < 16u; ++i) y[i] ^= buf[i]; - ghash_mul(y, h); - } -} - -// --------------------------------------------------------------------------- -// Kernel: one thread per AEAD job. AES-256-GCM seal. -// --------------------------------------------------------------------------- - -kernel void aes_gcm_jobs( - device const AeadJob* jobs [[buffer(0)]], - device const uint8_t* keys [[buffer(1)]], - device const uint8_t* nonces [[buffer(2)]], - device const uint8_t* inputs_arena [[buffer(3)]], - device uint8_t* outputs_arena [[buffer(4)]], - device const uint& n_jobs [[buffer(5)]], - uint gid [[thread_position_in_grid]]) -{ - if (gid >= n_jobs) return; - - AeadJob job = jobs[gid]; - const device uint8_t* key = keys + job.key_offset; - const device uint8_t* iv = nonces + job.nonce_offset; - - // ---- Key schedule ------------------------------------------------------ - uint8_t rk[AES256_KS_BYTES]; - aes256_expand_key(key, rk); - - // ---- H = AES_K(0^128) -------------------------------------------------- - uint8_t H[16]; - { - uint8_t zero[16]; - for (uint i = 0u; i < 16u; ++i) zero[i] = 0u; - aes256_encrypt_block(rk, zero, H); - } - - // ---- J0 = IV || 0x00000001 (96-bit IV path) ---------------------------- - uint8_t J0[16]; - for (uint i = 0u; i < 12u; ++i) J0[i] = iv[i]; - J0[12] = 0u; J0[13] = 0u; J0[14] = 0u; J0[15] = 1u; - - // ---- Encrypt plaintext under counter starting at inc32(J0) ------------- - { - uint8_t ctr[16]; - for (uint i = 0u; i < 16u; ++i) ctr[i] = J0[i]; - inc32(ctr); - - uint pos = 0u; - while (pos < job.pt_len) { - uint8_t ks[16]; - aes256_encrypt_block(rk, ctr, ks); - uint take = job.pt_len - pos; - if (take > 16u) take = 16u; - for (uint i = 0u; i < take; ++i) { - outputs_arena[job.ct_offset + pos + i] = - (uint8_t)(inputs_arena[job.pt_offset + pos + i] ^ ks[i]); - } - inc32(ctr); - pos += take; - } - } - - // ---- GHASH over (aad || pad || ct || pad || lens_in_bits) -------------- - uint8_t Y[16]; - for (uint i = 0u; i < 16u; ++i) Y[i] = 0u; - ghash_update(Y, H, inputs_arena, job.aad_offset, job.aad_len); - ghash_update_dev_w(Y, H, outputs_arena, job.ct_offset, job.pt_len); - { - uint8_t lens[16]; - ulong la = (ulong)job.aad_len * 8ul; - ulong lc = (ulong)job.pt_len * 8ul; - // BE encoding. - for (uint i = 0u; i < 8u; ++i) lens[i] = (uint8_t)(la >> (8u * (7u - i))); - for (uint i = 0u; i < 8u; ++i) lens[8u + i] = (uint8_t)(lc >> (8u * (7u - i))); - for (uint i = 0u; i < 16u; ++i) Y[i] ^= lens[i]; - ghash_mul(Y, H); - } - - // ---- Tag = GHASH XOR AES_K(J0) ----------------------------------------- - { - uint8_t s[16]; - aes256_encrypt_block(rk, J0, s); - for (uint i = 0u; i < 16u; ++i) { - outputs_arena[job.tag_offset + i] = (uint8_t)(Y[i] ^ s[i]); - } - } -} diff --git a/aead/gpu/metal/aes_gcm_driver.mm b/aead/gpu/metal/aes_gcm_driver.mm deleted file mode 100644 index 304fdb0..0000000 --- a/aead/gpu/metal/aes_gcm_driver.mm +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Metal driver for batched AES-256-GCM (NIST SP 800-38D, 96-bit IV). -// macOS / iOS only. Loads aes_gcm.metallib, dispatches `aes_gcm_jobs` with -// one thread per message. Output is byte-equal to lux::crypto::aead:: -// aes_256_gcm::encrypt() in cpp/aead.cpp. - -#if __APPLE__ && __OBJC__ - -#import -#import - -#include -#include -#include - -namespace { - -// Mirror the Metal struct AeadJob in metal/aes_gcm.metal. -struct AeadJobGPU { - uint32_t aad_offset; - uint32_t aad_len; - uint32_t pt_offset; - uint32_t pt_len; - uint32_t ct_offset; - uint32_t tag_offset; - uint32_t key_offset; - uint32_t nonce_offset; -}; - -} // namespace - -// Encrypt n messages in a single GPU dispatch. -// -// Inputs: -// keys -- n * 32 bytes (AES-256 keys, packed) -// ivs -- n * 12 bytes (96-bit IVs, packed) -// inputs_arena -- packed (aad || plaintext) per message; offsets/lens in jobs -// inputs_arena_len -- total bytes in inputs_arena -// jobs -- n AeadJobGPU records -// -// Outputs: -// outputs_arena -- caller-allocated; receives (ciphertext || tag) per -// message at the offsets specified in jobs. -// outputs_arena_len-- total capacity of outputs_arena. -// -// Returns 0 on success, negative on failure. -extern "C" int aead_aes_256_gcm_batch_metal( - const uint8_t* keys, - const uint8_t* ivs, - const uint8_t* inputs_arena, - size_t inputs_arena_len, - const AeadJobGPU* jobs, - size_t n, - uint8_t* outputs_arena, - size_t outputs_arena_len, - const char* metallib_path) { - - if (n == 0) return 0; - if (!keys || !ivs || !jobs || !outputs_arena || !metallib_path) return -1; - // inputs_arena may legally be NULL only if every message has aad_len=0 - // and pt_len=0 (i.e. inputs_arena_len==0). Metal requires a non-NULL - // bytes pointer; substitute a single-byte fallback. - static const uint8_t kEmpty = 0; - const uint8_t* in_ptr = inputs_arena ? inputs_arena : &kEmpty; - size_t in_len = inputs_arena_len > 0 ? inputs_arena_len : 1; - - @autoreleasepool { - id device = MTLCreateSystemDefaultDevice(); - if (!device) return -2; - - NSError* err = nil; - NSString* path = [NSString stringWithUTF8String:metallib_path]; - NSURL* url = [NSURL fileURLWithPath:path]; - id lib = [device newLibraryWithURL:url error:&err]; - if (!lib) return -3; - - id fn = [lib newFunctionWithName:@"aes_gcm_jobs"]; - if (!fn) return -4; - - id pipeline = - [device newComputePipelineStateWithFunction:fn error:&err]; - if (!pipeline) return -5; - - id queue = [device newCommandQueue]; - - id jobs_buf = [device newBufferWithBytes:jobs - length:n * sizeof(AeadJobGPU) - options:MTLResourceStorageModeShared]; - id keys_buf = [device newBufferWithBytes:keys - length:n * 32 - options:MTLResourceStorageModeShared]; - id ivs_buf = [device newBufferWithBytes:ivs - length:n * 12 - options:MTLResourceStorageModeShared]; - id inputs_buf = [device newBufferWithBytes:in_ptr - length:in_len - options:MTLResourceStorageModeShared]; - id outputs_buf = [device newBufferWithLength:outputs_arena_len - options:MTLResourceStorageModeShared]; - // Zero outputs_arena so any unused regions stay deterministic. - std::memset([outputs_buf contents], 0, outputs_arena_len); - uint32_t n_u32 = (uint32_t)n; - id n_buf = [device newBufferWithBytes:&n_u32 - length:sizeof(n_u32) - options:MTLResourceStorageModeShared]; - - id cmd = [queue commandBuffer]; - id enc = [cmd computeCommandEncoder]; - [enc setComputePipelineState:pipeline]; - [enc setBuffer:jobs_buf offset:0 atIndex:0]; - [enc setBuffer:keys_buf offset:0 atIndex:1]; - [enc setBuffer:ivs_buf offset:0 atIndex:2]; - [enc setBuffer:inputs_buf offset:0 atIndex:3]; - [enc setBuffer:outputs_buf offset:0 atIndex:4]; - [enc setBuffer:n_buf offset:0 atIndex:5]; - - NSUInteger tg_max = pipeline.maxTotalThreadsPerThreadgroup; - NSUInteger tg_w = tg_max < 64 ? tg_max : 64; - MTLSize threads_per_grid = MTLSizeMake(n, 1, 1); - MTLSize threads_per_tg = MTLSizeMake(tg_w, 1, 1); - [enc dispatchThreads:threads_per_grid threadsPerThreadgroup:threads_per_tg]; - [enc endEncoding]; - [cmd commit]; - [cmd waitUntilCompleted]; - - std::memcpy(outputs_arena, [outputs_buf contents], outputs_arena_len); - } - return 0; -} - -#endif // __APPLE__ && __OBJC__ diff --git a/aead/gpu/wgsl/aead_driver_wgpu.cpp b/aead/gpu/wgsl/aead_driver_wgpu.cpp deleted file mode 100644 index 681dfb9..0000000 --- a/aead/gpu/wgsl/aead_driver_wgpu.cpp +++ /dev/null @@ -1,348 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// WebGPU/WGSL host driver for batched AEAD ciphers (ChaCha20-Poly1305 + -// AES-256-GCM). Builds against wgpu-native (or Dawn). Output is byte-equal -// to the CPU body in cpp/aead.cpp. - -#include "aead_driver_wgpu.h" - -#if defined(LUX_AEAD_HAS_WEBGPU) - -#include -#if defined(LUX_AEAD_HAS_WGPU_NATIVE) -# include -#endif - -#include "aead_wgsl_sources.h" - -#include -#include -#include -#include -#include -#include - -namespace { - -WGPUStringView sv(const char* s) { - WGPUStringView v{}; - v.data = s; - v.length = (s == nullptr) ? 0 : std::strlen(s); - return v; -} - -void drain(WGPUInstance inst, WGPUDevice dev) { - if (inst) wgpuInstanceProcessEvents(inst); -#if defined(LUX_AEAD_HAS_WGPU_NATIVE) - if (dev) wgpuDevicePoll(dev, /*wait=*/WGPU_TRUE, nullptr); -#else - (void)dev; -#endif -} - -bool wait_map(WGPUInstance inst, WGPUDevice dev, WGPUBuffer buf, - WGPUMapMode mode, size_t off, size_t size) { - struct State { std::atomic done{false}; WGPUMapAsyncStatus status{WGPUMapAsyncStatus_Error}; } s; - WGPUBufferMapCallbackInfo cb{}; - cb.mode = WGPUCallbackMode_AllowProcessEvents; - cb.callback = [](WGPUMapAsyncStatus st, WGPUStringView, void* u, void*) { - auto* p = static_cast(u); - p->status = st; - p->done.store(true, std::memory_order_release); - }; - cb.userdata1 = &s; - wgpuBufferMapAsync(buf, mode, off, size, cb); - for (int spin = 0; spin < 8192; ++spin) { - if (s.done.load(std::memory_order_acquire)) break; - drain(inst, dev); - } - return s.done.load() && s.status == WGPUMapAsyncStatus_Success; -} - -struct Engine { - WGPUInstance instance{nullptr}; - WGPUAdapter adapter{nullptr}; - WGPUDevice device{nullptr}; - WGPUQueue queue{nullptr}; - WGPUShaderModule chacha_module{nullptr}; - WGPUShaderModule aes_module{nullptr}; - WGPUComputePipeline chacha_pipe{nullptr}; - WGPUComputePipeline aes_pipe{nullptr}; - bool initialized{false}; -}; - -Engine& engine() { static Engine e; return e; } - -bool init_engine() { - Engine& e = engine(); - if (e.initialized) return true; - - WGPUInstanceDescriptor idesc{}; - e.instance = wgpuCreateInstance(&idesc); - if (!e.instance) return false; - - struct AS { std::atomic done{false}; WGPUAdapter ad{nullptr}; } as; - WGPURequestAdapterOptions ropt{}; - ropt.powerPreference = WGPUPowerPreference_HighPerformance; - WGPURequestAdapterCallbackInfo rcb{}; - rcb.mode = WGPUCallbackMode_AllowProcessEvents; - rcb.callback = [](WGPURequestAdapterStatus st, WGPUAdapter ad, - WGPUStringView, void* u, void*) { - auto* p = static_cast(u); - if (st == WGPURequestAdapterStatus_Success) p->ad = ad; - p->done.store(true, std::memory_order_release); - }; - rcb.userdata1 = &as; - wgpuInstanceRequestAdapter(e.instance, &ropt, rcb); - for (int spin = 0; spin < 8192; ++spin) { - if (as.done.load(std::memory_order_acquire)) break; - wgpuInstanceProcessEvents(e.instance); - } - if (!as.ad) { - std::fprintf(stderr, "aead/wgpu: no adapter\n"); - return false; - } - e.adapter = as.ad; - - struct DS { std::atomic done{false}; WGPUDevice dev{nullptr}; } ds; - WGPUDeviceDescriptor ddesc{}; - WGPURequestDeviceCallbackInfo dcb{}; - dcb.mode = WGPUCallbackMode_AllowProcessEvents; - dcb.callback = [](WGPURequestDeviceStatus st, WGPUDevice dev, - WGPUStringView, void* u, void*) { - auto* p = static_cast(u); - if (st == WGPURequestDeviceStatus_Success) p->dev = dev; - p->done.store(true, std::memory_order_release); - }; - dcb.userdata1 = &ds; - wgpuAdapterRequestDevice(e.adapter, &ddesc, dcb); - for (int spin = 0; spin < 8192; ++spin) { - if (ds.done.load(std::memory_order_acquire)) break; - wgpuInstanceProcessEvents(e.instance); - } - if (!ds.dev) { - std::fprintf(stderr, "aead/wgpu: no device\n"); - return false; - } - e.device = ds.dev; - e.queue = wgpuDeviceGetQueue(e.device); - if (!e.queue) return false; - - auto compile = [&](const char* src, const char* label) -> WGPUShaderModule { - WGPUShaderSourceWGSL wgsl{}; - wgsl.chain.sType = WGPUSType_ShaderSourceWGSL; - wgsl.code = sv(src); - WGPUShaderModuleDescriptor smd{}; - smd.nextInChain = &wgsl.chain; - smd.label = sv(label); - return wgpuDeviceCreateShaderModule(e.device, &smd); - }; - - e.chacha_module = compile(kAEAD_WGSL_ChaCha20Poly1305, "aead_chacha20_poly1305"); - if (!e.chacha_module) { - std::fprintf(stderr, "aead/wgpu: chacha20_poly1305 compile failed\n"); - return false; - } - e.aes_module = compile(kAEAD_WGSL_AesGcm, "aead_aes_256_gcm"); - if (!e.aes_module) { - std::fprintf(stderr, "aead/wgpu: aes_gcm compile failed\n"); - return false; - } - - auto make_pipe = [&](WGPUShaderModule mod, const char* entry) -> WGPUComputePipeline { - WGPUComputePipelineDescriptor cpd{}; - cpd.compute.module = mod; - cpd.compute.entryPoint = sv(entry); - cpd.label = sv(entry); - return wgpuDeviceCreateComputePipeline(e.device, &cpd); - }; - e.chacha_pipe = make_pipe(e.chacha_module, "chacha20_poly1305_jobs"); - if (!e.chacha_pipe) { - std::fprintf(stderr, "aead/wgpu: chacha pipeline failed\n"); - return false; - } - e.aes_pipe = make_pipe(e.aes_module, "aes_gcm_jobs"); - if (!e.aes_pipe) { - std::fprintf(stderr, "aead/wgpu: aes pipeline failed\n"); - return false; - } - - e.initialized = true; - return true; -} - -WGPUBuffer make_buf(Engine& e, size_t bytes, WGPUBufferUsage usage) { - WGPUBufferDescriptor bd{}; - // Round up to 4 (WGSL storage alignment). - bd.size = (bytes + 3) & ~size_t(3); - if (bd.size == 0) bd.size = 4; - bd.usage = usage; - return wgpuDeviceCreateBuffer(e.device, &bd); -} - -// One unified dispatch path. The caller picks the pipeline and supplies the -// job-record stride (same for both AEAD ciphers: 8 * uint32_t = 32 bytes). -int dispatch_aead(WGPUComputePipeline pipeline, - const uint8_t* keys, size_t keys_bytes, - const uint8_t* nonces, size_t nonces_bytes, - const uint8_t* inputs_arena, size_t inputs_bytes, - const void* jobs, size_t jobs_bytes, - uint8_t* outputs, size_t outputs_bytes, - uint32_t n) { - Engine& e = engine(); - - WGPUBuffer buf_jobs = make_buf(e, jobs_bytes, WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst); - WGPUBuffer buf_keys = make_buf(e, keys_bytes, WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst); - WGPUBuffer buf_nonces = make_buf(e, nonces_bytes, WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst); - WGPUBuffer buf_in = make_buf(e, inputs_bytes ? inputs_bytes : 4, WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst); - WGPUBuffer buf_out = make_buf(e, outputs_bytes, WGPUBufferUsage_Storage | WGPUBufferUsage_CopySrc | WGPUBufferUsage_CopyDst); - WGPUBuffer buf_params = make_buf(e, 16, WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst); - WGPUBuffer buf_read = make_buf(e, outputs_bytes, WGPUBufferUsage_MapRead | WGPUBufferUsage_CopyDst); - if (!buf_jobs || !buf_keys || !buf_nonces || !buf_in || !buf_out || - !buf_params || !buf_read) { - return -3; - } - - // wgpuQueueWriteBuffer requires copy size multiple of 4. Round each - // payload up by copying into a 4-aligned staging vector. - auto padded = [](const uint8_t* src, size_t n) { - size_t up = (n + 3) & ~size_t(3); - std::vector v(up, 0); - if (n > 0 && src != nullptr) std::memcpy(v.data(), src, n); - return v; - }; - { - auto j = padded(static_cast(jobs), jobs_bytes); - wgpuQueueWriteBuffer(e.queue, buf_jobs, 0, j.data(), j.size()); - } - { - auto k = padded(keys, keys_bytes); - wgpuQueueWriteBuffer(e.queue, buf_keys, 0, k.data(), k.size()); - } - { - auto n_pad = padded(nonces, nonces_bytes); - wgpuQueueWriteBuffer(e.queue, buf_nonces, 0, n_pad.data(), n_pad.size()); - } - if (inputs_bytes > 0) { - auto in_pad = padded(inputs_arena, inputs_bytes); - wgpuQueueWriteBuffer(e.queue, buf_in, 0, in_pad.data(), in_pad.size()); - } else { - uint32_t z = 0; - wgpuQueueWriteBuffer(e.queue, buf_in, 0, &z, 4); - } - // Zero outputs (matches Metal driver: deterministic unused regions). - { - size_t up = (outputs_bytes + 3) & ~size_t(3); - std::vector zeros(up, 0); - wgpuQueueWriteBuffer(e.queue, buf_out, 0, zeros.data(), up); - } - uint32_t params[4] = { n, 0, 0, 0 }; - wgpuQueueWriteBuffer(e.queue, buf_params, 0, params, 16); - - auto align4 = [](size_t n) { return (n + 3) & ~size_t(3); }; - WGPUBindGroupLayout bgl = wgpuComputePipelineGetBindGroupLayout(pipeline, 0); - WGPUBindGroupEntry bge[6] = {}; - bge[0].binding = 0; bge[0].buffer = buf_jobs; bge[0].size = align4(jobs_bytes); - bge[1].binding = 1; bge[1].buffer = buf_keys; bge[1].size = align4(keys_bytes); - bge[2].binding = 2; bge[2].buffer = buf_nonces; bge[2].size = align4(nonces_bytes); - bge[3].binding = 3; bge[3].buffer = buf_in; bge[3].size = align4(inputs_bytes ? inputs_bytes : 4); - bge[4].binding = 4; bge[4].buffer = buf_out; bge[4].size = align4(outputs_bytes); - bge[5].binding = 5; bge[5].buffer = buf_params; bge[5].size = 16; - WGPUBindGroupDescriptor bgd{}; - bgd.layout = bgl; - bgd.entryCount = 6; - bgd.entries = bge; - WGPUBindGroup bg = wgpuDeviceCreateBindGroup(e.device, &bgd); - if (!bg) return -4; - - WGPUCommandEncoderDescriptor ced{}; - WGPUCommandEncoder ce = wgpuDeviceCreateCommandEncoder(e.device, &ced); - WGPUComputePassDescriptor cpd{}; - WGPUComputePassEncoder cpe = wgpuCommandEncoderBeginComputePass(ce, &cpd); - wgpuComputePassEncoderSetPipeline(cpe, pipeline); - wgpuComputePassEncoderSetBindGroup(cpe, 0, bg, 0, nullptr); - // Workgroup size = 64; dispatch ceil(n / 64) groups. - const uint32_t groups = (n + 63u) / 64u; - wgpuComputePassEncoderDispatchWorkgroups(cpe, groups, 1, 1); - wgpuComputePassEncoderEnd(cpe); - const size_t out_pad = (outputs_bytes + 3) & ~size_t(3); - wgpuCommandEncoderCopyBufferToBuffer(ce, buf_out, 0, buf_read, 0, out_pad); - WGPUCommandBufferDescriptor cbd{}; - WGPUCommandBuffer cmd = wgpuCommandEncoderFinish(ce, &cbd); - wgpuQueueSubmit(e.queue, 1, &cmd); - - if (!wait_map(e.instance, e.device, buf_read, WGPUMapMode_Read, 0, out_pad)) { - return -5; - } - const void* mapped = wgpuBufferGetConstMappedRange(buf_read, 0, out_pad); - std::memcpy(outputs, mapped, outputs_bytes); - wgpuBufferUnmap(buf_read); - - wgpuComputePassEncoderRelease(cpe); - wgpuCommandEncoderRelease(ce); - wgpuCommandBufferRelease(cmd); - wgpuBindGroupRelease(bg); - wgpuBindGroupLayoutRelease(bgl); - wgpuBufferRelease(buf_jobs); - wgpuBufferRelease(buf_keys); - wgpuBufferRelease(buf_nonces); - wgpuBufferRelease(buf_in); - wgpuBufferRelease(buf_out); - wgpuBufferRelease(buf_params); - wgpuBufferRelease(buf_read); - - return 0; -} - -} // namespace - -extern "C" int lux_aead_wgpu_available(void) { - return init_engine() ? 1 : 0; -} - -extern "C" int aead_chacha20poly1305_batch_wgpu( - const uint8_t* keys, const uint8_t* nonces, - const uint8_t* inputs_arena, size_t inputs_arena_len, - const void* jobs, size_t n, - uint8_t* outputs_arena, size_t outputs_arena_len) { - if (n == 0) return 0; - if (!keys || !nonces || !jobs || !outputs_arena) return -1; - if (!init_engine()) return -2; - return dispatch_aead(engine().chacha_pipe, - keys, n * 32, - nonces, n * 12, - inputs_arena, inputs_arena_len, - jobs, n * 32, - outputs_arena, outputs_arena_len, - (uint32_t)n); -} - -extern "C" int aead_aes_256_gcm_batch_wgpu( - const uint8_t* keys, const uint8_t* ivs, - const uint8_t* inputs_arena, size_t inputs_arena_len, - const void* jobs, size_t n, - uint8_t* outputs_arena, size_t outputs_arena_len) { - if (n == 0) return 0; - if (!keys || !ivs || !jobs || !outputs_arena) return -1; - if (!init_engine()) return -2; - return dispatch_aead(engine().aes_pipe, - keys, n * 32, - ivs, n * 12, - inputs_arena, inputs_arena_len, - jobs, n * 32, - outputs_arena, outputs_arena_len, - (uint32_t)n); -} - -#else // LUX_AEAD_HAS_WEBGPU not defined - -extern "C" int lux_aead_wgpu_available(void) { return 0; } -extern "C" int aead_chacha20poly1305_batch_wgpu( - const uint8_t*, const uint8_t*, const uint8_t*, size_t, - const void*, size_t, uint8_t*, size_t) { return -1; } -extern "C" int aead_aes_256_gcm_batch_wgpu( - const uint8_t*, const uint8_t*, const uint8_t*, size_t, - const void*, size_t, uint8_t*, size_t) { return -1; } - -#endif diff --git a/aead/gpu/wgsl/aead_driver_wgpu.h b/aead/gpu/wgsl/aead_driver_wgpu.h deleted file mode 100644 index ad29617..0000000 --- a/aead/gpu/wgsl/aead_driver_wgpu.h +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Public C ABI for the WebGPU/WGSL batched AEAD driver. The driver -// compiles in two modes: -// * LUX_AEAD_HAS_WEBGPU defined -> real wgpu-native dispatch -// * not defined -> stub mode, every entry returns -1 -// -// All entry points return 0 on success, negative on failure (matching -// the Metal and CUDA driver conventions). - -#pragma once - -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -int lux_aead_wgpu_available(void); - -int aead_chacha20poly1305_batch_wgpu( - const uint8_t* keys, - const uint8_t* nonces, - const uint8_t* inputs_arena, - size_t inputs_arena_len, - const void* jobs, - size_t n, - uint8_t* outputs_arena, - size_t outputs_arena_len); - -int aead_aes_256_gcm_batch_wgpu( - const uint8_t* keys, - const uint8_t* ivs, - const uint8_t* inputs_arena, - size_t inputs_arena_len, - const void* jobs, - size_t n, - uint8_t* outputs_arena, - size_t outputs_arena_len); - -#ifdef __cplusplus -} -#endif diff --git a/aead/gpu/wgsl/aes_gcm.wgsl b/aead/gpu/wgsl/aes_gcm.wgsl deleted file mode 100644 index eddcc8a..0000000 --- a/aead/gpu/wgsl/aes_gcm.wgsl +++ /dev/null @@ -1,456 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Batched AES-256-GCM (NIST SP 800-38D, 96-bit IV) compute shader. -// One thread per (key, iv, aad, plaintext) message; output is byte-equal -// to lux::crypto::aead::aes_256_gcm::encrypt() in cpp/aead.cpp and to -// gpu/metal/aes_gcm.metal. -// -// WGSL has no u8: byte arrays are packed into u32 storage buffers -// (LSB-first within each word). Constant-time S-box (Boyar-Peralta) and -// constant-time GHASH (128 iters) are preserved. - -struct AeadJob { - aad_offset: u32, - aad_len: u32, - pt_offset: u32, - pt_len: u32, - ct_offset: u32, - tag_offset: u32, - key_offset: u32, // i*32 - nonce_offset: u32, // i*12 -} - -struct Params { - n_jobs: u32, - _pad0: u32, - _pad1: u32, - _pad2: u32, -} - -@group(0) @binding(0) var jobs: array; -@group(0) @binding(1) var keys: array; -@group(0) @binding(2) var nonces: array; -@group(0) @binding(3) var inputs: array; -@group(0) @binding(4) var outputs: array; -@group(0) @binding(5) var params: Params; - -// ---- Byte access --------------------------------------------------------- - -fn rd_in(byte_off: u32) -> u32 { - let w = byte_off >> 2u; - let s = (byte_off & 3u) * 8u; - return (inputs[w] >> s) & 0xffu; -} -fn rd_out(byte_off: u32) -> u32 { - let w = byte_off >> 2u; - let s = (byte_off & 3u) * 8u; - return (outputs[w] >> s) & 0xffu; -} -fn wr_out(byte_off: u32, v: u32) { - let w = byte_off >> 2u; - let s = (byte_off & 3u) * 8u; - let mask = 0xffu << s; - outputs[w] = (outputs[w] & ~mask) | ((v & 0xffu) << s); -} -fn rd_msg(arena: u32, byte_off: u32) -> u32 { - if (arena == 0u) { return rd_in(byte_off); } - return rd_out(byte_off); -} - -fn rd_key_byte(key_byte_off: u32, i: u32) -> u32 { - let abs = key_byte_off + i; - let w = abs >> 2u; - let s = (abs & 3u) * 8u; - return (keys[w] >> s) & 0xffu; -} -fn rd_iv_byte(nonce_byte_off: u32, i: u32) -> u32 { - let abs = nonce_byte_off + i; - let w = abs >> 2u; - let s = (abs & 3u) * 8u; - return (nonces[w] >> s) & 0xffu; -} - -// ---- AES S-box (Boyar-Peralta circuit, J. Cryptol. 2010) ----------------- -// Byte-for-byte port of lux::crypto::aead::aes::aes_sbox. - -fn aes_sbox(x: u32) -> u32 { - let U0 = (x >> 7u) & 1u; - let U1 = (x >> 6u) & 1u; - let U2 = (x >> 5u) & 1u; - let U3 = (x >> 4u) & 1u; - let U4 = (x >> 3u) & 1u; - let U5 = (x >> 2u) & 1u; - let U6 = (x >> 1u) & 1u; - let U7 = x & 1u; - - let T1 = U0 ^ U3; - let T2 = U0 ^ U5; - let T3 = U0 ^ U6; - let T4 = U3 ^ U5; - let T5 = U4 ^ U6; - let T6 = T1 ^ T5; - let T7 = U1 ^ U2; - let T8 = U7 ^ T6; - let T9 = U7 ^ T7; - let T10 = T6 ^ T7; - let T11 = U1 ^ U5; - let T12 = U2 ^ U5; - let T13 = T3 ^ T4; - let T14 = T6 ^ T11; - let T15 = T5 ^ T11; - let T16 = T5 ^ T12; - let T17 = T9 ^ T16; - let T18 = U3 ^ U7; - let T19 = T7 ^ T18; - let T20 = T1 ^ T19; - let T21 = U6 ^ U7; - let T22 = T7 ^ T21; - let T23 = T2 ^ T22; - let T24 = T2 ^ T10; - let T25 = T20 ^ T17; - let T26 = T3 ^ T16; - let T27 = T1 ^ T12; - - let M1 = T13 & T6; - let M2 = T23 & T8; - let M3 = T14 ^ M1; - let M4 = T19 & U7; - let M5 = M4 ^ M1; - let M6 = T3 & T16; - let M7 = T22 & T9; - let M8 = T26 ^ M6; - let M9 = T20 & T17; - let M10 = M9 ^ M6; - let M11 = T1 & T15; - let M12 = T4 & T27; - let M13 = M12 ^ M11; - let M14 = T2 & T10; - let M15 = M14 ^ M11; - let M16 = M3 ^ M2; - let M17 = M5 ^ T24; - let M18 = M8 ^ M7; - let M19 = M10 ^ M15; - let M20 = M16 ^ M13; - let M21 = M17 ^ M15; - let M22 = M18 ^ M13; - let M23 = M19 ^ T25; - let M24 = M22 ^ M23; - let M25 = M22 & M20; - let M26 = M21 ^ M25; - let M27 = M20 ^ M21; - let M28 = M23 ^ M25; - let M29 = M28 & M27; - let M30 = M26 & M24; - let M31 = M20 & M23; - let M32 = M27 & M31; - let M33 = M27 ^ M25; - let M34 = M21 & M22; - let M35 = M24 & M34; - let M36 = M24 ^ M25; - let M37 = M21 ^ M29; - let M38 = M32 ^ M33; - let M39 = M23 ^ M30; - let M40 = M35 ^ M36; - let M41 = M38 ^ M40; - let M42 = M37 ^ M39; - let M43 = M37 ^ M38; - let M44 = M39 ^ M40; - let M45 = M42 ^ M41; - let M46 = M44 & T6; - let M47 = M40 & T8; - let M48 = M39 & U7; - let M49 = M43 & T16; - let M50 = M38 & T9; - let M51 = M37 & T17; - let M52 = M42 & T15; - let M53 = M45 & T27; - let M54 = M41 & T10; - let M55 = M44 & T13; - let M56 = M40 & T23; - let M57 = M39 & T19; - let M58 = M43 & T3; - let M59 = M38 & T22; - let M60 = M37 & T20; - let M61 = M42 & T1; - let M62 = M45 & T4; - let M63 = M41 & T2; - - let L0 = M61 ^ M62; - let L1 = M50 ^ M56; - let L2 = M46 ^ M48; - let L3 = M47 ^ M55; - let L4 = M54 ^ M58; - let L5 = M49 ^ M61; - let L6 = M62 ^ L5; - let L7 = M46 ^ L3; - let L8 = M51 ^ M59; - let L9 = M52 ^ M53; - let L10 = M53 ^ L4; - let L11 = M60 ^ L2; - let L12 = M48 ^ M51; - let L13 = M50 ^ L0; - let L14 = M52 ^ M61; - let L15 = M55 ^ L1; - let L16 = M56 ^ L0; - let L17 = M57 ^ L1; - let L18 = M58 ^ L8; - let L19 = M63 ^ L4; - let L20 = L0 ^ L1; - let L21 = L1 ^ L7; - let L22 = L3 ^ L12; - let L23 = L18 ^ L2; - let L24 = L15 ^ L9; - let L25 = L6 ^ L10; - let L26 = L7 ^ L9; - let L27 = L8 ^ L10; - let L28 = L11 ^ L14; - let L29 = L11 ^ L17; - - let S0 = L6 ^ L24; - var S1 = L16 ^ L26; S1 = S1 ^ 1u; - var S2 = L19 ^ L28; S2 = S2 ^ 1u; - let S3 = L6 ^ L21; - let S4 = L20 ^ L22; - let S5 = L25 ^ L29; - var S6 = L13 ^ L27; S6 = S6 ^ 1u; - var S7 = L6 ^ L23; S7 = S7 ^ 1u; - - return ((S0 & 1u) << 7u) - | ((S1 & 1u) << 6u) - | ((S2 & 1u) << 5u) - | ((S3 & 1u) << 4u) - | ((S4 & 1u) << 3u) - | ((S5 & 1u) << 2u) - | ((S6 & 1u) << 1u) - | (S7 & 1u); -} - -const RCON: array = array( - 0x01u, 0x02u, 0x04u, 0x08u, 0x10u, 0x20u, 0x40u -); - -fn xtime(x: u32) -> u32 { - return ((x << 1u) ^ (((x >> 7u) & 1u) * 0x1bu)) & 0xffu; -} - -// ---- AES-256 key expansion (FIPS 197 §5.2) ------------------------------- -// 240 bytes of round-key material, stored byte-by-byte. - -var rk_state: array; - -fn aes256_expand_key(key_byte_off: u32) { - for (var i = 0u; i < 32u; i = i + 1u) { - rk_state[i] = rd_key_byte(key_byte_off, i); - } - for (var i = 8u; i < 60u; i = i + 1u) { - var t0 = rk_state[(i - 1u) * 4u + 0u]; - var t1 = rk_state[(i - 1u) * 4u + 1u]; - var t2 = rk_state[(i - 1u) * 4u + 2u]; - var t3 = rk_state[(i - 1u) * 4u + 3u]; - - if ((i & 7u) == 0u) { - let r0 = t1; let r1 = t2; let r2 = t3; let r3 = t0; - t0 = aes_sbox(r0) ^ RCON[(i / 8u) - 1u]; - t1 = aes_sbox(r1); - t2 = aes_sbox(r2); - t3 = aes_sbox(r3); - } else if ((i & 7u) == 4u) { - t0 = aes_sbox(t0); - t1 = aes_sbox(t1); - t2 = aes_sbox(t2); - t3 = aes_sbox(t3); - } - - rk_state[i * 4u + 0u] = (rk_state[(i - 8u) * 4u + 0u] ^ t0) & 0xffu; - rk_state[i * 4u + 1u] = (rk_state[(i - 8u) * 4u + 1u] ^ t1) & 0xffu; - rk_state[i * 4u + 2u] = (rk_state[(i - 8u) * 4u + 2u] ^ t2) & 0xffu; - rk_state[i * 4u + 3u] = (rk_state[(i - 8u) * 4u + 3u] ^ t3) & 0xffu; - } -} - -// ---- AES-256 encrypt one 16-byte block (FIPS 197 §5.1) ------------------- - -var aes_state: array; - -fn aes256_encrypt_block(in_buf: array) -> array { - for (var i = 0u; i < 16u; i = i + 1u) { - aes_state[i] = (in_buf[i] ^ rk_state[i]) & 0xffu; - } - for (var round = 1u; round < 14u; round = round + 1u) { - for (var i = 0u; i < 16u; i = i + 1u) { - aes_state[i] = aes_sbox(aes_state[i]); - } - // ShiftRows. - var t: u32; - t = aes_state[1]; aes_state[1] = aes_state[5]; aes_state[5] = aes_state[9]; aes_state[9] = aes_state[13]; aes_state[13] = t; - t = aes_state[2]; aes_state[2] = aes_state[10]; aes_state[10] = t; - t = aes_state[6]; aes_state[6] = aes_state[14]; aes_state[14] = t; - t = aes_state[15]; aes_state[15] = aes_state[11]; aes_state[11] = aes_state[7]; aes_state[7] = aes_state[3]; aes_state[3] = t; - // MixColumns. - for (var c = 0u; c < 4u; c = c + 1u) { - let a0 = aes_state[c*4u + 0u]; - let a1 = aes_state[c*4u + 1u]; - let a2 = aes_state[c*4u + 2u]; - let a3 = aes_state[c*4u + 3u]; - let x = a0 ^ a1 ^ a2 ^ a3; - let y0 = a0; - aes_state[c*4u + 0u] = (a0 ^ x ^ xtime(a0 ^ a1)) & 0xffu; - aes_state[c*4u + 1u] = (a1 ^ x ^ xtime(a1 ^ a2)) & 0xffu; - aes_state[c*4u + 2u] = (a2 ^ x ^ xtime(a2 ^ a3)) & 0xffu; - aes_state[c*4u + 3u] = (a3 ^ x ^ xtime(a3 ^ y0)) & 0xffu; - } - for (var i = 0u; i < 16u; i = i + 1u) { - aes_state[i] = (aes_state[i] ^ rk_state[round * 16u + i]) & 0xffu; - } - } - // Final round. - for (var i = 0u; i < 16u; i = i + 1u) { - aes_state[i] = aes_sbox(aes_state[i]); - } - var t: u32; - t = aes_state[1]; aes_state[1] = aes_state[5]; aes_state[5] = aes_state[9]; aes_state[9] = aes_state[13]; aes_state[13] = t; - t = aes_state[2]; aes_state[2] = aes_state[10]; aes_state[10] = t; - t = aes_state[6]; aes_state[6] = aes_state[14]; aes_state[14] = t; - t = aes_state[15]; aes_state[15] = aes_state[11]; aes_state[11] = aes_state[7]; aes_state[7] = aes_state[3]; aes_state[3] = t; - - var out_buf: array; - for (var i = 0u; i < 16u; i = i + 1u) { - out_buf[i] = (aes_state[i] ^ rk_state[14u * 16u + i]) & 0xffu; - } - return out_buf; -} - -// ---- GHASH (NIST SP 800-38D §6.3, constant-time 128 iterations) ---------- - -fn ghash_mul(z_in: array, h_in: array) -> array { - var z = z_in; - var v = h_in; - var r: array; - for (var i = 0u; i < 16u; i = i + 1u) { r[i] = 0u; } - for (var i = 0u; i < 128u; i = i + 1u) { - let zbit = (z[i >> 3u] >> (7u - (i & 7u))) & 1u; - let mask = (0u - zbit) & 0xffu; - for (var j = 0u; j < 16u; j = j + 1u) { - r[j] = (r[j] ^ (v[j] & mask)) & 0xffu; - } - let lsb = v[15] & 1u; - // shift v right by 1 bit. - for (var j = 15u; j > 0u; j = j - 1u) { - v[j] = ((v[j] >> 1u) | ((v[j-1u] & 1u) << 7u)) & 0xffu; - } - v[0] = (v[0] >> 1u) & 0xffu; - let rmask = (0u - lsb) & 0xffu; - v[0] = (v[0] ^ (0xe1u & rmask)) & 0xffu; - } - return r; -} - -fn inc32(ctr_in: array) -> array { - var ctr = ctr_in; - var c = (ctr[12] << 24u) | (ctr[13] << 16u) | (ctr[14] << 8u) | ctr[15]; - c = c + 1u; - ctr[12] = (c >> 24u) & 0xffu; - ctr[13] = (c >> 16u) & 0xffu; - ctr[14] = (c >> 8u) & 0xffu; - ctr[15] = c & 0xffu; - return ctr; -} - -fn ghash_update(y_in: array, h: array, arena: u32, - off: u32, len: u32) -> array { - var y = y_in; - var pos: u32 = 0u; - while (len - pos >= 16u) { - for (var i = 0u; i < 16u; i = i + 1u) { - y[i] = (y[i] ^ rd_msg(arena, off + pos + i)) & 0xffu; - } - y = ghash_mul(y, h); - pos = pos + 16u; - } - let rem = len - pos; - if (rem > 0u) { - var blk: array; - for (var i = 0u; i < 16u; i = i + 1u) { blk[i] = 0u; } - for (var i = 0u; i < rem; i = i + 1u) { - blk[i] = rd_msg(arena, off + pos + i); - } - for (var i = 0u; i < 16u; i = i + 1u) { - y[i] = (y[i] ^ blk[i]) & 0xffu; - } - y = ghash_mul(y, h); - } - return y; -} - -@compute @workgroup_size(64) -fn aes_gcm_jobs(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; - if (i >= params.n_jobs) { return; } - - let job = jobs[i]; - - aes256_expand_key(job.key_offset); - - // H = AES_K(0^128). - var zero: array; - for (var k = 0u; k < 16u; k = k + 1u) { zero[k] = 0u; } - let H = aes256_encrypt_block(zero); - - // J0 = IV || 0x00000001. - var J0: array; - for (var k = 0u; k < 12u; k = k + 1u) { - J0[k] = rd_iv_byte(job.nonce_offset, k); - } - J0[12] = 0u; J0[13] = 0u; J0[14] = 0u; J0[15] = 1u; - - // Encrypt plaintext under counter starting at inc32(J0). - var ctr = inc32(J0); - var pos: u32 = 0u; - while (pos < job.pt_len) { - let ks = aes256_encrypt_block(ctr); - var take = job.pt_len - pos; - if (take > 16u) { take = 16u; } - for (var k = 0u; k < take; k = k + 1u) { - let pt_b = rd_in(job.pt_offset + pos + k); - wr_out(job.ct_offset + pos + k, pt_b ^ ks[k]); - } - ctr = inc32(ctr); - pos = pos + take; - } - - // GHASH over (aad || pad || ct || pad || lens_in_bits). - var Y: array; - for (var k = 0u; k < 16u; k = k + 1u) { Y[k] = 0u; } - Y = ghash_update(Y, H, 0u, job.aad_offset, job.aad_len); - Y = ghash_update(Y, H, 1u, job.ct_offset, job.pt_len); - { - // Lengths block (BE, in bits). aad_len and pt_len are < 2^32, so - // upper 4 bytes of each 8-byte field are always zero. - var lens: array; - let la_bits = job.aad_len * 8u; - let lc_bits = job.pt_len * 8u; - // Upper 4 bytes of la_bits == 0 (since la_bits < 2^35 here, but we - // still emit u32 BE for the low 32 bits and zero for the high). - lens[0] = 0u; lens[1] = 0u; lens[2] = 0u; lens[3] = 0u; - lens[4] = (la_bits >> 24u) & 0xffu; - lens[5] = (la_bits >> 16u) & 0xffu; - lens[6] = (la_bits >> 8u) & 0xffu; - lens[7] = la_bits & 0xffu; - lens[ 8] = 0u; lens[ 9] = 0u; lens[10] = 0u; lens[11] = 0u; - lens[12] = (lc_bits >> 24u) & 0xffu; - lens[13] = (lc_bits >> 16u) & 0xffu; - lens[14] = (lc_bits >> 8u) & 0xffu; - lens[15] = lc_bits & 0xffu; - for (var k = 0u; k < 16u; k = k + 1u) { - Y[k] = (Y[k] ^ lens[k]) & 0xffu; - } - Y = ghash_mul(Y, H); - } - - // Tag = GHASH XOR AES_K(J0). - let s = aes256_encrypt_block(J0); - for (var k = 0u; k < 16u; k = k + 1u) { - wr_out(job.tag_offset + k, Y[k] ^ s[k]); - } -} diff --git a/aead/gpu/wgsl/chacha20_poly1305.wgsl b/aead/gpu/wgsl/chacha20_poly1305.wgsl deleted file mode 100644 index 1d1aa53..0000000 --- a/aead/gpu/wgsl/chacha20_poly1305.wgsl +++ /dev/null @@ -1,382 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Batched ChaCha20-Poly1305 (RFC 8439) compute shader. One thread per -// (key, nonce, aad, plaintext) message. Output is byte-equal to -// lux::crypto::aead::chacha20_poly1305::encrypt() in cpp/aead.cpp and to -// gpu/metal/aead_batch.metal. -// -// WGSL has no u8 nor u64. Bytes are packed into u32 storage buffers -// (LSB-first within each word). Poly1305 mirrors the Metal radix-2^26 -// limb layout, with all u64 arithmetic emulated via vec2 (lo, hi). - -struct AeadJob { - aad_offset: u32, - aad_len: u32, - pt_offset: u32, - pt_len: u32, - ct_offset: u32, - tag_offset: u32, - key_offset: u32, - nonce_offset: u32, -} - -struct Params { - n_jobs: u32, - _pad0: u32, - _pad1: u32, - _pad2: u32, -} - -@group(0) @binding(0) var jobs: array; -@group(0) @binding(1) var keys: array; -@group(0) @binding(2) var nonces: array; -@group(0) @binding(3) var inputs: array; -@group(0) @binding(4) var outputs: array; -@group(0) @binding(5) var params: Params; - -// ---- u64 emulation ------------------------------------------------------- - -fn u64_make(lo: u32) -> vec2 { return vec2(lo, 0u); } -fn u64_const_lo(lo: u32) -> vec2 { return vec2(lo, 0u); } -fn u64_add(a: vec2, b: vec2) -> vec2 { - let lo = a.x + b.x; - var carry: u32 = 0u; - if (lo < a.x) { carry = 1u; } - let hi = a.y + b.y + carry; - return vec2(lo, hi); -} -fn u64_shr_26(a: vec2) -> vec2 { - // a >> 26: lo = (a.x >> 26) | (a.y << 6); hi = a.y >> 26. - let lo = (a.x >> 26u) | (a.y << 6u); - let hi = a.y >> 26u; - return vec2(lo, hi); -} -fn u64_low26(a: vec2) -> u32 { - return a.x & 0x3ffffffu; -} -fn u64_low32(a: vec2) -> u32 { - return a.x; -} -fn u64_shr_32(a: vec2) -> vec2 { - return vec2(a.y, 0u); -} - -// 32 x 32 -> 64 -fn mul32_64(a: u32, b: u32) -> vec2 { - let al = a & 0xffffu; - let ah = a >> 16u; - let bl = b & 0xffffu; - let bh = b >> 16u; - let ll = al * bl; - let lh = al * bh; - let hl = ah * bl; - let hh = ah * bh; - let mid = (ll >> 16u) + (lh & 0xffffu) + (hl & 0xffffu); - let lo = (mid << 16u) | (ll & 0xffffu); - let hi = hh + (lh >> 16u) + (hl >> 16u) + (mid >> 16u); - return vec2(lo, hi); -} - -// ---- Byte access --------------------------------------------------------- - -fn rd_in(byte_off: u32) -> u32 { - let w = byte_off >> 2u; - let s = (byte_off & 3u) * 8u; - return (inputs[w] >> s) & 0xffu; -} -fn rd_out(byte_off: u32) -> u32 { - let w = byte_off >> 2u; - let s = (byte_off & 3u) * 8u; - return (outputs[w] >> s) & 0xffu; -} -fn wr_out(byte_off: u32, v: u32) { - let w = byte_off >> 2u; - let s = (byte_off & 3u) * 8u; - let mask = 0xffu << s; - outputs[w] = (outputs[w] & ~mask) | ((v & 0xffu) << s); -} -fn rd_msg(arena: u32, byte_off: u32) -> u32 { - if (arena == 0u) { return rd_in(byte_off); } - return rd_out(byte_off); -} - -fn rd_key32(key_byte_off: u32, idx: u32) -> u32 { - let w = (key_byte_off + idx * 4u) >> 2u; - return keys[w]; -} -fn rd_nonce32(nonce_byte_off: u32, idx: u32) -> u32 { - let w = (nonce_byte_off + idx * 4u) >> 2u; - return nonces[w]; -} - -// ---- ChaCha20 ------------------------------------------------------------ - -fn rotl32(x: u32, n: u32) -> u32 { - return (x << n) | (x >> (32u - n)); -} - -var v_state: array; -var v_init: array; - -fn quarter(ai: u32, bi: u32, ci: u32, di: u32) { - var a = v_state[ai]; var b = v_state[bi]; - var c = v_state[ci]; var d = v_state[di]; - a = a + b; d = d ^ a; d = rotl32(d, 16u); - c = c + d; b = b ^ c; b = rotl32(b, 12u); - a = a + b; d = d ^ a; d = rotl32(d, 8u); - c = c + d; b = b ^ c; b = rotl32(b, 7u); - v_state[ai] = a; v_state[bi] = b; - v_state[ci] = c; v_state[di] = d; -} - -fn chacha20_block(key_off: u32, nonce_off: u32, counter: u32) { - v_init[ 0] = 0x61707865u; v_init[ 1] = 0x3320646eu; - v_init[ 2] = 0x79622d32u; v_init[ 3] = 0x6b206574u; - v_init[ 4] = rd_key32(key_off, 0u); v_init[ 5] = rd_key32(key_off, 1u); - v_init[ 6] = rd_key32(key_off, 2u); v_init[ 7] = rd_key32(key_off, 3u); - v_init[ 8] = rd_key32(key_off, 4u); v_init[ 9] = rd_key32(key_off, 5u); - v_init[10] = rd_key32(key_off, 6u); v_init[11] = rd_key32(key_off, 7u); - v_init[12] = counter; - v_init[13] = rd_nonce32(nonce_off, 0u); - v_init[14] = rd_nonce32(nonce_off, 1u); - v_init[15] = rd_nonce32(nonce_off, 2u); - for (var i = 0u; i < 16u; i = i + 1u) { v_state[i] = v_init[i]; } - for (var r = 0u; r < 10u; r = r + 1u) { - quarter(0u, 4u, 8u, 12u); - quarter(1u, 5u, 9u, 13u); - quarter(2u, 6u, 10u, 14u); - quarter(3u, 7u, 11u, 15u); - quarter(0u, 5u, 10u, 15u); - quarter(1u, 6u, 11u, 12u); - quarter(2u, 7u, 8u, 13u); - quarter(3u, 4u, 9u, 14u); - } - for (var i = 0u; i < 16u; i = i + 1u) { v_state[i] = v_state[i] + v_init[i]; } -} - -fn ks_byte(i: u32) -> u32 { - let w = i >> 2u; - let s = (i & 3u) * 8u; - return (v_state[w] >> s) & 0xffu; -} - -// ---- Poly1305 (radix-2^26, mirrors metal/aead_batch.metal) ---------------- - -var p_r: array; // 5 x 26-bit limbs -var p_s: array; // 4 x 32-bit s words -var p_h: array; // 5 x 26-bit limbs - -fn poly_init_from_block0() { - let c0 = v_state[0] & 0x0fffffffu; - let c1 = v_state[1] & 0x0ffffffcu; - let c2 = v_state[2] & 0x0ffffffcu; - let c3 = v_state[3] & 0x0ffffffcu; - p_r[0] = c0 & 0x3ffffffu; - p_r[1] = ((c0 >> 26) | (c1 << 6)) & 0x3ffffffu; - p_r[2] = ((c1 >> 20) | (c2 << 12)) & 0x3ffffffu; - p_r[3] = ((c2 >> 14) | (c3 << 18)) & 0x3ffffffu; - p_r[4] = (c3 >> 8) & 0x3ffffffu; - p_s[0] = v_state[4]; - p_s[1] = v_state[5]; - p_s[2] = v_state[6]; - p_s[3] = v_state[7]; - for (var i = 0u; i < 5u; i = i + 1u) { p_h[i] = 0u; } -} - -// Compute one Poly1305 block: h = (h + m) * r mod (2^130 - 5). -// Layout is identical to metal/aead_batch.metal poly_block. -fn poly_block_words(t0: u32, t1: u32, t2: u32, t3: u32, hibit: u32) { - // h += unpacked m (5 limbs of 26 bits). - let h0 = p_h[0] + ( t0 & 0x3ffffffu); - let h1 = p_h[1] + (((t0 >> 26) | (t1 << 6)) & 0x3ffffffu); - let h2 = p_h[2] + (((t1 >> 20) | (t2 << 12)) & 0x3ffffffu); - let h3 = p_h[3] + (((t2 >> 14) | (t3 << 18)) & 0x3ffffffu); - // For the high limb: h4 += ((t3 >> 8) | hibit). hibit is 1u<<24 always. - let h4 = p_h[4] + ((t3 >> 8) | hibit); - - let r0 = p_r[0]; let r1 = p_r[1]; let r2 = p_r[2]; - let r3 = p_r[3]; let r4 = p_r[4]; - let s1 = r1 * 5u; let s2 = r2 * 5u; - let s3 = r3 * 5u; let s4 = r4 * 5u; - - // 64-bit accumulators d_i = sum of (h_j * coeff) - // Each h_j * coeff is up to 32-bit * 30-bit = 62-bit; sum of 5 fits in 65 bits - // with a tiny excess; we must use a true 64-bit add chain. - var d0 = u64_add(u64_add(u64_add(u64_add( - mul32_64(h0, r0), - mul32_64(h1, s4)), - mul32_64(h2, s3)), - mul32_64(h3, s2)), - mul32_64(h4, s1)); - var d1 = u64_add(u64_add(u64_add(u64_add( - mul32_64(h0, r1), - mul32_64(h1, r0)), - mul32_64(h2, s4)), - mul32_64(h3, s3)), - mul32_64(h4, s2)); - var d2 = u64_add(u64_add(u64_add(u64_add( - mul32_64(h0, r2), - mul32_64(h1, r1)), - mul32_64(h2, r0)), - mul32_64(h3, s4)), - mul32_64(h4, s3)); - var d3 = u64_add(u64_add(u64_add(u64_add( - mul32_64(h0, r3), - mul32_64(h1, r2)), - mul32_64(h2, r1)), - mul32_64(h3, r0)), - mul32_64(h4, s4)); - var d4 = u64_add(u64_add(u64_add(u64_add( - mul32_64(h0, r4), - mul32_64(h1, r3)), - mul32_64(h2, r2)), - mul32_64(h3, r1)), - mul32_64(h4, r0)); - - var c: vec2; - c = u64_shr_26(d0); d0.x = u64_low26(d0); d0.y = 0u; d1 = u64_add(d1, c); - c = u64_shr_26(d1); d1.x = u64_low26(d1); d1.y = 0u; d2 = u64_add(d2, c); - c = u64_shr_26(d2); d2.x = u64_low26(d2); d2.y = 0u; d3 = u64_add(d3, c); - c = u64_shr_26(d3); d3.x = u64_low26(d3); d3.y = 0u; d4 = u64_add(d4, c); - // d4_hi -> *5 -> d0 - c = u64_shr_26(d4); d4.x = u64_low26(d4); d4.y = 0u; - // c may have nonzero hi; multiply by 5 carefully. - // Since h values were at most ~2^27 (h_i + m_i, where p_h[i] < 2^26 and m_i < 2^26), - // after mul each d_i is < 5 * 2^27 * 2^27 = 5 * 2^54, comfortably within 64 bits. - // c after >>26 is < 2^38 (safe). Multiplying by 5 stays < 2^41. - let c_lo = c.x; let c_hi = c.y; - // c * 5 in u64. - let c5_lo_full = mul32_64(c_lo, 5u); - var c5 = c5_lo_full; - if (c_hi != 0u) { - // c_hi * 5 contributes to high. - c5.y = c5.y + c_hi * 5u; - } - d0 = u64_add(d0, c5); - c = u64_shr_26(d0); d0.x = u64_low26(d0); d0.y = 0u; d1 = u64_add(d1, c); - - p_h[0] = u64_low32(d0); - p_h[1] = u64_low32(d1); - p_h[2] = u64_low32(d2); - p_h[3] = u64_low32(d3); - p_h[4] = u64_low32(d4); -} - -fn poly_block_bytes(arena: u32, byte_off: u32, byte_len: u32, full: bool) { - var b: array; - var consumed = byte_len; - if (full) { consumed = 16u; } - for (var i = 0u; i < 16u; i = i + 1u) { - if (i < consumed) { b[i] = rd_msg(arena, byte_off + i); } - else { b[i] = 0u; } - } - let t0 = b[ 0] | (b[ 1] << 8u) | (b[ 2] << 16u) | (b[ 3] << 24u); - let t1 = b[ 4] | (b[ 5] << 8u) | (b[ 6] << 16u) | (b[ 7] << 24u); - let t2 = b[ 8] | (b[ 9] << 8u) | (b[10] << 16u) | (b[11] << 24u); - let t3 = b[12] | (b[13] << 8u) | (b[14] << 16u) | (b[15] << 24u); - poly_block_words(t0, t1, t2, t3, 1u << 24u); -} - -fn absorb_padded(arena: u32, off: u32, len: u32) { - var pos: u32 = 0u; - while (len - pos >= 16u) { - poly_block_bytes(arena, off + pos, 16u, true); - pos = pos + 16u; - } - let rem = len - pos; - if (rem > 0u) { - poly_block_bytes(arena, off + pos, rem, false); - } -} - -fn poly_finalize_to_tag(out_byte_off: u32) { - var h0 = p_h[0]; var h1 = p_h[1]; var h2 = p_h[2]; - var h3 = p_h[3]; var h4 = p_h[4]; - var c: u32; - c = h1 >> 26u; h1 = h1 & 0x3ffffffu; h2 = h2 + c; - c = h2 >> 26u; h2 = h2 & 0x3ffffffu; h3 = h3 + c; - c = h3 >> 26u; h3 = h3 & 0x3ffffffu; h4 = h4 + c; - c = h4 >> 26u; h4 = h4 & 0x3ffffffu; h0 = h0 + c * 5u; - c = h0 >> 26u; h0 = h0 & 0x3ffffffu; h1 = h1 + c; - - var g0 = h0 + 5u; c = g0 >> 26u; g0 = g0 & 0x3ffffffu; - var g1 = h1 + c; c = g1 >> 26u; g1 = g1 & 0x3ffffffu; - var g2 = h2 + c; c = g2 >> 26u; g2 = g2 & 0x3ffffffu; - var g3 = h3 + c; c = g3 >> 26u; g3 = g3 & 0x3ffffffu; - let g4 = h4 + c - (1u << 26u); - - let mask = (g4 >> 31u) - 1u; - let nm = ~mask; - h0 = (h0 & nm) | (g0 & mask); - h1 = (h1 & nm) | (g1 & mask); - h2 = (h2 & nm) | (g2 & mask); - h3 = (h3 & nm) | (g3 & mask); - h4 = (h4 & nm) | (g4 & mask); - - let f0 = h0 | (h1 << 26u); - let f1 = (h1 >> 6u) | (h2 << 20u); - let f2 = (h2 >> 12u) | (h3 << 14u); - let f3 = (h3 >> 18u) | (h4 << 8u); - - var t = u64_add(u64_make(f0), u64_make(p_s[0])); - let w0 = t.x; - wr_out(out_byte_off + 0u, (w0 >> 0u) & 0xffu); - wr_out(out_byte_off + 1u, (w0 >> 8u) & 0xffu); - wr_out(out_byte_off + 2u, (w0 >> 16u) & 0xffu); - wr_out(out_byte_off + 3u, (w0 >> 24u) & 0xffu); - t = u64_add(u64_add(u64_shr_32(t), u64_make(f1)), u64_make(p_s[1])); - let w1 = t.x; - wr_out(out_byte_off + 4u, (w1 >> 0u) & 0xffu); - wr_out(out_byte_off + 5u, (w1 >> 8u) & 0xffu); - wr_out(out_byte_off + 6u, (w1 >> 16u) & 0xffu); - wr_out(out_byte_off + 7u, (w1 >> 24u) & 0xffu); - t = u64_add(u64_add(u64_shr_32(t), u64_make(f2)), u64_make(p_s[2])); - let w2 = t.x; - wr_out(out_byte_off + 8u, (w2 >> 0u) & 0xffu); - wr_out(out_byte_off + 9u, (w2 >> 8u) & 0xffu); - wr_out(out_byte_off + 10u, (w2 >> 16u) & 0xffu); - wr_out(out_byte_off + 11u, (w2 >> 24u) & 0xffu); - t = u64_add(u64_add(u64_shr_32(t), u64_make(f3)), u64_make(p_s[3])); - let w3 = t.x; - wr_out(out_byte_off + 12u, (w3 >> 0u) & 0xffu); - wr_out(out_byte_off + 13u, (w3 >> 8u) & 0xffu); - wr_out(out_byte_off + 14u, (w3 >> 16u) & 0xffu); - wr_out(out_byte_off + 15u, (w3 >> 24u) & 0xffu); -} - -@compute @workgroup_size(64) -fn chacha20_poly1305_jobs(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; - if (i >= params.n_jobs) { return; } - - let job = jobs[i]; - - chacha20_block(job.key_offset, job.nonce_offset, 0u); - poly_init_from_block0(); - - var counter: u32 = 1u; - var pos: u32 = 0u; - while (pos < job.pt_len) { - chacha20_block(job.key_offset, job.nonce_offset, counter); - var take = job.pt_len - pos; - if (take > 64u) { take = 64u; } - for (var k = 0u; k < take; k = k + 1u) { - let pt_b = rd_in(job.pt_offset + pos + k); - let ks_b = ks_byte(k); - wr_out(job.ct_offset + pos + k, pt_b ^ ks_b); - } - pos = pos + take; - counter = counter + 1u; - } - - absorb_padded(0u, job.aad_offset, job.aad_len); - absorb_padded(1u, job.ct_offset, job.pt_len); - - // Lengths block: aad_len LE u64 || pt_len LE u64. - let la = job.aad_len; - let lc = job.pt_len; - poly_block_words(la, 0u, lc, 0u, 1u << 24u); - - poly_finalize_to_tag(job.tag_offset); -} diff --git a/banderwagon/gpu/cuda/banderwagon.cu b/banderwagon/gpu/cuda/banderwagon.cu deleted file mode 100644 index c86d3aa..0000000 --- a/banderwagon/gpu/cuda/banderwagon.cu +++ /dev/null @@ -1,552 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// First-party CUDA kernel for Banderwagon group operations. -// -// Mechanical port of banderwagon/gpu/metal/banderwagon.metal -- byte-for-byte -// equivalent to lux::banderwagon::Element {add, double_self, scalar_mul} in -// banderwagon/cpp/element.cpp (twisted Edwards a*x^2 + y^2 = 1 + d*x^2*y^2 -// over the BLS12-381 scalar field; a = -5; d = canonical gnark constant). -// -// The constant table (modulus, Montgomery R/R^2/qInvNeg, curve a/d, generator) -// is emitted from the CPU body via banderwagon_gen_gpu_constants -> -// banderwagon_const.cuh. There is exactly one source of truth across CPU, -// Metal, CUDA, and WGSL. -// -// Compile with nvcc compute_60+ on Linux/CUDA hosts. On hosts without nvcc -// (Apple, plain g++/clang++), this file compiles via a host-side polyfill -// that elides __device__/__global__/__host__ qualifiers and lets the kernel -// run on the CPU oracle path. Either path runs the identical kernel body and -// yields byte-equal output by construction. - -#include -#include - -// Polyfill: when not building with nvcc, neutralize device qualifiers so the -// kernel body is plain C++. The driver below dispatches the same body via a -// host-side loop in that mode. -#ifndef __CUDA_ARCH__ -# ifndef LUX_BANDERWAGON_CUDA_HOST_POLYFILL -# define LUX_BANDERWAGON_CUDA_HOST_POLYFILL 1 -# endif -#endif - -#if LUX_BANDERWAGON_CUDA_HOST_POLYFILL -# define __device__ -# define __global__ -# define __host__ -# define __forceinline__ inline -#endif - -#include "banderwagon_const.cuh" // FP_Q_*, FP_R*_*, CURVE_*_*, FP_QINV_NEG, ... - -// ============================================================================= -// 64x64 -> 128 multiply. nvcc emits __umul64hi on the GPU; the host polyfill -// uses __int128 (or a portable 32-bit fallback). -// ============================================================================= -__device__ __forceinline__ void mul64(unsigned long long a, - unsigned long long b, - unsigned long long &lo, - unsigned long long &hi) { -#if defined(__CUDA_ARCH__) - lo = a * b; - hi = __umul64hi(a, b); -#elif defined(__SIZEOF_INT128__) - unsigned __int128 t = (unsigned __int128)a * b; - lo = (unsigned long long)t; - hi = (unsigned long long)(t >> 64); -#else - unsigned long long al = a & 0xffffffffULL, ah = a >> 32; - unsigned long long bl = b & 0xffffffffULL, bh = b >> 32; - unsigned long long ll = al * bl; - unsigned long long lh = al * bh; - unsigned long long hl = ah * bl; - unsigned long long hh = ah * bh; - unsigned long long mid = - (ll >> 32) + (lh & 0xffffffffULL) + (hl & 0xffffffffULL); - lo = (ll & 0xffffffffULL) | (mid << 32); - hi = hh + (lh >> 32) + (hl >> 32) + (mid >> 32); -#endif -} - -__device__ __forceinline__ unsigned long long adc(unsigned long long a, - unsigned long long b, - unsigned long long &carry) { - unsigned long long s = a + b; - unsigned long long c1 = (s < a) ? 1ULL : 0ULL; - unsigned long long s2 = s + carry; - unsigned long long c2 = (s2 < s) ? 1ULL : 0ULL; - carry = c1 + c2; - return s2; -} - -__device__ __forceinline__ unsigned long long sbb(unsigned long long a, - unsigned long long b, - unsigned long long &borrow) { - unsigned long long d = a - b; - unsigned long long b1 = (a < b) ? 1ULL : 0ULL; - unsigned long long d2 = d - borrow; - unsigned long long b2 = (d < borrow) ? 1ULL : 0ULL; - borrow = b1 + b2; - return d2; -} - -struct Fp { unsigned long long l0, l1, l2, l3; }; -struct Pt { Fp X; Fp Y; Fp Z; }; - -__device__ __forceinline__ void fp_cond_sub_q(Fp &a) { - unsigned long long br = 0; - unsigned long long r0 = sbb(a.l0, FP_Q_0, br); - unsigned long long r1 = sbb(a.l1, FP_Q_1, br); - unsigned long long r2 = sbb(a.l2, FP_Q_2, br); - unsigned long long r3 = sbb(a.l3, FP_Q_3, br); - unsigned long long mask = br - 1ULL; - a.l0 = (a.l0 & ~mask) | (r0 & mask); - a.l1 = (a.l1 & ~mask) | (r1 & mask); - a.l2 = (a.l2 & ~mask) | (r2 & mask); - a.l3 = (a.l3 & ~mask) | (r3 & mask); -} - -__device__ __forceinline__ void fp_cond_add_q(Fp &a, unsigned long long mask) { - unsigned long long c = 0; - a.l0 = adc(a.l0, FP_Q_0 & mask, c); - a.l1 = adc(a.l1, FP_Q_1 & mask, c); - a.l2 = adc(a.l2, FP_Q_2 & mask, c); - a.l3 = adc(a.l3, FP_Q_3 & mask, c); -} - -__device__ __forceinline__ Fp fp_add(const Fp &a, const Fp &b) { - Fp r; - unsigned long long c = 0; - r.l0 = adc(a.l0, b.l0, c); - r.l1 = adc(a.l1, b.l1, c); - r.l2 = adc(a.l2, b.l2, c); - r.l3 = adc(a.l3, b.l3, c); - fp_cond_sub_q(r); - return r; -} - -__device__ __forceinline__ Fp fp_sub(const Fp &a, const Fp &b) { - Fp r; - unsigned long long br = 0; - r.l0 = sbb(a.l0, b.l0, br); - r.l1 = sbb(a.l1, b.l1, br); - r.l2 = sbb(a.l2, b.l2, br); - r.l3 = sbb(a.l3, b.l3, br); - fp_cond_add_q(r, 0ULL - br); - return r; -} - -__device__ __forceinline__ Fp fp_mul(const Fp &a, const Fp &b) { - const unsigned long long xl[4] = {a.l0, a.l1, a.l2, a.l3}; - const unsigned long long yl[4] = {b.l0, b.l1, b.l2, b.l3}; - const unsigned long long qq[4] = {FP_Q_0, FP_Q_1, FP_Q_2, FP_Q_3}; - - unsigned long long t[5] = {0, 0, 0, 0, 0}; - for (int i = 0; i < 4; ++i) { - const unsigned long long yi = yl[i]; - - unsigned long long cy = 0; - for (int j = 0; j < 4; ++j) { - unsigned long long lo, hi; - mul64(xl[j], yi, lo, hi); - unsigned long long c1 = 0; - unsigned long long s = adc(t[j], lo, c1); - unsigned long long c2 = 0; - unsigned long long s2 = adc(s, cy, c2); - t[j] = s2; - cy = hi + c1 + c2; - } - unsigned long long carry_out = 0; - t[4] = adc(t[4], cy, carry_out); - unsigned long long D = carry_out; - - unsigned long long m = t[0] * FP_QINV_NEG; - - cy = 0; - for (int j = 0; j < 4; ++j) { - unsigned long long lo, hi; - mul64(m, qq[j], lo, hi); - unsigned long long c1 = 0; - unsigned long long s = adc(t[j], lo, c1); - unsigned long long c2 = 0; - unsigned long long s2 = adc(s, cy, c2); - t[j] = s2; - cy = hi + c1 + c2; - } - carry_out = 0; - unsigned long long t3_new = adc(t[4], cy, carry_out); - unsigned long long t4_new = adc(0ULL, D, carry_out); - - t[0] = t[1]; t[1] = t[2]; t[2] = t[3]; - t[3] = t3_new; - t[4] = t4_new; - } - - Fp r; - r.l0 = t[0]; r.l1 = t[1]; r.l2 = t[2]; r.l3 = t[3]; - if (t[4] != 0) { - unsigned long long b = 0; - r.l0 = sbb(r.l0, FP_Q_0, b); - r.l1 = sbb(r.l1, FP_Q_1, b); - r.l2 = sbb(r.l2, FP_Q_2, b); - r.l3 = sbb(r.l3, FP_Q_3, b); - return r; - } - fp_cond_sub_q(r); - return r; -} - -__device__ __forceinline__ Fp fp_square(const Fp &a) { return fp_mul(a, a); } - -__device__ __forceinline__ Fp fp_zero() { - Fp r; r.l0=0; r.l1=0; r.l2=0; r.l3=0; return r; -} -__device__ __forceinline__ Fp fp_one() { - Fp r; r.l0=FP_R_0; r.l1=FP_R_1; r.l2=FP_R_2; r.l3=FP_R_3; return r; -} -__device__ __forceinline__ Fp curve_a_const() { - Fp r; r.l0=CURVE_A_0; r.l1=CURVE_A_1; r.l2=CURVE_A_2; r.l3=CURVE_A_3; - return r; -} -__device__ __forceinline__ Fp curve_d_const() { - Fp r; r.l0=CURVE_D_0; r.l1=CURVE_D_1; r.l2=CURVE_D_2; r.l3=CURVE_D_3; - return r; -} - -__device__ __forceinline__ Pt pt_identity() { - Pt p; p.X = fp_zero(); p.Y = fp_one(); p.Z = fp_one(); return p; -} - -__device__ __forceinline__ Pt pt_add(const Pt &p1, const Pt &p2) { - Fp d_const = curve_d_const(); - Fp a_const = curve_a_const(); - Fp A = fp_mul(p1.Z, p2.Z); - Fp B = fp_square(A); - Fp C = fp_mul(p1.X, p2.X); - Fp D = fp_mul(p1.Y, p2.Y); - Fp E = fp_mul(d_const, fp_mul(C, D)); - Fp F = fp_sub(B, E); - Fp G = fp_add(B, E); - Fp H = fp_add(p1.X, p1.Y); - Fp I = fp_add(p2.X, p2.Y); - - Pt r; - Fp t = fp_mul(H, I); - t = fp_sub(t, C); - t = fp_sub(t, D); - t = fp_mul(t, A); - r.X = fp_mul(t, F); - - Fp aC = fp_mul(a_const, C); - Fp t2 = fp_sub(D, aC); - t2 = fp_mul(t2, A); - r.Y = fp_mul(t2, G); - - r.Z = fp_mul(F, G); - return r; -} - -__device__ __forceinline__ Pt pt_double(const Pt &p) { - Fp a_const = curve_a_const(); - Fp XY = fp_add(p.X, p.Y); - Fp B = fp_square(XY); - Fp C = fp_square(p.X); - Fp D = fp_square(p.Y); - Fp E = fp_mul(a_const, C); - Fp F = fp_add(E, D); - Fp H = fp_square(p.Z); - Fp twoH = fp_add(H, H); - Fp J = fp_sub(F, twoH); - - Pt r; - Fp t = fp_sub(B, C); - t = fp_sub(t, D); - r.X = fp_mul(t, J); - r.Y = fp_mul(F, fp_sub(E, D)); - r.Z = fp_mul(F, J); - return r; -} - -__device__ __forceinline__ void pt_cmov(Pt &dst, const Pt &src, - unsigned long long mask) { - dst.X.l0 = (dst.X.l0 & ~mask) | (src.X.l0 & mask); - dst.X.l1 = (dst.X.l1 & ~mask) | (src.X.l1 & mask); - dst.X.l2 = (dst.X.l2 & ~mask) | (src.X.l2 & mask); - dst.X.l3 = (dst.X.l3 & ~mask) | (src.X.l3 & mask); - dst.Y.l0 = (dst.Y.l0 & ~mask) | (src.Y.l0 & mask); - dst.Y.l1 = (dst.Y.l1 & ~mask) | (src.Y.l1 & mask); - dst.Y.l2 = (dst.Y.l2 & ~mask) | (src.Y.l2 & mask); - dst.Y.l3 = (dst.Y.l3 & ~mask) | (src.Y.l3 & mask); - dst.Z.l0 = (dst.Z.l0 & ~mask) | (src.Z.l0 & mask); - dst.Z.l1 = (dst.Z.l1 & ~mask) | (src.Z.l1 & mask); - dst.Z.l2 = (dst.Z.l2 & ~mask) | (src.Z.l2 & mask); - dst.Z.l3 = (dst.Z.l3 & ~mask) | (src.Z.l3 & mask); -} - -__device__ __forceinline__ Pt pt_scalar_mul(const Pt &p, - const std::uint8_t *s_le) { - Pt acc = pt_identity(); - Pt base = p; - for (int byte_idx = 0; byte_idx < 32; ++byte_idx) { - std::uint8_t b = s_le[byte_idx]; - for (int bit = 0; bit < 8; ++bit) { - unsigned long long one_or_zero = - (unsigned long long)((b >> bit) & 1u); - unsigned long long mask = 0ULL - one_or_zero; - Pt sum = pt_add(acc, base); - pt_cmov(acc, sum, mask); - base = pt_double(base); - } - } - return acc; -} - -__device__ __forceinline__ Fp read_fp_limbs(const std::uint8_t *p) { - Fp r; - auto rd = [&](int o) -> unsigned long long { - return ((unsigned long long)p[o]) | ((unsigned long long)p[o+1] << 8) - | ((unsigned long long)p[o+2] << 16) | ((unsigned long long)p[o+3] << 24) - | ((unsigned long long)p[o+4] << 32) | ((unsigned long long)p[o+5] << 40) - | ((unsigned long long)p[o+6] << 48) | ((unsigned long long)p[o+7] << 56); - }; - r.l0 = rd(0); r.l1 = rd(8); r.l2 = rd(16); r.l3 = rd(24); - return r; -} - -__device__ __forceinline__ void write_fp_limbs(const Fp &x, std::uint8_t *p) { - auto wr = [&](unsigned long long v, int o) { - for (int i = 0; i < 8; ++i) { p[o+i] = (std::uint8_t)(v & 0xff); v >>= 8; } - }; - wr(x.l0, 0); wr(x.l1, 8); wr(x.l2, 16); wr(x.l3, 24); -} - -__device__ __forceinline__ Pt read_pt(const std::uint8_t *p) { - Pt r; - r.X = read_fp_limbs(p); - r.Y = read_fp_limbs(p + 32); - r.Z = read_fp_limbs(p + 64); - return r; -} - -__device__ __forceinline__ void write_pt(const Pt &p, std::uint8_t *out) { - write_fp_limbs(p.X, out); - write_fp_limbs(p.Y, out + 32); - write_fp_limbs(p.Z, out + 64); -} - -__device__ static void bw_add_one(const std::uint8_t *pair_in, - std::uint8_t *out) { - Pt P = read_pt(pair_in); - Pt Q = read_pt(pair_in + 96); - Pt R = pt_add(P, Q); - write_pt(R, out); -} - -__device__ static void bw_double_one(const std::uint8_t *pt_in, - std::uint8_t *out) { - Pt P = read_pt(pt_in); - Pt R = pt_double(P); - write_pt(R, out); -} - -__device__ static void bw_smul_one(const std::uint8_t *pt_in, - const std::uint8_t *scalar_in, - std::uint8_t *out) { - Pt P = read_pt(pt_in); - Pt R = pt_scalar_mul(P, scalar_in); - write_pt(R, out); -} - -__device__ static void bw_msm_one(const std::uint8_t *pts, - const std::uint8_t *scalars_b, - std::uint8_t *out, - unsigned int n) { - Pt acc = pt_identity(); - for (unsigned int i = 0; i < n; ++i) { - Pt P = read_pt(pts + (size_t)i * 96); - Pt term = pt_scalar_mul(P, scalars_b + (size_t)i * 32); - acc = pt_add(acc, term); - } - write_pt(acc, out); -} - -#if !LUX_BANDERWAGON_CUDA_HOST_POLYFILL -extern "C" __global__ void banderwagon_add_kernel( - const std::uint8_t *pairs, std::uint8_t *outs, unsigned int n) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= n) return; - bw_add_one(pairs + (size_t)i * 192, outs + (size_t)i * 96); -} -extern "C" __global__ void banderwagon_double_kernel( - const std::uint8_t *pts, std::uint8_t *outs, unsigned int n) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= n) return; - bw_double_one(pts + (size_t)i * 96, outs + (size_t)i * 96); -} -extern "C" __global__ void banderwagon_smul_kernel( - const std::uint8_t *pts, const std::uint8_t *scalars, - std::uint8_t *outs, unsigned int n) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= n) return; - bw_smul_one(pts + (size_t)i * 96, scalars + (size_t)i * 32, - outs + (size_t)i * 96); -} -extern "C" __global__ void banderwagon_msm_kernel( - const std::uint8_t *pts, const std::uint8_t *scalars, - std::uint8_t *outs, unsigned int n, unsigned int M) { - unsigned int b = blockIdx.x * blockDim.x + threadIdx.x; - if (b >= M) return; - bw_msm_one(pts, scalars + (size_t)b * n * 32, outs + (size_t)b * 96, n); -} -#endif - -extern "C" int banderwagon_cuda_add_batch(const std::uint8_t *pairs, - std::uint8_t *outs, - unsigned long n) { - if (n == 0) return 0; - if (!pairs || !outs) return -1; - -#if LUX_BANDERWAGON_CUDA_HOST_POLYFILL - for (unsigned long i = 0; i < n; ++i) { - bw_add_one(pairs + (size_t)i * 192, outs + (size_t)i * 96); - } - return 0; -#else - std::uint8_t *d_in = nullptr, *d_out = nullptr; - cudaError_t st; - st = cudaMalloc(&d_in, (size_t)n * 192); if (st != cudaSuccess) return -2; - st = cudaMalloc(&d_out, (size_t)n * 96); - if (st != cudaSuccess) { cudaFree(d_in); return -2; } - st = cudaMemcpy(d_in, pairs, (size_t)n * 192, cudaMemcpyHostToDevice); - if (st != cudaSuccess) { cudaFree(d_in); cudaFree(d_out); return -3; } - unsigned int tpb = 64; - unsigned int blocks = (unsigned int)((n + tpb - 1) / tpb); - banderwagon_add_kernel<<>>(d_in, d_out, (unsigned int)n); - st = cudaGetLastError(); - if (st == cudaSuccess) st = cudaDeviceSynchronize(); - if (st != cudaSuccess) { cudaFree(d_in); cudaFree(d_out); return -4; } - st = cudaMemcpy(outs, d_out, (size_t)n * 96, cudaMemcpyDeviceToHost); - cudaFree(d_in); cudaFree(d_out); - return (st != cudaSuccess) ? -5 : 0; -#endif -} - -extern "C" int banderwagon_cuda_double_batch(const std::uint8_t *pts, - std::uint8_t *outs, - unsigned long n) { - if (n == 0) return 0; - if (!pts || !outs) return -1; - -#if LUX_BANDERWAGON_CUDA_HOST_POLYFILL - for (unsigned long i = 0; i < n; ++i) { - bw_double_one(pts + (size_t)i * 96, outs + (size_t)i * 96); - } - return 0; -#else - std::uint8_t *d_in = nullptr, *d_out = nullptr; - cudaError_t st; - st = cudaMalloc(&d_in, (size_t)n * 96); if (st != cudaSuccess) return -2; - st = cudaMalloc(&d_out, (size_t)n * 96); - if (st != cudaSuccess) { cudaFree(d_in); return -2; } - st = cudaMemcpy(d_in, pts, (size_t)n * 96, cudaMemcpyHostToDevice); - if (st != cudaSuccess) { cudaFree(d_in); cudaFree(d_out); return -3; } - unsigned int tpb = 64; - unsigned int blocks = (unsigned int)((n + tpb - 1) / tpb); - banderwagon_double_kernel<<>>(d_in, d_out, (unsigned int)n); - st = cudaGetLastError(); - if (st == cudaSuccess) st = cudaDeviceSynchronize(); - if (st != cudaSuccess) { cudaFree(d_in); cudaFree(d_out); return -4; } - st = cudaMemcpy(outs, d_out, (size_t)n * 96, cudaMemcpyDeviceToHost); - cudaFree(d_in); cudaFree(d_out); - return (st != cudaSuccess) ? -5 : 0; -#endif -} - -extern "C" int banderwagon_cuda_smul_batch(const std::uint8_t *pts, - const std::uint8_t *scalars, - std::uint8_t *outs, - unsigned long n) { - if (n == 0) return 0; - if (!pts || !scalars || !outs) return -1; - -#if LUX_BANDERWAGON_CUDA_HOST_POLYFILL - for (unsigned long i = 0; i < n; ++i) { - bw_smul_one(pts + (size_t)i * 96, - scalars + (size_t)i * 32, - outs + (size_t)i * 96); - } - return 0; -#else - std::uint8_t *d_pts = nullptr, *d_scl = nullptr, *d_out = nullptr; - cudaError_t st; - st = cudaMalloc(&d_pts, (size_t)n * 96); if (st != cudaSuccess) return -2; - st = cudaMalloc(&d_scl, (size_t)n * 32); - if (st != cudaSuccess) { cudaFree(d_pts); return -2; } - st = cudaMalloc(&d_out, (size_t)n * 96); - if (st != cudaSuccess) { cudaFree(d_pts); cudaFree(d_scl); return -2; } - st = cudaMemcpy(d_pts, pts, (size_t)n * 96, cudaMemcpyHostToDevice); - if (st == cudaSuccess) - st = cudaMemcpy(d_scl, scalars, (size_t)n * 32, cudaMemcpyHostToDevice); - if (st != cudaSuccess) { - cudaFree(d_pts); cudaFree(d_scl); cudaFree(d_out); return -3; - } - unsigned int tpb = 32; - unsigned int blocks = (unsigned int)((n + tpb - 1) / tpb); - banderwagon_smul_kernel<<>>(d_pts, d_scl, d_out, - (unsigned int)n); - st = cudaGetLastError(); - if (st == cudaSuccess) st = cudaDeviceSynchronize(); - if (st != cudaSuccess) { - cudaFree(d_pts); cudaFree(d_scl); cudaFree(d_out); return -4; - } - st = cudaMemcpy(outs, d_out, (size_t)n * 96, cudaMemcpyDeviceToHost); - cudaFree(d_pts); cudaFree(d_scl); cudaFree(d_out); - return (st != cudaSuccess) ? -5 : 0; -#endif -} - -extern "C" int banderwagon_cuda_msm_batch(const std::uint8_t *pts, - const std::uint8_t *scalars, - std::uint8_t *outs, - unsigned long n, - unsigned long M) { - if (n == 0 || M == 0) return 0; - if (!pts || !scalars || !outs) return -1; - -#if LUX_BANDERWAGON_CUDA_HOST_POLYFILL - for (unsigned long b = 0; b < M; ++b) { - bw_msm_one(pts, - scalars + (size_t)b * n * 32, - outs + (size_t)b * 96, - (unsigned int)n); - } - return 0; -#else - std::uint8_t *d_pts = nullptr, *d_scl = nullptr, *d_out = nullptr; - cudaError_t st; - st = cudaMalloc(&d_pts, (size_t)n * 96); if (st != cudaSuccess) return -2; - st = cudaMalloc(&d_scl, (size_t)M * n * 32); - if (st != cudaSuccess) { cudaFree(d_pts); return -2; } - st = cudaMalloc(&d_out, (size_t)M * 96); - if (st != cudaSuccess) { cudaFree(d_pts); cudaFree(d_scl); return -2; } - st = cudaMemcpy(d_pts, pts, (size_t)n * 96, cudaMemcpyHostToDevice); - if (st == cudaSuccess) - st = cudaMemcpy(d_scl, scalars, (size_t)M * n * 32, cudaMemcpyHostToDevice); - if (st != cudaSuccess) { - cudaFree(d_pts); cudaFree(d_scl); cudaFree(d_out); return -3; - } - unsigned int tpb = 32; - unsigned int blocks = (unsigned int)((M + tpb - 1) / tpb); - banderwagon_msm_kernel<<>>(d_pts, d_scl, d_out, - (unsigned int)n, (unsigned int)M); - st = cudaGetLastError(); - if (st == cudaSuccess) st = cudaDeviceSynchronize(); - if (st != cudaSuccess) { - cudaFree(d_pts); cudaFree(d_scl); cudaFree(d_out); return -4; - } - st = cudaMemcpy(outs, d_out, (size_t)M * 96, cudaMemcpyDeviceToHost); - cudaFree(d_pts); cudaFree(d_scl); cudaFree(d_out); - return (st != cudaSuccess) ? -5 : 0; -#endif -} diff --git a/banderwagon/gpu/cuda/banderwagon_driver.cpp b/banderwagon/gpu/cuda/banderwagon_driver.cpp deleted file mode 100644 index ede58e8..0000000 --- a/banderwagon/gpu/cuda/banderwagon_driver.cpp +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CUDA host driver for Banderwagon group ops. -// -// Build modes: -// 1. With CUDA toolkit (LUX_BANDERWAGON_HAVE_CUDA defined): compiles -// banderwagon.cu via nvcc, dispatches kernels with one thread per work -// item (or per MSM). Byte-equal to lux::banderwagon::Element ops. -// 2. Without CUDA: stub mode. banderwagon_cuda_available() returns 0, -// every dispatch returns -1. - -#include "banderwagon_driver.h" - -#include - -#ifdef LUX_BANDERWAGON_HAVE_CUDA -#include - -extern "C" __global__ void banderwagon_add_batch( - const uint8_t* pairs, uint8_t* outs, uint32_t n); -extern "C" __global__ void banderwagon_double_batch( - const uint8_t* pts, uint8_t* outs, uint32_t n); -extern "C" __global__ void banderwagon_smul_batch( - const uint8_t* pts, const uint8_t* scalars, uint8_t* outs, uint32_t n); -extern "C" __global__ void banderwagon_msm_batch_naive( - const uint8_t* pts, const uint8_t* scalars, uint8_t* outs, - uint32_t n, uint32_t M); - -extern "C" int banderwagon_cuda_available(void) { - int count = 0; - cudaError_t e = cudaGetDeviceCount(&count); - return (e == cudaSuccess && count > 0) ? 1 : 0; -} - -namespace { -struct DevPtr { - uint8_t* p = nullptr; - ~DevPtr() { if (p) cudaFree(p); } -}; -} // namespace - -extern "C" int banderwagon_cuda_add_batch( - const uint8_t* pairs, uint8_t* outs, size_t n) { - if (n == 0) return 0; - if (!pairs || !outs) return -1; - if (!banderwagon_cuda_available()) return -2; - - DevPtr d_in, d_out; - if (cudaMalloc((void**)&d_in.p, n * 192) != cudaSuccess) return -3; - if (cudaMalloc((void**)&d_out.p, n * 96) != cudaSuccess) return -3; - if (cudaMemcpy(d_in.p, pairs, n * 192, cudaMemcpyHostToDevice) != cudaSuccess) return -4; - unsigned tg = 64; - unsigned grid = unsigned((n + tg - 1) / tg); - banderwagon_add_batch<<>>(d_in.p, d_out.p, (uint32_t)n); - if (cudaDeviceSynchronize() != cudaSuccess) return -4; - if (cudaMemcpy(outs, d_out.p, n * 96, cudaMemcpyDeviceToHost) != cudaSuccess) return -4; - return 0; -} - -extern "C" int banderwagon_cuda_double_batch( - const uint8_t* pts, uint8_t* outs, size_t n) { - if (n == 0) return 0; - if (!pts || !outs) return -1; - if (!banderwagon_cuda_available()) return -2; - - DevPtr d_in, d_out; - if (cudaMalloc((void**)&d_in.p, n * 96) != cudaSuccess) return -3; - if (cudaMalloc((void**)&d_out.p, n * 96) != cudaSuccess) return -3; - if (cudaMemcpy(d_in.p, pts, n * 96, cudaMemcpyHostToDevice) != cudaSuccess) return -4; - unsigned tg = 64; - unsigned grid = unsigned((n + tg - 1) / tg); - banderwagon_double_batch<<>>(d_in.p, d_out.p, (uint32_t)n); - if (cudaDeviceSynchronize() != cudaSuccess) return -4; - if (cudaMemcpy(outs, d_out.p, n * 96, cudaMemcpyDeviceToHost) != cudaSuccess) return -4; - return 0; -} - -extern "C" int banderwagon_cuda_smul_batch( - const uint8_t* pts, const uint8_t* scalars, uint8_t* outs, size_t n) { - if (n == 0) return 0; - if (!pts || !scalars || !outs) return -1; - if (!banderwagon_cuda_available()) return -2; - - DevPtr d_pts, d_scl, d_out; - if (cudaMalloc((void**)&d_pts.p, n * 96) != cudaSuccess) return -3; - if (cudaMalloc((void**)&d_scl.p, n * 32) != cudaSuccess) return -3; - if (cudaMalloc((void**)&d_out.p, n * 96) != cudaSuccess) return -3; - if (cudaMemcpy(d_pts.p, pts, n * 96, cudaMemcpyHostToDevice) != cudaSuccess) return -4; - if (cudaMemcpy(d_scl.p, scalars, n * 32, cudaMemcpyHostToDevice) != cudaSuccess) return -4; - unsigned tg = 32; - unsigned grid = unsigned((n + tg - 1) / tg); - banderwagon_smul_batch<<>>(d_pts.p, d_scl.p, d_out.p, (uint32_t)n); - if (cudaDeviceSynchronize() != cudaSuccess) return -4; - if (cudaMemcpy(outs, d_out.p, n * 96, cudaMemcpyDeviceToHost) != cudaSuccess) return -4; - return 0; -} - -extern "C" int banderwagon_cuda_msm_batch( - const uint8_t* pts, const uint8_t* scalars, uint8_t* outs, - size_t n, size_t M) { - if (n == 0 || M == 0) return 0; - if (!pts || !scalars || !outs) return -1; - if (!banderwagon_cuda_available()) return -2; - - DevPtr d_pts, d_scl, d_out; - if (cudaMalloc((void**)&d_pts.p, n * 96) != cudaSuccess) return -3; - if (cudaMalloc((void**)&d_scl.p, M * n * 32) != cudaSuccess) return -3; - if (cudaMalloc((void**)&d_out.p, M * 96) != cudaSuccess) return -3; - if (cudaMemcpy(d_pts.p, pts, n * 96, cudaMemcpyHostToDevice) != cudaSuccess) return -4; - if (cudaMemcpy(d_scl.p, scalars, M * n * 32, cudaMemcpyHostToDevice) != cudaSuccess) return -4; - unsigned tg = 32; - unsigned grid = unsigned((M + tg - 1) / tg); - banderwagon_msm_batch_naive<<>>(d_pts.p, d_scl.p, d_out.p, - (uint32_t)n, (uint32_t)M); - if (cudaDeviceSynchronize() != cudaSuccess) return -4; - if (cudaMemcpy(outs, d_out.p, M * 96, cudaMemcpyDeviceToHost) != cudaSuccess) return -4; - return 0; -} - -#else // LUX_BANDERWAGON_HAVE_CUDA not defined: stub mode - -extern "C" int banderwagon_cuda_available(void) { return 0; } -extern "C" int banderwagon_cuda_add_batch(const uint8_t*, uint8_t*, size_t) { return -1; } -extern "C" int banderwagon_cuda_double_batch(const uint8_t*, uint8_t*, size_t) { return -1; } -extern "C" int banderwagon_cuda_smul_batch(const uint8_t*, const uint8_t*, uint8_t*, size_t) { return -1; } -extern "C" int banderwagon_cuda_msm_batch(const uint8_t*, const uint8_t*, uint8_t*, size_t, size_t) { return -1; } - -#endif diff --git a/banderwagon/gpu/cuda/banderwagon_driver.h b/banderwagon/gpu/cuda/banderwagon_driver.h deleted file mode 100644 index 120bb9e..0000000 --- a/banderwagon/gpu/cuda/banderwagon_driver.h +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CUDA driver for Banderwagon group ops. Linux/CUDA on real GPUs; host-side -// polyfill when nvcc is unavailable so the CPU oracle test path stays -// exercised. Same encoding as the Metal driver. - -#ifndef LUX_BANDERWAGON_CUDA_DRIVER_H -#define LUX_BANDERWAGON_CUDA_DRIVER_H - -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -int banderwagon_cuda_add_batch(const uint8_t *pairs, - uint8_t *outs, - unsigned long n); - -int banderwagon_cuda_double_batch(const uint8_t *pts, - uint8_t *outs, - unsigned long n); - -int banderwagon_cuda_smul_batch(const uint8_t *pts, - const uint8_t *scalars, - uint8_t *outs, - unsigned long n); - -int banderwagon_cuda_msm_batch(const uint8_t *pts, - const uint8_t *scalars, - uint8_t *outs, - unsigned long n, - unsigned long M); - -#ifdef __cplusplus -} -#endif - -#endif diff --git a/banderwagon/gpu/metal/banderwagon.metal b/banderwagon/gpu/metal/banderwagon.metal deleted file mode 100644 index b28b2ed..0000000 --- a/banderwagon/gpu/metal/banderwagon.metal +++ /dev/null @@ -1,487 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// First-party Metal kernel for Banderwagon group operations. -// -// Byte-equal to lux::banderwagon::Element in banderwagon/cpp/element.cpp -// (twisted Edwards a*x^2 + y^2 = 1 + d*x^2*y^2 over the BLS12-381 scalar -// field; a = -5; d = canonical gnark constant). Constants come from the -// CPU body via banderwagon_gen_metal_constants -> banderwagon_const.metalh, -// so the GPU and CPU share exactly one source of truth -- drift is impossible. -// -// Kernels: -// * banderwagon_add_batch : (P_i, Q_i) -> P_i + Q_i in projective form -// * banderwagon_double_batch : P_i -> 2*P_i in projective form -// * banderwagon_smul_batch : (P_i, s_i) -> [s_i] P_i in projective form -// * banderwagon_msm_window : sum_{i in bucket} P_i (Pippenger inner) -// -// Encoding: each Element is 96 bytes = 32-byte BE Fp X || 32-byte BE Fp Y || -// 32-byte BE Fp Z (Montgomery reduced). Each Fr scalar is 32-byte LE in -// canonical form. -// -// All Fp arithmetic operates on Montgomery limbs (4 x ulong, little-endian). -// All operations are constant-time wrt secret bits (no data-dependent branches -// inside the inner loops). - -#include -using namespace metal; - -#include "banderwagon_const.metalh" - -// ============================================================================= -// 256-bit Montgomery field element layout. 4 x 64-bit limbs, LE. -// ============================================================================= - -struct Fp { - ulong l0, l1, l2, l3; -}; - -struct Pt { - Fp X; - Fp Y; - Fp Z; -}; - -// ============================================================================= -// 64x64 -> 128 multiply via Metal's native mulhi(u64,u64). On Apple silicon -// this maps to the hardware UMULH instruction. -// ============================================================================= - -inline void mul64(ulong a, ulong b, thread ulong &lo, thread ulong &hi) { - lo = a * b; - hi = mulhi(a, b); -} - -inline ulong adc(ulong a, ulong b, thread ulong &carry) { - ulong s = a + b; - ulong c1 = (s < a) ? 1UL : 0UL; - ulong s2 = s + carry; - ulong c2 = (s2 < s) ? 1UL : 0UL; - carry = c1 + c2; - return s2; -} - -inline ulong sbb(ulong a, ulong b, thread ulong &borrow) { - ulong d = a - b; - ulong b1 = (a < b) ? 1UL : 0UL; - ulong d2 = d - borrow; - ulong b2 = (d < borrow) ? 1UL : 0UL; - borrow = b1 + b2; - return d2; -} - -// less_than_q: returns 1 if a < q, 0 otherwise. -inline int fp_less_than_q(thread const Fp &a) { - if (a.l3 != FP_Q_3) return (a.l3 < FP_Q_3) ? 1 : 0; - if (a.l2 != FP_Q_2) return (a.l2 < FP_Q_2) ? 1 : 0; - if (a.l1 != FP_Q_1) return (a.l1 < FP_Q_1) ? 1 : 0; - return (a.l0 < FP_Q_0) ? 1 : 0; -} - -inline void fp_cond_sub_q(thread Fp &a) { - ulong br = 0; - ulong r0 = sbb(a.l0, FP_Q_0, br); - ulong r1 = sbb(a.l1, FP_Q_1, br); - ulong r2 = sbb(a.l2, FP_Q_2, br); - ulong r3 = sbb(a.l3, FP_Q_3, br); - // br=1 => a < q (keep), br=0 => a >= q (take r). - ulong mask = br - 1UL; // br=1 -> 0, br=0 -> all-ones - a.l0 = (a.l0 & ~mask) | (r0 & mask); - a.l1 = (a.l1 & ~mask) | (r1 & mask); - a.l2 = (a.l2 & ~mask) | (r2 & mask); - a.l3 = (a.l3 & ~mask) | (r3 & mask); -} - -inline void fp_cond_add_q(thread Fp &a, ulong mask) { - ulong c = 0; - ulong add0 = FP_Q_0 & mask; - ulong add1 = FP_Q_1 & mask; - ulong add2 = FP_Q_2 & mask; - ulong add3 = FP_Q_3 & mask; - a.l0 = adc(a.l0, add0, c); - a.l1 = adc(a.l1, add1, c); - a.l2 = adc(a.l2, add2, c); - a.l3 = adc(a.l3, add3, c); -} - -// ============================================================================= -// Fp arithmetic. Mirrors fp.cpp byte-for-byte. CIOS Montgomery multiplication. -// ============================================================================= - -inline Fp fp_add(thread const Fp &a, thread const Fp &b) { - Fp r; - ulong c = 0; - r.l0 = adc(a.l0, b.l0, c); - r.l1 = adc(a.l1, b.l1, c); - r.l2 = adc(a.l2, b.l2, c); - r.l3 = adc(a.l3, b.l3, c); - fp_cond_sub_q(r); - return r; -} - -inline Fp fp_sub(thread const Fp &a, thread const Fp &b) { - Fp r; - ulong br = 0; - r.l0 = sbb(a.l0, b.l0, br); - r.l1 = sbb(a.l1, b.l1, br); - r.l2 = sbb(a.l2, b.l2, br); - r.l3 = sbb(a.l3, b.l3, br); - ulong mask = 0UL - br; // br=1 -> all-ones, br=0 -> 0 - fp_cond_add_q(r, mask); - return r; -} - -inline Fp fp_neg(thread const Fp &a) { - ulong nz = (a.l0 | a.l1 | a.l2 | a.l3); - ulong m = 0UL - ((nz | (0UL - nz)) >> 63); // all-ones if nz != 0, else 0 - Fp r; - ulong b = 0; - r.l0 = sbb(FP_Q_0, a.l0, b); - r.l1 = sbb(FP_Q_1, a.l1, b); - r.l2 = sbb(FP_Q_2, a.l2, b); - r.l3 = sbb(FP_Q_3, a.l3, b); - r.l0 &= m; r.l1 &= m; r.l2 &= m; r.l3 &= m; - return r; -} - -// CIOS Montgomery multiplication. Algorithm 2 of El Housni / Botrel. -// Mirrors banderwagon/cpp/fp.cpp::cios_mul byte-for-byte. -inline Fp fp_mul(thread const Fp &a, thread const Fp &b) { - const ulong xl[4] = {a.l0, a.l1, a.l2, a.l3}; - const ulong yl[4] = {b.l0, b.l1, b.l2, b.l3}; - const ulong qq[4] = {FP_Q_0, FP_Q_1, FP_Q_2, FP_Q_3}; - - ulong t[5] = {0, 0, 0, 0, 0}; - for (int i = 0; i < 4; ++i) { - const ulong yi = yl[i]; - - // t += a * b[i] (5 limbs) - ulong cy = 0; - for (int j = 0; j < 4; ++j) { - ulong lo, hi; - mul64(xl[j], yi, lo, hi); - ulong c1 = 0; - ulong s = adc(t[j], lo, c1); - ulong c2 = 0; - ulong s2 = adc(s, cy, c2); - t[j] = s2; - cy = hi + c1 + c2; - } - ulong carry_out = 0; - t[4] = adc(t[4], cy, carry_out); - ulong D = carry_out; - - // m = t[0] * qInvNeg mod 2^64 - ulong m = t[0] * FP_QINV_NEG; - - // t = (t + m*q) >> 64 - cy = 0; - for (int j = 0; j < 4; ++j) { - ulong lo, hi; - mul64(m, qq[j], lo, hi); - ulong c1 = 0; - ulong s = adc(t[j], lo, c1); - ulong c2 = 0; - ulong s2 = adc(s, cy, c2); - t[j] = s2; - cy = hi + c1 + c2; - } - carry_out = 0; - ulong t3_new = adc(t[4], cy, carry_out); - ulong t4_new = adc(0UL, D, carry_out); - - // shift down by one limb - t[0] = t[1]; t[1] = t[2]; t[2] = t[3]; - t[3] = t3_new; - t[4] = t4_new; - } - - Fp r; - r.l0 = t[0]; r.l1 = t[1]; r.l2 = t[2]; r.l3 = t[3]; - - // If t[4] != 0 we definitely overflow once: r -= q. - if (t[4] != 0) { - ulong b = 0; - r.l0 = sbb(r.l0, FP_Q_0, b); - r.l1 = sbb(r.l1, FP_Q_1, b); - r.l2 = sbb(r.l2, FP_Q_2, b); - r.l3 = sbb(r.l3, FP_Q_3, b); - return r; - } - fp_cond_sub_q(r); - return r; -} - -inline Fp fp_square(thread const Fp &a) { return fp_mul(a, a); } - -// ============================================================================= -// Constants (Montgomery form), pulled from the auto-generated header. -// ============================================================================= - -inline Fp fp_zero() { Fp r; r.l0 = 0; r.l1 = 0; r.l2 = 0; r.l3 = 0; return r; } -inline Fp fp_one() { Fp r; r.l0 = FP_R_0; r.l1 = FP_R_1; r.l2 = FP_R_2; r.l3 = FP_R_3; return r; } -inline Fp curve_a() { Fp r; r.l0 = CURVE_A_0; r.l1 = CURVE_A_1; r.l2 = CURVE_A_2; r.l3 = CURVE_A_3; return r; } -inline Fp curve_d() { Fp r; r.l0 = CURVE_D_0; r.l1 = CURVE_D_1; r.l2 = CURVE_D_2; r.l3 = CURVE_D_3; return r; } - -// ============================================================================= -// Banderwagon (= projective twisted Edwards) group ops. -// Unified addition (add-2008-bbjlp). Mirrors element.cpp::Element::add. -// A = Z1*Z2; B = A^2; C = X1*X2; D = Y1*Y2; E = d*C*D; F = B - E; G = B + E -// H = X1+Y1; I = X2+Y2 -// X3 = ((H*I) - C - D) * A * F -// Y3 = (D - a*C) * A * G -// Z3 = F * G -// ============================================================================= - -inline Pt pt_identity() { - Pt p; - p.X = fp_zero(); - p.Y = fp_one(); - p.Z = fp_one(); - return p; -} - -inline Pt pt_add(thread const Pt &p1, thread const Pt &p2) { - Fp d_const = curve_d(); - Fp a_const = curve_a(); - - Fp A = fp_mul(p1.Z, p2.Z); - Fp B = fp_square(A); - Fp C = fp_mul(p1.X, p2.X); - Fp D = fp_mul(p1.Y, p2.Y); - Fp E = fp_mul(d_const, fp_mul(C, D)); - Fp F = fp_sub(B, E); - Fp G = fp_add(B, E); - Fp H = fp_add(p1.X, p1.Y); - Fp I = fp_add(p2.X, p2.Y); - - Pt r; - Fp t = fp_mul(H, I); - t = fp_sub(t, C); - t = fp_sub(t, D); - t = fp_mul(t, A); - r.X = fp_mul(t, F); - - Fp aC = fp_mul(a_const, C); - Fp t2 = fp_sub(D, aC); - t2 = fp_mul(t2, A); - r.Y = fp_mul(t2, G); - - r.Z = fp_mul(F, G); - return r; -} - -// Dedicated doubling (dbl-2008-bbjlp). Mirrors element.cpp::Element::double_self. -// B = (X+Y)^2; C = X^2; D = Y^2; E = a*C; F = E + D; H = Z^2; J = F - 2H -// X3 = (B - C - D) * J; Y3 = F * (E - D); Z3 = F * J -inline Pt pt_double(thread const Pt &p) { - Fp a_const = curve_a(); - - Fp XY = fp_add(p.X, p.Y); - Fp B = fp_square(XY); - Fp C = fp_square(p.X); - Fp D = fp_square(p.Y); - Fp E = fp_mul(a_const, C); - Fp F = fp_add(E, D); - Fp H = fp_square(p.Z); - Fp twoH = fp_add(H, H); - Fp J = fp_sub(F, twoH); - - Pt r; - Fp t = fp_sub(B, C); - t = fp_sub(t, D); - r.X = fp_mul(t, J); - r.Y = fp_mul(F, fp_sub(E, D)); - r.Z = fp_mul(F, J); - return r; -} - -// Constant-time conditional move: dst = mask ? src : dst (mask is 0 or all-1). -inline void pt_cmov(thread Pt &dst, thread const Pt &src, ulong mask) { - dst.X.l0 = (dst.X.l0 & ~mask) | (src.X.l0 & mask); - dst.X.l1 = (dst.X.l1 & ~mask) | (src.X.l1 & mask); - dst.X.l2 = (dst.X.l2 & ~mask) | (src.X.l2 & mask); - dst.X.l3 = (dst.X.l3 & ~mask) | (src.X.l3 & mask); - dst.Y.l0 = (dst.Y.l0 & ~mask) | (src.Y.l0 & mask); - dst.Y.l1 = (dst.Y.l1 & ~mask) | (src.Y.l1 & mask); - dst.Y.l2 = (dst.Y.l2 & ~mask) | (src.Y.l2 & mask); - dst.Y.l3 = (dst.Y.l3 & ~mask) | (src.Y.l3 & mask); - dst.Z.l0 = (dst.Z.l0 & ~mask) | (src.Z.l0 & mask); - dst.Z.l1 = (dst.Z.l1 & ~mask) | (src.Z.l1 & mask); - dst.Z.l2 = (dst.Z.l2 & ~mask) | (src.Z.l2 & mask); - dst.Z.l3 = (dst.Z.l3 & ~mask) | (src.Z.l3 & mask); -} - -// Constant-time scalar multiplication. Mirrors Element::scalar_mul. -// Iterates LSB->MSB across canonical 32-byte LE scalar bytes. -// At each bit: compute acc + base unconditionally; cmov-select via mask. -// Always double the base. -inline Pt pt_scalar_mul(thread const Pt &p, device const uchar *s_le) { - Pt acc = pt_identity(); - Pt base = p; - for (int byte_idx = 0; byte_idx < 32; ++byte_idx) { - uchar b = s_le[byte_idx]; - for (int bit = 0; bit < 8; ++bit) { - ulong one_or_zero = (ulong)((b >> bit) & 1u); - ulong mask = 0UL - one_or_zero; - Pt sum = pt_add(acc, base); - pt_cmov(acc, sum, mask); - base = pt_double(base); - } - } - return acc; -} - -// Same scalar-mul but reads scalar from a `thread` 32-byte buffer. -inline Pt pt_scalar_mul_thread(thread const Pt &p, thread const uchar s_le[32]) { - Pt acc = pt_identity(); - Pt base = p; - for (int byte_idx = 0; byte_idx < 32; ++byte_idx) { - uchar b = s_le[byte_idx]; - for (int bit = 0; bit < 8; ++bit) { - ulong one_or_zero = (ulong)((b >> bit) & 1u); - ulong mask = 0UL - one_or_zero; - Pt sum = pt_add(acc, base); - pt_cmov(acc, sum, mask); - base = pt_double(base); - } - } - return acc; -} - -// ============================================================================= -// Bytes <-> Fp (Montgomery) and bytes <-> Pt encoders. -// -// Pt is serialized as 96 bytes = 32B BE Fp X || 32B BE Fp Y || 32B BE Fp Z, -// each 32-byte slice already a Montgomery-form Fp's canonical-image bytes. -// (We use the *direct* Mont-limb encoding for kernel I/O; the determinism -// test marshals CPU-side Element.X/Y/Z limbs verbatim.) -// -// For raw limbs we read 4 x 8 LE bytes per Fp from a 32-byte chunk treated -// as the Montgomery limb array (NOT the canonical-integer encoding from -// fp::to_bytes_be -- that path requires from_mont conversion). -// ============================================================================= - -inline Fp read_fp_limbs(device const uchar *p) { - Fp r; - r.l0 = ((ulong)p[0]) | ((ulong)p[1] << 8) | ((ulong)p[2] << 16) | ((ulong)p[3] << 24) - | ((ulong)p[4] << 32) | ((ulong)p[5] << 40) | ((ulong)p[6] << 48) | ((ulong)p[7] << 56); - r.l1 = ((ulong)p[8]) | ((ulong)p[9] << 8) | ((ulong)p[10] << 16) | ((ulong)p[11] << 24) - | ((ulong)p[12] << 32) | ((ulong)p[13] << 40) | ((ulong)p[14] << 48) | ((ulong)p[15] << 56); - r.l2 = ((ulong)p[16]) | ((ulong)p[17] << 8) | ((ulong)p[18] << 16) | ((ulong)p[19] << 24) - | ((ulong)p[20] << 32) | ((ulong)p[21] << 40) | ((ulong)p[22] << 48) | ((ulong)p[23] << 56); - r.l3 = ((ulong)p[24]) | ((ulong)p[25] << 8) | ((ulong)p[26] << 16) | ((ulong)p[27] << 24) - | ((ulong)p[28] << 32) | ((ulong)p[29] << 40) | ((ulong)p[30] << 48) | ((ulong)p[31] << 56); - return r; -} - -inline void write_fp_limbs(thread const Fp &x, device uchar *p) { - ulong v = x.l0; - for (int i = 0; i < 8; ++i) { p[i] = (uchar)(v & 0xff); v >>= 8; } - v = x.l1; - for (int i = 0; i < 8; ++i) { p[8 + i] = (uchar)(v & 0xff); v >>= 8; } - v = x.l2; - for (int i = 0; i < 8; ++i) { p[16 + i] = (uchar)(v & 0xff); v >>= 8; } - v = x.l3; - for (int i = 0; i < 8; ++i) { p[24 + i] = (uchar)(v & 0xff); v >>= 8; } -} - -inline Pt read_pt(device const uchar *p) { - Pt r; - r.X = read_fp_limbs(p); - r.Y = read_fp_limbs(p + 32); - r.Z = read_fp_limbs(p + 64); - return r; -} - -inline void write_pt(thread const Pt &p, device uchar *out) { - write_fp_limbs(p.X, out); - write_fp_limbs(p.Y, out + 32); - write_fp_limbs(p.Z, out + 64); -} - -// ============================================================================= -// Kernels. One thread per element. -// ============================================================================= - -// banderwagon_add_batch: out_i = P_i + Q_i. -// pairs: n * 192 bytes = [Pt P_i (96 B) || Pt Q_i (96 B)] for i=0..n-1. -// outs: n * 96 bytes. -kernel void banderwagon_add_batch( - device const uchar *pairs [[buffer(0)]], - device uchar *outs [[buffer(1)]], - constant uint &n [[buffer(2)]], - uint i [[thread_position_in_grid]]) -{ - if (i >= n) return; - Pt P = read_pt(pairs + i * 192); - Pt Q = read_pt(pairs + i * 192 + 96); - Pt R = pt_add(P, Q); - write_pt(R, outs + i * 96); -} - -// banderwagon_double_batch: out_i = 2 * P_i. -// pts: n * 96 bytes. -// outs: n * 96 bytes. -kernel void banderwagon_double_batch( - device const uchar *pts [[buffer(0)]], - device uchar *outs [[buffer(1)]], - constant uint &n [[buffer(2)]], - uint i [[thread_position_in_grid]]) -{ - if (i >= n) return; - Pt P = read_pt(pts + i * 96); - Pt R = pt_double(P); - write_pt(R, outs + i * 96); -} - -// banderwagon_smul_batch: out_i = scalar_i * P_i. -// pts: n * 96 bytes. -// scalars: n * 32 bytes (canonical LE Fr scalar bytes). -// outs: n * 96 bytes. -kernel void banderwagon_smul_batch( - device const uchar *pts [[buffer(0)]], - device const uchar *scalars [[buffer(1)]], - device uchar *outs [[buffer(2)]], - constant uint &n [[buffer(3)]], - uint i [[thread_position_in_grid]]) -{ - if (i >= n) return; - Pt P = read_pt(pts + i * 96); - Pt R = pt_scalar_mul(P, scalars + i * 32); - write_pt(R, outs + i * 96); -} - -// banderwagon_msm: result_b = sum_{i=0..n-1} scalar_i[b] * P_i, for batch b -// in [0..M). -// -// Layout: -// pts: n * 96 bytes (shared across batches). -// scalars: M * n * 32 bytes; scalar for batch b, point i lives at offset -// (b*n + i) * 32. -// outs: M * 96 bytes. -// n: points per MSM. -// M: number of independent MSMs (one thread per MSM). -// -// One thread per output MSM. Naive double-and-add over all (P_i, s_i); the -// outer batching gives us many independent MSMs in parallel. This trades -// per-MSM throughput for parallelism across batches and is byte-equal to a -// naive CPU MSM (= sum of s_i * P_i computed with the same Element::add / -// Element::scalar_mul primitives). -kernel void banderwagon_msm_batch_naive( - device const uchar *pts [[buffer(0)]], - device const uchar *scalars [[buffer(1)]], - device uchar *outs [[buffer(2)]], - constant uint &n [[buffer(3)]], - constant uint &M [[buffer(4)]], - uint b [[thread_position_in_grid]]) -{ - if (b >= M) return; - Pt acc = pt_identity(); - for (uint i = 0; i < n; ++i) { - Pt P = read_pt(pts + i * 96); - Pt term = pt_scalar_mul(P, scalars + (b * n + i) * 32); - acc = pt_add(acc, term); - } - write_pt(acc, outs + b * 96); -} diff --git a/banderwagon/gpu/metal/banderwagon_driver.h b/banderwagon/gpu/metal/banderwagon_driver.h deleted file mode 100644 index d3df8c6..0000000 --- a/banderwagon/gpu/metal/banderwagon_driver.h +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Metal driver for Banderwagon group ops. macOS / iOS only. -// -// Byte-equal-by-construction to lux::banderwagon::Element {add, double_self, -// scalar_mul} in banderwagon/cpp/element.cpp. Constants are emitted from the -// CPU body via banderwagon_gen_metal_constants -> banderwagon_const.metalh -// and #include'd into banderwagon.metal, so there is exactly one source of -// truth for the modulus / Montgomery / curve constants. -// -// All Pt buffers carry 96 bytes per element (32B X || 32B Y || 32B Z, each -// the Montgomery-form Fp limbs in little-endian byte order). All Fr scalars -// carry 32 bytes in canonical little-endian (the same encoding as -// Fr::to_bytes_le). - -#ifndef LUX_BANDERWAGON_METAL_DRIVER_H -#define LUX_BANDERWAGON_METAL_DRIVER_H - -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -// Run n point additions in one Metal dispatch. -// pairs : n * 192 bytes = [Pt P || Pt Q] per pair -// outs : n * 96 bytes = result per pair -// Returns 0 on success, negative on failure (-1 invalid arg, -2 device, -// -3 lib load, -4 function lookup, -5 pipeline create). -int banderwagon_metal_add_batch( - const uint8_t* pairs, - uint8_t* outs, - size_t n, - const char* metallib_path); - -// Run n point doublings in one Metal dispatch. -// pts : n * 96 bytes -// outs : n * 96 bytes -int banderwagon_metal_double_batch( - const uint8_t* pts, - uint8_t* outs, - size_t n, - const char* metallib_path); - -// Run n scalar multiplications in one Metal dispatch. -// pts : n * 96 bytes -// scalars : n * 32 bytes (canonical LE Fr) -// outs : n * 96 bytes -int banderwagon_metal_smul_batch( - const uint8_t* pts, - const uint8_t* scalars, - uint8_t* outs, - size_t n, - const char* metallib_path); - -// Run M independent MSMs in one Metal dispatch (one thread per MSM). -// pts : n * 96 bytes (shared across all M) -// scalars : M * n * 32 bytes (scalar for batch b, point i is at -// offset (b*n + i) * 32) -// outs : M * 96 bytes -// n : points per MSM -// M : number of MSMs -int banderwagon_metal_msm_batch( - const uint8_t* pts, - const uint8_t* scalars, - uint8_t* outs, - size_t n, - size_t M, - const char* metallib_path); - -#ifdef __cplusplus -} -#endif - -#endif // LUX_BANDERWAGON_METAL_DRIVER_H diff --git a/banderwagon/gpu/metal/banderwagon_driver.mm b/banderwagon/gpu/metal/banderwagon_driver.mm deleted file mode 100644 index 25085a3..0000000 --- a/banderwagon/gpu/metal/banderwagon_driver.mm +++ /dev/null @@ -1,164 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Metal driver for Banderwagon group ops. - -#if __APPLE__ && __OBJC__ - -#import -#import - -#include "banderwagon_driver.h" - -#include -#include -#include - -namespace { - -// Generic single-kernel dispatch helper. Sets up the Metal pipeline, binds -// the listed buffers in order, dispatches `count` threads (one per work item), -// waits for completion, then memcpy's the output buffer back to host. -// -// `bufs` is an array of (host_ptr, length, is_output) tuples. is_output -// pinned buffers are allocated and copied back; non-output are uploaded. -struct BufSpec { - const uint8_t* host_in; // null for output-only - uint8_t* host_out; // null for input-only - size_t length; - bool is_output; -}; - -int run_kernel(const char* kernel_name, - const BufSpec* bufs, size_t nbufs, - size_t threads, - const char* metallib_path) { - @autoreleasepool { - id device = MTLCreateSystemDefaultDevice(); - if (!device) return -2; - - NSError* err = nil; - NSString* path = [NSString stringWithUTF8String:metallib_path]; - NSURL* url = [NSURL fileURLWithPath:path]; - id lib = [device newLibraryWithURL:url error:&err]; - if (!lib) return -3; - - NSString* fname = [NSString stringWithUTF8String:kernel_name]; - id fn = [lib newFunctionWithName:fname]; - if (!fn) return -4; - - id pipeline = - [device newComputePipelineStateWithFunction:fn error:&err]; - if (!pipeline) return -5; - - id queue = [device newCommandQueue]; - - NSMutableArray>* mtlbufs = [NSMutableArray array]; - for (size_t i = 0; i < nbufs; ++i) { - const BufSpec& b = bufs[i]; - id buf; - if (b.is_output) { - buf = [device newBufferWithLength:b.length - options:MTLResourceStorageModeShared]; - } else { - buf = [device newBufferWithBytes:b.host_in - length:b.length - options:MTLResourceStorageModeShared]; - } - [mtlbufs addObject:buf]; - } - - id cmd = [queue commandBuffer]; - id enc = [cmd computeCommandEncoder]; - [enc setComputePipelineState:pipeline]; - for (size_t i = 0; i < nbufs; ++i) { - [enc setBuffer:mtlbufs[i] offset:0 atIndex:i]; - } - - NSUInteger tg_max = pipeline.maxTotalThreadsPerThreadgroup; - NSUInteger tg_w = tg_max < 64 ? tg_max : 64; - if (tg_w > threads) tg_w = threads; - MTLSize threads_per_grid = MTLSizeMake(threads, 1, 1); - MTLSize threads_per_tg = MTLSizeMake(tg_w, 1, 1); - [enc dispatchThreads:threads_per_grid threadsPerThreadgroup:threads_per_tg]; - [enc endEncoding]; - [cmd commit]; - [cmd waitUntilCompleted]; - - for (size_t i = 0; i < nbufs; ++i) { - const BufSpec& b = bufs[i]; - if (b.is_output && b.host_out) { - std::memcpy(b.host_out, [mtlbufs[i] contents], b.length); - } - } - } - return 0; -} - -} // namespace - -extern "C" int banderwagon_metal_add_batch( - const uint8_t* pairs, uint8_t* outs, size_t n, - const char* metallib_path) { - if (n == 0) return 0; - if (!pairs || !outs || !metallib_path) return -1; - - uint32_t n_u32 = (uint32_t)n; - BufSpec bufs[3] = { - { pairs, nullptr, n * 192, false }, - { nullptr, outs, n * 96, true }, - { (const uint8_t*)&n_u32, nullptr, sizeof(n_u32), false }, - }; - return run_kernel("banderwagon_add_batch", bufs, 3, n, metallib_path); -} - -extern "C" int banderwagon_metal_double_batch( - const uint8_t* pts, uint8_t* outs, size_t n, - const char* metallib_path) { - if (n == 0) return 0; - if (!pts || !outs || !metallib_path) return -1; - - uint32_t n_u32 = (uint32_t)n; - BufSpec bufs[3] = { - { pts, nullptr, n * 96, false }, - { nullptr, outs, n * 96, true }, - { (const uint8_t*)&n_u32, nullptr, sizeof(n_u32), false }, - }; - return run_kernel("banderwagon_double_batch", bufs, 3, n, metallib_path); -} - -extern "C" int banderwagon_metal_smul_batch( - const uint8_t* pts, const uint8_t* scalars, uint8_t* outs, size_t n, - const char* metallib_path) { - if (n == 0) return 0; - if (!pts || !scalars || !outs || !metallib_path) return -1; - - uint32_t n_u32 = (uint32_t)n; - BufSpec bufs[4] = { - { pts, nullptr, n * 96, false }, - { scalars, nullptr, n * 32, false }, - { nullptr, outs, n * 96, true }, - { (const uint8_t*)&n_u32, nullptr, sizeof(n_u32), false }, - }; - return run_kernel("banderwagon_smul_batch", bufs, 4, n, metallib_path); -} - -extern "C" int banderwagon_metal_msm_batch( - const uint8_t* pts, const uint8_t* scalars, uint8_t* outs, - size_t n, size_t M, const char* metallib_path) { - if (n == 0 || M == 0) return 0; - if (!pts || !scalars || !outs || !metallib_path) return -1; - - uint32_t n_u32 = (uint32_t)n; - uint32_t M_u32 = (uint32_t)M; - BufSpec bufs[5] = { - { pts, nullptr, n * 96, false }, - { scalars, nullptr, M * n * 32, false }, - { nullptr, outs, M * 96, true }, - { (const uint8_t*)&n_u32, nullptr, sizeof(n_u32), false }, - { (const uint8_t*)&M_u32, nullptr, sizeof(M_u32), false }, - }; - return run_kernel("banderwagon_msm_batch_naive", bufs, 5, M, metallib_path); -} - -#endif // __APPLE__ && __OBJC__ diff --git a/banderwagon/gpu/metal/banderwagon_msm.metal b/banderwagon/gpu/metal/banderwagon_msm.metal deleted file mode 100644 index 6b5ea90..0000000 --- a/banderwagon/gpu/metal/banderwagon_msm.metal +++ /dev/null @@ -1,51 +0,0 @@ -// ============================================================================= -// luxcpp/crypto/banderwagon -- Metal MSM kernel (WIP) -// ============================================================================= -// -// 256-bit Montgomery field arithmetic over q = BLS12-381 r, plus twisted -// Edwards add/double for Bandersnatch (a = -5), plus windowed Pippenger MSM. -// -// Status: source compiles with `xcrun -sdk macosx metal` and lays out the -// kernel signature; bucket-aggregation and bit-reduction step are not yet -// validated against the CPU oracle. The driver does not dispatch this kernel -// until validation lands; the C-ABI shim therefore falls back to the CPU -// body. The kernel is kept in-tree so the metallib build target exists for -// the follow-up validation PR. -// -// Copyright (C) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// ============================================================================= - -#include -using namespace metal; - -// 4 x 64-bit limb little-endian field element. Canonical (non-Montgomery) -// representation, mirrors Fp on the CPU side (cpp/banderwagon.cpp). -struct Fp { ulong4 v; }; - -// Banderwagon affine point. -struct PointAffine { Fp x; Fp y; }; - -// Window size (must match CPU msm_pippenger). -constant uint kWindow = 8u; -constant uint kBuckets = (1u << kWindow) - 1u; - -// MSM dispatch signature. Each thread handles one (scalar, point) pair within -// one window; outer windows are aggregated on the host. Bucket sums are -// reduced per workgroup via threadgroup memory. -// -// THIS KERNEL BODY IS A PLACEHOLDER. It does not perform the MSM; the driver -// returns NOTIMPL and the host code falls back to CPU. See banderwagon_driver.mm. -[[host_name("banderwagon_msm_window")]] -kernel void banderwagon_msm_window_kernel( - const device PointAffine* points [[ buffer(0) ]], - const device Fp* scalars [[ buffer(1) ]], - device PointAffine* bucket_out [[ buffer(2) ]], - constant uint& n [[ buffer(3) ]], - constant uint& window_idx [[ buffer(4) ]], - uint gid [[ thread_position_in_grid ]]) -{ - if (gid >= n) return; - // Placeholder; real bucket accumulation goes here. - bucket_out[gid] = points[gid]; -} diff --git a/banderwagon/gpu/wgsl/banderwagon.wgsl b/banderwagon/gpu/wgsl/banderwagon.wgsl deleted file mode 100644 index 8a312c4..0000000 --- a/banderwagon/gpu/wgsl/banderwagon.wgsl +++ /dev/null @@ -1,460 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// First-party WGSL kernel for Banderwagon group operations. -// -// Mechanical port of banderwagon/gpu/metal/banderwagon.metal -- byte-for-byte -// equivalent to lux::banderwagon::Element {add, double_self, scalar_mul} in -// banderwagon/cpp/element.cpp (twisted Edwards a*x^2 + y^2 = 1 + d*x^2*y^2 -// over the BLS12-381 scalar field; a = -5; d = canonical gnark constant). -// -// WGSL has no native u64. Each 64-bit Montgomery limb is represented as a -// pair of u32 (lo, hi), and 64-bit ops (adc, sbb, mulhi) are reconstructed -// from 32-bit primitives. Same approach as poseidon/gpu/wgsl/poseidon2_bn254. -// -// The constant table comes from banderwagon_const.wgslh, which is emitted by -// the CPU body's internal accessors via banderwagon_gen_gpu_constants. There -// is exactly one source of truth across CPU, Metal, CUDA, and WGSL. -// -// The host driver concatenates banderwagon_const.wgslh in front of this file -// before submitting it to wgpu. The kernel below references FP_Q_*_LO/HI, -// FP_QINV_NEG_LO/HI, etc. - -// ============================================================================= -// 64-bit unsigned represented as (lo, hi) pair of u32. Same layout as the -// host polyfill in banderwagon_driver.cpp. -// ============================================================================= -struct U64 { lo: u32, hi: u32 }; - -fn u64_zero() -> U64 { return U64(0u, 0u); } - -fn u64_lt(a: U64, b: U64) -> bool { - if (a.hi != b.hi) { return a.hi < b.hi; } - return a.lo < b.lo; -} - -fn u64_eq(a: U64, b: U64) -> bool { - return a.lo == b.lo && a.hi == b.hi; -} - -struct U64Carry { v: U64, carry: u32 }; - -fn u64_add_carry(a: U64, b: U64, cin: u32) -> U64Carry { - let lo1 = a.lo + b.lo; - let c0: u32 = select(0u, 1u, lo1 < a.lo); - let lo = lo1 + cin; - let c1: u32 = select(0u, 1u, lo < lo1); - let hi1 = a.hi + b.hi; - let c2: u32 = select(0u, 1u, hi1 < a.hi); - let hi2 = hi1 + (c0 + c1); - let c3: u32 = select(0u, 1u, hi2 < hi1); - return U64Carry(U64(lo, hi2), c2 + c3); -} - -struct U64Borrow { v: U64, borrow: u32 }; - -fn u64_sub_borrow(a: U64, b: U64, bin: u32) -> U64Borrow { - let lo1: u32 = a.lo - b.lo; - let bor0: u32 = select(0u, 1u, a.lo < b.lo); - let lo: u32 = lo1 - bin; - let bor1: u32 = select(0u, 1u, lo1 < bin); - let hi1: u32 = a.hi - b.hi; - let bor2: u32 = select(0u, 1u, a.hi < b.hi); - let hi: u32 = hi1 - (bor0 + bor1); - let bor3: u32 = select(0u, 1u, hi1 < (bor0 + bor1)); - return U64Borrow(U64(lo, hi), bor2 + bor3); -} - -fn u32_mul64(a: u32, b: u32) -> U64 { - let al: u32 = a & 0xffffu; - let ah: u32 = a >> 16u; - let bl: u32 = b & 0xffffu; - let bh: u32 = b >> 16u; - let ll: u32 = al * bl; - let lh: u32 = al * bh; - let hl: u32 = ah * bl; - let hh: u32 = ah * bh; - let mid_a: u32 = (ll >> 16u) + (lh & 0xffffu); - let mid_b: u32 = mid_a + (hl & 0xffffu); - let mid_carry: u32 = (mid_b >> 16u); - let lo: u32 = (ll & 0xffffu) | (mid_b << 16u); - let hi: u32 = hh + (lh >> 16u) + (hl >> 16u) + mid_carry; - return U64(lo, hi); -} - -struct U128 { l0: u32, l1: u32, l2: u32, l3: u32 }; - -fn umul64(a: U64, b: U64) -> U128 { - let p_ll: U64 = u32_mul64(a.lo, b.lo); - let p_lh: U64 = u32_mul64(a.lo, b.hi); - let p_hl: U64 = u32_mul64(a.hi, b.lo); - let p_hh: U64 = u32_mul64(a.hi, b.hi); - let w0: u32 = p_ll.lo; - let s1a: u32 = p_ll.hi + p_lh.lo; - let c1a: u32 = select(0u, 1u, s1a < p_ll.hi); - let w1: u32 = s1a + p_hl.lo; - let c1b: u32 = select(0u, 1u, w1 < s1a); - let carry1: u32 = c1a + c1b; - let s2a: u32 = p_lh.hi + p_hl.hi; - let c2a: u32 = select(0u, 1u, s2a < p_lh.hi); - let s2b: u32 = s2a + p_hh.lo; - let c2b: u32 = select(0u, 1u, s2b < s2a); - let w2: u32 = s2b + carry1; - let c2c: u32 = select(0u, 1u, w2 < s2b); - let carry2: u32 = c2a + c2b + c2c; - let w3: u32 = p_hh.hi + carry2; - return U128(w0, w1, w2, w3); -} - -fn u64_mul_low(a: U64, b: U64) -> U64 { - let p_ll: U64 = u32_mul64(a.lo, b.lo); - let p_lh: U64 = u32_mul64(a.lo, b.hi); - let p_hl: U64 = u32_mul64(a.hi, b.lo); - let lo: u32 = p_ll.lo; - let hi: u32 = p_ll.hi + p_lh.lo + p_hl.lo; - return U64(lo, hi); -} - -// ============================================================================= -// 256-bit Fp in Montgomery form, four U64 limbs. -// ============================================================================= -struct Fp { l0: U64, l1: U64, l2: U64, l3: U64 }; -struct Pt { X: Fp, Y: Fp, Z: Fp }; - -fn fp_zero() -> Fp { return Fp(u64_zero(), u64_zero(), u64_zero(), u64_zero()); } - -fn fp_q() -> Fp { - return Fp(U64(FP_Q_0_LO, FP_Q_0_HI), U64(FP_Q_1_LO, FP_Q_1_HI), - U64(FP_Q_2_LO, FP_Q_2_HI), U64(FP_Q_3_LO, FP_Q_3_HI)); -} - -fn fp_one() -> Fp { - return Fp(U64(FP_R_0_LO, FP_R_0_HI), U64(FP_R_1_LO, FP_R_1_HI), - U64(FP_R_2_LO, FP_R_2_HI), U64(FP_R_3_LO, FP_R_3_HI)); -} - -fn fp_curve_a() -> Fp { - return Fp(U64(CURVE_A_0_LO, CURVE_A_0_HI), U64(CURVE_A_1_LO, CURVE_A_1_HI), - U64(CURVE_A_2_LO, CURVE_A_2_HI), U64(CURVE_A_3_LO, CURVE_A_3_HI)); -} - -fn fp_curve_d() -> Fp { - return Fp(U64(CURVE_D_0_LO, CURVE_D_0_HI), U64(CURVE_D_1_LO, CURVE_D_1_HI), - U64(CURVE_D_2_LO, CURVE_D_2_HI), U64(CURVE_D_3_LO, CURVE_D_3_HI)); -} - -fn fp_qinv() -> U64 { return U64(FP_QINV_NEG_LO, FP_QINV_NEG_HI); } - -fn fp_cond_sub_q(a_in: Fp) -> Fp { - let q = fp_q(); - let r0 = u64_sub_borrow(a_in.l0, q.l0, 0u); - let r1 = u64_sub_borrow(a_in.l1, q.l1, r0.borrow); - let r2 = u64_sub_borrow(a_in.l2, q.l2, r1.borrow); - let r3 = u64_sub_borrow(a_in.l3, q.l3, r2.borrow); - let br: u32 = r3.borrow; - let mask: u32 = select(0xffffffffu, 0u, br == 1u); - return Fp( - U64((a_in.l0.lo & ~mask) | (r0.v.lo & mask), - (a_in.l0.hi & ~mask) | (r0.v.hi & mask)), - U64((a_in.l1.lo & ~mask) | (r1.v.lo & mask), - (a_in.l1.hi & ~mask) | (r1.v.hi & mask)), - U64((a_in.l2.lo & ~mask) | (r2.v.lo & mask), - (a_in.l2.hi & ~mask) | (r2.v.hi & mask)), - U64((a_in.l3.lo & ~mask) | (r3.v.lo & mask), - (a_in.l3.hi & ~mask) | (r3.v.hi & mask)), - ); -} - -fn fp_cond_add_q(a_in: Fp, mask: u32) -> Fp { - let q = fp_q(); - let add0 = U64(q.l0.lo & mask, q.l0.hi & mask); - let add1 = U64(q.l1.lo & mask, q.l1.hi & mask); - let add2 = U64(q.l2.lo & mask, q.l2.hi & mask); - let add3 = U64(q.l3.lo & mask, q.l3.hi & mask); - let r0 = u64_add_carry(a_in.l0, add0, 0u); - let r1 = u64_add_carry(a_in.l1, add1, r0.carry); - let r2 = u64_add_carry(a_in.l2, add2, r1.carry); - let r3 = u64_add_carry(a_in.l3, add3, r2.carry); - return Fp(r0.v, r1.v, r2.v, r3.v); -} - -fn fp_add(a: Fp, b: Fp) -> Fp { - let r0 = u64_add_carry(a.l0, b.l0, 0u); - let r1 = u64_add_carry(a.l1, b.l1, r0.carry); - let r2 = u64_add_carry(a.l2, b.l2, r1.carry); - let r3 = u64_add_carry(a.l3, b.l3, r2.carry); - return fp_cond_sub_q(Fp(r0.v, r1.v, r2.v, r3.v)); -} - -fn fp_sub(a: Fp, b: Fp) -> Fp { - let r0 = u64_sub_borrow(a.l0, b.l0, 0u); - let r1 = u64_sub_borrow(a.l1, b.l1, r0.borrow); - let r2 = u64_sub_borrow(a.l2, b.l2, r1.borrow); - let r3 = u64_sub_borrow(a.l3, b.l3, r2.borrow); - let mask: u32 = select(0u, 0xffffffffu, r3.borrow == 1u); - return fp_cond_add_q(Fp(r0.v, r1.v, r2.v, r3.v), mask); -} - -fn fp_mul(a: Fp, b: Fp) -> Fp { - var t0 = u64_zero(); - var t1 = u64_zero(); - var t2 = u64_zero(); - var t3 = u64_zero(); - var t4 = u64_zero(); - let al = array(a.l0, a.l1, a.l2, a.l3); - let bl = array(b.l0, b.l1, b.l2, b.l3); - let qq = array(U64(FP_Q_0_LO, FP_Q_0_HI), - U64(FP_Q_1_LO, FP_Q_1_HI), - U64(FP_Q_2_LO, FP_Q_2_HI), - U64(FP_Q_3_LO, FP_Q_3_HI)); - let qinv = fp_qinv(); - - for (var i: i32 = 0; i < 4; i = i + 1) { - var cy: U64 = u64_zero(); - for (var j: i32 = 0; j < 4; j = j + 1) { - let prod = umul64(al[j], bl[i]); - let lo = U64(prod.l0, prod.l1); - let hi = U64(prod.l2, prod.l3); - var tj: U64 = u64_zero(); - if (j == 0) { tj = t0; } - else if (j == 1) { tj = t1; } - else if (j == 2) { tj = t2; } - else { tj = t3; } - let s = u64_add_carry(tj, lo, 0u); - let s2 = u64_add_carry(s.v, cy, 0u); - if (j == 0) { t0 = s2.v; } - else if (j == 1) { t1 = s2.v; } - else if (j == 2) { t2 = s2.v; } - else { t3 = s2.v; } - let cy1 = u64_add_carry(hi, U64(s.carry, 0u), 0u); - let cy2 = u64_add_carry(cy1.v, U64(s2.carry, 0u), 0u); - cy = cy2.v; - } - let t4u = u64_add_carry(t4, cy, 0u); - t4 = t4u.v; - let big_carry: u32 = t4u.carry; - - let m: U64 = u64_mul_low(t0, qinv); - - cy = u64_zero(); - for (var j: i32 = 0; j < 4; j = j + 1) { - let prod = umul64(m, qq[j]); - let lo = U64(prod.l0, prod.l1); - let hi = U64(prod.l2, prod.l3); - var tj: U64 = u64_zero(); - if (j == 0) { tj = t0; } - else if (j == 1) { tj = t1; } - else if (j == 2) { tj = t2; } - else { tj = t3; } - let s = u64_add_carry(tj, lo, 0u); - let s2 = u64_add_carry(s.v, cy, 0u); - if (j == 0) { t0 = s2.v; } - else if (j == 1) { t1 = s2.v; } - else if (j == 2) { t2 = s2.v; } - else { t3 = s2.v; } - let cy1 = u64_add_carry(hi, U64(s.carry, 0u), 0u); - let cy2 = u64_add_carry(cy1.v, U64(s2.carry, 0u), 0u); - cy = cy2.v; - } - let t3_step = u64_add_carry(t4, cy, 0u); - let t4_step = u64_add_carry(u64_zero(), - U64(big_carry, 0u), - t3_step.carry); - - t0 = t1; t1 = t2; t2 = t3; - t3 = t3_step.v; - t4 = t4_step.v; - } - - var r = Fp(t0, t1, t2, t3); - if (!(t4.lo == 0u && t4.hi == 0u)) { - let s0 = u64_sub_borrow(r.l0, U64(FP_Q_0_LO, FP_Q_0_HI), 0u); - let s1 = u64_sub_borrow(r.l1, U64(FP_Q_1_LO, FP_Q_1_HI), s0.borrow); - let s2 = u64_sub_borrow(r.l2, U64(FP_Q_2_LO, FP_Q_2_HI), s1.borrow); - let s3 = u64_sub_borrow(r.l3, U64(FP_Q_3_LO, FP_Q_3_HI), s2.borrow); - return Fp(s0.v, s1.v, s2.v, s3.v); - } - return fp_cond_sub_q(r); -} - -fn fp_square(a: Fp) -> Fp { return fp_mul(a, a); } - -fn pt_identity() -> Pt { return Pt(fp_zero(), fp_one(), fp_one()); } - -fn pt_add(p1: Pt, p2: Pt) -> Pt { - let d_const = fp_curve_d(); - let a_const = fp_curve_a(); - - let A = fp_mul(p1.Z, p2.Z); - let B = fp_square(A); - let C = fp_mul(p1.X, p2.X); - let D = fp_mul(p1.Y, p2.Y); - let CD = fp_mul(C, D); - let E = fp_mul(d_const, CD); - let F = fp_sub(B, E); - let G = fp_add(B, E); - let H = fp_add(p1.X, p1.Y); - let I = fp_add(p2.X, p2.Y); - - var t = fp_mul(H, I); - t = fp_sub(t, C); - t = fp_sub(t, D); - t = fp_mul(t, A); - let X3 = fp_mul(t, F); - - let aC = fp_mul(a_const, C); - var t2 = fp_sub(D, aC); - t2 = fp_mul(t2, A); - let Y3 = fp_mul(t2, G); - - let Z3 = fp_mul(F, G); - return Pt(X3, Y3, Z3); -} - -fn pt_double(p: Pt) -> Pt { - let a_const = fp_curve_a(); - let XY = fp_add(p.X, p.Y); - let B = fp_square(XY); - let C = fp_square(p.X); - let D = fp_square(p.Y); - let E = fp_mul(a_const, C); - let F = fp_add(E, D); - let H = fp_square(p.Z); - let twoH = fp_add(H, H); - let J = fp_sub(F, twoH); - - var t = fp_sub(B, C); - t = fp_sub(t, D); - let X3 = fp_mul(t, J); - let Y3 = fp_mul(F, fp_sub(E, D)); - let Z3 = fp_mul(F, J); - return Pt(X3, Y3, Z3); -} - -fn pt_cmov_select(dst: Pt, src: Pt, mask: u32) -> Pt { - return Pt( - Fp(U64((dst.X.l0.lo & ~mask) | (src.X.l0.lo & mask), - (dst.X.l0.hi & ~mask) | (src.X.l0.hi & mask)), - U64((dst.X.l1.lo & ~mask) | (src.X.l1.lo & mask), - (dst.X.l1.hi & ~mask) | (src.X.l1.hi & mask)), - U64((dst.X.l2.lo & ~mask) | (src.X.l2.lo & mask), - (dst.X.l2.hi & ~mask) | (src.X.l2.hi & mask)), - U64((dst.X.l3.lo & ~mask) | (src.X.l3.lo & mask), - (dst.X.l3.hi & ~mask) | (src.X.l3.hi & mask))), - Fp(U64((dst.Y.l0.lo & ~mask) | (src.Y.l0.lo & mask), - (dst.Y.l0.hi & ~mask) | (src.Y.l0.hi & mask)), - U64((dst.Y.l1.lo & ~mask) | (src.Y.l1.lo & mask), - (dst.Y.l1.hi & ~mask) | (src.Y.l1.hi & mask)), - U64((dst.Y.l2.lo & ~mask) | (src.Y.l2.lo & mask), - (dst.Y.l2.hi & ~mask) | (src.Y.l2.hi & mask)), - U64((dst.Y.l3.lo & ~mask) | (src.Y.l3.lo & mask), - (dst.Y.l3.hi & ~mask) | (src.Y.l3.hi & mask))), - Fp(U64((dst.Z.l0.lo & ~mask) | (src.Z.l0.lo & mask), - (dst.Z.l0.hi & ~mask) | (src.Z.l0.hi & mask)), - U64((dst.Z.l1.lo & ~mask) | (src.Z.l1.lo & mask), - (dst.Z.l1.hi & ~mask) | (src.Z.l1.hi & mask)), - U64((dst.Z.l2.lo & ~mask) | (src.Z.l2.lo & mask), - (dst.Z.l2.hi & ~mask) | (src.Z.l2.hi & mask)), - U64((dst.Z.l3.lo & ~mask) | (src.Z.l3.lo & mask), - (dst.Z.l3.hi & ~mask) | (src.Z.l3.hi & mask))), - ); -} - -fn read_fp(buf: ptr, read>, word_off: u32) -> Fp { - return Fp( - U64((*buf)[word_off + 0u], (*buf)[word_off + 1u]), - U64((*buf)[word_off + 2u], (*buf)[word_off + 3u]), - U64((*buf)[word_off + 4u], (*buf)[word_off + 5u]), - U64((*buf)[word_off + 6u], (*buf)[word_off + 7u]), - ); -} - -fn write_fp(buf: ptr, read_write>, word_off: u32, x: Fp) { - (*buf)[word_off + 0u] = x.l0.lo; (*buf)[word_off + 1u] = x.l0.hi; - (*buf)[word_off + 2u] = x.l1.lo; (*buf)[word_off + 3u] = x.l1.hi; - (*buf)[word_off + 4u] = x.l2.lo; (*buf)[word_off + 5u] = x.l2.hi; - (*buf)[word_off + 6u] = x.l3.lo; (*buf)[word_off + 7u] = x.l3.hi; -} - -fn read_pt(buf: ptr, read>, word_off: u32) -> Pt { - return Pt(read_fp(buf, word_off ), - read_fp(buf, word_off + 8u ), - read_fp(buf, word_off + 16u )); -} - -fn write_pt(buf: ptr, read_write>, word_off: u32, p: Pt) { - write_fp(buf, word_off , p.X); - write_fp(buf, word_off + 8u , p.Y); - write_fp(buf, word_off + 16u , p.Z); -} - -fn pt_scalar_mul(p: Pt, scalars: ptr, read>, word_off: u32) -> Pt { - var acc = pt_identity(); - var base = p; - for (var w: u32 = 0u; w < 8u; w = w + 1u) { - let word = (*scalars)[word_off + w]; - for (var byte_in_word: u32 = 0u; byte_in_word < 4u; byte_in_word = byte_in_word + 1u) { - let b: u32 = (word >> (byte_in_word * 8u)) & 0xffu; - for (var bit: u32 = 0u; bit < 8u; bit = bit + 1u) { - let one_or_zero: u32 = (b >> bit) & 1u; - let mask: u32 = select(0u, 0xffffffffu, one_or_zero == 1u); - let sum = pt_add(acc, base); - acc = pt_cmov_select(acc, sum, mask); - base = pt_double(base); - } - } - } - return acc; -} - -@group(0) @binding(0) var g_pts : array; -@group(0) @binding(1) var g_scalars : array; -@group(0) @binding(2) var g_outs : array; -@group(0) @binding(3) var g_n : u32; -@group(0) @binding(4) var g_M : u32; - -@compute @workgroup_size(64) -fn banderwagon_add_kernel(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; - if (i >= g_n) { return; } - let off = i * 48u; - let P = read_pt(&g_pts, off); - let Q = read_pt(&g_pts, off + 24u); - let R = pt_add(P, Q); - write_pt(&g_outs, i * 24u, R); -} - -@compute @workgroup_size(64) -fn banderwagon_double_kernel(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; - if (i >= g_n) { return; } - let off = i * 24u; - let P = read_pt(&g_pts, off); - let R = pt_double(P); - write_pt(&g_outs, off, R); -} - -@compute @workgroup_size(32) -fn banderwagon_smul_kernel(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; - if (i >= g_n) { return; } - let pt_off = i * 24u; - let sc_off = i * 8u; - let P = read_pt(&g_pts, pt_off); - let R = pt_scalar_mul(P, &g_scalars, sc_off); - write_pt(&g_outs, pt_off, R); -} - -@compute @workgroup_size(32) -fn banderwagon_msm_kernel(@builtin(global_invocation_id) gid: vec3) { - let b = gid.x; - if (b >= g_M) { return; } - let n = g_n; - var acc = pt_identity(); - for (var i: u32 = 0u; i < n; i = i + 1u) { - let P = read_pt(&g_pts, i * 24u); - let term = pt_scalar_mul(P, &g_scalars, (b * n + i) * 8u); - acc = pt_add(acc, term); - } - write_pt(&g_outs, b * 24u, acc); -} diff --git a/banderwagon/gpu/wgsl/banderwagon_driver.cpp b/banderwagon/gpu/wgsl/banderwagon_driver.cpp deleted file mode 100644 index 7c59d11..0000000 --- a/banderwagon/gpu/wgsl/banderwagon_driver.cpp +++ /dev/null @@ -1,450 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// WGSL driver for Banderwagon group ops -- C++ host polyfill of the WGSL -// kernel. -// -// Mirrors banderwagon/gpu/wgsl/banderwagon.wgsl byte-for-byte: each 64-bit -// Montgomery limb is represented as a (lo, hi) pair of uint32_t, and every -// 64-bit op is reconstructed from u32 primitives. This is exactly what WGSL -// does; running the same arithmetic on the host gives byte-equal output to -// both the WGSL kernel and the CPU oracle. -// -// The constant table comes from banderwagon_const.wgslh (auto-generated by -// the CPU body via banderwagon_gen_gpu_constants), which is a polyglot header -// emitting both __cplusplus `static const uint32_t` arrays and WGSL `const` -// declarations. - -#include "banderwagon_driver.h" -#include "banderwagon_const.wgslh" - -#include -#include - -namespace { - -struct U64 { uint32_t lo, hi; }; - -inline U64 u64_zero() { return U64{0u, 0u}; } -inline bool u64_lt(U64 a, U64 b) { - if (a.hi != b.hi) return a.hi < b.hi; - return a.lo < b.lo; -} -inline bool u64_eq(U64 a, U64 b) { return a.lo == b.lo && a.hi == b.hi; } - -struct U64Carry { U64 v; uint32_t carry; }; -struct U64Borrow { U64 v; uint32_t borrow; }; - -inline U64Carry u64_add_carry(U64 a, U64 b, uint32_t cin) { - uint32_t lo1 = a.lo + b.lo; - uint32_t c0 = (lo1 < a.lo) ? 1u : 0u; - uint32_t lo = lo1 + cin; - uint32_t c1 = (lo < lo1) ? 1u : 0u; - uint32_t hi1 = a.hi + b.hi; - uint32_t c2 = (hi1 < a.hi) ? 1u : 0u; - uint32_t hi2 = hi1 + (c0 + c1); - uint32_t c3 = (hi2 < hi1) ? 1u : 0u; - return U64Carry{U64{lo, hi2}, c2 + c3}; -} - -inline U64Borrow u64_sub_borrow(U64 a, U64 b, uint32_t bin) { - uint32_t lo1 = a.lo - b.lo; - uint32_t bor0 = (a.lo < b.lo) ? 1u : 0u; - uint32_t lo = lo1 - bin; - uint32_t bor1 = (lo1 < bin) ? 1u : 0u; - uint32_t hi1 = a.hi - b.hi; - uint32_t bor2 = (a.hi < b.hi) ? 1u : 0u; - uint32_t bor_in_hi = bor0 + bor1; - uint32_t hi = hi1 - bor_in_hi; - uint32_t bor3 = (hi1 < bor_in_hi) ? 1u : 0u; - return U64Borrow{U64{lo, hi}, bor2 + bor3}; -} - -inline U64 u32_mul64(uint32_t a, uint32_t b) { - uint32_t al = a & 0xffffu; - uint32_t ah = a >> 16; - uint32_t bl = b & 0xffffu; - uint32_t bh = b >> 16; - uint32_t ll = al * bl; - uint32_t lh = al * bh; - uint32_t hl = ah * bl; - uint32_t hh = ah * bh; - uint32_t mid_a = (ll >> 16) + (lh & 0xffffu); - uint32_t mid_b = mid_a + (hl & 0xffffu); - uint32_t mid_carry = (mid_b >> 16); - uint32_t lo = (ll & 0xffffu) | (mid_b << 16); - uint32_t hi = hh + (lh >> 16) + (hl >> 16) + mid_carry; - return U64{lo, hi}; -} - -struct U128 { uint32_t l0, l1, l2, l3; }; - -inline U128 umul64(U64 a, U64 b) { - U64 p_ll = u32_mul64(a.lo, b.lo); - U64 p_lh = u32_mul64(a.lo, b.hi); - U64 p_hl = u32_mul64(a.hi, b.lo); - U64 p_hh = u32_mul64(a.hi, b.hi); - uint32_t w0 = p_ll.lo; - uint32_t s1a = p_ll.hi + p_lh.lo; - uint32_t c1a = (s1a < p_ll.hi) ? 1u : 0u; - uint32_t w1 = s1a + p_hl.lo; - uint32_t c1b = (w1 < s1a) ? 1u : 0u; - uint32_t carry1 = c1a + c1b; - uint32_t s2a = p_lh.hi + p_hl.hi; - uint32_t c2a = (s2a < p_lh.hi) ? 1u : 0u; - uint32_t s2b = s2a + p_hh.lo; - uint32_t c2b = (s2b < s2a) ? 1u : 0u; - uint32_t w2 = s2b + carry1; - uint32_t c2c = (w2 < s2b) ? 1u : 0u; - uint32_t carry2 = c2a + c2b + c2c; - uint32_t w3 = p_hh.hi + carry2; - return U128{w0, w1, w2, w3}; -} - -inline U64 u64_mul_low(U64 a, U64 b) { - U64 p_ll = u32_mul64(a.lo, b.lo); - U64 p_lh = u32_mul64(a.lo, b.hi); - U64 p_hl = u32_mul64(a.hi, b.lo); - uint32_t lo = p_ll.lo; - uint32_t hi = p_ll.hi + p_lh.lo + p_hl.lo; - return U64{lo, hi}; -} - -struct Fp { U64 l0, l1, l2, l3; }; -struct Pt { Fp X, Y, Z; }; - -inline Fp fp_zero() { return Fp{u64_zero(), u64_zero(), u64_zero(), u64_zero()}; } - -inline Fp fp_q() { - return Fp{ - U64{FP_Q_0_LO, FP_Q_0_HI}, U64{FP_Q_1_LO, FP_Q_1_HI}, - U64{FP_Q_2_LO, FP_Q_2_HI}, U64{FP_Q_3_LO, FP_Q_3_HI} - }; -} - -inline Fp fp_one() { - return Fp{ - U64{FP_R_0_LO, FP_R_0_HI}, U64{FP_R_1_LO, FP_R_1_HI}, - U64{FP_R_2_LO, FP_R_2_HI}, U64{FP_R_3_LO, FP_R_3_HI} - }; -} - -inline Fp fp_curve_a() { - return Fp{ - U64{CURVE_A_0_LO, CURVE_A_0_HI}, U64{CURVE_A_1_LO, CURVE_A_1_HI}, - U64{CURVE_A_2_LO, CURVE_A_2_HI}, U64{CURVE_A_3_LO, CURVE_A_3_HI} - }; -} - -inline Fp fp_curve_d() { - return Fp{ - U64{CURVE_D_0_LO, CURVE_D_0_HI}, U64{CURVE_D_1_LO, CURVE_D_1_HI}, - U64{CURVE_D_2_LO, CURVE_D_2_HI}, U64{CURVE_D_3_LO, CURVE_D_3_HI} - }; -} - -inline U64 fp_qinv() { return U64{FP_QINV_NEG_LO, FP_QINV_NEG_HI}; } - -inline Fp fp_cond_sub_q(const Fp &a) { - Fp q = fp_q(); - auto r0 = u64_sub_borrow(a.l0, q.l0, 0u); - auto r1 = u64_sub_borrow(a.l1, q.l1, r0.borrow); - auto r2 = u64_sub_borrow(a.l2, q.l2, r1.borrow); - auto r3 = u64_sub_borrow(a.l3, q.l3, r2.borrow); - uint32_t br = r3.borrow; - uint32_t mask = (br == 1u) ? 0u : 0xffffffffu; - return Fp{ - U64{(a.l0.lo & ~mask) | (r0.v.lo & mask), - (a.l0.hi & ~mask) | (r0.v.hi & mask)}, - U64{(a.l1.lo & ~mask) | (r1.v.lo & mask), - (a.l1.hi & ~mask) | (r1.v.hi & mask)}, - U64{(a.l2.lo & ~mask) | (r2.v.lo & mask), - (a.l2.hi & ~mask) | (r2.v.hi & mask)}, - U64{(a.l3.lo & ~mask) | (r3.v.lo & mask), - (a.l3.hi & ~mask) | (r3.v.hi & mask)}, - }; -} - -inline Fp fp_cond_add_q(const Fp &a, uint32_t mask) { - Fp q = fp_q(); - U64 add0{q.l0.lo & mask, q.l0.hi & mask}; - U64 add1{q.l1.lo & mask, q.l1.hi & mask}; - U64 add2{q.l2.lo & mask, q.l2.hi & mask}; - U64 add3{q.l3.lo & mask, q.l3.hi & mask}; - auto r0 = u64_add_carry(a.l0, add0, 0u); - auto r1 = u64_add_carry(a.l1, add1, r0.carry); - auto r2 = u64_add_carry(a.l2, add2, r1.carry); - auto r3 = u64_add_carry(a.l3, add3, r2.carry); - return Fp{r0.v, r1.v, r2.v, r3.v}; -} - -inline Fp fp_add(const Fp &a, const Fp &b) { - auto r0 = u64_add_carry(a.l0, b.l0, 0u); - auto r1 = u64_add_carry(a.l1, b.l1, r0.carry); - auto r2 = u64_add_carry(a.l2, b.l2, r1.carry); - auto r3 = u64_add_carry(a.l3, b.l3, r2.carry); - return fp_cond_sub_q(Fp{r0.v, r1.v, r2.v, r3.v}); -} - -inline Fp fp_sub(const Fp &a, const Fp &b) { - auto r0 = u64_sub_borrow(a.l0, b.l0, 0u); - auto r1 = u64_sub_borrow(a.l1, b.l1, r0.borrow); - auto r2 = u64_sub_borrow(a.l2, b.l2, r1.borrow); - auto r3 = u64_sub_borrow(a.l3, b.l3, r2.borrow); - uint32_t mask = (r3.borrow == 1u) ? 0xffffffffu : 0u; - return fp_cond_add_q(Fp{r0.v, r1.v, r2.v, r3.v}, mask); -} - -inline Fp fp_mul(const Fp &a, const Fp &b) { - U64 t[5] = {u64_zero(), u64_zero(), u64_zero(), u64_zero(), u64_zero()}; - const U64 al[4] = {a.l0, a.l1, a.l2, a.l3}; - const U64 bl[4] = {b.l0, b.l1, b.l2, b.l3}; - const U64 qq[4] = { - U64{FP_Q_0_LO, FP_Q_0_HI}, U64{FP_Q_1_LO, FP_Q_1_HI}, - U64{FP_Q_2_LO, FP_Q_2_HI}, U64{FP_Q_3_LO, FP_Q_3_HI}, - }; - U64 qinv = fp_qinv(); - - for (int i = 0; i < 4; ++i) { - U64 cy = u64_zero(); - for (int j = 0; j < 4; ++j) { - U128 prod = umul64(al[j], bl[i]); - U64 lo = U64{prod.l0, prod.l1}; - U64 hi = U64{prod.l2, prod.l3}; - auto s = u64_add_carry(t[j], lo, 0u); - auto s2 = u64_add_carry(s.v, cy, 0u); - t[j] = s2.v; - auto cy1 = u64_add_carry(hi, U64{s.carry, 0u}, 0u); - auto cy2 = u64_add_carry(cy1.v, U64{s2.carry, 0u}, 0u); - cy = cy2.v; - } - auto t4u = u64_add_carry(t[4], cy, 0u); - t[4] = t4u.v; - uint32_t big_carry = t4u.carry; - - U64 m = u64_mul_low(t[0], qinv); - - cy = u64_zero(); - for (int j = 0; j < 4; ++j) { - U128 prod = umul64(m, qq[j]); - U64 lo = U64{prod.l0, prod.l1}; - U64 hi = U64{prod.l2, prod.l3}; - auto s = u64_add_carry(t[j], lo, 0u); - auto s2 = u64_add_carry(s.v, cy, 0u); - t[j] = s2.v; - auto cy1 = u64_add_carry(hi, U64{s.carry, 0u}, 0u); - auto cy2 = u64_add_carry(cy1.v, U64{s2.carry, 0u}, 0u); - cy = cy2.v; - } - auto t3_step = u64_add_carry(t[4], cy, 0u); - auto t4_step = u64_add_carry(u64_zero(), - U64{big_carry, 0u}, - t3_step.carry); - t[0] = t[1]; - t[1] = t[2]; - t[2] = t[3]; - t[3] = t3_step.v; - t[4] = t4_step.v; - } - - Fp r{t[0], t[1], t[2], t[3]}; - if (!(t[4].lo == 0u && t[4].hi == 0u)) { - auto s0 = u64_sub_borrow(r.l0, U64{FP_Q_0_LO, FP_Q_0_HI}, 0u); - auto s1 = u64_sub_borrow(r.l1, U64{FP_Q_1_LO, FP_Q_1_HI}, s0.borrow); - auto s2 = u64_sub_borrow(r.l2, U64{FP_Q_2_LO, FP_Q_2_HI}, s1.borrow); - auto s3 = u64_sub_borrow(r.l3, U64{FP_Q_3_LO, FP_Q_3_HI}, s2.borrow); - return Fp{s0.v, s1.v, s2.v, s3.v}; - } - return fp_cond_sub_q(r); -} - -inline Fp fp_square(const Fp &a) { return fp_mul(a, a); } - -inline Pt pt_identity() { return Pt{fp_zero(), fp_one(), fp_one()}; } - -inline Pt pt_add(const Pt &p1, const Pt &p2) { - Fp d_const = fp_curve_d(); - Fp a_const = fp_curve_a(); - - Fp A = fp_mul(p1.Z, p2.Z); - Fp B = fp_square(A); - Fp C = fp_mul(p1.X, p2.X); - Fp D = fp_mul(p1.Y, p2.Y); - Fp E = fp_mul(d_const, fp_mul(C, D)); - Fp F = fp_sub(B, E); - Fp G = fp_add(B, E); - Fp H = fp_add(p1.X, p1.Y); - Fp I = fp_add(p2.X, p2.Y); - - Fp t = fp_mul(H, I); - t = fp_sub(t, C); - t = fp_sub(t, D); - t = fp_mul(t, A); - Fp X3 = fp_mul(t, F); - - Fp aC = fp_mul(a_const, C); - Fp t2 = fp_sub(D, aC); - t2 = fp_mul(t2, A); - Fp Y3 = fp_mul(t2, G); - - Fp Z3 = fp_mul(F, G); - return Pt{X3, Y3, Z3}; -} - -inline Pt pt_double(const Pt &p) { - Fp a_const = fp_curve_a(); - Fp XY = fp_add(p.X, p.Y); - Fp B = fp_square(XY); - Fp C = fp_square(p.X); - Fp D = fp_square(p.Y); - Fp E = fp_mul(a_const, C); - Fp F = fp_add(E, D); - Fp H = fp_square(p.Z); - Fp twoH = fp_add(H, H); - Fp J = fp_sub(F, twoH); - - Fp t = fp_sub(B, C); - t = fp_sub(t, D); - Fp X3 = fp_mul(t, J); - Fp Y3 = fp_mul(F, fp_sub(E, D)); - Fp Z3 = fp_mul(F, J); - return Pt{X3, Y3, Z3}; -} - -inline Pt pt_cmov_select(const Pt &dst, const Pt &src, uint32_t mask) { - auto sel = [mask](U64 a, U64 b) -> U64 { - return U64{(a.lo & ~mask) | (b.lo & mask), - (a.hi & ~mask) | (b.hi & mask)}; - }; - return Pt{ - Fp{sel(dst.X.l0, src.X.l0), sel(dst.X.l1, src.X.l1), - sel(dst.X.l2, src.X.l2), sel(dst.X.l3, src.X.l3)}, - Fp{sel(dst.Y.l0, src.Y.l0), sel(dst.Y.l1, src.Y.l1), - sel(dst.Y.l2, src.Y.l2), sel(dst.Y.l3, src.Y.l3)}, - Fp{sel(dst.Z.l0, src.Z.l0), sel(dst.Z.l1, src.Z.l1), - sel(dst.Z.l2, src.Z.l2), sel(dst.Z.l3, src.Z.l3)}, - }; -} - -inline Pt pt_scalar_mul(const Pt &p, const std::uint8_t s_le[32]) { - Pt acc = pt_identity(); - Pt base = p; - for (int byte_idx = 0; byte_idx < 32; ++byte_idx) { - std::uint8_t b = s_le[byte_idx]; - for (int bit = 0; bit < 8; ++bit) { - uint32_t one_or_zero = (uint32_t)((b >> bit) & 1u); - uint32_t mask = (one_or_zero == 1u) ? 0xffffffffu : 0u; - Pt sum = pt_add(acc, base); - acc = pt_cmov_select(acc, sum, mask); - base = pt_double(base); - } - } - return acc; -} - -inline U64 read_u64_le(const std::uint8_t *p) { - uint32_t lo = (uint32_t)p[0] - | ((uint32_t)p[1] << 8) - | ((uint32_t)p[2] << 16) - | ((uint32_t)p[3] << 24); - uint32_t hi = (uint32_t)p[4] - | ((uint32_t)p[5] << 8) - | ((uint32_t)p[6] << 16) - | ((uint32_t)p[7] << 24); - return U64{lo, hi}; -} - -inline void write_u64_le(std::uint8_t *p, U64 v) { - p[0] = (std::uint8_t)(v.lo); - p[1] = (std::uint8_t)(v.lo >> 8); - p[2] = (std::uint8_t)(v.lo >> 16); - p[3] = (std::uint8_t)(v.lo >> 24); - p[4] = (std::uint8_t)(v.hi); - p[5] = (std::uint8_t)(v.hi >> 8); - p[6] = (std::uint8_t)(v.hi >> 16); - p[7] = (std::uint8_t)(v.hi >> 24); -} - -inline Fp read_fp(const std::uint8_t *p) { - return Fp{read_u64_le(p), read_u64_le(p + 8), - read_u64_le(p + 16), read_u64_le(p + 24)}; -} - -inline void write_fp(std::uint8_t *p, const Fp &x) { - write_u64_le(p, x.l0); - write_u64_le(p + 8, x.l1); - write_u64_le(p + 16, x.l2); - write_u64_le(p + 24, x.l3); -} - -inline Pt read_pt(const std::uint8_t *p) { - return Pt{read_fp(p), read_fp(p + 32), read_fp(p + 64)}; -} - -inline void write_pt(std::uint8_t *p, const Pt &v) { - write_fp(p, v.X); - write_fp(p + 32, v.Y); - write_fp(p + 64, v.Z); -} - -} // namespace - -extern "C" int banderwagon_wgsl_add_batch(const std::uint8_t *pairs, - std::uint8_t *outs, - unsigned long n) { - if (n == 0) return 0; - if (!pairs || !outs) return -1; - for (unsigned long i = 0; i < n; ++i) { - Pt P = read_pt(pairs + i * 192); - Pt Q = read_pt(pairs + i * 192 + 96); - Pt R = pt_add(P, Q); - write_pt(outs + i * 96, R); - } - return 0; -} - -extern "C" int banderwagon_wgsl_double_batch(const std::uint8_t *pts, - std::uint8_t *outs, - unsigned long n) { - if (n == 0) return 0; - if (!pts || !outs) return -1; - for (unsigned long i = 0; i < n; ++i) { - Pt P = read_pt(pts + i * 96); - Pt R = pt_double(P); - write_pt(outs + i * 96, R); - } - return 0; -} - -extern "C" int banderwagon_wgsl_smul_batch(const std::uint8_t *pts, - const std::uint8_t *scalars, - std::uint8_t *outs, - unsigned long n) { - if (n == 0) return 0; - if (!pts || !scalars || !outs) return -1; - for (unsigned long i = 0; i < n; ++i) { - Pt P = read_pt(pts + i * 96); - Pt R = pt_scalar_mul(P, scalars + i * 32); - write_pt(outs + i * 96, R); - } - return 0; -} - -extern "C" int banderwagon_wgsl_msm_batch(const std::uint8_t *pts, - const std::uint8_t *scalars, - std::uint8_t *outs, - unsigned long n, - unsigned long M) { - if (n == 0 || M == 0) return 0; - if (!pts || !scalars || !outs) return -1; - for (unsigned long b = 0; b < M; ++b) { - Pt acc = pt_identity(); - for (unsigned long i = 0; i < n; ++i) { - Pt P = read_pt(pts + i * 96); - Pt term = pt_scalar_mul(P, scalars + (b * n + i) * 32); - acc = pt_add(acc, term); - } - write_pt(outs + b * 96, acc); - } - return 0; -} diff --git a/banderwagon/gpu/wgsl/banderwagon_driver.h b/banderwagon/gpu/wgsl/banderwagon_driver.h deleted file mode 100644 index 13e6703..0000000 --- a/banderwagon/gpu/wgsl/banderwagon_driver.h +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// WGSL driver for Banderwagon group ops. Runs the algorithm via a C++ host -// polyfill that emulates the kernel's u32-only arithmetic (WGSL has no native -// u64). Same encoding as the Metal/CUDA drivers: 96-byte Pt, 32-byte LE Fr. - -#ifndef LUX_BANDERWAGON_WGSL_DRIVER_H -#define LUX_BANDERWAGON_WGSL_DRIVER_H - -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -int banderwagon_wgsl_add_batch(const uint8_t *pairs, - uint8_t *outs, - unsigned long n); - -int banderwagon_wgsl_double_batch(const uint8_t *pts, - uint8_t *outs, - unsigned long n); - -int banderwagon_wgsl_smul_batch(const uint8_t *pts, - const uint8_t *scalars, - uint8_t *outs, - unsigned long n); - -int banderwagon_wgsl_msm_batch(const uint8_t *pts, - const uint8_t *scalars, - uint8_t *outs, - unsigned long n, - unsigned long M); - -#ifdef __cplusplus -} -#endif - -#endif diff --git a/blake2b/gpu/cuda/blake2b.cu b/blake2b/gpu/cuda/blake2b.cu deleted file mode 100644 index 034f807..0000000 --- a/blake2b/gpu/cuda/blake2b.cu +++ /dev/null @@ -1,197 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// BLAKE2b-512 batch hashing — CUDA implementation (RFC 7693). -// Byte-equal to blake2b/c-abi/blake2b_full.cpp::hash() and to -// blake2b/gpu/metal/blake2b_batch.metal::blake2b_jobs. -// -// Algorithm: 12 rounds, 128-byte blocks, 64-byte digest. Counter `t` -// increments by the actual amount consumed (RFC 7693 sec 3.3). Final block -// flag inverts v[14]. SIGMA permutation cycles every 10 rounds (so rounds -// 10 and 11 reuse SIGMA[0] and SIGMA[1]). -// -// One thread per input. Layout matches the Metal/SHA-256/RIPEMD-160 drivers: -// caller fills a flat byte arena and per-input (offset, length) descriptors; -// outputs are 64-byte stride. -// -// When CRYPTO_ENABLE_CUDA=ON this file is fed to nvcc and exposes -// `blake2b_jobs` as a real __global__ kernel. When CUDA is off (default) the -// same file compiles as host C++ via the `__CUDA_ARCH__` shim and exposes -// `blake2b_batch_cuda_host` so the determinism test still runs 100/100 on -// non-CUDA hosts. The kernel body is shared — byte-equal by construction. - -#include - -#ifndef __CUDA_ARCH__ -#define __device__ -#define __global__ -#define __shared__ -struct dim3 { unsigned x, y, z; }; -static dim3 blockIdx, blockDim, threadIdx; -#endif - -// ============================================================================= -// IV (RFC 7693 sec 2.6 — same as SHA-512 IV). -// ============================================================================= -__device__ static const uint64_t IV[8] = { - 0x6A09E667F3BCC908ULL, 0xBB67AE8584CAA73BULL, - 0x3C6EF372FE94F82BULL, 0xA54FF53A5F1D36F1ULL, - 0x510E527FADE682D1ULL, 0x9B05688C2B3E6C1FULL, - 0x1F83D9ABFB41BD6BULL, 0x5BE0CD19137E2179ULL, -}; - -// ============================================================================= -// SIGMA permutation (RFC 7693 sec 2.7). Round i uses SIGMA[i % 10]. -// ============================================================================= -__device__ static const uint8_t SIGMA[10][16] = { - { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 }, - { 14, 10, 4, 8, 9, 15, 13, 6, 1, 12, 0, 2, 11, 7, 5, 3 }, - { 11, 8, 12, 0, 5, 2, 15, 13, 10, 14, 3, 6, 7, 1, 9, 4 }, - { 7, 9, 3, 1, 13, 12, 11, 14, 2, 6, 5, 10, 4, 0, 15, 8 }, - { 9, 0, 5, 7, 2, 4, 10, 15, 14, 1, 11, 12, 6, 8, 3, 13 }, - { 2, 12, 6, 10, 0, 11, 8, 3, 4, 13, 7, 5, 15, 14, 1, 9 }, - { 12, 5, 1, 15, 14, 13, 4, 10, 0, 7, 6, 3, 9, 2, 8, 11 }, - { 13, 11, 7, 14, 12, 1, 3, 9, 5, 0, 15, 4, 8, 6, 2, 10 }, - { 6, 15, 14, 9, 11, 3, 0, 8, 12, 2, 13, 7, 1, 4, 10, 5 }, - { 10, 2, 8, 4, 7, 6, 1, 5, 15, 11, 9, 14, 3, 12, 13, 0 }, -}; - -// ============================================================================= -// Mixing function G (RFC 7693 sec 3.1). -// ============================================================================= -__device__ static inline uint64_t rotr64(uint64_t x, unsigned n) { - return (x >> n) | (x << (64u - n)); -} - -__device__ static void g_mix(uint64_t v[16], - unsigned a, unsigned b, unsigned c, unsigned d, - uint64_t x, uint64_t y) { - v[a] = v[a] + v[b] + x; - v[d] = rotr64(v[d] ^ v[a], 32); - v[c] = v[c] + v[d]; - v[b] = rotr64(v[b] ^ v[c], 24); - v[a] = v[a] + v[b] + y; - v[d] = rotr64(v[d] ^ v[a], 16); - v[c] = v[c] + v[d]; - v[b] = rotr64(v[b] ^ v[c], 63); -} - -// ============================================================================= -// Compression function F (RFC 7693 sec 3.2). 12 rounds. -// ============================================================================= -__device__ static void blake2b_compress(uint64_t h[8], const uint64_t m[16], - uint64_t t0, uint64_t t1, bool last) { - uint64_t v[16]; - for (int i = 0; i < 8; ++i) v[i] = h[i]; - for (int i = 0; i < 8; ++i) v[i + 8] = IV[i]; - v[12] ^= t0; - v[13] ^= t1; - if (last) v[14] = ~v[14]; - - for (int r = 0; r < 12; ++r) { - const uint8_t* s = SIGMA[r % 10]; - g_mix(v, 0, 4, 8, 12, m[s[ 0]], m[s[ 1]]); - g_mix(v, 1, 5, 9, 13, m[s[ 2]], m[s[ 3]]); - g_mix(v, 2, 6, 10, 14, m[s[ 4]], m[s[ 5]]); - g_mix(v, 3, 7, 11, 15, m[s[ 6]], m[s[ 7]]); - g_mix(v, 0, 5, 10, 15, m[s[ 8]], m[s[ 9]]); - g_mix(v, 1, 6, 11, 12, m[s[10]], m[s[11]]); - g_mix(v, 2, 7, 8, 13, m[s[12]], m[s[13]]); - g_mix(v, 3, 4, 9, 14, m[s[14]], m[s[15]]); - } - for (int i = 0; i < 8; ++i) h[i] ^= v[i] ^ v[i + 8]; -} - -// ============================================================================= -// Kernel — one thread per input. 64-byte (= 8 × u64) digest per output slot. -// ============================================================================= -extern "C" __global__ void blake2b_jobs( - const uint8_t* __restrict__ inputs, - const uint32_t* __restrict__ input_offsets, - const uint32_t* __restrict__ input_lens, - uint8_t* __restrict__ outputs, - uint32_t num_jobs) -{ - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= num_jobs) return; - - const uint8_t* in = inputs + input_offsets[tid]; - uint8_t* out = outputs + tid * 64u; - uint32_t len = input_lens[tid]; - - // Param block (RFC 7693 sec 2.5): digest_len=64, key_len=0, fanout=1, - // depth=1. h[0] ^= 0x01010040. - uint64_t h[8]; - for (int i = 0; i < 8; ++i) h[i] = IV[i]; - h[0] ^= 0x0000000001010040ULL; - - uint64_t t0 = 0, t1 = 0; - uint64_t m[16]; - uint8_t block[128]; - uint32_t pos = 0; - - // Stream all but the final block as non-last. - while ((len - pos) > 128u) { - for (int i = 0; i < 128; ++i) block[i] = in[pos + i]; - for (int i = 0; i < 16; ++i) { - uint64_t w = 0; - for (int b = 0; b < 8; ++b) - w |= (uint64_t)block[i * 8 + b] << (b * 8); - m[i] = w; - } - pos += 128u; - uint64_t t0_new = t0 + 128u; - if (t0_new < t0) ++t1; - t0 = t0_new; - blake2b_compress(h, m, t0, t1, false); - } - - // Final (possibly partial) block, zero-padded. - uint32_t rem = len - pos; - for (int i = 0; i < 128; ++i) block[i] = 0; - for (uint32_t i = 0; i < rem; ++i) block[i] = in[pos + i]; - for (int i = 0; i < 16; ++i) { - uint64_t w = 0; - for (int b = 0; b < 8; ++b) - w |= (uint64_t)block[i * 8 + b] << (b * 8); - m[i] = w; - } - uint64_t t0_new = t0 + (uint64_t)rem; - if (t0_new < t0) ++t1; - t0 = t0_new; - blake2b_compress(h, m, t0, t1, true); - - // Little-endian state -> output digest. - for (int i = 0; i < 8; ++i) { - for (int b = 0; b < 8; ++b) { - out[i * 8 + b] = (uint8_t)((h[i] >> (b * 8)) & 0xFFULL); - } - } -} - -// ============================================================================= -// Host-emulation entry. When this TU is compiled as plain C++ (CUDA disabled -// or unavailable on the build host) we replay the kernel sequentially per -// thread index. Same code, no GPU. Used by the determinism test to prove -// byte-equality with the CPU oracle on every host. -// ============================================================================= -#ifndef __CUDA_ARCH__ -extern "C" int blake2b_batch_cuda_host( - const uint8_t* data, - const uint32_t* offsets, - const uint32_t* lengths, - uint8_t* outputs, - uint32_t num_inputs) -{ - if (num_inputs == 0) return 0; - if (!data || !offsets || !lengths || !outputs) return -1; - - for (uint32_t tid = 0; tid < num_inputs; ++tid) { - blockIdx.x = tid; - blockDim.x = 1; - threadIdx.x = 0; - blake2b_jobs(data, offsets, lengths, outputs, num_inputs); - } - return 0; -} -#endif diff --git a/blake2b/gpu/metal/blake2b_batch.metal b/blake2b/gpu/metal/blake2b_batch.metal deleted file mode 100644 index f5d2442..0000000 --- a/blake2b/gpu/metal/blake2b_batch.metal +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// GPU-batched BLAKE2b-512 (RFC 7693). One thread per input. Byte-equal to -// blake2b/c-abi/blake2b_full.cpp::hash() — no key, no salt, no personalisation. -// -// 12 rounds, 128-byte blocks, 64-byte digest. Counter `t` increments by -// the actual amount consumed (per RFC 7693 sec 3.3). Final block flag set -// on the last call. - -#include -using namespace metal; - -constant ulong IV[8] = { - 0x6a09e667f3bcc908UL, 0xbb67ae8584caa73bUL, - 0x3c6ef372fe94f82bUL, 0xa54ff53a5f1d36f1UL, - 0x510e527fade682d1UL, 0x9b05688c2b3e6c1fUL, - 0x1f83d9abfb41bd6bUL, 0x5be0cd19137e2179UL, -}; - -constant uint8_t SIGMA[10][16] = { - { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 }, - { 14, 10, 4, 8, 9, 15, 13, 6, 1, 12, 0, 2, 11, 7, 5, 3 }, - { 11, 8, 12, 0, 5, 2, 15, 13, 10, 14, 3, 6, 7, 1, 9, 4 }, - { 7, 9, 3, 1, 13, 12, 11, 14, 2, 6, 5, 10, 4, 0, 15, 8 }, - { 9, 0, 5, 7, 2, 4, 10, 15, 14, 1, 11, 12, 6, 8, 3, 13 }, - { 2, 12, 6, 10, 0, 11, 8, 3, 4, 13, 7, 5, 15, 14, 1, 9 }, - { 12, 5, 1, 15, 14, 13, 4, 10, 0, 7, 6, 3, 9, 2, 8, 11 }, - { 13, 11, 7, 14, 12, 1, 3, 9, 5, 0, 15, 4, 8, 6, 2, 10 }, - { 6, 15, 14, 9, 11, 3, 0, 8, 12, 2, 13, 7, 1, 4, 10, 5 }, - { 10, 2, 8, 4, 7, 6, 1, 5, 15, 11, 9, 14, 3, 12, 13, 0 }, -}; - -inline ulong rotr64(ulong x, uint n) { return (x >> n) | (x << (64u - n)); } - -inline void g_mix(thread ulong* v, uint a, uint b, uint c, uint d, - ulong x, ulong y) { - v[a] = v[a] + v[b] + x; - v[d] = rotr64(v[d] ^ v[a], 32); - v[c] = v[c] + v[d]; - v[b] = rotr64(v[b] ^ v[c], 24); - v[a] = v[a] + v[b] + y; - v[d] = rotr64(v[d] ^ v[a], 16); - v[c] = v[c] + v[d]; - v[b] = rotr64(v[b] ^ v[c], 63); -} - -inline void blake2b_compress(thread ulong* h, thread const ulong* m, - ulong t0, ulong t1, bool last) { - ulong v[16]; - for (uint i = 0; i < 8; ++i) v[i] = h[i]; - for (uint i = 0; i < 8; ++i) v[i + 8] = IV[i]; - v[12] ^= t0; - v[13] ^= t1; - if (last) v[14] = ~v[14]; - - for (uint i = 0; i < 12; ++i) { - const constant uint8_t* s = SIGMA[i % 10]; - g_mix(v, 0, 4, 8, 12, m[s[ 0]], m[s[ 1]]); - g_mix(v, 1, 5, 9, 13, m[s[ 2]], m[s[ 3]]); - g_mix(v, 2, 6, 10, 14, m[s[ 4]], m[s[ 5]]); - g_mix(v, 3, 7, 11, 15, m[s[ 6]], m[s[ 7]]); - g_mix(v, 0, 5, 10, 15, m[s[ 8]], m[s[ 9]]); - g_mix(v, 1, 6, 11, 12, m[s[10]], m[s[11]]); - g_mix(v, 2, 7, 8, 13, m[s[12]], m[s[13]]); - g_mix(v, 3, 4, 9, 14, m[s[14]], m[s[15]]); - } - - for (uint i = 0; i < 8; ++i) h[i] ^= v[i] ^ v[i + 8]; -} - -struct Blake2bJobGPU { - uint input_offset; - uint input_len; - uint output_offset; - uint _pad; -}; - -kernel void blake2b_jobs( - device const Blake2bJobGPU* jobs [[buffer(0)]], - device const uchar* inputs [[buffer(1)]], - device uchar* outputs [[buffer(2)]], - constant uint& num_jobs [[buffer(3)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= num_jobs) return; - - Blake2bJobGPU j = jobs[tid]; - const device uchar* in = inputs + j.input_offset; - device uchar* out = outputs + j.output_offset; - - // Param block sec 2.5: digest_len=64, key_len=0, fanout=1, depth=1. - // h[0] ^= 0x01010040 (= 0x01010000 ^ 64). - ulong h[8]; - for (uint i = 0; i < 8; ++i) h[i] = IV[i]; - h[0] ^= 0x0000000001010040UL; - - // Stream: process all but the final block as non-last. - ulong t0 = 0, t1 = 0; - ulong m[16]; - uchar block[128]; - uint pos = 0; - - while ((j.input_len - pos) > 128u) { - for (uint i = 0; i < 128; ++i) block[i] = in[pos + i]; - for (uint i = 0; i < 16; ++i) { - ulong w = 0; - for (uint b = 0; b < 8; ++b) - w |= ulong(block[i * 8 + b]) << (b * 8); - m[i] = w; - } - pos += 128; - ulong t0_new = t0 + 128; - if (t0_new < t0) ++t1; - t0 = t0_new; - blake2b_compress(h, m, t0, t1, false); - } - - // Final (possibly partial) block, zero-padded. - uint rem = j.input_len - pos; - for (uint i = 0; i < 128; ++i) block[i] = 0; - for (uint i = 0; i < rem; ++i) block[i] = in[pos + i]; - for (uint i = 0; i < 16; ++i) { - ulong w = 0; - for (uint b = 0; b < 8; ++b) - w |= ulong(block[i * 8 + b]) << (b * 8); - m[i] = w; - } - ulong t0_new = t0 + (ulong)rem; - if (t0_new < t0) ++t1; - t0 = t0_new; - blake2b_compress(h, m, t0, t1, true); - - // Little-endian output. - for (uint i = 0; i < 8; ++i) { - for (uint b = 0; b < 8; ++b) { - out[i * 8 + b] = uchar((h[i] >> (b * 8)) & 0xFF); - } - } -} diff --git a/blake2b/gpu/metal/blake2b_batch_driver.mm b/blake2b/gpu/metal/blake2b_batch_driver.mm deleted file mode 100644 index 1310183..0000000 --- a/blake2b/gpu/metal/blake2b_batch_driver.mm +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Metal driver for batched BLAKE2b-512 (RFC 7693). macOS / iOS only. -// Loads blake2b_batch.metallib, dispatches `blake2b_jobs` with one thread -// per input. Byte-equal to blake2b/c-abi/blake2b_full.cpp::hash(). - -#if __APPLE__ && __OBJC__ - -#import -#import - -#include -#include -#include -#include - -namespace { - -struct Blake2bJobGPU { - uint32_t input_offset; - uint32_t input_len; - uint32_t output_offset; - uint32_t _pad; -}; - -} // namespace - -extern "C" int blake2b_batch_metal( - const uint8_t* inputs_arena, - size_t inputs_arena_len, - const uint32_t* input_offsets, - const uint32_t* input_lens, - size_t n, - uint8_t* outputs_arena, - const char* metallib_path) { - - if (n == 0) return 0; - if (!inputs_arena || !input_offsets || !input_lens || !outputs_arena || - !metallib_path) return -1; - - @autoreleasepool { - id device = MTLCreateSystemDefaultDevice(); - if (!device) return -2; - - NSError* err = nil; - NSString* path = [NSString stringWithUTF8String:metallib_path]; - NSURL* url = [NSURL fileURLWithPath:path]; - id lib = [device newLibraryWithURL:url error:&err]; - if (!lib) return -3; - - id fn = [lib newFunctionWithName:@"blake2b_jobs"]; - if (!fn) return -4; - - id pipeline = - [device newComputePipelineStateWithFunction:fn error:&err]; - if (!pipeline) return -5; - - id queue = [device newCommandQueue]; - - std::vector jobs(n); - for (size_t i = 0; i < n; ++i) { - jobs[i].input_offset = input_offsets[i]; - jobs[i].input_len = input_lens[i]; - jobs[i].output_offset = (uint32_t)(i * 64); - jobs[i]._pad = 0; - } - - id jobs_buf = [device newBufferWithBytes:jobs.data() - length:jobs.size() * sizeof(Blake2bJobGPU) - options:MTLResourceStorageModeShared]; - id inputs_buf = [device newBufferWithBytes:inputs_arena - length:inputs_arena_len - options:MTLResourceStorageModeShared]; - id outputs_buf = [device newBufferWithLength:n * 64 - options:MTLResourceStorageModeShared]; - uint32_t n_u32 = (uint32_t)n; - id n_buf = [device newBufferWithBytes:&n_u32 - length:sizeof(n_u32) - options:MTLResourceStorageModeShared]; - - id cmd = [queue commandBuffer]; - id enc = [cmd computeCommandEncoder]; - [enc setComputePipelineState:pipeline]; - [enc setBuffer:jobs_buf offset:0 atIndex:0]; - [enc setBuffer:inputs_buf offset:0 atIndex:1]; - [enc setBuffer:outputs_buf offset:0 atIndex:2]; - [enc setBuffer:n_buf offset:0 atIndex:3]; - - NSUInteger tg_max = pipeline.maxTotalThreadsPerThreadgroup; - NSUInteger tg_w = tg_max < 64 ? tg_max : 64; - MTLSize threads_per_grid = MTLSizeMake(n, 1, 1); - MTLSize threads_per_tg = MTLSizeMake(tg_w, 1, 1); - [enc dispatchThreads:threads_per_grid threadsPerThreadgroup:threads_per_tg]; - [enc endEncoding]; - [cmd commit]; - [cmd waitUntilCompleted]; - - std::memcpy(outputs_arena, [outputs_buf contents], n * 64); - } - return 0; -} - -#endif // __APPLE__ && __OBJC__ diff --git a/blake2b/gpu/wgsl/blake2b.wgsl b/blake2b/gpu/wgsl/blake2b.wgsl deleted file mode 100644 index edc2381..0000000 --- a/blake2b/gpu/wgsl/blake2b.wgsl +++ /dev/null @@ -1,236 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// BLAKE2b-512 (RFC 7693) compute shader in WGSL. One thread per input. -// Byte-equal to blake2b/c-abi/blake2b_full.cpp::hash() and to -// blake2b/gpu/cuda/blake2b.cu and blake2b/gpu/metal/blake2b_batch.metal. -// -// WGSL has no native u64. Each 64-bit word is emulated as a vec2 -// (.x = lo, .y = hi). All BLAKE2b operations are expressed against this -// pair representation: -// -// xor64 — lane-wise xor -// add64 — full 64-bit add (carry from lo via unsigned wrap detect) -// rotr64(n) — four cases: n==0 (identity), n==32 (lane swap), -// n in (0,32) inner funnel, n in (32,64) outer funnel. -// -// 12 rounds, 128-byte block, little-endian message words, final-block flag -// inverts v[14]. Output is 64 bytes little-endian (8 × u64 lanes packed -// into 16 × u32 entries in the outputs[] storage buffer). - -struct HashInput { - offset: u32, - length: u32, -} - -@group(0) @binding(0) var inputs: array; -@group(0) @binding(1) var data: array; -@group(0) @binding(2) var outputs: array; - -// ============================================================================= -// IV (RFC 7693 sec 2.6) split into (lo, hi) u32 pairs. -// ============================================================================= -const IV_LO = array( - 0xF3BCC908u, 0x84CAA73Bu, 0xFE94F82Bu, 0x5F1D36F1u, - 0xADE682D1u, 0x2B3E6C1Fu, 0xFB41BD6Bu, 0x137E2179u, -); -const IV_HI = array( - 0x6A09E667u, 0xBB67AE85u, 0x3C6EF372u, 0xA54FF53Au, - 0x510E527Fu, 0x9B05688Cu, 0x1F83D9ABu, 0x5BE0CD19u, -); - -// SIGMA permutation (RFC 7693 sec 2.7). -const SIGMA = array, 10>( - array( 0u, 1u, 2u, 3u, 4u, 5u, 6u, 7u, 8u, 9u, 10u, 11u, 12u, 13u, 14u, 15u), - array(14u, 10u, 4u, 8u, 9u, 15u, 13u, 6u, 1u, 12u, 0u, 2u, 11u, 7u, 5u, 3u), - array(11u, 8u, 12u, 0u, 5u, 2u, 15u, 13u, 10u, 14u, 3u, 6u, 7u, 1u, 9u, 4u), - array( 7u, 9u, 3u, 1u, 13u, 12u, 11u, 14u, 2u, 6u, 5u, 10u, 4u, 0u, 15u, 8u), - array( 9u, 0u, 5u, 7u, 2u, 4u, 10u, 15u, 14u, 1u, 11u, 12u, 6u, 8u, 3u, 13u), - array( 2u, 12u, 6u, 10u, 0u, 11u, 8u, 3u, 4u, 13u, 7u, 5u, 15u, 14u, 1u, 9u), - array(12u, 5u, 1u, 15u, 14u, 13u, 4u, 10u, 0u, 7u, 6u, 3u, 9u, 2u, 8u, 11u), - array(13u, 11u, 7u, 14u, 12u, 1u, 3u, 9u, 5u, 0u, 15u, 4u, 8u, 6u, 2u, 10u), - array( 6u, 15u, 14u, 9u, 11u, 3u, 0u, 8u, 12u, 2u, 13u, 7u, 1u, 4u, 10u, 5u), - array(10u, 2u, 8u, 4u, 7u, 6u, 1u, 5u, 15u, 11u, 9u, 14u, 3u, 12u, 13u, 0u), -); - -// ============================================================================= -// 64-bit emulation. Each "u64" lives in a vec2 (lo, hi). -// ============================================================================= - -fn xor64(a: vec2, b: vec2) -> vec2 { - return vec2(a.x ^ b.x, a.y ^ b.y); -} - -fn add64(a: vec2, b: vec2) -> vec2 { - let lo = a.x + b.x; - let carry = select(0u, 1u, lo < a.x); // unsigned wrap detect - let hi = a.y + b.y + carry; - return vec2(lo, hi); -} - -fn rotr64(v: vec2, n: u32) -> vec2 { - if (n == 0u) { return v; } - if (n == 32u) { return vec2(v.y, v.x); } - if (n < 32u) { - let lo = (v.x >> n) | (v.y << (32u - n)); - let hi = (v.y >> n) | (v.x << (32u - n)); - return vec2(lo, hi); - } - let m = n - 32u; - let lo = (v.y >> m) | (v.x << (32u - m)); - let hi = (v.x >> m) | (v.y << (32u - m)); - return vec2(lo, hi); -} - -fn not64(a: vec2) -> vec2 { - return vec2(~a.x, ~a.y); -} - -// ============================================================================= -// Working state — kept in private storage so dynamic indexing is legal. -// ============================================================================= -var v_lo: array; -var v_hi: array; -var h_lo: array; -var h_hi: array; -var m_lo: array; -var m_hi: array; - -fn g_mix(a: u32, b: u32, c: u32, d: u32, mx: u32, my: u32) { - var va = vec2(v_lo[a], v_hi[a]); - var vb = vec2(v_lo[b], v_hi[b]); - var vc = vec2(v_lo[c], v_hi[c]); - var vd = vec2(v_lo[d], v_hi[d]); - let mxv = vec2(m_lo[mx], m_hi[mx]); - let myv = vec2(m_lo[my], m_hi[my]); - - va = add64(add64(va, vb), mxv); - vd = rotr64(xor64(vd, va), 32u); - vc = add64(vc, vd); - vb = rotr64(xor64(vb, vc), 24u); - va = add64(add64(va, vb), myv); - vd = rotr64(xor64(vd, va), 16u); - vc = add64(vc, vd); - vb = rotr64(xor64(vb, vc), 63u); - - v_lo[a] = va.x; v_hi[a] = va.y; - v_lo[b] = vb.x; v_hi[b] = vb.y; - v_lo[c] = vc.x; v_hi[c] = vc.y; - v_lo[d] = vd.x; v_hi[d] = vd.y; -} - -fn compress(t0: vec2, t1: vec2, last_block: bool) { - for (var i = 0u; i < 8u; i = i + 1u) { - v_lo[i] = h_lo[i]; - v_hi[i] = h_hi[i]; - } - for (var i = 0u; i < 8u; i = i + 1u) { - v_lo[i + 8u] = IV_LO[i]; - v_hi[i + 8u] = IV_HI[i]; - } - v_lo[12] = v_lo[12] ^ t0.x; v_hi[12] = v_hi[12] ^ t0.y; - v_lo[13] = v_lo[13] ^ t1.x; v_hi[13] = v_hi[13] ^ t1.y; - if (last_block) { - v_lo[14] = ~v_lo[14]; - v_hi[14] = ~v_hi[14]; - } - - for (var r = 0u; r < 12u; r = r + 1u) { - let s = SIGMA[r % 10u]; - g_mix(0u, 4u, 8u, 12u, s[0u], s[1u]); - g_mix(1u, 5u, 9u, 13u, s[2u], s[3u]); - g_mix(2u, 6u, 10u, 14u, s[4u], s[5u]); - g_mix(3u, 7u, 11u, 15u, s[6u], s[7u]); - g_mix(0u, 5u, 10u, 15u, s[8u], s[9u]); - g_mix(1u, 6u, 11u, 12u, s[10u], s[11u]); - g_mix(2u, 7u, 8u, 13u, s[12u], s[13u]); - g_mix(3u, 4u, 9u, 14u, s[14u], s[15u]); - } - - for (var i = 0u; i < 8u; i = i + 1u) { - h_lo[i] = h_lo[i] ^ v_lo[i] ^ v_lo[i + 8u]; - h_hi[i] = h_hi[i] ^ v_hi[i] ^ v_hi[i + 8u]; - } -} - -// Read a single byte from the packed-u32 input arena (little-endian). -fn read_byte(byte_offset: u32) -> u32 { - let word_idx = byte_offset >> 2u; - let byte_pos = byte_offset & 3u; - return (data[word_idx] >> (byte_pos * 8u)) & 0xFFu; -} - -@compute @workgroup_size(64) -fn blake2b_jobs(@builtin(global_invocation_id) gid: vec3) { - let tid = gid.x; - if (tid >= arrayLength(&inputs)) { return; } - - let inp = inputs[tid]; - let offset = inp.offset; - let len = inp.length; - - // Param block (sec 2.5): digest_len=64, key_len=0, fanout=1, depth=1. - // h[0] ^= 0x0000000001010040. - for (var i = 0u; i < 8u; i = i + 1u) { - h_lo[i] = IV_LO[i]; - h_hi[i] = IV_HI[i]; - } - h_lo[0] = h_lo[0] ^ 0x01010040u; - - var t0 = vec2(0u, 0u); - var t1 = vec2(0u, 0u); - var pos = 0u; - - // Stream all but the final block as non-last. - loop { - if (len - pos <= 128u) { break; } - - for (var w = 0u; w < 16u; w = w + 1u) { - var lo = 0u; - var hi = 0u; - for (var b = 0u; b < 4u; b = b + 1u) { - lo = lo | (read_byte(offset + pos + w * 8u + b) << (b * 8u)); - } - for (var b = 0u; b < 4u; b = b + 1u) { - hi = hi | (read_byte(offset + pos + w * 8u + 4u + b) << (b * 8u)); - } - m_lo[w] = lo; - m_hi[w] = hi; - } - - pos = pos + 128u; - let t0_new = add64(t0, vec2(128u, 0u)); - if (t0_new.y < t0.y) { t1 = add64(t1, vec2(1u, 0u)); } - t0 = t0_new; - compress(t0, t1, false); - } - - // Final (possibly partial) block, zero-padded. - let rem = len - pos; - for (var w = 0u; w < 16u; w = w + 1u) { - m_lo[w] = 0u; - m_hi[w] = 0u; - } - for (var i = 0u; i < rem; i = i + 1u) { - let byte_val = read_byte(offset + pos + i); - let word_idx = i >> 3u; - let byte_in_word = i & 7u; - if (byte_in_word < 4u) { - m_lo[word_idx] = m_lo[word_idx] | (byte_val << (byte_in_word * 8u)); - } else { - m_hi[word_idx] = m_hi[word_idx] | (byte_val << ((byte_in_word - 4u) * 8u)); - } - } - let t0_new = add64(t0, vec2(rem, 0u)); - if (t0_new.y < t0.y) { t1 = add64(t1, vec2(1u, 0u)); } - t0 = t0_new; - compress(t0, t1, true); - - // Output 64 bytes little-endian. outputs[] is u32-packed: each h[i] - // becomes 2 consecutive u32 lanes (lo, hi). - let out_base = tid * 16u; - for (var i = 0u; i < 8u; i = i + 1u) { - outputs[out_base + i * 2u] = h_lo[i]; - outputs[out_base + i * 2u + 1u] = h_hi[i]; - } -} diff --git a/blake2b/gpu/wgsl/blake2b_wgsl_host.cpp b/blake2b/gpu/wgsl/blake2b_wgsl_host.cpp deleted file mode 100644 index efcd358..0000000 --- a/blake2b/gpu/wgsl/blake2b_wgsl_host.cpp +++ /dev/null @@ -1,215 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Host-emulation translation of blake2b.wgsl. Each operation here mirrors -// the WGSL kernel one-for-one — same constants, same control flow, same -// vec2-emulated u64 arithmetic — so byte-equality with the CPU oracle -// proves the WGSL kernel will produce identical output on a real GPU. -// -// The wire format matches the kernel's bind group: data is uploaded as -// packed u32 (little-endian), inputs are HashInput descriptors, outputs -// are 16 u32 lanes per 64-byte digest. When a real wgpu host driver lands -// the body of `blake2b_batch_wgsl_host` is replaced by the dispatch path; -// the harness side stays unchanged. - -#include -#include -#include - -namespace { - -// ============================================================================= -// IV (RFC 7693 sec 2.6) split (lo, hi). -// ============================================================================= -constexpr uint32_t IV_LO[8] = { - 0xF3BCC908u, 0x84CAA73Bu, 0xFE94F82Bu, 0x5F1D36F1u, - 0xADE682D1u, 0x2B3E6C1Fu, 0xFB41BD6Bu, 0x137E2179u, -}; -constexpr uint32_t IV_HI[8] = { - 0x6A09E667u, 0xBB67AE85u, 0x3C6EF372u, 0xA54FF53Au, - 0x510E527Fu, 0x9B05688Cu, 0x1F83D9ABu, 0x5BE0CD19u, -}; - -constexpr uint8_t SIGMA[10][16] = { - { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 }, - { 14, 10, 4, 8, 9, 15, 13, 6, 1, 12, 0, 2, 11, 7, 5, 3 }, - { 11, 8, 12, 0, 5, 2, 15, 13, 10, 14, 3, 6, 7, 1, 9, 4 }, - { 7, 9, 3, 1, 13, 12, 11, 14, 2, 6, 5, 10, 4, 0, 15, 8 }, - { 9, 0, 5, 7, 2, 4, 10, 15, 14, 1, 11, 12, 6, 8, 3, 13 }, - { 2, 12, 6, 10, 0, 11, 8, 3, 4, 13, 7, 5, 15, 14, 1, 9 }, - { 12, 5, 1, 15, 14, 13, 4, 10, 0, 7, 6, 3, 9, 2, 8, 11 }, - { 13, 11, 7, 14, 12, 1, 3, 9, 5, 0, 15, 4, 8, 6, 2, 10 }, - { 6, 15, 14, 9, 11, 3, 0, 8, 12, 2, 13, 7, 1, 4, 10, 5 }, - { 10, 2, 8, 4, 7, 6, 1, 5, 15, 11, 9, 14, 3, 12, 13, 0 }, -}; - -// ============================================================================= -// 64-bit emulation primitives — identical to the WGSL helpers in blake2b.wgsl. -// ============================================================================= -struct U64 { uint32_t lo; uint32_t hi; }; - -inline U64 xor64(U64 a, U64 b) { return U64{a.lo ^ b.lo, a.hi ^ b.hi}; } - -inline U64 add64(U64 a, U64 b) { - U64 r; - r.lo = a.lo + b.lo; - uint32_t carry = (r.lo < a.lo) ? 1u : 0u; - r.hi = a.hi + b.hi + carry; - return r; -} - -inline U64 rotr64(U64 v, uint32_t n) { - if (n == 0u) return v; - if (n == 32u) return U64{v.hi, v.lo}; - if (n < 32u) { - U64 r; - r.lo = (v.lo >> n) | (v.hi << (32u - n)); - r.hi = (v.hi >> n) | (v.lo << (32u - n)); - return r; - } - uint32_t m = n - 32u; - U64 r; - r.lo = (v.hi >> m) | (v.lo << (32u - m)); - r.hi = (v.lo >> m) | (v.hi << (32u - m)); - return r; -} - -inline U64 not64(U64 a) { return U64{~a.lo, ~a.hi}; } - -inline uint32_t read_byte(const uint32_t* arena, uint32_t byte_offset) { - uint32_t word_idx = byte_offset >> 2u; - uint32_t byte_pos = byte_offset & 3u; - return (arena[word_idx] >> (byte_pos * 8u)) & 0xFFu; -} - -struct State { - U64 v[16]; - U64 h[8]; - U64 m[16]; -}; - -inline void g_mix(State& s, uint32_t a, uint32_t b, uint32_t c, uint32_t d, - uint32_t mx, uint32_t my) { - s.v[a] = add64(add64(s.v[a], s.v[b]), s.m[mx]); - s.v[d] = rotr64(xor64(s.v[d], s.v[a]), 32u); - s.v[c] = add64(s.v[c], s.v[d]); - s.v[b] = rotr64(xor64(s.v[b], s.v[c]), 24u); - s.v[a] = add64(add64(s.v[a], s.v[b]), s.m[my]); - s.v[d] = rotr64(xor64(s.v[d], s.v[a]), 16u); - s.v[c] = add64(s.v[c], s.v[d]); - s.v[b] = rotr64(xor64(s.v[b], s.v[c]), 63u); -} - -inline void compress(State& s, U64 t0, U64 t1, bool last_block) { - for (int i = 0; i < 8; ++i) s.v[i] = s.h[i]; - for (int i = 0; i < 8; ++i) s.v[i + 8] = U64{IV_LO[i], IV_HI[i]}; - s.v[12] = xor64(s.v[12], t0); - s.v[13] = xor64(s.v[13], t1); - if (last_block) s.v[14] = not64(s.v[14]); - - for (int r = 0; r < 12; ++r) { - const uint8_t* sig = SIGMA[r % 10]; - g_mix(s, 0, 4, 8, 12, sig[ 0], sig[ 1]); - g_mix(s, 1, 5, 9, 13, sig[ 2], sig[ 3]); - g_mix(s, 2, 6, 10, 14, sig[ 4], sig[ 5]); - g_mix(s, 3, 7, 11, 15, sig[ 6], sig[ 7]); - g_mix(s, 0, 5, 10, 15, sig[ 8], sig[ 9]); - g_mix(s, 1, 6, 11, 12, sig[10], sig[11]); - g_mix(s, 2, 7, 8, 13, sig[12], sig[13]); - g_mix(s, 3, 4, 9, 14, sig[14], sig[15]); - } - for (int i = 0; i < 8; ++i) { - s.h[i] = xor64(s.h[i], xor64(s.v[i], s.v[i + 8])); - } -} - -} // namespace - -// ============================================================================= -// Public host entry — wire format matches the WGSL bind group: data is a -// flat byte arena (packed into u32 little-endian internally), per-input -// (offset, length) descriptors, and a 64-byte stride output buffer. -// ============================================================================= -extern "C" int blake2b_batch_wgsl_host( - const uint8_t* data, - size_t data_len, - const uint32_t* offsets, - const uint32_t* lengths, - uint8_t* outputs, - uint32_t num_inputs) -{ - if (num_inputs == 0) return 0; - if (!data || !offsets || !lengths || !outputs) return -1; - - // Pack into u32 words exactly the way wgpuQueueWriteBuffer would upload - // the byte arena into a `array` storage buffer. - size_t pad = (4 - (data_len & 3)) & 3; - size_t word_count = (data_len + pad) / 4; - if (word_count == 0) word_count = 1; - std::vector packed(word_count, 0u); - if (data_len > 0) std::memcpy(packed.data(), data, data_len); - - for (uint32_t tid = 0; tid < num_inputs; ++tid) { - uint32_t offset = offsets[tid]; - uint32_t len = lengths[tid]; - - State s{}; - for (int i = 0; i < 8; ++i) s.h[i] = U64{IV_LO[i], IV_HI[i]}; - s.h[0].lo ^= 0x01010040u; - - U64 t0{0u, 0u}; - U64 t1{0u, 0u}; - uint32_t pos = 0; - - while (len - pos > 128u) { - for (int w = 0; w < 16; ++w) { - uint32_t lo = 0, hi = 0; - for (int b = 0; b < 4; ++b) { - lo |= read_byte(packed.data(), offset + pos + w * 8u + b) << (b * 8u); - } - for (int b = 0; b < 4; ++b) { - hi |= read_byte(packed.data(), offset + pos + w * 8u + 4u + b) << (b * 8u); - } - s.m[w] = U64{lo, hi}; - } - pos += 128u; - U64 t0_new = add64(t0, U64{128u, 0u}); - if (t0_new.hi < t0.hi) t1 = add64(t1, U64{1u, 0u}); - t0 = t0_new; - compress(s, t0, t1, false); - } - - const uint32_t rem = len - pos; - for (int w = 0; w < 16; ++w) s.m[w] = U64{0u, 0u}; - for (uint32_t i = 0; i < rem; ++i) { - uint32_t byte_val = read_byte(packed.data(), offset + pos + i); - uint32_t word_idx = i >> 3u; - uint32_t byte_in_word = i & 7u; - if (byte_in_word < 4u) { - s.m[word_idx].lo |= byte_val << (byte_in_word * 8u); - } else { - s.m[word_idx].hi |= byte_val << ((byte_in_word - 4u) * 8u); - } - } - U64 t0_new = add64(t0, U64{rem, 0u}); - if (t0_new.hi < t0.hi) t1 = add64(t1, U64{1u, 0u}); - t0 = t0_new; - compress(s, t0, t1, true); - - // Emit 64 bytes little-endian: each h[i] = (lo, hi) -> 8 bytes. - uint8_t* out = outputs + tid * 64u; - for (int i = 0; i < 8; ++i) { - uint32_t lo = s.h[i].lo; - uint32_t hi = s.h[i].hi; - out[i * 8 + 0] = (uint8_t)(lo & 0xFFu); - out[i * 8 + 1] = (uint8_t)((lo >> 8) & 0xFFu); - out[i * 8 + 2] = (uint8_t)((lo >> 16) & 0xFFu); - out[i * 8 + 3] = (uint8_t)((lo >> 24) & 0xFFu); - out[i * 8 + 4] = (uint8_t)(hi & 0xFFu); - out[i * 8 + 5] = (uint8_t)((hi >> 8) & 0xFFu); - out[i * 8 + 6] = (uint8_t)((hi >> 16) & 0xFFu); - out[i * 8 + 7] = (uint8_t)((hi >> 24) & 0xFFu); - } - } - return 0; -} diff --git a/blake3/gpu/cuda/blake3.cu b/blake3/gpu/cuda/blake3.cu deleted file mode 100644 index 7c41934..0000000 --- a/blake3/gpu/cuda/blake3.cu +++ /dev/null @@ -1,315 +0,0 @@ -// BLAKE3 batch hash — CUDA implementation -// Matches blake3.metal output byte-for-byte -// One thread per hash - -#include - -#ifndef __CUDA_ARCH__ -#define __device__ -#define __global__ -#define __shared__ -struct dim3 { unsigned x, y, z; }; -static dim3 blockIdx, blockDim, threadIdx; -#endif - -// ============================================================================= -// BLAKE3 constants -// ============================================================================= - -__device__ static const uint32_t BLAKE3_IV[8] = { - 0x6A09E667u, 0xBB67AE85u, 0x3C6EF372u, 0xA54FF53Au, - 0x510E527Fu, 0x9B05688Cu, 0x1F83D9ABu, 0x5BE0CD19u -}; - -__device__ static const uint32_t BLAKE3_CHUNK_START = 1u; -__device__ static const uint32_t BLAKE3_CHUNK_END = 2u; -__device__ static const uint32_t BLAKE3_ROOT = 8u; - -__device__ static const uint8_t BLAKE3_MSG_PERM[16] = { - 2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8 -}; - -// ============================================================================= -// BLAKE3 quarter-round G function -// ============================================================================= - -__device__ static inline uint32_t rotr32(uint32_t x, uint32_t n) { - return (x >> n) | (x << (32u - n)); -} - -__device__ static void blake3_g(uint32_t state[16], int a, int b, int c, int d, - uint32_t mx, uint32_t my) { - state[a] = state[a] + state[b] + mx; - state[d] = rotr32(state[d] ^ state[a], 16u); - state[c] = state[c] + state[d]; - state[b] = rotr32(state[b] ^ state[c], 12u); - state[a] = state[a] + state[b] + my; - state[d] = rotr32(state[d] ^ state[a], 8u); - state[c] = state[c] + state[d]; - state[b] = rotr32(state[b] ^ state[c], 7u); -} - -// ============================================================================= -// BLAKE3 round (column + diagonal) -// ============================================================================= - -__device__ static void blake3_round(uint32_t state[16], const uint32_t m[16]) { - // Columns - blake3_g(state, 0, 4, 8, 12, m[0], m[1]); - blake3_g(state, 1, 5, 9, 13, m[2], m[3]); - blake3_g(state, 2, 6, 10, 14, m[4], m[5]); - blake3_g(state, 3, 7, 11, 15, m[6], m[7]); - // Diagonals - blake3_g(state, 0, 5, 10, 15, m[8], m[9]); - blake3_g(state, 1, 6, 11, 12, m[10], m[11]); - blake3_g(state, 2, 7, 8, 13, m[12], m[13]); - blake3_g(state, 3, 4, 9, 14, m[14], m[15]); -} - -// ============================================================================= -// BLAKE3 compression function -// ============================================================================= - -__device__ static void blake3_compress(const uint32_t cv[8], - const uint8_t block[64], - uint64_t counter, - uint32_t block_len, - uint32_t flags, - uint32_t out[8]) { - // Load message words (little-endian) - uint32_t m[16]; - for (int i = 0; i < 16; i++) { - m[i] = (uint32_t)block[i * 4] - | ((uint32_t)block[i * 4 + 1] << 8) - | ((uint32_t)block[i * 4 + 2] << 16) - | ((uint32_t)block[i * 4 + 3] << 24); - } - - uint32_t state[16] = { - cv[0], cv[1], cv[2], cv[3], - cv[4], cv[5], cv[6], cv[7], - BLAKE3_IV[0], BLAKE3_IV[1], BLAKE3_IV[2], BLAKE3_IV[3], - (uint32_t)(counter & 0xFFFFFFFFu), - (uint32_t)(counter >> 32), - block_len, - flags - }; - - // 7 rounds with message permutation - for (int round = 0; round < 7; round++) { - blake3_round(state, m); - // Permute message words - uint32_t tmp[16]; - for (int i = 0; i < 16; i++) { - tmp[i] = m[BLAKE3_MSG_PERM[i]]; - } - for (int i = 0; i < 16; i++) m[i] = tmp[i]; - } - - // Output: state[0..7] XOR state[8..15] - for (int i = 0; i < 8; i++) { - out[i] = state[i] ^ state[i + 8]; - } -} - -// ============================================================================= -// Input descriptor -// ============================================================================= - -struct HashInput { - uint32_t offset; - uint32_t length; -}; - -// ============================================================================= -// Kernel: blake3_hash_batch -// ============================================================================= - -extern "C" __global__ void blake3_hash_batch( - const HashInput* __restrict__ inputs, - const uint8_t* __restrict__ data, - uint8_t* __restrict__ outputs, - const uint32_t num_inputs) -{ - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= num_inputs) return; - - const uint32_t offset = inputs[tid].offset; - const uint32_t len = inputs[tid].length; - - // Process chunks (each chunk = 1024 bytes = 16 blocks of 64 bytes) - const uint32_t chunk_size = 1024; - uint32_t num_chunks = (len + chunk_size - 1) / chunk_size; - if (num_chunks == 0) num_chunks = 1; - - // Single chunk (most common case) - if (num_chunks == 1) { - uint32_t cv[8]; - for (int i = 0; i < 8; i++) cv[i] = BLAKE3_IV[i]; - - uint32_t remaining = len; - uint32_t pos = 0; - uint32_t block_idx = 0; - - while (remaining > 0 || block_idx == 0) { - uint8_t block[64] = {}; - uint32_t to_copy = (remaining > 64) ? 64 : remaining; - for (uint32_t i = 0; i < to_copy; i++) { - block[i] = data[offset + pos + i]; - } - - uint32_t flags = 0; - if (block_idx == 0) flags |= BLAKE3_CHUNK_START; - bool is_last = (remaining <= 64); - if (is_last) flags |= BLAKE3_CHUNK_END | BLAKE3_ROOT; - - uint32_t out[8]; - blake3_compress(cv, block, 0, to_copy, flags, out); - - if (is_last) { - uint8_t* dst = outputs + tid * 32; - for (int i = 0; i < 8; i++) { - dst[i * 4] = (uint8_t)(out[i] & 0xFF); - dst[i * 4 + 1] = (uint8_t)((out[i] >> 8) & 0xFF); - dst[i * 4 + 2] = (uint8_t)((out[i] >> 16) & 0xFF); - dst[i * 4 + 3] = (uint8_t)((out[i] >> 24) & 0xFF); - } - return; - } - - for (int i = 0; i < 8; i++) cv[i] = out[i]; - pos += to_copy; - remaining -= to_copy; - block_idx++; - } - } - - // Multi-chunk: process each chunk independently, then tree-hash the CVs - uint32_t cv_stack[20][8]; - int stack_depth = 0; - - for (uint32_t chunk = 0; chunk < num_chunks; chunk++) { - uint32_t cv[8]; - for (int i = 0; i < 8; i++) cv[i] = BLAKE3_IV[i]; - - uint32_t chunk_start = offset + chunk * chunk_size; - uint32_t chunk_len = (chunk == num_chunks - 1) ? (len - chunk * chunk_size) : chunk_size; - uint32_t remaining = chunk_len; - uint32_t pos = 0; - uint32_t block_idx = 0; - - while (remaining > 0 || block_idx == 0) { - uint8_t block[64] = {}; - uint32_t to_copy = (remaining > 64) ? 64 : remaining; - for (uint32_t i = 0; i < to_copy; i++) { - block[i] = data[chunk_start + pos + i]; - } - - uint32_t flags = 0; - if (block_idx == 0) flags |= BLAKE3_CHUNK_START; - if (remaining <= 64) flags |= BLAKE3_CHUNK_END; - - uint32_t out[8]; - blake3_compress(cv, block, chunk, to_copy, flags, out); - - if (remaining <= 64) { - for (int i = 0; i < 8; i++) cv_stack[stack_depth][i] = out[i]; - stack_depth++; - - while (stack_depth >= 2) { - bool is_root = (chunk == num_chunks - 1) && (stack_depth == 2); - - uint8_t parent_block[64]; - for (int i = 0; i < 8; i++) { - uint32_t w = cv_stack[stack_depth - 2][i]; - parent_block[i * 4] = (uint8_t)(w & 0xFF); - parent_block[i * 4 + 1] = (uint8_t)((w >> 8) & 0xFF); - parent_block[i * 4 + 2] = (uint8_t)((w >> 16) & 0xFF); - parent_block[i * 4 + 3] = (uint8_t)((w >> 24) & 0xFF); - } - for (int i = 0; i < 8; i++) { - uint32_t w = cv_stack[stack_depth - 1][i]; - parent_block[32 + i * 4] = (uint8_t)(w & 0xFF); - parent_block[32 + i * 4 + 1] = (uint8_t)((w >> 8) & 0xFF); - parent_block[32 + i * 4 + 2] = (uint8_t)((w >> 16) & 0xFF); - parent_block[32 + i * 4 + 3] = (uint8_t)((w >> 24) & 0xFF); - } - - uint32_t parent_cv[8]; - for (int i = 0; i < 8; i++) parent_cv[i] = BLAKE3_IV[i]; - - uint32_t parent_flags = 4u; // PARENT flag - if (is_root) parent_flags |= BLAKE3_ROOT; - - uint32_t parent_out[8]; - blake3_compress(parent_cv, parent_block, 0, 64, parent_flags, parent_out); - - stack_depth -= 2; - - if (is_root) { - uint8_t* dst = outputs + tid * 32; - for (int i = 0; i < 8; i++) { - dst[i * 4] = (uint8_t)(parent_out[i] & 0xFF); - dst[i * 4 + 1] = (uint8_t)((parent_out[i] >> 8) & 0xFF); - dst[i * 4 + 2] = (uint8_t)((parent_out[i] >> 16) & 0xFF); - dst[i * 4 + 3] = (uint8_t)((parent_out[i] >> 24) & 0xFF); - } - return; - } - - for (int i = 0; i < 8; i++) cv_stack[stack_depth][i] = parent_out[i]; - stack_depth++; - break; // only merge one pair per chunk - } - break; - } - - for (int i = 0; i < 8; i++) cv[i] = out[i]; - pos += to_copy; - remaining -= to_copy; - block_idx++; - } - } - - // Merge remaining stack entries - while (stack_depth >= 2) { - bool is_root = (stack_depth == 2); - uint8_t parent_block[64]; - for (int i = 0; i < 8; i++) { - uint32_t w = cv_stack[stack_depth - 2][i]; - parent_block[i * 4] = (uint8_t)(w & 0xFF); - parent_block[i * 4 + 1] = (uint8_t)((w >> 8) & 0xFF); - parent_block[i * 4 + 2] = (uint8_t)((w >> 16) & 0xFF); - parent_block[i * 4 + 3] = (uint8_t)((w >> 24) & 0xFF); - } - for (int i = 0; i < 8; i++) { - uint32_t w = cv_stack[stack_depth - 1][i]; - parent_block[32 + i * 4] = (uint8_t)(w & 0xFF); - parent_block[32 + i * 4 + 1] = (uint8_t)((w >> 8) & 0xFF); - parent_block[32 + i * 4 + 2] = (uint8_t)((w >> 16) & 0xFF); - parent_block[32 + i * 4 + 3] = (uint8_t)((w >> 24) & 0xFF); - } - - uint32_t parent_cv[8]; - for (int i = 0; i < 8; i++) parent_cv[i] = BLAKE3_IV[i]; - uint32_t parent_flags = 4u; - if (is_root) parent_flags |= BLAKE3_ROOT; - - uint32_t parent_out[8]; - blake3_compress(parent_cv, parent_block, 0, 64, parent_flags, parent_out); - - stack_depth -= 2; - if (is_root) { - uint8_t* dst = outputs + tid * 32; - for (int i = 0; i < 8; i++) { - dst[i * 4] = (uint8_t)(parent_out[i] & 0xFF); - dst[i * 4 + 1] = (uint8_t)((parent_out[i] >> 8) & 0xFF); - dst[i * 4 + 2] = (uint8_t)((parent_out[i] >> 16) & 0xFF); - dst[i * 4 + 3] = (uint8_t)((parent_out[i] >> 24) & 0xFF); - } - return; - } - for (int i = 0; i < 8; i++) cv_stack[stack_depth][i] = parent_out[i]; - stack_depth++; - } -} diff --git a/blake3/gpu/metal/blake3.metal b/blake3/gpu/metal/blake3.metal deleted file mode 100644 index 8ef8bb6..0000000 --- a/blake3/gpu/metal/blake3.metal +++ /dev/null @@ -1,340 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -/// @file blake3.metal -/// Metal compute shader for parallel BLAKE3 hashing. -/// -/// BLAKE3 is a tree-based hash with 7 rounds per compression. Each 1024-byte -/// chunk is independently hashable, making it ideal for GPU parallelism. -/// -/// Kernel: blake3_hash_batch -/// - Each thread hashes one input independently -/// - Supports inputs up to ~64KB (single chunk for simplicity; tree mode -/// for longer inputs would use chunk-level parallelism) -/// - Output: 32 bytes per input -/// -/// Reference: https://github.com/BLAKE3-team/BLAKE3-spec - -#include -using namespace metal; - -// ============================================================================= -// BLAKE3 constants -// ============================================================================= - -constant uint BLAKE3_IV[8] = { - 0x6A09E667u, 0xBB67AE85u, 0x3C6EF372u, 0xA54FF53Au, - 0x510E527Fu, 0x9B05688Cu, 0x1F83D9ABu, 0x5BE0CD19u -}; - -// Domain separation flags -constant uint BLAKE3_CHUNK_START = 1u; -constant uint BLAKE3_CHUNK_END = 2u; -constant uint BLAKE3_ROOT = 8u; - -// Message word permutation (applied after each round) -constant uchar BLAKE3_MSG_PERM[16] = { - 2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8 -}; - -// ============================================================================= -// BLAKE3 quarter-round G function -// ============================================================================= - -inline uint rotr32(uint x, uint n) { - return (x >> n) | (x << (32u - n)); -} - -inline void blake3_g(thread uint state[16], int a, int b, int c, int d, - uint mx, uint my) { - state[a] = state[a] + state[b] + mx; - state[d] = rotr32(state[d] ^ state[a], 16u); - state[c] = state[c] + state[d]; - state[b] = rotr32(state[b] ^ state[c], 12u); - state[a] = state[a] + state[b] + my; - state[d] = rotr32(state[d] ^ state[a], 8u); - state[c] = state[c] + state[d]; - state[b] = rotr32(state[b] ^ state[c], 7u); -} - -// ============================================================================= -// BLAKE3 round (column + diagonal) -// ============================================================================= - -inline void blake3_round(thread uint state[16], thread const uint m[16]) { - // Columns - blake3_g(state, 0, 4, 8, 12, m[0], m[1]); - blake3_g(state, 1, 5, 9, 13, m[2], m[3]); - blake3_g(state, 2, 6, 10, 14, m[4], m[5]); - blake3_g(state, 3, 7, 11, 15, m[6], m[7]); - // Diagonals - blake3_g(state, 0, 5, 10, 15, m[8], m[9]); - blake3_g(state, 1, 6, 11, 12, m[10], m[11]); - blake3_g(state, 2, 7, 8, 13, m[12], m[13]); - blake3_g(state, 3, 4, 9, 14, m[14], m[15]); -} - -// ============================================================================= -// BLAKE3 compression function -// ============================================================================= - -/// Compress one 64-byte block. -/// cv: 8-word chaining value -/// block: 64-byte message block -/// counter: block counter within chunk -/// block_len: actual bytes in this block (0-64) -/// flags: domain separation flags -/// out: 8-word output chaining value -inline void blake3_compress(thread const uint cv[8], - thread const uchar block[64], - ulong counter, - uint block_len, - uint flags, - thread uint out[8]) { - // Load message words (little-endian) - uint m[16]; - for (int i = 0; i < 16; i++) { - m[i] = uint(block[i * 4]) - | (uint(block[i * 4 + 1]) << 8) - | (uint(block[i * 4 + 2]) << 16) - | (uint(block[i * 4 + 3]) << 24); - } - - uint state[16] = { - cv[0], cv[1], cv[2], cv[3], - cv[4], cv[5], cv[6], cv[7], - BLAKE3_IV[0], BLAKE3_IV[1], BLAKE3_IV[2], BLAKE3_IV[3], - uint(counter & 0xFFFFFFFFu), - uint(counter >> 32), - block_len, - flags - }; - - // 7 rounds with message permutation - for (int round = 0; round < 7; round++) { - blake3_round(state, m); - // Permute message words - uint tmp[16]; - for (int i = 0; i < 16; i++) { - tmp[i] = m[BLAKE3_MSG_PERM[i]]; - } - for (int i = 0; i < 16; i++) m[i] = tmp[i]; - } - - // Output: state[0..7] XOR state[8..15] - for (int i = 0; i < 8; i++) { - out[i] = state[i] ^ state[i + 8]; - } -} - -// ============================================================================= -// Input descriptor -// ============================================================================= - -struct HashInput { - uint offset; - uint length; -}; - -// ============================================================================= -// Kernel: blake3_hash_batch -// ============================================================================= - -/// Each thread hashes one input to a 32-byte BLAKE3 digest. -/// For inputs <= 1024 bytes, this is a single-chunk hash. -/// For inputs > 1024 bytes, we process multiple chunks and merge via tree mode. -kernel void blake3_hash_batch( - device const HashInput* inputs [[buffer(0)]], - device const uchar* data [[buffer(1)]], - device uchar* outputs [[buffer(2)]], - constant uint& num_inputs [[buffer(3)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= num_inputs) return; - - const uint offset = inputs[tid].offset; - const uint len = inputs[tid].length; - - // Process chunks (each chunk = 1024 bytes = 16 blocks of 64 bytes) - const uint chunk_size = 1024; - uint num_chunks = (len + chunk_size - 1) / chunk_size; - if (num_chunks == 0) num_chunks = 1; - - // For single chunk (most common case), compute directly - if (num_chunks == 1) { - uint cv[8]; - for (int i = 0; i < 8; i++) cv[i] = BLAKE3_IV[i]; - - uint remaining = len; - uint pos = 0; - uint block_idx = 0; - - while (remaining > 0 || block_idx == 0) { - uchar block[64] = {}; - uint to_copy = (remaining > 64) ? 64 : remaining; - for (uint i = 0; i < to_copy; i++) { - block[i] = data[offset + pos + i]; - } - - uint flags = 0; - if (block_idx == 0) flags |= BLAKE3_CHUNK_START; - bool is_last = (remaining <= 64); - if (is_last) flags |= BLAKE3_CHUNK_END | BLAKE3_ROOT; - - uint out[8]; - blake3_compress(cv, block, 0, to_copy, flags, out); - - if (is_last) { - // Write final hash - device uchar* dst = outputs + tid * 32; - for (int i = 0; i < 8; i++) { - dst[i * 4] = uchar(out[i] & 0xFF); - dst[i * 4 + 1] = uchar((out[i] >> 8) & 0xFF); - dst[i * 4 + 2] = uchar((out[i] >> 16) & 0xFF); - dst[i * 4 + 3] = uchar((out[i] >> 24) & 0xFF); - } - return; - } - - for (int i = 0; i < 8; i++) cv[i] = out[i]; - pos += to_copy; - remaining -= to_copy; - block_idx++; - } - } - - // Multi-chunk: process each chunk independently, then tree-hash the CVs - // Stack for tree hashing (max depth ~20 for 2^20 chunks) - uint cv_stack[20][8]; - int stack_depth = 0; - - for (uint chunk = 0; chunk < num_chunks; chunk++) { - uint cv[8]; - for (int i = 0; i < 8; i++) cv[i] = BLAKE3_IV[i]; - - uint chunk_start = offset + chunk * chunk_size; - uint chunk_len = (chunk == num_chunks - 1) ? (len - chunk * chunk_size) : chunk_size; - uint remaining = chunk_len; - uint pos = 0; - uint block_idx = 0; - - // Process blocks within this chunk - while (remaining > 0 || block_idx == 0) { - uchar block[64] = {}; - uint to_copy = (remaining > 64) ? 64 : remaining; - for (uint i = 0; i < to_copy; i++) { - block[i] = data[chunk_start + pos + i]; - } - - uint flags = 0; - if (block_idx == 0) flags |= BLAKE3_CHUNK_START; - if (remaining <= 64) flags |= BLAKE3_CHUNK_END; - - uint out[8]; - blake3_compress(cv, block, chunk, to_copy, flags, out); - - if (remaining <= 64) { - // Push chunk CV onto stack and merge - for (int i = 0; i < 8; i++) cv_stack[stack_depth][i] = out[i]; - stack_depth++; - - // Merge pairs while we have complete pairs at same level - while (stack_depth >= 2) { - // Check if this is the very last merge (root) - bool is_root = (chunk == num_chunks - 1) && (stack_depth == 2); - - // Build parent block: left_cv[32] || right_cv[32] - uchar parent_block[64]; - for (int i = 0; i < 8; i++) { - uint w = cv_stack[stack_depth - 2][i]; - parent_block[i * 4] = uchar(w & 0xFF); - parent_block[i * 4 + 1] = uchar((w >> 8) & 0xFF); - parent_block[i * 4 + 2] = uchar((w >> 16) & 0xFF); - parent_block[i * 4 + 3] = uchar((w >> 24) & 0xFF); - } - for (int i = 0; i < 8; i++) { - uint w = cv_stack[stack_depth - 1][i]; - parent_block[32 + i * 4] = uchar(w & 0xFF); - parent_block[32 + i * 4 + 1] = uchar((w >> 8) & 0xFF); - parent_block[32 + i * 4 + 2] = uchar((w >> 16) & 0xFF); - parent_block[32 + i * 4 + 3] = uchar((w >> 24) & 0xFF); - } - - uint parent_cv[8]; - for (int i = 0; i < 8; i++) parent_cv[i] = BLAKE3_IV[i]; - - uint parent_flags = 4u; // PARENT flag - if (is_root) parent_flags |= BLAKE3_ROOT; - - uint parent_out[8]; - blake3_compress(parent_cv, parent_block, 0, 64, parent_flags, parent_out); - - stack_depth -= 2; - - if (is_root) { - device uchar* dst = outputs + tid * 32; - for (int i = 0; i < 8; i++) { - dst[i * 4] = uchar(parent_out[i] & 0xFF); - dst[i * 4 + 1] = uchar((parent_out[i] >> 8) & 0xFF); - dst[i * 4 + 2] = uchar((parent_out[i] >> 16) & 0xFF); - dst[i * 4 + 3] = uchar((parent_out[i] >> 24) & 0xFF); - } - return; - } - - for (int i = 0; i < 8; i++) cv_stack[stack_depth][i] = parent_out[i]; - stack_depth++; - break; // only merge one pair per chunk - } - break; - } - - for (int i = 0; i < 8; i++) cv[i] = out[i]; - pos += to_copy; - remaining -= to_copy; - block_idx++; - } - } - - // Merge remaining stack entries (right-to-left) - while (stack_depth >= 2) { - bool is_root = (stack_depth == 2); - uchar parent_block[64]; - for (int i = 0; i < 8; i++) { - uint w = cv_stack[stack_depth - 2][i]; - parent_block[i * 4] = uchar(w & 0xFF); - parent_block[i * 4 + 1] = uchar((w >> 8) & 0xFF); - parent_block[i * 4 + 2] = uchar((w >> 16) & 0xFF); - parent_block[i * 4 + 3] = uchar((w >> 24) & 0xFF); - } - for (int i = 0; i < 8; i++) { - uint w = cv_stack[stack_depth - 1][i]; - parent_block[32 + i * 4] = uchar(w & 0xFF); - parent_block[32 + i * 4 + 1] = uchar((w >> 8) & 0xFF); - parent_block[32 + i * 4 + 2] = uchar((w >> 16) & 0xFF); - parent_block[32 + i * 4 + 3] = uchar((w >> 24) & 0xFF); - } - - uint parent_cv[8]; - for (int i = 0; i < 8; i++) parent_cv[i] = BLAKE3_IV[i]; - uint parent_flags = 4u; - if (is_root) parent_flags |= BLAKE3_ROOT; - - uint parent_out[8]; - blake3_compress(parent_cv, parent_block, 0, 64, parent_flags, parent_out); - - stack_depth -= 2; - if (is_root) { - device uchar* dst = outputs + tid * 32; - for (int i = 0; i < 8; i++) { - dst[i * 4] = uchar(parent_out[i] & 0xFF); - dst[i * 4 + 1] = uchar((parent_out[i] >> 8) & 0xFF); - dst[i * 4 + 2] = uchar((parent_out[i] >> 16) & 0xFF); - dst[i * 4 + 3] = uchar((parent_out[i] >> 24) & 0xFF); - } - return; - } - for (int i = 0; i < 8; i++) cv_stack[stack_depth][i] = parent_out[i]; - stack_depth++; - } -} diff --git a/blake3/gpu/metal/blake3_authored.metal b/blake3/gpu/metal/blake3_authored.metal deleted file mode 100644 index b33fb6b..0000000 --- a/blake3/gpu/metal/blake3_authored.metal +++ /dev/null @@ -1,624 +0,0 @@ -// ============================================================================= -// Blake3 Metal Compute Shaders -// ============================================================================= -// -// GPU-accelerated Blake3 hash function on Apple Silicon. -// Implements batch hashing for high-throughput applications. -// -// Blake3 Parameters: -// Block size: 64 bytes -// Output size: 256 bits (default), extensible to arbitrary length -// Rounds: 7 per compression -// -// Reference: https://github.com/BLAKE3-team/BLAKE3-specs -// -// Copyright (C) 2024-2025 Lux Industries Inc. -// SPDX-License-Identifier: Apache-2.0 - -#include -using namespace metal; - -// ============================================================================= -// Blake3 Constants -// ============================================================================= - -// Initial state (IV from SHA-256) -constant uint32_t BLAKE3_IV[8] = { - 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, - 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19 -}; - -// Message permutation schedule -constant uint8_t MSG_PERMUTATION[16] = { - 2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8 -}; - -// Block length for chunk chaining -constant uint32_t BLOCK_LEN = 64; -constant uint32_t CHUNK_LEN = 1024; - -// Domain separation flags -constant uint32_t CHUNK_START = 1 << 0; -constant uint32_t CHUNK_END = 1 << 1; -constant uint32_t PARENT = 1 << 2; -constant uint32_t ROOT = 1 << 3; - -// ============================================================================= -// Blake3 State -// ============================================================================= - -struct Blake3State { - uint32_t cv[8]; // Chaining value - uint64_t chunk_counter; - uint8_t block[64]; - uint8_t block_len; - uint8_t blocks_compressed; - uint8_t flags; -}; - -struct Blake3Output { - uint32_t hash[8]; // 256-bit output -}; - -// ============================================================================= -// Rotation and Mixing Functions -// ============================================================================= - -inline uint32_t rotr32(uint32_t x, uint8_t n) { - return (x >> n) | (x << (32 - n)); -} - -// G function - quarter round -inline void g(thread uint32_t& a, thread uint32_t& b, thread uint32_t& c, thread uint32_t& d, - uint32_t mx, uint32_t my) { - a = a + b + mx; - d = rotr32(d ^ a, 16); - c = c + d; - b = rotr32(b ^ c, 12); - a = a + b + my; - d = rotr32(d ^ a, 8); - c = c + d; - b = rotr32(b ^ c, 7); -} - -// ============================================================================= -// Compression Function -// ============================================================================= - -inline void compress(thread uint32_t state[16], - thread const uint32_t cv[8], - thread const uint32_t block_words[16], - uint64_t counter, - uint32_t block_len, - uint32_t flags) { - // Initialize state - for (int i = 0; i < 8; i++) { - state[i] = cv[i]; - } - state[8] = BLAKE3_IV[0]; - state[9] = BLAKE3_IV[1]; - state[10] = BLAKE3_IV[2]; - state[11] = BLAKE3_IV[3]; - state[12] = (uint32_t)counter; - state[13] = (uint32_t)(counter >> 32); - state[14] = block_len; - state[15] = flags; - - // Message schedule - uint32_t m[16]; - for (int i = 0; i < 16; i++) { - m[i] = block_words[i]; - } - - // 7 rounds - for (int round = 0; round < 7; round++) { - // Column step - g(state[0], state[4], state[8], state[12], m[0], m[1]); - g(state[1], state[5], state[9], state[13], m[2], m[3]); - g(state[2], state[6], state[10], state[14], m[4], m[5]); - g(state[3], state[7], state[11], state[15], m[6], m[7]); - - // Diagonal step - g(state[0], state[5], state[10], state[15], m[8], m[9]); - g(state[1], state[6], state[11], state[12], m[10], m[11]); - g(state[2], state[7], state[8], state[13], m[12], m[13]); - g(state[3], state[4], state[9], state[14], m[14], m[15]); - - // Permute message for next round - uint32_t temp[16]; - for (int i = 0; i < 16; i++) { - temp[i] = m[MSG_PERMUTATION[i]]; - } - for (int i = 0; i < 16; i++) { - m[i] = temp[i]; - } - } - - // Finalize (XOR with input chaining value) - for (int i = 0; i < 8; i++) { - state[i] ^= state[i + 8]; - state[i + 8] ^= cv[i]; - } -} - -// ============================================================================= -// Hash Single Block Kernel -// ============================================================================= - -kernel void blake3_hash_block( - device const uint8_t* input [[buffer(0)]], - device uint32_t* output [[buffer(1)]], - device const uint32_t* input_lengths [[buffer(2)]], - uint index [[thread_position_in_grid]] -) { - uint32_t len = input_lengths[index]; - uint32_t offset = index * 64; // Assuming 64-byte aligned inputs - - // Load block words (little-endian) - uint32_t block_words[16]; - for (int i = 0; i < 16; i++) { - uint32_t word = 0; - for (int j = 0; j < 4; j++) { - uint32_t byte_idx = offset + i * 4 + j; - if (i * 4 + j < len) { - word |= ((uint32_t)input[byte_idx]) << (j * 8); - } - } - block_words[i] = word; - } - - // Compress with IV - uint32_t state[16]; - uint32_t cv[8]; - for (int i = 0; i < 8; i++) { - cv[i] = BLAKE3_IV[i]; - } - - uint32_t flags = CHUNK_START | CHUNK_END | ROOT; - compress(state, cv, block_words, 0, len, flags); - - // Output first 8 words (256 bits) - uint32_t out_offset = index * 8; - for (int i = 0; i < 8; i++) { - output[out_offset + i] = state[i]; - } -} - -// ============================================================================= -// Batch Hash Kernel -// ============================================================================= - -kernel void blake3_batch_hash( - device const uint8_t* inputs [[buffer(0)]], - device uint32_t* outputs [[buffer(1)]], - device const uint32_t* offsets [[buffer(2)]], - device const uint32_t* lengths [[buffer(3)]], - uint batch_idx [[thread_position_in_grid]] -) { - uint32_t offset = offsets[batch_idx]; - uint32_t len = lengths[batch_idx]; - - // Initialize chaining value - uint32_t cv[8]; - for (int i = 0; i < 8; i++) { - cv[i] = BLAKE3_IV[i]; - } - - // Process full blocks - uint64_t chunk_counter = 0; - uint32_t bytes_processed = 0; - - while (bytes_processed < len) { - uint32_t block_len = min(64u, len - bytes_processed); - - // Load block - uint32_t block_words[16] = {0}; - for (uint32_t i = 0; i < block_len; i++) { - uint32_t word_idx = i / 4; - uint32_t byte_pos = i % 4; - block_words[word_idx] |= ((uint32_t)inputs[offset + bytes_processed + i]) << (byte_pos * 8); - } - - // Determine flags - uint32_t flags = 0; - if (bytes_processed == 0) flags |= CHUNK_START; - if (bytes_processed + block_len >= len) flags |= CHUNK_END | ROOT; - - // Compress - uint32_t state[16]; - compress(state, cv, block_words, chunk_counter, block_len, flags); - - // Update chaining value - for (int i = 0; i < 8; i++) { - cv[i] = state[i]; - } - - bytes_processed += block_len; - if ((bytes_processed % CHUNK_LEN) == 0) { - chunk_counter++; - } - } - - // Output - uint32_t out_offset = batch_idx * 8; - for (int i = 0; i < 8; i++) { - outputs[out_offset + i] = cv[i]; - } -} - -// ============================================================================= -// Merkle Tree Root Kernel -// ============================================================================= - -// Compute parent node from two child hashes -kernel void blake3_merge_nodes( - device const uint32_t* left_hashes [[buffer(0)]], - device const uint32_t* right_hashes [[buffer(1)]], - device uint32_t* parent_hashes [[buffer(2)]], - constant uint32_t& num_pairs [[buffer(3)]], - uint index [[thread_position_in_grid]] -) { - if (index >= num_pairs) return; - - // Load left and right child hashes - uint32_t block_words[16]; - for (int i = 0; i < 8; i++) { - block_words[i] = left_hashes[index * 8 + i]; - block_words[i + 8] = right_hashes[index * 8 + i]; - } - - // Compress with PARENT flag - uint32_t state[16]; - uint32_t cv[8]; - for (int i = 0; i < 8; i++) { - cv[i] = BLAKE3_IV[i]; - } - - uint32_t flags = PARENT; - compress(state, cv, block_words, 0, 64, flags); - - // Output parent hash - uint32_t out_offset = index * 8; - for (int i = 0; i < 8; i++) { - parent_hashes[out_offset + i] = state[i]; - } -} - -// ============================================================================= -// XOF (Extendable Output Function) Kernel -// ============================================================================= - -kernel void blake3_xof( - device const uint8_t* input [[buffer(0)]], - device uint8_t* output [[buffer(1)]], - constant uint32_t& input_len [[buffer(2)]], - constant uint32_t& output_len [[buffer(3)]], - uint index [[thread_position_in_grid]] -) { - // Each thread generates 64 bytes of output - uint32_t out_offset = index * 64; - if (out_offset >= output_len) return; - - // First, compute the base hash - uint32_t cv[8]; - for (int i = 0; i < 8; i++) { - cv[i] = BLAKE3_IV[i]; - } - - // Load and hash input (simplified for short inputs) - uint32_t block_words[16] = {0}; - for (uint32_t i = 0; i < min(64u, input_len); i++) { - uint32_t word_idx = i / 4; - uint32_t byte_pos = i % 4; - block_words[word_idx] |= ((uint32_t)input[i]) << (byte_pos * 8); - } - - uint32_t state[16]; - uint32_t flags = CHUNK_START | CHUNK_END | ROOT; - compress(state, cv, block_words, 0, input_len, flags); - - // XOF: use counter to extend output - uint32_t xof_state[16]; - uint32_t xof_cv[8]; - for (int i = 0; i < 8; i++) { - xof_cv[i] = state[i]; - } - - // Extend using counter = block index - uint64_t counter = index; - uint32_t zero_block[16] = {0}; - compress(xof_state, xof_cv, zero_block, counter, 0, ROOT); - - // Output 64 bytes - uint32_t bytes_to_write = min(64u, output_len - out_offset); - for (uint32_t i = 0; i < bytes_to_write; i++) { - uint32_t word_idx = i / 4; - uint32_t byte_pos = i % 4; - output[out_offset + i] = (uint8_t)(xof_state[word_idx] >> (byte_pos * 8)); - } -} - -// ============================================================================= -// Enhanced Merkle Tree Operations -// ============================================================================= - -// Build one layer of Merkle tree from contiguous array -// Input: current_layer has `layer_size` hashes (each 8 uint32_t = 32 bytes) -// Output: next_layer has layer_size/2 parent hashes -kernel void blake3_merkle_layer( - device const uint32_t* current_layer [[buffer(0)]], - device uint32_t* next_layer [[buffer(1)]], - constant uint32_t& layer_size [[buffer(2)]], - uint index [[thread_position_in_grid]] -) { - if (index >= layer_size / 2) return; - - // Load left child (8 words at position 2*index) - uint32_t left_offset = 2 * index * 8; - uint32_t block_words[16]; - for (int i = 0; i < 8; i++) { - block_words[i] = current_layer[left_offset + i]; - } - - // Load right child (8 words at position 2*index + 1) - uint32_t right_offset = (2 * index + 1) * 8; - for (int i = 0; i < 8; i++) { - block_words[i + 8] = current_layer[right_offset + i]; - } - - // Compress with PARENT flag - uint32_t state[16]; - uint32_t cv[8]; - for (int i = 0; i < 8; i++) { - cv[i] = BLAKE3_IV[i]; - } - - compress(state, cv, block_words, 0, 64, PARENT); - - // Write parent hash - uint32_t out_offset = index * 8; - for (int i = 0; i < 8; i++) { - next_layer[out_offset + i] = state[i]; - } -} - -// Hash leaves (raw data) to first layer of Merkle tree -// Each thread hashes one leaf of fixed size -kernel void blake3_hash_leaves( - device const uint8_t* leaf_data [[buffer(0)]], - device uint32_t* leaf_hashes [[buffer(1)]], - constant uint32_t& leaf_size [[buffer(2)]], // Size of each leaf in bytes - constant uint32_t& num_leaves [[buffer(3)]], - uint index [[thread_position_in_grid]] -) { - if (index >= num_leaves) return; - - uint32_t offset = index * leaf_size; - - // Initialize chaining value - uint32_t cv[8]; - for (int i = 0; i < 8; i++) { - cv[i] = BLAKE3_IV[i]; - } - - // Process leaf data in chunks - uint32_t bytes_processed = 0; - uint64_t chunk_counter = 0; - - while (bytes_processed < leaf_size) { - uint32_t block_len = min(64u, leaf_size - bytes_processed); - - // Load block - uint32_t block_words[16] = {0}; - for (uint32_t i = 0; i < block_len; i++) { - uint32_t word_idx = i / 4; - uint32_t byte_pos = i % 4; - block_words[word_idx] |= ((uint32_t)leaf_data[offset + bytes_processed + i]) << (byte_pos * 8); - } - - // Determine flags - uint32_t flags = 0; - if (bytes_processed == 0) flags |= CHUNK_START; - if (bytes_processed + block_len >= leaf_size) flags |= CHUNK_END; - - // Compress - uint32_t state[16]; - compress(state, cv, block_words, chunk_counter, block_len, flags); - - // Update chaining value - for (int i = 0; i < 8; i++) { - cv[i] = state[i]; - } - - bytes_processed += block_len; - } - - // Write leaf hash - uint32_t out_offset = index * 8; - for (int i = 0; i < 8; i++) { - leaf_hashes[out_offset + i] = cv[i]; - } -} - -// Verify a Merkle proof -// Computes root from leaf and sibling path, compares to expected -kernel void blake3_verify_merkle_proof( - device const uint32_t* leaf_hash [[buffer(0)]], // Leaf hash (8 words per proof) - device const uint32_t* sibling_path [[buffer(1)]], // Siblings (8 words each) - device const uint32_t* path_indices [[buffer(2)]], // 0=left, 1=right for each level - device const uint32_t* expected_root [[buffer(3)]], // Expected root (8 words per proof) - device uint32_t* results [[buffer(4)]], // 1=valid, 0=invalid - constant uint32_t& path_len [[buffer(5)]], - uint proof_idx [[thread_position_in_grid]] -) { - // Load leaf hash - uint32_t current[8]; - uint32_t leaf_offset = proof_idx * 8; - for (int i = 0; i < 8; i++) { - current[i] = leaf_hash[leaf_offset + i]; - } - - // Traverse up the tree - for (uint32_t level = 0; level < path_len; level++) { - uint32_t sibling_offset = (proof_idx * path_len + level) * 8; - uint32_t idx = path_indices[proof_idx * path_len + level]; - - // Prepare block: [left, right] based on index - uint32_t block_words[16]; - if (idx == 0) { - // Current is left child - for (int i = 0; i < 8; i++) { - block_words[i] = current[i]; - block_words[i + 8] = sibling_path[sibling_offset + i]; - } - } else { - // Current is right child - for (int i = 0; i < 8; i++) { - block_words[i] = sibling_path[sibling_offset + i]; - block_words[i + 8] = current[i]; - } - } - - // Compress - uint32_t cv[8]; - for (int i = 0; i < 8; i++) { - cv[i] = BLAKE3_IV[i]; - } - - uint32_t state[16]; - compress(state, cv, block_words, 0, 64, PARENT); - - // Update current - for (int i = 0; i < 8; i++) { - current[i] = state[i]; - } - } - - // Compare with expected root - uint32_t root_offset = proof_idx * 8; - bool valid = true; - for (int i = 0; i < 8; i++) { - if (current[i] != expected_root[root_offset + i]) { - valid = false; - break; - } - } - - results[proof_idx] = valid ? 1 : 0; -} - -// Batch Merkle proof verification (multiple proofs in parallel) -kernel void blake3_batch_verify_proofs( - device const uint32_t* leaf_hashes [[buffer(0)]], // All leaf hashes - device const uint32_t* all_siblings [[buffer(1)]], // All sibling paths - device const uint32_t* all_indices [[buffer(2)]], // All path indices - device const uint32_t* expected_root [[buffer(3)]], // Single root for all - device uint32_t* results [[buffer(4)]], // Per-proof results - constant uint32_t& path_len [[buffer(5)]], - constant uint32_t& num_proofs [[buffer(6)]], - uint proof_idx [[thread_position_in_grid]] -) { - if (proof_idx >= num_proofs) return; - - // Load leaf hash for this proof - uint32_t current[8]; - uint32_t leaf_offset = proof_idx * 8; - for (int i = 0; i < 8; i++) { - current[i] = leaf_hashes[leaf_offset + i]; - } - - // Traverse path - for (uint32_t level = 0; level < path_len; level++) { - uint32_t sibling_offset = (proof_idx * path_len + level) * 8; - uint32_t idx = all_indices[proof_idx * path_len + level]; - - uint32_t block_words[16]; - if (idx == 0) { - for (int i = 0; i < 8; i++) { - block_words[i] = current[i]; - block_words[i + 8] = all_siblings[sibling_offset + i]; - } - } else { - for (int i = 0; i < 8; i++) { - block_words[i] = all_siblings[sibling_offset + i]; - block_words[i + 8] = current[i]; - } - } - - uint32_t cv[8]; - for (int i = 0; i < 8; i++) { - cv[i] = BLAKE3_IV[i]; - } - - uint32_t state[16]; - compress(state, cv, block_words, 0, 64, PARENT); - - for (int i = 0; i < 8; i++) { - current[i] = state[i]; - } - } - - // Compare with shared root - bool valid = true; - for (int i = 0; i < 8; i++) { - if (current[i] != expected_root[i]) { - valid = false; - break; - } - } - - results[proof_idx] = valid ? 1 : 0; -} - -// KDF (Key Derivation Function) using Blake3 -kernel void blake3_derive_key( - device const uint8_t* context [[buffer(0)]], // Context string - device const uint8_t* key_material [[buffer(1)]], - device uint8_t* derived_key [[buffer(2)]], - constant uint32_t& context_len [[buffer(3)]], - constant uint32_t& key_len [[buffer(4)]], - constant uint32_t& output_len [[buffer(5)]], - uint index [[thread_position_in_grid]] -) { - if (index != 0) return; // Single-threaded for simplicity - - // Step 1: Hash context to get derive_key_context - uint32_t cv[8]; - for (int i = 0; i < 8; i++) { - cv[i] = BLAKE3_IV[i]; - } - - // Hash context (simplified for short contexts) - uint32_t context_words[16] = {0}; - for (uint32_t i = 0; i < min(64u, context_len); i++) { - uint32_t word_idx = i / 4; - uint32_t byte_pos = i % 4; - context_words[word_idx] |= ((uint32_t)context[i]) << (byte_pos * 8); - } - - uint32_t state[16]; - compress(state, cv, context_words, 0, context_len, CHUNK_START | CHUNK_END); - - // Use state as new IV - uint32_t key_cv[8]; - for (int i = 0; i < 8; i++) { - key_cv[i] = state[i]; - } - - // Step 2: Hash key material with derive_key_context - uint32_t key_words[16] = {0}; - for (uint32_t i = 0; i < min(64u, key_len); i++) { - uint32_t word_idx = i / 4; - uint32_t byte_pos = i % 4; - key_words[word_idx] |= ((uint32_t)key_material[i]) << (byte_pos * 8); - } - - compress(state, key_cv, key_words, 0, key_len, CHUNK_START | CHUNK_END | ROOT); - - // Output derived key - for (uint32_t i = 0; i < min(32u, output_len); i++) { - uint32_t word_idx = i / 4; - uint32_t byte_pos = i % 4; - derived_key[i] = (uint8_t)(state[word_idx] >> (byte_pos * 8)); - } -} diff --git a/blake3/gpu/metal/blake3_batch.metal b/blake3/gpu/metal/blake3_batch.metal deleted file mode 100644 index aab20c5..0000000 --- a/blake3/gpu/metal/blake3_batch.metal +++ /dev/null @@ -1,334 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// GPU-batched BLAKE3 (BLAKE3 spec). One thread per input. Byte-equal to -// blake3/cpp/blake3.cpp::hash() for arbitrary-length inputs. -// -// Each input is hashed in plain mode (key = IV, base flags = 0). Tree mode -// is implemented in-kernel: each chunk is processed sequentially per thread, -// chunk CVs are merged via a per-thread stack with the canonical "fold on -// trailing zero of total chunk count" rule, and the final root output is -// emitted with the ROOT flag. -// -// Layout: caller fills a Blake3Job[] with (input_offset, input_len, -// output_offset). Inputs share a flat byte arena; outputs share a 32-byte -// stride arena. - -#include -using namespace metal; - -constant uint BLAKE3_IV[8] = { - 0x6A09E667u, 0xBB67AE85u, 0x3C6EF372u, 0xA54FF53Au, - 0x510E527Fu, 0x9B05688Cu, 0x1F83D9ABu, 0x5BE0CD19u, -}; - -constant uint CHUNK_START = 1u << 0; -constant uint CHUNK_END = 1u << 1; -constant uint PARENT = 1u << 2; -constant uint ROOT = 1u << 3; - -constant uchar MSG_PERM[16] = { - 2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8, -}; - -inline uint rotr32(uint x, uint n) { - return (x >> n) | (x << (32u - n)); -} - -inline void g(thread uint* s, int a, int b, int c, int d, - uint mx, uint my) { - s[a] = s[a] + s[b] + mx; - s[d] = rotr32(s[d] ^ s[a], 16u); - s[c] = s[c] + s[d]; - s[b] = rotr32(s[b] ^ s[c], 12u); - s[a] = s[a] + s[b] + my; - s[d] = rotr32(s[d] ^ s[a], 8u); - s[c] = s[c] + s[d]; - s[b] = rotr32(s[b] ^ s[c], 7u); -} - -inline void round_fn(thread uint* s, thread uint* m) { - g(s, 0, 4, 8, 12, m[0], m[1]); - g(s, 1, 5, 9, 13, m[2], m[3]); - g(s, 2, 6, 10, 14, m[4], m[5]); - g(s, 3, 7, 11, 15, m[6], m[7]); - g(s, 0, 5, 10, 15, m[8], m[9]); - g(s, 1, 6, 11, 12, m[10], m[11]); - g(s, 2, 7, 8, 13, m[12], m[13]); - g(s, 3, 4, 9, 14, m[14], m[15]); -} - -inline void permute(thread uint* m) { - uint tmp[16]; - for (int i = 0; i < 16; ++i) tmp[i] = m[MSG_PERM[i]]; - for (int i = 0; i < 16; ++i) m[i] = tmp[i]; -} - -// compress: writes 16 words to out_state. Caller takes [0..7] for chaining -// values, [0..15] for XOF blocks. -inline void compress(thread const uint* cv, - thread const uchar* block, - ulong counter, - uint block_len, - uint flags, - thread uint* out_state) { - uint m[16]; - for (int i = 0; i < 16; ++i) { - m[i] = uint(block[i * 4 + 0]) - | (uint(block[i * 4 + 1]) << 8) - | (uint(block[i * 4 + 2]) << 16) - | (uint(block[i * 4 + 3]) << 24); - } - - uint s[16] = { - cv[0], cv[1], cv[2], cv[3], - cv[4], cv[5], cv[6], cv[7], - BLAKE3_IV[0], BLAKE3_IV[1], BLAKE3_IV[2], BLAKE3_IV[3], - uint(counter & 0xFFFFFFFFu), - uint(counter >> 32), - block_len, - flags, - }; - - round_fn(s, m); permute(m); - round_fn(s, m); permute(m); - round_fn(s, m); permute(m); - round_fn(s, m); permute(m); - round_fn(s, m); permute(m); - round_fn(s, m); permute(m); - round_fn(s, m); - - for (int i = 0; i < 8; ++i) { - out_state[i] = s[i] ^ s[i + 8]; - out_state[i + 8] = s[i + 8] ^ cv[i]; - } -} - -struct Blake3Job { - uint input_offset; - uint input_len; - uint output_offset; - uint _pad; -}; - -// Per-thread chunk state. -struct ChunkState { - uint cv[8]; - ulong chunk_counter; - uchar block[64]; - uint block_len; // 0..64 - uint blocks_compressed; // 0..15 mid-chunk - uint flags; -}; - -inline void chunk_init(thread ChunkState& cs, ulong cc, uint base_flags) { - for (int i = 0; i < 8; ++i) cs.cv[i] = BLAKE3_IV[i]; - cs.chunk_counter = cc; - for (int i = 0; i < 64; ++i) cs.block[i] = 0; - cs.block_len = 0u; - cs.blocks_compressed = 0u; - cs.flags = base_flags; -} - -inline uint chunk_start_flag(thread const ChunkState& cs) { - return cs.blocks_compressed == 0u ? CHUNK_START : 0u; -} - -inline uint chunk_len(thread const ChunkState& cs) { - return 64u * cs.blocks_compressed + cs.block_len; -} - -inline void chunk_update(thread ChunkState& cs, - device const uchar* data, uint pos, uint count) { - uint i = 0u; - while (i < count) { - if (cs.block_len == 64u) { - uint s[16]; - uchar tb[64]; - for (int k = 0; k < 64; ++k) tb[k] = cs.block[k]; - compress(cs.cv, tb, cs.chunk_counter, 64u, - cs.flags | chunk_start_flag(cs), s); - for (int k = 0; k < 8; ++k) cs.cv[k] = s[k]; - cs.blocks_compressed += 1u; - for (int k = 0; k < 64; ++k) cs.block[k] = 0; - cs.block_len = 0u; - } - uint want = 64u - cs.block_len; - uint take = (count - i < want) ? (count - i) : want; - for (uint k = 0u; k < take; ++k) { - cs.block[cs.block_len + k] = data[pos + i + k]; - } - cs.block_len += take; - i += take; - } -} - -// Compute chunk CV (non-root). Reads cs by-thread copy. -inline void chunk_chaining_value(thread const ChunkState& cs, thread uint* out) { - uint s[16]; - uchar tb[64]; - for (int k = 0; k < 64; ++k) tb[k] = cs.block[k]; - uint flags = cs.flags | chunk_start_flag(cs) | CHUNK_END; - compress(cs.cv, tb, cs.chunk_counter, cs.block_len, flags, s); - for (int k = 0; k < 8; ++k) out[k] = s[k]; -} - -inline void parent_block(thread const uint* l, thread const uint* r, - thread uchar* out) { - for (int i = 0; i < 8; ++i) { - uint w = l[i]; - out[i * 4 + 0] = uchar(w & 0xFFu); - out[i * 4 + 1] = uchar((w >> 8) & 0xFFu); - out[i * 4 + 2] = uchar((w >> 16) & 0xFFu); - out[i * 4 + 3] = uchar((w >> 24) & 0xFFu); - } - for (int i = 0; i < 8; ++i) { - uint w = r[i]; - out[32 + i * 4 + 0] = uchar(w & 0xFFu); - out[32 + i * 4 + 1] = uchar((w >> 8) & 0xFFu); - out[32 + i * 4 + 2] = uchar((w >> 16) & 0xFFu); - out[32 + i * 4 + 3] = uchar((w >> 24) & 0xFFu); - } -} - -inline void parent_cv_compute(thread const uint* l, thread const uint* r, - uint base_flags, thread uint* out) { - uchar pb[64]; - parent_block(l, r, pb); - uint s[16]; - uint kw[8]; - for (int i = 0; i < 8; ++i) kw[i] = BLAKE3_IV[i]; - compress(kw, pb, 0ul, 64u, PARENT | base_flags, s); - for (int i = 0; i < 8; ++i) out[i] = s[i]; -} - -// Stack-of-CVs Bao tree, max depth 54 covers 2^54 chunks. -constant uint BLAKE3_STACK_MAX = 54u; - -kernel void blake3_jobs( - device const Blake3Job* jobs [[buffer(0)]], - device const uchar* inputs [[buffer(1)]], - device uchar* outputs [[buffer(2)]], - constant uint& num_jobs [[buffer(3)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= num_jobs) return; - - Blake3Job j = jobs[tid]; - uint base_flags = 0u; - - // Single-chunk fast path: rare but common for short messages. - if (j.input_len <= 1024u) { - ChunkState cs; - chunk_init(cs, 0ul, base_flags); - chunk_update(cs, inputs, j.input_offset, j.input_len); - - // Root output: 16 words from compressing this chunk's final state - // with ROOT flag. For 32-byte digest we only emit the first 8 words. - uint s[16]; - uchar tb[64]; - for (int k = 0; k < 64; ++k) tb[k] = cs.block[k]; - uint flags = cs.flags | chunk_start_flag(cs) | CHUNK_END | ROOT; - compress(cs.cv, tb, cs.chunk_counter, cs.block_len, flags, s); - - device uchar* dst = outputs + j.output_offset; - for (int i = 0; i < 8; ++i) { - uint w = s[i]; - dst[i * 4 + 0] = uchar(w & 0xFFu); - dst[i * 4 + 1] = uchar((w >> 8) & 0xFFu); - dst[i * 4 + 2] = uchar((w >> 16) & 0xFFu); - dst[i * 4 + 3] = uchar((w >> 24) & 0xFFu); - } - return; - } - - // Multi-chunk: stack-of-CVs tree mode. - uint stack[BLAKE3_STACK_MAX][8]; - uint stack_len = 0u; - - ChunkState cs; - chunk_init(cs, 0ul, base_flags); - uint pos = 0u; - while (pos < j.input_len) { - if (chunk_len(cs) == 1024u) { - // Finish this chunk → push CV. - uint cv[8]; - chunk_chaining_value(cs, cv); - ulong this_idx = cs.chunk_counter; - // Merge while the trailing bit of (this_idx + 1) is 0. - ulong total = this_idx + 1ul; - uint cur[8]; - for (int k = 0; k < 8; ++k) cur[k] = cv[k]; - while ((total & 1ul) == 0ul) { - uint left[8]; - for (int k = 0; k < 8; ++k) left[k] = stack[stack_len - 1u][k]; - uint merged[8]; - parent_cv_compute(left, cur, base_flags, merged); - for (int k = 0; k < 8; ++k) cur[k] = merged[k]; - stack_len -= 1u; - total >>= 1ul; - } - for (int k = 0; k < 8; ++k) stack[stack_len][k] = cur[k]; - stack_len += 1u; - chunk_init(cs, this_idx + 1ul, base_flags); - } - uint want = 1024u - chunk_len(cs); - uint take = (j.input_len - pos < want) ? (j.input_len - pos) : want; - chunk_update(cs, inputs, j.input_offset + pos, take); - pos += take; - } - - // Finalize: walk stack folding right with current chunk's output. - // The very last merge gets ROOT flag. - if (stack_len == 0u) { - // Should not reach here in multi-chunk path; handle defensively. - uint s[16]; - uchar tb[64]; - for (int k = 0; k < 64; ++k) tb[k] = cs.block[k]; - uint flags = cs.flags | chunk_start_flag(cs) | CHUNK_END | ROOT; - compress(cs.cv, tb, cs.chunk_counter, cs.block_len, flags, s); - device uchar* dst = outputs + j.output_offset; - for (int i = 0; i < 8; ++i) { - uint w = s[i]; - dst[i * 4 + 0] = uchar(w & 0xFFu); - dst[i * 4 + 1] = uchar((w >> 8) & 0xFFu); - dst[i * 4 + 2] = uchar((w >> 16) & 0xFFu); - dst[i * 4 + 3] = uchar((w >> 24) & 0xFFu); - } - return; - } - - // Get current chunk's CV (non-root for now; root applied at last merge). - uint cur_cv[8]; - chunk_chaining_value(cs, cur_cv); - - // Walk stack from top down. The final merge replaces parent_cv_compute - // with a ROOT-flagged compress emitting the digest directly. - int idx = int(stack_len) - 1; - while (idx >= 0) { - uint left[8]; - for (int k = 0; k < 8; ++k) left[k] = stack[idx][k]; - if (idx == 0) { - // Root parent: ROOT-flagged compress, take first 8 words. - uchar pb[64]; - parent_block(left, cur_cv, pb); - uint s[16]; - uint kw[8]; - for (int k = 0; k < 8; ++k) kw[k] = BLAKE3_IV[k]; - compress(kw, pb, 0ul, 64u, PARENT | base_flags | ROOT, s); - device uchar* dst = outputs + j.output_offset; - for (int i = 0; i < 8; ++i) { - uint w = s[i]; - dst[i * 4 + 0] = uchar(w & 0xFFu); - dst[i * 4 + 1] = uchar((w >> 8) & 0xFFu); - dst[i * 4 + 2] = uchar((w >> 16) & 0xFFu); - dst[i * 4 + 3] = uchar((w >> 24) & 0xFFu); - } - return; - } - uint merged[8]; - parent_cv_compute(left, cur_cv, base_flags, merged); - for (int k = 0; k < 8; ++k) cur_cv[k] = merged[k]; - idx -= 1; - } -} diff --git a/blake3/gpu/metal/blake3_batch_driver.mm b/blake3/gpu/metal/blake3_batch_driver.mm deleted file mode 100644 index b024289..0000000 --- a/blake3/gpu/metal/blake3_batch_driver.mm +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Metal driver for batched BLAKE3 (BLAKE3 spec). macOS / iOS only. -// Loads blake3_batch.metallib, dispatches `blake3_jobs` with one thread per -// input. Byte-equal to blake3/cpp/blake3.cpp::hash(). - -#if __APPLE__ && __OBJC__ - -#import -#import - -#include -#include -#include -#include - -namespace { - -struct Blake3JobGPU { - uint32_t input_offset; - uint32_t input_len; - uint32_t output_offset; - uint32_t _pad; -}; - -} // namespace - -extern "C" int blake3_batch_metal( - const uint8_t* inputs_arena, - size_t inputs_arena_len, - const uint32_t* input_offsets, - const uint32_t* input_lens, - size_t n, - uint8_t* outputs_arena, - const char* metallib_path) { - - if (n == 0) return 0; - if (!inputs_arena || !input_offsets || !input_lens || !outputs_arena || - !metallib_path) return -1; - - @autoreleasepool { - id device = MTLCreateSystemDefaultDevice(); - if (!device) return -2; - - NSError* err = nil; - NSString* path = [NSString stringWithUTF8String:metallib_path]; - NSURL* url = [NSURL fileURLWithPath:path]; - id lib = [device newLibraryWithURL:url error:&err]; - if (!lib) return -3; - - id fn = [lib newFunctionWithName:@"blake3_jobs"]; - if (!fn) return -4; - - id pipeline = - [device newComputePipelineStateWithFunction:fn error:&err]; - if (!pipeline) return -5; - - id queue = [device newCommandQueue]; - - std::vector jobs(n); - for (size_t i = 0; i < n; ++i) { - jobs[i].input_offset = input_offsets[i]; - jobs[i].input_len = input_lens[i]; - jobs[i].output_offset = (uint32_t)(i * 32); - jobs[i]._pad = 0; - } - - id jobs_buf = [device newBufferWithBytes:jobs.data() - length:jobs.size() * sizeof(Blake3JobGPU) - options:MTLResourceStorageModeShared]; - id inputs_buf = [device newBufferWithBytes:inputs_arena - length:inputs_arena_len - options:MTLResourceStorageModeShared]; - id outputs_buf = [device newBufferWithLength:n * 32 - options:MTLResourceStorageModeShared]; - uint32_t n_u32 = (uint32_t)n; - id n_buf = [device newBufferWithBytes:&n_u32 - length:sizeof(n_u32) - options:MTLResourceStorageModeShared]; - - id cmd = [queue commandBuffer]; - id enc = [cmd computeCommandEncoder]; - [enc setComputePipelineState:pipeline]; - [enc setBuffer:jobs_buf offset:0 atIndex:0]; - [enc setBuffer:inputs_buf offset:0 atIndex:1]; - [enc setBuffer:outputs_buf offset:0 atIndex:2]; - [enc setBuffer:n_buf offset:0 atIndex:3]; - - NSUInteger tg_max = pipeline.maxTotalThreadsPerThreadgroup; - NSUInteger tg_w = tg_max < 64 ? tg_max : 64; - MTLSize threads_per_grid = MTLSizeMake(n, 1, 1); - MTLSize threads_per_tg = MTLSizeMake(tg_w, 1, 1); - [enc dispatchThreads:threads_per_grid threadsPerThreadgroup:threads_per_tg]; - [enc endEncoding]; - [cmd commit]; - [cmd waitUntilCompleted]; - - std::memcpy(outputs_arena, [outputs_buf contents], n * 32); - } - return 0; -} - -#endif // __APPLE__ && __OBJC__ diff --git a/blake3/gpu/metal/blake3_driver.h b/blake3/gpu/metal/blake3_driver.h deleted file mode 100644 index 5bc8d9a..0000000 --- a/blake3/gpu/metal/blake3_driver.h +++ /dev/null @@ -1,275 +0,0 @@ -// ============================================================================= -// Metal BLAKE3 - GPU Acceleration for BLAKE3 Hash -// ============================================================================= -// -// High-performance BLAKE3 hashing with GPU parallelization. -// Based on the official BLAKE3 specification. -// -// Copyright (C) 2024-2025 Lux Industries Inc. -// SPDX-License-Identifier: Apache-2.0 - -#pragma once -#include -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -// ============================================================================= -// Context Management -// ============================================================================= - -typedef struct MetalBLAKE3Context MetalBLAKE3Context; - -MetalBLAKE3Context* metal_blake3_init(void); -void metal_blake3_destroy(MetalBLAKE3Context* ctx); -bool metal_blake3_available(void); - -// ============================================================================= -// Constants -// ============================================================================= - -#define BLAKE3_OUT_LEN 32 -#define BLAKE3_KEY_LEN 32 -#define BLAKE3_BLOCK_LEN 64 -#define BLAKE3_CHUNK_LEN 1024 - -// ============================================================================= -// Return Codes -// ============================================================================= - -typedef enum { - METAL_BLAKE3_SUCCESS = 0, - METAL_BLAKE3_ERROR_INIT = -1, - METAL_BLAKE3_ERROR_INVALID_INPUT = -2, - METAL_BLAKE3_ERROR_GPU_DISPATCH = -3, -} MetalBLAKE3Result; - -// ============================================================================= -// Simple Hash Interface -// ============================================================================= - -/** - * Hash data with BLAKE3. - * - * @param ctx Metal context - * @param output Output hash (32 bytes or more for XOF) - * @param out_len Output length (32 for standard, more for XOF) - * @param input Input data - * @param in_len Input length - */ -MetalBLAKE3Result metal_blake3_hash( - MetalBLAKE3Context* ctx, - uint8_t* output, - size_t out_len, - const uint8_t* input, - size_t in_len -); - -/** - * Batch hash - hash multiple inputs in parallel. - * Highly efficient for many small inputs. - * - * @param ctx Metal context - * @param outputs Output hashes (32 bytes each) - * @param inputs Array of input pointers - * @param in_lens Array of input lengths - * @param count Number of inputs - */ -MetalBLAKE3Result metal_blake3_batch_hash( - MetalBLAKE3Context* ctx, - uint8_t* outputs, - const uint8_t** inputs, - const size_t* in_lens, - uint32_t count -); - -/** - * Batch hash fixed-size inputs (optimized path). - * All inputs must be the same size. - * - * @param ctx Metal context - * @param outputs Output hashes - * @param inputs Contiguous input data - * @param in_len Size of each input - * @param count Number of inputs - */ -MetalBLAKE3Result metal_blake3_batch_hash_fixed( - MetalBLAKE3Context* ctx, - uint8_t* outputs, - const uint8_t* inputs, - size_t in_len, - uint32_t count -); - -// ============================================================================= -// Keyed Hash (MAC) -// ============================================================================= - -/** - * Keyed BLAKE3 hash (for MAC). - * - * @param ctx Metal context - * @param output Output MAC - * @param out_len Output length - * @param key 32-byte key - * @param input Input data - * @param in_len Input length - */ -MetalBLAKE3Result metal_blake3_keyed_hash( - MetalBLAKE3Context* ctx, - uint8_t* output, - size_t out_len, - const uint8_t key[BLAKE3_KEY_LEN], - const uint8_t* input, - size_t in_len -); - -/** - * Batch keyed hash with same key. - */ -MetalBLAKE3Result metal_blake3_batch_keyed_hash( - MetalBLAKE3Context* ctx, - uint8_t* outputs, - const uint8_t key[BLAKE3_KEY_LEN], - const uint8_t** inputs, - const size_t* in_lens, - uint32_t count -); - -// ============================================================================= -// Key Derivation (KDF) -// ============================================================================= - -/** - * Derive key using BLAKE3 KDF. - * - * @param ctx Metal context - * @param output Output key material - * @param out_len Output length - * @param context Context string - * @param context_len Context string length - * @param key_material Input key material - * @param km_len Key material length - */ -MetalBLAKE3Result metal_blake3_derive_key( - MetalBLAKE3Context* ctx, - uint8_t* output, - size_t out_len, - const char* context, - size_t context_len, - const uint8_t* key_material, - size_t km_len -); - -// ============================================================================= -// Streaming Interface (Incremental Hashing) -// ============================================================================= - -typedef struct MetalBLAKE3Hasher MetalBLAKE3Hasher; - -/** - * Create new BLAKE3 hasher. - */ -MetalBLAKE3Hasher* metal_blake3_hasher_new(MetalBLAKE3Context* ctx); - -/** - * Create keyed hasher. - */ -MetalBLAKE3Hasher* metal_blake3_hasher_new_keyed( - MetalBLAKE3Context* ctx, - const uint8_t key[BLAKE3_KEY_LEN] -); - -/** - * Create KDF hasher. - */ -MetalBLAKE3Hasher* metal_blake3_hasher_new_derive_key( - MetalBLAKE3Context* ctx, - const char* context, - size_t context_len -); - -/** - * Update hasher with more data. - */ -MetalBLAKE3Result metal_blake3_hasher_update( - MetalBLAKE3Hasher* hasher, - const uint8_t* input, - size_t in_len -); - -/** - * Finalize and get output. - */ -MetalBLAKE3Result metal_blake3_hasher_finalize( - MetalBLAKE3Hasher* hasher, - uint8_t* output, - size_t out_len -); - -/** - * Reset hasher for reuse. - */ -void metal_blake3_hasher_reset(MetalBLAKE3Hasher* hasher); - -/** - * Free hasher. - */ -void metal_blake3_hasher_free(MetalBLAKE3Hasher* hasher); - -// ============================================================================= -// Merkle Tree (BLAKE3-native parallelism) -// ============================================================================= - -/** - * Compute BLAKE3 Merkle tree root. - * Uses BLAKE3's native tree hashing mode. - * - * @param ctx Metal context - * @param root Output root (32 bytes) - * @param leaves Leaf data (each 32 bytes) - * @param count Number of leaves (power of 2) - */ -MetalBLAKE3Result metal_blake3_merkle_root( - MetalBLAKE3Context* ctx, - uint8_t root[BLAKE3_OUT_LEN], - const uint8_t* leaves, - uint32_t count -); - -/** - * Build full BLAKE3 Merkle tree. - */ -MetalBLAKE3Result metal_blake3_merkle_tree( - MetalBLAKE3Context* ctx, - uint8_t* nodes, - const uint8_t* leaves, - uint32_t count -); - -// ============================================================================= -// Large File Hashing (Streaming with GPU chunks) -// ============================================================================= - -/** - * Hash large file with GPU-accelerated chunk processing. - * Processes 1MB chunks in parallel on GPU. - * - * @param ctx Metal context - * @param output Output hash - * @param out_len Output length - * @param path File path - */ -MetalBLAKE3Result metal_blake3_hash_file( - MetalBLAKE3Context* ctx, - uint8_t* output, - size_t out_len, - const char* path -); - -#ifdef __cplusplus -} -#endif diff --git a/blake3/gpu/metal/blake3_driver.mm b/blake3/gpu/metal/blake3_driver.mm deleted file mode 100644 index 52118db..0000000 --- a/blake3/gpu/metal/blake3_driver.mm +++ /dev/null @@ -1,1083 +0,0 @@ -// ============================================================================= -// Metal BLAKE3 - GPU Acceleration for BLAKE3 Hash -// ============================================================================= -// -// High-performance BLAKE3 hashing with GPU parallelization. -// Based on the official BLAKE3 specification. -// -// Copyright (C) 2024-2025 Lux Industries Inc. -// SPDX-License-Identifier: Apache-2.0 - -#import -#import -#include "lux/crypto/metal_blake3.h" -#include -#include -#include -#include - -// ============================================================================= -// Metal Shader Source - BLAKE3 -// ============================================================================= - -static const char* BLAKE3_SHADER_SOURCE = R"( -#include -using namespace metal; - -// BLAKE3 constants -constant uint32_t BLAKE3_IV[8] = { - 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, - 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19 -}; - -// Domain separation flags -constant uint32_t CHUNK_START = 1; -constant uint32_t CHUNK_END = 2; -constant uint32_t PARENT = 4; -constant uint32_t ROOT = 8; -constant uint32_t KEYED_HASH = 16; -constant uint32_t DERIVE_KEY_CONTEXT = 32; -constant uint32_t DERIVE_KEY_MATERIAL = 64; - -// Message permutation schedule -constant uint8_t MSG_SCHEDULE[7][16] = { - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - {2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8}, - {3, 4, 10, 12, 13, 2, 7, 14, 6, 5, 9, 0, 11, 15, 8, 1}, - {10, 7, 12, 9, 14, 3, 13, 15, 4, 0, 11, 2, 5, 8, 1, 6}, - {12, 13, 9, 11, 15, 10, 14, 8, 7, 2, 5, 3, 0, 1, 6, 4}, - {9, 14, 11, 5, 8, 12, 15, 1, 13, 3, 0, 10, 2, 6, 4, 7}, - {11, 15, 5, 0, 1, 9, 8, 6, 14, 10, 2, 12, 3, 4, 7, 13} -}; - -// Quarter round function -void g(thread uint32_t* state, int a, int b, int c, int d, uint32_t mx, uint32_t my) { - state[a] = state[a] + state[b] + mx; - state[d] = ((state[d] ^ state[a]) >> 16) | ((state[d] ^ state[a]) << 16); - state[c] = state[c] + state[d]; - state[b] = ((state[b] ^ state[c]) >> 12) | ((state[b] ^ state[c]) << 20); - state[a] = state[a] + state[b] + my; - state[d] = ((state[d] ^ state[a]) >> 8) | ((state[d] ^ state[a]) << 24); - state[c] = state[c] + state[d]; - state[b] = ((state[b] ^ state[c]) >> 7) | ((state[b] ^ state[c]) << 25); -} - -// BLAKE3 round function -void blake3_round(thread uint32_t* state, const thread uint32_t* msg, int round) { - constant uint8_t* s = MSG_SCHEDULE[round]; - - g(state, 0, 4, 8, 12, msg[s[0]], msg[s[1]]); - g(state, 1, 5, 9, 13, msg[s[2]], msg[s[3]]); - g(state, 2, 6, 10, 14, msg[s[4]], msg[s[5]]); - g(state, 3, 7, 11, 15, msg[s[6]], msg[s[7]]); - g(state, 0, 5, 10, 15, msg[s[8]], msg[s[9]]); - g(state, 1, 6, 11, 12, msg[s[10]], msg[s[11]]); - g(state, 2, 7, 8, 13, msg[s[12]], msg[s[13]]); - g(state, 3, 4, 9, 14, msg[s[14]], msg[s[15]]); -} - -// BLAKE3 compression function -void blake3_compress( - thread uint32_t* cv_out, - const thread uint32_t* cv_in, - const thread uint32_t* block, - uint64_t counter, - uint32_t block_len, - uint32_t flags -) { - uint32_t state[16]; - - // Initialize state - for (int i = 0; i < 8; i++) state[i] = cv_in[i]; - for (int i = 0; i < 4; i++) state[8 + i] = BLAKE3_IV[i]; - state[12] = (uint32_t)counter; - state[13] = (uint32_t)(counter >> 32); - state[14] = block_len; - state[15] = flags; - - // 7 rounds - for (int r = 0; r < 7; r++) { - blake3_round(state, block, r); - } - - // XOR output - for (int i = 0; i < 8; i++) { - cv_out[i] = state[i] ^ state[i + 8]; - } -} - -// Hash a single chunk (up to 1024 bytes) -void blake3_hash_chunk( - thread uint32_t* cv_out, - const device uint8_t* chunk, - uint32_t chunk_len, - const thread uint32_t* key, - uint64_t chunk_counter, - uint32_t flags -) { - uint32_t cv[8]; - for (int i = 0; i < 8; i++) cv[i] = key[i]; - - uint32_t blocks = (chunk_len + 63) / 64; - if (blocks == 0) blocks = 1; - - for (uint32_t b = 0; b < blocks; b++) { - uint32_t block[16] = {0}; - - // Load block (little-endian) - uint32_t offset = b * 64; - uint32_t remaining = (chunk_len > offset) ? chunk_len - offset : 0; - uint32_t to_copy = min(remaining, 64u); - - for (uint32_t i = 0; i < to_copy; i++) { - block[i / 4] |= ((uint32_t)chunk[offset + i]) << ((i % 4) * 8); - } - - uint32_t block_flags = flags; - if (b == 0) block_flags |= CHUNK_START; - if (b == blocks - 1) block_flags |= CHUNK_END; - - uint32_t block_len = min(to_copy, 64u); - blake3_compress(cv, cv, block, chunk_counter, block_len, block_flags); - } - - for (int i = 0; i < 8; i++) cv_out[i] = cv[i]; -} - -// Kernel: Hash fixed-size inputs in parallel -kernel void blake3_batch_hash_fixed( - device const uint8_t* inputs [[buffer(0)]], - device uint8_t* outputs [[buffer(1)]], - constant uint32_t& input_len [[buffer(2)]], - constant uint32_t* key [[buffer(3)]], - constant uint32_t& flags [[buffer(4)]], - uint32_t tid [[thread_position_in_grid]] -) { - uint32_t cv[8]; - - const device uint8_t* my_input = inputs + tid * input_len; - device uint8_t* my_output = outputs + tid * 32; - - // Hash the input as a single chunk - blake3_hash_chunk(cv, my_input, input_len, key, (uint64_t)tid, flags | ROOT); - - // Output (little-endian) - for (int i = 0; i < 8; i++) { - my_output[i * 4 + 0] = cv[i] & 0xFF; - my_output[i * 4 + 1] = (cv[i] >> 8) & 0xFF; - my_output[i * 4 + 2] = (cv[i] >> 16) & 0xFF; - my_output[i * 4 + 3] = (cv[i] >> 24) & 0xFF; - } -} - -// Kernel: Process chunks in parallel (for large files) -kernel void blake3_process_chunks( - device const uint8_t* data [[buffer(0)]], - device uint32_t* chunk_cvs [[buffer(1)]], - constant uint32_t* key [[buffer(2)]], - constant uint32_t& num_chunks [[buffer(3)]], - constant uint32_t& last_chunk_len [[buffer(4)]], - constant uint32_t& flags [[buffer(5)]], - uint32_t tid [[thread_position_in_grid]] -) { - if (tid >= num_chunks) return; - - uint32_t chunk_len = 1024; - if (tid == num_chunks - 1) { - chunk_len = last_chunk_len; - } - - const device uint8_t* chunk = data + (uint64_t)tid * 1024; - device uint32_t* cv_out = chunk_cvs + tid * 8; - - uint32_t cv[8]; - blake3_hash_chunk(cv, chunk, chunk_len, key, (uint64_t)tid, flags); - - for (int i = 0; i < 8; i++) { - cv_out[i] = cv[i]; - } -} - -// Kernel: Parent node hash (merge two children) -kernel void blake3_parent_hash( - device const uint32_t* children [[buffer(0)]], - device uint32_t* parents [[buffer(1)]], - constant uint32_t* key [[buffer(2)]], - constant uint32_t& flags [[buffer(3)]], - uint32_t tid [[thread_position_in_grid]] -) { - const device uint32_t* left = children + tid * 16; - const device uint32_t* right = children + tid * 16 + 8; - device uint32_t* parent = parents + tid * 8; - - uint32_t block[16]; - for (int i = 0; i < 8; i++) block[i] = left[i]; - for (int i = 0; i < 8; i++) block[8 + i] = right[i]; - - uint32_t cv[8]; - for (int i = 0; i < 8; i++) cv[i] = key[i]; - - blake3_compress(cv, cv, block, 0, 64, flags | PARENT); - - for (int i = 0; i < 8; i++) { - parent[i] = cv[i]; - } -} - -// Kernel: Merkle tree layer -kernel void blake3_merkle_layer( - device const uint8_t* children [[buffer(0)]], - device uint8_t* parents [[buffer(1)]], - constant uint32_t* key [[buffer(2)]], - uint32_t tid [[thread_position_in_grid]] -) { - const device uint8_t* left = children + tid * 64; - const device uint8_t* right = children + tid * 64 + 32; - device uint8_t* parent_out = parents + tid * 32; - - uint32_t block[16]; - - // Load children as little-endian words - for (int i = 0; i < 8; i++) { - block[i] = ((uint32_t)left[i*4]) | - ((uint32_t)left[i*4+1] << 8) | - ((uint32_t)left[i*4+2] << 16) | - ((uint32_t)left[i*4+3] << 24); - } - for (int i = 0; i < 8; i++) { - block[8+i] = ((uint32_t)right[i*4]) | - ((uint32_t)right[i*4+1] << 8) | - ((uint32_t)right[i*4+2] << 16) | - ((uint32_t)right[i*4+3] << 24); - } - - uint32_t cv[8]; - for (int i = 0; i < 8; i++) cv[i] = key[i]; - - blake3_compress(cv, cv, block, 0, 64, PARENT); - - for (int i = 0; i < 8; i++) { - parent_out[i*4] = cv[i] & 0xFF; - parent_out[i*4+1] = (cv[i] >> 8) & 0xFF; - parent_out[i*4+2] = (cv[i] >> 16) & 0xFF; - parent_out[i*4+3] = (cv[i] >> 24) & 0xFF; - } -} -)"; - -// ============================================================================= -// Context Structure -// ============================================================================= - -struct MetalBLAKE3Context { - id device; - id commandQueue; - id batchHashFixedPipeline; - id processChunksPipeline; - id parentHashPipeline; - id merkleLayerPipeline; - id ivBuffer; -}; - -struct MetalBLAKE3Hasher { - MetalBLAKE3Context* ctx; - uint32_t key[8]; - uint32_t flags; - std::vector buffer; - std::vector> chunk_cvs; - uint64_t chunk_counter; - uint32_t buf_len; -}; - -// ============================================================================= -// Context Management -// ============================================================================= - -extern "C" { - -MetalBLAKE3Context* metal_blake3_init(void) { - @autoreleasepool { - MetalBLAKE3Context* ctx = new MetalBLAKE3Context(); - - ctx->device = MTLCreateSystemDefaultDevice(); - if (!ctx->device) { - delete ctx; - return nullptr; - } - - ctx->commandQueue = [ctx->device newCommandQueue]; - if (!ctx->commandQueue) { - delete ctx; - return nullptr; - } - - // Compile shaders - NSError* error = nil; - NSString* source = [NSString stringWithUTF8String:BLAKE3_SHADER_SOURCE]; - id library = [ctx->device newLibraryWithSource:source options:nil error:&error]; - - if (!library) { - NSLog(@"BLAKE3 shader compilation failed: %@", error); - delete ctx; - return nullptr; - } - - // Create pipelines - id batchHashFixedFunc = [library newFunctionWithName:@"blake3_batch_hash_fixed"]; - if (batchHashFixedFunc) { - ctx->batchHashFixedPipeline = [ctx->device newComputePipelineStateWithFunction:batchHashFixedFunc error:&error]; - } - - id processChunksFunc = [library newFunctionWithName:@"blake3_process_chunks"]; - if (processChunksFunc) { - ctx->processChunksPipeline = [ctx->device newComputePipelineStateWithFunction:processChunksFunc error:&error]; - } - - id parentHashFunc = [library newFunctionWithName:@"blake3_parent_hash"]; - if (parentHashFunc) { - ctx->parentHashPipeline = [ctx->device newComputePipelineStateWithFunction:parentHashFunc error:&error]; - } - - id merkleLayerFunc = [library newFunctionWithName:@"blake3_merkle_layer"]; - if (merkleLayerFunc) { - ctx->merkleLayerPipeline = [ctx->device newComputePipelineStateWithFunction:merkleLayerFunc error:&error]; - } - - // BLAKE3 IV - uint32_t iv[8] = { - 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, - 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19 - }; - ctx->ivBuffer = [ctx->device newBufferWithBytes:iv length:32 options:MTLResourceStorageModeShared]; - - return ctx; - } -} - -void metal_blake3_destroy(MetalBLAKE3Context* ctx) { - if (ctx) { - delete ctx; - } -} - -bool metal_blake3_available(void) { - id device = MTLCreateSystemDefaultDevice(); - return device != nil; -} - -// ============================================================================= -// Helper: Software BLAKE3 compression for small inputs -// ============================================================================= - -static const uint32_t BLAKE3_IV[8] = { - 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, - 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19 -}; - -static const uint8_t MSG_SCHEDULE[7][16] = { - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - {2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8}, - {3, 4, 10, 12, 13, 2, 7, 14, 6, 5, 9, 0, 11, 15, 8, 1}, - {10, 7, 12, 9, 14, 3, 13, 15, 4, 0, 11, 2, 5, 8, 1, 6}, - {12, 13, 9, 11, 15, 10, 14, 8, 7, 2, 5, 3, 0, 1, 6, 4}, - {9, 14, 11, 5, 8, 12, 15, 1, 13, 3, 0, 10, 2, 6, 4, 7}, - {11, 15, 5, 0, 1, 9, 8, 6, 14, 10, 2, 12, 3, 4, 7, 13} -}; - -static inline uint32_t rotr32(uint32_t x, int n) { - return (x >> n) | (x << (32 - n)); -} - -static void g_cpu(uint32_t* state, int a, int b, int c, int d, uint32_t mx, uint32_t my) { - state[a] = state[a] + state[b] + mx; - state[d] = rotr32(state[d] ^ state[a], 16); - state[c] = state[c] + state[d]; - state[b] = rotr32(state[b] ^ state[c], 12); - state[a] = state[a] + state[b] + my; - state[d] = rotr32(state[d] ^ state[a], 8); - state[c] = state[c] + state[d]; - state[b] = rotr32(state[b] ^ state[c], 7); -} - -static void blake3_round_cpu(uint32_t* state, const uint32_t* msg, int round) { - const uint8_t* s = MSG_SCHEDULE[round]; - g_cpu(state, 0, 4, 8, 12, msg[s[0]], msg[s[1]]); - g_cpu(state, 1, 5, 9, 13, msg[s[2]], msg[s[3]]); - g_cpu(state, 2, 6, 10, 14, msg[s[4]], msg[s[5]]); - g_cpu(state, 3, 7, 11, 15, msg[s[6]], msg[s[7]]); - g_cpu(state, 0, 5, 10, 15, msg[s[8]], msg[s[9]]); - g_cpu(state, 1, 6, 11, 12, msg[s[10]], msg[s[11]]); - g_cpu(state, 2, 7, 8, 13, msg[s[12]], msg[s[13]]); - g_cpu(state, 3, 4, 9, 14, msg[s[14]], msg[s[15]]); -} - -static void blake3_compress_cpu( - uint32_t* cv_out, - const uint32_t* cv_in, - const uint32_t* block, - uint64_t counter, - uint32_t block_len, - uint32_t flags -) { - uint32_t state[16]; - - for (int i = 0; i < 8; i++) state[i] = cv_in[i]; - for (int i = 0; i < 4; i++) state[8 + i] = BLAKE3_IV[i]; - state[12] = (uint32_t)counter; - state[13] = (uint32_t)(counter >> 32); - state[14] = block_len; - state[15] = flags; - - for (int r = 0; r < 7; r++) { - blake3_round_cpu(state, block, r); - } - - for (int i = 0; i < 8; i++) { - cv_out[i] = state[i] ^ state[i + 8]; - } -} - -static void blake3_hash_chunk_cpu( - uint32_t* cv_out, - const uint8_t* chunk, - uint32_t chunk_len, - const uint32_t* key, - uint64_t chunk_counter, - uint32_t flags -) { - uint32_t cv[8]; - for (int i = 0; i < 8; i++) cv[i] = key[i]; - - uint32_t blocks = (chunk_len + 63) / 64; - if (blocks == 0) blocks = 1; - - for (uint32_t b = 0; b < blocks; b++) { - uint32_t block[16] = {0}; - - uint32_t offset = b * 64; - uint32_t remaining = (chunk_len > offset) ? chunk_len - offset : 0; - uint32_t to_copy = std::min(remaining, 64u); - - for (uint32_t i = 0; i < to_copy; i++) { - block[i / 4] |= ((uint32_t)chunk[offset + i]) << ((i % 4) * 8); - } - - uint32_t block_flags = flags; - if (b == 0) block_flags |= 1; // CHUNK_START - if (b == blocks - 1) block_flags |= 2; // CHUNK_END - - blake3_compress_cpu(cv, cv, block, chunk_counter, std::min(to_copy, 64u), block_flags); - } - - for (int i = 0; i < 8; i++) cv_out[i] = cv[i]; -} - -// ============================================================================= -// Simple Hash Interface -// ============================================================================= - -MetalBLAKE3Result metal_blake3_hash( - MetalBLAKE3Context* ctx, - uint8_t* output, - size_t out_len, - const uint8_t* input, - size_t in_len -) { - if (!ctx || !output || out_len < 32) { - return METAL_BLAKE3_ERROR_INVALID_INPUT; - } - - // For small inputs, use CPU - if (in_len <= 1024) { - uint32_t cv[8]; - blake3_hash_chunk_cpu(cv, input, (uint32_t)in_len, BLAKE3_IV, 0, 8); // ROOT flag - - for (int i = 0; i < 8; i++) { - output[i * 4 + 0] = cv[i] & 0xFF; - output[i * 4 + 1] = (cv[i] >> 8) & 0xFF; - output[i * 4 + 2] = (cv[i] >> 16) & 0xFF; - output[i * 4 + 3] = (cv[i] >> 24) & 0xFF; - } - - return METAL_BLAKE3_SUCCESS; - } - - // For larger inputs, use GPU - @autoreleasepool { - uint32_t num_chunks = (uint32_t)((in_len + 1023) / 1024); - uint32_t last_chunk_len = (uint32_t)(in_len % 1024); - if (last_chunk_len == 0) last_chunk_len = 1024; - - id inputBuffer = [ctx->device newBufferWithBytes:input - length:in_len - options:MTLResourceStorageModeShared]; - id chunkCVsBuffer = [ctx->device newBufferWithLength:num_chunks * 32 - options:MTLResourceStorageModeShared]; - - uint32_t flags = 0; - - // Process chunks - id commandBuffer = [ctx->commandQueue commandBuffer]; - id encoder = [commandBuffer computeCommandEncoder]; - - [encoder setComputePipelineState:ctx->processChunksPipeline]; - [encoder setBuffer:inputBuffer offset:0 atIndex:0]; - [encoder setBuffer:chunkCVsBuffer offset:0 atIndex:1]; - [encoder setBuffer:ctx->ivBuffer offset:0 atIndex:2]; - [encoder setBytes:&num_chunks length:sizeof(uint32_t) atIndex:3]; - [encoder setBytes:&last_chunk_len length:sizeof(uint32_t) atIndex:4]; - [encoder setBytes:&flags length:sizeof(uint32_t) atIndex:5]; - - [encoder dispatchThreads:MTLSizeMake(num_chunks, 1, 1) - threadsPerThreadgroup:MTLSizeMake(std::min(num_chunks, (uint32_t)256), 1, 1)]; - [encoder endEncoding]; - - [commandBuffer commit]; - [commandBuffer waitUntilCompleted]; - - // Merge chunks into tree - std::vector cvs(num_chunks * 8); - memcpy(cvs.data(), [chunkCVsBuffer contents], num_chunks * 32); - - while (num_chunks > 1) { - uint32_t pairs = num_chunks / 2; - std::vector parents(pairs * 8); - - for (uint32_t i = 0; i < pairs; i++) { - uint32_t block[16]; - for (int j = 0; j < 8; j++) block[j] = cvs[i * 16 + j]; - for (int j = 0; j < 8; j++) block[8 + j] = cvs[i * 16 + 8 + j]; - - uint32_t parent_flags = 4; // PARENT - if (pairs == 1 && (num_chunks % 2 == 0)) { - parent_flags |= 8; // ROOT - } - - uint32_t cv[8]; - for (int j = 0; j < 8; j++) cv[j] = BLAKE3_IV[j]; - blake3_compress_cpu(cv, cv, block, 0, 64, parent_flags); - - for (int j = 0; j < 8; j++) parents[i * 8 + j] = cv[j]; - } - - // Handle odd chunk - if (num_chunks % 2 == 1) { - for (int j = 0; j < 8; j++) { - parents.push_back(cvs[(num_chunks - 1) * 8 + j]); - } - num_chunks = pairs + 1; - } else { - num_chunks = pairs; - } - - cvs = std::move(parents); - } - - // Output root - for (int i = 0; i < 8; i++) { - output[i * 4 + 0] = cvs[i] & 0xFF; - output[i * 4 + 1] = (cvs[i] >> 8) & 0xFF; - output[i * 4 + 2] = (cvs[i] >> 16) & 0xFF; - output[i * 4 + 3] = (cvs[i] >> 24) & 0xFF; - } - - return METAL_BLAKE3_SUCCESS; - } -} - -MetalBLAKE3Result metal_blake3_batch_hash( - MetalBLAKE3Context* ctx, - uint8_t* outputs, - const uint8_t** inputs, - const size_t* in_lens, - uint32_t count -) { - if (!ctx || !outputs || !inputs || !in_lens || count == 0) { - return METAL_BLAKE3_ERROR_INVALID_INPUT; - } - - // Process each input (could optimize by grouping by size) - for (uint32_t i = 0; i < count; i++) { - MetalBLAKE3Result result = metal_blake3_hash(ctx, outputs + i * 32, 32, inputs[i], in_lens[i]); - if (result != METAL_BLAKE3_SUCCESS) { - return result; - } - } - - return METAL_BLAKE3_SUCCESS; -} - -MetalBLAKE3Result metal_blake3_batch_hash_fixed( - MetalBLAKE3Context* ctx, - uint8_t* outputs, - const uint8_t* inputs, - size_t in_len, - uint32_t count -) { - if (!ctx || !outputs || !inputs || count == 0) { - return METAL_BLAKE3_ERROR_INVALID_INPUT; - } - - @autoreleasepool { - id inputBuffer = [ctx->device newBufferWithBytes:inputs - length:in_len * count - options:MTLResourceStorageModeShared]; - id outputBuffer = [ctx->device newBufferWithLength:count * 32 - options:MTLResourceStorageModeShared]; - - uint32_t input_len_u32 = (uint32_t)in_len; - uint32_t flags = 0; - - id commandBuffer = [ctx->commandQueue commandBuffer]; - id encoder = [commandBuffer computeCommandEncoder]; - - [encoder setComputePipelineState:ctx->batchHashFixedPipeline]; - [encoder setBuffer:inputBuffer offset:0 atIndex:0]; - [encoder setBuffer:outputBuffer offset:0 atIndex:1]; - [encoder setBytes:&input_len_u32 length:sizeof(uint32_t) atIndex:2]; - [encoder setBuffer:ctx->ivBuffer offset:0 atIndex:3]; - [encoder setBytes:&flags length:sizeof(uint32_t) atIndex:4]; - - [encoder dispatchThreads:MTLSizeMake(count, 1, 1) - threadsPerThreadgroup:MTLSizeMake(std::min(count, (uint32_t)256), 1, 1)]; - [encoder endEncoding]; - - [commandBuffer commit]; - [commandBuffer waitUntilCompleted]; - - memcpy(outputs, [outputBuffer contents], count * 32); - - return METAL_BLAKE3_SUCCESS; - } -} - -// ============================================================================= -// Keyed Hash -// ============================================================================= - -MetalBLAKE3Result metal_blake3_keyed_hash( - MetalBLAKE3Context* ctx, - uint8_t* output, - size_t out_len, - const uint8_t key[BLAKE3_KEY_LEN], - const uint8_t* input, - size_t in_len -) { - if (!ctx || !output || !key || out_len < 32) { - return METAL_BLAKE3_ERROR_INVALID_INPUT; - } - - // Convert key bytes to words - uint32_t key_words[8]; - for (int i = 0; i < 8; i++) { - key_words[i] = ((uint32_t)key[i*4]) | - ((uint32_t)key[i*4+1] << 8) | - ((uint32_t)key[i*4+2] << 16) | - ((uint32_t)key[i*4+3] << 24); - } - - uint32_t cv[8]; - blake3_hash_chunk_cpu(cv, input, (uint32_t)in_len, key_words, 0, 16 | 8); // KEYED_HASH | ROOT - - for (int i = 0; i < 8; i++) { - output[i * 4 + 0] = cv[i] & 0xFF; - output[i * 4 + 1] = (cv[i] >> 8) & 0xFF; - output[i * 4 + 2] = (cv[i] >> 16) & 0xFF; - output[i * 4 + 3] = (cv[i] >> 24) & 0xFF; - } - - return METAL_BLAKE3_SUCCESS; -} - -MetalBLAKE3Result metal_blake3_batch_keyed_hash( - MetalBLAKE3Context* ctx, - uint8_t* outputs, - const uint8_t key[BLAKE3_KEY_LEN], - const uint8_t** inputs, - const size_t* in_lens, - uint32_t count -) { - for (uint32_t i = 0; i < count; i++) { - MetalBLAKE3Result result = metal_blake3_keyed_hash(ctx, outputs + i * 32, 32, key, inputs[i], in_lens[i]); - if (result != METAL_BLAKE3_SUCCESS) { - return result; - } - } - return METAL_BLAKE3_SUCCESS; -} - -// ============================================================================= -// Key Derivation -// ============================================================================= - -MetalBLAKE3Result metal_blake3_derive_key( - MetalBLAKE3Context* ctx, - uint8_t* output, - size_t out_len, - const char* context, - size_t context_len, - const uint8_t* key_material, - size_t km_len -) { - if (!ctx || !output || !context || !key_material || out_len < 32) { - return METAL_BLAKE3_ERROR_INVALID_INPUT; - } - - // First: hash context string with DERIVE_KEY_CONTEXT flag - uint32_t context_key[8]; - blake3_hash_chunk_cpu(context_key, (const uint8_t*)context, (uint32_t)context_len, BLAKE3_IV, 0, 32 | 8); // DERIVE_KEY_CONTEXT | ROOT - - // Second: hash key material with derived context key and DERIVE_KEY_MATERIAL flag - uint32_t cv[8]; - blake3_hash_chunk_cpu(cv, key_material, (uint32_t)km_len, context_key, 0, 64 | 8); // DERIVE_KEY_MATERIAL | ROOT - - for (int i = 0; i < 8; i++) { - output[i * 4 + 0] = cv[i] & 0xFF; - output[i * 4 + 1] = (cv[i] >> 8) & 0xFF; - output[i * 4 + 2] = (cv[i] >> 16) & 0xFF; - output[i * 4 + 3] = (cv[i] >> 24) & 0xFF; - } - - return METAL_BLAKE3_SUCCESS; -} - -// ============================================================================= -// Streaming Interface -// ============================================================================= - -MetalBLAKE3Hasher* metal_blake3_hasher_new(MetalBLAKE3Context* ctx) { - if (!ctx) return nullptr; - - MetalBLAKE3Hasher* hasher = new MetalBLAKE3Hasher(); - hasher->ctx = ctx; - for (int i = 0; i < 8; i++) hasher->key[i] = BLAKE3_IV[i]; - hasher->flags = 0; - hasher->buffer.reserve(1024); - hasher->chunk_counter = 0; - hasher->buf_len = 0; - - return hasher; -} - -MetalBLAKE3Hasher* metal_blake3_hasher_new_keyed( - MetalBLAKE3Context* ctx, - const uint8_t key[BLAKE3_KEY_LEN] -) { - if (!ctx || !key) return nullptr; - - MetalBLAKE3Hasher* hasher = new MetalBLAKE3Hasher(); - hasher->ctx = ctx; - - for (int i = 0; i < 8; i++) { - hasher->key[i] = ((uint32_t)key[i*4]) | - ((uint32_t)key[i*4+1] << 8) | - ((uint32_t)key[i*4+2] << 16) | - ((uint32_t)key[i*4+3] << 24); - } - - hasher->flags = 16; // KEYED_HASH - hasher->buffer.reserve(1024); - hasher->chunk_counter = 0; - hasher->buf_len = 0; - - return hasher; -} - -MetalBLAKE3Hasher* metal_blake3_hasher_new_derive_key( - MetalBLAKE3Context* ctx, - const char* context, - size_t context_len -) { - if (!ctx || !context) return nullptr; - - MetalBLAKE3Hasher* hasher = new MetalBLAKE3Hasher(); - hasher->ctx = ctx; - - // Derive key from context - blake3_hash_chunk_cpu(hasher->key, (const uint8_t*)context, (uint32_t)context_len, BLAKE3_IV, 0, 32 | 8); - - hasher->flags = 64; // DERIVE_KEY_MATERIAL - hasher->buffer.reserve(1024); - hasher->chunk_counter = 0; - hasher->buf_len = 0; - - return hasher; -} - -MetalBLAKE3Result metal_blake3_hasher_update( - MetalBLAKE3Hasher* hasher, - const uint8_t* input, - size_t in_len -) { - if (!hasher || (!input && in_len > 0)) { - return METAL_BLAKE3_ERROR_INVALID_INPUT; - } - - size_t consumed = 0; - while (consumed < in_len) { - size_t space = 1024 - hasher->buf_len; - size_t to_copy = std::min(space, in_len - consumed); - - hasher->buffer.resize(hasher->buf_len + to_copy); - memcpy(hasher->buffer.data() + hasher->buf_len, input + consumed, to_copy); - hasher->buf_len += to_copy; - consumed += to_copy; - - if (hasher->buf_len == 1024) { - // Complete chunk - hash it - std::array cv; - blake3_hash_chunk_cpu(cv.data(), hasher->buffer.data(), 1024, hasher->key, hasher->chunk_counter, hasher->flags); - hasher->chunk_cvs.push_back(cv); - hasher->chunk_counter++; - hasher->buf_len = 0; - hasher->buffer.clear(); - } - } - - return METAL_BLAKE3_SUCCESS; -} - -MetalBLAKE3Result metal_blake3_hasher_finalize( - MetalBLAKE3Hasher* hasher, - uint8_t* output, - size_t out_len -) { - if (!hasher || !output || out_len < 32) { - return METAL_BLAKE3_ERROR_INVALID_INPUT; - } - - // Finalize last chunk - std::array last_cv; - if (hasher->buf_len > 0 || hasher->chunk_cvs.empty()) { - uint32_t flags = hasher->flags; - if (hasher->chunk_cvs.empty()) flags |= 8; // ROOT if only chunk - blake3_hash_chunk_cpu(last_cv.data(), hasher->buffer.data(), hasher->buf_len, hasher->key, hasher->chunk_counter, flags); - hasher->chunk_cvs.push_back(last_cv); - } - - // Merge tree - std::vector> cvs = hasher->chunk_cvs; - - while (cvs.size() > 1) { - std::vector> parents; - - for (size_t i = 0; i + 1 < cvs.size(); i += 2) { - uint32_t block[16]; - for (int j = 0; j < 8; j++) block[j] = cvs[i][j]; - for (int j = 0; j < 8; j++) block[8 + j] = cvs[i + 1][j]; - - uint32_t parent_flags = 4; // PARENT - if (parents.empty() && i + 2 >= cvs.size()) { - parent_flags |= 8; // ROOT - } - - std::array parent; - for (int j = 0; j < 8; j++) parent[j] = hasher->key[j]; - blake3_compress_cpu(parent.data(), parent.data(), block, 0, 64, parent_flags); - parents.push_back(parent); - } - - if (cvs.size() % 2 == 1) { - parents.push_back(cvs.back()); - } - - cvs = std::move(parents); - } - - // Output - for (int i = 0; i < 8; i++) { - output[i * 4 + 0] = cvs[0][i] & 0xFF; - output[i * 4 + 1] = (cvs[0][i] >> 8) & 0xFF; - output[i * 4 + 2] = (cvs[0][i] >> 16) & 0xFF; - output[i * 4 + 3] = (cvs[0][i] >> 24) & 0xFF; - } - - return METAL_BLAKE3_SUCCESS; -} - -void metal_blake3_hasher_reset(MetalBLAKE3Hasher* hasher) { - if (hasher) { - hasher->buffer.clear(); - hasher->chunk_cvs.clear(); - hasher->chunk_counter = 0; - hasher->buf_len = 0; - } -} - -void metal_blake3_hasher_free(MetalBLAKE3Hasher* hasher) { - if (hasher) { - delete hasher; - } -} - -// ============================================================================= -// Merkle Tree -// ============================================================================= - -MetalBLAKE3Result metal_blake3_merkle_root( - MetalBLAKE3Context* ctx, - uint8_t root[BLAKE3_OUT_LEN], - const uint8_t* leaves, - uint32_t count -) { - if (!ctx || !root || !leaves || count == 0) { - return METAL_BLAKE3_ERROR_INVALID_INPUT; - } - - if ((count & (count - 1)) != 0) { - return METAL_BLAKE3_ERROR_INVALID_INPUT; // Must be power of 2 - } - - @autoreleasepool { - id nodesBuffer = [ctx->device newBufferWithBytes:leaves - length:count * 32 - options:MTLResourceStorageModeShared]; - - uint32_t level_size = count; - - while (level_size > 1) { - id parentsBuffer = [ctx->device newBufferWithLength:(level_size / 2) * 32 - options:MTLResourceStorageModeShared]; - - id commandBuffer = [ctx->commandQueue commandBuffer]; - id encoder = [commandBuffer computeCommandEncoder]; - - [encoder setComputePipelineState:ctx->merkleLayerPipeline]; - [encoder setBuffer:nodesBuffer offset:0 atIndex:0]; - [encoder setBuffer:parentsBuffer offset:0 atIndex:1]; - [encoder setBuffer:ctx->ivBuffer offset:0 atIndex:2]; - - uint32_t pairs = level_size / 2; - [encoder dispatchThreads:MTLSizeMake(pairs, 1, 1) - threadsPerThreadgroup:MTLSizeMake(std::min(pairs, (uint32_t)256), 1, 1)]; - [encoder endEncoding]; - - [commandBuffer commit]; - [commandBuffer waitUntilCompleted]; - - nodesBuffer = parentsBuffer; - level_size /= 2; - } - - memcpy(root, [nodesBuffer contents], 32); - - return METAL_BLAKE3_SUCCESS; - } -} - -MetalBLAKE3Result metal_blake3_merkle_tree( - MetalBLAKE3Context* ctx, - uint8_t* nodes, - const uint8_t* leaves, - uint32_t count -) { - if (!ctx || !nodes || !leaves || count == 0) { - return METAL_BLAKE3_ERROR_INVALID_INPUT; - } - - if ((count & (count - 1)) != 0) { - return METAL_BLAKE3_ERROR_INVALID_INPUT; - } - - // Copy leaves to output - memcpy(nodes, leaves, count * 32); - - uint32_t offset = 0; - uint32_t level_size = count; - - while (level_size > 1) { - uint32_t pairs = level_size / 2; - uint8_t* children = nodes + offset * 32; - uint8_t* parents = nodes + (offset + level_size) * 32; - - for (uint32_t i = 0; i < pairs; i++) { - uint32_t block[16]; - - const uint8_t* left = children + i * 64; - const uint8_t* right = children + i * 64 + 32; - - for (int j = 0; j < 8; j++) { - block[j] = ((uint32_t)left[j*4]) | - ((uint32_t)left[j*4+1] << 8) | - ((uint32_t)left[j*4+2] << 16) | - ((uint32_t)left[j*4+3] << 24); - } - for (int j = 0; j < 8; j++) { - block[8+j] = ((uint32_t)right[j*4]) | - ((uint32_t)right[j*4+1] << 8) | - ((uint32_t)right[j*4+2] << 16) | - ((uint32_t)right[j*4+3] << 24); - } - - uint32_t cv[8]; - for (int j = 0; j < 8; j++) cv[j] = BLAKE3_IV[j]; - blake3_compress_cpu(cv, cv, block, 0, 64, 4); // PARENT - - uint8_t* parent = parents + i * 32; - for (int j = 0; j < 8; j++) { - parent[j*4] = cv[j] & 0xFF; - parent[j*4+1] = (cv[j] >> 8) & 0xFF; - parent[j*4+2] = (cv[j] >> 16) & 0xFF; - parent[j*4+3] = (cv[j] >> 24) & 0xFF; - } - } - - offset += level_size; - level_size /= 2; - } - - return METAL_BLAKE3_SUCCESS; -} - -// ============================================================================= -// File Hashing -// ============================================================================= - -MetalBLAKE3Result metal_blake3_hash_file( - MetalBLAKE3Context* ctx, - uint8_t* output, - size_t out_len, - const char* path -) { - if (!ctx || !output || !path || out_len < 32) { - return METAL_BLAKE3_ERROR_INVALID_INPUT; - } - - std::ifstream file(path, std::ios::binary | std::ios::ate); - if (!file.is_open()) { - return METAL_BLAKE3_ERROR_INVALID_INPUT; - } - - size_t file_size = file.tellg(); - file.seekg(0); - - if (file_size <= 1024 * 1024) { - // Small file - read all at once - std::vector data(file_size); - file.read((char*)data.data(), file_size); - return metal_blake3_hash(ctx, output, out_len, data.data(), file_size); - } - - // Large file - stream with hasher - MetalBLAKE3Hasher* hasher = metal_blake3_hasher_new(ctx); - if (!hasher) { - return METAL_BLAKE3_ERROR_INIT; - } - - std::vector buffer(1024 * 1024); // 1MB chunks - - while (file) { - file.read((char*)buffer.data(), buffer.size()); - size_t bytes_read = file.gcount(); - if (bytes_read > 0) { - metal_blake3_hasher_update(hasher, buffer.data(), bytes_read); - } - } - - MetalBLAKE3Result result = metal_blake3_hasher_finalize(hasher, output, out_len); - metal_blake3_hasher_free(hasher); - - return result; -} - -} // extern "C" diff --git a/blake3/gpu/wgsl/blake3.wgsl b/blake3/gpu/wgsl/blake3.wgsl deleted file mode 100644 index cfce3d3..0000000 --- a/blake3/gpu/wgsl/blake3.wgsl +++ /dev/null @@ -1,155 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// BLAKE3 hash compute shader in WGSL. -// -// One thread per hash. Each thread reads its input descriptor, processes -// chunks through the BLAKE3 compression function (7 rounds), and writes -// a 32-byte digest. -// -// BLAKE3 is tree-based and parallelizable by design. - -struct HashInput { - offset: u32, - length: u32, -} - -@group(0) @binding(0) var inputs: array; -@group(0) @binding(1) var data: array; -@group(0) @binding(2) var outputs: array; -@group(0) @binding(3) var params: vec4; // params.x = num_inputs - -const IV = array( - 0x6A09E667u, 0xBB67AE85u, 0x3C6EF372u, 0xA54FF53Au, - 0x510E527Fu, 0x9B05688Cu, 0x1F83D9ABu, 0x5BE0CD19u -); - -const CHUNK_START: u32 = 1u; -const CHUNK_END: u32 = 2u; -const ROOT: u32 = 8u; - -const MSG_PERM = array( - 2u, 6u, 3u, 10u, 7u, 0u, 4u, 13u, 1u, 11u, 12u, 5u, 9u, 14u, 15u, 8u -); - -fn rotr32(x: u32, n: u32) -> u32 { - return (x >> n) | (x << (32u - n)); -} - -fn blake3_g(state: ptr>, a: u32, b: u32, c: u32, d: u32, mx: u32, my: u32) { - (*state)[a] = (*state)[a] + (*state)[b] + mx; - (*state)[d] = rotr32((*state)[d] ^ (*state)[a], 16u); - (*state)[c] = (*state)[c] + (*state)[d]; - (*state)[b] = rotr32((*state)[b] ^ (*state)[c], 12u); - (*state)[a] = (*state)[a] + (*state)[b] + my; - (*state)[d] = rotr32((*state)[d] ^ (*state)[a], 8u); - (*state)[c] = (*state)[c] + (*state)[d]; - (*state)[b] = rotr32((*state)[b] ^ (*state)[c], 7u); -} - -fn blake3_round(state: ptr>, m: ptr>) { - blake3_g(state, 0u, 4u, 8u, 12u, (*m)[0u], (*m)[1u]); - blake3_g(state, 1u, 5u, 9u, 13u, (*m)[2u], (*m)[3u]); - blake3_g(state, 2u, 6u, 10u, 14u, (*m)[4u], (*m)[5u]); - blake3_g(state, 3u, 7u, 11u, 15u, (*m)[6u], (*m)[7u]); - blake3_g(state, 0u, 5u, 10u, 15u, (*m)[8u], (*m)[9u]); - blake3_g(state, 1u, 6u, 11u, 12u, (*m)[10u], (*m)[11u]); - blake3_g(state, 2u, 7u, 8u, 13u, (*m)[12u], (*m)[13u]); - blake3_g(state, 3u, 4u, 9u, 14u, (*m)[14u], (*m)[15u]); -} - -fn read_byte(byte_offset: u32) -> u32 { - let word_idx = byte_offset >> 2u; - let byte_pos = byte_offset & 3u; - return (data[word_idx] >> (byte_pos * 8u)) & 0xFFu; -} - -fn blake3_compress(cv: ptr>, - block: ptr>, - counter: u32, - block_len: u32, - flags: u32, - out: ptr>) { - var state: array; - state[0u] = (*cv)[0u]; state[1u] = (*cv)[1u]; state[2u] = (*cv)[2u]; state[3u] = (*cv)[3u]; - state[4u] = (*cv)[4u]; state[5u] = (*cv)[5u]; state[6u] = (*cv)[6u]; state[7u] = (*cv)[7u]; - state[8u] = IV[0]; state[9u] = IV[1]; state[10u] = IV[2]; state[11u] = IV[3]; - state[12u] = counter; - state[13u] = 0u; - state[14u] = block_len; - state[15u] = flags; - - var m: array; - for (var i = 0u; i < 16u; i = i + 1u) { - m[i] = (*block)[i]; - } - - for (var round = 0u; round < 7u; round = round + 1u) { - blake3_round(&state, &m); - var tmp: array; - for (var i = 0u; i < 16u; i = i + 1u) { - tmp[i] = m[MSG_PERM[i]]; - } - m = tmp; - } - - for (var i = 0u; i < 8u; i = i + 1u) { - (*out)[i] = state[i] ^ state[i + 8u]; - } -} - -@compute @workgroup_size(64) -fn blake3_hash_batch(@builtin(global_invocation_id) gid: vec3) { - let tid = gid.x; - if (tid >= params.x) { return; } - - let inp = inputs[tid]; - let offset = inp.offset; - let len = inp.length; - - var cv: array; - for (var i = 0u; i < 8u; i = i + 1u) { cv[i] = IV[i]; } - - var remaining = len; - var pos = 0u; - var block_idx = 0u; - - loop { - if (remaining == 0u && block_idx > 0u) { break; } - - var block: array; - for (var i = 0u; i < 16u; i = i + 1u) { block[i] = 0u; } - - var to_copy = remaining; - if (to_copy > 64u) { to_copy = 64u; } - - // Load block bytes into u32 words (little-endian) - for (var i = 0u; i < to_copy; i = i + 1u) { - let byte_val = read_byte(offset + pos + i); - let word_idx = i >> 2u; - let byte_pos = i & 3u; - block[word_idx] = block[word_idx] | (byte_val << (byte_pos * 8u)); - } - - var flags = 0u; - if (block_idx == 0u) { flags = flags | CHUNK_START; } - let is_last = (remaining <= 64u); - if (is_last) { flags = flags | CHUNK_END | ROOT; } - - var out: array; - blake3_compress(&cv, &block, 0u, to_copy, flags, &out); - - if (is_last) { - let out_base = tid * 8u; - for (var i = 0u; i < 8u; i = i + 1u) { - outputs[out_base + i] = out[i]; - } - return; - } - - cv = out; - pos = pos + to_copy; - remaining = remaining - to_copy; - block_idx = block_idx + 1u; - } -} diff --git a/bls/gpu/cuda/bls.cu b/bls/gpu/cuda/bls.cu deleted file mode 100644 index 6598489..0000000 --- a/bls/gpu/cuda/bls.cu +++ /dev/null @@ -1,667 +0,0 @@ -// BLS12-381 batch signature verification — CUDA implementation -// Matches bls12_381.metal output byte-for-byte -// 384-bit Montgomery field arithmetic for G1 point operations - -#include - -#ifndef __CUDA_ARCH__ -#define __device__ -#define __global__ -#define __shared__ -struct dim3 { unsigned x, y, z; }; -static dim3 blockIdx, blockDim, threadIdx; -static inline void __syncthreads() {} -template static inline T atomicAdd(T* addr, T val) { T old = *addr; *addr += val; return old; } -#endif - -// ============================================================================= -// 384-bit unsigned integer (6 x 64-bit limbs, little-endian) -// ============================================================================= - -struct uint384 { - uint64_t limbs[6]; -}; - -// ============================================================================= -// BLS12-381 constants -// ============================================================================= - -__device__ static const uint384 BLS_P = {{ - 0xB9FEFFFFFFFFAAABULL, - 0x1EABFFFEB153FFFFULL, - 0x6730D2A0F6B0F624ULL, - 0x64774B84F38512BFULL, - 0x4B1BA7B6434BACD7ULL, - 0x1A0111EA397FE69AULL -}}; - -__device__ static const uint384 BLS_R2 = {{ - 0xF4DF1F341C341746ULL, - 0x0A76E6A609D104F1ULL, - 0x8DE5476C4C95B6D5ULL, - 0x67EB88A9939D83C0ULL, - 0x9A793E85B519952DULL, - 0x11988FE592CAE3AAULL -}}; - -__device__ static const uint384 BLS_R = {{ - 0x760900000002FFCDULL, - 0xEBF4000BC40C0002ULL, - 0x5F48985753C758BAULL, - 0x77CE585370525745ULL, - 0x5C071A97A256EC6DULL, - 0x15F65EC3FA80E493ULL -}}; - -__device__ static const uint64_t BLS_P_INV = 0x89F3FFFCFFFCFFFDULL; - -__device__ static const uint384 G1_X = {{ - 0x5CB38790FD666E19ULL, - 0xF85DDE8F09FE5D5CULL, - 0x2C0B0A5CAFB74CD8ULL, - 0x95F7B3B14AAE717DULL, - 0x70E02F1AB69D14E3ULL, - 0x03C26A6D58B32048ULL -}}; - -__device__ static const uint384 G1_Y = {{ - 0xA402B931448DC5C8ULL, - 0xFBD6AA1ADEAD1CF6ULL, - 0x5B9D93D1BA1F5B57ULL, - 0x6DC08AFF5B3AF6DDULL, - 0xA4CF5B5C1B6CE90CULL, - 0x13F48FFF25F51018ULL -}}; - -__device__ static const uint384 ZERO384 = {{0, 0, 0, 0, 0, 0}}; - -// ============================================================================= -// 384-bit arithmetic -// ============================================================================= - -__device__ static int u384_cmp(uint384 a, uint384 b) { - for (int i = 5; i >= 0; i--) { - if (a.limbs[i] < b.limbs[i]) return -1; - if (a.limbs[i] > b.limbs[i]) return 1; - } - return 0; -} - -__device__ static bool u384_is_zero(uint384 a) { - return (a.limbs[0] | a.limbs[1] | a.limbs[2] | - a.limbs[3] | a.limbs[4] | a.limbs[5]) == 0; -} - -__device__ static uint384 u384_add(uint384 a, uint384 b, uint64_t& carry) { - uint384 r; - uint64_t c = 0; - for (int i = 0; i < 6; i++) { - uint64_t sum = a.limbs[i] + c; - c = (sum < a.limbs[i]) ? 1ULL : 0ULL; - uint64_t sum2 = sum + b.limbs[i]; - c += (sum2 < sum) ? 1ULL : 0ULL; - r.limbs[i] = sum2; - } - carry = c; - return r; -} - -__device__ static uint384 u384_sub(uint384 a, uint384 b, uint64_t& borrow) { - uint384 r; - uint64_t bw = 0; - for (int i = 0; i < 6; i++) { - uint64_t diff = a.limbs[i] - bw; - bw = (diff > a.limbs[i]) ? 1ULL : 0ULL; - uint64_t diff2 = diff - b.limbs[i]; - bw += (diff2 > diff) ? 1ULL : 0ULL; - r.limbs[i] = diff2; - } - borrow = bw; - return r; -} - -// ============================================================================= -// Montgomery arithmetic over Fp (384-bit) -// Uses __int128 for 64x64->128 multiply on CUDA -// ============================================================================= - -__device__ static uint384 mont_reduce_384(uint64_t t[12]) { - uint64_t a[13]; - for (int i = 0; i < 12; i++) a[i] = t[i]; - a[12] = 0; - - for (int i = 0; i < 6; i++) { - uint64_t u = a[i] * BLS_P_INV; - - uint64_t carry = 0; - for (int j = 0; j < 6; j++) { -#ifdef __CUDA_ARCH__ - unsigned __int128 prod = (unsigned __int128)u * BLS_P.limbs[j]; - unsigned __int128 acc = prod + carry + a[i + j]; - a[i + j] = (uint64_t)acc; - carry = (uint64_t)(acc >> 64); -#else - uint64_t u_lo = u & 0xFFFFFFFFULL; - uint64_t u_hi = u >> 32; - uint64_t m_lo = BLS_P.limbs[j] & 0xFFFFFFFFULL; - uint64_t m_hi = BLS_P.limbs[j] >> 32; - uint64_t ll = u_lo * m_lo; - uint64_t lh = u_lo * m_hi; - uint64_t hl = u_hi * m_lo; - uint64_t hh = u_hi * m_hi; - uint64_t mid = lh + (ll >> 32); - uint64_t mid2 = mid + hl; - if (mid2 < mid) hh += (1ULL << 32); - uint64_t lo = (mid2 << 32) | (ll & 0xFFFFFFFFULL); - uint64_t hi = hh + (mid2 >> 32); - uint64_t sum = lo + carry; - if (sum < lo) hi++; - lo = sum; - sum = a[i + j] + lo; - if (sum < a[i + j]) hi++; - a[i + j] = sum; - carry = hi; -#endif - } - for (int j = 6; i + j <= 12; j++) { - uint64_t sum = a[i + j] + carry; - carry = (sum < a[i + j]) ? 1ULL : 0ULL; - a[i + j] = sum; - if (carry == 0) break; - } - } - - uint384 r; - r.limbs[0] = a[6]; - r.limbs[1] = a[7]; - r.limbs[2] = a[8]; - r.limbs[3] = a[9]; - r.limbs[4] = a[10]; - r.limbs[5] = a[11]; - - if (a[12] || u384_cmp(r, BLS_P) >= 0) { - uint64_t bw; - r = u384_sub(r, BLS_P, bw); - } - return r; -} - -__device__ static uint384 fp_mul(uint384 a, uint384 b) { - uint64_t t[12] = {}; - - for (int i = 0; i < 6; i++) { - uint64_t carry = 0; - for (int j = 0; j < 6; j++) { -#ifdef __CUDA_ARCH__ - unsigned __int128 prod = (unsigned __int128)a.limbs[i] * b.limbs[j]; - unsigned __int128 acc = prod + carry + t[i + j]; - t[i + j] = (uint64_t)acc; - carry = (uint64_t)(acc >> 64); -#else - uint64_t a_lo = a.limbs[i] & 0xFFFFFFFFULL; - uint64_t a_hi = a.limbs[i] >> 32; - uint64_t b_lo = b.limbs[j] & 0xFFFFFFFFULL; - uint64_t b_hi = b.limbs[j] >> 32; - uint64_t ll = a_lo * b_lo; - uint64_t lh = a_lo * b_hi; - uint64_t hl = a_hi * b_lo; - uint64_t hh = a_hi * b_hi; - uint64_t mid = lh + (ll >> 32); - uint64_t mid2 = mid + hl; - if (mid2 < mid) hh += (1ULL << 32); - uint64_t lo = (mid2 << 32) | (ll & 0xFFFFFFFFULL); - uint64_t hi = hh + (mid2 >> 32); - uint64_t sum = lo + carry; - if (sum < lo) hi++; - lo = sum; - sum = t[i + j] + lo; - if (sum < t[i + j]) hi++; - t[i + j] = sum; - carry = hi; -#endif - } - for (int j = 6; i + j < 12; j++) { - uint64_t sum = t[i + j] + carry; - carry = (sum < t[i + j]) ? 1ULL : 0ULL; - t[i + j] = sum; - if (carry == 0) break; - } - } - - return mont_reduce_384(t); -} - -__device__ static uint384 fp_sqr(uint384 a) { - return fp_mul(a, a); -} - -__device__ static uint384 fp_add(uint384 a, uint384 b) { - uint64_t carry; - uint384 r = u384_add(a, b, carry); - if (carry || u384_cmp(r, BLS_P) >= 0) { - uint64_t bw; - r = u384_sub(r, BLS_P, bw); - } - return r; -} - -__device__ static uint384 fp_sub(uint384 a, uint384 b) { - uint64_t bw; - uint384 r = u384_sub(a, b, bw); - if (bw) { - uint64_t c; - r = u384_add(r, BLS_P, c); - } - return r; -} - -__device__ static uint384 fp_neg(uint384 a) { - if (u384_is_zero(a)) return a; - uint64_t bw; - return u384_sub(BLS_P, a, bw); -} - -__device__ static uint384 to_mont(uint384 a) { - return fp_mul(a, BLS_R2); -} - -__device__ static uint384 from_mont(uint384 a) { - uint64_t t[12] = {a.limbs[0], a.limbs[1], a.limbs[2], - a.limbs[3], a.limbs[4], a.limbs[5], - 0, 0, 0, 0, 0, 0}; - return mont_reduce_384(t); -} - -__device__ static uint384 fp_inv(uint384 a) { - uint384 exp = BLS_P; - exp.limbs[0] -= 2; - - uint384 result = BLS_R; - uint384 base = a; - - for (int i = 0; i < 6; i++) { - for (int bit = 0; bit < 64; bit++) { - if ((exp.limbs[i] >> bit) & 1) { - result = fp_mul(result, base); - } - base = fp_sqr(base); - } - } - return result; -} - -// ============================================================================= -// G1 point operations (Jacobian coordinates, Montgomery Fp) -// ============================================================================= - -struct G1Point { - uint384 x, y, z; -}; - -__device__ static G1Point g1_identity() { - G1Point p; - p.x = BLS_R; - p.y = BLS_R; - p.z = ZERO384; - return p; -} - -__device__ static bool g1_is_infinity(G1Point p) { - return u384_is_zero(p.z); -} - -__device__ static G1Point g1_double(G1Point p) { - if (g1_is_infinity(p)) return p; - - uint384 A = fp_sqr(p.y); - uint384 B = fp_mul(p.x, A); - uint384 C = fp_sqr(A); - - uint384 S = fp_add(B, B); - S = fp_add(S, S); - - uint384 X2 = fp_sqr(p.x); - uint384 M = fp_add(X2, fp_add(X2, X2)); - - uint384 X3 = fp_sub(fp_sqr(M), fp_add(S, S)); - - uint384 C8 = fp_add(C, C); - C8 = fp_add(C8, C8); - C8 = fp_add(C8, C8); - uint384 Y3 = fp_sub(fp_mul(M, fp_sub(S, X3)), C8); - - uint384 Z3 = fp_mul(p.y, p.z); - Z3 = fp_add(Z3, Z3); - - G1Point r; - r.x = X3; r.y = Y3; r.z = Z3; - return r; -} - -__device__ static G1Point g1_add_mixed(G1Point P, uint384 Qx, uint384 Qy) { - if (g1_is_infinity(P)) { - G1Point r; - r.x = Qx; r.y = Qy; r.z = BLS_R; - return r; - } - - uint384 Z2 = fp_sqr(P.z); - uint384 U2 = fp_mul(Qx, Z2); - uint384 Z3 = fp_mul(Z2, P.z); - uint384 S2 = fp_mul(Qy, Z3); - - uint384 H = fp_sub(U2, P.x); - uint384 R = fp_sub(S2, P.y); - - if (u384_is_zero(H)) { - if (u384_is_zero(R)) - return g1_double(P); - return g1_identity(); - } - - uint384 H2 = fp_sqr(H); - uint384 H3 = fp_mul(H, H2); - uint384 U1H2 = fp_mul(P.x, H2); - - uint384 X3 = fp_sub(fp_sub(fp_sqr(R), H3), fp_add(U1H2, U1H2)); - uint384 Y3 = fp_sub(fp_mul(R, fp_sub(U1H2, X3)), fp_mul(P.y, H3)); - uint384 Zr = fp_mul(H, P.z); - - G1Point res; - res.x = X3; res.y = Y3; res.z = Zr; - return res; -} - -__device__ static G1Point g1_mul(uint384 k, uint384 Px, uint384 Py) { - G1Point result = g1_identity(); - - for (int i = 5; i >= 0; i--) { - for (int bit = 63; bit >= 0; bit--) { - result = g1_double(result); - if ((k.limbs[i] >> bit) & 1) { - result = g1_add_mixed(result, Px, Py); - } - } - } - return result; -} - -__device__ static void g1_to_affine(G1Point p, uint384& ax, uint384& ay) { - if (g1_is_infinity(p)) { - ax = ZERO384; ay = ZERO384; - return; - } - uint384 z_inv = fp_inv(p.z); - uint384 z_inv2 = fp_sqr(z_inv); - uint384 z_inv3 = fp_mul(z_inv2, z_inv); - ax = fp_mul(p.x, z_inv2); - ay = fp_mul(p.y, z_inv3); -} - -// ============================================================================= -// BLS signature structures -// ============================================================================= - -struct BLSSignature { - uint8_t data[48]; -}; - -struct BLSPublicKey { - uint8_t data[96]; -}; - -struct BLSMessage { - uint8_t data[32]; -}; - -// ============================================================================= -// Deserialization -// ============================================================================= - -__device__ static uint384 deserialize_fp(const uint8_t* data) { - uint384 r = {}; - for (int limb = 0; limb < 6; limb++) { - uint64_t val = 0; - for (int byte_idx = 0; byte_idx < 8; byte_idx++) { - int src = (5 - limb) * 8 + (7 - byte_idx); - if (src < 48) - val |= (uint64_t)data[src] << (byte_idx * 8); - } - r.limbs[limb] = val; - } - return r; -} - -__device__ static bool decompress_g1(uint384 x_raw, bool y_positive, - uint384& x_mont, uint384& y_mont) { - x_mont = to_mont(x_raw); - - uint384 x2 = fp_sqr(x_mont); - uint384 x3 = fp_mul(x2, x_mont); - uint384 b_mont = to_mont(uint384{{4, 0, 0, 0, 0, 0}}); - uint384 y2 = fp_add(x3, b_mont); - - // sqrt via a^((p+1)/4) since p = 3 mod 4 - uint384 exp = {{ - 0xEE7FBFFFFFFFEAAFULL, - 0x07AAFFFFAC54FFFFULL, - 0xD9CC34A83DAC3D89ULL, - 0xD91DD2E13CE144AFULL, - 0x92C6E9ED90D2EB35ULL, - 0x0680447A8E5FF9A6ULL - }}; - - uint384 y_cand = BLS_R; - uint384 base = y2; - for (int i = 0; i < 6; i++) { - for (int bit = 0; bit < 64; bit++) { - if ((exp.limbs[i] >> bit) & 1) { - y_cand = fp_mul(y_cand, base); - } - base = fp_sqr(base); - } - } - - uint384 check = fp_sqr(y_cand); - if (u384_cmp(check, y2) != 0) - return false; - - uint384 y_normal = from_mont(y_cand); - bool is_positive = (y_normal.limbs[0] & 1) == 0; - if (is_positive != y_positive) { - y_mont = fp_neg(y_cand); - } else { - y_mont = y_cand; - } - - return true; -} - -// ============================================================================= -// BLS Verification kernel -// ============================================================================= - -extern "C" __global__ void bls_verify_batch( - const BLSSignature* __restrict__ sigs, - const BLSPublicKey* __restrict__ pubkeys, - const BLSMessage* __restrict__ messages, - uint32_t* __restrict__ results, - const uint32_t num_sigs) -{ - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= num_sigs) return; - - const uint8_t* sig_data = sigs[tid].data; - - uint8_t flags = sig_data[0]; - bool compressed = (flags & 0x80) != 0; - bool infinity = (flags & 0x40) != 0; - bool y_sign = (flags & 0x20) != 0; - - if (infinity) { - results[tid] = 0; - return; - } - - if (!compressed) { - results[tid] = 0; - return; - } - - uint8_t clean_data[48]; - for (int i = 0; i < 48; i++) clean_data[i] = sig_data[i]; - clean_data[0] &= 0x1F; - - uint384 x_raw = {}; - for (int limb = 0; limb < 6; limb++) { - uint64_t val = 0; - for (int b = 0; b < 8; b++) { - int src = (5 - limb) * 8 + (7 - b); - if (src < 48) - val |= (uint64_t)clean_data[src] << (b * 8); - } - x_raw.limbs[limb] = val; - } - - uint384 sig_x, sig_y; - bool on_curve = decompress_g1(x_raw, !y_sign, sig_x, sig_y); - if (!on_curve) { - results[tid] = 0; - return; - } - - results[tid] = 0x3; // on_curve=1, needs_subgroup_check=1 -} - -// ============================================================================= -// BLS G1 Aggregation kernel -// ============================================================================= - -extern "C" __global__ void bls_aggregate_g1( - const BLSSignature* __restrict__ sigs, - uint8_t* __restrict__ agg_out, - uint32_t* __restrict__ counter, - const uint32_t num_sigs) -{ - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - uint32_t lid = threadIdx.x; - uint32_t tgid = blockIdx.x; - uint32_t tg_size = blockDim.x; - - // Each thread deserializes and decompresses one signature - G1Point local_sum = g1_identity(); - - if (tid < num_sigs) { - const uint8_t* sig_data = sigs[tid].data; - - uint8_t flags = sig_data[0]; - bool infinity = (flags & 0x40) != 0; - bool y_sign = (flags & 0x20) != 0; - - if (!infinity) { - uint8_t clean_data[48]; - for (int i = 0; i < 48; i++) clean_data[i] = sig_data[i]; - clean_data[0] &= 0x1F; - - uint384 x_raw = {}; - for (int limb = 0; limb < 6; limb++) { - uint64_t val = 0; - for (int b = 0; b < 8; b++) { - int src = (5 - limb) * 8 + (7 - b); - if (src < 48) - val |= (uint64_t)clean_data[src] << (b * 8); - } - x_raw.limbs[limb] = val; - } - - uint384 sx, sy; - if (decompress_g1(x_raw, !y_sign, sx, sy)) { - local_sum.x = sx; - local_sum.y = sy; - local_sum.z = BLS_R; - } - } - } - - // Threadgroup reduction via shared memory - __shared__ uint384 shared_x[256]; - __shared__ uint384 shared_y[256]; - __shared__ uint384 shared_z[256]; - - shared_x[lid] = local_sum.x; - shared_y[lid] = local_sum.y; - shared_z[lid] = local_sum.z; - __syncthreads(); - - // Binary reduction - for (uint32_t stride = tg_size / 2; stride > 0; stride >>= 1) { - if (lid < stride) { - G1Point a; - a.x = shared_x[lid]; a.y = shared_y[lid]; a.z = shared_z[lid]; - - G1Point b; - b.x = shared_x[lid + stride]; b.y = shared_y[lid + stride]; b.z = shared_z[lid + stride]; - - if (!g1_is_infinity(b)) { - if (g1_is_infinity(a)) { - a = b; - } else { - uint384 Z1sq = fp_sqr(a.z); - uint384 Z2sq = fp_sqr(b.z); - uint384 U1 = fp_mul(a.x, Z2sq); - uint384 U2 = fp_mul(b.x, Z1sq); - uint384 S1 = fp_mul(a.y, fp_mul(Z2sq, b.z)); - uint384 S2 = fp_mul(b.y, fp_mul(Z1sq, a.z)); - - uint384 H = fp_sub(U2, U1); - uint384 R = fp_sub(S2, S1); - - if (u384_is_zero(H)) { - if (u384_is_zero(R)) { - a = g1_double(a); - } else { - a = g1_identity(); - } - } else { - uint384 H2 = fp_sqr(H); - uint384 H3 = fp_mul(H, H2); - uint384 U1H2 = fp_mul(U1, H2); - a.x = fp_sub(fp_sub(fp_sqr(R), H3), fp_add(U1H2, U1H2)); - a.y = fp_sub(fp_mul(R, fp_sub(U1H2, a.x)), fp_mul(S1, H3)); - a.z = fp_mul(fp_mul(H, a.z), b.z); - } - } - } - - shared_x[lid] = a.x; - shared_y[lid] = a.y; - shared_z[lid] = a.z; - } - __syncthreads(); - } - - // Thread 0 of each block writes partial result - if (lid == 0) { - G1Point partial; - partial.x = shared_x[0]; partial.y = shared_y[0]; partial.z = shared_z[0]; - - uint384 ax, ay; - g1_to_affine(partial, ax, ay); - - uint384 ax_norm = from_mont(ax); - uint384 ay_norm = from_mont(ay); - - uint32_t tg_offset = tgid * 96; - for (int limb = 0; limb < 6; limb++) { - for (int b = 0; b < 8; b++) { - int dst = (5 - limb) * 8 + (7 - b); - if (dst < 48) { - agg_out[tg_offset + dst] = (uint8_t)((ax_norm.limbs[limb] >> (b * 8)) & 0xFF); - agg_out[tg_offset + 48 + dst] = (uint8_t)((ay_norm.limbs[limb] >> (b * 8)) & 0xFF); - } - } - } - - atomicAdd(counter, 1u); - } -} diff --git a/bls/gpu/cuda/bls_combined_miller.cu b/bls/gpu/cuda/bls_combined_miller.cu deleted file mode 100644 index 313ac77..0000000 --- a/bls/gpu/cuda/bls_combined_miller.cu +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CUDA combined-pair Miller-loop pack — peer of bls_combined_miller.metal. -// -// One driver call dispatches the existing per-bit Miller kernels -// (k_miller_init / k_miller_add_T_and_line / k_miller_dbl_T_and_line / -// k_miller_sqr_ret / k_miller_fold_line / k_miller_finalize) over k -// pairs as N=k threads, then folds the k Fp12 outputs to a single -// product via canonical pairwise tree reduction -// (k_combined_miller_reduce below). -// -// Output is the pre-final-exponentiation Fp12 product -// -// prod_i miller_loop(Q_i, P_i) -// -// byte-equal the CPU reference. Caller applies final_exp() once. - -#include "bls_miller.cuh" - -extern "C" { - -// One round of canonical pairwise tree reduction. -// -// pairs == n / 2, carry == n & 1u, threads == pairs + carry. -// tid < pairs : out[tid] = in[2*tid] * in[2*tid+1] -// tid == pairs (carry): out[tid] = in[2*tid] (last element passes through) -// -// Determinism: the index map is canonical and matches -// tree_reduce_fp12 in cpp/bls_pairing.cpp + the Metal/WGSL peers. -__global__ void k_combined_miller_reduce(const Fp12* __restrict__ in_buf, - Fp12* __restrict__ out_buf, - unsigned pairs, - unsigned carry) -{ - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < pairs) { - out_buf[tid] = fp12_mul(in_buf[2u * tid], in_buf[2u * tid + 1u]); - return; - } - if (carry != 0u && tid == pairs) { - out_buf[tid] = in_buf[2u * tid]; - } -} - -} // extern "C" diff --git a/bls/gpu/cuda/bls_combined_miller_driver.cpp b/bls/gpu/cuda/bls_combined_miller_driver.cpp deleted file mode 100644 index 3e44226..0000000 --- a/bls/gpu/cuda/bls_combined_miller_driver.cpp +++ /dev/null @@ -1,179 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CUDA host driver for the combined-pair Miller-loop pack. -// Peer of bls/gpu/metal/bls_combined_miller_driver.mm. -// -// Reuses the same per-bit Miller kernels (k_miller_init, k_miller_add_T_and_line, -// k_miller_dbl_T_and_line, k_miller_sqr_ret, k_miller_fold_line, k_miller_finalize) -// in bls_miller.cu and adds the canonical pairwise tree-reduce kernel -// k_combined_miller_reduce from bls_combined_miller.cu. -// -// Build modes: -// - BLS_HAVE_CUDA defined : real CUDA path (CI runner with GPU). -// - BLS_HAVE_CUDA undefined: stub returns -2 ("CUDA unavailable"). - -#include "bls_combined_miller_driver.h" - -#include -#include -#include - -#ifdef BLS_HAVE_CUDA -#include - -extern "C" { -__global__ void k_miller_init(const void*, void*, void*, void*, unsigned); -__global__ void k_miller_add_T_and_line(const void*, void*, void*, const void*, unsigned); -__global__ void k_miller_dbl_T_and_line(void*, void*, const void*, unsigned); -__global__ void k_miller_sqr_ret(void*, unsigned); -__global__ void k_miller_fold_line(void*, const void*, unsigned); -__global__ void k_miller_finalize(void*, void*, unsigned); -__global__ void k_combined_miller_reduce(const void*, void*, unsigned, unsigned); -} - -namespace { - -constexpr size_t kP1Aff = 96; -constexpr size_t kP2Aff = 192; -constexpr size_t kInRow = kP2Aff + kP1Aff; -constexpr size_t kP2Bytes = 288; -constexpr size_t kFp2Bytes = 96; -constexpr size_t kFp12Bytes = 576; -constexpr size_t kLineBytes = 3 * kFp2Bytes; - -inline unsigned ceil_div(unsigned n, unsigned d) { return (n + d - 1) / d; } - -int device_present() -{ - int n = 0; - cudaError_t e = cudaGetDeviceCount(&n); - return (e == cudaSuccess && n > 0) ? 1 : 0; -} - -} // namespace - -extern "C" int bls_combined_miller_cuda_available(void) -{ - return device_present(); -} - -extern "C" int bls_combined_miller_cuda(const uint8_t* g1s, - const uint8_t* g2s, - size_t k, - uint8_t fp12_out[576]) -{ - if (fp12_out == nullptr) return -1; - if (k == 0) return -1; - if (g1s == nullptr || g2s == nullptr) return -1; - if (!device_present()) return -2; - - // Pack into MillerIn layout: Q || P per workitem. - std::vector in_packed(k * kInRow); - for (size_t i = 0; i < k; ++i) { - std::memcpy(in_packed.data() + i * kInRow, - g2s + i * kP2Aff, kP2Aff); - std::memcpy(in_packed.data() + i * kInRow + kP2Aff, - g1s + i * kP1Aff, kP1Aff); - } - - void *dIn=nullptr, *dT=nullptr, *dRet=nullptr, *dPx2=nullptr; - void *dLine=nullptr, *dMOut=nullptr, *dRed=nullptr; - - auto cleanup = [&]() { - if (dIn) cudaFree(dIn); - if (dT) cudaFree(dT); - if (dRet) cudaFree(dRet); - if (dPx2) cudaFree(dPx2); - if (dLine) cudaFree(dLine); - if (dMOut) cudaFree(dMOut); - if (dRed) cudaFree(dRed); - }; - - auto alloc = [&](void** p, size_t b) { - return cudaMalloc(p, b) == cudaSuccess; - }; - - if (!alloc(&dIn, k * kInRow) || - !alloc(&dT, k * kP2Bytes) || - !alloc(&dRet, k * kFp12Bytes) || - !alloc(&dPx2, k * kFp2Bytes) || - !alloc(&dLine, k * kLineBytes) || - !alloc(&dMOut, k * kFp12Bytes) || - !alloc(&dRed, k * kFp12Bytes)) { - cleanup(); - return -2; - } - - if (cudaMemcpy(dIn, in_packed.data(), k * kInRow, - cudaMemcpyHostToDevice) != cudaSuccess) { - cleanup(); - return -2; - } - - const unsigned tg = 16; - unsigned grid = ceil_div(static_cast(k), tg); - const unsigned phases[5] = { 2u, 3u, 9u, 32u, 16u }; - - // Miller loop on N=k workitems. Each kernel is one bit of the ate - // scalar; per-pair state stays resident in dT/dRet/dLine. - k_miller_init<<>>(dIn, dT, dRet, dPx2, static_cast(k)); - for (int phase = 0; phase < 5; ++phase) { - k_miller_add_T_and_line<<>>(dIn, dT, dLine, dPx2, - static_cast(k)); - k_miller_fold_line<<>>(dRet, dLine, - static_cast(k)); - for (unsigned r = 0; r < phases[phase]; ++r) { - k_miller_sqr_ret<<>>(dRet, static_cast(k)); - k_miller_dbl_T_and_line<<>>(dT, dLine, dPx2, - static_cast(k)); - k_miller_fold_line<<>>(dRet, dLine, - static_cast(k)); - } - } - k_miller_finalize<<>>(dRet, dMOut, static_cast(k)); - - // Canonical Fp12 tree reduction over the k outputs. - void* round_in = dMOut; - void* round_out = dRed; - size_t n = k; - while (n > 1) { - unsigned pairs = static_cast(n / 2); - unsigned carry = static_cast(n & 1u); - unsigned threads = pairs + carry; - unsigned r_grid = ceil_div(threads, tg); - k_combined_miller_reduce<<>>(round_in, round_out, - pairs, carry); - void* tmp = round_in; - round_in = round_out; - round_out = tmp; - n = pairs + carry; - } - - cudaError_t syncErr = cudaDeviceSynchronize(); - if (syncErr != cudaSuccess) { - cleanup(); - return -2; - } - - // round_in points at the buffer holding the final Fp12 product at slot 0. - if (cudaMemcpy(fp12_out, round_in, kFp12Bytes, - cudaMemcpyDeviceToHost) != cudaSuccess) { - cleanup(); - return -2; - } - - cleanup(); - return 0; -} - -#else // BLS_HAVE_CUDA undefined — stub mode. - -extern "C" int bls_combined_miller_cuda_available(void) { return 0; } - -extern "C" int bls_combined_miller_cuda(const uint8_t*, const uint8_t*, - size_t, uint8_t[576]) { - return -2; -} - -#endif // BLS_HAVE_CUDA diff --git a/bls/gpu/cuda/bls_combined_miller_driver.h b/bls/gpu/cuda/bls_combined_miller_driver.h deleted file mode 100644 index 57c3627..0000000 --- a/bls/gpu/cuda/bls_combined_miller_driver.h +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Host-driver entry for the combined-pair Miller-loop on CUDA. -// Peer of bls/gpu/metal/bls_combined_miller_driver.h. -// -// Output is the pre-final-exponentiation Fp12 product -// -// prod_i miller_loop(Q_i, P_i) for i in 0..k -// -// byte-equal the canonical CPU reference. Caller applies final_exp() once. - -#pragma once - -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -// 0 if a CUDA device is present and the driver loaded successfully; -// 1 if CUDA is not available on this build (stub mode). -int bls_combined_miller_cuda_available(void); - -// Combined-pair Miller-loop on CUDA. -// -// g1s : k * 96 bytes (uncompressed G1 affines, blst_p1_affine layout) -// g2s : k * 192 bytes (uncompressed G2 affines, blst_p2_affine layout) -// k : number of pairs (>= 1) -// fp12_out : 576-byte Fp12 product (pre-final-exponentiation) -// -// Returns: -// 0 on success -// -1 on input error -// -2 on CUDA unavailable (stub mode or runtime init failure) -int bls_combined_miller_cuda(const uint8_t* g1s, - const uint8_t* g2s, - size_t k, - uint8_t fp12_out[576]); - -#ifdef __cplusplus -} -#endif diff --git a/bls/gpu/cuda/bls_driver_cuda.cpp b/bls/gpu/cuda/bls_driver_cuda.cpp deleted file mode 100644 index 7751a5e..0000000 --- a/bls/gpu/cuda/bls_driver_cuda.cpp +++ /dev/null @@ -1,278 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CUDA host driver for BLS12-381 pairing kernels. -// -// Build modes: -// 1. With CUDA toolkit (BLS_HAVE_CUDA defined): -// - Loads PTX from a .cubin / .fatbin or invokes nvrtc to JIT compile, -// then dispatches the same kernel sequence as the Metal driver. -// - Mirrors run_miller / run_final_exp / run_pairing from -// bls_pairing_test.mm exactly so byte-equality holds. -// -// 2. Without CUDA (BLS_HAVE_CUDA not defined): -// - Provides stub functions that return -1 ("CUDA unavailable on this host"). -// - Test harness skips CUDA path and prints "[CUDA built; skipped on Apple]". -// -// Per-stage acceptance: the CI runner with a real CUDA device runs ctest with -// BLS_HAVE_CUDA=ON, which exercises the full byte-equality path against -// the same vectors_*.h headers Metal uses. On Apple, this file participates -// only in the build (proving headers are syntactically valid C++ for portable -// hosts) and is skipped at runtime. - -#include "bls_driver_cuda.h" - -#include -#include -#include -#include - -#ifdef BLS_HAVE_CUDA -#include - -// Forward declarations of CUDA kernels. Defined in the .cu files compiled by nvcc. -extern "C" { -// Fp2/Fp6/Fp12 (called by harness directly for unit tests) -__global__ void k_fp2_add(const void*, const void*, void*, unsigned); -__global__ void k_fp2_sub(const void*, const void*, void*, unsigned); -__global__ void k_fp2_mul(const void*, const void*, void*, unsigned); -__global__ void k_fp2_sqr(const void*, void*, unsigned); -__global__ void k_fp2_inv(const void*, void*, unsigned); -__global__ void k_fp2_conj(const void*, void*, unsigned); -__global__ void k_fp_inv_diag(const void*, void*, unsigned); -__global__ void k_fp6_add(const void*, const void*, void*, unsigned); -__global__ void k_fp6_sub(const void*, const void*, void*, unsigned); -__global__ void k_fp6_mul(const void*, const void*, void*, unsigned); -__global__ void k_fp6_sqr(const void*, void*, unsigned); -__global__ void k_fp6_inv(const void*, void*, unsigned); -__global__ void k_fp12_add(const void*, const void*, void*, unsigned); -__global__ void k_fp12_sub(const void*, const void*, void*, unsigned); -__global__ void k_fp12_mul(const void*, const void*, void*, unsigned); -__global__ void k_fp12_sqr(const void*, void*, unsigned); -__global__ void k_fp12_inv(const void*, void*, unsigned); -__global__ void k_fp12_conj(const void*, void*, unsigned); -__global__ void k_fp12_cyclo_sqr(const void*, void*, unsigned); -// G2 -__global__ void k_p2_jac_add(const void*, const void*, void*, unsigned); -__global__ void k_p2_jac_dbl(const void*, void*, unsigned); -__global__ void k_p2_mixed_add(const void*, const void*, void*, unsigned); -__global__ void k_p2_scalar_mult(const void*, void*, unsigned, unsigned); -// Miller -__global__ void k_miller_init(const void*, void*, void*, void*, unsigned); -__global__ void k_miller_add_T_and_line(const void*, void*, void*, const void*, unsigned); -__global__ void k_miller_dbl_T_and_line(void*, void*, const void*, unsigned); -__global__ void k_miller_sqr_ret(void*, unsigned); -__global__ void k_miller_fold_line(void*, const void*, unsigned); -__global__ void k_miller_finalize(void*, void*, unsigned); -// final_exp -__global__ void k_fe_inv(const void*, void*, unsigned); -__global__ void k_fe_cyclo_sqr(void*, unsigned); -__global__ void k_fe_mul(const void*, const void*, void*, unsigned); -__global__ void k_fe_conj(void*, unsigned); -__global__ void k_fe_frobenius(void*, unsigned, unsigned); -__global__ void k_fe_copy(const void*, void*, unsigned); -// pairing -__global__ void k_pair_one_init(void*, unsigned); -__global__ void k_pair_aggregate_step(const void*, void*, unsigned, unsigned); -__global__ void k_pair_eq_one(const void*, unsigned char*, unsigned); -} - -namespace bls_cuda { - -static unsigned compute_grid(unsigned n, unsigned tg) { return (n + tg - 1) / tg; } - -static int device_available() { - int count = 0; - cudaError_t e = cudaGetDeviceCount(&count); - return (e == cudaSuccess && count > 0) ? 1 : 0; -} - -// Generic dispatcher: launch a 3-buffer (a, b, out) kernel. -template -static int dispatch3(Fn kernel, const void* a, const void* b, void* out, - size_t bytes, unsigned n) { - void *dA = nullptr, *dB = nullptr, *dO = nullptr; - if (cudaMalloc(&dA, bytes) != cudaSuccess) return -1; - if (cudaMalloc(&dB, bytes) != cudaSuccess) { cudaFree(dA); return -1; } - if (cudaMalloc(&dO, bytes) != cudaSuccess) { cudaFree(dA); cudaFree(dB); return -1; } - cudaMemcpy(dA, a, bytes, cudaMemcpyHostToDevice); - cudaMemcpy(dB, b, bytes, cudaMemcpyHostToDevice); - unsigned tg = 32; unsigned grid = compute_grid(n, tg); - kernel<<>>(dA, dB, dO, n); - cudaDeviceSynchronize(); - cudaMemcpy(out, dO, bytes, cudaMemcpyDeviceToHost); - cudaFree(dA); cudaFree(dB); cudaFree(dO); - return 0; -} - -template -static int dispatch2(Fn kernel, const void* a, void* out, size_t bytes, unsigned n) { - void *dA = nullptr, *dO = nullptr; - if (cudaMalloc(&dA, bytes) != cudaSuccess) return -1; - if (cudaMalloc(&dO, bytes) != cudaSuccess) { cudaFree(dA); return -1; } - cudaMemcpy(dA, a, bytes, cudaMemcpyHostToDevice); - unsigned tg = 32; unsigned grid = compute_grid(n, tg); - kernel<<>>(dA, dO, n); - cudaDeviceSynchronize(); - cudaMemcpy(out, dO, bytes, cudaMemcpyDeviceToHost); - cudaFree(dA); cudaFree(dO); - return 0; -} - -} // namespace bls_cuda - -extern "C" { - -int bls_cuda_available(void) { return bls_cuda::device_available(); } - -int bls_cuda_fp2_mul(const void* a, const void* b, void* out, unsigned n) { - return bls_cuda::dispatch3(k_fp2_mul, a, b, out, 96 * n, n); -} -int bls_cuda_fp12_mul(const void* a, const void* b, void* out, unsigned n) { - return bls_cuda::dispatch3(k_fp12_mul, a, b, out, 576 * n, n); -} - -// Full pairing entry point. Inputs match Metal layout exactly: -// in: array of N * (P2Aff || P1Aff) = N * (192 + 96) = N * 288 bytes -// out: array of N * Fp12 = N * 576 bytes -// -// CI dispatches the same kernel sequence as run_pairing in -// bls_pairing_test.mm: miller_loop (6 kernels) -> final_exp (5 kernels). -// On byte-equality CI the result is byte-equal Metal and CPU oracle. -int bls_cuda_pairing(const void* in_buf, void* out_buf, unsigned N) { - if (!bls_cuda::device_available()) return -1; - constexpr size_t kP2Bytes = 288; - constexpr size_t kFp2Bytes = 96; - constexpr size_t kFp12Bytes = 576; - constexpr size_t kP2Aff = 192; - constexpr size_t kP1Aff = 96; - constexpr size_t kInRow = kP2Aff + kP1Aff; - constexpr size_t kLineBytes = 3 * kFp2Bytes; - - void *dIn=nullptr,*dT=nullptr,*dRet=nullptr,*dPx2=nullptr,*dLine=nullptr; - void *dMillerOut=nullptr,*dY0=nullptr,*dY1=nullptr,*dY2=nullptr,*dY3=nullptr,*dTmp=nullptr; - - auto alloc = [&](void** p, size_t b) { return cudaMalloc(p, b) == cudaSuccess; }; - - if (!alloc(&dIn, N*kInRow) || - !alloc(&dT, N*kP2Bytes) || - !alloc(&dRet, N*kFp12Bytes) || - !alloc(&dPx2, N*kFp2Bytes) || - !alloc(&dLine, N*kLineBytes) || - !alloc(&dMillerOut, N*kFp12Bytes) || - !alloc(&dY0, N*kFp12Bytes) || - !alloc(&dY1, N*kFp12Bytes) || - !alloc(&dY2, N*kFp12Bytes) || - !alloc(&dY3, N*kFp12Bytes) || - !alloc(&dTmp, N*kFp12Bytes)) { - return -1; - } - - cudaMemcpy(dIn, in_buf, N*kInRow, cudaMemcpyHostToDevice); - - unsigned tg = 16; - unsigned grid = bls_cuda::compute_grid(N, tg); - - // Miller-loop phase doubling counts (same as Metal). - const unsigned kPhases[5] = { 2u, 3u, 9u, 32u, 16u }; - - // init - k_miller_init<<>>(dIn, dT, dRet, dPx2, N); - cudaDeviceSynchronize(); - - for (int phase = 0; phase < 5; phase++) { - k_miller_add_T_and_line<<>>(dIn, dT, dLine, dPx2, N); - cudaDeviceSynchronize(); - k_miller_fold_line<<>>(dRet, dLine, N); - cudaDeviceSynchronize(); - for (unsigned k = 0; k < kPhases[phase]; k++) { - k_miller_sqr_ret<<>>(dRet, N); - cudaDeviceSynchronize(); - k_miller_dbl_T_and_line<<>>(dT, dLine, dPx2, N); - cudaDeviceSynchronize(); - k_miller_fold_line<<>>(dRet, dLine, N); - cudaDeviceSynchronize(); - } - } - k_miller_finalize<<>>(dRet, dMillerOut, N); - cudaDeviceSynchronize(); - - // final_exp easy part - k_fe_copy<<>>(dMillerOut, dY1, N); cudaDeviceSynchronize(); - k_fe_conj<<>>(dY1, N); cudaDeviceSynchronize(); - k_fe_inv<<>>(dMillerOut, dY2, N); cudaDeviceSynchronize(); - k_fe_mul<<>>(dY1, dY2, dRet, N); cudaDeviceSynchronize(); - k_fe_copy<<>>(dRet, dY2, N); cudaDeviceSynchronize(); - k_fe_frobenius<<>>(dY2, N, 2u); cudaDeviceSynchronize(); - k_fe_mul<<>>(dRet, dY2, dTmp, N); cudaDeviceSynchronize(); - k_fe_copy<<>>(dTmp, dRet, N); cudaDeviceSynchronize(); - - // hard part: see bls_pairing_test.mm:run_final_exp() — same dispatch order. - auto cyclo_sqr = [&](void* b) { k_fe_cyclo_sqr<<>>(b, N); cudaDeviceSynchronize(); }; - auto fe_mul = [&](void* a, void* b, void* c) { - k_fe_mul<<>>(a, b, c, N); cudaDeviceSynchronize(); - }; - auto fe_copy = [&](void* s, void* d) { - k_fe_copy<<>>(s, d, N); cudaDeviceSynchronize(); - }; - auto fe_conj = [&](void* b) { k_fe_conj<<>>(b, N); cudaDeviceSynchronize(); }; - auto fe_frob = [&](void* b, unsigned n_pow) { - k_fe_frobenius<<>>(b, N, n_pow); cudaDeviceSynchronize(); - }; - - auto raise_to_z_div_2 = [&](void* out, void* a, void* tmp) { - fe_copy(a, out); - cyclo_sqr(out); - auto mul_n_sqr = [&](unsigned n) { - fe_mul(out, a, tmp); fe_copy(tmp, out); - for (unsigned i = 0; i < n; i++) cyclo_sqr(out); - }; - mul_n_sqr(2); mul_n_sqr(3); mul_n_sqr(9); - mul_n_sqr(32); mul_n_sqr(15); - fe_conj(out); - }; - auto raise_to_z = [&](void* out, void* a, void* tmp) { - raise_to_z_div_2(out, a, tmp); cyclo_sqr(out); - }; - - fe_copy(dRet, dY0); cyclo_sqr(dY0); - raise_to_z(dY1, dY0, dTmp); - raise_to_z_div_2(dY2, dY1, dTmp); - fe_copy(dRet, dY3); fe_conj(dY3); - fe_mul(dY1, dY3, dTmp); fe_copy(dTmp, dY1); - fe_conj(dY1); - fe_mul(dY1, dY2, dTmp); fe_copy(dTmp, dY1); - raise_to_z(dY2, dY1, dTmp); - raise_to_z(dY3, dY2, dTmp); - fe_conj(dY1); - fe_mul(dY3, dY1, dTmp); fe_copy(dTmp, dY3); - fe_conj(dY1); - fe_frob(dY1, 3u); - fe_frob(dY2, 2u); - fe_mul(dY1, dY2, dTmp); fe_copy(dTmp, dY1); - raise_to_z(dY2, dY3, dTmp); - fe_mul(dY2, dY0, dTmp); fe_copy(dTmp, dY2); - fe_mul(dY2, dRet, dTmp); fe_copy(dTmp, dY2); - fe_mul(dY1, dY2, dTmp); fe_copy(dTmp, dY1); - fe_copy(dY3, dY2); - fe_frob(dY2, 1u); - fe_mul(dY1, dY2, dTmp); - cudaMemcpy(out_buf, dTmp, N*kFp12Bytes, cudaMemcpyDeviceToHost); - - cudaFree(dIn); cudaFree(dT); cudaFree(dRet); cudaFree(dPx2); cudaFree(dLine); - cudaFree(dMillerOut); cudaFree(dY0); cudaFree(dY1); cudaFree(dY2); cudaFree(dY3); cudaFree(dTmp); - return 0; -} - -} // extern "C" - -#else // BLS_HAVE_CUDA not defined: stub mode - -extern "C" { -int bls_cuda_available(void) { return 0; } -int bls_cuda_fp2_mul(const void*, const void*, void*, unsigned) { return -1; } -int bls_cuda_fp12_mul(const void*, const void*, void*, unsigned) { return -1; } -int bls_cuda_pairing(const void*, void*, unsigned) { return -1; } -} - -#endif // BLS_HAVE_CUDA diff --git a/bls/gpu/cuda/bls_driver_cuda.h b/bls/gpu/cuda/bls_driver_cuda.h deleted file mode 100644 index 008b978..0000000 --- a/bls/gpu/cuda/bls_driver_cuda.h +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Public C-ABI interface for the CUDA driver. The function names mirror the -// kernel set used by the Metal driver. On non-CUDA hosts every function -// returns -1 except bls_cuda_available() which returns 0. - -#pragma once -#ifdef __cplusplus -extern "C" { -#endif - -// 1 if a CUDA device is present and the runtime initialised successfully. -int bls_cuda_available(void); - -// Convenience batch ops used by the unit-test harness. -int bls_cuda_fp2_mul(const void* a, const void* b, void* out, unsigned n); -int bls_cuda_fp12_mul(const void* a, const void* b, void* out, unsigned n); - -// Full pairing. in is N * (P2Aff||P1Aff) = N * 288 bytes. out is N * Fp12 = N * 576 bytes. -int bls_cuda_pairing(const void* in_buf, void* out_buf, unsigned N); - -#ifdef __cplusplus -} -#endif diff --git a/bls/gpu/cuda/bls_final_exp.cu b/bls/gpu/cuda/bls_final_exp.cu deleted file mode 100644 index 70d4de5..0000000 --- a/bls/gpu/cuda/bls_final_exp.cu +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CUDA final-exp kernels — mirror bls_final_exp.metal 1:1. - -#include "bls_fp12.cuh" - -extern "C" { - -__global__ void k_fe_inv(const Fp12* __restrict__ in_buf, Fp12* __restrict__ out_buf, unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - out_buf[tid] = fp12_inv(in_buf[tid]); -} - -__global__ void k_fe_cyclo_sqr(Fp12* __restrict__ ret_buf, unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - ret_buf[tid] = fp12_cyclotomic_sqr(ret_buf[tid]); -} - -__global__ void k_fe_mul(const Fp12* __restrict__ a, const Fp12* __restrict__ b, - Fp12* __restrict__ out, unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - out[tid] = fp12_mul(a[tid], b[tid]); -} - -__global__ void k_fe_conj(Fp12* __restrict__ ret_buf, unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - ret_buf[tid] = fp12_conj(ret_buf[tid]); -} - -__global__ void k_fe_frobenius(Fp12* __restrict__ ret_buf, unsigned n, unsigned n_pow) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - ret_buf[tid] = fp12_frobenius(ret_buf[tid], n_pow); -} - -__global__ void k_fe_copy(const Fp12* __restrict__ src, Fp12* __restrict__ dst, unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - dst[tid] = src[tid]; -} - -} // extern "C" diff --git a/bls/gpu/cuda/bls_fp12.cu b/bls/gpu/cuda/bls_fp12.cu deleted file mode 100644 index 07edae7..0000000 --- a/bls/gpu/cuda/bls_fp12.cu +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco - -#include "bls_fp12.cuh" - -extern "C" { - -__global__ void k_fp12_add(const Fp12* __restrict__ a, const Fp12* __restrict__ b, - Fp12* __restrict__ out, unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - out[tid] = fp12_add(a[tid], b[tid]); -} - -__global__ void k_fp12_sub(const Fp12* __restrict__ a, const Fp12* __restrict__ b, - Fp12* __restrict__ out, unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - out[tid] = fp12_sub(a[tid], b[tid]); -} - -__global__ void k_fp12_mul(const Fp12* __restrict__ a, const Fp12* __restrict__ b, - Fp12* __restrict__ out, unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - out[tid] = fp12_mul(a[tid], b[tid]); -} - -__global__ void k_fp12_sqr(const Fp12* __restrict__ a, Fp12* __restrict__ out, unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - out[tid] = fp12_sqr(a[tid]); -} - -__global__ void k_fp12_inv(const Fp12* __restrict__ a, Fp12* __restrict__ out, unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - out[tid] = fp12_inv(a[tid]); -} - -__global__ void k_fp12_conj(const Fp12* __restrict__ a, Fp12* __restrict__ out, unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - out[tid] = fp12_conj(a[tid]); -} - -__global__ void k_fp12_cyclo_sqr(const Fp12* __restrict__ a, Fp12* __restrict__ out, unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - out[tid] = fp12_cyclotomic_sqr(a[tid]); -} - -} // extern "C" diff --git a/bls/gpu/cuda/bls_fp12.cuh b/bls/gpu/cuda/bls_fp12.cuh deleted file mode 100644 index b76277f..0000000 --- a/bls/gpu/cuda/bls_fp12.cuh +++ /dev/null @@ -1,191 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CUDA port of bls_fp12.metal — Fp12 = Fp6[w] / (w^2 - v). -// Layout matches blst_fp12 byte-for-byte. - -#ifndef BLS_FP12_CUH -#define BLS_FP12_CUH - -#include "bls_fp6.cuh" - -struct Fp12 { Fp6 c0, c1; }; - -__device__ __forceinline__ Fp12 fp12_add(Fp12 a, Fp12 b) { - Fp12 r; - r.c0 = fp6_add(a.c0, b.c0); - r.c1 = fp6_add(a.c1, b.c1); - return r; -} - -__device__ __forceinline__ Fp12 fp12_sub(Fp12 a, Fp12 b) { - Fp12 r; - r.c0 = fp6_sub(a.c0, b.c0); - r.c1 = fp6_sub(a.c1, b.c1); - return r; -} - -__device__ __forceinline__ Fp6 fp6_mul_by_v(Fp6 a) { - Fp6 r; - r.c0 = fp2_mul_by_1_plus_u(a.c2); - r.c1 = a.c0; - r.c2 = a.c1; - return r; -} - -__device__ __forceinline__ Fp12 fp12_mul(Fp12 a, Fp12 b) { - Fp6 t0 = fp6_mul(a.c0, b.c0); - Fp6 t1 = fp6_mul(a.c1, b.c1); - - Fp6 sa = fp6_add(a.c0, a.c1); - Fp6 sb = fp6_add(b.c0, b.c1); - Fp6 r1 = fp6_mul(sa, sb); - r1 = fp6_sub(r1, t0); - r1 = fp6_sub(r1, t1); - - Fp6 r0 = fp6_add(t0, fp6_mul_by_v(t1)); - - Fp12 r; r.c0 = r0; r.c1 = r1; return r; -} - -__device__ __forceinline__ Fp12 fp12_sqr(Fp12 a) { - Fp6 t0 = fp6_add(a.c0, a.c1); - Fp6 t1 = fp6_mul_by_v(a.c1); - t1 = fp6_add(a.c0, t1); - t0 = fp6_mul(t0, t1); - - Fp6 t2 = fp6_mul(a.c0, a.c1); - - Fp12 r; - r.c1 = fp6_add(t2, t2); - - Fp6 r0 = fp6_sub(t0, t2); - r0 = fp6_sub(r0, fp6_mul_by_v(t2)); - r.c0 = r0; - return r; -} - -__device__ __forceinline__ Fp12 fp12_conj(Fp12 a) { - Fp12 r; - r.c0 = a.c0; - r.c1 = fp6_neg(a.c1); - return r; -} - -__device__ __forceinline__ Fp12 fp12_inv(Fp12 a) { - Fp6 t0 = fp6_sqr(a.c0); - Fp6 t1 = fp6_sqr(a.c1); - t0 = fp6_sub(t0, fp6_mul_by_v(t1)); - Fp6 ti = fp6_inv(t0); - - Fp12 r; - r.c0 = fp6_mul(a.c0, ti); - r.c1 = fp6_mul(a.c1, ti); - r.c1 = fp6_neg(r.c1); - return r; -} - -__device__ __forceinline__ void sqr_fp4(Fp2& r0, Fp2& r1, Fp2 a0, Fp2 a1) { - Fp2 t0 = fp2_sqr(a0); - Fp2 t1 = fp2_sqr(a1); - Fp2 sum = fp2_add(a0, a1); - - r0 = fp2_add(fp2_mul_by_1_plus_u(t1), t0); - - r1 = fp2_sqr(sum); - r1 = fp2_sub(r1, t0); - r1 = fp2_sub(r1, t1); -} - -__device__ __forceinline__ Fp12 fp12_cyclotomic_sqr(Fp12 a) { - Fp2 t00, t01, t10, t11, t20, t21; - sqr_fp4(t00, t01, a.c0.c0, a.c1.c1); - sqr_fp4(t10, t11, a.c1.c0, a.c0.c2); - sqr_fp4(t20, t21, a.c0.c1, a.c1.c2); - - Fp12 r; - Fp2 tmp = fp2_sub(t00, a.c0.c0); - r.c0.c0 = fp2_add(fp2_add(tmp, tmp), t00); - - tmp = fp2_sub(t10, a.c0.c1); - r.c0.c1 = fp2_add(fp2_add(tmp, tmp), t10); - - tmp = fp2_sub(t20, a.c0.c2); - r.c0.c2 = fp2_add(fp2_add(tmp, tmp), t20); - - tmp = fp2_mul_by_1_plus_u(t21); - Fp2 add = fp2_add(tmp, a.c1.c0); - r.c1.c0 = fp2_add(fp2_add(add, add), tmp); - - add = fp2_add(t01, a.c1.c1); - r.c1.c1 = fp2_add(fp2_add(add, add), t01); - - add = fp2_add(t11, a.c1.c2); - r.c1.c2 = fp2_add(fp2_add(add, add), t11); - - return r; -} - -// Frobenius coefficients for Fp12 — verbatim from blst. -__device__ __forceinline__ static uint384 FP12_FROB_RE_N1_dev() { - uint384 r = {{ - 0x07089552B319D465ULL, 0xC6695F92B50A8313ULL, 0x97E83CCCD117228FULL, - 0xA35BAECAB2DC29EEULL, 0x1CE393EA5DAACE4DULL, 0x08F2220FB0FB66EBULL - }}; return r; -} -__device__ __forceinline__ static uint384 FP12_FROB_IM_N1_dev() { - uint384 r = {{ - 0xB2F66AAD4CE5D646ULL, 0x5842A06BFC497CECULL, 0xCF4895D42599D394ULL, - 0xC11B9CBA40A8E8D0ULL, 0x2E3813CBE5A0DE89ULL, 0x110EEFDA88847FAFULL - }}; return r; -} -__device__ __forceinline__ static uint384 FP12_FROB_RE_N2_dev() { - uint384 r = {{ - 0xECFB361B798DBA3AULL, 0xC100DDB891865A2CULL, 0x0EC08FF1232BDA8EULL, - 0xD5C13CC6F1CA4721ULL, 0x47222A47BF7B5C04ULL, 0x0110F184E51C5F59ULL - }}; return r; -} -__device__ __forceinline__ static uint384 FP12_FROB_IM_N2_dev() { - uint384 r = {{0,0,0,0,0,0}}; return r; -} -__device__ __forceinline__ static uint384 FP12_FROB_RE_N3_dev() { - uint384 r = {{ - 0x3E2F585DA55C9AD1ULL, 0x4294213D86C18183ULL, 0x382844C88B623732ULL, - 0x92AD2AFD19103E18ULL, 0x1D794E4FAC7CF0B9ULL, 0x0BD592FC7D825EC8ULL - }}; return r; -} -__device__ __forceinline__ static uint384 FP12_FROB_IM_N3_dev() { - uint384 r = {{ - 0x7BCFA7A25AA30FDAULL, 0xDC17DEC12A927E7CULL, 0x2F088DD86B4EBEF1ULL, - 0xD1CA2087DA74D4A7ULL, 0x2DA2596696CEBC1DULL, 0x0E2B7EEDBBFD87D2ULL - }}; return r; -} - -__device__ __forceinline__ Fp12 fp12_frobenius(Fp12 a, unsigned n) { - Fp6 r0 = fp6_frobenius(a.c0, n); - Fp6 r1 = fp6_frobenius(a.c1, n); - - Fp2 coeff; - if (n == 1u) { - coeff.c0 = FP12_FROB_RE_N1_dev(); coeff.c1 = FP12_FROB_IM_N1_dev(); - } else if (n == 2u) { - coeff.c0 = FP12_FROB_RE_N2_dev(); coeff.c1 = FP12_FROB_IM_N2_dev(); - } else { - coeff.c0 = FP12_FROB_RE_N3_dev(); coeff.c1 = FP12_FROB_IM_N3_dev(); - } - r1.c0 = fp2_mul(r1.c0, coeff); - r1.c1 = fp2_mul(r1.c1, coeff); - r1.c2 = fp2_mul(r1.c2, coeff); - - Fp12 r; r.c0 = r0; r.c1 = r1; return r; -} - -__device__ __forceinline__ Fp12 fp12_one() { - Fp2 zerop = fp2_zero(); - Fp2 onep = fp2_one(); - Fp6 one6; one6.c0 = onep; one6.c1 = zerop; one6.c2 = zerop; - Fp6 zero6; zero6.c0 = zerop; zero6.c1 = zerop; zero6.c2 = zerop; - Fp12 r; r.c0 = one6; r.c1 = zero6; return r; -} - -#endif // BLS_FP12_CUH diff --git a/bls/gpu/cuda/bls_fp2.cu b/bls/gpu/cuda/bls_fp2.cu deleted file mode 100644 index 90b7cdf..0000000 --- a/bls/gpu/cuda/bls_fp2.cu +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CUDA kernels for Fp2 — mirror Metal k_fp2_* kernels 1:1 (same buffer layout). - -#include "bls_fp2.cuh" - -extern "C" { - -__global__ void k_fp2_add(const Fp2* __restrict__ a, const Fp2* __restrict__ b, - Fp2* __restrict__ out, unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - out[tid] = fp2_add(a[tid], b[tid]); -} - -__global__ void k_fp2_sub(const Fp2* __restrict__ a, const Fp2* __restrict__ b, - Fp2* __restrict__ out, unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - out[tid] = fp2_sub(a[tid], b[tid]); -} - -__global__ void k_fp2_mul(const Fp2* __restrict__ a, const Fp2* __restrict__ b, - Fp2* __restrict__ out, unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - out[tid] = fp2_mul(a[tid], b[tid]); -} - -__global__ void k_fp2_sqr(const Fp2* __restrict__ a, Fp2* __restrict__ out, unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - out[tid] = fp2_sqr(a[tid]); -} - -__global__ void k_fp2_inv(const Fp2* __restrict__ a, Fp2* __restrict__ out, unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - out[tid] = fp2_inv(a[tid]); -} - -__global__ void k_fp2_conj(const Fp2* __restrict__ a, Fp2* __restrict__ out, unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - out[tid] = fp2_conj(a[tid]); -} - -// Diagnostic: raw Fp inversion. Reads first 48 B of Fp2 as Fp, returns inv in c0, zero in c1. -__global__ void k_fp_inv_diag(const Fp2* __restrict__ a, Fp2* __restrict__ out, unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - Fp2 r; - r.c0 = fp_inv(a[tid].c0); - r.c1 = ZERO384_dev(); - out[tid] = r; -} - -} // extern "C" diff --git a/bls/gpu/cuda/bls_fp2.cuh b/bls/gpu/cuda/bls_fp2.cuh deleted file mode 100644 index 369d3ea..0000000 --- a/bls/gpu/cuda/bls_fp2.cuh +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CUDA port of bls_fp2.metal — Fp2 = Fp[u]/(u^2 + 1). -// Layout matches blst_fp2 byte-for-byte. - -#ifndef BLS_FP2_CUH -#define BLS_FP2_CUH - -#include "bls_fp_ops.cuh" - -struct Fp2 { uint384 c0, c1; }; - -__device__ __forceinline__ Fp2 fp2_add(Fp2 a, Fp2 b) { - Fp2 r; r.c0 = fp_add(a.c0, b.c0); r.c1 = fp_add(a.c1, b.c1); return r; -} - -__device__ __forceinline__ Fp2 fp2_sub(Fp2 a, Fp2 b) { - Fp2 r; r.c0 = fp_sub(a.c0, b.c0); r.c1 = fp_sub(a.c1, b.c1); return r; -} - -__device__ __forceinline__ Fp2 fp2_neg(Fp2 a) { - Fp2 r; r.c0 = fp_neg(a.c0); r.c1 = fp_neg(a.c1); return r; -} - -__device__ __forceinline__ Fp2 fp2_mul(Fp2 a, Fp2 b) { - uint384 aa = fp_mul(a.c0, b.c0); - uint384 bb = fp_mul(a.c1, b.c1); - uint384 sa = fp_add(a.c0, a.c1); - uint384 sb = fp_add(b.c0, b.c1); - uint384 cross = fp_mul(sa, sb); - Fp2 r; - r.c0 = fp_sub(aa, bb); - r.c1 = fp_sub(fp_sub(cross, aa), bb); - return r; -} - -__device__ __forceinline__ Fp2 fp2_sqr(Fp2 a) { - uint384 ab = fp_mul(a.c0, a.c1); - uint384 sum = fp_add(a.c0, a.c1); - uint384 dif = fp_sub(a.c0, a.c1); - Fp2 r; - r.c0 = fp_mul(sum, dif); - r.c1 = fp_add(ab, ab); - return r; -} - -__device__ __forceinline__ Fp2 fp2_conj(Fp2 a) { - Fp2 r; r.c0 = a.c0; r.c1 = fp_neg(a.c1); return r; -} - -__device__ __forceinline__ Fp2 fp2_inv(Fp2 a) { - uint384 t0 = fp_sqr(a.c0); - uint384 t1 = fp_sqr(a.c1); - uint384 norm = fp_add(t0, t1); - uint384 ni = fp_inv(norm); - Fp2 r; - r.c0 = fp_mul(a.c0, ni); - r.c1 = fp_neg(fp_mul(a.c1, ni)); - return r; -} - -__device__ __forceinline__ Fp2 fp2_frobenius(Fp2 a, unsigned n) { - return ((n & 1u) == 1u) ? fp2_conj(a) : a; -} - -__device__ __forceinline__ Fp2 fp2_mul_by_1_plus_u(Fp2 a) { - Fp2 r; - r.c0 = fp_sub(a.c0, a.c1); - r.c1 = fp_add(a.c0, a.c1); - return r; -} - -__device__ __forceinline__ bool fp2_is_zero(Fp2 a) { - return u384_is_zero(a.c0) && u384_is_zero(a.c1); -} - -__device__ __forceinline__ Fp2 fp2_one() { - Fp2 r; r.c0 = BLS_R_dev(); r.c1 = ZERO384_dev(); return r; -} - -__device__ __forceinline__ Fp2 fp2_zero() { - Fp2 r; r.c0 = ZERO384_dev(); r.c1 = ZERO384_dev(); return r; -} - -#endif // BLS_FP2_CUH diff --git a/bls/gpu/cuda/bls_fp6.cu b/bls/gpu/cuda/bls_fp6.cu deleted file mode 100644 index bb0b398..0000000 --- a/bls/gpu/cuda/bls_fp6.cu +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco - -#include "bls_fp6.cuh" - -extern "C" { - -__global__ void k_fp6_add(const Fp6* __restrict__ a, const Fp6* __restrict__ b, - Fp6* __restrict__ out, unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - out[tid] = fp6_add(a[tid], b[tid]); -} - -__global__ void k_fp6_sub(const Fp6* __restrict__ a, const Fp6* __restrict__ b, - Fp6* __restrict__ out, unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - out[tid] = fp6_sub(a[tid], b[tid]); -} - -__global__ void k_fp6_mul(const Fp6* __restrict__ a, const Fp6* __restrict__ b, - Fp6* __restrict__ out, unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - out[tid] = fp6_mul(a[tid], b[tid]); -} - -__global__ void k_fp6_sqr(const Fp6* __restrict__ a, Fp6* __restrict__ out, unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - out[tid] = fp6_sqr(a[tid]); -} - -__global__ void k_fp6_inv(const Fp6* __restrict__ a, Fp6* __restrict__ out, unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - out[tid] = fp6_inv(a[tid]); -} - -} // extern "C" diff --git a/bls/gpu/cuda/bls_fp6.cuh b/bls/gpu/cuda/bls_fp6.cuh deleted file mode 100644 index e4f6452..0000000 --- a/bls/gpu/cuda/bls_fp6.cuh +++ /dev/null @@ -1,191 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CUDA port of bls_fp6.metal — Fp6 = Fp2[v] / (v^3 - (u + 1)). -// Layout matches blst_fp6 byte-for-byte. - -#ifndef BLS_FP6_CUH -#define BLS_FP6_CUH - -#include "bls_fp2.cuh" - -struct Fp6 { Fp2 c0, c1, c2; }; - -__device__ __forceinline__ Fp6 fp6_add(Fp6 a, Fp6 b) { - Fp6 r; - r.c0 = fp2_add(a.c0, b.c0); - r.c1 = fp2_add(a.c1, b.c1); - r.c2 = fp2_add(a.c2, b.c2); - return r; -} - -__device__ __forceinline__ Fp6 fp6_sub(Fp6 a, Fp6 b) { - Fp6 r; - r.c0 = fp2_sub(a.c0, b.c0); - r.c1 = fp2_sub(a.c1, b.c1); - r.c2 = fp2_sub(a.c2, b.c2); - return r; -} - -__device__ __forceinline__ Fp6 fp6_neg(Fp6 a) { - Fp6 r; - r.c0 = fp2_neg(a.c0); - r.c1 = fp2_neg(a.c1); - r.c2 = fp2_neg(a.c2); - return r; -} - -__device__ __forceinline__ Fp6 fp6_mul(Fp6 a, Fp6 b) { - Fp2 t0 = fp2_mul(a.c0, b.c0); - Fp2 t1 = fp2_mul(a.c1, b.c1); - Fp2 t2 = fp2_mul(a.c2, b.c2); - - Fp2 sa12 = fp2_add(a.c1, a.c2); - Fp2 sb12 = fp2_add(b.c1, b.c2); - Fp2 r0 = fp2_mul(sa12, sb12); - r0 = fp2_sub(r0, t1); - r0 = fp2_sub(r0, t2); - r0 = fp2_mul_by_1_plus_u(r0); - r0 = fp2_add(r0, t0); - - Fp2 sa01 = fp2_add(a.c0, a.c1); - Fp2 sb01 = fp2_add(b.c0, b.c1); - Fp2 r1 = fp2_mul(sa01, sb01); - r1 = fp2_sub(r1, t0); - r1 = fp2_sub(r1, t1); - r1 = fp2_add(r1, fp2_mul_by_1_plus_u(t2)); - - Fp2 sa02 = fp2_add(a.c0, a.c2); - Fp2 sb02 = fp2_add(b.c0, b.c2); - Fp2 r2 = fp2_mul(sa02, sb02); - r2 = fp2_sub(r2, t0); - r2 = fp2_sub(r2, t2); - r2 = fp2_add(r2, t1); - - Fp6 r; r.c0 = r0; r.c1 = r1; r.c2 = r2; return r; -} - -__device__ __forceinline__ Fp6 fp6_sqr(Fp6 a) { - Fp2 s0 = fp2_sqr(a.c0); - Fp2 m01 = fp2_mul(a.c0, a.c1); m01 = fp2_add(m01, m01); - Fp2 m12 = fp2_mul(a.c1, a.c2); m12 = fp2_add(m12, m12); - Fp2 s2 = fp2_sqr(a.c2); - - Fp2 sum = fp2_add(fp2_add(a.c0, a.c1), a.c2); - Fp2 r2 = fp2_sqr(sum); - r2 = fp2_sub(r2, s0); - r2 = fp2_sub(r2, s2); - r2 = fp2_sub(r2, m01); - r2 = fp2_sub(r2, m12); - - Fp2 r0 = fp2_mul_by_1_plus_u(m12); - r0 = fp2_add(r0, s0); - - Fp2 r1 = fp2_mul_by_1_plus_u(s2); - r1 = fp2_add(r1, m01); - - Fp6 r; r.c0 = r0; r.c1 = r1; r.c2 = r2; return r; -} - -__device__ __forceinline__ Fp6 fp6_inv(Fp6 a) { - Fp2 c0 = fp2_sqr(a.c0); - Fp2 t = fp2_mul(a.c1, a.c2); - t = fp2_mul_by_1_plus_u(t); - c0 = fp2_sub(c0, t); - - Fp2 c1 = fp2_sqr(a.c2); - c1 = fp2_mul_by_1_plus_u(c1); - Fp2 t01 = fp2_mul(a.c0, a.c1); - c1 = fp2_sub(c1, t01); - - Fp2 c2 = fp2_sqr(a.c1); - Fp2 t02 = fp2_mul(a.c0, a.c2); - c2 = fp2_sub(c2, t02); - - Fp2 t1 = fp2_mul(c1, a.c2); - Fp2 t2 = fp2_mul(c2, a.c1); - Fp2 norm = fp2_add(t1, t2); - norm = fp2_mul_by_1_plus_u(norm); - norm = fp2_add(norm, fp2_mul(c0, a.c0)); - - Fp2 ni = fp2_inv(norm); - - Fp6 r; - r.c0 = fp2_mul(c0, ni); - r.c1 = fp2_mul(c1, ni); - r.c2 = fp2_mul(c2, ni); - return r; -} - -// Frobenius coefficients (Montgomery form), copied verbatim from Metal/blst. -__device__ __forceinline__ static uint384 FP6_FROB_C1_RE_N1_dev() { - uint384 r = {{0,0,0,0,0,0}}; return r; -} -__device__ __forceinline__ static uint384 FP6_FROB_C1_IM_N1_dev() { - uint384 r = {{ - 0xCD03C9E48671F071ULL, 0x5DAB22461FCDA5D2ULL, 0x587042AFD3851B95ULL, - 0x8EB60EBE01BACB9EULL, 0x03F97D6E83D050D2ULL, 0x18F0206554638741ULL - }}; return r; -} -__device__ __forceinline__ static uint384 FP6_FROB_C1_RE_N2_dev() { - uint384 r = {{ - 0x30F1361B798A64E8ULL, 0xF3B8DDAB7ECE5A2AULL, 0x16A8CA3AC61577F7ULL, - 0xC26A2FF874FD029BULL, 0x3636B76660701C6EULL, 0x051BA4AB241B6160ULL - }}; return r; -} -__device__ __forceinline__ static uint384 FP6_FROB_C1_IM_N2_dev() { - uint384 r = {{0,0,0,0,0,0}}; return r; -} -__device__ __forceinline__ static uint384 FP6_FROB_C1_RE_N3_dev() { - uint384 r = {{0,0,0,0,0,0}}; return r; -} -__device__ __forceinline__ static uint384 FP6_FROB_C1_IM_N3_dev() { - uint384 r = {{ - 0x760900000002FFFDULL, 0xEBF4000BC40C0002ULL, 0x5F48985753C758BAULL, - 0x77CE585370525745ULL, 0x5C071A97A256EC6DULL, 0x15F65EC3FA80E493ULL - }}; return r; -} -__device__ __forceinline__ static uint384 FP6_FROB_C2_N1_dev() { - uint384 r = {{ - 0x890DC9E4867545C3ULL, 0x2AF322533285A5D5ULL, 0x50880866309B7E2CULL, - 0xA20D1B8C7E881024ULL, 0x14E4F04FE2DB9068ULL, 0x14E56D3F1564853AULL - }}; return r; -} -__device__ __forceinline__ static uint384 FP6_FROB_C2_N2_dev() { - uint384 r = {{ - 0xCD03C9E48671F071ULL, 0x5DAB22461FCDA5D2ULL, 0x587042AFD3851B95ULL, - 0x8EB60EBE01BACB9EULL, 0x03F97D6E83D050D2ULL, 0x18F0206554638741ULL - }}; return r; -} -__device__ __forceinline__ static uint384 FP6_FROB_C2_N3_dev() { - uint384 r = {{ - 0x43F5FFFFFFFCAAAEULL, 0x32B7FFF2ED47FFFDULL, 0x07E83A49A2E99D69ULL, - 0xECA8F3318332BB7AULL, 0xEF148D1EA0F4C069ULL, 0x040AB3263EFF0206ULL - }}; return r; -} - -__device__ __forceinline__ Fp6 fp6_frobenius(Fp6 a, unsigned n) { - Fp2 r0 = fp2_frobenius(a.c0, n); - Fp2 r1 = fp2_frobenius(a.c1, n); - Fp2 r2 = fp2_frobenius(a.c2, n); - - Fp2 c1; uint384 c2_real; - if (n == 1u) { - c1.c0 = FP6_FROB_C1_RE_N1_dev(); c1.c1 = FP6_FROB_C1_IM_N1_dev(); - c2_real = FP6_FROB_C2_N1_dev(); - } else if (n == 2u) { - c1.c0 = FP6_FROB_C1_RE_N2_dev(); c1.c1 = FP6_FROB_C1_IM_N2_dev(); - c2_real = FP6_FROB_C2_N2_dev(); - } else { - c1.c0 = FP6_FROB_C1_RE_N3_dev(); c1.c1 = FP6_FROB_C1_IM_N3_dev(); - c2_real = FP6_FROB_C2_N3_dev(); - } - - r1 = fp2_mul(r1, c1); - r2.c0 = fp_mul(r2.c0, c2_real); - r2.c1 = fp_mul(r2.c1, c2_real); - - Fp6 r; r.c0 = r0; r.c1 = r1; r.c2 = r2; return r; -} - -#endif // BLS_FP6_CUH diff --git a/bls/gpu/cuda/bls_fp_ops.cuh b/bls/gpu/cuda/bls_fp_ops.cuh deleted file mode 100644 index 4989d8c..0000000 --- a/bls/gpu/cuda/bls_fp_ops.cuh +++ /dev/null @@ -1,206 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CUDA port of bls_fp_ops.h.metal — byte-equal Fp arithmetic for BLS12-381. -// -// All Fp values are stored in Montgomery form, 6 x 64-bit little-endian limbs, -// matching blst's vec384 / blst_fp exactly. Algorithms are 1:1 translations of -// the Metal reference (bls_fp_ops.h.metal) which is itself byte-equal to blst. -// -// Compile-time switch via __CUDA_ARCH__: when not compiled with nvcc, the -// __device__ annotation degrades to nothing so the header is portable for -// stand-in builds (the actual kernels of course only live in nvcc objects). - -#ifndef BLS_FP_OPS_CUH -#define BLS_FP_OPS_CUH - -#include - -#ifndef __CUDACC__ -#define __device__ -#define __host__ -#define __forceinline__ inline -#endif - -struct uint384 { uint64_t limbs[6]; }; - -__device__ __forceinline__ static const uint384 BLS_P_dev() { - uint384 r = {{ - 0xB9FEFFFFFFFFAAABULL, 0x1EABFFFEB153FFFFULL, 0x6730D2A0F6B0F624ULL, - 0x64774B84F38512BFULL, 0x4B1BA7B6434BACD7ULL, 0x1A0111EA397FE69AULL - }}; return r; -} -__device__ __forceinline__ static const uint384 BLS_R_dev() { - uint384 r = {{ - 0x760900000002FFFDULL, 0xEBF4000BC40C0002ULL, 0x5F48985753C758BAULL, - 0x77CE585370525745ULL, 0x5C071A97A256EC6DULL, 0x15F65EC3FA80E493ULL - }}; return r; -} -__device__ __forceinline__ static const uint384 BLS_R2_dev() { - uint384 r = {{ - 0xF4DF1F341C341746ULL, 0x0A76E6A609D104F1ULL, 0x8DE5476C4C95B6D5ULL, - 0x67EB88A9939D83C0ULL, 0x9A793E85B519952DULL, 0x11988FE592CAE3AAULL - }}; return r; -} -__device__ __forceinline__ static const uint384 ZERO384_dev() { - uint384 r = {{0,0,0,0,0,0}}; return r; -} -__device__ __forceinline__ static uint64_t BLS_P_INV_dev() { - return 0x89F3FFFCFFFCFFFDULL; -} - -__device__ __forceinline__ int u384_cmp(uint384 a, uint384 b) { - for (int i = 5; i >= 0; i--) { - if (a.limbs[i] < b.limbs[i]) return -1; - if (a.limbs[i] > b.limbs[i]) return 1; - } - return 0; -} - -__device__ __forceinline__ bool u384_is_zero(uint384 a) { - return (a.limbs[0]|a.limbs[1]|a.limbs[2]|a.limbs[3]|a.limbs[4]|a.limbs[5]) == 0; -} - -__device__ __forceinline__ uint384 u384_add(uint384 a, uint384 b, uint64_t& carry) { - uint384 r; uint64_t c = 0; - for (int i = 0; i < 6; i++) { - uint64_t s1 = a.limbs[i] + c; - uint64_t c1 = (s1 < a.limbs[i]) ? 1ULL : 0ULL; - uint64_t s2 = s1 + b.limbs[i]; - uint64_t c2 = (s2 < s1) ? 1ULL : 0ULL; - r.limbs[i] = s2; - c = c1 + c2; - } - carry = c; - return r; -} - -__device__ __forceinline__ uint384 u384_sub(uint384 a, uint384 b, uint64_t& borrow) { - uint384 r; uint64_t bw = 0; - for (int i = 0; i < 6; i++) { - uint64_t d1 = a.limbs[i] - bw; - uint64_t b1 = (d1 > a.limbs[i]) ? 1ULL : 0ULL; - uint64_t d2 = d1 - b.limbs[i]; - uint64_t b2 = (d2 > d1) ? 1ULL : 0ULL; - r.limbs[i] = d2; - bw = b1 + b2; - } - borrow = bw; - return r; -} - -// 64x64 -> 128 (lo, hi). On CUDA we have native __umul64hi. -__device__ __forceinline__ void mul64(uint64_t a, uint64_t b, uint64_t& lo, uint64_t& hi) { -#ifdef __CUDA_ARCH__ - lo = a * b; - hi = __umul64hi(a, b); -#else - uint64_t al = a & 0xFFFFFFFFULL, ah = a >> 32; - uint64_t bl = b & 0xFFFFFFFFULL, bh = b >> 32; - uint64_t ll = al*bl, lh = al*bh, hl = ah*bl, hh = ah*bh; - uint64_t mid = lh + (ll >> 32); - uint64_t mid2 = mid + hl; - if (mid2 < mid) hh += (1ULL << 32); - lo = (mid2 << 32) | (ll & 0xFFFFFFFFULL); - hi = hh + (mid2 >> 32); -#endif -} - -// CIOS Montgomery reduction of 768-bit t -> t * R^(-1) mod p. -__device__ __forceinline__ uint384 mont_reduce_384(uint64_t t[12]) { - uint64_t a[13]; - for (int i = 0; i < 12; i++) a[i] = t[i]; - a[12] = 0; - const uint384 P = BLS_P_dev(); - const uint64_t P_INV = BLS_P_INV_dev(); - for (int i = 0; i < 6; i++) { - uint64_t u = a[i] * P_INV; - uint64_t carry = 0; - for (int j = 0; j < 6; j++) { - uint64_t lo, hi; mul64(u, P.limbs[j], lo, hi); - uint64_t s = lo + carry; if (s < lo) hi++; - lo = s; - s = a[i+j] + lo; if (s < a[i+j]) hi++; - a[i+j] = s; - carry = hi; - } - for (int j = 6; i+j <= 12; j++) { - uint64_t s = a[i+j] + carry; - carry = (s < a[i+j]) ? 1ULL : 0ULL; - a[i+j] = s; - if (carry == 0) break; - } - } - uint384 r; - r.limbs[0]=a[6]; r.limbs[1]=a[7]; r.limbs[2]=a[8]; - r.limbs[3]=a[9]; r.limbs[4]=a[10]; r.limbs[5]=a[11]; - if (a[12] || u384_cmp(r, P) >= 0) { - uint64_t bw; r = u384_sub(r, P, bw); - } - return r; -} - -__device__ __forceinline__ uint384 fp_mul(uint384 a, uint384 b) { - uint64_t t[12] = {}; - for (int i = 0; i < 6; i++) { - uint64_t carry = 0; - for (int j = 0; j < 6; j++) { - uint64_t lo, hi; mul64(a.limbs[i], b.limbs[j], lo, hi); - uint64_t s = lo + carry; if (s < lo) hi++; - lo = s; - s = t[i+j] + lo; if (s < t[i+j]) hi++; - t[i+j] = s; - carry = hi; - } - for (int j = 6; i+j < 12; j++) { - uint64_t s = t[i+j] + carry; - carry = (s < t[i+j]) ? 1ULL : 0ULL; - t[i+j] = s; - if (carry == 0) break; - } - } - return mont_reduce_384(t); -} - -__device__ __forceinline__ uint384 fp_sqr(uint384 a) { return fp_mul(a, a); } - -__device__ __forceinline__ uint384 fp_add(uint384 a, uint384 b) { - uint64_t c; uint384 r = u384_add(a, b, c); - const uint384 P = BLS_P_dev(); - if (c || u384_cmp(r, P) >= 0) { - uint64_t bw; r = u384_sub(r, P, bw); - } - return r; -} - -__device__ __forceinline__ uint384 fp_sub(uint384 a, uint384 b) { - uint64_t bw; uint384 r = u384_sub(a, b, bw); - if (bw) { uint64_t c; r = u384_add(r, BLS_P_dev(), c); } - return r; -} - -__device__ __forceinline__ uint384 fp_neg(uint384 a) { - if (u384_is_zero(a)) return a; - uint64_t bw; return u384_sub(BLS_P_dev(), a, bw); -} - -// Fermat inversion. Same MSB->LSB binary square-and-multiply as Metal. -__device__ __forceinline__ uint384 fp_inv(uint384 a) { - uint384 exp = BLS_P_dev(); - exp.limbs[0] -= 2; - - uint384 result = BLS_R_dev(); - bool started = false; - for (int i = 5; i >= 0; i--) { - for (int bit = 63; bit >= 0; bit--) { - if (started) result = fp_sqr(result); - if ((exp.limbs[i] >> bit) & 1) { - result = started ? fp_mul(result, a) : a; - started = true; - } - } - } - return result; -} - -#endif // BLS_FP_OPS_CUH diff --git a/bls/gpu/cuda/bls_g2.cu b/bls/gpu/cuda/bls_g2.cu deleted file mode 100644 index 626c1e5..0000000 --- a/bls/gpu/cuda/bls_g2.cu +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco - -#include "bls_g2.cuh" - -struct P2ScalarIn { - P2 base; - unsigned char scalar[32]; -}; - -extern "C" { - -__global__ void k_p2_jac_add(const P2* __restrict__ a, const P2* __restrict__ b, - P2* __restrict__ out, unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - out[tid] = p2_jac_add(a[tid], b[tid]); -} - -__global__ void k_p2_jac_dbl(const P2* __restrict__ a, P2* __restrict__ out, unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - out[tid] = p2_jac_dbl(a[tid]); -} - -__global__ void k_p2_mixed_add(const P2* __restrict__ a, const P2Aff* __restrict__ b, - P2* __restrict__ out, unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - out[tid] = p2_mixed_add(a[tid], b[tid]); -} - -__global__ void k_p2_scalar_mult(const P2ScalarIn* __restrict__ in, - P2Aff* __restrict__ out, - unsigned n, unsigned nbits) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - out[tid] = p2_scalar_mult(in[tid].base, in[tid].scalar, nbits); -} - -} // extern "C" diff --git a/bls/gpu/cuda/bls_g2.cuh b/bls/gpu/cuda/bls_g2.cuh deleted file mode 100644 index 45f514a..0000000 --- a/bls/gpu/cuda/bls_g2.cuh +++ /dev/null @@ -1,182 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CUDA port of bls_g2.metal — G2 = E'(Fp2). Layouts byte-equal blst. - -#ifndef BLS_G2_CUH -#define BLS_G2_CUH - -#include "bls_fp2.cuh" - -struct P2 { Fp2 X, Y, Z; }; -struct P2Aff { Fp2 X, Y; }; - -__device__ __forceinline__ P2 p2_jac_add(P2 p1, P2 p2) { - bool p1inf = fp2_is_zero(p1.Z); - Fp2 Z1Z1 = fp2_sqr(p1.Z); - - Fp2 pZ = fp2_mul(Z1Z1, p1.Z); - pZ = fp2_mul(pZ, p2.Y); - - bool p2inf = fp2_is_zero(p2.Z); - Fp2 Z2Z2 = fp2_sqr(p2.Z); - - Fp2 S1 = fp2_mul(Z2Z2, p2.Z); - S1 = fp2_mul(S1, p1.Y); - - pZ = fp2_sub(pZ, S1); - pZ = fp2_add(pZ, pZ); - - Fp2 U1 = fp2_mul(p1.X, Z2Z2); - Fp2 H = fp2_mul(p2.X, Z1Z1); - H = fp2_sub(H, U1); - - Fp2 I = fp2_add(H, H); - I = fp2_sqr(I); - - Fp2 J = fp2_mul(H, I); - S1 = fp2_mul(S1, J); - - Fp2 V = fp2_mul(U1, I); - - Fp2 pX = fp2_sqr(pZ); - pX = fp2_sub(pX, J); - pX = fp2_sub(pX, V); - pX = fp2_sub(pX, V); - - Fp2 pY = fp2_sub(V, pX); - pY = fp2_mul(pY, pZ); - pY = fp2_sub(pY, S1); - pY = fp2_sub(pY, S1); - - pZ = fp2_add(p1.Z, p2.Z); - pZ = fp2_sqr(pZ); - pZ = fp2_sub(pZ, Z1Z1); - pZ = fp2_sub(pZ, Z2Z2); - pZ = fp2_mul(pZ, H); - - P2 p3; p3.X = pX; p3.Y = pY; p3.Z = pZ; - - if (p2inf) p3 = p1; - if (p1inf) p3 = p2; - return p3; -} - -__device__ __forceinline__ P2 p2_jac_dbl(P2 p1) { - Fp2 A = fp2_sqr(p1.X); - Fp2 B = fp2_sqr(p1.Y); - Fp2 C = fp2_sqr(B); - - B = fp2_add(B, p1.X); - B = fp2_sqr(B); - B = fp2_sub(B, A); - B = fp2_sub(B, C); - B = fp2_add(B, B); - - Fp2 A3 = fp2_add(A, A); - A3 = fp2_add(A3, A); - - Fp2 pX = fp2_sqr(A3); - pX = fp2_sub(pX, B); - pX = fp2_sub(pX, B); - - Fp2 pZ = fp2_add(p1.Z, p1.Z); - pZ = fp2_mul(pZ, p1.Y); - - Fp2 C8 = fp2_add(C, C); - C8 = fp2_add(C8, C8); - C8 = fp2_add(C8, C8); - - Fp2 pY = fp2_sub(B, pX); - pY = fp2_mul(pY, A3); - pY = fp2_sub(pY, C8); - - P2 p3; p3.X = pX; p3.Y = pY; p3.Z = pZ; - return p3; -} - -__device__ __forceinline__ P2 p2_mixed_add(P2 p1, P2Aff p2) { - bool p1inf = fp2_is_zero(p1.Z); - Fp2 Z1Z1 = fp2_sqr(p1.Z); - - Fp2 pZ = fp2_mul(Z1Z1, p1.Z); - pZ = fp2_mul(pZ, p2.Y); - - bool p2inf = fp2_is_zero(p2.X) && fp2_is_zero(p2.Y); - - Fp2 H = fp2_mul(p2.X, Z1Z1); - H = fp2_sub(H, p1.X); - - Fp2 HH = fp2_sqr(H); - Fp2 I = fp2_add(HH, HH); - I = fp2_add(I, I); - - Fp2 pY_v = fp2_mul(p1.X, I); - Fp2 J = fp2_mul(H, I); - Fp2 Iy = fp2_mul(J, p1.Y); - - pZ = fp2_sub(pZ, p1.Y); - pZ = fp2_add(pZ, pZ); - - Fp2 pX = fp2_sqr(pZ); - pX = fp2_sub(pX, J); - pX = fp2_sub(pX, pY_v); - pX = fp2_sub(pX, pY_v); - - Fp2 pY = fp2_sub(pY_v, pX); - pY = fp2_mul(pY, pZ); - pY = fp2_sub(pY, Iy); - pY = fp2_sub(pY, Iy); - - pZ = fp2_add(p1.Z, H); - pZ = fp2_sqr(pZ); - pZ = fp2_sub(pZ, Z1Z1); - pZ = fp2_sub(pZ, HH); - - P2 p3; p3.X = pX; p3.Y = pY; p3.Z = pZ; - - if (p1inf) { - p3.X = p2.X; - p3.Y = p2.Y; - p3.Z = fp2_one(); - } - if (p2inf) p3 = p1; - return p3; -} - -__device__ __forceinline__ P2Aff p2_to_affine(P2 p) { - P2Aff a; - if (fp2_is_zero(p.Z)) { - a.X = fp2_zero(); - a.Y = fp2_zero(); - return a; - } - Fp2 Zi = fp2_inv(p.Z); - Fp2 Zi2 = fp2_sqr(Zi); - Fp2 Zi3 = fp2_mul(Zi2, Zi); - a.X = fp2_mul(p.X, Zi2); - a.Y = fp2_mul(p.Y, Zi3); - return a; -} - -__device__ __forceinline__ P2Aff p2_scalar_mult(P2 base, const unsigned char* scalar, unsigned nbits) { - P2 R; R.X = fp2_zero(); R.Y = fp2_zero(); R.Z = fp2_zero(); - - bool started = false; - for (int i = (int)nbits - 1; i >= 0; i--) { - unsigned char byte = scalar[i >> 3]; - unsigned bit = (unsigned)((byte >> (i & 7)) & 1u); - if (started) R = p2_jac_dbl(R); - if (bit) { - if (started) { - R = p2_jac_add(R, base); - } else { - R = base; - started = true; - } - } - } - return p2_to_affine(R); -} - -#endif // BLS_G2_CUH diff --git a/bls/gpu/cuda/bls_miller.cu b/bls/gpu/cuda/bls_miller.cu deleted file mode 100644 index 5e5b6d4..0000000 --- a/bls/gpu/cuda/bls_miller.cu +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CUDA Miller-loop kernels — same kernel split as Metal (host orchestrates). - -#include "bls_miller.cuh" - -extern "C" { - -__global__ void k_miller_init(const MillerIn* __restrict__ in, - P2* __restrict__ T_buf, - Fp12* __restrict__ ret_buf, - Fp2* __restrict__ px2_buf, - unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - - P2Aff Q = in[tid].Q; - P1Aff P = in[tid].P; - - uint384 two_px = fp_add(P.X, P.X); - Fp2 Px2; - Px2.c0 = fp_neg(two_px); - Px2.c1 = fp_add(P.Y, P.Y); - px2_buf[tid] = Px2; - - P2 T; - T.X = Q.X; T.Y = Q.Y; T.Z = fp2_one(); - - Line L0 = line_dbl_dev(T, T); - L0 = line_by_Px2_dev(L0, Px2.c0, Px2.c1); - Fp12 ret = unpack_initial_line_dev(L0); - - T_buf[tid] = T; - ret_buf[tid] = ret; -} - -__global__ void k_miller_add_T_and_line(const MillerIn* __restrict__ in, - P2* __restrict__ T_buf, - LineBuf* __restrict__ line_buf, - const Fp2* __restrict__ px2_buf, - unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - P2 T = T_buf[tid]; - Fp2 Px2 = px2_buf[tid]; - - Line L = line_add_dev(T, T, in[tid].Q); - L = line_by_Px2_dev(L, Px2.c0, Px2.c1); - - T_buf[tid] = T; - LineBuf lb; lb.x = L.x; lb.y = L.y; lb.z = L.z; - line_buf[tid] = lb; -} - -__global__ void k_miller_dbl_T_and_line(P2* __restrict__ T_buf, - LineBuf* __restrict__ line_buf, - const Fp2* __restrict__ px2_buf, - unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - P2 T = T_buf[tid]; - Fp2 Px2 = px2_buf[tid]; - - Line Ld = line_dbl_dev(T, T); - Ld = line_by_Px2_dev(Ld, Px2.c0, Px2.c1); - - T_buf[tid] = T; - LineBuf lb; lb.x = Ld.x; lb.y = Ld.y; lb.z = Ld.z; - line_buf[tid] = lb; -} - -__global__ void k_miller_sqr_ret(Fp12* __restrict__ ret_buf, unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - ret_buf[tid] = fp12_sqr(ret_buf[tid]); -} - -__global__ void k_miller_fold_line(Fp12* __restrict__ ret_buf, - const LineBuf* __restrict__ line_buf, - unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - LineBuf lb = line_buf[tid]; - Line L; L.x = lb.x; L.y = lb.y; L.z = lb.z; - ret_buf[tid] = fp12_mul_by_xy00z0_dev(ret_buf[tid], L); -} - -__global__ void k_miller_finalize(Fp12* __restrict__ ret_buf, - Fp12* __restrict__ out, unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - out[tid] = fp12_conj(ret_buf[tid]); -} - -} // extern "C" diff --git a/bls/gpu/cuda/bls_miller.cuh b/bls/gpu/cuda/bls_miller.cuh deleted file mode 100644 index 96e9f97..0000000 --- a/bls/gpu/cuda/bls_miller.cuh +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CUDA port of bls_miller.metal — BLS12-381 optimal-ate Miller loop. -// Same kernel split as Metal (host orchestrates dispatches). - -#ifndef BLS_MILLER_CUH -#define BLS_MILLER_CUH - -#include "bls_fp12.cuh" -#include "bls_g2.cuh" - -struct P1Aff { uint384 X, Y; }; -struct Line { Fp2 x, y, z; }; -struct LineBuf { Fp2 x, y, z; }; - -struct MillerIn { - P2Aff Q; - P1Aff P; -}; - -__device__ static Line line_dbl_dev(P2& T, P2 Q) { - Fp2 A = fp2_sqr(Q.X); - Fp2 B = fp2_sqr(Q.Y); - Fp2 ZZ = fp2_sqr(Q.Z); - Fp2 C = fp2_sqr(B); - - Fp2 D = fp2_add(Q.X, B); - D = fp2_sqr(D); - D = fp2_sub(D, A); - D = fp2_sub(D, C); - D = fp2_add(D, D); - - Fp2 E = fp2_add(A, A); - E = fp2_add(E, A); - - Fp2 F = fp2_sqr(E); - - Fp2 line0 = fp2_add(E, Q.X); - - Fp2 Tx = fp2_sub(F, D); - Tx = fp2_sub(Tx, D); - - Fp2 Tz = fp2_add(Q.Y, Q.Z); - Tz = fp2_sqr(Tz); - Tz = fp2_sub(Tz, B); - Tz = fp2_sub(Tz, ZZ); - - Fp2 C8 = fp2_add(C, C); - C8 = fp2_add(C8, C8); - C8 = fp2_add(C8, C8); - - Fp2 Ty = fp2_sub(D, Tx); - Ty = fp2_mul(Ty, E); - Ty = fp2_sub(Ty, C8); - - line0 = fp2_sqr(line0); - line0 = fp2_sub(line0, A); - line0 = fp2_sub(line0, F); - Fp2 B4 = fp2_add(B, B); - B4 = fp2_add(B4, B4); - line0 = fp2_sub(line0, B4); - - Fp2 line1 = fp2_mul(E, ZZ); - Fp2 line2 = fp2_mul(Tz, ZZ); - - T.X = Tx; T.Y = Ty; T.Z = Tz; - - Line L; L.x = line0; L.y = line1; L.z = line2; - return L; -} - -__device__ static Line line_add_dev(P2& T, P2 R, P2Aff Q) { - Fp2 Z1Z1 = fp2_sqr(R.Z); - Fp2 U2 = fp2_mul(Q.X, Z1Z1); - - Fp2 S2 = fp2_mul(Q.Y, R.Z); - S2 = fp2_mul(S2, Z1Z1); - - Fp2 H = fp2_sub(U2, R.X); - - Fp2 HH = fp2_sqr(H); - Fp2 I = fp2_add(HH, HH); - I = fp2_add(I, I); - - Fp2 J = fp2_mul(H, I); - - Fp2 r = fp2_sub(S2, R.Y); - r = fp2_add(r, r); - - Fp2 V = fp2_mul(R.X, I); - - Fp2 Tx = fp2_sqr(r); - Tx = fp2_sub(Tx, J); - Tx = fp2_sub(Tx, V); - Tx = fp2_sub(Tx, V); - - Fp2 Jy = fp2_mul(J, R.Y); - Fp2 Ty = fp2_sub(V, Tx); - Ty = fp2_mul(Ty, r); - Ty = fp2_sub(Ty, Jy); - Ty = fp2_sub(Ty, Jy); - - Fp2 Tz = fp2_add(R.Z, H); - Tz = fp2_sqr(Tz); - Tz = fp2_sub(Tz, Z1Z1); - Tz = fp2_sub(Tz, HH); - - Fp2 lineI = fp2_mul(r, Q.X); - Fp2 lineJ = fp2_mul(Q.Y, Tz); - lineI = fp2_sub(lineI, lineJ); - Fp2 line0 = fp2_add(lineI, lineI); - - T.X = Tx; T.Y = Ty; T.Z = Tz; - - Line L; L.x = line0; L.y = r; L.z = Tz; - return L; -} - -__device__ __forceinline__ Line line_by_Px2_dev(Line L, uint384 px_neg2, uint384 py_2) { - L.y.c0 = fp_mul(L.y.c0, px_neg2); - L.y.c1 = fp_mul(L.y.c1, px_neg2); - L.z.c0 = fp_mul(L.z.c0, py_2); - L.z.c1 = fp_mul(L.z.c1, py_2); - return L; -} - -__device__ __forceinline__ Fp6 fp6_mul_by_xy0_dev(Fp6 a, Fp2 b0, Fp2 b1) { - Fp2 t0 = fp2_mul(a.c0, b0); - Fp2 t1 = fp2_mul(a.c1, b1); - - Fp2 t3 = fp2_mul(a.c2, b1); - t3 = fp2_mul_by_1_plus_u(t3); - - Fp2 t4 = fp2_add(a.c0, a.c1); - Fp2 t5 = fp2_add(b0, b1); - Fp2 r1 = fp2_mul(t4, t5); - r1 = fp2_sub(r1, t0); - r1 = fp2_sub(r1, t1); - - Fp2 r2 = fp2_mul(a.c2, b0); - r2 = fp2_add(r2, t1); - - Fp2 r0 = fp2_add(t3, t0); - - Fp6 r; r.c0 = r0; r.c1 = r1; r.c2 = r2; return r; -} - -__device__ __forceinline__ Fp6 fp6_mul_by_0y0_dev(Fp6 a, Fp2 b) { - Fp2 t = fp2_mul(a.c2, b); - Fp6 r; - r.c2 = fp2_mul(a.c1, b); - r.c1 = fp2_mul(a.c0, b); - r.c0 = fp2_mul_by_1_plus_u(t); - return r; -} - -__device__ static Fp12 fp12_mul_by_xy00z0_dev(Fp12 a, Line L) { - Fp6 t0 = fp6_mul_by_xy0_dev(a.c0, L.x, L.y); - Fp6 t1 = fp6_mul_by_0y0_dev(a.c1, L.z); - - Fp2 b1_alt = fp2_add(L.y, L.z); - Fp6 sum = fp6_add(a.c0, a.c1); - Fp6 r1 = fp6_mul_by_xy0_dev(sum, L.x, b1_alt); - r1 = fp6_sub(r1, t0); - r1 = fp6_sub(r1, t1); - - Fp6 t1v; - t1v.c0 = fp2_mul_by_1_plus_u(t1.c2); - t1v.c1 = t1.c0; - t1v.c2 = t1.c1; - Fp6 r0 = fp6_add(t0, t1v); - - Fp12 r; r.c0 = r0; r.c1 = r1; return r; -} - -__device__ __forceinline__ Fp12 unpack_initial_line_dev(Line L) { - Fp12 ret; - ret.c0.c0 = L.x; - ret.c0.c1 = L.y; - ret.c0.c2 = fp2_zero(); - ret.c1.c0 = fp2_zero(); - ret.c1.c1 = L.z; - ret.c1.c2 = fp2_zero(); - return ret; -} - -#endif // BLS_MILLER_CUH diff --git a/bls/gpu/cuda/bls_pairing.cu b/bls/gpu/cuda/bls_pairing.cu deleted file mode 100644 index 7101793..0000000 --- a/bls/gpu/cuda/bls_pairing.cu +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CUDA pairing helpers — mirror bls_pairing.metal. - -#include "bls_fp12.cuh" - -extern "C" { - -__global__ void k_pair_one_init(Fp12* __restrict__ acc, unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - acc[tid] = fp12_one(); -} - -__global__ void k_pair_aggregate_step(const Fp12* __restrict__ src, Fp12* __restrict__ acc, - unsigned step_idx, unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - acc[tid] = fp12_mul(acc[tid], src[step_idx]); -} - -__global__ void k_pair_eq_one(const Fp12* __restrict__ ret, - unsigned char* __restrict__ flag, unsigned n) { - unsigned tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n) return; - Fp12 one = fp12_one(); - Fp12 r = ret[tid]; - - bool eq = true; - Fp2 a[6] = { r.c0.c0, r.c0.c1, r.c0.c2, r.c1.c0, r.c1.c1, r.c1.c2 }; - Fp2 b[6] = { one.c0.c0, one.c0.c1, one.c0.c2, one.c1.c0, one.c1.c1, one.c1.c2 }; - for (unsigned i = 0; i < 6; i++) { - for (unsigned j = 0; j < 6; j++) { - if (a[i].c0.limbs[j] != b[i].c0.limbs[j]) { eq = false; } - if (a[i].c1.limbs[j] != b[i].c1.limbs[j]) { eq = false; } - } - } - flag[tid] = eq ? 1u : 0u; -} - -} // extern "C" diff --git a/bls/gpu/metal/bls.metal b/bls/gpu/metal/bls.metal deleted file mode 100644 index 79c610a..0000000 --- a/bls/gpu/metal/bls.metal +++ /dev/null @@ -1,742 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -/// @file bls12_381.metal -/// Metal compute shader for BLS12-381 signature verification. -/// -/// Implements 384-bit field arithmetic (Fp) in Montgomery form for -/// batch BLS signature verification in Quasar consensus. -/// -/// BLS12-381 curve parameters: -/// p = 0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab -/// r = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001 -/// G1: y^2 = x^3 + 4 over Fp -/// G2: y^2 = x^3 + 4(1+i) over Fp2 = Fp[u]/(u^2+1) -/// -/// Operations: -/// bls_verify_batch: verify N BLS signatures in parallel (one per thread) -/// bls_aggregate_g1: aggregate N G1 points (signature aggregation) -/// -/// Each BLS signature verification requires: -/// 1. Deserialize signature (G1 point, 48 bytes compressed) -/// 2. Deserialize public key (G2 point, 96 bytes compressed) -/// 3. Hash message to G1 point (hash-to-curve, simplified) -/// 4. Pairing check: e(sig, G2_gen) == e(H(msg), pubkey) -/// -/// For the initial implementation, we focus on G1 point operations -/// and batch verification of the signature equation WITHOUT full -/// pairing (which requires Fp12 and Miller loop). Instead, we verify -/// the aggregated form: e(sum(sigs), G2_gen) == e(sum(H(msgs_i) * alpha_i), pubkey) -/// reducing to a single pairing check on the host, with the GPU doing -/// the heavy G1 scalar multiplications and aggregation. - -#include -using namespace metal; - -// ============================================================================= -// 384-bit unsigned integer (6 x 64-bit limbs, little-endian) -// ============================================================================= - -struct uint384 { - ulong limbs[6]; // limbs[0] = least significant -}; - -// ============================================================================= -// BLS12-381 constants -// ============================================================================= - -// Field modulus p (384 bits) -constant uint384 BLS_P = {{ - 0xB9FEFFFFFFFFAAABUL, - 0x1EABFFFEB153FFFFUL, - 0x6730D2A0F6B0F624UL, - 0x64774B84F38512BFUL, - 0x4B1BA7B6434BACD7UL, - 0x1A0111EA397FE69AUL -}}; - -// Montgomery R^2 mod p (for encoding to Montgomery form) -constant uint384 BLS_R2 = {{ - 0xF4DF1F341C341746UL, - 0x0A76E6A609D104F1UL, - 0x8DE5476C4C95B6D5UL, - 0x67EB88A9939D83C0UL, - 0x9A793E85B519952DUL, - 0x11988FE592CAE3AAUL -}}; - -// Montgomery R mod p -constant uint384 BLS_R = {{ - 0x760900000002FFCDUL, - 0xEBF4000BC40C0002UL, - 0x5F48985753C758BAUL, - 0x77CE585370525745UL, - 0x5C071A97A256EC6DUL, - 0x15F65EC3FA80E493UL -}}; - -// -p^(-1) mod 2^64 (Montgomery constant) -constant ulong BLS_P_INV = 0x89F3FFFCFFFCFFFDUL; - -// Generator G1 (affine, Montgomery form) -// G1_x = 0x17f1d3a73197d7942695638c4fa9ac0fc3688c4f9774b905a14e3a3f171bac586c55e83ff97a1aeffb3af00adb22c6bb -// G1_y = 0x08b3f481e3aaa0f1a09e30ed741d8ae4fcf5e095d5d00af600db18cb2c04b3edd03cc744a2888ae40caa232946c5e7e1 -constant uint384 G1_X = {{ - 0x5CB38790FD666E19UL, - 0xF85DDE8F09FE5D5CUL, - 0x2C0B0A5CAFB74CD8UL, - 0x95F7B3B14AAE717DUL, - 0x70E02F1AB69D14E3UL, - 0x03C26A6D58B32048UL -}}; -constant uint384 G1_Y = {{ - 0xA402B931448DC5C8UL, - 0xFBD6AA1ADEAD1CF6UL, - 0x5B9D93D1BA1F5B57UL, - 0x6DC08AFF5B3AF6DDUL, - 0xA4CF5B5C1B6CE90CUL, - 0x13F48FFF25F51018UL -}}; - -// Zero -constant uint384 ZERO384 = {{0, 0, 0, 0, 0, 0}}; - -// ============================================================================= -// 384-bit arithmetic -// ============================================================================= - -inline int u384_cmp(uint384 a, uint384 b) { - for (int i = 5; i >= 0; i--) { - if (a.limbs[i] < b.limbs[i]) return -1; - if (a.limbs[i] > b.limbs[i]) return 1; - } - return 0; -} - -inline bool u384_is_zero(uint384 a) { - return (a.limbs[0] | a.limbs[1] | a.limbs[2] | - a.limbs[3] | a.limbs[4] | a.limbs[5]) == 0; -} - -inline uint384 u384_add(uint384 a, uint384 b, thread ulong& carry) { - uint384 r; - ulong c = 0; - for (int i = 0; i < 6; i++) { - ulong sum = a.limbs[i] + c; - c = (sum < a.limbs[i]) ? 1UL : 0UL; - ulong sum2 = sum + b.limbs[i]; - c += (sum2 < sum) ? 1UL : 0UL; - r.limbs[i] = sum2; - } - carry = c; - return r; -} - -inline uint384 u384_sub(uint384 a, uint384 b, thread ulong& borrow) { - uint384 r; - ulong bw = 0; - for (int i = 0; i < 6; i++) { - ulong diff = a.limbs[i] - bw; - bw = (diff > a.limbs[i]) ? 1UL : 0UL; - ulong diff2 = diff - b.limbs[i]; - bw += (diff2 > diff) ? 1UL : 0UL; - r.limbs[i] = diff2; - } - borrow = bw; - return r; -} - -// ============================================================================= -// 64x64 -> 128 bit multiplication (no native 128-bit on Metal) -// ============================================================================= - -/// Multiply two 64-bit values, return (lo, hi). -inline void mul64(ulong a, ulong b, thread ulong& lo, thread ulong& hi) { - ulong a_lo = a & 0xFFFFFFFFUL; - ulong a_hi = a >> 32; - ulong b_lo = b & 0xFFFFFFFFUL; - ulong b_hi = b >> 32; - - ulong ll = a_lo * b_lo; - ulong lh = a_lo * b_hi; - ulong hl = a_hi * b_lo; - ulong hh = a_hi * b_hi; - - ulong mid = lh + (ll >> 32); - ulong mid2 = mid + hl; - if (mid2 < mid) hh += (1UL << 32); - - lo = (mid2 << 32) | (ll & 0xFFFFFFFFUL); - hi = hh + (mid2 >> 32); -} - -// ============================================================================= -// Montgomery arithmetic over Fp (384-bit) -// ============================================================================= - -/// Montgomery reduction of a 768-bit value t[12] mod p. -/// Returns t * R^(-1) mod p. -inline uint384 mont_reduce_384(ulong t[12]) { - ulong a[13]; - for (int i = 0; i < 12; i++) a[i] = t[i]; - a[12] = 0; - - for (int i = 0; i < 6; i++) { - ulong u = a[i] * BLS_P_INV; - - ulong carry = 0; - for (int j = 0; j < 6; j++) { - ulong lo, hi; - mul64(u, BLS_P.limbs[j], lo, hi); - - ulong sum = lo + carry; - if (sum < lo) hi++; - lo = sum; - - sum = a[i + j] + lo; - if (sum < a[i + j]) hi++; - a[i + j] = sum; - carry = hi; - } - for (int j = 6; i + j <= 12; j++) { - ulong sum = a[i + j] + carry; - carry = (sum < a[i + j]) ? 1UL : 0UL; - a[i + j] = sum; - if (carry == 0) break; - } - } - - uint384 r; - r.limbs[0] = a[6]; - r.limbs[1] = a[7]; - r.limbs[2] = a[8]; - r.limbs[3] = a[9]; - r.limbs[4] = a[10]; - r.limbs[5] = a[11]; - - if (a[12] || u384_cmp(r, BLS_P) >= 0) { - ulong bw; - r = u384_sub(r, BLS_P, bw); - } - return r; -} - -/// Montgomery multiplication: a * b * R^(-1) mod p -inline uint384 fp_mul(uint384 a, uint384 b) { - ulong t[12] = {}; - - for (int i = 0; i < 6; i++) { - ulong carry = 0; - for (int j = 0; j < 6; j++) { - ulong lo, hi; - mul64(a.limbs[i], b.limbs[j], lo, hi); - - ulong sum = lo + carry; - if (sum < lo) hi++; - lo = sum; - - sum = t[i + j] + lo; - if (sum < t[i + j]) hi++; - t[i + j] = sum; - carry = hi; - } - for (int j = 6; i + j < 12; j++) { - ulong sum = t[i + j] + carry; - carry = (sum < t[i + j]) ? 1UL : 0UL; - t[i + j] = sum; - if (carry == 0) break; - } - } - - return mont_reduce_384(t); -} - -inline uint384 fp_sqr(uint384 a) { - return fp_mul(a, a); -} - -inline uint384 fp_add(uint384 a, uint384 b) { - ulong carry; - uint384 r = u384_add(a, b, carry); - if (carry || u384_cmp(r, BLS_P) >= 0) { - ulong bw; - r = u384_sub(r, BLS_P, bw); - } - return r; -} - -inline uint384 fp_sub(uint384 a, uint384 b) { - ulong bw; - uint384 r = u384_sub(a, b, bw); - if (bw) { - ulong c; - r = u384_add(r, BLS_P, c); - } - return r; -} - -inline uint384 fp_neg(uint384 a) { - if (u384_is_zero(a)) return a; - ulong bw; - return u384_sub(BLS_P, a, bw); -} - -/// Convert to Montgomery form: a * R mod p -inline uint384 to_mont(uint384 a) { - return fp_mul(a, BLS_R2); -} - -/// Convert from Montgomery form: aR * R^(-1) = a -inline uint384 from_mont(uint384 a) { - ulong t[12] = {a.limbs[0], a.limbs[1], a.limbs[2], - a.limbs[3], a.limbs[4], a.limbs[5], - 0, 0, 0, 0, 0, 0}; - return mont_reduce_384(t); -} - -/// Fermat inversion: a^(p-2) mod p -inline uint384 fp_inv(uint384 a) { - // p-2 as 6 limbs (BLS_P - 2) - uint384 exp = BLS_P; - exp.limbs[0] -= 2; - - uint384 result = BLS_R; // 1 in Montgomery form - uint384 base = a; - - for (int i = 0; i < 6; i++) { - for (int bit = 0; bit < 64; bit++) { - if ((exp.limbs[i] >> bit) & 1) { - result = fp_mul(result, base); - } - base = fp_sqr(base); - } - } - return result; -} - -// ============================================================================= -// G1 point operations (Jacobian coordinates, Montgomery Fp) -// ============================================================================= - -struct G1Point { - uint384 x, y, z; -}; - -inline G1Point g1_identity() { - G1Point p; - p.x = BLS_R; // 1 in Montgomery - p.y = BLS_R; - p.z = ZERO384; // Z=0 is identity - return p; -} - -inline bool g1_is_infinity(G1Point p) { - return u384_is_zero(p.z); -} - -/// G1 point doubling (a=0 for BLS12-381 G1: y^2 = x^3 + 4) -inline G1Point g1_double(G1Point p) { - if (g1_is_infinity(p)) return p; - - uint384 A = fp_sqr(p.y); - uint384 B = fp_mul(p.x, A); - uint384 C = fp_sqr(A); - - // S = 4*B - uint384 S = fp_add(B, B); - S = fp_add(S, S); - - // M = 3*X^2 (a=0) - uint384 X2 = fp_sqr(p.x); - uint384 M = fp_add(X2, fp_add(X2, X2)); - - // X3 = M^2 - 2*S - uint384 X3 = fp_sub(fp_sqr(M), fp_add(S, S)); - - // Y3 = M*(S - X3) - 8*C - uint384 C8 = fp_add(C, C); - C8 = fp_add(C8, C8); - C8 = fp_add(C8, C8); - uint384 Y3 = fp_sub(fp_mul(M, fp_sub(S, X3)), C8); - - // Z3 = 2*Y*Z - uint384 Z3 = fp_mul(p.y, p.z); - Z3 = fp_add(Z3, Z3); - - G1Point r; - r.x = X3; r.y = Y3; r.z = Z3; - return r; -} - -/// G1 mixed addition (Q in affine, P in Jacobian) -inline G1Point g1_add_mixed(G1Point P, uint384 Qx, uint384 Qy) { - if (g1_is_infinity(P)) { - G1Point r; - r.x = Qx; r.y = Qy; r.z = BLS_R; - return r; - } - - uint384 Z2 = fp_sqr(P.z); - uint384 U2 = fp_mul(Qx, Z2); - uint384 Z3 = fp_mul(Z2, P.z); - uint384 S2 = fp_mul(Qy, Z3); - - uint384 H = fp_sub(U2, P.x); - uint384 R = fp_sub(S2, P.y); - - if (u384_is_zero(H)) { - if (u384_is_zero(R)) - return g1_double(P); - return g1_identity(); - } - - uint384 H2 = fp_sqr(H); - uint384 H3 = fp_mul(H, H2); - uint384 U1H2 = fp_mul(P.x, H2); - - uint384 X3 = fp_sub(fp_sub(fp_sqr(R), H3), fp_add(U1H2, U1H2)); - uint384 Y3 = fp_sub(fp_mul(R, fp_sub(U1H2, X3)), fp_mul(P.y, H3)); - uint384 Zr = fp_mul(H, P.z); - - G1Point res; - res.x = X3; res.y = Y3; res.z = Zr; - return res; -} - -/// Scalar multiplication k * P (affine base point) -inline G1Point g1_mul(uint384 k, uint384 Px, uint384 Py) { - G1Point result = g1_identity(); - - for (int i = 5; i >= 0; i--) { - for (int bit = 63; bit >= 0; bit--) { - result = g1_double(result); - if ((k.limbs[i] >> bit) & 1) { - result = g1_add_mixed(result, Px, Py); - } - } - } - return result; -} - -/// Convert Jacobian -> affine -inline void g1_to_affine(G1Point p, thread uint384& ax, thread uint384& ay) { - if (g1_is_infinity(p)) { - ax = ZERO384; ay = ZERO384; - return; - } - uint384 z_inv = fp_inv(p.z); - uint384 z_inv2 = fp_sqr(z_inv); - uint384 z_inv3 = fp_mul(z_inv2, z_inv); - ax = fp_mul(p.x, z_inv2); - ay = fp_mul(p.y, z_inv3); -} - -// ============================================================================= -// BLS signature structures -// ============================================================================= - -/// Compressed BLS signature (G1 point, 48 bytes) -struct BLSSignature { - uchar data[48]; -}; - -/// Compressed BLS public key (G2 point, 96 bytes) -struct BLSPublicKey { - uchar data[96]; -}; - -/// Message hash for BLS verification (32 bytes, pre-hashed) -struct BLSMessage { - uchar data[32]; -}; - -// ============================================================================= -// Deserialization helpers -// ============================================================================= - -/// Deserialize a 48-byte compressed G1 point to uint384 x-coordinate. -/// Format: big-endian 48 bytes. Bit 383 = compression flag, bit 382 = infinity, -/// bit 381 = sign of y (0 = positive). -inline uint384 deserialize_fp(device const uchar* data) { - uint384 r = {}; - // Big-endian to little-endian limbs - for (int limb = 0; limb < 6; limb++) { - ulong val = 0; - for (int byte_idx = 0; byte_idx < 8; byte_idx++) { - // Byte position: (5 - limb) * 8 + (7 - byte_idx) - int src = (5 - limb) * 8 + (7 - byte_idx); - if (src < 48) - val |= (ulong)data[src] << (byte_idx * 8); - } - r.limbs[limb] = val; - } - return r; -} - -/// Decompress G1 point: recover y from x using curve equation y^2 = x^3 + 4 -inline bool decompress_g1(uint384 x_raw, bool y_positive, thread uint384& x_mont, thread uint384& y_mont) { - x_mont = to_mont(x_raw); - - // y^2 = x^3 + 4 - uint384 x2 = fp_sqr(x_mont); - uint384 x3 = fp_mul(x2, x_mont); - uint384 b_mont = to_mont(uint384{{4, 0, 0, 0, 0, 0}}); - uint384 y2 = fp_add(x3, b_mont); - - // Square root via Tonelli-Shanks. For BLS12-381, p = 3 mod 4, - // so sqrt(a) = a^((p+1)/4). - // (p+1)/4 computed offline - uint384 exp = {{ - 0xEE7FBFFFFFFFEAAFUL, - 0x07AAFFFFAC54FFFFUL, - 0xD9CC34A83DAC3D89UL, - 0xD91DD2E13CE144AFUL, - 0x92C6E9ED90D2EB35UL, - 0x0680447A8E5FF9A6UL - }}; - - uint384 y_cand = BLS_R; - uint384 base = y2; - for (int i = 0; i < 6; i++) { - for (int bit = 0; bit < 64; bit++) { - if ((exp.limbs[i] >> bit) & 1) { - y_cand = fp_mul(y_cand, base); - } - base = fp_sqr(base); - } - } - - // Verify: y_cand^2 == y2 - uint384 check = fp_sqr(y_cand); - if (u384_cmp(check, y2) != 0) - return false; // Not on curve - - // Pick correct sign - uint384 y_normal = from_mont(y_cand); - bool is_positive = (y_normal.limbs[0] & 1) == 0; - if (is_positive != y_positive) { - y_mont = fp_neg(y_cand); - } else { - y_mont = y_cand; - } - - return true; -} - -// ============================================================================= -// BLS Verification kernel -// ============================================================================= - -/// Batch BLS signature verification. -/// Each thread verifies one signature independently. -/// -/// For initial implementation, this performs the G1 point operations needed -/// for verification (deserialization, decompression, point arithmetic). -/// The final pairing check is deferred to the host (requires Fp12 Miller loop). -/// -/// Output: -/// results[tid] = 1 if the G1 operations succeeded (point on curve, valid) -/// results[tid] = 0 if deserialization or curve check failed -kernel void bls_verify_batch( - device const BLSSignature* sigs [[buffer(0)]], - device const BLSPublicKey* pubkeys [[buffer(1)]], - device const BLSMessage* messages [[buffer(2)]], - device uint* results [[buffer(3)]], - constant uint& num_sigs [[buffer(4)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= num_sigs) return; - - // -- Deserialize signature (compressed G1 point) -------------------------- - device const uchar* sig_data = sigs[tid].data; - - // Check flags byte - uchar flags = sig_data[0]; - bool compressed = (flags & 0x80) != 0; - bool infinity = (flags & 0x40) != 0; - bool y_sign = (flags & 0x20) != 0; - - if (infinity) { - // Signature at infinity is invalid - results[tid] = 0; - return; - } - - if (!compressed) { - // We only handle compressed format - results[tid] = 0; - return; - } - - // Clear flag bits for x-coordinate deserialization - uchar clean_data[48]; - for (int i = 0; i < 48; i++) clean_data[i] = sig_data[i]; - clean_data[0] &= 0x1F; - - uint384 x_raw = {}; - for (int limb = 0; limb < 6; limb++) { - ulong val = 0; - for (int b = 0; b < 8; b++) { - int src = (5 - limb) * 8 + (7 - b); - if (src < 48) - val |= (ulong)clean_data[src] << (b * 8); - } - x_raw.limbs[limb] = val; - } - - // -- Decompress to affine G1 point ---------------------------------------- - uint384 sig_x, sig_y; - bool on_curve = decompress_g1(x_raw, !y_sign, sig_x, sig_y); - if (!on_curve) { - results[tid] = 0; - return; - } - - // -- Subgroup check deferred to host -- - // Full subgroup check (multiply by curve order r, verify identity point) - // requires ~250 point doublings which is too expensive on GPU per-thread. - // The host MUST perform subgroup check before accepting the verification - // result. We signal this via bit 1 of the result word: - // bit 0 = on-curve check passed - // bit 1 = subgroup check still required by host - results[tid] = 0x3; // on_curve=1, needs_subgroup_check=1 -} - -/// Aggregate N G1 signatures into one by summing the points. -/// This is the main GPU-accelerated operation for BLS aggregation. -/// -/// Input: N compressed G1 points (48 bytes each) -/// Output: 1 uncompressed G1 point (96 bytes: x[48] || y[48]) -/// -/// Uses parallel reduction within a threadgroup. -kernel void bls_aggregate_g1( - device const BLSSignature* sigs [[buffer(0)]], - device uchar* agg_out [[buffer(1)]], // 96 bytes - device atomic_uint* counter [[buffer(2)]], // Atomic completion counter - constant uint& num_sigs [[buffer(3)]], - uint tid [[thread_position_in_grid]], - uint tgid [[threadgroup_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint tg_size [[threads_per_threadgroup]]) -{ - // Each thread deserializes and decompresses one signature - G1Point local_sum = g1_identity(); - - if (tid < num_sigs) { - device const uchar* sig_data = sigs[tid].data; - - uchar flags = sig_data[0]; - bool infinity = (flags & 0x40) != 0; - bool y_sign = (flags & 0x20) != 0; - - if (!infinity) { - uchar clean_data[48]; - for (int i = 0; i < 48; i++) clean_data[i] = sig_data[i]; - clean_data[0] &= 0x1F; - - uint384 x_raw = {}; - for (int limb = 0; limb < 6; limb++) { - ulong val = 0; - for (int b = 0; b < 8; b++) { - int src = (5 - limb) * 8 + (7 - b); - if (src < 48) - val |= (ulong)clean_data[src] << (b * 8); - } - x_raw.limbs[limb] = val; - } - - uint384 sx, sy; - if (decompress_g1(x_raw, !y_sign, sx, sy)) { - local_sum.x = sx; - local_sum.y = sy; - local_sum.z = BLS_R; // Affine: z=1 in Montgomery - } - } - } - - // Threadgroup reduction: sum all points in this threadgroup. - // Use threadgroup memory for inter-thread communication. - threadgroup uint384 shared_x[256]; - threadgroup uint384 shared_y[256]; - threadgroup uint384 shared_z[256]; - - shared_x[lid] = local_sum.x; - shared_y[lid] = local_sum.y; - shared_z[lid] = local_sum.z; - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Binary reduction - for (uint stride = tg_size / 2; stride > 0; stride >>= 1) { - if (lid < stride) { - G1Point a; - a.x = shared_x[lid]; a.y = shared_y[lid]; a.z = shared_z[lid]; - - G1Point b; - b.x = shared_x[lid + stride]; b.y = shared_y[lid + stride]; b.z = shared_z[lid + stride]; - - if (!g1_is_infinity(b)) { - if (g1_is_infinity(a)) { - a = b; - } else { - // Full Jacobian addition (both points are Jacobian) - uint384 Z1sq = fp_sqr(a.z); - uint384 Z2sq = fp_sqr(b.z); - uint384 U1 = fp_mul(a.x, Z2sq); - uint384 U2 = fp_mul(b.x, Z1sq); - uint384 S1 = fp_mul(a.y, fp_mul(Z2sq, b.z)); - uint384 S2 = fp_mul(b.y, fp_mul(Z1sq, a.z)); - - uint384 H = fp_sub(U2, U1); - uint384 R = fp_sub(S2, S1); - - if (u384_is_zero(H)) { - if (u384_is_zero(R)) { - a = g1_double(a); - } else { - a = g1_identity(); - } - } else { - uint384 H2 = fp_sqr(H); - uint384 H3 = fp_mul(H, H2); - uint384 U1H2 = fp_mul(U1, H2); - a.x = fp_sub(fp_sub(fp_sqr(R), H3), fp_add(U1H2, U1H2)); - a.y = fp_sub(fp_mul(R, fp_sub(U1H2, a.x)), fp_mul(S1, H3)); - a.z = fp_mul(fp_mul(H, a.z), b.z); - } - } - } - - shared_x[lid] = a.x; - shared_y[lid] = a.y; - shared_z[lid] = a.z; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - // Thread 0 of each threadgroup writes partial result. - // The host sums partial results from each threadgroup. - if (lid == 0) { - G1Point partial; - partial.x = shared_x[0]; partial.y = shared_y[0]; partial.z = shared_z[0]; - - uint384 ax, ay; - g1_to_affine(partial, ax, ay); - - uint384 ax_norm = from_mont(ax); - uint384 ay_norm = from_mont(ay); - - // Write to output: big-endian 48 bytes x, then 48 bytes y - uint tg_offset = tgid * 96; - for (int limb = 0; limb < 6; limb++) { - for (int b = 0; b < 8; b++) { - int dst = (5 - limb) * 8 + (7 - b); - if (dst < 48) { - agg_out[tg_offset + dst] = uchar((ax_norm.limbs[limb] >> (b * 8)) & 0xFF); - agg_out[tg_offset + 48 + dst] = uchar((ay_norm.limbs[limb] >> (b * 8)) & 0xFF); - } - } - } - - atomic_fetch_add_explicit(counter, 1u, memory_order_relaxed); - } -} diff --git a/bls/gpu/metal/bls_authored.metal b/bls/gpu/metal/bls_authored.metal deleted file mode 100644 index 86a04ac..0000000 --- a/bls/gpu/metal/bls_authored.metal +++ /dev/null @@ -1,647 +0,0 @@ -// ============================================================================= -// BLS12-381 Metal Compute Shaders -// ============================================================================= -// -// GPU-accelerated elliptic curve operations for BLS12-381 on Apple Silicon. -// Implements G1 point operations for batch signature verification. -// -// BLS12-381 Parameters: -// p = 0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab -// r = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001 -// G1: y^2 = x^3 + 4 over Fp -// -// Copyright (C) 2024-2025 Lux Industries Inc. -// SPDX-License-Identifier: Apache-2.0 - -#include -using namespace metal; - -// ============================================================================= -// 384-bit Field Arithmetic (6 x 64-bit limbs) -// ============================================================================= - -// BLS12-381 base field prime p (6 limbs, little-endian) -constant uint64_t BLS_P[6] = { - 0xb9feffffffffaaab, - 0x1eabfffeb153ffff, - 0x6730d2a0f6b0f624, - 0x64774b84f38512bf, - 0x4b1ba7b6434bacd7, - 0x1a0111ea397fe69a -}; - -// Montgomery R^2 mod p (for converting to Montgomery form) -constant uint64_t BLS_R2[6] = { - 0xf4df1f341c341746, - 0x0a76e6a609d104f1, - 0x8de5476c4c95b6d5, - 0x67eb88a9939d83c0, - 0x9a793e85b519952d, - 0x11988fe592cae3aa -}; - -// Montgomery constant: -p^{-1} mod 2^64 -constant uint64_t BLS_INV = 0x89f3fffcfffcfffd; - -// Fp384 represented as 6 uint64 limbs -struct Fp384 { - uint64_t limbs[6]; -}; - -// G1 affine point -struct G1Affine { - Fp384 x; - Fp384 y; - bool infinity; -}; - -// G1 projective point (Jacobian coordinates) -struct G1Projective { - Fp384 x; - Fp384 y; - Fp384 z; -}; - -// ============================================================================= -// Multi-precision Arithmetic -// ============================================================================= - -// Add with carry -inline uint64_t adc(uint64_t a, uint64_t b, thread uint64_t& carry) { - uint64_t sum = a + b + carry; - carry = (sum < a) || ((sum == a) && (b > 0 || carry > 0)) ? 1 : 0; - // Simplified carry detection - if (carry == 0) { - if (sum < a || sum < b) carry = 1; - } - return sum; -} - -// Subtract with borrow -inline uint64_t sbb(uint64_t a, uint64_t b, thread uint64_t& borrow) { - uint64_t diff = a - b - borrow; - borrow = (a < b + borrow) ? 1 : 0; - return diff; -} - -// 64x64 -> 128 bit multiplication (returns low and high parts) -inline void mul64(uint64_t a, uint64_t b, thread uint64_t& lo, thread uint64_t& hi) { - lo = a * b; - hi = mulhi(a, b); -} - -// Compare: returns -1 if a < b, 0 if a == b, 1 if a > b -inline int fp384_cmp(thread const Fp384& a, constant uint64_t* b) { - for (int i = 5; i >= 0; i--) { - if (a.limbs[i] < b[i]) return -1; - if (a.limbs[i] > b[i]) return 1; - } - return 0; -} - -// Conditional subtraction: if a >= p, compute a - p -inline void fp384_reduce(thread Fp384& a) { - if (fp384_cmp(a, BLS_P) >= 0) { - uint64_t borrow = 0; - for (int i = 0; i < 6; i++) { - a.limbs[i] = sbb(a.limbs[i], BLS_P[i], borrow); - } - } -} - -// Fp addition: c = a + b mod p -inline Fp384 fp384_add(thread const Fp384& a, thread const Fp384& b) { - Fp384 c; - uint64_t carry = 0; - - for (int i = 0; i < 6; i++) { - uint64_t sum = a.limbs[i] + b.limbs[i] + carry; - carry = (sum < a.limbs[i]) ? 1 : 0; - c.limbs[i] = sum; - } - - fp384_reduce(c); - return c; -} - -// Fp subtraction: c = a - b mod p -inline Fp384 fp384_sub(thread const Fp384& a, thread const Fp384& b) { - Fp384 c; - uint64_t borrow = 0; - - for (int i = 0; i < 6; i++) { - c.limbs[i] = sbb(a.limbs[i], b.limbs[i], borrow); - } - - // If underflow, add p - if (borrow) { - uint64_t carry = 0; - for (int i = 0; i < 6; i++) { - uint64_t sum = c.limbs[i] + BLS_P[i] + carry; - carry = (sum < c.limbs[i]) ? 1 : 0; - c.limbs[i] = sum; - } - } - - return c; -} - -// Montgomery reduction: given T < p*R, compute T*R^{-1} mod p -inline Fp384 fp384_mont_reduce(thread uint64_t t[12]) { - for (int i = 0; i < 6; i++) { - uint64_t m = t[i] * BLS_INV; - uint64_t carry = 0; - - for (int j = 0; j < 6; j++) { - uint64_t lo, hi; - mul64(m, BLS_P[j], lo, hi); - - uint64_t sum = t[i + j] + lo + carry; - carry = (sum < lo) ? hi + 1 : hi; - t[i + j] = sum; - } - - // Propagate carry - for (int j = i + 6; j < 12; j++) { - uint64_t sum = t[j] + carry; - carry = (sum < carry) ? 1 : 0; - t[j] = sum; - if (carry == 0) break; - } - } - - Fp384 result; - for (int i = 0; i < 6; i++) { - result.limbs[i] = t[i + 6]; - } - - fp384_reduce(result); - return result; -} - -// Montgomery multiplication: c = a * b * R^{-1} mod p -inline Fp384 fp384_mul(thread const Fp384& a, thread const Fp384& b) { - uint64_t t[12] = {0}; - - // Schoolbook multiplication - for (int i = 0; i < 6; i++) { - uint64_t carry = 0; - for (int j = 0; j < 6; j++) { - uint64_t lo, hi; - mul64(a.limbs[i], b.limbs[j], lo, hi); - - uint64_t sum = t[i + j] + lo + carry; - carry = (sum < lo) ? hi + 1 : hi; - t[i + j] = sum; - } - t[i + 6] += carry; - } - - return fp384_mont_reduce(t); -} - -// Montgomery squaring (optimized) -inline Fp384 fp384_sqr(thread const Fp384& a) { - return fp384_mul(a, a); -} - -// Double: c = 2 * a mod p -inline Fp384 fp384_double(thread const Fp384& a) { - Fp384 c; - uint64_t carry = 0; - - for (int i = 0; i < 6; i++) { - uint64_t sum = (a.limbs[i] << 1) | carry; - carry = a.limbs[i] >> 63; - c.limbs[i] = sum; - } - - fp384_reduce(c); - return c; -} - -// Negate: c = -a mod p = p - a -inline Fp384 fp384_neg(thread const Fp384& a) { - bool is_zero = true; - for (int i = 0; i < 6; i++) { - if (a.limbs[i] != 0) { is_zero = false; break; } - } - - if (is_zero) return a; - - Fp384 c; - uint64_t borrow = 0; - for (int i = 0; i < 6; i++) { - c.limbs[i] = sbb(BLS_P[i], a.limbs[i], borrow); - } - - return c; -} - -// Check if zero -inline bool fp384_is_zero(thread const Fp384& a) { - for (int i = 0; i < 6; i++) { - if (a.limbs[i] != 0) return false; - } - return true; -} - -// ============================================================================= -// G1 Point Arithmetic (Jacobian Projective Coordinates) -// ============================================================================= - -// Point at infinity (identity element) -inline G1Projective g1_identity() { - G1Projective p; - for (int i = 0; i < 6; i++) { - p.x.limbs[i] = 0; - p.y.limbs[i] = i == 0 ? 1 : 0; // y = 1 in Montgomery form - p.z.limbs[i] = 0; - } - return p; -} - -// Check if point is at infinity (Z == 0) -inline bool g1_is_identity(thread const G1Projective& p) { - return fp384_is_zero(p.z); -} - -// Point doubling in Jacobian coordinates -// Formula: http://www.hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-0.html#doubling-dbl-2009-l -inline G1Projective g1_double(thread const G1Projective& p) { - if (g1_is_identity(p)) { - return p; - } - - // A = X1^2 - Fp384 a = fp384_sqr(p.x); - // B = Y1^2 - Fp384 b = fp384_sqr(p.y); - // C = B^2 - Fp384 c = fp384_sqr(b); - - // D = 2*((X1+B)^2 - A - C) - Fp384 xb = fp384_add(p.x, b); - Fp384 xb2 = fp384_sqr(xb); - Fp384 d = fp384_sub(xb2, a); - d = fp384_sub(d, c); - d = fp384_double(d); - - // E = 3*A - Fp384 e = fp384_add(a, a); - e = fp384_add(e, a); - - // F = E^2 - Fp384 f = fp384_sqr(e); - - // X3 = F - 2*D - Fp384 d2 = fp384_double(d); - G1Projective result; - result.x = fp384_sub(f, d2); - - // Y3 = E*(D - X3) - 8*C - Fp384 dx3 = fp384_sub(d, result.x); - Fp384 edx3 = fp384_mul(e, dx3); - Fp384 c8 = fp384_double(c); - c8 = fp384_double(c8); - c8 = fp384_double(c8); - result.y = fp384_sub(edx3, c8); - - // Z3 = 2*Y1*Z1 - Fp384 yz = fp384_mul(p.y, p.z); - result.z = fp384_double(yz); - - return result; -} - -// Mixed addition: R = P + Q where Q is affine (Z_Q = 1) -// More efficient when one point is in affine form -inline G1Projective g1_add_mixed(thread const G1Projective& p, thread const G1Affine& q) { - if (q.infinity) return p; - if (g1_is_identity(p)) { - G1Projective r; - r.x = q.x; - r.y = q.y; - // Z = 1 in Montgomery form - for (int i = 0; i < 6; i++) r.z.limbs[i] = BLS_R2[i]; // R mod p - return r; - } - - // Z1Z1 = Z1^2 - Fp384 z1z1 = fp384_sqr(p.z); - - // U2 = X2*Z1Z1 - Fp384 u2 = fp384_mul(q.x, z1z1); - - // S2 = Y2*Z1*Z1Z1 - Fp384 s2 = fp384_mul(p.z, z1z1); - s2 = fp384_mul(q.y, s2); - - // H = U2 - X1 - Fp384 h = fp384_sub(u2, p.x); - - // HH = H^2 - Fp384 hh = fp384_sqr(h); - - // I = 4*HH - Fp384 i = fp384_double(hh); - i = fp384_double(i); - - // J = H*I - Fp384 j = fp384_mul(h, i); - - // r = 2*(S2 - Y1) - Fp384 r = fp384_sub(s2, p.y); - r = fp384_double(r); - - // V = X1*I - Fp384 v = fp384_mul(p.x, i); - - // X3 = r^2 - J - 2*V - Fp384 r2 = fp384_sqr(r); - Fp384 v2 = fp384_double(v); - G1Projective result; - result.x = fp384_sub(r2, j); - result.x = fp384_sub(result.x, v2); - - // Y3 = r*(V - X3) - 2*Y1*J - Fp384 vx3 = fp384_sub(v, result.x); - Fp384 rvx3 = fp384_mul(r, vx3); - Fp384 y1j = fp384_mul(p.y, j); - y1j = fp384_double(y1j); - result.y = fp384_sub(rvx3, y1j); - - // Z3 = (Z1 + H)^2 - Z1Z1 - HH - Fp384 zh = fp384_add(p.z, h); - Fp384 zh2 = fp384_sqr(zh); - result.z = fp384_sub(zh2, z1z1); - result.z = fp384_sub(result.z, hh); - - return result; -} - -// Copy projective point (for address space conversion) -inline G1Projective g1_copy(constant G1Projective& src) { - G1Projective dst; - for (int i = 0; i < 6; i++) { - dst.x.limbs[i] = src.x.limbs[i]; - dst.y.limbs[i] = src.y.limbs[i]; - dst.z.limbs[i] = src.z.limbs[i]; - } - return dst; -} - -inline G1Projective g1_copy_tg(threadgroup G1Projective& src) { - G1Projective dst; - for (int i = 0; i < 6; i++) { - dst.x.limbs[i] = src.x.limbs[i]; - dst.y.limbs[i] = src.y.limbs[i]; - dst.z.limbs[i] = src.z.limbs[i]; - } - return dst; -} - -// Full point addition: R = P + Q (both projective) -inline G1Projective g1_add(thread const G1Projective& p, thread const G1Projective& q) { - if (g1_is_identity(p)) return q; - if (g1_is_identity(q)) return p; - - // Z1Z1 = Z1^2, Z2Z2 = Z2^2 - Fp384 z1z1 = fp384_sqr(p.z); - Fp384 z2z2 = fp384_sqr(q.z); - - // U1 = X1*Z2Z2, U2 = X2*Z1Z1 - Fp384 u1 = fp384_mul(p.x, z2z2); - Fp384 u2 = fp384_mul(q.x, z1z1); - - // S1 = Y1*Z2*Z2Z2, S2 = Y2*Z1*Z1Z1 - Fp384 s1 = fp384_mul(p.y, q.z); - s1 = fp384_mul(s1, z2z2); - Fp384 s2 = fp384_mul(q.y, p.z); - s2 = fp384_mul(s2, z1z1); - - // H = U2 - U1 - Fp384 h = fp384_sub(u2, u1); - - // Check if points are equal (H == 0 and S2 - S1 == 0) - bool h_zero = fp384_is_zero(h); - Fp384 s_diff = fp384_sub(s2, s1); - bool s_zero = fp384_is_zero(s_diff); - - if (h_zero && s_zero) { - return g1_double(p); // P == Q, use doubling - } - if (h_zero) { - return g1_identity(); // P == -Q, result is identity - } - - // I = (2*H)^2 - Fp384 h2 = fp384_double(h); - Fp384 i = fp384_sqr(h2); - - // J = H*I - Fp384 j = fp384_mul(h, i); - - // r = 2*(S2 - S1) - Fp384 r = fp384_double(s_diff); - - // V = U1*I - Fp384 v = fp384_mul(u1, i); - - // X3 = r^2 - J - 2*V - Fp384 r2 = fp384_sqr(r); - Fp384 v2 = fp384_double(v); - G1Projective result; - result.x = fp384_sub(r2, j); - result.x = fp384_sub(result.x, v2); - - // Y3 = r*(V - X3) - 2*S1*J - Fp384 vx3 = fp384_sub(v, result.x); - Fp384 rvx3 = fp384_mul(r, vx3); - Fp384 s1j = fp384_mul(s1, j); - s1j = fp384_double(s1j); - result.y = fp384_sub(rvx3, s1j); - - // Z3 = ((Z1 + Z2)^2 - Z1Z1 - Z2Z2) * H - Fp384 z12 = fp384_add(p.z, q.z); - Fp384 z12_2 = fp384_sqr(z12); - result.z = fp384_sub(z12_2, z1z1); - result.z = fp384_sub(result.z, z2z2); - result.z = fp384_mul(result.z, h); - - return result; -} - -// Scalar multiplication using double-and-add -// scalar is 256 bits (4 x 64-bit limbs) -inline G1Projective g1_scalar_mul(thread const G1Projective& p, constant uint64_t* scalar) { - G1Projective result = g1_identity(); - G1Projective base = p; - - // Process 256 bits - for (int limb = 0; limb < 4; limb++) { - uint64_t bits = scalar[limb]; - for (int bit = 0; bit < 64; bit++) { - if (bits & 1) { - result = g1_add(result, base); - } - base = g1_double(base); - bits >>= 1; - } - } - - return result; -} - -// ============================================================================= -// Metal Compute Kernels -// ============================================================================= - -// Batch point addition kernel -// Adds pairs of points in parallel -kernel void g1_batch_add( - device G1Projective* results [[buffer(0)]], - constant G1Projective* points_a [[buffer(1)]], - constant G1Projective* points_b [[buffer(2)]], - constant uint& count [[buffer(3)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= count) return; - - G1Projective a = points_a[tid]; - G1Projective b = points_b[tid]; - results[tid] = g1_add(a, b); -} - -// Batch point doubling kernel -kernel void g1_batch_double( - device G1Projective* results [[buffer(0)]], - constant G1Projective* points [[buffer(1)]], - constant uint& count [[buffer(2)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= count) return; - - G1Projective p = points[tid]; - results[tid] = g1_double(p); -} - -// Parallel scalar multiplication kernel -// Each thread computes one scalar multiplication -kernel void g1_batch_scalar_mul( - device G1Projective* results [[buffer(0)]], - constant G1Projective* points [[buffer(1)]], - constant uint64_t* scalars [[buffer(2)]], // 4 limbs per scalar - constant uint& count [[buffer(3)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= count) return; - - G1Projective p = points[tid]; - constant uint64_t* scalar = scalars + tid * 4; - results[tid] = g1_scalar_mul(p, scalar); -} - -// Multi-scalar multiplication (MSM) using bucket method -// Suitable for batch signature verification -// This is a simplified version; production would use Pippenger's algorithm -kernel void g1_msm_accumulate( - device G1Projective* buckets [[buffer(0)]], - constant G1Projective* points [[buffer(1)]], - constant uint8_t* bucket_indices [[buffer(2)]], // Which bucket each point goes to - constant uint& num_points [[buffer(3)]], - constant uint& num_buckets [[buffer(4)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= num_points) return; - - uint bucket_idx = bucket_indices[tid]; - if (bucket_idx >= num_buckets) return; - - G1Projective p = points[tid]; - - // Atomic-style accumulation (simplified - real impl needs proper sync) - // This accumulates point into the appropriate bucket - G1Projective current = buckets[bucket_idx]; - buckets[bucket_idx] = g1_add(current, p); -} - -// Reduce buckets for MSM final result -kernel void g1_msm_reduce( - device G1Projective* result [[buffer(0)]], - constant G1Projective* buckets [[buffer(1)]], - constant uint& num_buckets [[buffer(2)]], - constant uint& window_bits [[buffer(3)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid != 0) return; // Single thread for final reduction - - G1Projective sum = g1_identity(); - G1Projective running = g1_identity(); - - // Process buckets from highest to lowest - for (int i = num_buckets - 1; i >= 0; i--) { - G1Projective bucket = g1_copy(buckets[i]); - running = g1_add(running, bucket); - sum = g1_add(sum, running); - } - - result[0] = sum; -} - -// Batch signature verification helper -// Computes: sum_i (r_i * P_i) where P_i are public keys and r_i are random scalars -// Used for verifying aggregate signatures efficiently -kernel void bls_batch_verify_msm( - device G1Projective* result [[buffer(0)]], - constant G1Affine* public_keys [[buffer(1)]], - constant uint64_t* random_scalars [[buffer(2)]], // 4 limbs per scalar - constant uint& count [[buffer(3)]], - uint tid [[thread_position_in_grid]], - uint threads_per_group [[threads_per_threadgroup]], - threadgroup G1Projective* shared_mem [[threadgroup(0)]]) -{ - G1Projective local_sum = g1_identity(); - - // Each thread processes multiple points with stride - uint total_threads = threads_per_group; - for (uint i = tid; i < count; i += total_threads) { - // Convert affine to projective - G1Projective p; - p.x = public_keys[i].x; - p.y = public_keys[i].y; - for (int j = 0; j < 6; j++) p.z.limbs[j] = BLS_R2[j]; // Z = 1 - - if (public_keys[i].infinity) { - p = g1_identity(); - } - - // Scalar multiplication - constant uint64_t* scalar = random_scalars + i * 4; - G1Projective scaled = g1_scalar_mul(p, scalar); - - // Accumulate - local_sum = g1_add(local_sum, scaled); - } - - // Store to shared memory - shared_mem[tid] = local_sum; - - // Barrier - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Tree reduction - for (uint stride = threads_per_group / 2; stride > 0; stride >>= 1) { - if (tid < stride) { - G1Projective a = g1_copy_tg(shared_mem[tid]); - G1Projective b = g1_copy_tg(shared_mem[tid + stride]); - shared_mem[tid] = g1_add(a, b); - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - // Thread 0 writes final result - if (tid == 0) { - result[0] = g1_copy_tg(shared_mem[0]); - } -} diff --git a/bls/gpu/metal/bls_combined_miller.metal b/bls/gpu/metal/bls_combined_miller.metal deleted file mode 100644 index 854dbdb..0000000 --- a/bls/gpu/metal/bls_combined_miller.metal +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// BLS12-381 combined-pair Miller-loop kernel pack. -// -// One-shot Miller loop over k pairs (P_i, Q_i) producing the -// Fp12 product prod_i miller_loop(P_i, Q_i) pre-final-exponentiation, -// byte-equal to the canonical CPU path -// for i in 0..k: ml[i] = blst_miller_loop(Q_i, P_i) -// product = tree_reduce_fp12(ml[0..k-1]) -// -// Layout reuses the existing Stage-3 kernel split (init / add_T / dbl_T / -// sqr_ret / fold_line / finalize from bls_miller.metal) — those already -// handle the per-bit Miller iteration on N=k workitems in one dispatch -// each. Fusing means: ONE driver call dispatches the whole 6-stage -// pipeline once with N=k, then folds the k Fp12 outputs to a single -// product via the canonical tree reduction below. -// -// New kernel introduced here: -// k_combined_miller_reduce — one round of pairwise tree-reduce on -// Fp12. Caller invokes it ceil(log2(k)) -// times with shrinking workitem counts. -// -// Determinism: the round-by-round shape is canonical (matches -// tree_reduce_fp12 in cpp/bls_pairing.cpp). Round k+1 multiplies -// adjacent outputs of round k; an odd count carries the last element -// forward unchanged. The final result lands at ret_buf[0]. - -#define BLS_FP12_NO_KERNELS -#define BLS_FP6_NO_KERNELS -#define BLS_FP2_NO_KERNELS -#include "bls_fp12.metal" -#undef BLS_FP12_NO_KERNELS -#undef BLS_FP6_NO_KERNELS -#undef BLS_FP2_NO_KERNELS - -// k_combined_miller_reduce — one round of canonical pairwise tree reduction. -// -// Reads in[2*tid] * in[2*tid+1] for tid < pairs. -// Writes out[tid]. -// If carry==1u and tid==pairs (one extra workitem dispatched), writes -// out[tid] = in[2*tid] (carry-forward of the last element when odd input). -// -// Caller pattern: -// n = k -// while n > 1: -// pairs = n / 2 -// odd = n & 1u -// dispatch(pairs + odd) with carry = odd -// swap(in_buf, out_buf) -// n = pairs + odd -// -// Determinism: the index map (in[2i], in[2i+1]) -> out[i] is canonical; -// odd carry is the last element verbatim. Matches tree_reduce_fp12. -kernel void k_combined_miller_reduce( - device const Fp12* in_buf [[buffer(0)]], - device Fp12* out_buf [[buffer(1)]], - constant uint& pairs [[buffer(2)]], - constant uint& carry [[buffer(3)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid < pairs) { - out_buf[tid] = fp12_mul(in_buf[2u * tid], in_buf[2u * tid + 1u]); - return; - } - if (carry != 0u && tid == pairs) { - // Odd input: last element passes through unchanged. - out_buf[tid] = in_buf[2u * tid]; - } -} diff --git a/bls/gpu/metal/bls_combined_miller_driver.h b/bls/gpu/metal/bls_combined_miller_driver.h deleted file mode 100644 index 3086542..0000000 --- a/bls/gpu/metal/bls_combined_miller_driver.h +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Host dispatcher entry point for the combined-pair Miller-loop kernel -// pack on Metal. See bls_combined_miller_driver.mm for the orchestration -// body and bls_combined_miller.metal for the fused-tail reduction kernel. -// -// Output is the pre-final-exponentiation Fp12 product -// -// prod_i miller_loop(Q_i, P_i) for i in 0..k -// -// byte-equal the canonical CPU reference (per-pair blst_miller_loop + -// canonical pairwise tree reduction). Caller applies final_exp() once -// after this call to obtain the e(...) verdict. - -#pragma once - -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -// Combined-pair Miller loop on Metal. -// -// g1s : k * 96 bytes (uncompressed G1 points, blst_p1_affine layout) -// g2s : k * 192 bytes (uncompressed G2 points, blst_p2_affine layout) -// k : number of pairs (>= 1) -// fp12_out : 576-byte output (blst_fp12 layout, pre-final-exp product) -// -// Returns: -// 0 on success -// -1 on input error (null pointer, k == 0) -// -2 on Metal initialisation failure (no device, missing metallib, -// missing kernel symbols). Caller should fall back to CPU. -int bls_combined_miller_metal(const uint8_t* g1s, - const uint8_t* g2s, - size_t k, - uint8_t fp12_out[576]); - -#ifdef __cplusplus -} -#endif diff --git a/bls/gpu/metal/bls_combined_miller_driver.mm b/bls/gpu/metal/bls_combined_miller_driver.mm deleted file mode 100644 index bab65d4..0000000 --- a/bls/gpu/metal/bls_combined_miller_driver.mm +++ /dev/null @@ -1,275 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Host dispatcher for the combined-pair Miller-loop pipeline on Metal. -// -// One driver call runs the full Miller loop (init / add_T / dbl_T / -// sqr_ret / fold_line / finalize, sequenced per the BLS12-381 ate -// scalar |x|) over k pairs as N=k workitems, then collapses the k -// Fp12 outputs to a single product via canonical pairwise tree -// reduction. The output is the pre-final-exponentiation Fp12 product -// -// prod_i miller_loop(Q_i, P_i) -// -// byte-equal the CPU path -// -// std::vector ml(k); -// for i: blst_miller_loop(&ml[i], &Q_i, &P_i); -// tree_reduce_fp12(ml); // round-by-round pairwise -// out = ml[0]; // 576-byte Fp12, NO final_exp. -// -// Caller is expected to apply final_exp() once after this routine -// (matches the multi_pair semantics used by tests). - -#import -#import - -#include "bls_combined_miller_driver.h" - -#include -#include -#include -#include -#include - -namespace { - -// Layout constants — must match bls_miller.metal struct sizes. -constexpr size_t kP1AffBytes = 96; // sizeof(P1Aff) — uint384 X, Y -constexpr size_t kP2AffBytes = 192; // sizeof(P2Aff) — Fp2 X, Y -constexpr size_t kMillerInBytes = kP2AffBytes + kP1AffBytes; // Q || P -constexpr size_t kP2Bytes = 288; // sizeof(P2) Jacobian -constexpr size_t kFp2Bytes = 96; -constexpr size_t kFp12Bytes = 576; -constexpr size_t kLineBytes = 3 * kFp2Bytes; - -// Miller-loop phase doubling counts (BLS12-381 ate scalar |x| bit pattern). -// Same as bls_miller_test.mm and bls_pairing_test.mm. -constexpr uint32_t kPhases[5] = { 2u, 3u, 9u, 32u, 16u }; - -struct Ctx { - id device; - id queue; - id library; - // Miller PSOs. - id mil_init; - id mil_add_T; - id mil_dbl_T; - id mil_sqr_ret; - id mil_fold; - id mil_finalize; - // Tree-reduce PSO. - id reduce; - // Status of last load. - int last_status; -}; - -std::mutex g_ctx_mu; -Ctx* g_ctx = nullptr; - -const char* env_metallib_path() -{ - if (const char* p = std::getenv("BLS_COMBINED_MILLER_METALLIB")) { - if (p[0] != '\0') return p; - } - return nullptr; -} - -id load_library(id dev, NSError** err_out) -{ - // Try caller-supplied env path first (CI/tests), then standard install - // locations, then fall back to nil. - NSMutableArray* paths = [NSMutableArray array]; - if (const char* envp = env_metallib_path()) { - [paths addObject:[NSString stringWithUTF8String:envp]]; - } - [paths addObject:@"/usr/local/share/lux/crypto/bls_combined_miller.metallib"]; - [paths addObject:@"/usr/local/share/lux/crypto/bls_pairing.metallib"]; - - NSFileManager* fm = [NSFileManager defaultManager]; - for (NSString* p in paths) { - if (p.length == 0) continue; - if (![fm fileExistsAtPath:p]) continue; - NSError* e = nil; - id lib = [dev newLibraryWithURL:[NSURL fileURLWithPath:p] - error:&e]; - if (lib) return lib; - if (err_out && *err_out == nil) *err_out = e; - } - return nil; -} - -id make_pso(id lib, - id dev, - const char* name, - int* status) -{ - NSError* e = nil; - id f = [lib newFunctionWithName:[NSString stringWithUTF8String:name]]; - if (!f) { - if (status) *status = -2; - return nil; - } - id pso = - [dev newComputePipelineStateWithFunction:f error:&e]; - if (!pso && status) *status = -3; - return pso; -} - -// Initialise (or reuse) the global context. Returns nullptr on failure; -// in that case the caller should fall back to the CPU multi_pair path. -Ctx* init_or_get_ctx() -{ - std::lock_guard g(g_ctx_mu); - if (g_ctx != nullptr) return g_ctx; - - Ctx* c = new Ctx{}; - c->device = MTLCreateSystemDefaultDevice(); - if (!c->device) { c->last_status = -1; g_ctx = c; return c; } - c->queue = [c->device newCommandQueue]; - - NSError* err = nil; - c->library = load_library(c->device, &err); - if (!c->library) { c->last_status = -2; g_ctx = c; return c; } - - int s = 0; - c->mil_init = make_pso(c->library, c->device, "k_miller_init", &s); - c->mil_add_T = make_pso(c->library, c->device, "k_miller_add_T_and_line", &s); - c->mil_dbl_T = make_pso(c->library, c->device, "k_miller_dbl_T_and_line", &s); - c->mil_sqr_ret = make_pso(c->library, c->device, "k_miller_sqr_ret", &s); - c->mil_fold = make_pso(c->library, c->device, "k_miller_fold_line", &s); - c->mil_finalize = make_pso(c->library, c->device, "k_miller_finalize", &s); - c->reduce = make_pso(c->library, c->device, "k_combined_miller_reduce", &s); - if (s != 0 || - !c->mil_init || !c->mil_add_T || !c->mil_dbl_T || - !c->mil_sqr_ret || !c->mil_fold || !c->mil_finalize || - !c->reduce) { - c->last_status = -3; - g_ctx = c; - return c; - } - c->last_status = 0; - g_ctx = c; - return c; -} - -void encode_dispatch(id enc, - id pso, - std::initializer_list> bufs, - size_t threads) -{ - [enc setComputePipelineState:pso]; - NSUInteger i = 0; - for (id b : bufs) { - [enc setBuffer:b offset:0 atIndex:i]; - ++i; - } - NSUInteger tg = MIN((NSUInteger)16, pso.maxTotalThreadsPerThreadgroup); - [enc dispatchThreads:MTLSizeMake(threads, 1, 1) - threadsPerThreadgroup:MTLSizeMake(tg, 1, 1)]; -} - -} // namespace - -extern "C" int bls_combined_miller_metal(const uint8_t* g1s, - const uint8_t* g2s, - size_t k, - uint8_t fp12_out[576]) -{ - if (fp12_out == nullptr) return -1; - if (k == 0) return -1; - if (g1s == nullptr || g2s == nullptr) return -1; - - Ctx* ctx = init_or_get_ctx(); - if (!ctx || ctx->last_status != 0) return -2; - - @autoreleasepool { - // ---- Pack inputs in MillerIn layout: Q || P per workitem. ---- - std::vector in_packed(k * kMillerInBytes); - for (size_t i = 0; i < k; ++i) { - std::memcpy(in_packed.data() + i * kMillerInBytes, - g2s + i * kP2AffBytes, kP2AffBytes); - std::memcpy(in_packed.data() + i * kMillerInBytes + kP2AffBytes, - g1s + i * kP1AffBytes, kP1AffBytes); - } - - id dev = ctx->device; - auto buf = [&](size_t bytes) -> id { - return [dev newBufferWithLength:bytes - options:MTLResourceStorageModeShared]; - }; - auto buf_data = [&](const void* data, size_t bytes) -> id { - return [dev newBufferWithBytes:data - length:bytes - options:MTLResourceStorageModeShared]; - }; - - id bIn = buf_data(in_packed.data(), in_packed.size()); - id bT = buf(kP2Bytes * k); - id bRet = buf(kFp12Bytes * k); - id bPx2 = buf(kFp2Bytes * k); - id bLine = buf(kLineBytes * k); - id bMOut = buf(kFp12Bytes * k); - - // Two ping-pong Fp12 buffers for the tree reduction. bMOut holds - // the conjugated Miller outputs first, then alternates with bRed. - id bRed = buf(kFp12Bytes * k); - - uint32_t k32 = static_cast(k); - id bN = buf_data(&k32, sizeof(k32)); - - id cb = [ctx->queue commandBuffer]; - id enc = [cb computeCommandEncoder]; - - // ---- Miller loop on N=k workitems. ---- - encode_dispatch(enc, ctx->mil_init, { bIn, bT, bRet, bPx2, bN }, k); - for (int phase = 0; phase < 5; ++phase) { - encode_dispatch(enc, ctx->mil_add_T, { bIn, bT, bLine, bPx2, bN }, k); - encode_dispatch(enc, ctx->mil_fold, { bRet, bLine, bN }, k); - for (uint32_t r = 0; r < kPhases[phase]; ++r) { - encode_dispatch(enc, ctx->mil_sqr_ret, { bRet, bN }, k); - encode_dispatch(enc, ctx->mil_dbl_T, { bT, bLine, bPx2, bN }, k); - encode_dispatch(enc, ctx->mil_fold, { bRet, bLine, bN }, k); - } - } - encode_dispatch(enc, ctx->mil_finalize, { bRet, bMOut, bN }, k); - - // ---- Canonical Fp12 tree reduction over the k outputs. ---- - // round_in == bMOut, round_out == bRed; swap each round. - id round_in = bMOut; - id round_out = bRed; - size_t n = k; - - // For k == 1: no reduction rounds; round_in already holds the answer. - // We hold one in-flight constant buffer per round so each kernel - // dispatch sees the (pairs, carry) for its specific round. - std::vector> round_const_bufs; - round_const_bufs.reserve(64); // log2(k) << 64 always. - - while (n > 1) { - uint32_t pairs = static_cast(n / 2); - uint32_t carry = static_cast(n & 1u); - id bPairs = buf_data(&pairs, sizeof(pairs)); - id bCarry = buf_data(&carry, sizeof(carry)); - round_const_bufs.push_back(bPairs); - round_const_bufs.push_back(bCarry); - size_t threads = pairs + carry; - encode_dispatch(enc, ctx->reduce, - { round_in, round_out, bPairs, bCarry }, - threads); - // Swap. - id tmp = round_in; - round_in = round_out; - round_out = tmp; - n = pairs + carry; - } - - [enc endEncoding]; - [cb commit]; - [cb waitUntilCompleted]; - - // After the loop, round_in.contents[0..575] is the product. - std::memcpy(fp12_out, [round_in contents], kFp12Bytes); - } - return 0; -} diff --git a/bls/gpu/metal/bls_driver.h b/bls/gpu/metal/bls_driver.h deleted file mode 100644 index b23f3fc..0000000 --- a/bls/gpu/metal/bls_driver.h +++ /dev/null @@ -1,246 +0,0 @@ -// ============================================================================= -// Metal BLS12-381 - GPU Acceleration Interface -// ============================================================================= -// -// C++ interface for dispatching BLS12-381 operations to Metal compute shaders. -// Provides batch operations for signature verification and key aggregation. -// -// Copyright (C) 2024-2025 Lux Industries Inc. -// SPDX-License-Identifier: Apache-2.0 - -#pragma once -#include -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -// ============================================================================= -// Metal Context Management -// ============================================================================= - -/** - * Opaque handle to Metal compute context. - */ -typedef struct MetalBLSContext MetalBLSContext; - -/** - * Initialize Metal BLS context. - * Loads shaders and creates compute pipelines. - * @return Context handle, or NULL if Metal unavailable - */ -MetalBLSContext* metal_bls_init(void); - -/** - * Destroy Metal BLS context and release resources. - */ -void metal_bls_destroy(MetalBLSContext* ctx); - -/** - * Check if Metal acceleration is available. - * @return true if Metal GPU is available - */ -bool metal_bls_available(void); - -// ============================================================================= -// Field Element Types (384-bit) -// ============================================================================= - -/** - * 384-bit field element (6 x 64-bit limbs, little-endian). - */ -typedef struct { - uint64_t limbs[6]; -} Fp384; - -/** - * G1 affine point (x, y coordinates + infinity flag). - */ -typedef struct { - Fp384 x; - Fp384 y; - bool infinity; - uint8_t _pad[7]; // Alignment padding -} G1Affine; - -/** - * G1 projective point (Jacobian coordinates). - */ -typedef struct { - Fp384 x; - Fp384 y; - Fp384 z; -} G1Projective; - -// ============================================================================= -// Batch Point Operations -// ============================================================================= - -/** - * Batch point addition on GPU. - * Computes: results[i] = a[i] + b[i] for all i. - * @param ctx Metal context - * @param results Output array (count elements) - * @param a First input array (count elements) - * @param b Second input array (count elements) - * @param count Number of additions - * @return 0 on success, negative on error - */ -int metal_bls_batch_add( - MetalBLSContext* ctx, - G1Projective* results, - const G1Projective* a, - const G1Projective* b, - uint32_t count); - -/** - * Batch point doubling on GPU. - * Computes: results[i] = 2 * points[i] for all i. - * @param ctx Metal context - * @param results Output array (count elements) - * @param points Input array (count elements) - * @param count Number of doublings - * @return 0 on success, negative on error - */ -int metal_bls_batch_double( - MetalBLSContext* ctx, - G1Projective* results, - const G1Projective* points, - uint32_t count); - -/** - * Batch scalar multiplication on GPU. - * Computes: results[i] = scalars[i] * points[i] for all i. - * @param ctx Metal context - * @param results Output array (count elements) - * @param points Input array (count elements) - * @param scalars 256-bit scalars as 4x64-bit limbs each (count * 4 elements) - * @param count Number of multiplications - * @return 0 on success, negative on error - */ -int metal_bls_batch_scalar_mul( - MetalBLSContext* ctx, - G1Projective* results, - const G1Projective* points, - const uint64_t* scalars, - uint32_t count); - -// ============================================================================= -// Multi-Scalar Multiplication (MSM) -// ============================================================================= - -/** - * Multi-scalar multiplication on GPU. - * Computes: result = sum_i (scalars[i] * points[i]) - * Optimized using bucket method for batch signature verification. - * @param ctx Metal context - * @param result Output single point - * @param points Input affine points (count elements) - * @param scalars 256-bit scalars (count * 4 limbs) - * @param count Number of point-scalar pairs - * @return 0 on success, negative on error - */ -int metal_bls_msm( - MetalBLSContext* ctx, - G1Projective* result, - const G1Affine* points, - const uint64_t* scalars, - uint32_t count); - -// ============================================================================= -// Batch Signature Verification -// ============================================================================= - -/** - * Batch verify BLS signatures using random linear combination. - * More efficient than verifying signatures individually. - * - * Verification equation: - * e(sum_i(r_i * sig_i), G2) = e(sum_i(r_i * H(msg_i)), sum_i(r_i * pk_i)) - * - * @param ctx Metal context - * @param sigs Array of signatures (G2 points, 96 bytes each) - * @param pks Array of public keys (G1 points, 48 bytes each) - * @param msgs Array of message hashes (32 bytes each) - * @param count Number of signatures - * @param results Output: 1 if valid, 0 if invalid (for individual tracking) - * @return 0 on success (all valid), negative on error, positive = invalid count - */ -int metal_bls_batch_verify( - MetalBLSContext* ctx, - const uint8_t* const* sigs, - const uint8_t* const* pks, - const uint8_t* const* msgs, - uint32_t count, - int* results); - -/** - * Aggregate signatures on GPU. - * Computes: agg_sig = sum_i(sigs[i]) - * @param ctx Metal context - * @param agg_sig Output aggregated signature (96 bytes) - * @param sigs Array of signatures (96 bytes each) - * @param count Number of signatures - * @return 0 on success, negative on error - */ -int metal_bls_aggregate_sigs( - MetalBLSContext* ctx, - uint8_t* agg_sig, - const uint8_t* const* sigs, - uint32_t count); - -/** - * Aggregate public keys on GPU. - * Computes: agg_pk = sum_i(pks[i]) - * @param ctx Metal context - * @param agg_pk Output aggregated public key (48 bytes) - * @param pks Array of public keys (48 bytes each) - * @param count Number of public keys - * @return 0 on success, negative on error - */ -int metal_bls_aggregate_pks( - MetalBLSContext* ctx, - uint8_t* agg_pk, - const uint8_t* const* pks, - uint32_t count); - -// ============================================================================= -// Utility Functions -// ============================================================================= - -/** - * Convert affine point to projective. - */ -void metal_bls_affine_to_projective(G1Projective* proj, const G1Affine* affine); - -/** - * Convert projective point to affine (requires inversion). - */ -void metal_bls_projective_to_affine(G1Affine* affine, const G1Projective* proj); - -/** - * Deserialize compressed G1 point (48 bytes). - */ -int metal_bls_g1_decompress(G1Affine* point, const uint8_t* compressed); - -/** - * Serialize G1 point to compressed form (48 bytes). - */ -int metal_bls_g1_compress(uint8_t* compressed, const G1Affine* point); - -// ============================================================================= -// Error Codes -// ============================================================================= - -#define METAL_BLS_SUCCESS 0 -#define METAL_BLS_ERROR_NO_DEVICE -1 -#define METAL_BLS_ERROR_NO_SHADER -2 -#define METAL_BLS_ERROR_ALLOC -3 -#define METAL_BLS_ERROR_NULL_PTR -4 -#define METAL_BLS_ERROR_INVALID -5 - -#ifdef __cplusplus -} -#endif diff --git a/bls/gpu/metal/bls_driver.mm b/bls/gpu/metal/bls_driver.mm deleted file mode 100644 index 6d1a559..0000000 --- a/bls/gpu/metal/bls_driver.mm +++ /dev/null @@ -1,688 +0,0 @@ -// ============================================================================= -// Metal BLS12-381 - GPU Acceleration Implementation -// ============================================================================= -// -// Objective-C++ implementation for Metal compute shader dispatch. -// Manages GPU buffers, pipeline states, and kernel execution. -// -// Copyright (C) 2024-2025 Lux Industries Inc. -// SPDX-License-Identifier: Apache-2.0 - -#import -#import -#include "lux/crypto/metal_bls.h" -#include -#include -#include - -// ============================================================================= -// Metal Context Structure -// ============================================================================= - -struct MetalBLSContext { - id device; - id commandQueue; - id library; - - // Compute pipeline states - id pipelineBatchAdd; - id pipelineBatchDouble; - id pipelineBatchScalarMul; - id pipelineMSMAccumulate; - id pipelineMSMReduce; - id pipelineBatchVerifyMSM; - - // Reusable buffers (lazily allocated) - id scratchBuffer; - size_t scratchSize; -}; - -// ============================================================================= -// Initialization -// ============================================================================= - -extern "C" bool metal_bls_available(void) { - @autoreleasepool { - id device = MTLCreateSystemDefaultDevice(); - return device != nil; - } -} - -extern "C" MetalBLSContext* metal_bls_init(void) { - @autoreleasepool { - MetalBLSContext* ctx = new MetalBLSContext(); - memset(ctx, 0, sizeof(MetalBLSContext)); - - // Get default Metal device - ctx->device = MTLCreateSystemDefaultDevice(); - if (!ctx->device) { - delete ctx; - return nullptr; - } - - // Create command queue - ctx->commandQueue = [ctx->device newCommandQueue]; - if (!ctx->commandQueue) { - delete ctx; - return nullptr; - } - - // Load Metal library from compiled metallib or source - NSError* error = nil; - - // Try loading pre-compiled metallib first - // Check standard install locations (unified lux_crypto.metallib or legacy bls12_381.metallib) - NSArray* metallibPaths = @[ - @"/usr/local/share/lux/crypto/lux_crypto.metallib", - @"/usr/local/share/lux/crypto/bls12_381.metallib", - [[NSBundle mainBundle] pathForResource:@"lux_crypto" ofType:@"metallib"] ?: @"", - [[NSBundle mainBundle] pathForResource:@"bls12_381" ofType:@"metallib"] ?: @"" - ]; - - for (NSString* libPath in metallibPaths) { - if (libPath.length > 0 && [[NSFileManager defaultManager] fileExistsAtPath:libPath]) { - NSURL* libURL = [NSURL fileURLWithPath:libPath]; - ctx->library = [ctx->device newLibraryWithURL:libURL error:&error]; - if (ctx->library) break; - } - } - - // Fall back to compiling from source at runtime - if (!ctx->library) { - // Look for shader source file - NSString* shaderPath = nil; - - // Check common locations - NSArray* searchPaths = @[ - @"src/metal/bls12_381.metal", - @"../src/metal/bls12_381.metal", - @"metal/bls12_381.metal" - ]; - - for (NSString* path in searchPaths) { - if ([[NSFileManager defaultManager] fileExistsAtPath:path]) { - shaderPath = path; - break; - } - } - - if (shaderPath) { - NSString* source = [NSString stringWithContentsOfFile:shaderPath - encoding:NSUTF8StringEncoding - error:&error]; - if (source) { - MTLCompileOptions* options = [[MTLCompileOptions alloc] init]; - // Use mathMode instead of deprecated fastMathEnabled - if (@available(macOS 15.0, *)) { - options.mathMode = MTLMathModeFast; - } else { -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wdeprecated-declarations" - options.fastMathEnabled = YES; -#pragma clang diagnostic pop - } - - ctx->library = [ctx->device newLibraryWithSource:source - options:options - error:&error]; - } - } - } - - if (!ctx->library) { - NSLog(@"Metal BLS: Failed to load shader library: %@", - error ? error.localizedDescription : @"Unknown error"); - delete ctx; - return nullptr; - } - - // Create compute pipeline states - auto createPipeline = [&](const char* name) -> id { - id func = [ctx->library newFunctionWithName: - [NSString stringWithUTF8String:name]]; - if (!func) { - NSLog(@"Metal BLS: Function '%s' not found", name); - return nil; - } - - NSError* pipelineError = nil; - id pipeline = - [ctx->device newComputePipelineStateWithFunction:func - error:&pipelineError]; - if (!pipeline) { - NSLog(@"Metal BLS: Failed to create pipeline for '%s': %@", - name, pipelineError.localizedDescription); - } - return pipeline; - }; - - ctx->pipelineBatchAdd = createPipeline("g1_batch_add"); - ctx->pipelineBatchDouble = createPipeline("g1_batch_double"); - ctx->pipelineBatchScalarMul = createPipeline("g1_batch_scalar_mul"); - ctx->pipelineMSMAccumulate = createPipeline("g1_msm_accumulate"); - ctx->pipelineMSMReduce = createPipeline("g1_msm_reduce"); - ctx->pipelineBatchVerifyMSM = createPipeline("bls_batch_verify_msm"); - - // At minimum we need batch add for aggregation - if (!ctx->pipelineBatchAdd) { - delete ctx; - return nullptr; - } - - return ctx; - } -} - -extern "C" void metal_bls_destroy(MetalBLSContext* ctx) { - if (!ctx) return; - - @autoreleasepool { - // ARC handles release of Objective-C objects - ctx->scratchBuffer = nil; - ctx->pipelineBatchAdd = nil; - ctx->pipelineBatchDouble = nil; - ctx->pipelineBatchScalarMul = nil; - ctx->pipelineMSMAccumulate = nil; - ctx->pipelineMSMReduce = nil; - ctx->pipelineBatchVerifyMSM = nil; - ctx->library = nil; - ctx->commandQueue = nil; - ctx->device = nil; - } - - delete ctx; -} - -// ============================================================================= -// Helper: Create GPU Buffer -// ============================================================================= - -static id createBuffer(MetalBLSContext* ctx, size_t size) { - return [ctx->device newBufferWithLength:size - options:MTLResourceStorageModeShared]; -} - -static id createBufferWithData(MetalBLSContext* ctx, - const void* data, size_t size) { - return [ctx->device newBufferWithBytes:data - length:size - options:MTLResourceStorageModeShared]; -} - -// ============================================================================= -// Batch Point Operations -// ============================================================================= - -extern "C" int metal_bls_batch_add( - MetalBLSContext* ctx, - G1Projective* results, - const G1Projective* a, - const G1Projective* b, - uint32_t count) -{ - if (!ctx || !results || !a || !b || count == 0) { - return METAL_BLS_ERROR_NULL_PTR; - } - - if (!ctx->pipelineBatchAdd) { - return METAL_BLS_ERROR_NO_SHADER; - } - - @autoreleasepool { - size_t pointSize = sizeof(G1Projective); - size_t bufferSize = count * pointSize; - - // Create buffers - id bufferA = createBufferWithData(ctx, a, bufferSize); - id bufferB = createBufferWithData(ctx, b, bufferSize); - id bufferResult = createBuffer(ctx, bufferSize); - - if (!bufferA || !bufferB || !bufferResult) { - return METAL_BLS_ERROR_ALLOC; - } - - // Create command buffer - id commandBuffer = [ctx->commandQueue commandBuffer]; - id encoder = [commandBuffer computeCommandEncoder]; - - // Set pipeline and buffers - [encoder setComputePipelineState:ctx->pipelineBatchAdd]; - [encoder setBuffer:bufferResult offset:0 atIndex:0]; - [encoder setBuffer:bufferA offset:0 atIndex:1]; - [encoder setBuffer:bufferB offset:0 atIndex:2]; - [encoder setBytes:&count length:sizeof(count) atIndex:3]; - - // Dispatch - NSUInteger threadsPerGroup = ctx->pipelineBatchAdd.maxTotalThreadsPerThreadgroup; - if (threadsPerGroup > 256) threadsPerGroup = 256; - - MTLSize gridSize = MTLSizeMake(count, 1, 1); - MTLSize groupSize = MTLSizeMake(threadsPerGroup, 1, 1); - - [encoder dispatchThreads:gridSize threadsPerThreadgroup:groupSize]; - [encoder endEncoding]; - - // Execute and wait - [commandBuffer commit]; - [commandBuffer waitUntilCompleted]; - - // Copy results back - memcpy(results, [bufferResult contents], bufferSize); - - return METAL_BLS_SUCCESS; - } -} - -extern "C" int metal_bls_batch_double( - MetalBLSContext* ctx, - G1Projective* results, - const G1Projective* points, - uint32_t count) -{ - if (!ctx || !results || !points || count == 0) { - return METAL_BLS_ERROR_NULL_PTR; - } - - if (!ctx->pipelineBatchDouble) { - return METAL_BLS_ERROR_NO_SHADER; - } - - @autoreleasepool { - size_t bufferSize = count * sizeof(G1Projective); - - id bufferPoints = createBufferWithData(ctx, points, bufferSize); - id bufferResult = createBuffer(ctx, bufferSize); - - if (!bufferPoints || !bufferResult) { - return METAL_BLS_ERROR_ALLOC; - } - - id commandBuffer = [ctx->commandQueue commandBuffer]; - id encoder = [commandBuffer computeCommandEncoder]; - - [encoder setComputePipelineState:ctx->pipelineBatchDouble]; - [encoder setBuffer:bufferResult offset:0 atIndex:0]; - [encoder setBuffer:bufferPoints offset:0 atIndex:1]; - [encoder setBytes:&count length:sizeof(count) atIndex:2]; - - NSUInteger threadsPerGroup = MIN(256UL, - ctx->pipelineBatchDouble.maxTotalThreadsPerThreadgroup); - - [encoder dispatchThreads:MTLSizeMake(count, 1, 1) - threadsPerThreadgroup:MTLSizeMake(threadsPerGroup, 1, 1)]; - [encoder endEncoding]; - - [commandBuffer commit]; - [commandBuffer waitUntilCompleted]; - - memcpy(results, [bufferResult contents], bufferSize); - - return METAL_BLS_SUCCESS; - } -} - -extern "C" int metal_bls_batch_scalar_mul( - MetalBLSContext* ctx, - G1Projective* results, - const G1Projective* points, - const uint64_t* scalars, - uint32_t count) -{ - if (!ctx || !results || !points || !scalars || count == 0) { - return METAL_BLS_ERROR_NULL_PTR; - } - - if (!ctx->pipelineBatchScalarMul) { - return METAL_BLS_ERROR_NO_SHADER; - } - - @autoreleasepool { - size_t pointSize = count * sizeof(G1Projective); - size_t scalarSize = count * 4 * sizeof(uint64_t); // 256-bit scalars - - id bufferPoints = createBufferWithData(ctx, points, pointSize); - id bufferScalars = createBufferWithData(ctx, scalars, scalarSize); - id bufferResult = createBuffer(ctx, pointSize); - - if (!bufferPoints || !bufferScalars || !bufferResult) { - return METAL_BLS_ERROR_ALLOC; - } - - id commandBuffer = [ctx->commandQueue commandBuffer]; - id encoder = [commandBuffer computeCommandEncoder]; - - [encoder setComputePipelineState:ctx->pipelineBatchScalarMul]; - [encoder setBuffer:bufferResult offset:0 atIndex:0]; - [encoder setBuffer:bufferPoints offset:0 atIndex:1]; - [encoder setBuffer:bufferScalars offset:0 atIndex:2]; - [encoder setBytes:&count length:sizeof(count) atIndex:3]; - - NSUInteger threadsPerGroup = MIN(64UL, - ctx->pipelineBatchScalarMul.maxTotalThreadsPerThreadgroup); - - [encoder dispatchThreads:MTLSizeMake(count, 1, 1) - threadsPerThreadgroup:MTLSizeMake(threadsPerGroup, 1, 1)]; - [encoder endEncoding]; - - [commandBuffer commit]; - [commandBuffer waitUntilCompleted]; - - memcpy(results, [bufferResult contents], pointSize); - - return METAL_BLS_SUCCESS; - } -} - -// ============================================================================= -// Multi-Scalar Multiplication -// ============================================================================= - -extern "C" int metal_bls_msm( - MetalBLSContext* ctx, - G1Projective* result, - const G1Affine* points, - const uint64_t* scalars, - uint32_t count) -{ - if (!ctx || !result || !points || !scalars || count == 0) { - return METAL_BLS_ERROR_NULL_PTR; - } - - if (!ctx->pipelineBatchVerifyMSM) { - // Fall back to CPU implementation - return METAL_BLS_ERROR_NO_SHADER; - } - - @autoreleasepool { - size_t affineSize = count * sizeof(G1Affine); - size_t scalarSize = count * 4 * sizeof(uint64_t); - - id bufferPoints = createBufferWithData(ctx, points, affineSize); - id bufferScalars = createBufferWithData(ctx, scalars, scalarSize); - id bufferResult = createBuffer(ctx, sizeof(G1Projective)); - - if (!bufferPoints || !bufferScalars || !bufferResult) { - return METAL_BLS_ERROR_ALLOC; - } - - id commandBuffer = [ctx->commandQueue commandBuffer]; - id encoder = [commandBuffer computeCommandEncoder]; - - [encoder setComputePipelineState:ctx->pipelineBatchVerifyMSM]; - [encoder setBuffer:bufferResult offset:0 atIndex:0]; - [encoder setBuffer:bufferPoints offset:0 atIndex:1]; - [encoder setBuffer:bufferScalars offset:0 atIndex:2]; - [encoder setBytes:&count length:sizeof(count) atIndex:3]; - - // Use threadgroup memory for reduction - NSUInteger threadsPerGroup = MIN(256UL, - ctx->pipelineBatchVerifyMSM.maxTotalThreadsPerThreadgroup); - size_t sharedMemSize = threadsPerGroup * sizeof(G1Projective); - - [encoder setThreadgroupMemoryLength:sharedMemSize atIndex:0]; - - [encoder dispatchThreads:MTLSizeMake(threadsPerGroup, 1, 1) - threadsPerThreadgroup:MTLSizeMake(threadsPerGroup, 1, 1)]; - [encoder endEncoding]; - - [commandBuffer commit]; - [commandBuffer waitUntilCompleted]; - - memcpy(result, [bufferResult contents], sizeof(G1Projective)); - - return METAL_BLS_SUCCESS; - } -} - -// ============================================================================= -// Batch Signature Verification -// ============================================================================= - -extern "C" int metal_bls_batch_verify( - MetalBLSContext* ctx, - const uint8_t* const* sigs, - const uint8_t* const* pks, - const uint8_t* const* msgs, - uint32_t count, - int* results) -{ - if (!ctx || !sigs || !pks || !msgs || !results || count == 0) { - return METAL_BLS_ERROR_NULL_PTR; - } - - // For actual batch verification: - // 1. Generate random scalars r_i - // 2. Compute S = sum_i(r_i * sig_i) using MSM on G2 - // 3. Compute P = sum_i(r_i * pk_i) using MSM on G1 - // 4. Compute H = sum_i(r_i * H(msg_i)) using MSM on G2 - // 5. Verify pairing: e(G1, S) == e(P, H) - // - // For now, mark all as valid (placeholder) - // Full pairing implementation requires G2 arithmetic - - for (uint32_t i = 0; i < count; i++) { - // Validate inputs exist - if (!sigs[i] || !pks[i] || !msgs[i]) { - results[i] = 0; - } else { - results[i] = 1; // Placeholder - } - } - - return METAL_BLS_SUCCESS; -} - -// ============================================================================= -// Aggregation -// ============================================================================= - -extern "C" int metal_bls_aggregate_sigs( - MetalBLSContext* ctx, - uint8_t* agg_sig, - const uint8_t* const* sigs, - uint32_t count) -{ - if (!ctx || !agg_sig || !sigs || count == 0) { - return METAL_BLS_ERROR_NULL_PTR; - } - - // For G2 aggregation, we would: - // 1. Decompress each signature to G2 projective - // 2. Sum all points on GPU - // 3. Compress result - // - // Placeholder: XOR aggregation - memset(agg_sig, 0, 96); - for (uint32_t i = 0; i < count; i++) { - if (!sigs[i]) return METAL_BLS_ERROR_NULL_PTR; - for (int j = 0; j < 96; j++) { - agg_sig[j] ^= sigs[i][j]; - } - } - - return METAL_BLS_SUCCESS; -} - -extern "C" int metal_bls_aggregate_pks( - MetalBLSContext* ctx, - uint8_t* agg_pk, - const uint8_t* const* pks, - uint32_t count) -{ - if (!ctx || !agg_pk || !pks || count == 0) { - return METAL_BLS_ERROR_NULL_PTR; - } - - if (!ctx->pipelineBatchAdd) { - return METAL_BLS_ERROR_NO_SHADER; - } - - @autoreleasepool { - // Decompress public keys to G1 projective - std::vector points(count); - - for (uint32_t i = 0; i < count; i++) { - if (!pks[i]) return METAL_BLS_ERROR_NULL_PTR; - - G1Affine affine; - int err = metal_bls_g1_decompress(&affine, pks[i]); - if (err != METAL_BLS_SUCCESS) { - // Use identity for invalid points - memset(&points[i], 0, sizeof(G1Projective)); - continue; - } - - metal_bls_affine_to_projective(&points[i], &affine); - } - - // Parallel reduction using batch add - while (count > 1) { - uint32_t halfCount = count / 2; - - std::vector results(halfCount); - int err = metal_bls_batch_add(ctx, - results.data(), - points.data(), - points.data() + halfCount, - halfCount); - if (err != METAL_BLS_SUCCESS) return err; - - // Handle odd element - if (count & 1) { - results.push_back(points[count - 1]); - halfCount++; - } - - points = std::move(results); - count = halfCount; - } - - // Compress result - G1Affine result_affine; - metal_bls_projective_to_affine(&result_affine, &points[0]); - return metal_bls_g1_compress(agg_pk, &result_affine); - } -} - -// ============================================================================= -// Utility Functions -// ============================================================================= - -extern "C" void metal_bls_affine_to_projective(G1Projective* proj, - const G1Affine* affine) { - if (!proj || !affine) return; - - proj->x = affine->x; - proj->y = affine->y; - - if (affine->infinity) { - // Identity element: Z = 0 - memset(proj->z.limbs, 0, sizeof(proj->z.limbs)); - } else { - // Z = 1 (in Montgomery form, this is R mod p) - // BLS12-381 R mod p (simplified - actual value needed) - memset(proj->z.limbs, 0, sizeof(proj->z.limbs)); - proj->z.limbs[0] = 1; - } -} - -extern "C" void metal_bls_projective_to_affine(G1Affine* affine, - const G1Projective* proj) { - if (!affine || !proj) return; - - // Check for identity (Z == 0) - bool is_identity = true; - for (int i = 0; i < 6; i++) { - if (proj->z.limbs[i] != 0) { - is_identity = false; - break; - } - } - - if (is_identity) { - memset(affine, 0, sizeof(G1Affine)); - affine->infinity = true; - return; - } - - // For full implementation: - // x_affine = X / Z^2 - // y_affine = Y / Z^3 - // Requires field inversion - - // Simplified placeholder - affine->x = proj->x; - affine->y = proj->y; - affine->infinity = false; -} - -extern "C" int metal_bls_g1_decompress(G1Affine* point, const uint8_t* compressed) { - if (!point || !compressed) return METAL_BLS_ERROR_NULL_PTR; - - // BLS12-381 G1 compressed format: - // - 48 bytes - // - Bit 7 of byte 0: compression flag (should be 1) - // - Bit 6 of byte 0: infinity flag - // - Bit 5 of byte 0: sign of y (0 = positive, 1 = negative) - // - Remaining bits: x coordinate (big-endian) - - uint8_t flags = compressed[0]; - bool is_compressed = (flags >> 7) & 1; - bool is_infinity = (flags >> 6) & 1; - bool y_sign = (flags >> 5) & 1; - - if (is_infinity) { - memset(point, 0, sizeof(G1Affine)); - point->infinity = true; - return METAL_BLS_SUCCESS; - } - - // Extract x coordinate (big-endian to little-endian limbs) - uint8_t x_bytes[48]; - memcpy(x_bytes, compressed, 48); - x_bytes[0] &= 0x1F; // Clear flag bits - - // Convert big-endian bytes to little-endian limbs - for (int i = 0; i < 6; i++) { - uint64_t limb = 0; - for (int j = 0; j < 8; j++) { - limb = (limb << 8) | x_bytes[i * 8 + j + (48 - 48)]; - } - point->x.limbs[5 - i] = limb; - } - - // Compute y from x (y^2 = x^3 + 4) - // Simplified: placeholder - real implementation needs field ops - memset(point->y.limbs, 0, sizeof(point->y.limbs)); - - point->infinity = false; - - return METAL_BLS_SUCCESS; -} - -extern "C" int metal_bls_g1_compress(uint8_t* compressed, const G1Affine* point) { - if (!compressed || !point) return METAL_BLS_ERROR_NULL_PTR; - - if (point->infinity) { - memset(compressed, 0, 48); - compressed[0] = 0xC0; // Compressed + infinity flags - return METAL_BLS_SUCCESS; - } - - // Convert little-endian limbs to big-endian bytes - for (int i = 0; i < 6; i++) { - uint64_t limb = point->x.limbs[5 - i]; - for (int j = 7; j >= 0; j--) { - compressed[i * 8 + j] = limb & 0xFF; - limb >>= 8; - } - } - - // Set compression flag - compressed[0] |= 0x80; - - // Set y sign bit (placeholder - needs actual y coordinate analysis) - // compressed[0] |= 0x20; // if y is negative - - return METAL_BLS_SUCCESS; -} diff --git a/bls/gpu/metal/bls_final_exp.metal b/bls/gpu/metal/bls_final_exp.metal deleted file mode 100644 index a9b2cbb..0000000 --- a/bls/gpu/metal/bls_final_exp.metal +++ /dev/null @@ -1,112 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// BLS12-381 final exponentiation on Metal. -// -// Computes f^((p^12 - 1) / r) byte-equal to blst_final_exp. Algorithm mirrors -// blst src/pairing.c::final_exp() exactly: -// -// easy part: ret = (conj(f) * inv(f)) ^ (p^2 + 1) -// hard part: ret = (zkcrypto chain over the easy-part output) -// -// The hard part uses: -// raise_to_z_div_by_2(out, a): out = a^(z/2) via cyclotomic squarings -// and multiplies in the addchain pattern -// z = 0xd201000000010000 (BLS scalar |x|) -// conjugated at the end (z is negative) -// raise_to_z(out, a) = sqr( raise_to_z_div_by_2(out, a) ) -// -// The full chain (~64 cyclotomic squarings + 6 muls per raise_to_z) cannot fit -// in one Metal kernel function under the MetalCompilerService XPC compile -// budget on M1 (same constraint that split the Miller loop into 6 kernels). -// We split into 6 bounded sub-kernels and orchestrate from the host: -// -// k_fe_easy : easy part (1 dispatch) -// k_fe_cyclo_sqr : ret_buf[i] = cyclotomic_sqr(ret_buf[i]) (in place) -// k_fe_mul : out = a * b -// k_fe_conj : ret_buf[i] = conj(ret_buf[i]) (in place) -// k_fe_frobenius : ret_buf[i] = frobenius(ret_buf[i], n) (in place) -// k_fe_copy : dst[i] = src[i] (memcpy) -// -// All arithmetic runs on Metal — the host only sequences kernel dispatches -// matching blst's exact addchain step sequence, so byte-equality holds. - -#define BLS_FP12_NO_KERNELS -#define BLS_FP6_NO_KERNELS -#define BLS_FP2_NO_KERNELS -#include "bls_fp12.metal" -#undef BLS_FP12_NO_KERNELS -#undef BLS_FP6_NO_KERNELS -#undef BLS_FP2_NO_KERNELS - -// ============================================================================= -// In-place / out-of-place Fp12 helpers exposed as kernels. -// ============================================================================= - -// k_fe_inv — out[tid] = inv(in[tid]). Standalone Fp12 inversion kernel. -// (Inversion is the single fattest Fp12 op — kept as its own kernel so the -// MetalCompilerService XPC budget is comfortably below the limit.) -kernel void k_fe_inv( - device const Fp12* in_buf [[buffer(0)]], - device Fp12* out_buf [[buffer(1)]], - constant uint& n [[buffer(2)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - out_buf[tid] = fp12_inv(in_buf[tid]); -} - -// In-place cyclotomic squaring. -kernel void k_fe_cyclo_sqr( - device Fp12* ret_buf [[buffer(0)]], - constant uint& n [[buffer(1)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - ret_buf[tid] = fp12_cyclotomic_sqr(ret_buf[tid]); -} - -// out = a * b (out, a, b may alias different slots) -kernel void k_fe_mul( - device const Fp12* a [[buffer(0)]], - device const Fp12* b [[buffer(1)]], - device Fp12* out [[buffer(2)]], - constant uint& n [[buffer(3)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - out[tid] = fp12_mul(a[tid], b[tid]); -} - -// In-place conjugate. -kernel void k_fe_conj( - device Fp12* ret_buf [[buffer(0)]], - constant uint& n [[buffer(1)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - ret_buf[tid] = fp12_conj(ret_buf[tid]); -} - -// In-place Frobenius with power n_pow ∈ {1, 2, 3}. -kernel void k_fe_frobenius( - device Fp12* ret_buf [[buffer(0)]], - constant uint& n [[buffer(1)]], - constant uint& n_pow [[buffer(2)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - ret_buf[tid] = fp12_frobenius(ret_buf[tid], n_pow); -} - -// dst = src (memcpy at the Fp12 granularity). Used to checkpoint values -// across the addchain (e.g. y3 = ret). -kernel void k_fe_copy( - device const Fp12* src [[buffer(0)]], - device Fp12* dst [[buffer(1)]], - constant uint& n [[buffer(2)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - dst[tid] = src[tid]; -} diff --git a/bls/gpu/metal/bls_fp12.metal b/bls/gpu/metal/bls_fp12.metal deleted file mode 100644 index 2959637..0000000 --- a/bls/gpu/metal/bls_fp12.metal +++ /dev/null @@ -1,300 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Fp12 = Fp6[w] / (w^2 - v) for BLS12-381. -// Layout: struct Fp12 { Fp6 c0, c1; } == blst_fp12 { blst_fp6 fp6[2]; } byte-equal. -// Element c0 + c1 w with w^2 = v. -// -// Algorithms mirror blst src/fp12_tower.c: -// -// mul (Karatsuba over Fp6): -// t0 = a0 b0, t1 = a1 b1 -// r1 = (a0 + a1)(b0 + b1) - t0 - t1 -// r0 = t0 + t1 * v where (a + b v + c v^2) * v = c (u+1) + a v + b v^2 -// -// sqr (Karatsuba): -// t0 = (a0 + a1)(a0 + a1*v), t1 = a0 a1 -// r1 = 2 t1 -// r0 = t0 - t1 - t1*v -// -// conjugate: (c0 + c1 w) -> (c0 - c1 w) = c0 + (-c1) w -// -// inv: (a0 - a1 w) / (a0^2 - a1^2 v) -// -// cyclotomic_sqr: see blst (uses sqr_fp4). Implementation mirrors blst exactly. -// -// Frobenius coefficients hardcoded from blst src/fp12_tower.c lines 706-731, -// already in Montgomery form. - -#ifndef BLS_FP12_INLINES -#define BLS_FP12_INLINES - -#define BLS_FP6_NO_KERNELS -#define BLS_FP2_NO_KERNELS -#include "bls_fp6.metal" -#undef BLS_FP6_NO_KERNELS -#undef BLS_FP2_NO_KERNELS - -struct Fp12 { Fp6 c0, c1; }; - -inline Fp12 fp12_add(Fp12 a, Fp12 b) { - Fp12 r; - r.c0 = fp6_add(a.c0, b.c0); - r.c1 = fp6_add(a.c1, b.c1); - return r; -} - -inline Fp12 fp12_sub(Fp12 a, Fp12 b) { - Fp12 r; - r.c0 = fp6_sub(a.c0, b.c0); - r.c1 = fp6_sub(a.c1, b.c1); - return r; -} - -// (a0 + a1 v + a2 v^2) * v = a2 (u+1) + a0 v + a1 v^2 -inline Fp6 fp6_mul_by_v(Fp6 a) { - Fp6 r; - r.c0 = fp2_mul_by_1_plus_u(a.c2); - r.c1 = a.c0; - r.c2 = a.c1; - return r; -} - -inline Fp12 fp12_mul(Fp12 a, Fp12 b) { - Fp6 t0 = fp6_mul(a.c0, b.c0); - Fp6 t1 = fp6_mul(a.c1, b.c1); - - // r1 = (a0 + a1)(b0 + b1) - t0 - t1 - Fp6 sa = fp6_add(a.c0, a.c1); - Fp6 sb = fp6_add(b.c0, b.c1); - Fp6 r1 = fp6_mul(sa, sb); - r1 = fp6_sub(r1, t0); - r1 = fp6_sub(r1, t1); - - // r0 = t0 + t1 * v - Fp6 r0 = fp6_add(t0, fp6_mul_by_v(t1)); - - Fp12 r; r.c0 = r0; r.c1 = r1; return r; -} - -inline Fp12 fp12_sqr(Fp12 a) { - // Karatsuba: t0 = (a0+a1)(a0 + a1*v), t1 = a0 a1 - // r1 = 2 t1 - // r0 = t0 - t1 - t1*v - Fp6 t0 = fp6_add(a.c0, a.c1); - Fp6 t1 = fp6_mul_by_v(a.c1); - t1 = fp6_add(a.c0, t1); - t0 = fp6_mul(t0, t1); - - Fp6 t2 = fp6_mul(a.c0, a.c1); - - Fp12 r; - // r1 = 2 t2 - r.c1 = fp6_add(t2, t2); - - // r0 = t0 - t2 - t2*v - Fp6 r0 = fp6_sub(t0, t2); - r0 = fp6_sub(r0, fp6_mul_by_v(t2)); - r.c0 = r0; - return r; -} - -inline Fp12 fp12_conj(Fp12 a) { - Fp12 r; - r.c0 = a.c0; - r.c1 = fp6_neg(a.c1); - return r; -} - -inline Fp12 fp12_inv(Fp12 a) { - Fp6 t0 = fp6_sqr(a.c0); - Fp6 t1 = fp6_sqr(a.c1); - t0 = fp6_sub(t0, fp6_mul_by_v(t1)); // a0^2 - a1^2 * v - Fp6 ti = fp6_inv(t0); - - Fp12 r; - r.c0 = fp6_mul(a.c0, ti); - r.c1 = fp6_mul(a.c1, ti); - r.c1 = fp6_neg(r.c1); - return r; -} - -// Cyclotomic squaring on Fp12. Defined for elements in the cyclotomic -// subgroup G_phi_12 (i.e., output of the easy part of final exponentiation). -// Mirrors blst's cyclotomic_sqr_fp12 + sqr_fp4 exactly. -inline void sqr_fp4(thread Fp2& r0, thread Fp2& r1, Fp2 a0, Fp2 a1) { - Fp2 t0 = fp2_sqr(a0); - Fp2 t1 = fp2_sqr(a1); - Fp2 sum = fp2_add(a0, a1); - - r0 = fp2_add(fp2_mul_by_1_plus_u(t1), t0); - - r1 = fp2_sqr(sum); - r1 = fp2_sub(r1, t0); - r1 = fp2_sub(r1, t1); -} - -inline Fp12 fp12_cyclotomic_sqr(Fp12 a) { - Fp2 t00, t01, t10, t11, t20, t21; - sqr_fp4(t00, t01, a.c0.c0, a.c1.c1); - sqr_fp4(t10, t11, a.c1.c0, a.c0.c2); - sqr_fp4(t20, t21, a.c0.c1, a.c1.c2); - - Fp12 r; - // r.c0.c0 = 3 t00 - 2 a.c0.c0 - Fp2 tmp = fp2_sub(t00, a.c0.c0); - r.c0.c0 = fp2_add(fp2_add(tmp, tmp), t00); - - // r.c0.c1 = 3 t10 - 2 a.c0.c1 - tmp = fp2_sub(t10, a.c0.c1); - r.c0.c1 = fp2_add(fp2_add(tmp, tmp), t10); - - // r.c0.c2 = 3 t20 - 2 a.c0.c2 - tmp = fp2_sub(t20, a.c0.c2); - r.c0.c2 = fp2_add(fp2_add(tmp, tmp), t20); - - // r.c1.c0 = 3 (t21 * (u+1)) + 2 a.c1.c0 - tmp = fp2_mul_by_1_plus_u(t21); - Fp2 add = fp2_add(tmp, a.c1.c0); - r.c1.c0 = fp2_add(fp2_add(add, add), tmp); - - // r.c1.c1 = 3 t01 + 2 a.c1.c1 - add = fp2_add(t01, a.c1.c1); - r.c1.c1 = fp2_add(fp2_add(add, add), t01); - - // r.c1.c2 = 3 t11 + 2 a.c1.c2 - add = fp2_add(t11, a.c1.c2); - r.c1.c2 = fp2_add(fp2_add(add, add), t11); - - return r; -} - -// Frobenius coefficients for Fp12 (Montgomery form), from blst src/fp12_tower.c -// lines 708-723. coeffs[n-1] = (u + 1)^((p^n - 1) / 6) in Fp2. - -constant uint384 FP12_FROB_RE_N1 = {{ - 0x07089552B319D465UL, 0xC6695F92B50A8313UL, 0x97E83CCCD117228FUL, - 0xA35BAECAB2DC29EEUL, 0x1CE393EA5DAACE4DUL, 0x08F2220FB0FB66EBUL -}}; -constant uint384 FP12_FROB_IM_N1 = {{ - 0xB2F66AAD4CE5D646UL, 0x5842A06BFC497CECUL, 0xCF4895D42599D394UL, - 0xC11B9CBA40A8E8D0UL, 0x2E3813CBE5A0DE89UL, 0x110EEFDA88847FAFUL -}}; - -constant uint384 FP12_FROB_RE_N2 = {{ - 0xECFB361B798DBA3AUL, 0xC100DDB891865A2CUL, 0x0EC08FF1232BDA8EUL, - 0xD5C13CC6F1CA4721UL, 0x47222A47BF7B5C04UL, 0x0110F184E51C5F59UL -}}; -constant uint384 FP12_FROB_IM_N2 = {{0,0,0,0,0,0}}; - -constant uint384 FP12_FROB_RE_N3 = {{ - 0x3E2F585DA55C9AD1UL, 0x4294213D86C18183UL, 0x382844C88B623732UL, - 0x92AD2AFD19103E18UL, 0x1D794E4FAC7CF0B9UL, 0x0BD592FC7D825EC8UL -}}; -constant uint384 FP12_FROB_IM_N3 = {{ - 0x7BCFA7A25AA30FDAUL, 0xDC17DEC12A927E7CUL, 0x2F088DD86B4EBEF1UL, - 0xD1CA2087DA74D4A7UL, 0x2DA2596696CEBC1DUL, 0x0E2B7EEDBBFD87D2UL -}}; - -inline Fp12 fp12_frobenius(Fp12 a, uint n) { - Fp6 r0 = fp6_frobenius(a.c0, n); - Fp6 r1 = fp6_frobenius(a.c1, n); - - Fp2 coeff; - if (n == 1u) { - coeff.c0 = FP12_FROB_RE_N1; coeff.c1 = FP12_FROB_IM_N1; - } else if (n == 2u) { - coeff.c0 = FP12_FROB_RE_N2; coeff.c1 = FP12_FROB_IM_N2; - } else { - coeff.c0 = FP12_FROB_RE_N3; coeff.c1 = FP12_FROB_IM_N3; - } - r1.c0 = fp2_mul(r1.c0, coeff); - r1.c1 = fp2_mul(r1.c1, coeff); - r1.c2 = fp2_mul(r1.c2, coeff); - - Fp12 r; r.c0 = r0; r.c1 = r1; return r; -} - -#endif // BLS_FP12_INLINES - -// ============================================================================= -// Kernels — buffer element size = 576 bytes (sizeof(Fp12) = 2 * 288). -// Higher-tower files #define BLS_FP12_NO_KERNELS before #including. -// ============================================================================= - -#ifndef BLS_FP12_NO_KERNELS - -kernel void k_fp12_add( - device const Fp12* a [[buffer(0)]], - device const Fp12* b [[buffer(1)]], - device Fp12* out [[buffer(2)]], - constant uint& n [[buffer(3)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - out[tid] = fp12_add(a[tid], b[tid]); -} - -kernel void k_fp12_sub( - device const Fp12* a [[buffer(0)]], - device const Fp12* b [[buffer(1)]], - device Fp12* out [[buffer(2)]], - constant uint& n [[buffer(3)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - out[tid] = fp12_sub(a[tid], b[tid]); -} - -kernel void k_fp12_mul( - device const Fp12* a [[buffer(0)]], - device const Fp12* b [[buffer(1)]], - device Fp12* out [[buffer(2)]], - constant uint& n [[buffer(3)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - out[tid] = fp12_mul(a[tid], b[tid]); -} - -kernel void k_fp12_sqr( - device const Fp12* a [[buffer(0)]], - device Fp12* out [[buffer(1)]], - constant uint& n [[buffer(2)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - out[tid] = fp12_sqr(a[tid]); -} - -kernel void k_fp12_inv( - device const Fp12* a [[buffer(0)]], - device Fp12* out [[buffer(1)]], - constant uint& n [[buffer(2)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - out[tid] = fp12_inv(a[tid]); -} - -kernel void k_fp12_conj( - device const Fp12* a [[buffer(0)]], - device Fp12* out [[buffer(1)]], - constant uint& n [[buffer(2)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - out[tid] = fp12_conj(a[tid]); -} - -kernel void k_fp12_cyclo_sqr( - device const Fp12* a [[buffer(0)]], - device Fp12* out [[buffer(1)]], - constant uint& n [[buffer(2)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - out[tid] = fp12_cyclotomic_sqr(a[tid]); -} - -#endif // BLS_FP12_NO_KERNELS diff --git a/bls/gpu/metal/bls_fp2.metal b/bls/gpu/metal/bls_fp2.metal deleted file mode 100644 index d1b82d8..0000000 --- a/bls/gpu/metal/bls_fp2.metal +++ /dev/null @@ -1,178 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Fp2 = Fp[u]/(u^2 + 1) for BLS12-381. -// Layout: struct Fp2 { Fp c0, c1; } == blst_fp2 { blst_fp fp[2]; } byte-for-byte. -// Element a + b*u stored as { c0=a, c1=b } in Montgomery form. -// -// Algorithms mirror blst src/fp12_tower.c: -// add/sub/neg : componentwise -// mul : Karatsuba -// (a0 + a1 u)(b0 + b1 u) -// = (a0 b0 - a1 b1) + ((a0+a1)(b0+b1) - a0 b0 - a1 b1) u -// sqr : complex squaring (a + b u)^2 = (a-b)(a+b) + 2 a b u -// inv : a^(-1) = (c0 - c1 u) / (c0^2 + c1^2) -// conj : (a + b u) -> (a - b u) -// frobenius_n : conj when n odd, identity when n even (since p ≡ 3 mod 4) -// -// Inline ops live behind a header guard (BLS_FP2_INLINES) so this file can be -// #included by bls_fp6.metal / bls_fp12.metal for type+ops without redefining -// kernels (those are guarded by BLS_FP2_NO_KERNELS). - -#ifndef BLS_FP2_INLINES -#define BLS_FP2_INLINES - -#include "bls_fp_ops.h.metal" - -struct Fp2 { uint384 c0, c1; }; - -inline Fp2 fp2_add(Fp2 a, Fp2 b) { - Fp2 r; r.c0 = fp_add(a.c0, b.c0); r.c1 = fp_add(a.c1, b.c1); return r; -} - -inline Fp2 fp2_sub(Fp2 a, Fp2 b) { - Fp2 r; r.c0 = fp_sub(a.c0, b.c0); r.c1 = fp_sub(a.c1, b.c1); return r; -} - -inline Fp2 fp2_neg(Fp2 a) { - Fp2 r; r.c0 = fp_neg(a.c0); r.c1 = fp_neg(a.c1); return r; -} - -inline Fp2 fp2_mul(Fp2 a, Fp2 b) { - uint384 aa = fp_mul(a.c0, b.c0); - uint384 bb = fp_mul(a.c1, b.c1); - uint384 sa = fp_add(a.c0, a.c1); - uint384 sb = fp_add(b.c0, b.c1); - uint384 cross = fp_mul(sa, sb); - Fp2 r; - r.c0 = fp_sub(aa, bb); - r.c1 = fp_sub(fp_sub(cross, aa), bb); - return r; -} - -inline Fp2 fp2_sqr(Fp2 a) { - uint384 ab = fp_mul(a.c0, a.c1); - uint384 sum = fp_add(a.c0, a.c1); - uint384 dif = fp_sub(a.c0, a.c1); - Fp2 r; - r.c0 = fp_mul(sum, dif); - r.c1 = fp_add(ab, ab); - return r; -} - -inline Fp2 fp2_conj(Fp2 a) { - Fp2 r; r.c0 = a.c0; r.c1 = fp_neg(a.c1); return r; -} - -inline Fp2 fp2_inv(Fp2 a) { - uint384 t0 = fp_sqr(a.c0); - uint384 t1 = fp_sqr(a.c1); - uint384 norm = fp_add(t0, t1); - uint384 ni = fp_inv(norm); - Fp2 r; - r.c0 = fp_mul(a.c0, ni); - r.c1 = fp_neg(fp_mul(a.c1, ni)); - return r; -} - -inline Fp2 fp2_frobenius(Fp2 a, uint n) { - return ((n & 1u) == 1u) ? fp2_conj(a) : a; -} - -// (a + b u)(1 + u) = (a - b) + (a + b) u -inline Fp2 fp2_mul_by_1_plus_u(Fp2 a) { - Fp2 r; - r.c0 = fp_sub(a.c0, a.c1); - r.c1 = fp_add(a.c0, a.c1); - return r; -} - -#endif // BLS_FP2_INLINES - -// ============================================================================= -// Kernels — emitted only when this file is the primary translation unit. -// Higher-tower files #define BLS_FP2_NO_KERNELS before #including. -// Buffer element size = 96 bytes (sizeof(Fp2) = 2 * 48). -// ============================================================================= - -#ifndef BLS_FP2_NO_KERNELS - -kernel void k_fp2_add( - device const Fp2* a [[buffer(0)]], - device const Fp2* b [[buffer(1)]], - device Fp2* out [[buffer(2)]], - constant uint& n [[buffer(3)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - out[tid] = fp2_add(a[tid], b[tid]); -} - -kernel void k_fp2_sub( - device const Fp2* a [[buffer(0)]], - device const Fp2* b [[buffer(1)]], - device Fp2* out [[buffer(2)]], - constant uint& n [[buffer(3)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - out[tid] = fp2_sub(a[tid], b[tid]); -} - -kernel void k_fp2_mul( - device const Fp2* a [[buffer(0)]], - device const Fp2* b [[buffer(1)]], - device Fp2* out [[buffer(2)]], - constant uint& n [[buffer(3)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - out[tid] = fp2_mul(a[tid], b[tid]); -} - -kernel void k_fp2_sqr( - device const Fp2* a [[buffer(0)]], - device Fp2* out [[buffer(1)]], - constant uint& n [[buffer(2)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - out[tid] = fp2_sqr(a[tid]); -} - -kernel void k_fp2_inv( - device const Fp2* a [[buffer(0)]], - device Fp2* out [[buffer(1)]], - constant uint& n [[buffer(2)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - out[tid] = fp2_inv(a[tid]); -} - -kernel void k_fp2_conj( - device const Fp2* a [[buffer(0)]], - device Fp2* out [[buffer(1)]], - constant uint& n [[buffer(2)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - out[tid] = fp2_conj(a[tid]); -} - -// Raw Fp inversion exposed for diagnostic tests. Treats first 48 bytes as Fp, -// returns result in first 48 bytes, zeroes the c1 component. -kernel void k_fp_inv_diag( - device const Fp2* a [[buffer(0)]], - device Fp2* out [[buffer(1)]], - constant uint& n [[buffer(2)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - Fp2 r; - r.c0 = fp_inv(a[tid].c0); - r.c1 = ZERO384; - out[tid] = r; -} - -#endif // BLS_FP2_NO_KERNELS diff --git a/bls/gpu/metal/bls_fp6.metal b/bls/gpu/metal/bls_fp6.metal deleted file mode 100644 index 1d0d5c3..0000000 --- a/bls/gpu/metal/bls_fp6.metal +++ /dev/null @@ -1,278 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Fp6 = Fp2[v] / (v^3 - (u + 1)) for BLS12-381. -// Layout: struct Fp6 { Fp2 c0, c1, c2; } == blst_fp6 { blst_fp2 fp2[3]; } byte-equal. -// Element c0 + c1 v + c2 v^2 with v^3 = u + 1. -// -// Algorithms mirror blst src/fp12_tower.c (reference path). -// -// mul (schoolbook with v^3 = u+1 reduction): -// t0 = a0 b0, t1 = a1 b1, t2 = a2 b2 -// r0 = ((a1+a2)(b1+b2) - t1 - t2)(u+1) + t0 -// r1 = (a0+a1)(b0+b1) - t0 - t1 + t2(u+1) -// r2 = (a0+a2)(b0+b2) - t0 - t2 + t1 -// -// sqr (Chung-Hasan SQR3): -// s0 = a0^2, s2 = a2^2 -// m01 = 2 a0 a1, m12 = 2 a1 a2 -// r0 = m12 (u+1) + s0 -// r1 = s2 (u+1) + m01 -// r2 = (a0+a1+a2)^2 - s0 - s2 - m01 - m12 -// -// inv (Itoh-Tsujii): see blst inverse_fp6. -// -// Frobenius coefficients hardcoded from blst src/fp12_tower.c lines 675-695, -// already in Montgomery form. - -#ifndef BLS_FP6_INLINES -#define BLS_FP6_INLINES - -#define BLS_FP2_NO_KERNELS -#include "bls_fp2.metal" -#undef BLS_FP2_NO_KERNELS - -struct Fp6 { Fp2 c0, c1, c2; }; - -inline Fp6 fp6_add(Fp6 a, Fp6 b) { - Fp6 r; - r.c0 = fp2_add(a.c0, b.c0); - r.c1 = fp2_add(a.c1, b.c1); - r.c2 = fp2_add(a.c2, b.c2); - return r; -} - -inline Fp6 fp6_sub(Fp6 a, Fp6 b) { - Fp6 r; - r.c0 = fp2_sub(a.c0, b.c0); - r.c1 = fp2_sub(a.c1, b.c1); - r.c2 = fp2_sub(a.c2, b.c2); - return r; -} - -inline Fp6 fp6_neg(Fp6 a) { - Fp6 r; - r.c0 = fp2_neg(a.c0); - r.c1 = fp2_neg(a.c1); - r.c2 = fp2_neg(a.c2); - return r; -} - -inline Fp6 fp6_mul(Fp6 a, Fp6 b) { - Fp2 t0 = fp2_mul(a.c0, b.c0); - Fp2 t1 = fp2_mul(a.c1, b.c1); - Fp2 t2 = fp2_mul(a.c2, b.c2); - - // r0 = ((a1 + a2)(b1 + b2) - t1 - t2)(u+1) + t0 - Fp2 sa12 = fp2_add(a.c1, a.c2); - Fp2 sb12 = fp2_add(b.c1, b.c2); - Fp2 r0 = fp2_mul(sa12, sb12); - r0 = fp2_sub(r0, t1); - r0 = fp2_sub(r0, t2); - r0 = fp2_mul_by_1_plus_u(r0); - r0 = fp2_add(r0, t0); - - // r1 = (a0 + a1)(b0 + b1) - t0 - t1 + t2(u+1) - Fp2 sa01 = fp2_add(a.c0, a.c1); - Fp2 sb01 = fp2_add(b.c0, b.c1); - Fp2 r1 = fp2_mul(sa01, sb01); - r1 = fp2_sub(r1, t0); - r1 = fp2_sub(r1, t1); - r1 = fp2_add(r1, fp2_mul_by_1_plus_u(t2)); - - // r2 = (a0 + a2)(b0 + b2) - t0 - t2 + t1 - Fp2 sa02 = fp2_add(a.c0, a.c2); - Fp2 sb02 = fp2_add(b.c0, b.c2); - Fp2 r2 = fp2_mul(sa02, sb02); - r2 = fp2_sub(r2, t0); - r2 = fp2_sub(r2, t2); - r2 = fp2_add(r2, t1); - - Fp6 r; r.c0 = r0; r.c1 = r1; r.c2 = r2; return r; -} - -inline Fp6 fp6_sqr(Fp6 a) { - Fp2 s0 = fp2_sqr(a.c0); - Fp2 m01 = fp2_mul(a.c0, a.c1); m01 = fp2_add(m01, m01); - Fp2 m12 = fp2_mul(a.c1, a.c2); m12 = fp2_add(m12, m12); - Fp2 s2 = fp2_sqr(a.c2); - - // r2 = (a0 + a1 + a2)^2 - s0 - s2 - m01 - m12 - Fp2 sum = fp2_add(fp2_add(a.c0, a.c1), a.c2); - Fp2 r2 = fp2_sqr(sum); - r2 = fp2_sub(r2, s0); - r2 = fp2_sub(r2, s2); - r2 = fp2_sub(r2, m01); - r2 = fp2_sub(r2, m12); - - // r0 = m12 (u+1) + s0 - Fp2 r0 = fp2_mul_by_1_plus_u(m12); - r0 = fp2_add(r0, s0); - - // r1 = s2 (u+1) + m01 - Fp2 r1 = fp2_mul_by_1_plus_u(s2); - r1 = fp2_add(r1, m01); - - Fp6 r; r.c0 = r0; r.c1 = r1; r.c2 = r2; return r; -} - -inline Fp6 fp6_inv(Fp6 a) { - // c0 = a0^2 - (a1 a2)(u+1) - Fp2 c0 = fp2_sqr(a.c0); - Fp2 t = fp2_mul(a.c1, a.c2); - t = fp2_mul_by_1_plus_u(t); - c0 = fp2_sub(c0, t); - - // c1 = a2^2 (u+1) - a0 a1 - Fp2 c1 = fp2_sqr(a.c2); - c1 = fp2_mul_by_1_plus_u(c1); - Fp2 t01 = fp2_mul(a.c0, a.c1); - c1 = fp2_sub(c1, t01); - - // c2 = a1^2 - a0 a2 - Fp2 c2 = fp2_sqr(a.c1); - Fp2 t02 = fp2_mul(a.c0, a.c2); - c2 = fp2_sub(c2, t02); - - // norm = (a2 c1 + a1 c2)(u+1) + a0 c0 - Fp2 t1 = fp2_mul(c1, a.c2); - Fp2 t2 = fp2_mul(c2, a.c1); - Fp2 norm = fp2_add(t1, t2); - norm = fp2_mul_by_1_plus_u(norm); - norm = fp2_add(norm, fp2_mul(c0, a.c0)); - - Fp2 ni = fp2_inv(norm); - - Fp6 r; - r.c0 = fp2_mul(c0, ni); - r.c1 = fp2_mul(c1, ni); - r.c2 = fp2_mul(c2, ni); - return r; -} - -// Frobenius coefficients (Montgomery form), from blst src/fp12_tower.c lines 675-695. -// -// FP6_FROB_C1[n-1] = (u + 1)^((p^n - 1) / 3) in Fp2 (real, imag) -// FP6_FROB_C2[n-1] = (u + 1)^((2 p^n - 2) / 3) in Fp (real) -// -// Only n in {1, 2, 3} supported (sufficient for BLS12-381 Miller loop). - -constant uint384 FP6_FROB_C1_RE_N1 = {{0,0,0,0,0,0}}; -constant uint384 FP6_FROB_C1_IM_N1 = {{ - 0xCD03C9E48671F071UL, 0x5DAB22461FCDA5D2UL, 0x587042AFD3851B95UL, - 0x8EB60EBE01BACB9EUL, 0x03F97D6E83D050D2UL, 0x18F0206554638741UL -}}; - -constant uint384 FP6_FROB_C1_RE_N2 = {{ - 0x30F1361B798A64E8UL, 0xF3B8DDAB7ECE5A2AUL, 0x16A8CA3AC61577F7UL, - 0xC26A2FF874FD029BUL, 0x3636B76660701C6EUL, 0x051BA4AB241B6160UL -}}; -constant uint384 FP6_FROB_C1_IM_N2 = {{0,0,0,0,0,0}}; - -constant uint384 FP6_FROB_C1_RE_N3 = {{0,0,0,0,0,0}}; -// blst comment: "implied ONE_MONT_P at index 0" => imag part = R mod p (= 1 in Mont) -constant uint384 FP6_FROB_C1_IM_N3 = {{ - 0x760900000002FFFDUL, 0xEBF4000BC40C0002UL, 0x5F48985753C758BAUL, - 0x77CE585370525745UL, 0x5C071A97A256EC6DUL, 0x15F65EC3FA80E493UL -}}; - -constant uint384 FP6_FROB_C2_N1 = {{ - 0x890DC9E4867545C3UL, 0x2AF322533285A5D5UL, 0x50880866309B7E2CUL, - 0xA20D1B8C7E881024UL, 0x14E4F04FE2DB9068UL, 0x14E56D3F1564853AUL -}}; -constant uint384 FP6_FROB_C2_N2 = {{ - 0xCD03C9E48671F071UL, 0x5DAB22461FCDA5D2UL, 0x587042AFD3851B95UL, - 0x8EB60EBE01BACB9EUL, 0x03F97D6E83D050D2UL, 0x18F0206554638741UL -}}; -constant uint384 FP6_FROB_C2_N3 = {{ - 0x43F5FFFFFFFCAAAEUL, 0x32B7FFF2ED47FFFDUL, 0x07E83A49A2E99D69UL, - 0xECA8F3318332BB7AUL, 0xEF148D1EA0F4C069UL, 0x040AB3263EFF0206UL -}}; - -inline Fp6 fp6_frobenius(Fp6 a, uint n) { - Fp2 r0 = fp2_frobenius(a.c0, n); - Fp2 r1 = fp2_frobenius(a.c1, n); - Fp2 r2 = fp2_frobenius(a.c2, n); - - Fp2 c1; uint384 c2_real; - if (n == 1u) { - c1.c0 = FP6_FROB_C1_RE_N1; c1.c1 = FP6_FROB_C1_IM_N1; - c2_real = FP6_FROB_C2_N1; - } else if (n == 2u) { - c1.c0 = FP6_FROB_C1_RE_N2; c1.c1 = FP6_FROB_C1_IM_N2; - c2_real = FP6_FROB_C2_N2; - } else { - c1.c0 = FP6_FROB_C1_RE_N3; c1.c1 = FP6_FROB_C1_IM_N3; - c2_real = FP6_FROB_C2_N3; - } - - r1 = fp2_mul(r1, c1); - // c2 coefficient is pure-Fp; multiply both real and imaginary components. - r2.c0 = fp_mul(r2.c0, c2_real); - r2.c1 = fp_mul(r2.c1, c2_real); - - Fp6 r; r.c0 = r0; r.c1 = r1; r.c2 = r2; return r; -} - -#endif // BLS_FP6_INLINES - -// ============================================================================= -// Kernels — buffer element size = 288 bytes (sizeof(Fp6) = 3 * 96). -// ============================================================================= - -#ifndef BLS_FP6_NO_KERNELS - -kernel void k_fp6_add( - device const Fp6* a [[buffer(0)]], - device const Fp6* b [[buffer(1)]], - device Fp6* out [[buffer(2)]], - constant uint& n [[buffer(3)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - out[tid] = fp6_add(a[tid], b[tid]); -} - -kernel void k_fp6_sub( - device const Fp6* a [[buffer(0)]], - device const Fp6* b [[buffer(1)]], - device Fp6* out [[buffer(2)]], - constant uint& n [[buffer(3)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - out[tid] = fp6_sub(a[tid], b[tid]); -} - -kernel void k_fp6_mul( - device const Fp6* a [[buffer(0)]], - device const Fp6* b [[buffer(1)]], - device Fp6* out [[buffer(2)]], - constant uint& n [[buffer(3)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - out[tid] = fp6_mul(a[tid], b[tid]); -} - -kernel void k_fp6_sqr( - device const Fp6* a [[buffer(0)]], - device Fp6* out [[buffer(1)]], - constant uint& n [[buffer(2)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - out[tid] = fp6_sqr(a[tid]); -} - -kernel void k_fp6_inv( - device const Fp6* a [[buffer(0)]], - device Fp6* out [[buffer(1)]], - constant uint& n [[buffer(2)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - out[tid] = fp6_inv(a[tid]); -} - -#endif // BLS_FP6_NO_KERNELS diff --git a/bls/gpu/metal/bls_fp_ops.h.metal b/bls/gpu/metal/bls_fp_ops.h.metal deleted file mode 100644 index 874ca21..0000000 --- a/bls/gpu/metal/bls_fp_ops.h.metal +++ /dev/null @@ -1,193 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Shared Fp arithmetic primitives for BLS12-381 tower extensions. -// Included by bls_fp2.metal, bls_fp6.metal, bls_fp12.metal. -// -// All Fp values are stored in Montgomery form, 6 x 64-bit little-endian limbs, -// matching blst's `vec384` / `blst_fp` exactly so a memcpy round-trips. -// -// p = 0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab -// R = 2^384 mod p -// p0 = -p^(-1) mod 2^64 = 0x89f3fffcfffcfffd (matches blst src/consts.c) - -#ifndef BLS_FP_OPS_H_METAL -#define BLS_FP_OPS_H_METAL - -#include -using namespace metal; - -struct uint384 { ulong limbs[6]; }; - -constant uint384 BLS_P = {{ - 0xB9FEFFFFFFFFAAABUL, 0x1EABFFFEB153FFFFUL, 0x6730D2A0F6B0F624UL, - 0x64774B84F38512BFUL, 0x4B1BA7B6434BACD7UL, 0x1A0111EA397FE69AUL -}}; - -constant uint384 BLS_R2 = {{ - 0xF4DF1F341C341746UL, 0x0A76E6A609D104F1UL, 0x8DE5476C4C95B6D5UL, - 0x67EB88A9939D83C0UL, 0x9A793E85B519952DUL, 0x11988FE592CAE3AAUL -}}; - -constant uint384 BLS_R = {{ - 0x760900000002FFFDUL, 0xEBF4000BC40C0002UL, 0x5F48985753C758BAUL, - 0x77CE585370525745UL, 0x5C071A97A256EC6DUL, 0x15F65EC3FA80E493UL -}}; - -constant ulong BLS_P_INV = 0x89F3FFFCFFFCFFFDUL; -constant uint384 ZERO384 = {{0,0,0,0,0,0}}; - -inline int u384_cmp(uint384 a, uint384 b) { - for (int i = 5; i >= 0; i--) { - if (a.limbs[i] < b.limbs[i]) return -1; - if (a.limbs[i] > b.limbs[i]) return 1; - } - return 0; -} - -inline bool u384_is_zero(uint384 a) { - return (a.limbs[0]|a.limbs[1]|a.limbs[2]|a.limbs[3]|a.limbs[4]|a.limbs[5]) == 0; -} - -inline uint384 u384_add(uint384 a, uint384 b, thread ulong& carry) { - uint384 r; ulong c = 0; - for (int i = 0; i < 6; i++) { - ulong s1 = a.limbs[i] + c; - ulong c1 = (s1 < a.limbs[i]) ? 1UL : 0UL; - ulong s2 = s1 + b.limbs[i]; - ulong c2 = (s2 < s1) ? 1UL : 0UL; - r.limbs[i] = s2; - c = c1 + c2; - } - carry = c; - return r; -} - -inline uint384 u384_sub(uint384 a, uint384 b, thread ulong& borrow) { - uint384 r; ulong bw = 0; - for (int i = 0; i < 6; i++) { - ulong d1 = a.limbs[i] - bw; - ulong b1 = (d1 > a.limbs[i]) ? 1UL : 0UL; - ulong d2 = d1 - b.limbs[i]; - ulong b2 = (d2 > d1) ? 1UL : 0UL; - r.limbs[i] = d2; - bw = b1 + b2; - } - borrow = bw; - return r; -} - -// 64x64 -> 128 (lo, hi) — Metal lacks native uint128. -inline void mul64(ulong a, ulong b, thread ulong& lo, thread ulong& hi) { - ulong al = a & 0xFFFFFFFFUL, ah = a >> 32; - ulong bl = b & 0xFFFFFFFFUL, bh = b >> 32; - ulong ll = al*bl, lh = al*bh, hl = ah*bl, hh = ah*bh; - ulong mid = lh + (ll >> 32); - ulong mid2 = mid + hl; - if (mid2 < mid) hh += (1UL << 32); - lo = (mid2 << 32) | (ll & 0xFFFFFFFFUL); - hi = hh + (mid2 >> 32); -} - -// CIOS Montgomery reduction of 768-bit t -> t * R^(-1) mod p. -inline uint384 mont_reduce_384(ulong t[12]) { - ulong a[13]; - for (int i = 0; i < 12; i++) a[i] = t[i]; - a[12] = 0; - for (int i = 0; i < 6; i++) { - ulong u = a[i] * BLS_P_INV; - ulong carry = 0; - for (int j = 0; j < 6; j++) { - ulong lo, hi; mul64(u, BLS_P.limbs[j], lo, hi); - ulong s = lo + carry; if (s < lo) hi++; - lo = s; - s = a[i+j] + lo; if (s < a[i+j]) hi++; - a[i+j] = s; - carry = hi; - } - for (int j = 6; i+j <= 12; j++) { - ulong s = a[i+j] + carry; - carry = (s < a[i+j]) ? 1UL : 0UL; - a[i+j] = s; - if (carry == 0) break; - } - } - uint384 r; - r.limbs[0]=a[6]; r.limbs[1]=a[7]; r.limbs[2]=a[8]; - r.limbs[3]=a[9]; r.limbs[4]=a[10]; r.limbs[5]=a[11]; - if (a[12] || u384_cmp(r, BLS_P) >= 0) { - ulong bw; r = u384_sub(r, BLS_P, bw); - } - return r; -} - -inline uint384 fp_mul(uint384 a, uint384 b) { - ulong t[12] = {}; - for (int i = 0; i < 6; i++) { - ulong carry = 0; - for (int j = 0; j < 6; j++) { - ulong lo, hi; mul64(a.limbs[i], b.limbs[j], lo, hi); - ulong s = lo + carry; if (s < lo) hi++; - lo = s; - s = t[i+j] + lo; if (s < t[i+j]) hi++; - t[i+j] = s; - carry = hi; - } - for (int j = 6; i+j < 12; j++) { - ulong s = t[i+j] + carry; - carry = (s < t[i+j]) ? 1UL : 0UL; - t[i+j] = s; - if (carry == 0) break; - } - } - return mont_reduce_384(t); -} - -inline uint384 fp_sqr(uint384 a) { return fp_mul(a, a); } - -inline uint384 fp_add(uint384 a, uint384 b) { - ulong c; uint384 r = u384_add(a, b, c); - if (c || u384_cmp(r, BLS_P) >= 0) { - ulong bw; r = u384_sub(r, BLS_P, bw); - } - return r; -} - -inline uint384 fp_sub(uint384 a, uint384 b) { - ulong bw; uint384 r = u384_sub(a, b, bw); - if (bw) { ulong c; r = u384_add(r, BLS_P, c); } - return r; -} - -inline uint384 fp_neg(uint384 a) { - if (u384_is_zero(a)) return a; - ulong bw; return u384_sub(BLS_P, a, bw); -} - -// Fermat inversion: a^(p-2) mod p. Same modular inverse as blst's recip-addchain; -// produces identical Montgomery output bytes. -// -// Left-to-right binary square-and-multiply over the 381-bit exponent (p-2). -// Iterates from MSB to LSB; squares result every step, multiplies in `a` when -// the current bit is set. Skip leading zero bits to keep `result` at 1 until -// the first set bit (avoids unnecessary squarings before initialization). -inline uint384 fp_inv(uint384 a) { - uint384 exp = BLS_P; // exp = p - // exp -= 2 on the bottom limb (low limb is well above 2) - exp.limbs[0] -= 2; - - uint384 result = BLS_R; // 1 in Montgomery form - bool started = false; - for (int i = 5; i >= 0; i--) { - for (int bit = 63; bit >= 0; bit--) { - if (started) result = fp_sqr(result); - if ((exp.limbs[i] >> bit) & 1) { - result = started ? fp_mul(result, a) : a; - started = true; - } - } - } - return result; -} - -#endif // BLS_FP_OPS_H_METAL diff --git a/bls/gpu/metal/bls_g2.metal b/bls/gpu/metal/bls_g2.metal deleted file mode 100644 index 8302d4b..0000000 --- a/bls/gpu/metal/bls_g2.metal +++ /dev/null @@ -1,299 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// G2 = E'(Fp2) for BLS12-381 (twist curve y^2 = x^3 + 4(u+1)). -// Layouts: -// struct P2 { Fp2 X, Y, Z; } == blst_p2 / POINTonE2 (288 B) -// struct P2Aff { Fp2 X, Y; } == blst_p2_affine (192 B) -// Bytes match blst exactly so a memcpy round-trips. -// -// Algorithms mirror blst src/ec_ops.h: -// POINT_ADD_IMPL -> p2_jac_add (Jacobian + Jacobian) -// POINT_DOUBLE_IMPL_A0 -> p2_jac_dbl (a4 = 0 since BLS curve has a=0) -// POINT_ADD_AFFINE_IMPL -> p2_mixed_add (Jacobian + affine) -// -// Scalar multiplication mirrors a left-to-right binary double-and-add. -// blst uses a 5-bit window with Booth recoding internally; rather than try to -// byte-match a Jacobian representation produced by Booth-window arithmetic, -// the kernel converts its result to AFFINE form and the test compares against -// blst's affine output. Affine encodings are unique for a given group element -// so byte-equality is preserved across algorithmic choice. -// -// Curve parameter: a = 0, b' = 4(u+1) (twist). All ops below are correct -// only because a = 0 (POINT_DOUBLE_IMPL_A0, POINT_DADD_IMPL with a4=NULL path). - -#ifndef BLS_G2_INLINES -#define BLS_G2_INLINES - -#define BLS_FP2_NO_KERNELS -#include "bls_fp2.metal" -#undef BLS_FP2_NO_KERNELS - -struct P2 { Fp2 X, Y, Z; }; -struct P2Aff { Fp2 X, Y; }; - -inline bool fp2_is_zero(Fp2 a) { - return u384_is_zero(a.c0) && u384_is_zero(a.c1); -} - -// 1 in Fp2 (Montgomery): { c0 = R, c1 = 0 }. -inline Fp2 fp2_one() { - Fp2 r; r.c0 = BLS_R; r.c1 = ZERO384; return r; -} - -inline Fp2 fp2_zero() { - Fp2 r; r.c0 = ZERO384; r.c1 = ZERO384; return r; -} - -// Mirrors blst POINT_ADD_IMPL(POINTonE2, 384x, fp2). Handles either input at -// infinity (Z == 0). Does NOT handle the doubling case (H == 0); for safety we -// route through the d-add at the test level by ensuring random P != Q. -inline P2 p2_jac_add(P2 p1, P2 p2) { - bool p1inf = fp2_is_zero(p1.Z); - Fp2 Z1Z1 = fp2_sqr(p1.Z); - - Fp2 pZ = fp2_mul(Z1Z1, p1.Z); // Z1*Z1Z1 - pZ = fp2_mul(pZ, p2.Y); // S2 = Y2*Z1*Z1Z1 - - bool p2inf = fp2_is_zero(p2.Z); - Fp2 Z2Z2 = fp2_sqr(p2.Z); - - Fp2 S1 = fp2_mul(Z2Z2, p2.Z); // Z2*Z2Z2 - S1 = fp2_mul(S1, p1.Y); // S1 = Y1*Z2*Z2Z2 - - pZ = fp2_sub(pZ, S1); // S2-S1 - pZ = fp2_add(pZ, pZ); // r = 2*(S2-S1) - - Fp2 U1 = fp2_mul(p1.X, Z2Z2); // U1 = X1*Z2Z2 - Fp2 H = fp2_mul(p2.X, Z1Z1); // U2 = X2*Z1Z1 - H = fp2_sub(H, U1); // H = U2-U1 - - Fp2 I = fp2_add(H, H); // 2*H - I = fp2_sqr(I); // I = (2*H)^2 - - Fp2 J = fp2_mul(H, I); // J = H*I - S1 = fp2_mul(S1, J); // S1*J - - Fp2 V = fp2_mul(U1, I); // V = U1*I - - Fp2 pX = fp2_sqr(pZ); // r^2 - pX = fp2_sub(pX, J); // r^2-J - pX = fp2_sub(pX, V); - pX = fp2_sub(pX, V); // X3 = r^2-J-2*V - - Fp2 pY = fp2_sub(V, pX); // V-X3 - pY = fp2_mul(pY, pZ); // r*(V-X3) - pY = fp2_sub(pY, S1); - pY = fp2_sub(pY, S1); // Y3 = r*(V-X3)-2*S1*J - - pZ = fp2_add(p1.Z, p2.Z); // Z1+Z2 - pZ = fp2_sqr(pZ); - pZ = fp2_sub(pZ, Z1Z1); - pZ = fp2_sub(pZ, Z2Z2); // (Z1+Z2)^2-Z1Z1-Z2Z2 - pZ = fp2_mul(pZ, H); // Z3 = (...)*H - - P2 p3; p3.X = pX; p3.Y = pY; p3.Z = pZ; - - if (p2inf) p3 = p1; - if (p1inf) p3 = p2; - return p3; -} - -// Mirrors POINT_DOUBLE_IMPL_A0(POINTonE2, 384x, fp2). a = 0 for BLS12-381. -inline P2 p2_jac_dbl(P2 p1) { - Fp2 A = fp2_sqr(p1.X); // A = X1^2 - Fp2 B = fp2_sqr(p1.Y); // B = Y1^2 - Fp2 C = fp2_sqr(B); // C = B^2 - - B = fp2_add(B, p1.X); // X1+B - B = fp2_sqr(B); // (X1+B)^2 - B = fp2_sub(B, A); - B = fp2_sub(B, C); - B = fp2_add(B, B); // D = 2*((X1+B)^2-A-C) - - // mul_by_3: 3*A = A + 2A - Fp2 A3 = fp2_add(A, A); - A3 = fp2_add(A3, A); // E = 3*A - - Fp2 pX = fp2_sqr(A3); // F = E^2 - pX = fp2_sub(pX, B); - pX = fp2_sub(pX, B); // X3 = F-2*D - - Fp2 pZ = fp2_add(p1.Z, p1.Z); // 2*Z1 - pZ = fp2_mul(pZ, p1.Y); // Z3 = 2*Z1*Y1 - - // mul_by_8: 8*C = ((C<<1)<<1)<<1 - Fp2 C8 = fp2_add(C, C); - C8 = fp2_add(C8, C8); - C8 = fp2_add(C8, C8); // 8*C - - Fp2 pY = fp2_sub(B, pX); // D-X3 - pY = fp2_mul(pY, A3); // E*(D-X3) - pY = fp2_sub(pY, C8); // Y3 = E*(D-X3)-8*C - - P2 p3; p3.X = pX; p3.Y = pY; p3.Z = pZ; - return p3; -} - -// Mirrors POINT_ADD_AFFINE_IMPL(POINTonE2, 384x, fp2, BLS12_381_Rx.p2). -// |p1| at infinity encoded as Z==0; |p2| at infinity encoded as X==Y==0. -inline P2 p2_mixed_add(P2 p1, P2Aff p2) { - bool p1inf = fp2_is_zero(p1.Z); - Fp2 Z1Z1 = fp2_sqr(p1.Z); - - Fp2 pZ = fp2_mul(Z1Z1, p1.Z); // Z1*Z1Z1 - pZ = fp2_mul(pZ, p2.Y); // S2 = Y2*Z1*Z1Z1 - - bool p2inf = fp2_is_zero(p2.X) && fp2_is_zero(p2.Y); - - Fp2 H = fp2_mul(p2.X, Z1Z1); // U2 = X2*Z1Z1 - H = fp2_sub(H, p1.X); // H = U2-X1 - - Fp2 HH = fp2_sqr(H); // HH = H^2 - Fp2 I = fp2_add(HH, HH); - I = fp2_add(I, I); // I = 4*HH - - Fp2 pY_v = fp2_mul(p1.X, I); // V = X1*I - Fp2 J = fp2_mul(H, I); // J = H*I - Fp2 Iy = fp2_mul(J, p1.Y); // Y1*J - - pZ = fp2_sub(pZ, p1.Y); // S2-Y1 - pZ = fp2_add(pZ, pZ); // r = 2*(S2-Y1) - - Fp2 pX = fp2_sqr(pZ); // r^2 - pX = fp2_sub(pX, J); - pX = fp2_sub(pX, pY_v); - pX = fp2_sub(pX, pY_v); // X3 = r^2-J-2*V - - Fp2 pY = fp2_sub(pY_v, pX); // V-X3 - pY = fp2_mul(pY, pZ); // r*(V-X3) - pY = fp2_sub(pY, Iy); - pY = fp2_sub(pY, Iy); // Y3 = r*(V-X3)-2*Y1*J - - pZ = fp2_add(p1.Z, H); // Z1+H - pZ = fp2_sqr(pZ); - pZ = fp2_sub(pZ, Z1Z1); - pZ = fp2_sub(pZ, HH); // Z3 = (Z1+H)^2-Z1Z1-HH - - P2 p3; p3.X = pX; p3.Y = pY; p3.Z = pZ; - - if (p1inf) { - // p3 = p2 promoted to Jacobian (Z = 1) - p3.X = p2.X; - p3.Y = p2.Y; - p3.Z = fp2_one(); - } - if (p2inf) p3 = p1; - return p3; -} - -// Convert Jacobian -> affine. Mirrors blst POINTonE2_from_Jacobian: -// X_aff = X / Z^2, Y_aff = Y / Z^3. Z = 0 -> infinity (X=Y=0). -inline P2Aff p2_to_affine(P2 p) { - P2Aff a; - if (fp2_is_zero(p.Z)) { - a.X = fp2_zero(); - a.Y = fp2_zero(); - return a; - } - Fp2 Zi = fp2_inv(p.Z); - Fp2 Zi2 = fp2_sqr(Zi); - Fp2 Zi3 = fp2_mul(Zi2, Zi); - a.X = fp2_mul(p.X, Zi2); - a.Y = fp2_mul(p.Y, Zi3); - return a; -} - -// Left-to-right binary scalar multiplication using the Jacobian add/dbl above. -// scalar is 32 little-endian bytes (matches blst's `byte *scalar` API -// when nbits = 256). bit_len caps how many bits of the scalar are processed -// and is provided by the caller (matches blst's `nbits` parameter). -// -// Output is converted to AFFINE for byte-equality comparison against blst, -// because blst's internal w=5 Booth recoding produces a different Jacobian -// representation that nevertheless decodes to the same affine point. -inline P2Aff p2_scalar_mult(P2 base, device const uchar* scalar, uint nbits) { - // Jacobian zero (point at infinity) - P2 R; R.X = fp2_zero(); R.Y = fp2_zero(); R.Z = fp2_zero(); - - // Process bits MSB -> LSB. Skip leading zeros so R stays at infinity until - // the first set bit, at which point R becomes 2*infinity = infinity, then - // the next "if bit" path handles initialization via mixed-zero add. - bool started = false; - for (int i = (int)nbits - 1; i >= 0; i--) { - uchar byte = scalar[i >> 3]; - uint bit = (uint)((byte >> (i & 7)) & 1u); - if (started) R = p2_jac_dbl(R); - if (bit) { - if (started) { - R = p2_jac_add(R, base); - } else { - R = base; - started = true; - } - } - } - return p2_to_affine(R); -} - -#endif // BLS_G2_INLINES - -// ============================================================================= -// Kernels — P2 buffers are 288 B (3 * 96 B), P2Aff are 192 B (2 * 96 B). -// ============================================================================= - -#ifndef BLS_G2_NO_KERNELS - -kernel void k_p2_jac_add( - device const P2* a [[buffer(0)]], - device const P2* b [[buffer(1)]], - device P2* out [[buffer(2)]], - constant uint& n [[buffer(3)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - out[tid] = p2_jac_add(a[tid], b[tid]); -} - -kernel void k_p2_jac_dbl( - device const P2* a [[buffer(0)]], - device P2* out [[buffer(1)]], - constant uint& n [[buffer(2)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - out[tid] = p2_jac_dbl(a[tid]); -} - -kernel void k_p2_mixed_add( - device const P2* a [[buffer(0)]], - device const P2Aff* b [[buffer(1)]], - device P2* out [[buffer(2)]], - constant uint& n [[buffer(3)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - out[tid] = p2_mixed_add(a[tid], b[tid]); -} - -// Scalar mult: input layout per element -// bytes [0..288) : P2 base -// bytes [288..320) : 32 little-endian bytes scalar -// Output: P2Aff (192 bytes per element). -struct P2ScalarIn { - P2 base; - uchar scalar[32]; -}; - -kernel void k_p2_scalar_mult( - device const P2ScalarIn* in [[buffer(0)]], - device P2Aff* out [[buffer(1)]], - constant uint& n [[buffer(2)]], - constant uint& nbits [[buffer(3)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - out[tid] = p2_scalar_mult(in[tid].base, in[tid].scalar, nbits); -} - -#endif // BLS_G2_NO_KERNELS diff --git a/bls/gpu/metal/bls_miller.metal b/bls/gpu/metal/bls_miller.metal deleted file mode 100644 index dc3208b..0000000 --- a/bls/gpu/metal/bls_miller.metal +++ /dev/null @@ -1,394 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// BLS12-381 optimal-ate Miller loop on Metal. -// -// Inputs: -// Q in G2 (P2Aff = 192 bytes, blst_p2_affine layout) -// P in G1 (P1Aff = 96 bytes, blst_p1_affine layout = 2*48-byte Fp) -// Output: -// f in Fp12 (576 bytes), the value of e_pre(P, Q) = product of line evals, -// matching blst_miller_loop pre-final-exponentiation byte-for-byte. -// -// Algorithm mirrors blst src/pairing.c miller_loop() exactly: -// 1. Px2 = (-2*P.X, 2*P.Y) (precomputed for line_by_Px2) -// 2. T = Q (Jacobian, Z = 1) -// 3. line_dbl(line, T, T) ; line_by_Px2 ; ret = unpack(line) // first sqr fused -// 4. add_n_dbl(2) ; add_n_dbl(3) ; add_n_dbl(9) -// add_n_dbl(32) ; add_n_dbl(16) -// 5. ret = conjugate(ret) // x is negative -// -// Loop scalar: x = -0xd201000000010000 -// Counts above (1+2+3+9+32+16 = 63 doublings + 5 add-then-double-runs) -// match the bit pattern of |x| = 0xd201000000010000 minus the leading bit -// already absorbed by step 3. -// -// Line layout — "xy00z0" packed as three Fp2: -// Line { Fp2 x, y, z } -// maps into Fp12 a = (a0, a1) where a0 = (x, y, 0) Fp6 and a1 = (0, z, 0) Fp6. -// This matches blst's vec384fp6 line[3] convention used in pairing.c + -// fp12_tower.c::mul_by_xy00z0_fp12. - -#define BLS_FP12_NO_KERNELS -#define BLS_FP6_NO_KERNELS -#define BLS_FP2_NO_KERNELS -#include "bls_fp12.metal" -#undef BLS_FP12_NO_KERNELS -#undef BLS_FP6_NO_KERNELS -#undef BLS_FP2_NO_KERNELS - -#define BLS_FP2_NO_KERNELS -#define BLS_FP6_NO_KERNELS -#define BLS_FP12_NO_KERNELS -#define BLS_G2_NO_KERNELS -#include "bls_g2.metal" -#undef BLS_G2_NO_KERNELS -#undef BLS_FP12_NO_KERNELS -#undef BLS_FP6_NO_KERNELS -#undef BLS_FP2_NO_KERNELS - -struct P1Aff { uint384 X, Y; }; - -struct Line { Fp2 x, y, z; }; - -// line_dbl — blst pairing.c:78 -// T <- 2*T, line <- doubling-line at the original T. -// Q is passed by value (snapshot of the input); blst's reference call is -// line_dbl(line, T, T) — Q == T. Taking by value lets us safely overwrite T. -// Not inlined so the Miller loop fits in a single Metal compile unit. -static Line line_dbl(thread P2& T, P2 Q) { - Fp2 A = fp2_sqr(Q.X); // X1^2 - Fp2 B = fp2_sqr(Q.Y); // Y1^2 - Fp2 ZZ = fp2_sqr(Q.Z); // Z1^2 - Fp2 C = fp2_sqr(B); // C = B^2 - - Fp2 D = fp2_add(Q.X, B); - D = fp2_sqr(D); - D = fp2_sub(D, A); - D = fp2_sub(D, C); - D = fp2_add(D, D); // D = 2*((X+B)^2 - A - C) - - // E = 3*A - Fp2 E = fp2_add(A, A); - E = fp2_add(E, A); - - Fp2 F = fp2_sqr(E); // F = E^2 - - // line[0] = 3A + X1 (will be squared+adjusted next) - Fp2 line0 = fp2_add(E, Q.X); - - Fp2 Tx = fp2_sub(F, D); - Tx = fp2_sub(Tx, D); // X3 = F - 2D - - Fp2 Tz = fp2_add(Q.Y, Q.Z); - Tz = fp2_sqr(Tz); - Tz = fp2_sub(Tz, B); - Tz = fp2_sub(Tz, ZZ); // Z3 = (Y+Z)^2 - B - ZZ - - // 8*C for Y3 path - Fp2 C8 = fp2_add(C, C); - C8 = fp2_add(C8, C8); - C8 = fp2_add(C8, C8); // 8*C - - Fp2 Ty = fp2_sub(D, Tx); // D - X3 - Ty = fp2_mul(Ty, E); // E*(D-X3) - Ty = fp2_sub(Ty, C8); // Y3 = E*(D-X3) - 8C - - // line evaluation - line0 = fp2_sqr(line0); - line0 = fp2_sub(line0, A); - line0 = fp2_sub(line0, F); // (3A+X)^2 - X^2 - 9A^2 = 6X^3 - ... - // 4*B - Fp2 B4 = fp2_add(B, B); - B4 = fp2_add(B4, B4); - line0 = fp2_sub(line0, B4); // 6*X^3 - 4*Y^2 - - Fp2 line1 = fp2_mul(E, ZZ); // 3*X^2 * Z^2 - Fp2 line2 = fp2_mul(Tz, ZZ); // Z3 * Z^2 - - T.X = Tx; T.Y = Ty; T.Z = Tz; - - Line L; L.x = line0; L.y = line1; L.z = line2; - return L; -} - -// line_add — blst pairing.c:14 -// T <- R + Q, line <- addition-line at R. (R is Jacobian, Q is affine.) -// R is passed by value so callers can pass `R = T` without aliasing risk -// (blst's reference call is line_add(line, T, T, Q)). -static Line line_add(thread P2& T, P2 R, P2Aff Q) { - Fp2 Z1Z1 = fp2_sqr(R.Z); // Z1^2 - Fp2 U2 = fp2_mul(Q.X, Z1Z1); // U2 = X2*Z1Z1 - - Fp2 S2 = fp2_mul(Q.Y, R.Z); - S2 = fp2_mul(S2, Z1Z1); // S2 = Y2*Z1*Z1Z1 - - Fp2 H = fp2_sub(U2, R.X); // H = U2 - X1 - - Fp2 HH = fp2_sqr(H); - Fp2 I = fp2_add(HH, HH); - I = fp2_add(I, I); // I = 4*HH - - Fp2 J = fp2_mul(H, I); // J = H*I - - Fp2 r = fp2_sub(S2, R.Y); - r = fp2_add(r, r); // r = 2*(S2 - Y1) - - Fp2 V = fp2_mul(R.X, I); // V = X1*I - - Fp2 Tx = fp2_sqr(r); - Tx = fp2_sub(Tx, J); - Tx = fp2_sub(Tx, V); - Tx = fp2_sub(Tx, V); // X3 = r^2 - J - 2V - - Fp2 Jy = fp2_mul(J, R.Y); - Fp2 Ty = fp2_sub(V, Tx); - Ty = fp2_mul(Ty, r); - Ty = fp2_sub(Ty, Jy); - Ty = fp2_sub(Ty, Jy); // Y3 = r*(V-X3) - 2*Y1*J - - Fp2 Tz = fp2_add(R.Z, H); - Tz = fp2_sqr(Tz); - Tz = fp2_sub(Tz, Z1Z1); - Tz = fp2_sub(Tz, HH); // Z3 = (Z1+H)^2 - Z1Z1 - HH - - // line evaluation - Fp2 lineI = fp2_mul(r, Q.X); - Fp2 lineJ = fp2_mul(Q.Y, Tz); - lineI = fp2_sub(lineI, lineJ); - Fp2 line0 = fp2_add(lineI, lineI); // 2*(r*X2 - Y2*Z3) - - T.X = Tx; T.Y = Ty; T.Z = Tz; - - Line L; L.x = line0; L.y = r; L.z = Tz; - return L; -} - -// line_by_Px2 — blst pairing.c:128 -// line[1] *= Px2->X (where Px2->X = -2*P->X) -// line[2] *= Px2->Y (where Px2->Y = 2*P->Y) -// Both line[1], line[2] are Fp2; Px2 components are pure Fp -> componentwise. -inline Line line_by_Px2(Line L, uint384 px_neg2, uint384 py_2) { - L.y.c0 = fp_mul(L.y.c0, px_neg2); - L.y.c1 = fp_mul(L.y.c1, px_neg2); - L.z.c0 = fp_mul(L.z.c0, py_2); - L.z.c1 = fp_mul(L.z.c1, py_2); - return L; -} - -// fp6_mul_by_xy0 — blst fp12_tower.c:437 (b = (b0, b1, 0)) -inline Fp6 fp6_mul_by_xy0(Fp6 a, Fp2 b0, Fp2 b1) { - Fp2 t0 = fp2_mul(a.c0, b0); - Fp2 t1 = fp2_mul(a.c1, b1); - - // r0 = (a2 * b1) * (u+1) + a0*b0 - Fp2 t3 = fp2_mul(a.c2, b1); - t3 = fp2_mul_by_1_plus_u(t3); - - // r1 = (a0+a1)(b0+b1) - t0 - t1 - Fp2 t4 = fp2_add(a.c0, a.c1); - Fp2 t5 = fp2_add(b0, b1); - Fp2 r1 = fp2_mul(t4, t5); - r1 = fp2_sub(r1, t0); - r1 = fp2_sub(r1, t1); - - // r2 = a2*b0 + t1 - Fp2 r2 = fp2_mul(a.c2, b0); - r2 = fp2_add(r2, t1); - - Fp2 r0 = fp2_add(t3, t0); - - Fp6 r; r.c0 = r0; r.c1 = r1; r.c2 = r2; return r; -} - -// fp6_mul_by_0y0 — blst fp12_tower.c:426 (a * (0, b, 0)) -inline Fp6 fp6_mul_by_0y0(Fp6 a, Fp2 b) { - Fp2 t = fp2_mul(a.c2, b); - Fp6 r; - r.c2 = fp2_mul(a.c1, b); - r.c1 = fp2_mul(a.c0, b); - r.c0 = fp2_mul_by_1_plus_u(t); - return r; -} - -// mul_by_xy00z0_fp12 — blst fp12_tower.c:466 -// ret = a * line where line packs as Fp6 (x, y, 0) for ret[0] -// and Fp6 (0, z, 0) for ret[1] (the "00z0" tail). -static Fp12 fp12_mul_by_xy00z0(Fp12 a, Line L) { - // Build "Fp6 xy00z0" as (L.x, L.y, L.z) in blst layout - // (note: blst stores xy00z0 as 3 fp2's [x, y, z]; the 0's are virtual). - // mul_by_xy0_fp6: t0 = a[0] * (x, y, 0) - Fp6 t0 = fp6_mul_by_xy0(a.c0, L.x, L.y); - - // mul_by_0y0_fp6: t1 = a[1] * (0, z, 0) (i.e. multiply by xy00z0[2] = z) - Fp6 t1 = fp6_mul_by_0y0(a.c1, L.z); - - // ret[1] = (a0 + a1) * (x, y+z, 0) - t0 - t1 - Fp2 b1_alt = fp2_add(L.y, L.z); - Fp6 sum = fp6_add(a.c0, a.c1); - Fp6 r1 = fp6_mul_by_xy0(sum, L.x, b1_alt); - r1 = fp6_sub(r1, t0); - r1 = fp6_sub(r1, t1); - - // ret[0] = t0 + t1 * v (recall v applied to Fp6: (a,b,c) -> (c(u+1), a, b)) - // t1*v = (t1.c2*(u+1), t1.c0, t1.c1) - Fp6 t1v; - t1v.c0 = fp2_mul_by_1_plus_u(t1.c2); - t1v.c1 = t1.c0; - t1v.c2 = t1.c1; - Fp6 r0 = fp6_add(t0, t1v); - - Fp12 r; r.c0 = r0; r.c1 = r1; return r; -} - -// Initial step: ret = unpack( line_dbl( T = Q ) ). -// Mirrors blst pairing.c:166-170. After this, ret has nonzero coords only at -// ret[0][0] = line.x, ret[0][1] = line.y, ret[1][1] = line.z. -inline Fp12 unpack_initial_line(Line L) { - Fp12 ret; - ret.c0.c0 = L.x; - ret.c0.c1 = L.y; - ret.c0.c2 = fp2_zero(); - ret.c1.c0 = fp2_zero(); - ret.c1.c1 = L.z; - ret.c1.c2 = fp2_zero(); - return ret; -} - -// Note: the doubling counts per Miller-loop phase are -// 1 (initial dbl) + 2 + 3 + 9 + 32 + 16 = 63 doublings, -// matching log2(0xd201000000010000) — the BLS12-381 ate scalar |x|. -// The host driver sequences these phases (see bls_miller_test.mm). -// See pairing.c::miller_loop in blst. - -// ============================================================================= -// Host-orchestrated Miller loop. -// -// The full Miller loop crashes MetalCompilerService when expressed as a -// single kernel function (XPC connection drops during AIR->GPU lowering on -// this kernel's call tree). Splitting into bounded sub-kernels keeps each -// kernel's compile within the service's budget while still running 100% on -// Metal — no host arithmetic, just dispatch orchestration. -// -// State buffers per workitem: -// T_state (288 B) — POINTonE2 Jacobian (X, Y, Z) -// ret_state (576 B) — Fp12 accumulator -// px2 (96 B) — (-2*P.X, 2*P.Y) packed as Fp2 -// ============================================================================= - -struct MillerIn { - P2Aff Q; - P1Aff P; -}; - -// Storage for an evaluated line (line_by_Px2 already applied). Three Fp2. -struct LineBuf { Fp2 x, y, z; }; - -// k_miller_init — set T = Q (Z=1), compute Px2, do initial line_dbl, -// store T_state + ret_state + px2 for each workitem. -kernel void k_miller_init( - device const MillerIn* in [[buffer(0)]], - device P2* T_buf [[buffer(1)]], - device Fp12* ret_buf[[buffer(2)]], - device Fp2* px2_buf[[buffer(3)]], - constant uint& n [[buffer(4)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - - P2Aff Q = in[tid].Q; - P1Aff P = in[tid].P; - - uint384 two_px = fp_add(P.X, P.X); - Fp2 Px2; - Px2.c0 = fp_neg(two_px); - Px2.c1 = fp_add(P.Y, P.Y); - px2_buf[tid] = Px2; - - P2 T; - T.X = Q.X; T.Y = Q.Y; T.Z = fp2_one(); - - Line L0 = line_dbl(T, T); - L0 = line_by_Px2(L0, Px2.c0, Px2.c1); - Fp12 ret = unpack_initial_line(L0); - - T_buf[tid] = T; - ret_buf[tid] = ret; -} - -// k_miller_add_T_and_line — T = T + Q ; line = addition-line at original T, -// with line_by_Px2 baked in. -kernel void k_miller_add_T_and_line( - device const MillerIn* in [[buffer(0)]], - device P2* T_buf [[buffer(1)]], - device LineBuf* line_buf [[buffer(2)]], - device const Fp2* px2_buf [[buffer(3)]], - constant uint& n [[buffer(4)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - P2 T = T_buf[tid]; - Fp2 Px2 = px2_buf[tid]; - - Line L = line_add(T, T, in[tid].Q); - L = line_by_Px2(L, Px2.c0, Px2.c1); - - T_buf[tid] = T; - LineBuf lb; lb.x = L.x; lb.y = L.y; lb.z = L.z; - line_buf[tid] = lb; -} - -// k_miller_dbl_T_and_line — T = 2*T ; line = doubling-line at original T, -// with line_by_Px2 baked in. Saves 3 fp2 into line_buf for k_miller_fold_line. -kernel void k_miller_dbl_T_and_line( - device P2* T_buf [[buffer(0)]], - device LineBuf* line_buf [[buffer(1)]], - device const Fp2* px2_buf [[buffer(2)]], - constant uint& n [[buffer(3)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - P2 T = T_buf[tid]; - Fp2 Px2 = px2_buf[tid]; - - Line Ld = line_dbl(T, T); - Ld = line_by_Px2(Ld, Px2.c0, Px2.c1); - - T_buf[tid] = T; - LineBuf lb; lb.x = Ld.x; lb.y = Ld.y; lb.z = Ld.z; - line_buf[tid] = lb; -} - -// k_miller_sqr_ret — ret = ret^2 (in place). -kernel void k_miller_sqr_ret( - device Fp12* ret_buf [[buffer(0)]], - constant uint& n [[buffer(1)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - ret_buf[tid] = fp12_sqr(ret_buf[tid]); -} - -// k_miller_fold_line — ret *= line (sparse Fp12 multiply). -kernel void k_miller_fold_line( - device Fp12* ret_buf [[buffer(0)]], - device const LineBuf* line_buf [[buffer(1)]], - constant uint& n [[buffer(2)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - LineBuf lb = line_buf[tid]; - Line L; L.x = lb.x; L.y = lb.y; L.z = lb.z; - ret_buf[tid] = fp12_mul_by_xy00z0(ret_buf[tid], L); -} - -// k_miller_finalize — conjugate ret (account for x being negative). -kernel void k_miller_finalize( - device Fp12* ret_buf [[buffer(0)]], - device Fp12* out [[buffer(1)]], - constant uint& n [[buffer(2)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - out[tid] = fp12_conj(ret_buf[tid]); -} diff --git a/bls/gpu/metal/bls_pairing.metal b/bls/gpu/metal/bls_pairing.metal deleted file mode 100644 index 3f5e86c..0000000 --- a/bls/gpu/metal/bls_pairing.metal +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// BLS12-381 full pairing on Metal. -// -// e(P, Q) = final_exp( miller_loop(P, Q) ) -// -// Both stages run on Metal — this file is a small bridge that re-exposes the -// Fp12 type for buffer sizing and adds two batch helpers used by aggregate -// verify: -// -// k_pair_aggregate_step : acc = acc * src[i] (one pair folded in) -// k_pair_eq_one : flag[i] = (ret[i] == Fp12::one()) -// -// The Miller-loop kernels live in bls_miller.metal and the final_exp kernels -// in bls_final_exp.metal. The host driver chains them in a single command -// queue. Workgroup 1×1×1 per kernel preserves byte-determinism. - -#define BLS_FP12_NO_KERNELS -#define BLS_FP6_NO_KERNELS -#define BLS_FP2_NO_KERNELS -#include "bls_fp12.metal" -#undef BLS_FP12_NO_KERNELS -#undef BLS_FP6_NO_KERNELS -#undef BLS_FP2_NO_KERNELS - -// Multiplicative identity in Fp12 (Montgomery form): -// 1 = (1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) in the c0=Fp6, c1=Fp6 layout -// = ((BLS_R, 0), (0, 0), (0, 0)), ((0, 0), (0, 0), (0, 0)) -inline Fp12 fp12_one() { - Fp2 zerop; zerop.c0 = ZERO384; zerop.c1 = ZERO384; - Fp2 onep; onep.c0 = BLS_R; onep.c1 = ZERO384; - Fp6 one6; one6.c0 = onep; one6.c1 = zerop; one6.c2 = zerop; - Fp6 zero6; zero6.c0 = zerop; zero6.c1 = zerop; zero6.c2 = zerop; - Fp12 r; r.c0 = one6; r.c1 = zero6; return r; -} - -// k_pair_one_init — acc[tid] = 1 (Fp12 multiplicative identity). -kernel void k_pair_one_init( - device Fp12* acc [[buffer(0)]], - constant uint& n [[buffer(1)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - acc[tid] = fp12_one(); -} - -// k_pair_aggregate_step — acc[tid] = acc[tid] * src[step_idx] -// Used to fold a pre-final-exp Miller output for one pair into a per-batch -// accumulator. step_idx is supplied as a constant so the kernel reads the -// correct slot of the Miller-output array. -kernel void k_pair_aggregate_step( - device const Fp12* src [[buffer(0)]], - device Fp12* acc [[buffer(1)]], - constant uint& step_idx [[buffer(2)]], - constant uint& n [[buffer(3)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - acc[tid] = fp12_mul(acc[tid], src[step_idx]); -} - -// k_pair_eq_one — flag[tid] = (ret[tid] == 1) ? 1 : 0 -// Used to verify aggregate verify succeeded (final-exp output equals Fp12::one()). -kernel void k_pair_eq_one( - device const Fp12* ret [[buffer(0)]], - device uint8_t* flag [[buffer(1)]], - constant uint& n [[buffer(2)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - Fp12 one = fp12_one(); - Fp12 r = ret[tid]; - - // byte-equality of all 6 Fp2 components - bool eq = true; - Fp2 a[6] = { r.c0.c0, r.c0.c1, r.c0.c2, r.c1.c0, r.c1.c1, r.c1.c2 }; - Fp2 b[6] = { one.c0.c0, one.c0.c1, one.c0.c2, one.c1.c0, one.c1.c1, one.c1.c2 }; - for (uint i = 0; i < 6; i++) { - for (uint j = 0; j < 6; j++) { - if (a[i].c0.limbs[j] != b[i].c0.limbs[j]) { eq = false; } - if (a[i].c1.limbs[j] != b[i].c1.limbs[j]) { eq = false; } - } - } - flag[tid] = eq ? 1u : 0u; -} diff --git a/bls/gpu/metal/msm.metal b/bls/gpu/metal/msm.metal deleted file mode 100644 index 78a61f6..0000000 --- a/bls/gpu/metal/msm.metal +++ /dev/null @@ -1,665 +0,0 @@ -// Multi-Scalar Multiplication (MSM) for Elliptic Curves -// Implements Pippenger's algorithm for efficient batch scalar multiplication -// -// Supports: BN254 G1, BLS12-381 G1/G2 -// Use case: Pedersen commitments, KZG polynomial commitments, batch verification - -#include -using namespace metal; - -// ============================================================================= -// BN254 Field (Fp) - 254-bit prime field -// p = 21888242871839275222246405745257275088696311157297823662689037894645226208583 -// ============================================================================= - -struct Fp254 { - uint64_t limbs[4]; // 256-bit representation (4 x 64-bit) -}; - -// BN254 field modulus -constant uint64_t BN254_P[4] = { - 0x3C208C16D87CFD47ULL, - 0x97816A916871CA8DULL, - 0xB85045B68181585DULL, - 0x30644E72E131A029ULL -}; - -// BN254 scalar field (Fr) modulus -constant uint64_t BN254_R[4] = { - 0x43E1F593F0000001ULL, - 0x2833E84879B97091ULL, - 0xB85045B68181585DULL, - 0x30644E72E131A029ULL -}; - -// ============================================================================= -// BLS12-381 Field (Fp) - 381-bit prime field -// p = 0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab -// ============================================================================= - -struct Fp381 { - uint64_t limbs[6]; // 384-bit representation (6 x 64-bit) -}; - -// BLS12-381 field modulus -constant uint64_t BLS_P[6] = { - 0xB9FEFFFFFFFFAAABULL, - 0x1EABFFFEB153FFFFULL, - 0x6730D2A0F6B0F624ULL, - 0x64774B84F38512BFULL, - 0x4B1BA7B6434BACD7ULL, - 0x1A0111EA397FE69AULL -}; - -// BLS12-381 scalar field (Fr) modulus -constant uint64_t BLS_R[4] = { - 0xFFFFFFFF00000001ULL, - 0x53BDA402FFFE5BFEULL, - 0x3339D80809A1D805ULL, - 0x73EDA753299D7D48ULL -}; - -// ============================================================================= -// Basic 256-bit Arithmetic (for BN254 and BLS12-381 scalars) -// ============================================================================= - -// Add with carry -inline uint64_t adc(uint64_t a, uint64_t b, thread uint64_t& carry) { - uint64_t sum = a + b + carry; - carry = (sum < a || (carry && sum == a)) ? 1 : 0; - return sum; -} - -// Subtract with borrow -inline uint64_t sbb(uint64_t a, uint64_t b, thread uint64_t& borrow) { - uint64_t diff = a - b - borrow; - borrow = (a < b + borrow) ? 1 : 0; - return diff; -} - -// 256-bit addition -inline void add256(thread Fp254& c, Fp254 a, Fp254 b) { - uint64_t carry = 0; - for (int i = 0; i < 4; i++) { - c.limbs[i] = adc(a.limbs[i], b.limbs[i], carry); - } -} - -// 256-bit subtraction -inline void sub256(thread Fp254& c, Fp254 a, Fp254 b) { - uint64_t borrow = 0; - for (int i = 0; i < 4; i++) { - c.limbs[i] = sbb(a.limbs[i], b.limbs[i], borrow); - } -} - -// Check if a >= b -inline bool gte256(Fp254 a, Fp254 b) { - for (int i = 3; i >= 0; i--) { - if (a.limbs[i] > b.limbs[i]) return true; - if (a.limbs[i] < b.limbs[i]) return false; - } - return true; // Equal -} - -// ============================================================================= -// Affine Point Representation -// ============================================================================= - -struct G1Affine254 { - Fp254 x; - Fp254 y; - bool infinity; -}; - -struct G1Affine381 { - Fp381 x; - Fp381 y; - bool infinity; -}; - -// ============================================================================= -// Projective Point Representation (for efficient addition) -// ============================================================================= - -struct G1Projective254 { - Fp254 x; - Fp254 y; - Fp254 z; -}; - -struct G1Projective381 { - Fp381 x; - Fp381 y; - Fp381 z; -}; - -// ============================================================================= -// Montgomery Multiplication for BN254 Fp -// ============================================================================= - -inline void fp254_mul(thread Fp254& c, Fp254 a, Fp254 b) { - // Simplified schoolbook multiplication with Montgomery reduction - // Full implementation would use CIOS or similar - - uint64_t t[8] = {0}; - - // Schoolbook multiplication - for (int i = 0; i < 4; i++) { - uint64_t carry = 0; - for (int j = 0; j < 4; j++) { - // 64x64 -> 128-bit multiplication - uint64_t lo = a.limbs[i] * b.limbs[j]; - uint64_t hi = mulhi(a.limbs[i], b.limbs[j]); - - uint64_t sum = t[i+j] + lo + carry; - carry = (sum < t[i+j]) ? 1 : 0; - carry += hi; - t[i+j] = sum; - } - t[i+4] = carry; - } - - // Montgomery reduction (simplified) - // Full implementation would use proper Montgomery constants - for (int i = 0; i < 4; i++) { - c.limbs[i] = t[i]; // Placeholder - } -} - -// Field addition mod p -inline void fp254_add(thread Fp254& c, Fp254 a, Fp254 b) { - add256(c, a, b); - - Fp254 p = {{BN254_P[0], BN254_P[1], BN254_P[2], BN254_P[3]}}; - if (gte256(c, p)) { - sub256(c, c, p); - } -} - -// Field subtraction mod p -inline void fp254_sub(thread Fp254& c, Fp254 a, Fp254 b) { - if (gte256(a, b)) { - sub256(c, a, b); - } else { - Fp254 p = {{BN254_P[0], BN254_P[1], BN254_P[2], BN254_P[3]}}; - add256(c, a, p); - sub256(c, c, b); - } -} - -// ============================================================================= -// Point Operations (BN254 G1) -// ============================================================================= - -// Point doubling in projective coordinates -// Uses complete doubling formula -inline G1Projective254 g1_double_254(G1Projective254 p) { - if (p.z.limbs[0] == 0 && p.z.limbs[1] == 0 && - p.z.limbs[2] == 0 && p.z.limbs[3] == 0) { - return p; // Point at infinity - } - - // Standard projective doubling formula for short Weierstrass - // a = 0 for BN254 - Fp254 xx, yy, yyyy, zz; - Fp254 s, m, t; - - fp254_mul(xx, p.x, p.x); // X^2 - fp254_mul(yy, p.y, p.y); // Y^2 - fp254_mul(yyyy, yy, yy); // Y^4 - fp254_mul(zz, p.z, p.z); // Z^2 - - // S = 2*((X+YY)^2 - XX - YYYY) - Fp254 x_plus_yy; - fp254_add(x_plus_yy, p.x, yy); - fp254_mul(s, x_plus_yy, x_plus_yy); - fp254_sub(s, s, xx); - fp254_sub(s, s, yyyy); - fp254_add(s, s, s); // 2*S - - // M = 3*XX (a = 0 for BN254) - fp254_add(m, xx, xx); - fp254_add(m, m, xx); - - // T = M^2 - 2*S - fp254_mul(t, m, m); - Fp254 two_s; - fp254_add(two_s, s, s); - fp254_sub(t, t, two_s); - - G1Projective254 result; - - // X3 = T - result.x = t; - - // Y3 = M*(S-T) - 8*YYYY - Fp254 s_minus_t; - fp254_sub(s_minus_t, s, t); - fp254_mul(result.y, m, s_minus_t); - Fp254 eight_yyyy; - fp254_add(eight_yyyy, yyyy, yyyy); - fp254_add(eight_yyyy, eight_yyyy, eight_yyyy); - fp254_add(eight_yyyy, eight_yyyy, eight_yyyy); - fp254_sub(result.y, result.y, eight_yyyy); - - // Z3 = 2*Y*Z - fp254_mul(result.z, p.y, p.z); - fp254_add(result.z, result.z, result.z); - - return result; -} - -// Mixed addition: projective + affine -inline G1Projective254 g1_add_mixed_254(G1Projective254 p, G1Affine254 q) { - if (q.infinity) return p; - - if (p.z.limbs[0] == 0 && p.z.limbs[1] == 0 && - p.z.limbs[2] == 0 && p.z.limbs[3] == 0) { - return {q.x, q.y, {{1, 0, 0, 0}}}; - } - - // Using madd-2008-s formula - Fp254 zz, u2, s2; - fp254_mul(zz, p.z, p.z); - fp254_mul(u2, q.x, zz); - Fp254 zzz; - fp254_mul(zzz, zz, p.z); - fp254_mul(s2, q.y, zzz); - - // h = U2 - X1, r = S2 - Y1 - Fp254 h, r; - fp254_sub(h, u2, p.x); - fp254_sub(r, s2, p.y); - - // Check if P = Q (same point, need doubling) - // or P = -Q (result is infinity) - // Simplified: assume not edge cases - - Fp254 hh, hhh, v; - fp254_mul(hh, h, h); - fp254_mul(hhh, hh, h); - fp254_mul(v, p.x, hh); - - G1Projective254 result; - - // X3 = r^2 - HHH - 2*V - Fp254 rr; - fp254_mul(rr, r, r); - fp254_sub(result.x, rr, hhh); - Fp254 two_v; - fp254_add(two_v, v, v); - fp254_sub(result.x, result.x, two_v); - - // Y3 = r*(V - X3) - Y1*HHH - Fp254 v_minus_x3, y1_hhh; - fp254_sub(v_minus_x3, v, result.x); - fp254_mul(result.y, r, v_minus_x3); - fp254_mul(y1_hhh, p.y, hhh); - fp254_sub(result.y, result.y, y1_hhh); - - // Z3 = Z1 * H - fp254_mul(result.z, p.z, h); - - return result; -} - -// Full projective addition: projective + projective -// Uses add-2007-bl formula from hyperelliptic.org -inline G1Projective254 g1_add_projective_254(G1Projective254 p, G1Projective254 q) { - // Check for identity points - bool p_inf = (p.z.limbs[0] == 0 && p.z.limbs[1] == 0 && - p.z.limbs[2] == 0 && p.z.limbs[3] == 0); - bool q_inf = (q.z.limbs[0] == 0 && q.z.limbs[1] == 0 && - q.z.limbs[2] == 0 && q.z.limbs[3] == 0); - - if (p_inf) return q; - if (q_inf) return p; - - // U1 = X1*Z2^2, U2 = X2*Z1^2 - Fp254 z1z1, z2z2, u1, u2; - fp254_mul(z1z1, p.z, p.z); - fp254_mul(z2z2, q.z, q.z); - fp254_mul(u1, p.x, z2z2); - fp254_mul(u2, q.x, z1z1); - - // S1 = Y1*Z2^3, S2 = Y2*Z1^3 - Fp254 z1z1z1, z2z2z2, s1, s2; - fp254_mul(z1z1z1, z1z1, p.z); - fp254_mul(z2z2z2, z2z2, q.z); - fp254_mul(s1, p.y, z2z2z2); - fp254_mul(s2, q.y, z1z1z1); - - // H = U2 - U1, R = S2 - S1 - Fp254 h, r; - fp254_sub(h, u2, u1); - fp254_sub(r, s2, s1); - - // Check if P = Q (need doubling) or P = -Q (infinity) - bool h_zero = (h.limbs[0] == 0 && h.limbs[1] == 0 && - h.limbs[2] == 0 && h.limbs[3] == 0); - bool r_zero = (r.limbs[0] == 0 && r.limbs[1] == 0 && - r.limbs[2] == 0 && r.limbs[3] == 0); - - if (h_zero) { - if (r_zero) { - // P == Q, need doubling - return g1_double_254(p); - } else { - // P == -Q, return infinity - G1Projective254 inf; - inf.x = {{0, 0, 0, 0}}; - inf.y = {{1, 0, 0, 0}}; - inf.z = {{0, 0, 0, 0}}; - return inf; - } - } - - // I = (2*H)^2, J = H*I - Fp254 h2, i, j; - fp254_add(h2, h, h); - fp254_mul(i, h2, h2); - fp254_mul(j, h, i); - - // r2 = 2*R - Fp254 r2; - fp254_add(r2, r, r); - - // V = U1*I - Fp254 v; - fp254_mul(v, u1, i); - - // X3 = r2^2 - J - 2*V - G1Projective254 result; - Fp254 r2r2, two_v; - fp254_mul(r2r2, r2, r2); - fp254_sub(result.x, r2r2, j); - fp254_add(two_v, v, v); - fp254_sub(result.x, result.x, two_v); - - // Y3 = r2*(V - X3) - 2*S1*J - Fp254 v_minus_x3, s1j, two_s1j; - fp254_sub(v_minus_x3, v, result.x); - fp254_mul(result.y, r2, v_minus_x3); - fp254_mul(s1j, s1, j); - fp254_add(two_s1j, s1j, s1j); - fp254_sub(result.y, result.y, two_s1j); - - // Z3 = ((Z1+Z2)^2 - Z1Z1 - Z2Z2)*H - Fp254 z1_plus_z2, z1_plus_z2_sq, tmp; - fp254_add(z1_plus_z2, p.z, q.z); - fp254_mul(z1_plus_z2_sq, z1_plus_z2, z1_plus_z2); - fp254_sub(tmp, z1_plus_z2_sq, z1z1); - fp254_sub(tmp, tmp, z2z2); - fp254_mul(result.z, tmp, h); - - return result; -} - -// ============================================================================= -// Pippenger MSM Algorithm -// ============================================================================= - -// MSM configuration -constant uint32_t MSM_WINDOW_SIZE = 8; // c = 8 bits per window -constant uint32_t MSM_NUM_WINDOWS = 32; // 256 bits / 8 bits = 32 windows -constant uint32_t MSM_BUCKETS_PER_WINDOW = 255; // 2^c - 1 buckets (exclude 0) - -// Extract window from scalar -inline uint32_t get_scalar_window(Fp254 scalar, uint32_t window_idx) { - uint32_t bit_idx = window_idx * MSM_WINDOW_SIZE; - uint32_t limb_idx = bit_idx / 64; - uint32_t bit_offset = bit_idx % 64; - - uint64_t limb = scalar.limbs[limb_idx]; - uint32_t window = (limb >> bit_offset) & ((1 << MSM_WINDOW_SIZE) - 1); - - // Handle window crossing limb boundary - if (bit_offset > 64 - MSM_WINDOW_SIZE && limb_idx < 3) { - uint32_t remaining_bits = MSM_WINDOW_SIZE - (64 - bit_offset); - window |= (scalar.limbs[limb_idx + 1] & ((1ULL << remaining_bits) - 1)) << (64 - bit_offset); - } - - return window; -} - -// ============================================================================= -// MSM Kernels -// ============================================================================= - -// Phase 1: Bucket accumulation -// Each thread handles one (point, scalar) pair, adds to appropriate bucket -kernel void msm_bucket_accumulate( - device const G1Affine254* points [[buffer(0)]], - device const Fp254* scalars [[buffer(1)]], - device G1Projective254* buckets [[buffer(2)]], // [num_windows][num_buckets] - constant uint32_t& num_points [[buffer(3)]], - constant uint32_t& window_idx [[buffer(4)]], - uint index [[thread_position_in_grid]] -) { - if (index >= num_points) return; - - G1Affine254 point = points[index]; - Fp254 scalar = scalars[index]; - - uint32_t bucket_idx = get_scalar_window(scalar, window_idx); - if (bucket_idx == 0) return; // Skip zero windows - - // Atomic-style bucket update (simplified - real impl needs atomic operations) - uint32_t bucket_offset = window_idx * MSM_BUCKETS_PER_WINDOW + (bucket_idx - 1); - - // Add point to bucket - G1Projective254 bucket = buckets[bucket_offset]; - buckets[bucket_offset] = g1_add_mixed_254(bucket, point); -} - -// Phase 2: Bucket reduction -// Compute window sum from buckets: sum_{i=1}^{2^c-1} i * bucket[i] -kernel void msm_bucket_reduce( - device G1Projective254* buckets [[buffer(0)]], - device G1Projective254* window_sums [[buffer(1)]], - constant uint32_t& window_idx [[buffer(2)]], - uint tid [[thread_position_in_grid]] -) { - if (tid != 0) return; // Single thread per window (for correctness) - - uint32_t bucket_offset = window_idx * MSM_BUCKETS_PER_WINDOW; - - // Running sum method: - // sum = B[k-1], running = B[k-1] - // for i = k-2 down to 0: running += B[i], sum += running - - G1Projective254 running = buckets[bucket_offset + MSM_BUCKETS_PER_WINDOW - 1]; - G1Projective254 sum = running; - - for (int32_t i = MSM_BUCKETS_PER_WINDOW - 2; i >= 0; i--) { - G1Projective254 bucket = buckets[bucket_offset + i]; - // Add bucket to running (full projective addition) - running = g1_add_projective_254(running, bucket); - sum = g1_add_projective_254(sum, running); - } - - window_sums[window_idx] = sum; -} - -// Phase 3: Window combination -// Combine window sums: result = sum_{i=0}^{k-1} 2^{c*i} * window_sum[i] -kernel void msm_window_combine( - device const G1Projective254* window_sums [[buffer(0)]], - device G1Projective254* result [[buffer(1)]], - constant uint32_t& num_windows [[buffer(2)]], - uint tid [[thread_position_in_grid]] -) { - if (tid != 0) return; - - G1Projective254 acc = window_sums[num_windows - 1]; - - for (int32_t i = num_windows - 2; i >= 0; i--) { - // Double c times - for (uint32_t j = 0; j < MSM_WINDOW_SIZE; j++) { - acc = g1_double_254(acc); - } - // Add window sum - acc = g1_add_projective_254(acc, window_sums[i]); - } - - *result = acc; -} - -// ============================================================================= -// Single Scalar Multiplication (for small batches) -// Uses double-and-add -// ============================================================================= - -kernel void g1_scalar_mul( - device const G1Affine254* points [[buffer(0)]], - device const Fp254* scalars [[buffer(1)]], - device G1Projective254* results [[buffer(2)]], - uint index [[thread_position_in_grid]] -) { - G1Affine254 point = points[index]; - Fp254 scalar = scalars[index]; - - // Initialize result to point at infinity - G1Projective254 acc = {{{0,0,0,0}}, {{0,0,0,0}}, {{0,0,0,0}}}; - - // Double-and-add from MSB - bool started = false; - - for (int32_t limb = 3; limb >= 0; limb--) { - for (int32_t bit = 63; bit >= 0; bit--) { - if (started) { - acc = g1_double_254(acc); - } - - if ((scalar.limbs[limb] >> bit) & 1) { - if (started) { - acc = g1_add_mixed_254(acc, point); - } else { - acc = {point.x, point.y, {{1, 0, 0, 0}}}; - started = true; - } - } - } - } - - results[index] = acc; -} - -// ============================================================================= -// Pedersen Commitment: v*G + r*H (BN254 variant) -// ============================================================================= - -kernel void pedersen_commit_bn254( - device const Fp254* values [[buffer(0)]], - device const Fp254* blindings [[buffer(1)]], - device const G1Affine254* generators [[buffer(2)]], // [G, H] - device G1Projective254* commitments [[buffer(3)]], - uint index [[thread_position_in_grid]] -) { - Fp254 v = values[index]; - Fp254 r = blindings[index]; - G1Affine254 G = generators[0]; - G1Affine254 H = generators[1]; - - // Compute v*G - G1Projective254 vG = {{{0,0,0,0}}, {{0,0,0,0}}, {{0,0,0,0}}}; - bool started = false; - - for (int32_t limb = 3; limb >= 0; limb--) { - for (int32_t bit = 63; bit >= 0; bit--) { - if (started) { - vG = g1_double_254(vG); - } - if ((v.limbs[limb] >> bit) & 1) { - if (started) { - vG = g1_add_mixed_254(vG, G); - } else { - vG = {G.x, G.y, {{1, 0, 0, 0}}}; - started = true; - } - } - } - } - - // Compute r*H - G1Projective254 rH = {{{0,0,0,0}}, {{0,0,0,0}}, {{0,0,0,0}}}; - started = false; - - for (int32_t limb = 3; limb >= 0; limb--) { - for (int32_t bit = 63; bit >= 0; bit--) { - if (started) { - rH = g1_double_254(rH); - } - if ((r.limbs[limb] >> bit) & 1) { - if (started) { - rH = g1_add_mixed_254(rH, H); - } else { - rH = {H.x, H.y, {{1, 0, 0, 0}}}; - started = true; - } - } - } - } - - // Add vG + rH - commitments[index] = g1_add_projective_254(vG, rH); -} - -// ============================================================================= -// Vector Commitment: sum_i v_i * G_i -// ============================================================================= - -kernel void vector_commit( - device const Fp254* values [[buffer(0)]], - device const G1Affine254* generators [[buffer(1)]], - device G1Projective254* partial_sums [[buffer(2)]], // One per thread - constant uint32_t& num_values [[buffer(3)]], - uint index [[thread_position_in_grid]], - uint num_threads [[threads_per_grid]] -) { - // Each thread handles a subset of values - uint32_t chunk_size = (num_values + num_threads - 1) / num_threads; - uint32_t start = index * chunk_size; - uint32_t end = min(start + chunk_size, num_values); - - G1Projective254 acc = {{{0,0,0,0}}, {{0,0,0,0}}, {{0,0,0,0}}}; - bool started = false; - - for (uint32_t i = start; i < end; i++) { - Fp254 v = values[i]; - G1Affine254 G = generators[i]; - - // Compute v * G - G1Projective254 vG = {{{0,0,0,0}}, {{0,0,0,0}}, {{0,0,0,0}}}; - bool point_started = false; - - for (int32_t limb = 3; limb >= 0; limb--) { - for (int32_t bit = 63; bit >= 0; bit--) { - if (point_started) { - vG = g1_double_254(vG); - } - if ((v.limbs[limb] >> bit) & 1) { - if (point_started) { - vG = g1_add_mixed_254(vG, G); - } else { - vG = {G.x, G.y, {{1, 0, 0, 0}}}; - point_started = true; - } - } - } - } - - // Add to accumulator - if (point_started) { - if (started) { - acc = g1_add_projective_254(acc, vG); - } else { - acc = vG; - started = true; - } - } - } - - partial_sums[index] = acc; -} diff --git a/bls/gpu/wgsl/bls.wgsl b/bls/gpu/wgsl/bls.wgsl deleted file mode 100644 index b825302..0000000 --- a/bls/gpu/wgsl/bls.wgsl +++ /dev/null @@ -1,395 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// BLS12-381 G1 point operations in WGSL. -// 384-bit field arithmetic (Fp) in Montgomery form for batch BLS verification. -// Uses 12 x u32 limbs (WGSL has no u64). -// Matches bls12_381.metal output byte-for-byte. - -@group(0) @binding(0) var sig_data: array; -@group(0) @binding(1) var results: array; -@group(0) @binding(2) var params: Params; - -struct Params { - num_items: u32, - mode: u32, // 0 = verify, 1 = aggregate -} - -// ============================================================================ -// 384-bit integer as 12 x u32 (little-endian) -// ============================================================================ - -fn u384_zero() -> array { - return array(0u,0u,0u,0u,0u,0u,0u,0u,0u,0u,0u,0u); -} - -fn u384_is_zero(a: ptr>) -> bool { - var acc = 0u; - for (var i = 0u; i < 12u; i = i + 1u) { acc = acc | (*a)[i]; } - return acc == 0u; -} - -fn u384_cmp(a: ptr>, b: ptr>) -> i32 { - for (var i = 11i; i >= 0; i = i - 1) { - let ui = u32(i); - if ((*a)[ui] > (*b)[ui]) { return 1; } - if ((*a)[ui] < (*b)[ui]) { return -1; } - } - return 0; -} - -fn u384_add(a: ptr>, b: ptr>, - r: ptr>) -> u32 { - var c = 0u; - for (var i = 0u; i < 12u; i = i + 1u) { - let s1 = (*a)[i] + c; - c = select(0u, 1u, s1 < (*a)[i]); - let s2 = s1 + (*b)[i]; - c = c + select(0u, 1u, s2 < s1); - (*r)[i] = s2; - } - return c; -} - -fn u384_sub(a: ptr>, b: ptr>, - r: ptr>) -> u32 { - var bw = 0u; - for (var i = 0u; i < 12u; i = i + 1u) { - let d1 = (*a)[i] - bw; - bw = select(0u, 1u, d1 > (*a)[i]); - let d2 = d1 - (*b)[i]; - bw = bw + select(0u, 1u, d2 > d1); - (*r)[i] = d2; - } - return bw; -} - -// BLS12-381 field modulus p (384 bits, 12 x u32 LE) -const BLS_P = array( - 0xFFFFAAABu, 0xB9FEFFFFu, 0xB153FFFFu, 0x1EABFFFEu, - 0xF6B0F624u, 0x6730D2A0u, 0xF38512BFu, 0x64774B84u, - 0x434BACD7u, 0x4B1BA7B6u, 0x397FE69Au, 0x1A0111EAu -); - -// Montgomery R^2 mod p -const BLS_R2 = array( - 0x1C341746u, 0xF4DF1F34u, 0x09D104F1u, 0x0A76E6A6u, - 0x4C95B6D5u, 0x8DE5476Cu, 0x939D83C0u, 0x67EB88A9u, - 0xB519952Du, 0x9A793E85u, 0x92CAE3AAu, 0x11988FE5u -); - -// Montgomery R mod p (1 in Montgomery form) -const BLS_R = array( - 0x0002FFCDu, 0x76090000u, 0xC40C0002u, 0xEBF4000Bu, - 0x53C758BAu, 0x5F489857u, 0x70525745u, 0x77CE5853u, - 0xA256EC6Du, 0x5C071A97u, 0xFA80E493u, 0x15F65EC3u -); - -// -p^{-1} mod 2^32 -const BLS_P_INV: u32 = 0xFFFCFFFDu; - -// ============================================================================ -// Montgomery reduction/multiplication for 384-bit (12 x u32) -// ============================================================================ - -fn bls_mont_reduce(t: ptr>, - r: ptr>) { - var a: array; - for (var i = 0u; i < 24u; i = i + 1u) { a[i] = (*t)[i]; } - a[24] = 0u; - - for (var i = 0u; i < 12u; i = i + 1u) { - let u = a[i] * BLS_P_INV; - var carry = 0u; - for (var j = 0u; j < 12u; j = j + 1u) { - let ul = u & 0xFFFFu; let uh = u >> 16u; - let ml = BLS_P[j] & 0xFFFFu; let mh = BLS_P[j] >> 16u; - let ll = ul * ml; - let mid = ul * mh + uh * ml; - let hh = uh * mh; - var lo = ll + (mid << 16u); - var hi = hh + (mid >> 16u) + select(0u, 1u, lo < ll) + select(0u, 0x10000u, (ul*mh + uh*ml) < (ul*mh)); - - let s1 = lo + carry; hi = hi + select(0u, 1u, s1 < lo); - let s2 = a[i + j] + s1; hi = hi + select(0u, 1u, s2 < a[i + j]); - a[i + j] = s2; - carry = hi; - } - for (var j = 12u; i + j <= 24u; j = j + 1u) { - let s = a[i + j] + carry; - carry = select(0u, 1u, s < a[i + j]); - a[i + j] = s; - if (carry == 0u) { break; } - } - } - - for (var i = 0u; i < 12u; i = i + 1u) { (*r)[i] = a[i + 12u]; } - - var p = BLS_P; - if (a[24] != 0u || u384_cmp(r, &p) >= 0) { - let _ = u384_sub(r, &p, r); - } -} - -fn bls_fp_mul(a: ptr>, b: ptr>, - r: ptr>) { - var t: array; - for (var i = 0u; i < 24u; i = i + 1u) { t[i] = 0u; } - - for (var i = 0u; i < 12u; i = i + 1u) { - var carry = 0u; - for (var j = 0u; j < 12u; j = j + 1u) { - let al = (*a)[i] & 0xFFFFu; let ah = (*a)[i] >> 16u; - let bl = (*b)[j] & 0xFFFFu; let bh = (*b)[j] >> 16u; - let ll = al * bl; - let mid = al * bh + ah * bl; - let hh = ah * bh; - var lo = ll + (mid << 16u); - var hi = hh + (mid >> 16u) + select(0u, 1u, lo < ll); - let s1 = lo + carry; hi = hi + select(0u, 1u, s1 < lo); - let s2 = t[i + j] + s1; hi = hi + select(0u, 1u, s2 < t[i + j]); - t[i + j] = s2; - carry = hi; - } - for (var j = 12u; i + j < 24u; j = j + 1u) { - let s = t[i + j] + carry; - carry = select(0u, 1u, s < t[i + j]); - t[i + j] = s; - if (carry == 0u) { break; } - } - } - bls_mont_reduce(&t, r); -} - -fn bls_fp_sqr(a: ptr>, r: ptr>) { - bls_fp_mul(a, a, r); -} - -fn bls_fp_add(a: ptr>, b: ptr>, - r: ptr>) { - var p = BLS_P; - let c = u384_add(a, b, r); - if (c != 0u || u384_cmp(r, &p) >= 0) { - let _ = u384_sub(r, &p, r); - } -} - -fn bls_fp_sub(a: ptr>, b: ptr>, - r: ptr>) { - var p = BLS_P; - let bw = u384_sub(a, b, r); - if (bw != 0u) { - let _ = u384_add(r, &p, r); - } -} - -fn bls_fp_neg(a: ptr>, r: ptr>) { - if (u384_is_zero(a)) { *r = u384_zero(); return; } - var p = BLS_P; - let _ = u384_sub(&p, a, r); -} - -fn bls_to_mont(a: ptr>, r: ptr>) { - var r2 = BLS_R2; - bls_fp_mul(a, &r2, r); -} - -fn bls_from_mont(a: ptr>, r: ptr>) { - var t: array; - for (var i = 0u; i < 24u; i = i + 1u) { t[i] = 0u; } - for (var i = 0u; i < 12u; i = i + 1u) { t[i] = (*a)[i]; } - bls_mont_reduce(&t, r); -} - -fn bls_fp_inv(a: ptr>, r: ptr>) { - // p-2 (LE u32 limbs) - var exp = BLS_P; - exp[0] = exp[0] - 2u; - var result = BLS_R; - var base: array; - for (var i = 0u; i < 12u; i = i + 1u) { base[i] = (*a)[i]; } - - for (var i = 0u; i < 12u; i = i + 1u) { - for (var bit = 0u; bit < 32u; bit = bit + 1u) { - if (((exp[i] >> bit) & 1u) != 0u) { - var tmp: array; - bls_fp_mul(&result, &base, &tmp); - result = tmp; - } - var tmp2: array; - bls_fp_sqr(&base, &tmp2); - base = tmp2; - } - } - *r = result; -} - -// ============================================================================ -// G1 point operations (Jacobian, Montgomery Fp) -// ============================================================================ - -struct G1Point { - x: array, - y: array, - z: array, -} - -fn g1_identity() -> G1Point { - var p: G1Point; - p.x = BLS_R; p.y = BLS_R; p.z = u384_zero(); - return p; -} - -fn g1_is_inf(p: ptr) -> bool { - var z = (*p).z; - return u384_is_zero(&z); -} - -fn g1_double(p: ptr, r: ptr) { - if (g1_is_inf(p)) { *r = *p; return; } - var A: array; bls_fp_sqr(&(*p).y, &A); - var B: array; bls_fp_mul(&(*p).x, &A, &B); - var C: array; bls_fp_sqr(&A, &C); - var S: array; bls_fp_add(&B, &B, &S); bls_fp_add(&S, &S, &S); - var X2: array; bls_fp_sqr(&(*p).x, &X2); - var X2_2: array; bls_fp_add(&X2, &X2, &X2_2); - var M: array; bls_fp_add(&X2_2, &X2, &M); - var M2: array; bls_fp_sqr(&M, &M2); - var S2: array; bls_fp_add(&S, &S, &S2); - var X3: array; bls_fp_sub(&M2, &S2, &X3); - var SX: array; bls_fp_sub(&S, &X3, &SX); - var MSX: array; bls_fp_mul(&M, &SX, &MSX); - var C2: array; bls_fp_add(&C, &C, &C2); - var C4: array; bls_fp_add(&C2, &C2, &C4); - var C8: array; bls_fp_add(&C4, &C4, &C8); - var Y3: array; bls_fp_sub(&MSX, &C8, &Y3); - var YZ: array; bls_fp_mul(&(*p).y, &(*p).z, &YZ); - var Z3: array; bls_fp_add(&YZ, &YZ, &Z3); - (*r).x = X3; (*r).y = Y3; (*r).z = Z3; -} - -fn g1_add_mixed(P: ptr, Qx: ptr>, - Qy: ptr>, r: ptr) { - if (g1_is_inf(P)) { - (*r).x = *Qx; (*r).y = *Qy; (*r).z = BLS_R; - return; - } - var Z2: array; bls_fp_sqr(&(*P).z, &Z2); - var U2: array; bls_fp_mul(Qx, &Z2, &U2); - var Z3: array; bls_fp_mul(&Z2, &(*P).z, &Z3); - var S2: array; bls_fp_mul(Qy, &Z3, &S2); - var H: array; bls_fp_sub(&U2, &(*P).x, &H); - var R: array; bls_fp_sub(&S2, &(*P).y, &R); - if (u384_is_zero(&H)) { - if (u384_is_zero(&R)) { g1_double(P, r); return; } - *r = g1_identity(); return; - } - var H2: array; bls_fp_sqr(&H, &H2); - var H3: array; bls_fp_mul(&H, &H2, &H3); - var U1H2: array; bls_fp_mul(&(*P).x, &H2, &U1H2); - var R2: array; bls_fp_sqr(&R, &R2); - var U1H2_2: array; bls_fp_add(&U1H2, &U1H2, &U1H2_2); - var t1: array; bls_fp_sub(&R2, &H3, &t1); - var X3: array; bls_fp_sub(&t1, &U1H2_2, &X3); - var UX: array; bls_fp_sub(&U1H2, &X3, &UX); - var RUX: array; bls_fp_mul(&R, &UX, &RUX); - var YH3: array; bls_fp_mul(&(*P).y, &H3, &YH3); - var Y3: array; bls_fp_sub(&RUX, &YH3, &Y3); - var Zr: array; bls_fp_mul(&H, &(*P).z, &Zr); - (*r).x = X3; (*r).y = Y3; (*r).z = Zr; -} - -fn g1_to_affine(p: ptr, - ax: ptr>, - ay: ptr>) { - if (g1_is_inf(p)) { *ax = u384_zero(); *ay = u384_zero(); return; } - var z_inv: array; bls_fp_inv(&(*p).z, &z_inv); - var z_inv2: array; bls_fp_sqr(&z_inv, &z_inv2); - var z_inv3: array; bls_fp_mul(&z_inv2, &z_inv, &z_inv3); - bls_fp_mul(&(*p).x, &z_inv2, ax); - bls_fp_mul(&(*p).y, &z_inv3, ay); -} - -// ============================================================================ -// BLS verify batch: decompress G1 signature, check on-curve -// ============================================================================ - -@compute @workgroup_size(256) -fn bls_verify_batch(@builtin(global_invocation_id) gid: vec3) { - let tid = gid.x; - if (tid >= params.num_items) { return; } - - // Each signature is 48 bytes = 12 u32 words - let sig_base = tid * 12u; - - // Read flag byte (first byte of first word) - let first_word = sig_data[sig_base]; - let flags = first_word & 0xFFu; - let compressed = (flags & 0x80u) != 0u; - let infinity = (flags & 0x40u) != 0u; - let y_sign = (flags & 0x20u) != 0u; - - if (infinity || !compressed) { - results[tid] = 0u; - return; - } - - // Clear flag bits and deserialize x-coordinate (big-endian 48 bytes -> 12 x u32 LE) - var x_raw: array; - for (var i = 0u; i < 12u; i = i + 1u) { - var w = sig_data[sig_base + 11u - i]; - // Byte-swap - w = ((w >> 24u) & 0xFFu) | (((w >> 16u) & 0xFFu) << 8u) - | (((w >> 8u) & 0xFFu) << 16u) | ((w & 0xFFu) << 24u); - x_raw[i] = w; - } - // Clear flag bits from the most significant byte - x_raw[11] = x_raw[11] & 0x1FFFFFFFu; - - // Decompress: y^2 = x^3 + 4 - var x_mont: array; bls_to_mont(&x_raw, &x_mont); - var x2: array; bls_fp_sqr(&x_mont, &x2); - var x3: array; bls_fp_mul(&x2, &x_mont, &x3); - var four_raw = array(4u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u); - var four_mont: array; bls_to_mont(&four_raw, &four_mont); - var y2: array; bls_fp_add(&x3, &four_mont, &y2); - - // sqrt(y2) = y2^((p+1)/4) since p = 3 mod 4 - var exp = array( - 0xFFFEAAAFu, 0xEE7FBFFFu, 0xAC54FFFFu, 0x07AAFFFFu, - 0x3DAC3D89u, 0xD9CC34A8u, 0x3CE144AFu, 0xD91DD2E1u, - 0x90D2EB35u, 0x92C6E9EDu, 0xE5FF9A6u, 0x0680447Au - ); - - var y_cand = BLS_R; - var base = y2; - for (var i = 0u; i < 12u; i = i + 1u) { - for (var bit = 0u; bit < 32u; bit = bit + 1u) { - if (((exp[i] >> bit) & 1u) != 0u) { - var tmp: array; - bls_fp_mul(&y_cand, &base, &tmp); - y_cand = tmp; - } - var tmp2: array; - bls_fp_sqr(&base, &tmp2); - base = tmp2; - } - } - - // Verify: y_cand^2 == y2 - var check: array; bls_fp_sqr(&y_cand, &check); - if (u384_cmp(&check, &y2) != 0) { - results[tid] = 0u; - return; - } - - // Pick sign - var y_normal: array; bls_from_mont(&y_cand, &y_normal); - let is_positive = (y_normal[0] & 1u) == 0u; - if (is_positive == y_sign) { - bls_fp_neg(&y_cand, &y_cand); - } - - // On-curve check passed, subgroup check deferred to host - results[tid] = 3u; // bit 0: on_curve, bit 1: needs_subgroup_check -} diff --git a/bls/gpu/wgsl/bls_combined_miller.wgsl b/bls/gpu/wgsl/bls_combined_miller.wgsl deleted file mode 100644 index ee86381..0000000 --- a/bls/gpu/wgsl/bls_combined_miller.wgsl +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// WGSL peer of bls_combined_miller.metal — tree-reduce kernel for the -// k-pair Miller-loop fan-in. -// -// The Miller per-bit kernels are not yet ported to WGSL on this stage -// (Stage 4 ships Fp-tower parity only; full miller_loop remains on the -// host CPU oracle for WGSL until Stage 5b). This shader provides the -// canonical pairwise Fp12 reduction over k Miller outputs: -// -// round 0: out[i] = in[2*i] * in[2*i+1] (i < pairs) -// out[pairs] = in[2*pairs] (carry, when n is odd) -// round 1: same shape, halved n -// ... -// final: out[0] = prod_i in[i] -// -// Determinism: index map (in[2i], in[2i+1]) -> out[i] is canonical and -// matches tree_reduce_fp12 in cpp/bls_pairing.cpp + the Metal/CUDA peers. -// -// Concatenated by the WGSL host driver after bls_fp_ops.wgsl, bls_fp2.wgsl, -// bls_fp6.wgsl, bls_fp12.wgsl (same scheme as bls_fp_tower_kernels.wgsl). - -@group(0) @binding(0) var in_a: array; -@group(0) @binding(2) var out: array; -// params.x = pairs, params.y = carry (0 or 1), params.z = total threads -@group(0) @binding(3) var params: vec4; - -var g_a: array; -var g_b: array; -var g_r: array; - -fn load_fp12_at(slot: u32, dst: ptr>) { - let base = slot * 144u; - for (var i = 0u; i < 144u; i = i + 1u) { (*dst)[i] = in_a[base + i]; } -} -fn store_fp12_at(slot: u32, src: ptr>) { - let base = slot * 144u; - for (var i = 0u; i < 144u; i = i + 1u) { out[base + i] = (*src)[i]; } -} - -// One round of canonical pairwise tree reduction. -// -// tid < pairs : out[tid] = in[2*tid] * in[2*tid+1] -// tid == pairs (carry): out[tid] = in[2*tid] (last element passes through) -// -// Caller dispatches with threads = pairs + carry and runs ceil(log2(k)) -// rounds, swapping in_a / out between dispatches. -@compute @workgroup_size(1) -fn k_combined_miller_reduce(@builtin(global_invocation_id) gid: vec3) { - let tid = gid.x; - let pairs = params.x; - let carry = params.y; - - if (tid < pairs) { - load_fp12_at(2u * tid, &g_a); - load_fp12_at(2u * tid + 1u, &g_b); - fp12_mul_priv(&g_a, &g_b, &g_r); - store_fp12_at(tid, &g_r); - return; - } - if (carry != 0u && tid == pairs) { - load_fp12_at(2u * tid, &g_a); - store_fp12_at(tid, &g_a); - } -} diff --git a/bls/gpu/wgsl/bls_driver_wgpu.cpp b/bls/gpu/wgsl/bls_driver_wgpu.cpp deleted file mode 100644 index b29ec9f..0000000 --- a/bls/gpu/wgsl/bls_driver_wgpu.cpp +++ /dev/null @@ -1,275 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// WebGPU/WGSL host driver for BLS12-381 Fp-tower kernels (Stage 4 parity port). -// -// Concatenates the WGSL source fragments (Fp ops, Fp2, Fp6, Fp12, kernels) at -// runtime, compiles a single shader module, then dispatches per kernel name. -// -// Build: -// * BLS_HAS_WEBGPU=1 — Dawn or wgpu-native runtime found -// * BLS_HAS_WGPU_NATIVE=1 — wgpu-native specifically (gives wgpuDevicePoll) -// -// On Apple host with Homebrew wgpu-native, both flags are set by CMake. - -#include "bls_driver_wgpu.h" - -#if defined(BLS_HAS_WEBGPU) - -#include -#if defined(BLS_HAS_WGPU_NATIVE) -# include -#endif - -#include -#include -#include -#include -#include -#include -#include - -// WGSL sources concatenated by CMake into bls_wgsl_sources.h. -#include "bls_wgsl_sources.h" - -namespace { - -WGPUStringView sv(const char* s) { - WGPUStringView v{}; - v.data = s; - v.length = (s == nullptr) ? 0 : std::strlen(s); - return v; -} -WGPUStringView sv(const std::string& s) { - WGPUStringView v{}; - v.data = s.data(); - v.length = s.size(); - return v; -} - -void drain(WGPUInstance inst, WGPUDevice dev) { - if (inst) wgpuInstanceProcessEvents(inst); -#if defined(BLS_HAS_WGPU_NATIVE) - if (dev) wgpuDevicePoll(dev, /*wait=*/WGPU_TRUE, nullptr); -#else - (void)dev; -#endif -} - -bool wait_map(WGPUInstance inst, WGPUDevice dev, WGPUBuffer buf, - WGPUMapMode mode, size_t off, size_t size) { - struct State { std::atomic done{false}; WGPUMapAsyncStatus status{WGPUMapAsyncStatus_Error}; } s; - WGPUBufferMapCallbackInfo cb{}; - cb.mode = WGPUCallbackMode_AllowProcessEvents; - cb.callback = [](WGPUMapAsyncStatus st, WGPUStringView, void* u, void*) { - auto* p = static_cast(u); - p->status = st; - p->done.store(true, std::memory_order_release); - }; - cb.userdata1 = &s; - wgpuBufferMapAsync(buf, mode, off, size, cb); - for (int spin = 0; spin < 4096; spin++) { - if (s.done.load(std::memory_order_acquire)) break; - drain(inst, dev); - } - return s.done.load() && s.status == WGPUMapAsyncStatus_Success; -} - -struct Engine { - WGPUInstance instance{nullptr}; - WGPUAdapter adapter{nullptr}; - WGPUDevice device{nullptr}; - WGPUQueue queue{nullptr}; - WGPUShaderModule module{nullptr}; - bool initialized{false}; -}; - -Engine& engine() { static Engine e; return e; } - -bool init_engine() { - Engine& e = engine(); - if (e.initialized) return true; - - WGPUInstanceDescriptor idesc{}; - e.instance = wgpuCreateInstance(&idesc); - if (!e.instance) return false; - - struct AS { std::atomic done{false}; WGPUAdapter ad{nullptr}; } as; - WGPURequestAdapterOptions ropt{}; - ropt.powerPreference = WGPUPowerPreference_HighPerformance; - WGPURequestAdapterCallbackInfo rcb{}; - rcb.mode = WGPUCallbackMode_AllowProcessEvents; - rcb.callback = [](WGPURequestAdapterStatus st, WGPUAdapter ad, - WGPUStringView, void* u, void*) { - auto* p = static_cast(u); - if (st == WGPURequestAdapterStatus_Success) p->ad = ad; - p->done.store(true, std::memory_order_release); - }; - rcb.userdata1 = &as; - wgpuInstanceRequestAdapter(e.instance, &ropt, rcb); - for (int spin = 0; spin < 4096; spin++) { - if (as.done.load(std::memory_order_acquire)) break; - wgpuInstanceProcessEvents(e.instance); - } - if (!as.ad) { fprintf(stderr, "wgpu: no adapter\n"); return false; } - e.adapter = as.ad; - - struct DS { std::atomic done{false}; WGPUDevice dev{nullptr}; } ds; - WGPUDeviceDescriptor ddesc{}; - WGPURequestDeviceCallbackInfo dcb{}; - dcb.mode = WGPUCallbackMode_AllowProcessEvents; - dcb.callback = [](WGPURequestDeviceStatus st, WGPUDevice dev, - WGPUStringView, void* u, void*) { - auto* p = static_cast(u); - if (st == WGPURequestDeviceStatus_Success) p->dev = dev; - p->done.store(true, std::memory_order_release); - }; - dcb.userdata1 = &ds; - wgpuAdapterRequestDevice(e.adapter, &ddesc, dcb); - for (int spin = 0; spin < 4096; spin++) { - if (ds.done.load(std::memory_order_acquire)) break; - wgpuInstanceProcessEvents(e.instance); - } - if (!ds.dev) { fprintf(stderr, "wgpu: no device\n"); return false; } - e.device = ds.dev; - e.queue = wgpuDeviceGetQueue(e.device); - if (!e.queue) return false; - - // Concatenate WGSL sources and compile a single module. - std::string src; - src.append(kBLS_WGSL_FpOps); - src.append(kBLS_WGSL_Fp2); - src.append(kBLS_WGSL_Fp6); - src.append(kBLS_WGSL_Fp12); - src.append(kBLS_WGSL_Kernels); - - WGPUShaderSourceWGSL wgsl{}; - wgsl.chain.sType = WGPUSType_ShaderSourceWGSL; - wgsl.code = sv(src); - - WGPUShaderModuleDescriptor smd{}; - smd.nextInChain = &wgsl.chain; - smd.label = sv("bls_fp_tower"); - e.module = wgpuDeviceCreateShaderModule(e.device, &smd); - if (!e.module) { fprintf(stderr, "wgpu: shader compile failed\n"); return false; } - - e.initialized = true; - return true; -} - -WGPUBuffer make_buf(Engine& e, size_t size, WGPUBufferUsage usage) { - WGPUBufferDescriptor bd{}; - bd.size = (size + 3) & ~size_t(3); - bd.usage = usage; - return wgpuDeviceCreateBuffer(e.device, &bd); -} - -bool dispatch(const char* entry, const void* a_data, const void* b_data, - void* out_data, size_t a_bytes, size_t b_bytes, size_t out_bytes, - uint32_t count) { - Engine& e = engine(); - if (!init_engine()) return false; - - // Buffers - WGPUBuffer bufA = make_buf(e, a_bytes ? a_bytes : 4, WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst); - WGPUBuffer bufB = make_buf(e, b_bytes ? b_bytes : 4, WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst); - WGPUBuffer bufO = make_buf(e, out_bytes, WGPUBufferUsage_Storage | WGPUBufferUsage_CopySrc); - WGPUBuffer bufU = make_buf(e, 16, WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst); - WGPUBuffer bufR = make_buf(e, out_bytes, WGPUBufferUsage_MapRead | WGPUBufferUsage_CopyDst); - if (!bufA || !bufB || !bufO || !bufU || !bufR) return false; - - if (a_bytes) wgpuQueueWriteBuffer(e.queue, bufA, 0, a_data, a_bytes); - if (b_bytes) wgpuQueueWriteBuffer(e.queue, bufB, 0, b_data, b_bytes); - uint32_t params[4] = { count, 0, 0, 0 }; - wgpuQueueWriteBuffer(e.queue, bufU, 0, params, 16); - - // Pipeline - WGPUComputePipelineDescriptor cpd{}; - cpd.compute.module = e.module; - cpd.compute.entryPoint = sv(entry); - cpd.label = sv(entry); - WGPUComputePipeline pso = wgpuDeviceCreateComputePipeline(e.device, &cpd); - if (!pso) { fprintf(stderr, "wgpu: pipeline %s failed\n", entry); return false; } - - // Bind group (auto-derive layout) - WGPUBindGroupLayout bgl = wgpuComputePipelineGetBindGroupLayout(pso, 0); - WGPUBindGroupEntry bge[4] = {}; - bge[0].binding = 0; bge[0].buffer = bufA; bge[0].size = a_bytes ? a_bytes : 4; - bge[1].binding = 1; bge[1].buffer = bufB; bge[1].size = b_bytes ? b_bytes : 4; - bge[2].binding = 2; bge[2].buffer = bufO; bge[2].size = out_bytes; - bge[3].binding = 3; bge[3].buffer = bufU; bge[3].size = 16; - WGPUBindGroupDescriptor bgd{}; - bgd.layout = bgl; - bgd.entryCount = 4; - bgd.entries = bge; - WGPUBindGroup bg = wgpuDeviceCreateBindGroup(e.device, &bgd); - if (!bg) return false; - - // Encode + dispatch - WGPUCommandEncoderDescriptor ced{}; - WGPUCommandEncoder ce = wgpuDeviceCreateCommandEncoder(e.device, &ced); - WGPUComputePassDescriptor cpd2{}; - WGPUComputePassEncoder cpe = wgpuCommandEncoderBeginComputePass(ce, &cpd2); - wgpuComputePassEncoderSetPipeline(cpe, pso); - wgpuComputePassEncoderSetBindGroup(cpe, 0, bg, 0, nullptr); - wgpuComputePassEncoderDispatchWorkgroups(cpe, count, 1, 1); - wgpuComputePassEncoderEnd(cpe); - - wgpuCommandEncoderCopyBufferToBuffer(ce, bufO, 0, bufR, 0, out_bytes); - WGPUCommandBufferDescriptor cbd{}; - WGPUCommandBuffer cmd = wgpuCommandEncoderFinish(ce, &cbd); - wgpuQueueSubmit(e.queue, 1, &cmd); - - if (!wait_map(e.instance, e.device, bufR, WGPUMapMode_Read, 0, out_bytes)) { - fprintf(stderr, "wgpu: map readback failed\n"); - return false; - } - const void* mapped = wgpuBufferGetConstMappedRange(bufR, 0, out_bytes); - std::memcpy(out_data, mapped, out_bytes); - wgpuBufferUnmap(bufR); - - wgpuComputePassEncoderRelease(cpe); - wgpuCommandEncoderRelease(ce); - wgpuCommandBufferRelease(cmd); - wgpuBindGroupRelease(bg); - wgpuBindGroupLayoutRelease(bgl); - wgpuComputePipelineRelease(pso); - wgpuBufferRelease(bufA); - wgpuBufferRelease(bufB); - wgpuBufferRelease(bufO); - wgpuBufferRelease(bufU); - wgpuBufferRelease(bufR); - return true; -} - -} // namespace - -extern "C" { - -int bls_wgpu_available(void) { - return init_engine() ? 1 : 0; -} - -int bls_wgpu_run_binary(const char* entry, const void* a, const void* b, - void* out, unsigned elem_bytes, unsigned count) { - size_t bytes = size_t(elem_bytes) * count; - return dispatch(entry, a, b, out, bytes, bytes, bytes, count) ? 0 : -1; -} - -int bls_wgpu_run_unary(const char* entry, const void* a, void* out, - unsigned elem_bytes, unsigned count) { - size_t bytes = size_t(elem_bytes) * count; - return dispatch(entry, a, nullptr, out, bytes, 0, bytes, count) ? 0 : -1; -} - -} // extern "C" - -#else // BLS_HAS_WEBGPU not defined: stub mode - -extern "C" { -int bls_wgpu_available(void) { return 0; } -int bls_wgpu_run_binary(const char*, const void*, const void*, void*, unsigned, unsigned) { return -1; } -int bls_wgpu_run_unary(const char*, const void*, void*, unsigned, unsigned) { return -1; } -} - -#endif diff --git a/bls/gpu/wgsl/bls_driver_wgpu.h b/bls/gpu/wgsl/bls_driver_wgpu.h deleted file mode 100644 index 5a5885d..0000000 --- a/bls/gpu/wgsl/bls_driver_wgpu.h +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Public C-ABI for the WebGPU/WGSL driver. On hosts without a wgpu runtime, -// bls_wgpu_available() returns 0 and run_* return -1. - -#pragma once -#ifdef __cplusplus -extern "C" { -#endif - -int bls_wgpu_available(void); - -// Generic dispatch helpers. Buffer sizes = elem_bytes * count. -int bls_wgpu_run_binary(const char* entry, const void* a, const void* b, - void* out, unsigned elem_bytes, unsigned count); -int bls_wgpu_run_unary(const char* entry, const void* a, void* out, - unsigned elem_bytes, unsigned count); - -#ifdef __cplusplus -} -#endif diff --git a/bls/gpu/wgsl/bls_fp12.wgsl b/bls/gpu/wgsl/bls_fp12.wgsl deleted file mode 100644 index 5179f2a..0000000 --- a/bls/gpu/wgsl/bls_fp12.wgsl +++ /dev/null @@ -1,718 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// WGSL peer of bls_fp12.metal. Fp12 = 144 x u32 = 2 * Fp6 (576 bytes byte-equal blst_fp12). -// -// Out-pointer form for every function. Returning array by value -// materialises a stack copy at every call site and the karatsuba/inversion -// call tree (fp6_inv inside fp12_inv etc.) blows AGXMetalG13X's function-call -// stack budget. Out-pointer form keeps every intermediate in a single named -// slot the caller already owns. Same arithmetic as Metal. - -fn fp12_get_c0(a: ptr>, out: ptr>) { - for (var i = 0u; i < 72u; i = i + 1u) { (*out)[i] = (*a)[i]; } -} -fn fp12_get_c1(a: ptr>, out: ptr>) { - for (var i = 0u; i < 72u; i = i + 1u) { (*out)[i] = (*a)[72u + i]; } -} -fn fp12_set_c0(out: ptr>, v: ptr>) { - for (var i = 0u; i < 72u; i = i + 1u) { (*out)[i] = (*v)[i]; } -} -fn fp12_set_c1(out: ptr>, v: ptr>) { - for (var i = 0u; i < 72u; i = i + 1u) { (*out)[72u + i] = (*v)[i]; } -} - -// ---------- Out-pointer Fp12 primitives ---------- - -fn fp12_add_p(a: ptr>, b: ptr>, - out: ptr>) { - var a0: array; fp12_get_c0(a, &a0); - var a1: array; fp12_get_c1(a, &a1); - var b0: array; fp12_get_c0(b, &b0); - var b1: array; fp12_get_c1(b, &b1); - var r: array; - fp6_add_p(&a0, &b0, &r); fp12_set_c0(out, &r); - fp6_add_p(&a1, &b1, &r); fp12_set_c1(out, &r); -} -fn fp12_sub_p(a: ptr>, b: ptr>, - out: ptr>) { - var a0: array; fp12_get_c0(a, &a0); - var a1: array; fp12_get_c1(a, &a1); - var b0: array; fp12_get_c0(b, &b0); - var b1: array; fp12_get_c1(b, &b1); - var r: array; - fp6_sub_p(&a0, &b0, &r); fp12_set_c0(out, &r); - fp6_sub_p(&a1, &b1, &r); fp12_set_c1(out, &r); -} -fn fp12_conj_p(a: ptr>, out: ptr>) { - var c0: array; fp12_get_c0(a, &c0); - var c1: array; fp12_get_c1(a, &c1); - var r1: array; fp6_neg_p(&c1, &r1); - fp12_set_c0(out, &c0); - fp12_set_c1(out, &r1); -} - -fn fp12_mul_p(a: ptr>, b: ptr>, - out: ptr>) { - var a0: array; fp12_get_c0(a, &a0); - var a1: array; fp12_get_c1(a, &a1); - var b0: array; fp12_get_c0(b, &b0); - var b1: array; fp12_get_c1(b, &b1); - - var t0: array; fp6_mul_p(&a0, &b0, &t0); - var t1: array; fp6_mul_p(&a1, &b1, &t1); - - var sa: array; fp6_add_p(&a0, &a1, &sa); - var sb: array; fp6_add_p(&b0, &b1, &sb); - var prod: array; fp6_mul_p(&sa, &sb, &prod); - - var r1_v: array; - fp6_sub_p(&prod, &t0, &r1_v); - fp6_sub_p(&r1_v, &t1, &r1_v); - - var t1v: array; fp6_mul_by_v_p(&t1, &t1v); - var r0: array; fp6_add_p(&t0, &t1v, &r0); - - fp12_set_c0(out, &r0); - fp12_set_c1(out, &r1_v); -} - -fn fp12_sqr_p(a: ptr>, out: ptr>) { - var a0: array; fp12_get_c0(a, &a0); - var a1: array; fp12_get_c1(a, &a1); - - var s: array; fp6_add_p(&a0, &a1, &s); - var t1v: array; fp6_mul_by_v_p(&a1, &t1v); - var t1: array; fp6_add_p(&a0, &t1v, &t1); - var t0: array; fp6_mul_p(&s, &t1, &t0); - - var t2: array; fp6_mul_p(&a0, &a1, &t2); - var r1: array; fp6_add_p(&t2, &t2, &r1); - var r0: array; - fp6_sub_p(&t0, &t2, &r0); - var t2v: array; fp6_mul_by_v_p(&t2, &t2v); - fp6_sub_p(&r0, &t2v, &r0); - - fp12_set_c0(out, &r0); - fp12_set_c1(out, &r1); -} - -fn fp12_inv_p(a: ptr>, out: ptr>) { - var a0: array; fp12_get_c0(a, &a0); - var a1: array; fp12_get_c1(a, &a1); - - var t0: array; fp6_sqr_p(&a0, &t0); - var t1: array; fp6_sqr_p(&a1, &t1); - var t1v: array; fp6_mul_by_v_p(&t1, &t1v); - var diff: array; fp6_sub_p(&t0, &t1v, &diff); - var ti: array; fp6_inv_p(&diff, &ti); - - var r0: array; fp6_mul_p(&a0, &ti, &r0); - var r1_pos: array; fp6_mul_p(&a1, &ti, &r1_pos); - var r1: array; fp6_neg_p(&r1_pos, &r1); - - fp12_set_c0(out, &r0); - fp12_set_c1(out, &r1); -} - -// Cyclotomic squaring in Fp12. Mirrors blst's cyclotomic_sqr_fp12 + sqr_fp4. -// -// sqr_fp4 produces (r0, r1) from (a0, a1) Fp2 pair; we write r0 / r1 through -// out-pointers so we don't materialise a 192-byte struct return. -fn sqr_fp4_p(a0: ptr>, - a1: ptr>, - r0: ptr>, - r1: ptr>) { - var t0: array; fp2_sqr_p(a0, &t0); - var t1: array; fp2_sqr_p(a1, &t1); - var sum: array; fp2_add_p(a0, a1, &sum); - - var t1v: array; fp2_mul_by_1_plus_u_p(&t1, &t1v); - fp2_add_p(&t1v, &t0, r0); - var sum_sq: array; fp2_sqr_p(&sum, &sum_sq); - fp2_sub_p(&sum_sq, &t0, r1); - fp2_sub_p(r1, &t1, r1); -} - -fn fp12_cyclotomic_sqr_p(a: ptr>, - out: ptr>) { - var a_c0: array; fp12_get_c0(a, &a_c0); - var a_c1: array; fp12_get_c1(a, &a_c1); - var a00: array; fp6_get_c0(&a_c0, &a00); - var a01: array; fp6_get_c1(&a_c0, &a01); - var a02: array; fp6_get_c2(&a_c0, &a02); - var a10: array; fp6_get_c0(&a_c1, &a10); - var a11: array; fp6_get_c1(&a_c1, &a11); - var a12: array; fp6_get_c2(&a_c1, &a12); - - var ta_r0: array; var ta_r1: array; - sqr_fp4_p(&a00, &a11, &ta_r0, &ta_r1); - var tb_r0: array; var tb_r1: array; - sqr_fp4_p(&a10, &a02, &tb_r0, &tb_r1); - var tc_r0: array; var tc_r1: array; - sqr_fp4_p(&a01, &a12, &tc_r0, &tc_r1); - - // r.c0.c0 = 3 t00 - 2 a00 - var r00: array; - var tmp: array; - fp2_sub_p(&ta_r0, &a00, &tmp); - fp2_add_p(&tmp, &tmp, &tmp); - fp2_add_p(&tmp, &ta_r0, &r00); - - // r.c0.c1 = 3 t10 - 2 a01 - var r01: array; - fp2_sub_p(&tb_r0, &a01, &tmp); - fp2_add_p(&tmp, &tmp, &tmp); - fp2_add_p(&tmp, &tb_r0, &r01); - - // r.c0.c2 = 3 t20 - 2 a02 - var r02: array; - fp2_sub_p(&tc_r0, &a02, &tmp); - fp2_add_p(&tmp, &tmp, &tmp); - fp2_add_p(&tmp, &tc_r0, &r02); - - // r.c1.c0 = 3 (t21 * (u+1)) + 2 a10 - var r10: array; - var tcr1v: array; - fp2_mul_by_1_plus_u_p(&tc_r1, &tcr1v); - fp2_add_p(&tcr1v, &a10, &tmp); - fp2_add_p(&tmp, &tmp, &tmp); - fp2_add_p(&tmp, &tcr1v, &r10); - - // r.c1.c1 = 3 t01 + 2 a11 - var r11: array; - fp2_add_p(&ta_r1, &a11, &tmp); - fp2_add_p(&tmp, &tmp, &tmp); - fp2_add_p(&tmp, &ta_r1, &r11); - - // r.c1.c2 = 3 t11 + 2 a12 - var r12: array; - fp2_add_p(&tb_r1, &a12, &tmp); - fp2_add_p(&tmp, &tmp, &tmp); - fp2_add_p(&tmp, &tb_r1, &r12); - - var c0: array; - fp6_set_c0(&c0, &r00); - fp6_set_c1(&c0, &r01); - fp6_set_c2(&c0, &r02); - var c1: array; - fp6_set_c0(&c1, &r10); - fp6_set_c1(&c1, &r11); - fp6_set_c2(&c1, &r12); - fp12_set_c0(out, &c0); - fp12_set_c1(out, &c1); -} - -// Frobenius for Fp12. -fn frob12_n1(out: ptr>) { - (*out)[0u] = 0xB319D465u; (*out)[1u] = 0x07089552u; - (*out)[2u] = 0xB50A8313u; (*out)[3u] = 0xC6695F92u; - (*out)[4u] = 0xD117228Fu; (*out)[5u] = 0x97E83CCCu; - (*out)[6u] = 0xB2DC29EEu; (*out)[7u] = 0xA35BAECAu; - (*out)[8u] = 0x5DAACE4Du; (*out)[9u] = 0x1CE393EAu; - (*out)[10u] = 0xB0FB66EBu; (*out)[11u] = 0x08F2220Fu; - (*out)[12u + 0u] = 0x4CE5D646u; (*out)[12u + 1u] = 0xB2F66AADu; - (*out)[12u + 2u] = 0xFC497CECu; (*out)[12u + 3u] = 0x5842A06Bu; - (*out)[12u + 4u] = 0x2599D394u; (*out)[12u + 5u] = 0xCF4895D4u; - (*out)[12u + 6u] = 0x40A8E8D0u; (*out)[12u + 7u] = 0xC11B9CBAu; - (*out)[12u + 8u] = 0xE5A0DE89u; (*out)[12u + 9u] = 0x2E3813CBu; - (*out)[12u + 10u] = 0x88847FAFu; (*out)[12u + 11u] = 0x110EEFDAu; -} -fn frob12_n2(out: ptr>) { - (*out)[0u] = 0x798DBA3Au; (*out)[1u] = 0xECFB361Bu; - (*out)[2u] = 0x91865A2Cu; (*out)[3u] = 0xC100DDB8u; - (*out)[4u] = 0x232BDA8Eu; (*out)[5u] = 0x0EC08FF1u; - (*out)[6u] = 0xF1CA4721u; (*out)[7u] = 0xD5C13CC6u; - (*out)[8u] = 0xBF7B5C04u; (*out)[9u] = 0x47222A47u; - (*out)[10u] = 0xE51C5F59u; (*out)[11u] = 0x0110F184u; - for (var i = 12u; i < 24u; i = i + 1u) { (*out)[i] = 0u; } -} -fn frob12_n3(out: ptr>) { - (*out)[0u] = 0xA55C9AD1u; (*out)[1u] = 0x3E2F585Du; - (*out)[2u] = 0x86C18183u; (*out)[3u] = 0x4294213Du; - (*out)[4u] = 0x8B623732u; (*out)[5u] = 0x382844C8u; - (*out)[6u] = 0x19103E18u; (*out)[7u] = 0x92AD2AFDu; - (*out)[8u] = 0xAC7CF0B9u; (*out)[9u] = 0x1D794E4Fu; - (*out)[10u] = 0x7D825EC8u; (*out)[11u] = 0x0BD592FCu; - (*out)[12u + 0u] = 0x5AA30FDAu; (*out)[12u + 1u] = 0x7BCFA7A2u; - (*out)[12u + 2u] = 0x2A927E7Cu; (*out)[12u + 3u] = 0xDC17DEC1u; - (*out)[12u + 4u] = 0x6B4EBEF1u; (*out)[12u + 5u] = 0x2F088DD8u; - (*out)[12u + 6u] = 0xDA74D4A7u; (*out)[12u + 7u] = 0xD1CA2087u; - (*out)[12u + 8u] = 0x96CEBC1Du; (*out)[12u + 9u] = 0x2DA25966u; - (*out)[12u + 10u] = 0xBBFD87D2u; (*out)[12u + 11u] = 0x0E2B7EEDu; -} - -fn fp12_frobenius_p(a: ptr>, n: u32, - out: ptr>) { - var a0: array; fp12_get_c0(a, &a0); - var a1: array; fp12_get_c1(a, &a1); - var r0: array; fp6_frobenius_p(&a0, n, &r0); - var r1: array; fp6_frobenius_p(&a1, n, &r1); - - var coeff: array; - if (n == 1u) { frob12_n1(&coeff); } - else if (n == 2u) { frob12_n2(&coeff); } - else { frob12_n3(&coeff); } - - var r1_c0: array; fp6_get_c0(&r1, &r1_c0); - var r1_c1: array; fp6_get_c1(&r1, &r1_c1); - var r1_c2: array; fp6_get_c2(&r1, &r1_c2); - var r1_c0_new: array; fp2_mul_p(&r1_c0, &coeff, &r1_c0_new); - var r1_c1_new: array; fp2_mul_p(&r1_c1, &coeff, &r1_c1_new); - var r1_c2_new: array; fp2_mul_p(&r1_c2, &coeff, &r1_c2_new); - var r1_new: array; - fp6_set_c0(&r1_new, &r1_c0_new); - fp6_set_c1(&r1_new, &r1_c1_new); - fp6_set_c2(&r1_new, &r1_c2_new); - - fp12_set_c0(out, &r0); - fp12_set_c1(out, &r1_new); -} - -// ---------- Private-storage scratches for upper-tower kernels ---------- -// -// The full Fp12 call tree blows AGXMetalG13X's function-call stack budget -// when every intermediate sits on the function stack. Move all the multi- -// limb scratches into private storage; the function stack only ever sees -// one Fp2 (96 B). Workgroup size is 1, so each invocation owns these. - -// Fp12-level scratches (6 x 144 u32 = 3.4 KB private) -var sp_a: array; -var sp_b: array; -var sp_x12_0: array; -var sp_x12_1: array; - -// Fp6-level scratches (8 x 72 u32 = 2.3 KB private) -var sp_a0: array; -var sp_a1: array; -var sp_b0: array; -var sp_b1: array; -var sp_t0: array; -var sp_t1: array; -var sp_t2: array; -var sp_t3: array; -var sp_r0: array; -var sp_r1: array; - -// Fp2-level scratches for fp6_mul/sqr/inv internals -var sp_a00: array; -var sp_a01: array; -var sp_a02: array; -var sp_b00: array; -var sp_b01: array; -var sp_b02: array; -var sp_u0: array; -var sp_u1: array; -var sp_u2: array; -var sp_u3: array; -var sp_u4: array; -var sp_u5: array; -var sp_u6: array; -var sp_v0: array; -var sp_v1: array; -var sp_v2: array; - -fn fp12_priv_get_c0(a: ptr>, - out: ptr>) { - for (var i = 0u; i < 72u; i = i + 1u) { (*out)[i] = (*a)[i]; } -} -fn fp12_priv_get_c1(a: ptr>, - out: ptr>) { - for (var i = 0u; i < 72u; i = i + 1u) { (*out)[i] = (*a)[72u + i]; } -} -fn fp12_priv_set_c0(out: ptr>, - v: ptr>) { - for (var i = 0u; i < 72u; i = i + 1u) { (*out)[i] = (*v)[i]; } -} -fn fp12_priv_set_c1(out: ptr>, - v: ptr>) { - for (var i = 0u; i < 72u; i = i + 1u) { (*out)[72u + i] = (*v)[i]; } -} -fn fp6_priv_get_c0(a: ptr>, - out: ptr>) { - for (var i = 0u; i < 24u; i = i + 1u) { (*out)[i] = (*a)[i]; } -} -fn fp6_priv_get_c1(a: ptr>, - out: ptr>) { - for (var i = 0u; i < 24u; i = i + 1u) { (*out)[i] = (*a)[24u + i]; } -} -fn fp6_priv_get_c2(a: ptr>, - out: ptr>) { - for (var i = 0u; i < 24u; i = i + 1u) { (*out)[i] = (*a)[48u + i]; } -} -fn fp6_priv_set_c0(out: ptr>, - v: ptr>) { - for (var i = 0u; i < 24u; i = i + 1u) { (*out)[i] = (*v)[i]; } -} -fn fp6_priv_set_c1(out: ptr>, - v: ptr>) { - for (var i = 0u; i < 24u; i = i + 1u) { (*out)[24u + i] = (*v)[i]; } -} -fn fp6_priv_set_c2(out: ptr>, - v: ptr>) { - for (var i = 0u; i < 24u; i = i + 1u) { (*out)[48u + i] = (*v)[i]; } -} - -// ---------- Fp2 ops on private operands ---------- -// -// Decompose into raw fp_* leaves directly. Each Fp2 op only ever has at most -// six 12 x u32 stack temporaries (a0, a1, b0, b1, r0, r1). Returning an -// array from fp_mul/add/sub/neg/sqr/inv lands in a stack slot, and -// AGXMetalG13X tolerates that in the deep call tree because the private -// scratches above absorb the working state. - -fn run_fp2_add_pp(a: ptr>, b: ptr>, - out: ptr>) { - var a0: array; - var a1: array; - var b0: array; - var b1: array; - for (var i = 0u; i < 12u; i = i + 1u) { - a0[i] = (*a)[i]; a1[i] = (*a)[12u + i]; - b0[i] = (*b)[i]; b1[i] = (*b)[12u + i]; - } - let r0 = fp_add(a0, b0); - let r1 = fp_add(a1, b1); - for (var i = 0u; i < 12u; i = i + 1u) { (*out)[i] = r0[i]; (*out)[12u + i] = r1[i]; } -} -fn run_fp2_sub_pp(a: ptr>, b: ptr>, - out: ptr>) { - var a0: array; - var a1: array; - var b0: array; - var b1: array; - for (var i = 0u; i < 12u; i = i + 1u) { - a0[i] = (*a)[i]; a1[i] = (*a)[12u + i]; - b0[i] = (*b)[i]; b1[i] = (*b)[12u + i]; - } - let r0 = fp_sub(a0, b0); - let r1 = fp_sub(a1, b1); - for (var i = 0u; i < 12u; i = i + 1u) { (*out)[i] = r0[i]; (*out)[12u + i] = r1[i]; } -} -fn run_fp2_neg_pp(a: ptr>, - out: ptr>) { - var a0: array; - var a1: array; - for (var i = 0u; i < 12u; i = i + 1u) { - a0[i] = (*a)[i]; a1[i] = (*a)[12u + i]; - } - let r0 = fp_neg(a0); - let r1 = fp_neg(a1); - for (var i = 0u; i < 12u; i = i + 1u) { (*out)[i] = r0[i]; (*out)[12u + i] = r1[i]; } -} -fn run_fp2_mul_pp(a: ptr>, b: ptr>, - out: ptr>) { - var a0: array; - var a1: array; - var b0: array; - var b1: array; - for (var i = 0u; i < 12u; i = i + 1u) { - a0[i] = (*a)[i]; a1[i] = (*a)[12u + i]; - b0[i] = (*b)[i]; b1[i] = (*b)[12u + i]; - } - let aa = fp_mul(a0, b0); - let bb = fp_mul(a1, b1); - let sa = fp_add(a0, a1); - let sb = fp_add(b0, b1); - let cross = fp_mul(sa, sb); - let r0 = fp_sub(aa, bb); - let r1 = fp_sub(fp_sub(cross, aa), bb); - for (var i = 0u; i < 12u; i = i + 1u) { (*out)[i] = r0[i]; (*out)[12u + i] = r1[i]; } -} -fn run_fp2_sqr_pp(a: ptr>, - out: ptr>) { - var a0: array; - var a1: array; - for (var i = 0u; i < 12u; i = i + 1u) { - a0[i] = (*a)[i]; a1[i] = (*a)[12u + i]; - } - let ab = fp_mul(a0, a1); - let sum = fp_add(a0, a1); - let dif = fp_sub(a0, a1); - let r0 = fp_mul(sum, dif); - let r1 = fp_add(ab, ab); - for (var i = 0u; i < 12u; i = i + 1u) { (*out)[i] = r0[i]; (*out)[12u + i] = r1[i]; } -} -fn run_fp2_inv_pp(a: ptr>, - out: ptr>) { - var a0: array; - var a1: array; - for (var i = 0u; i < 12u; i = i + 1u) { - a0[i] = (*a)[i]; a1[i] = (*a)[12u + i]; - } - let t0 = fp_sqr(a0); - let t1 = fp_sqr(a1); - let norm = fp_add(t0, t1); - let ni = fp_inv(norm); - let r0 = fp_mul(a0, ni); - let r1 = fp_neg(fp_mul(a1, ni)); - for (var i = 0u; i < 12u; i = i + 1u) { (*out)[i] = r0[i]; (*out)[12u + i] = r1[i]; } -} -fn run_fp2_mul_by_1_plus_u_pp(a: ptr>, - out: ptr>) { - var a0: array; - var a1: array; - for (var i = 0u; i < 12u; i = i + 1u) { - a0[i] = (*a)[i]; a1[i] = (*a)[12u + i]; - } - let r0 = fp_sub(a0, a1); - let r1 = fp_add(a0, a1); - for (var i = 0u; i < 12u; i = i + 1u) { (*out)[i] = r0[i]; (*out)[12u + i] = r1[i]; } -} - -// ---------- Fp6 ops on private operands (the heavy ones use private scratches) ---------- - -fn run_fp6_add_priv(a: ptr>, b: ptr>, - out: ptr>) { - fp6_priv_get_c0(a, &sp_a00); fp6_priv_get_c1(a, &sp_a01); fp6_priv_get_c2(a, &sp_a02); - fp6_priv_get_c0(b, &sp_b00); fp6_priv_get_c1(b, &sp_b01); fp6_priv_get_c2(b, &sp_b02); - run_fp2_add_pp(&sp_a00, &sp_b00, &sp_v0); - run_fp2_add_pp(&sp_a01, &sp_b01, &sp_v1); - run_fp2_add_pp(&sp_a02, &sp_b02, &sp_v2); - fp6_priv_set_c0(out, &sp_v0); - fp6_priv_set_c1(out, &sp_v1); - fp6_priv_set_c2(out, &sp_v2); -} -fn run_fp6_sub_priv(a: ptr>, b: ptr>, - out: ptr>) { - fp6_priv_get_c0(a, &sp_a00); fp6_priv_get_c1(a, &sp_a01); fp6_priv_get_c2(a, &sp_a02); - fp6_priv_get_c0(b, &sp_b00); fp6_priv_get_c1(b, &sp_b01); fp6_priv_get_c2(b, &sp_b02); - run_fp2_sub_pp(&sp_a00, &sp_b00, &sp_v0); - run_fp2_sub_pp(&sp_a01, &sp_b01, &sp_v1); - run_fp2_sub_pp(&sp_a02, &sp_b02, &sp_v2); - fp6_priv_set_c0(out, &sp_v0); - fp6_priv_set_c1(out, &sp_v1); - fp6_priv_set_c2(out, &sp_v2); -} -fn run_fp6_neg_priv(a: ptr>, - out: ptr>) { - fp6_priv_get_c0(a, &sp_a00); fp6_priv_get_c1(a, &sp_a01); fp6_priv_get_c2(a, &sp_a02); - run_fp2_neg_pp(&sp_a00, &sp_v0); - run_fp2_neg_pp(&sp_a01, &sp_v1); - run_fp2_neg_pp(&sp_a02, &sp_v2); - fp6_priv_set_c0(out, &sp_v0); - fp6_priv_set_c1(out, &sp_v1); - fp6_priv_set_c2(out, &sp_v2); -} -fn run_fp6_mul_priv(a: ptr>, b: ptr>, - out: ptr>) { - fp6_priv_get_c0(a, &sp_a00); fp6_priv_get_c1(a, &sp_a01); fp6_priv_get_c2(a, &sp_a02); - fp6_priv_get_c0(b, &sp_b00); fp6_priv_get_c1(b, &sp_b01); fp6_priv_get_c2(b, &sp_b02); - - // u0 = a0*b0, u1 = a1*b1, u2 = a2*b2 - run_fp2_mul_pp(&sp_a00, &sp_b00, &sp_u0); - run_fp2_mul_pp(&sp_a01, &sp_b01, &sp_u1); - run_fp2_mul_pp(&sp_a02, &sp_b02, &sp_u2); - - // r0 = ((a1+a2)(b1+b2) - u1 - u2) * (u+1) + u0 - run_fp2_add_pp(&sp_a01, &sp_a02, &sp_u3); - run_fp2_add_pp(&sp_b01, &sp_b02, &sp_u4); - run_fp2_mul_pp(&sp_u3, &sp_u4, &sp_u5); - run_fp2_sub_pp(&sp_u5, &sp_u1, &sp_u5); - run_fp2_sub_pp(&sp_u5, &sp_u2, &sp_u5); - run_fp2_mul_by_1_plus_u_pp(&sp_u5, &sp_u5); - run_fp2_add_pp(&sp_u5, &sp_u0, &sp_v0); - - // r1 = (a0+a1)(b0+b1) - u0 - u1 + u2*(u+1) - run_fp2_add_pp(&sp_a00, &sp_a01, &sp_u3); - run_fp2_add_pp(&sp_b00, &sp_b01, &sp_u4); - run_fp2_mul_pp(&sp_u3, &sp_u4, &sp_u5); - run_fp2_sub_pp(&sp_u5, &sp_u0, &sp_u5); - run_fp2_sub_pp(&sp_u5, &sp_u1, &sp_u5); - run_fp2_mul_by_1_plus_u_pp(&sp_u2, &sp_u6); - run_fp2_add_pp(&sp_u5, &sp_u6, &sp_v1); - - // r2 = (a0+a2)(b0+b2) - u0 - u2 + u1 - run_fp2_add_pp(&sp_a00, &sp_a02, &sp_u3); - run_fp2_add_pp(&sp_b00, &sp_b02, &sp_u4); - run_fp2_mul_pp(&sp_u3, &sp_u4, &sp_u5); - run_fp2_sub_pp(&sp_u5, &sp_u0, &sp_u5); - run_fp2_sub_pp(&sp_u5, &sp_u2, &sp_u5); - run_fp2_add_pp(&sp_u5, &sp_u1, &sp_v2); - - fp6_priv_set_c0(out, &sp_v0); - fp6_priv_set_c1(out, &sp_v1); - fp6_priv_set_c2(out, &sp_v2); -} -fn run_fp6_sqr_priv(a: ptr>, - out: ptr>) { - fp6_priv_get_c0(a, &sp_a00); fp6_priv_get_c1(a, &sp_a01); fp6_priv_get_c2(a, &sp_a02); - - // u0 = a0^2, u2 = a2^2 - run_fp2_sqr_pp(&sp_a00, &sp_u0); - run_fp2_sqr_pp(&sp_a02, &sp_u2); - // u1 = 2*a0*a1, u3 = 2*a1*a2 - run_fp2_mul_pp(&sp_a00, &sp_a01, &sp_u1); - run_fp2_add_pp(&sp_u1, &sp_u1, &sp_u1); - run_fp2_mul_pp(&sp_a01, &sp_a02, &sp_u3); - run_fp2_add_pp(&sp_u3, &sp_u3, &sp_u3); - - // r2 = (a0+a1+a2)^2 - u0 - u2 - u1 - u3 - run_fp2_add_pp(&sp_a00, &sp_a01, &sp_u4); - run_fp2_add_pp(&sp_u4, &sp_a02, &sp_u4); - run_fp2_sqr_pp(&sp_u4, &sp_v2); - run_fp2_sub_pp(&sp_v2, &sp_u0, &sp_v2); - run_fp2_sub_pp(&sp_v2, &sp_u2, &sp_v2); - run_fp2_sub_pp(&sp_v2, &sp_u1, &sp_v2); - run_fp2_sub_pp(&sp_v2, &sp_u3, &sp_v2); - - // r0 = u3*(u+1) + u0 - run_fp2_mul_by_1_plus_u_pp(&sp_u3, &sp_v0); - run_fp2_add_pp(&sp_v0, &sp_u0, &sp_v0); - - // r1 = u2*(u+1) + u1 - run_fp2_mul_by_1_plus_u_pp(&sp_u2, &sp_v1); - run_fp2_add_pp(&sp_v1, &sp_u1, &sp_v1); - - fp6_priv_set_c0(out, &sp_v0); - fp6_priv_set_c1(out, &sp_v1); - fp6_priv_set_c2(out, &sp_v2); -} -fn run_fp6_inv_priv(a: ptr>, - out: ptr>) { - fp6_priv_get_c0(a, &sp_a00); fp6_priv_get_c1(a, &sp_a01); fp6_priv_get_c2(a, &sp_a02); - - // c0 = a0^2 - mul_v(a1*a2) (-> sp_v0) - run_fp2_sqr_pp(&sp_a00, &sp_v0); - run_fp2_mul_pp(&sp_a01, &sp_a02, &sp_u0); - run_fp2_mul_by_1_plus_u_pp(&sp_u0, &sp_u0); - run_fp2_sub_pp(&sp_v0, &sp_u0, &sp_v0); - - // c1 = mul_v(a2^2) - a0*a1 (-> sp_v1) - run_fp2_sqr_pp(&sp_a02, &sp_v1); - run_fp2_mul_by_1_plus_u_pp(&sp_v1, &sp_v1); - run_fp2_mul_pp(&sp_a00, &sp_a01, &sp_u0); - run_fp2_sub_pp(&sp_v1, &sp_u0, &sp_v1); - - // c2 = a1^2 - a0*a2 (-> sp_v2) - run_fp2_sqr_pp(&sp_a01, &sp_v2); - run_fp2_mul_pp(&sp_a00, &sp_a02, &sp_u0); - run_fp2_sub_pp(&sp_v2, &sp_u0, &sp_v2); - - // norm = mul_v(c1*a2 + c2*a1) + c0*a0 - run_fp2_mul_pp(&sp_v1, &sp_a02, &sp_u1); - run_fp2_mul_pp(&sp_v2, &sp_a01, &sp_u2); - run_fp2_add_pp(&sp_u1, &sp_u2, &sp_u3); - run_fp2_mul_by_1_plus_u_pp(&sp_u3, &sp_u3); - run_fp2_mul_pp(&sp_v0, &sp_a00, &sp_u4); - run_fp2_add_pp(&sp_u3, &sp_u4, &sp_u3); // norm - - // ni = norm^-1 (-> sp_u4) - run_fp2_inv_pp(&sp_u3, &sp_u4); - - // r0 = c0*ni, r1 = c1*ni, r2 = c2*ni - run_fp2_mul_pp(&sp_v0, &sp_u4, &sp_u0); - run_fp2_mul_pp(&sp_v1, &sp_u4, &sp_u1); - run_fp2_mul_pp(&sp_v2, &sp_u4, &sp_u2); - fp6_priv_set_c0(out, &sp_u0); - fp6_priv_set_c1(out, &sp_u1); - fp6_priv_set_c2(out, &sp_u2); -} -fn run_fp6_mul_by_v_priv(a: ptr>, - out: ptr>) { - fp6_priv_get_c0(a, &sp_a00); fp6_priv_get_c1(a, &sp_a01); fp6_priv_get_c2(a, &sp_a02); - run_fp2_mul_by_1_plus_u_pp(&sp_a02, &sp_v0); - fp6_priv_set_c0(out, &sp_v0); - fp6_priv_set_c1(out, &sp_a00); - fp6_priv_set_c2(out, &sp_a01); -} - -fn fp12_add_priv(a: ptr>, b: ptr>, - out: ptr>) { - fp12_priv_get_c0(a, &sp_a0); fp12_priv_get_c1(a, &sp_a1); - fp12_priv_get_c0(b, &sp_b0); fp12_priv_get_c1(b, &sp_b1); - run_fp6_add_priv(&sp_a0, &sp_b0, &sp_r0); - run_fp6_add_priv(&sp_a1, &sp_b1, &sp_r1); - fp12_priv_set_c0(out, &sp_r0); - fp12_priv_set_c1(out, &sp_r1); -} -fn fp12_sub_priv(a: ptr>, b: ptr>, - out: ptr>) { - fp12_priv_get_c0(a, &sp_a0); fp12_priv_get_c1(a, &sp_a1); - fp12_priv_get_c0(b, &sp_b0); fp12_priv_get_c1(b, &sp_b1); - run_fp6_sub_priv(&sp_a0, &sp_b0, &sp_r0); - run_fp6_sub_priv(&sp_a1, &sp_b1, &sp_r1); - fp12_priv_set_c0(out, &sp_r0); - fp12_priv_set_c1(out, &sp_r1); -} -fn fp12_conj_priv(a: ptr>, - out: ptr>) { - fp12_priv_get_c0(a, &sp_a0); fp12_priv_get_c1(a, &sp_a1); - run_fp6_neg_priv(&sp_a1, &sp_r1); - fp12_priv_set_c0(out, &sp_a0); - fp12_priv_set_c1(out, &sp_r1); -} -fn fp12_mul_priv(a: ptr>, b: ptr>, - out: ptr>) { - fp12_priv_get_c0(a, &sp_a0); fp12_priv_get_c1(a, &sp_a1); - fp12_priv_get_c0(b, &sp_b0); fp12_priv_get_c1(b, &sp_b1); - run_fp6_mul_priv(&sp_a0, &sp_b0, &sp_t0); - run_fp6_mul_priv(&sp_a1, &sp_b1, &sp_t1); - run_fp6_add_priv(&sp_a0, &sp_a1, &sp_t2); - run_fp6_add_priv(&sp_b0, &sp_b1, &sp_t3); - run_fp6_mul_priv(&sp_t2, &sp_t3, &sp_r1); - run_fp6_sub_priv(&sp_r1, &sp_t0, &sp_r1); - run_fp6_sub_priv(&sp_r1, &sp_t1, &sp_r1); - run_fp6_mul_by_v_priv(&sp_t1, &sp_t2); - run_fp6_add_priv(&sp_t0, &sp_t2, &sp_r0); - fp12_priv_set_c0(out, &sp_r0); - fp12_priv_set_c1(out, &sp_r1); -} -fn fp12_sqr_priv(a: ptr>, - out: ptr>) { - fp12_priv_get_c0(a, &sp_a0); fp12_priv_get_c1(a, &sp_a1); - run_fp6_add_priv(&sp_a0, &sp_a1, &sp_t0); - run_fp6_mul_by_v_priv(&sp_a1, &sp_t1); - run_fp6_add_priv(&sp_a0, &sp_t1, &sp_t2); - run_fp6_mul_priv(&sp_t0, &sp_t2, &sp_t3); - run_fp6_mul_priv(&sp_a0, &sp_a1, &sp_t1); - run_fp6_add_priv(&sp_t1, &sp_t1, &sp_r1); - run_fp6_sub_priv(&sp_t3, &sp_t1, &sp_r0); - run_fp6_mul_by_v_priv(&sp_t1, &sp_t2); - run_fp6_sub_priv(&sp_r0, &sp_t2, &sp_r0); - fp12_priv_set_c0(out, &sp_r0); - fp12_priv_set_c1(out, &sp_r1); -} -fn fp12_inv_priv(a: ptr>, - out: ptr>) { - fp12_priv_get_c0(a, &sp_a0); fp12_priv_get_c1(a, &sp_a1); - run_fp6_sqr_priv(&sp_a0, &sp_t0); - run_fp6_sqr_priv(&sp_a1, &sp_t1); - run_fp6_mul_by_v_priv(&sp_t1, &sp_t2); - run_fp6_sub_priv(&sp_t0, &sp_t2, &sp_t3); - run_fp6_inv_priv(&sp_t3, &sp_t0); // ti -> sp_t0 - run_fp6_mul_priv(&sp_a0, &sp_t0, &sp_r0); - run_fp6_mul_priv(&sp_a1, &sp_t0, &sp_t1); // r1_pos -> sp_t1 - run_fp6_neg_priv(&sp_t1, &sp_r1); - fp12_priv_set_c0(out, &sp_r0); - fp12_priv_set_c1(out, &sp_r1); -} - -// Cyclotomic squaring uses six Fp2 squarings (sqr_fp4 ×3) and a handful of -// Fp2 ops. All operands are 24 x u32 = 96 bytes; the function-stack budget -// holds at this size. -fn fp12_cyclotomic_sqr_priv(a: ptr>, - out: ptr>) { - var fa: array; - for (var i = 0u; i < 144u; i = i + 1u) { fa[i] = (*a)[i]; } - var fr: array; - fp12_cyclotomic_sqr_p(&fa, &fr); - for (var i = 0u; i < 144u; i = i + 1u) { (*out)[i] = fr[i]; } -} - -fn fp12_one_p(out: ptr>) { - var onep: array; - var zerop: array; - for (var i = 0u; i < 24u; i = i + 1u) { zerop[i] = 0u; } - for (var i = 0u; i < 12u; i = i + 1u) { onep[i] = BLS_R[i]; onep[12u + i] = 0u; } - var c0: array; - fp6_set_c0(&c0, &onep); - fp6_set_c1(&c0, &zerop); - fp6_set_c2(&c0, &zerop); - var c1: array; - fp6_set_c0(&c1, &zerop); - fp6_set_c1(&c1, &zerop); - fp6_set_c2(&c1, &zerop); - fp12_set_c0(out, &c0); - fp12_set_c1(out, &c1); -} diff --git a/bls/gpu/wgsl/bls_fp2.wgsl b/bls/gpu/wgsl/bls_fp2.wgsl deleted file mode 100644 index 00a3f5f..0000000 --- a/bls/gpu/wgsl/bls_fp2.wgsl +++ /dev/null @@ -1,172 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// WGSL peer of bls_fp2.metal. Fp2 represented as 24 x u32 (c0 in [0..12), c1 in [12..24)). -// Byte layout matches blst_fp2: 96 bytes. -// -// Inputs are passed by ptr> AND outputs are written -// through ptr>. Returning Fp2/Fp6/Fp12 by value -// materialises a stack copy at every call site; the karatsuba/inversion call -// tree on the upper tower (fp6_inv / fp12 mul/sqr/inv/conj/cyclo_sqr) then -// blows AGXMetalG13X's function-call stack budget. Out-pointer form keeps -// every intermediate in a single named slot the caller already owns. -// Same arithmetic as Metal. - -fn fp2_get_c0(a: ptr>) -> array { - var r: array; - for (var i = 0u; i < 12u; i = i + 1u) { r[i] = (*a)[i]; } - return r; -} -fn fp2_get_c1(a: ptr>) -> array { - var r: array; - for (var i = 0u; i < 12u; i = i + 1u) { r[i] = (*a)[12u + i]; } - return r; -} -fn fp2_pack(c0: array, c1: array) -> array { - var r: array; - for (var i = 0u; i < 12u; i = i + 1u) { r[i] = c0[i]; r[12u + i] = c1[i]; } - return r; -} -fn fp2_zero() -> array { - var r: array; - for (var i = 0u; i < 24u; i = i + 1u) { r[i] = 0u; } - return r; -} -fn fp2_one_v() -> array { - var r: array; - for (var i = 0u; i < 12u; i = i + 1u) { r[i] = BLS_R[i]; r[12u + i] = 0u; } - return r; -} -fn fp2_is_zero(a: ptr>) -> bool { - var acc = 0u; - for (var i = 0u; i < 24u; i = i + 1u) { acc = acc | (*a)[i]; } - return acc == 0u; -} - -// ---------- Out-pointer Fp2 primitives (used by upper tower) ---------- - -fn fp2_add_p(a: ptr>, b: ptr>, - out: ptr>) { - var ac0 = fp2_get_c0(a); var ac1 = fp2_get_c1(a); - var bc0 = fp2_get_c0(b); var bc1 = fp2_get_c1(b); - let r0 = fp_add(ac0, bc0); - let r1 = fp_add(ac1, bc1); - for (var i = 0u; i < 12u; i = i + 1u) { (*out)[i] = r0[i]; (*out)[12u + i] = r1[i]; } -} -fn fp2_sub_p(a: ptr>, b: ptr>, - out: ptr>) { - var ac0 = fp2_get_c0(a); var ac1 = fp2_get_c1(a); - var bc0 = fp2_get_c0(b); var bc1 = fp2_get_c1(b); - let r0 = fp_sub(ac0, bc0); - let r1 = fp_sub(ac1, bc1); - for (var i = 0u; i < 12u; i = i + 1u) { (*out)[i] = r0[i]; (*out)[12u + i] = r1[i]; } -} -fn fp2_neg_p(a: ptr>, out: ptr>) { - var ac0 = fp2_get_c0(a); var ac1 = fp2_get_c1(a); - let r0 = fp_neg(ac0); - let r1 = fp_neg(ac1); - for (var i = 0u; i < 12u; i = i + 1u) { (*out)[i] = r0[i]; (*out)[12u + i] = r1[i]; } -} -fn fp2_mul_p(a: ptr>, b: ptr>, - out: ptr>) { - var a0 = fp2_get_c0(a); - var a1 = fp2_get_c1(a); - var b0 = fp2_get_c0(b); - var b1 = fp2_get_c1(b); - let aa = fp_mul(a0, b0); - let bb = fp_mul(a1, b1); - let sa = fp_add(a0, a1); - let sb = fp_add(b0, b1); - let cross = fp_mul(sa, sb); - let r0 = fp_sub(aa, bb); - let r1 = fp_sub(fp_sub(cross, aa), bb); - for (var i = 0u; i < 12u; i = i + 1u) { (*out)[i] = r0[i]; (*out)[12u + i] = r1[i]; } -} -fn fp2_sqr_p(a: ptr>, out: ptr>) { - var a0 = fp2_get_c0(a); - var a1 = fp2_get_c1(a); - let ab = fp_mul(a0, a1); - let sum = fp_add(a0, a1); - let dif = fp_sub(a0, a1); - let r0 = fp_mul(sum, dif); - let r1 = fp_add(ab, ab); - for (var i = 0u; i < 12u; i = i + 1u) { (*out)[i] = r0[i]; (*out)[12u + i] = r1[i]; } -} -fn fp2_conj_p(a: ptr>, out: ptr>) { - var ac0 = fp2_get_c0(a); - var ac1 = fp2_get_c1(a); - let r1 = fp_neg(ac1); - for (var i = 0u; i < 12u; i = i + 1u) { (*out)[i] = ac0[i]; (*out)[12u + i] = r1[i]; } -} -fn fp2_inv_p(a: ptr>, out: ptr>) { - var a0 = fp2_get_c0(a); - var a1 = fp2_get_c1(a); - let t0 = fp_sqr(a0); - let t1 = fp_sqr(a1); - let norm = fp_add(t0, t1); - let ni = fp_inv(norm); - let r0 = fp_mul(a0, ni); - let r1 = fp_neg(fp_mul(a1, ni)); - for (var i = 0u; i < 12u; i = i + 1u) { (*out)[i] = r0[i]; (*out)[12u + i] = r1[i]; } -} -fn fp2_mul_by_1_plus_u_p(a: ptr>, - out: ptr>) { - var a0 = fp2_get_c0(a); - var a1 = fp2_get_c1(a); - let r0 = fp_sub(a0, a1); - let r1 = fp_add(a0, a1); - for (var i = 0u; i < 12u; i = i + 1u) { (*out)[i] = r0[i]; (*out)[12u + i] = r1[i]; } -} -fn fp2_frobenius_p(a: ptr>, n: u32, - out: ptr>) { - if ((n & 1u) == 1u) { fp2_conj_p(a, out); return; } - for (var i = 0u; i < 24u; i = i + 1u) { (*out)[i] = (*a)[i]; } -} - -// ---------- Legacy by-value Fp2 helpers (still used by Stage-1 kernels) ---------- - -fn fp2_add(a: ptr>, b: ptr>) -> array { - var r: array; - fp2_add_p(a, b, &r); - return r; -} -fn fp2_sub(a: ptr>, b: ptr>) -> array { - var r: array; - fp2_sub_p(a, b, &r); - return r; -} -fn fp2_neg(a: ptr>) -> array { - var r: array; - fp2_neg_p(a, &r); - return r; -} -fn fp2_mul(a: ptr>, b: ptr>) -> array { - var r: array; - fp2_mul_p(a, b, &r); - return r; -} -fn fp2_sqr(a: ptr>) -> array { - var r: array; - fp2_sqr_p(a, &r); - return r; -} -fn fp2_conj(a: ptr>) -> array { - var r: array; - fp2_conj_p(a, &r); - return r; -} -fn fp2_inv(a: ptr>) -> array { - var r: array; - fp2_inv_p(a, &r); - return r; -} -fn fp2_frobenius(a: ptr>, n: u32) -> array { - var r: array; - fp2_frobenius_p(a, n, &r); - return r; -} -fn fp2_mul_by_1_plus_u(a: ptr>) -> array { - var r: array; - fp2_mul_by_1_plus_u_p(a, &r); - return r; -} diff --git a/bls/gpu/wgsl/bls_fp6.wgsl b/bls/gpu/wgsl/bls_fp6.wgsl deleted file mode 100644 index d8553eb..0000000 --- a/bls/gpu/wgsl/bls_fp6.wgsl +++ /dev/null @@ -1,342 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// WGSL peer of bls_fp6.metal. Fp6 = 72 x u32 = 3 * Fp2 (288 bytes byte-equal blst_fp6). -// -// Out-pointer form for every function used by the Fp12 call tree. -// Returning array by value materialises a stack copy at every call -// site; the karatsuba/inversion call tree on the upper tower then exceeds -// AGXMetalG13X's function-call stack budget. Out-pointer form keeps every -// intermediate in a single named slot the caller already owns. Same -// arithmetic as Metal. -// -// Intermediates are computed into local Fp2 scratches (24 x u32) instead of -// chains of `let r = fp2_*(...)` which each materialise a 96-byte copy. - -fn fp6_get_c0(a: ptr>, out: ptr>) { - for (var i = 0u; i < 24u; i = i + 1u) { (*out)[i] = (*a)[i]; } -} -fn fp6_get_c1(a: ptr>, out: ptr>) { - for (var i = 0u; i < 24u; i = i + 1u) { (*out)[i] = (*a)[24u + i]; } -} -fn fp6_get_c2(a: ptr>, out: ptr>) { - for (var i = 0u; i < 24u; i = i + 1u) { (*out)[i] = (*a)[48u + i]; } -} -fn fp6_set_c0(out: ptr>, v: ptr>) { - for (var i = 0u; i < 24u; i = i + 1u) { (*out)[i] = (*v)[i]; } -} -fn fp6_set_c1(out: ptr>, v: ptr>) { - for (var i = 0u; i < 24u; i = i + 1u) { (*out)[24u + i] = (*v)[i]; } -} -fn fp6_set_c2(out: ptr>, v: ptr>) { - for (var i = 0u; i < 24u; i = i + 1u) { (*out)[48u + i] = (*v)[i]; } -} - -// ---------- Out-pointer Fp6 primitives ---------- - -fn fp6_add_p(a: ptr>, b: ptr>, - out: ptr>) { - var a0: array; fp6_get_c0(a, &a0); - var a1: array; fp6_get_c1(a, &a1); - var a2: array; fp6_get_c2(a, &a2); - var b0: array; fp6_get_c0(b, &b0); - var b1: array; fp6_get_c1(b, &b1); - var b2: array; fp6_get_c2(b, &b2); - var r: array; - fp2_add_p(&a0, &b0, &r); fp6_set_c0(out, &r); - fp2_add_p(&a1, &b1, &r); fp6_set_c1(out, &r); - fp2_add_p(&a2, &b2, &r); fp6_set_c2(out, &r); -} -fn fp6_sub_p(a: ptr>, b: ptr>, - out: ptr>) { - var a0: array; fp6_get_c0(a, &a0); - var a1: array; fp6_get_c1(a, &a1); - var a2: array; fp6_get_c2(a, &a2); - var b0: array; fp6_get_c0(b, &b0); - var b1: array; fp6_get_c1(b, &b1); - var b2: array; fp6_get_c2(b, &b2); - var r: array; - fp2_sub_p(&a0, &b0, &r); fp6_set_c0(out, &r); - fp2_sub_p(&a1, &b1, &r); fp6_set_c1(out, &r); - fp2_sub_p(&a2, &b2, &r); fp6_set_c2(out, &r); -} -fn fp6_neg_p(a: ptr>, out: ptr>) { - var a0: array; fp6_get_c0(a, &a0); - var a1: array; fp6_get_c1(a, &a1); - var a2: array; fp6_get_c2(a, &a2); - var r: array; - fp2_neg_p(&a0, &r); fp6_set_c0(out, &r); - fp2_neg_p(&a1, &r); fp6_set_c1(out, &r); - fp2_neg_p(&a2, &r); fp6_set_c2(out, &r); -} - -fn fp6_mul_p(a: ptr>, b: ptr>, - out: ptr>) { - var a0: array; fp6_get_c0(a, &a0); - var a1: array; fp6_get_c1(a, &a1); - var a2: array; fp6_get_c2(a, &a2); - var b0: array; fp6_get_c0(b, &b0); - var b1: array; fp6_get_c1(b, &b1); - var b2: array; fp6_get_c2(b, &b2); - - var t0: array; fp2_mul_p(&a0, &b0, &t0); - var t1: array; fp2_mul_p(&a1, &b1, &t1); - var t2: array; fp2_mul_p(&a2, &b2, &t2); - - var s_a: array; - var s_b: array; - var tmp: array; - var r0: array; - var r1: array; - var r2: array; - - // r0 = ((a1+a2)(b1+b2) - t1 - t2)(u+1) + t0 - fp2_add_p(&a1, &a2, &s_a); - fp2_add_p(&b1, &b2, &s_b); - fp2_mul_p(&s_a, &s_b, &tmp); - fp2_sub_p(&tmp, &t1, &tmp); - fp2_sub_p(&tmp, &t2, &tmp); - fp2_mul_by_1_plus_u_p(&tmp, &tmp); - fp2_add_p(&tmp, &t0, &r0); - - // r1 = (a0+a1)(b0+b1) - t0 - t1 + t2(u+1) - fp2_add_p(&a0, &a1, &s_a); - fp2_add_p(&b0, &b1, &s_b); - fp2_mul_p(&s_a, &s_b, &tmp); - fp2_sub_p(&tmp, &t0, &tmp); - fp2_sub_p(&tmp, &t1, &tmp); - var t2_v: array; fp2_mul_by_1_plus_u_p(&t2, &t2_v); - fp2_add_p(&tmp, &t2_v, &r1); - - // r2 = (a0+a2)(b0+b2) - t0 - t2 + t1 - fp2_add_p(&a0, &a2, &s_a); - fp2_add_p(&b0, &b2, &s_b); - fp2_mul_p(&s_a, &s_b, &tmp); - fp2_sub_p(&tmp, &t0, &tmp); - fp2_sub_p(&tmp, &t2, &tmp); - fp2_add_p(&tmp, &t1, &r2); - - fp6_set_c0(out, &r0); - fp6_set_c1(out, &r1); - fp6_set_c2(out, &r2); -} - -fn fp6_sqr_p(a: ptr>, out: ptr>) { - var a0: array; fp6_get_c0(a, &a0); - var a1: array; fp6_get_c1(a, &a1); - var a2: array; fp6_get_c2(a, &a2); - - var s0: array; fp2_sqr_p(&a0, &s0); - var s2: array; fp2_sqr_p(&a2, &s2); - var m01: array; fp2_mul_p(&a0, &a1, &m01); fp2_add_p(&m01, &m01, &m01); - var m12: array; fp2_mul_p(&a1, &a2, &m12); fp2_add_p(&m12, &m12, &m12); - - var r2: array; - var sum: array; - fp2_add_p(&a0, &a1, &sum); - fp2_add_p(&sum, &a2, &sum); - fp2_sqr_p(&sum, &r2); - fp2_sub_p(&r2, &s0, &r2); - fp2_sub_p(&r2, &s2, &r2); - fp2_sub_p(&r2, &m01, &r2); - fp2_sub_p(&r2, &m12, &r2); - - var r0: array; - fp2_mul_by_1_plus_u_p(&m12, &r0); - fp2_add_p(&r0, &s0, &r0); - - var r1: array; - fp2_mul_by_1_plus_u_p(&s2, &r1); - fp2_add_p(&r1, &m01, &r1); - - fp6_set_c0(out, &r0); - fp6_set_c1(out, &r1); - fp6_set_c2(out, &r2); -} - -fn fp6_inv_p(a: ptr>, out: ptr>) { - var a0: array; fp6_get_c0(a, &a0); - var a1: array; fp6_get_c1(a, &a1); - var a2: array; fp6_get_c2(a, &a2); - - // c0 = a0^2 - mul_v(a1*a2) - var c0: array; - var tmp: array; - fp2_sqr_p(&a0, &c0); - fp2_mul_p(&a1, &a2, &tmp); - fp2_mul_by_1_plus_u_p(&tmp, &tmp); - fp2_sub_p(&c0, &tmp, &c0); - - // c1 = mul_v(a2^2) - a0*a1 - var c1: array; - fp2_sqr_p(&a2, &c1); - fp2_mul_by_1_plus_u_p(&c1, &c1); - fp2_mul_p(&a0, &a1, &tmp); - fp2_sub_p(&c1, &tmp, &c1); - - // c2 = a1^2 - a0*a2 - var c2: array; - fp2_sqr_p(&a1, &c2); - fp2_mul_p(&a0, &a2, &tmp); - fp2_sub_p(&c2, &tmp, &c2); - - // norm = mul_v(a2*c1 + a1*c2) + a0*c0 - var norm: array; - var t1: array; - fp2_mul_p(&c1, &a2, &t1); - fp2_mul_p(&c2, &a1, &tmp); - fp2_add_p(&t1, &tmp, &norm); - fp2_mul_by_1_plus_u_p(&norm, &norm); - fp2_mul_p(&c0, &a0, &tmp); - fp2_add_p(&norm, &tmp, &norm); - - var ni: array; - fp2_inv_p(&norm, &ni); - - var r: array; - fp2_mul_p(&c0, &ni, &r); fp6_set_c0(out, &r); - fp2_mul_p(&c1, &ni, &r); fp6_set_c1(out, &r); - fp2_mul_p(&c2, &ni, &r); fp6_set_c2(out, &r); -} - -// Frobenius coefficient tables (in Montgomery form). 12 x u32 LE renderings of -// the same constants used in bls_fp6.metal. -fn frob6_c1_n1(out: ptr>) { - for (var i = 0u; i < 12u; i = i + 1u) { (*out)[i] = 0u; } - (*out)[12u + 0u] = 0x8671F071u; (*out)[12u + 1u] = 0xCD03C9E4u; - (*out)[12u + 2u] = 0x1FCDA5D2u; (*out)[12u + 3u] = 0x5DAB2246u; - (*out)[12u + 4u] = 0xD3851B95u; (*out)[12u + 5u] = 0x587042AFu; - (*out)[12u + 6u] = 0x01BACB9Eu; (*out)[12u + 7u] = 0x8EB60EBEu; - (*out)[12u + 8u] = 0x83D050D2u; (*out)[12u + 9u] = 0x03F97D6Eu; - (*out)[12u + 10u] = 0x54638741u; (*out)[12u + 11u] = 0x18F02065u; -} -fn frob6_c1_n2(out: ptr>) { - (*out)[0u] = 0x798A64E8u; (*out)[1u] = 0x30F1361Bu; - (*out)[2u] = 0x7ECE5A2Au; (*out)[3u] = 0xF3B8DDABu; - (*out)[4u] = 0xC61577F7u; (*out)[5u] = 0x16A8CA3Au; - (*out)[6u] = 0x74FD029Bu; (*out)[7u] = 0xC26A2FF8u; - (*out)[8u] = 0x60701C6Eu; (*out)[9u] = 0x3636B766u; - (*out)[10u] = 0x241B6160u; (*out)[11u] = 0x051BA4ABu; - for (var i = 12u; i < 24u; i = i + 1u) { (*out)[i] = 0u; } -} -fn frob6_c1_n3(out: ptr>) { - for (var i = 0u; i < 12u; i = i + 1u) { (*out)[i] = 0u; } - (*out)[12u + 0u] = 0x0002FFFDu; (*out)[12u + 1u] = 0x76090000u; - (*out)[12u + 2u] = 0xC40C0002u; (*out)[12u + 3u] = 0xEBF40000u; - (*out)[12u + 4u] = 0x53C758BAu; (*out)[12u + 5u] = 0x5F489857u; - (*out)[12u + 6u] = 0x70525745u; (*out)[12u + 7u] = 0x77CE5853u; - (*out)[12u + 8u] = 0xA256EC6Du; (*out)[12u + 9u] = 0x5C071A97u; - (*out)[12u + 10u] = 0xFA80E493u; (*out)[12u + 11u] = 0x15F65EC3u; -} -fn frob6_c2_n1(out: ptr>) { - (*out)[0u] = 0x867545C3u; (*out)[1u] = 0x890DC9E4u; - (*out)[2u] = 0x3285A5D5u; (*out)[3u] = 0x2AF32253u; - (*out)[4u] = 0x309B7E2Cu; (*out)[5u] = 0x50880866u; - (*out)[6u] = 0x7E881024u; (*out)[7u] = 0xA20D1B8Cu; - (*out)[8u] = 0xE2DB9068u; (*out)[9u] = 0x14E4F04Fu; - (*out)[10u] = 0x1564853Au; (*out)[11u] = 0x14E56D3Fu; -} -fn frob6_c2_n2(out: ptr>) { - (*out)[0u] = 0x8671F071u; (*out)[1u] = 0xCD03C9E4u; - (*out)[2u] = 0x1FCDA5D2u; (*out)[3u] = 0x5DAB2246u; - (*out)[4u] = 0xD3851B95u; (*out)[5u] = 0x587042AFu; - (*out)[6u] = 0x01BACB9Eu; (*out)[7u] = 0x8EB60EBEu; - (*out)[8u] = 0x83D050D2u; (*out)[9u] = 0x03F97D6Eu; - (*out)[10u] = 0x54638741u; (*out)[11u] = 0x18F02065u; -} -fn frob6_c2_n3(out: ptr>) { - (*out)[0u] = 0xFFFCAAAEu; (*out)[1u] = 0x43F5FFFFu; - (*out)[2u] = 0xED47FFFDu; (*out)[3u] = 0x32B7FFF2u; - (*out)[4u] = 0xA2E99D69u; (*out)[5u] = 0x07E83A49u; - (*out)[6u] = 0x8332BB7Au; (*out)[7u] = 0xECA8F331u; - (*out)[8u] = 0xA0F4C069u; (*out)[9u] = 0xEF148D1Eu; - (*out)[10u] = 0x3EFF0206u; (*out)[11u] = 0x040AB326u; -} - -fn fp6_frobenius_p(a: ptr>, n: u32, - out: ptr>) { - var a0: array; fp6_get_c0(a, &a0); - var a1: array; fp6_get_c1(a, &a1); - var a2: array; fp6_get_c2(a, &a2); - var r0: array; fp2_frobenius_p(&a0, n, &r0); - var r1: array; fp2_frobenius_p(&a1, n, &r1); - var r2: array; fp2_frobenius_p(&a2, n, &r2); - - var c1: array; - var c2_real: array; - if (n == 1u) { frob6_c1_n1(&c1); frob6_c2_n1(&c2_real); } - else if (n == 2u) { frob6_c1_n2(&c1); frob6_c2_n2(&c2_real); } - else { frob6_c1_n3(&c1); frob6_c2_n3(&c2_real); } - - var r1_new: array; fp2_mul_p(&r1, &c1, &r1_new); - var r2_c0_in = fp2_get_c0(&r2); - var r2_c1_in = fp2_get_c1(&r2); - let r2_c0 = fp_mul(r2_c0_in, c2_real); - let r2_c1 = fp_mul(r2_c1_in, c2_real); - var r2_new: array; - for (var i = 0u; i < 12u; i = i + 1u) { r2_new[i] = r2_c0[i]; r2_new[12u + i] = r2_c1[i]; } - - fp6_set_c0(out, &r0); - fp6_set_c1(out, &r1_new); - fp6_set_c2(out, &r2_new); -} - -// (a0 + a1 v + a2 v^2) * v = a2 (u+1) + a0 v + a1 v^2 (Fp6 in Fp2 layout) -fn fp6_mul_by_v_p(a: ptr>, - out: ptr>) { - var a2: array; fp6_get_c2(a, &a2); - var r0: array; fp2_mul_by_1_plus_u_p(&a2, &r0); - var r1: array; fp6_get_c0(a, &r1); - var r2: array; fp6_get_c1(a, &r2); - fp6_set_c0(out, &r0); - fp6_set_c1(out, &r1); - fp6_set_c2(out, &r2); -} - -// ---------- Legacy by-value Fp6 (Stage-1 kernels keep this) ---------- - -fn fp6_pack_v(c0: array, c1: array, c2: array) -> array { - var r: array; - for (var i = 0u; i < 24u; i = i + 1u) { - r[i] = c0[i]; - r[24u + i] = c1[i]; - r[48u + i] = c2[i]; - } - return r; -} - -fn fp6_add(a: ptr>, b: ptr>) -> array { - var r: array; - fp6_add_p(a, b, &r); - return r; -} -fn fp6_sub(a: ptr>, b: ptr>) -> array { - var r: array; - fp6_sub_p(a, b, &r); - return r; -} -fn fp6_neg(a: ptr>) -> array { - var r: array; - fp6_neg_p(a, &r); - return r; -} -fn fp6_mul(a: ptr>, b: ptr>) -> array { - var r: array; - fp6_mul_p(a, b, &r); - return r; -} -fn fp6_sqr(a: ptr>) -> array { - var r: array; - fp6_sqr_p(a, &r); - return r; -} -fn fp6_inv(a: ptr>) -> array { - var r: array; - fp6_inv_p(a, &r); - return r; -} -fn fp6_frobenius(a: ptr>, n: u32) -> array { - var r: array; - fp6_frobenius_p(a, n, &r); - return r; -} diff --git a/bls/gpu/wgsl/bls_fp_ops.wgsl b/bls/gpu/wgsl/bls_fp_ops.wgsl deleted file mode 100644 index 8fee982..0000000 --- a/bls/gpu/wgsl/bls_fp_ops.wgsl +++ /dev/null @@ -1,322 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// WGSL peer of bls_fp_ops.h.metal — Fp arithmetic for BLS12-381. -// Limbs: 12 x u32 little-endian (no native u64 in WGSL). -// All values in Montgomery form. Layout matches blst's vec384 byte-for-byte -// (each pair of u32 makes one u64 limb of vec384, native endian). -// -// Concatenated into every WGSL pipeline by the host driver before kernel sources. - -// BLS12-381 modulus p (12 x u32 LE). -const BLS_P: array = array( - 0xFFFFAAABu, 0xB9FEFFFFu, 0xB153FFFFu, 0x1EABFFFEu, - 0xF6B0F624u, 0x6730D2A0u, 0xF38512BFu, 0x64774B84u, - 0x434BACD7u, 0x4B1BA7B6u, 0x397FE69Au, 0x1A0111EAu -); - -// R^2 mod p (Montgomery) -const BLS_R2: array = array( - 0x1C341746u, 0xF4DF1F34u, 0x09D104F1u, 0x0A76E6A6u, - 0x4C95B6D5u, 0x8DE5476Cu, 0x939D83C0u, 0x67EB88A9u, - 0xB519952Du, 0x9A793E85u, 0x92CAE3AAu, 0x11988FE5u -); - -// R mod p (= 1 in Mont) -const BLS_R: array = array( - 0x0002FFFDu, 0x76090000u, 0xC40C0002u, 0xEBF40000u, - 0x53C758BAu, 0x5F489857u, 0x70525745u, 0x77CE5853u, - 0xA256EC6Du, 0x5C071A97u, 0xFA80E493u, 0x15F65EC3u -); - -// p_inv (low 64 bits = 0x89F3FFFCFFFCFFFD) -const BLS_P_INV_LO: u32 = 0xFFFCFFFDu; -const BLS_P_INV_HI: u32 = 0x89F3FFFCu; - -// 12 x u32 zero -const ZERO384: array = array(0u,0u,0u,0u,0u,0u,0u,0u,0u,0u,0u,0u); - -fn u384_is_zero(a: array) -> bool { - var acc = 0u; - for (var i = 0u; i < 12u; i = i + 1u) { acc = acc | a[i]; } - return acc == 0u; -} - -fn u384_cmp(a: array, b: array) -> i32 { - for (var i = 11i; i >= 0; i = i - 1) { - let ui = u32(i); - if (a[ui] > b[ui]) { return 1; } - if (a[ui] < b[ui]) { return -1; } - } - return 0; -} - -// Add 12 x u32; returns final carry. -fn u384_add(a: array, b: array, r: ptr>) -> u32 { - var c: u32 = 0u; - for (var i = 0u; i < 12u; i = i + 1u) { - let s1 = a[i] + c; - var c1: u32 = 0u; if (s1 < a[i]) { c1 = 1u; } - let s2 = s1 + b[i]; - var c2: u32 = 0u; if (s2 < s1) { c2 = 1u; } - (*r)[i] = s2; - c = c1 + c2; - } - return c; -} - -fn u384_sub(a: array, b: array, r: ptr>) -> u32 { - var bw: u32 = 0u; - for (var i = 0u; i < 12u; i = i + 1u) { - let d1 = a[i] - bw; - var b1: u32 = 0u; if (d1 > a[i]) { b1 = 1u; } - let d2 = d1 - b[i]; - var b2: u32 = 0u; if (d2 > d1) { b2 = 1u; } - (*r)[i] = d2; - bw = b1 + b2; - } - return bw; -} - -// 32 x 32 -> 64 (lo, hi) -fn mul32_64(a: u32, b: u32) -> vec2 { - let al = a & 0xFFFFu; - let ah = a >> 16u; - let bl = b & 0xFFFFu; - let bh = b >> 16u; - let ll = al * bl; - let lh = al * bh; - let hl = ah * bl; - let hh = ah * bh; - let mid = (ll >> 16u) + (lh & 0xFFFFu) + (hl & 0xFFFFu); - let lo = (mid << 16u) | (ll & 0xFFFFu); - let hi = hh + (lh >> 16u) + (hl >> 16u) + (mid >> 16u); - return vec2(lo, hi); -} - -// 384x384 -> 768 schoolbook in u32. -fn u384_mul768(a: array, b: array, t: ptr>) { - for (var i = 0u; i < 24u; i = i + 1u) { (*t)[i] = 0u; } - for (var i = 0u; i < 12u; i = i + 1u) { - var carry: u32 = 0u; - for (var j = 0u; j < 12u; j = j + 1u) { - let prod = mul32_64(a[i], b[j]); - let lo = prod.x; - let hi = prod.y; - - // accumulate lo + carry + t[i+j] - var s = lo + carry; - var c1: u32 = 0u; if (s < lo) { c1 = 1u; } - let s2 = s + (*t)[i + j]; - var c2: u32 = 0u; if (s2 < s) { c2 = 1u; } - (*t)[i + j] = s2; - carry = hi + c1 + c2; - } - // propagate final carry - var k = i + 12u; - while (carry != 0u && k < 24u) { - let sum = (*t)[k] + carry; - var c: u32 = 0u; if (sum < (*t)[k]) { c = 1u; } - (*t)[k] = sum; - carry = c; - k = k + 1u; - } - } -} - -// CIOS Montgomery reduce: t (24 x u32) -> r (12 x u32) = t * R^-1 mod p. -fn mont_reduce_384(t: ptr>, r: ptr>) { - // a holds t plus an extra u32 for carry overflow at position 24. - var a: array; - for (var i = 0u; i < 24u; i = i + 1u) { a[i] = (*t)[i]; } - a[24] = 0u; - a[25] = 0u; - - // We work in 32-bit limbs, so 6 reductions of 64-bit `u` from Metal become - // 12 reductions of 32-bit `u`. p_inv_32 (low 32 bits of -p^-1 mod 2^32) - // is BLS_P_INV_LO = 0xFFFCFFFDu. - for (var i = 0u; i < 12u; i = i + 1u) { - let u = a[i] * BLS_P_INV_LO; // -a[i] * p^-1 mod 2^32 - var carry: u32 = 0u; - for (var j = 0u; j < 12u; j = j + 1u) { - let prod = mul32_64(u, BLS_P[j]); - var s = prod.x + carry; - var c1: u32 = 0u; if (s < prod.x) { c1 = 1u; } - let s2 = s + a[i + j]; - var c2: u32 = 0u; if (s2 < s) { c2 = 1u; } - a[i + j] = s2; - carry = prod.y + c1 + c2; - } - // propagate - var k = i + 12u; - while (carry != 0u && k < 26u) { - let sum = a[k] + carry; - var c: u32 = 0u; if (sum < a[k]) { c = 1u; } - a[k] = sum; - carry = c; - k = k + 1u; - } - } - - var rr: array; - for (var i = 0u; i < 12u; i = i + 1u) { rr[i] = a[12u + i]; } - let cmp = u384_cmp(rr, BLS_P); - if (a[24] != 0u || cmp >= 0) { - var sub: array; - _ = u384_sub(rr, BLS_P, &sub); - rr = sub; - } - *r = rr; -} - -// Private scratch shared across the fp_* leaves so we don't materialise a -// fresh 96-byte fp_mul stack frame on every call from the Fp12 call tree. -// Workgroup size is 1 in every kernel that touches the upper tower, so each -// invocation owns these slots without contention. -var fp_scratch_t: array; -var fp_scratch_a: array; -var fp_scratch_inv_exp: array; -var fp_scratch_inv_result: array; -var fp_scratch_inv_base: array; - -fn u384_mul768_priv(a: array, b: array, - t: ptr>) { - for (var i = 0u; i < 24u; i = i + 1u) { (*t)[i] = 0u; } - for (var i = 0u; i < 12u; i = i + 1u) { - var carry: u32 = 0u; - for (var j = 0u; j < 12u; j = j + 1u) { - let prod = mul32_64(a[i], b[j]); - let lo = prod.x; - let hi = prod.y; - var s = lo + carry; - var c1: u32 = 0u; if (s < lo) { c1 = 1u; } - let s2 = s + (*t)[i + j]; - var c2: u32 = 0u; if (s2 < s) { c2 = 1u; } - (*t)[i + j] = s2; - carry = hi + c1 + c2; - } - var k = i + 12u; - while (carry != 0u && k < 24u) { - let sum = (*t)[k] + carry; - var c: u32 = 0u; if (sum < (*t)[k]) { c = 1u; } - (*t)[k] = sum; - carry = c; - k = k + 1u; - } - } -} - -fn mont_reduce_384_priv(t: ptr>, - r: ptr>) { - for (var i = 0u; i < 24u; i = i + 1u) { fp_scratch_a[i] = (*t)[i]; } - fp_scratch_a[24] = 0u; - fp_scratch_a[25] = 0u; - - for (var i = 0u; i < 12u; i = i + 1u) { - let u = fp_scratch_a[i] * BLS_P_INV_LO; - var carry: u32 = 0u; - for (var j = 0u; j < 12u; j = j + 1u) { - let prod = mul32_64(u, BLS_P[j]); - var s = prod.x + carry; - var c1: u32 = 0u; if (s < prod.x) { c1 = 1u; } - let s2 = s + fp_scratch_a[i + j]; - var c2: u32 = 0u; if (s2 < s) { c2 = 1u; } - fp_scratch_a[i + j] = s2; - carry = prod.y + c1 + c2; - } - var k = i + 12u; - while (carry != 0u && k < 26u) { - let sum = fp_scratch_a[k] + carry; - var c: u32 = 0u; if (sum < fp_scratch_a[k]) { c = 1u; } - fp_scratch_a[k] = sum; - carry = c; - k = k + 1u; - } - } - - var rr: array; - for (var i = 0u; i < 12u; i = i + 1u) { rr[i] = fp_scratch_a[12u + i]; } - let cmp = u384_cmp(rr, BLS_P); - if (fp_scratch_a[24] != 0u || cmp >= 0) { - var sub: array; - _ = u384_sub(rr, BLS_P, &sub); - rr = sub; - } - *r = rr; -} - -fn fp_mul(a: array, b: array) -> array { - u384_mul768_priv(a, b, &fp_scratch_t); - var r: array; - mont_reduce_384_priv(&fp_scratch_t, &r); - return r; -} - -fn fp_sqr(a: array) -> array { return fp_mul(a, a); } - -fn fp_add(a: array, b: array) -> array { - var r: array; - let c = u384_add(a, b, &r); - let cmp = u384_cmp(r, BLS_P); - if (c != 0u || cmp >= 0) { - var s: array; - _ = u384_sub(r, BLS_P, &s); - r = s; - } - return r; -} - -fn fp_sub(a: array, b: array) -> array { - var r: array; - let bw = u384_sub(a, b, &r); - if (bw != 0u) { - var s: array; - _ = u384_add(r, BLS_P, &s); - r = s; - } - return r; -} - -fn fp_neg(a: array) -> array { - if (u384_is_zero(a)) { return a; } - var r: array; - _ = u384_sub(BLS_P, a, &r); - return r; -} - -// Fermat inversion: a^(p-2) mod p. Mirror Metal exactly. -// Note: we operate on 32-bit limbs; bit iteration is over 384 bits MSB->LSB. -fn fp_inv(a: array) -> array { - for (var i = 0u; i < 12u; i = i + 1u) { fp_scratch_inv_exp[i] = BLS_P[i]; } - // exp -= 2 on the lowest limb (low limb is well above 2) - fp_scratch_inv_exp[0] = fp_scratch_inv_exp[0] - 2u; - - for (var i = 0u; i < 12u; i = i + 1u) { fp_scratch_inv_result[i] = BLS_R[i]; } - var started: bool = false; - for (var i = 11i; i >= 0; i = i - 1) { - let ui = u32(i); - for (var bit = 31i; bit >= 0; bit = bit - 1) { - if (started) { - var rcopy: array; - for (var k = 0u; k < 12u; k = k + 1u) { rcopy[k] = fp_scratch_inv_result[k]; } - let s = fp_sqr(rcopy); - for (var k = 0u; k < 12u; k = k + 1u) { fp_scratch_inv_result[k] = s[k]; } - } - let mask = 1u << u32(bit); - if ((fp_scratch_inv_exp[ui] & mask) != 0u) { - if (started) { - var rcopy: array; - for (var k = 0u; k < 12u; k = k + 1u) { rcopy[k] = fp_scratch_inv_result[k]; } - let m = fp_mul(rcopy, a); - for (var k = 0u; k < 12u; k = k + 1u) { fp_scratch_inv_result[k] = m[k]; } - } else { - for (var k = 0u; k < 12u; k = k + 1u) { fp_scratch_inv_result[k] = a[k]; } - started = true; - } - } - } - } - var result: array; - for (var i = 0u; i < 12u; i = i + 1u) { result[i] = fp_scratch_inv_result[i]; } - return result; -} diff --git a/bls/gpu/wgsl/bls_fp_tower_kernels.wgsl b/bls/gpu/wgsl/bls_fp_tower_kernels.wgsl deleted file mode 100644 index af5d844..0000000 --- a/bls/gpu/wgsl/bls_fp_tower_kernels.wgsl +++ /dev/null @@ -1,244 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// WGSL compute kernels for the BLS12-381 Fp tower (Stage 1 + 4 parity). -// Each kernel processes one element per dispatch (1×N grid, workgroup_size(1)) -// to preserve byte-determinism across backends. -// -// Concatenated by the host driver onto bls_fp_ops.wgsl + bls_fp2.wgsl + -// bls_fp6.wgsl + bls_fp12.wgsl in that order. -// -// Upper-tower kernels (fp6_inv, fp12_*) call the out-pointer (_p) form so -// the call tree never materialises a returned array or -// array on the function-call stack. - -@group(0) @binding(0) var in_a: array; -@group(0) @binding(1) var in_b: array; -@group(0) @binding(2) var out: array; -@group(0) @binding(3) var params: vec4; // [count, _, _, _] - -// Helpers — load/store sized arrays from the flat global storage buffers. -fn load_fp_into(buf_ptr: u32, base: u32, dst: ptr>) { - if (buf_ptr == 0u) { - for (var i = 0u; i < 12u; i = i + 1u) { (*dst)[i] = in_a[base + i]; } - } else { - for (var i = 0u; i < 12u; i = i + 1u) { (*dst)[i] = in_b[base + i]; } - } -} -fn store_fp_from(base: u32, src: ptr>) { - for (var i = 0u; i < 12u; i = i + 1u) { out[base + i] = (*src)[i]; } -} -fn load_fp2_into(buf_ptr: u32, base: u32, dst: ptr>) { - if (buf_ptr == 0u) { - for (var i = 0u; i < 24u; i = i + 1u) { (*dst)[i] = in_a[base + i]; } - } else { - for (var i = 0u; i < 24u; i = i + 1u) { (*dst)[i] = in_b[base + i]; } - } -} -fn store_fp2_from(base: u32, src: ptr>) { - for (var i = 0u; i < 24u; i = i + 1u) { out[base + i] = (*src)[i]; } -} -fn load_fp6_into(buf_ptr: u32, base: u32, dst: ptr>) { - if (buf_ptr == 0u) { - for (var i = 0u; i < 72u; i = i + 1u) { (*dst)[i] = in_a[base + i]; } - } else { - for (var i = 0u; i < 72u; i = i + 1u) { (*dst)[i] = in_b[base + i]; } - } -} -fn store_fp6_from(base: u32, src: ptr>) { - for (var i = 0u; i < 72u; i = i + 1u) { out[base + i] = (*src)[i]; } -} -fn load_fp12_into(buf_ptr: u32, base: u32, dst: ptr>) { - if (buf_ptr == 0u) { - for (var i = 0u; i < 144u; i = i + 1u) { (*dst)[i] = in_a[base + i]; } - } else { - for (var i = 0u; i < 144u; i = i + 1u) { (*dst)[i] = in_b[base + i]; } - } -} -fn store_fp12_from(base: u32, src: ptr>) { - for (var i = 0u; i < 144u; i = i + 1u) { out[base + i] = (*src)[i]; } -} - -// ============================================================================= -// Fp diagnostic — raw fp_inv on c0 of Fp2 input. -// ============================================================================= - -@compute @workgroup_size(1) fn k_fp_inv_diag(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; if (i >= params.x) { return; } - let off = i * 24u; - var c0: array; - load_fp_into(0u, off, &c0); - var r = fp_inv(c0); - store_fp_from(off, &r); - for (var k = 12u; k < 24u; k = k + 1u) { out[off + k] = 0u; } -} - -// ============================================================================= -// Fp2 kernels -// ============================================================================= - -@compute @workgroup_size(1) fn k_fp2_add(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; if (i >= params.x) { return; } - let off = i * 24u; - var a: array; load_fp2_into(0u, off, &a); - var b: array; load_fp2_into(1u, off, &b); - var r: array; fp2_add_p(&a, &b, &r); - store_fp2_from(off, &r); -} -@compute @workgroup_size(1) fn k_fp2_sub(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; if (i >= params.x) { return; } - let off = i * 24u; - var a: array; load_fp2_into(0u, off, &a); - var b: array; load_fp2_into(1u, off, &b); - var r: array; fp2_sub_p(&a, &b, &r); - store_fp2_from(off, &r); -} -@compute @workgroup_size(1) fn k_fp2_mul(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; if (i >= params.x) { return; } - let off = i * 24u; - var a: array; load_fp2_into(0u, off, &a); - var b: array; load_fp2_into(1u, off, &b); - var r: array; fp2_mul_p(&a, &b, &r); - store_fp2_from(off, &r); -} -@compute @workgroup_size(1) fn k_fp2_sqr(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; if (i >= params.x) { return; } - let off = i * 24u; - var a: array; load_fp2_into(0u, off, &a); - var r: array; fp2_sqr_p(&a, &r); - store_fp2_from(off, &r); -} -@compute @workgroup_size(1) fn k_fp2_inv(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; if (i >= params.x) { return; } - let off = i * 24u; - var a: array; load_fp2_into(0u, off, &a); - var r: array; fp2_inv_p(&a, &r); - store_fp2_from(off, &r); -} -@compute @workgroup_size(1) fn k_fp2_conj(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; if (i >= params.x) { return; } - let off = i * 24u; - var a: array; load_fp2_into(0u, off, &a); - var r: array; fp2_conj_p(&a, &r); - store_fp2_from(off, &r); -} - -// ============================================================================= -// Fp6 kernels -// ============================================================================= - -@compute @workgroup_size(1) fn k_fp6_add(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; if (i >= params.x) { return; } - let off = i * 72u; - var a: array; load_fp6_into(0u, off, &a); - var b: array; load_fp6_into(1u, off, &b); - var r: array; fp6_add_p(&a, &b, &r); - store_fp6_from(off, &r); -} -@compute @workgroup_size(1) fn k_fp6_sub(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; if (i >= params.x) { return; } - let off = i * 72u; - var a: array; load_fp6_into(0u, off, &a); - var b: array; load_fp6_into(1u, off, &b); - var r: array; fp6_sub_p(&a, &b, &r); - store_fp6_from(off, &r); -} -@compute @workgroup_size(1) fn k_fp6_mul(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; if (i >= params.x) { return; } - let off = i * 72u; - var a: array; load_fp6_into(0u, off, &a); - var b: array; load_fp6_into(1u, off, &b); - var r: array; fp6_mul_p(&a, &b, &r); - store_fp6_from(off, &r); -} -@compute @workgroup_size(1) fn k_fp6_sqr(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; if (i >= params.x) { return; } - let off = i * 72u; - var a: array; load_fp6_into(0u, off, &a); - var r: array; fp6_sqr_p(&a, &r); - store_fp6_from(off, &r); -} -@compute @workgroup_size(1) fn k_fp6_inv(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; if (i >= params.x) { return; } - let off = i * 72u; - var a: array; load_fp6_into(0u, off, &a); - var r: array; fp6_inv_p(&a, &r); - store_fp6_from(off, &r); -} - -// ============================================================================= -// Fp12 kernels -// -// Each kernel runs at workgroup_size(1) with one dispatch per element, so -// we keep the 144 x u32 operands in private storage. This keeps the -// function-call stack budget under AGXMetalG13X's per-thread limit when -// the karatsuba/inversion call tree (fp6_inv inside fp12_inv) is traversed. -// ============================================================================= - -var g_fp12_a: array; -var g_fp12_b: array; -var g_fp12_r: array; - -fn load_fp12_priv(buf_ptr: u32, base: u32, dst: ptr>) { - if (buf_ptr == 0u) { - for (var i = 0u; i < 144u; i = i + 1u) { (*dst)[i] = in_a[base + i]; } - } else { - for (var i = 0u; i < 144u; i = i + 1u) { (*dst)[i] = in_b[base + i]; } - } -} -fn store_fp12_priv(base: u32, src: ptr>) { - for (var i = 0u; i < 144u; i = i + 1u) { out[base + i] = (*src)[i]; } -} - -@compute @workgroup_size(1) fn k_fp12_add(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; if (i >= params.x) { return; } - let off = i * 144u; - load_fp12_priv(0u, off, &g_fp12_a); - load_fp12_priv(1u, off, &g_fp12_b); - fp12_add_priv(&g_fp12_a, &g_fp12_b, &g_fp12_r); - store_fp12_priv(off, &g_fp12_r); -} -@compute @workgroup_size(1) fn k_fp12_sub(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; if (i >= params.x) { return; } - let off = i * 144u; - load_fp12_priv(0u, off, &g_fp12_a); - load_fp12_priv(1u, off, &g_fp12_b); - fp12_sub_priv(&g_fp12_a, &g_fp12_b, &g_fp12_r); - store_fp12_priv(off, &g_fp12_r); -} -@compute @workgroup_size(1) fn k_fp12_mul(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; if (i >= params.x) { return; } - let off = i * 144u; - load_fp12_priv(0u, off, &g_fp12_a); - load_fp12_priv(1u, off, &g_fp12_b); - fp12_mul_priv(&g_fp12_a, &g_fp12_b, &g_fp12_r); - store_fp12_priv(off, &g_fp12_r); -} -@compute @workgroup_size(1) fn k_fp12_sqr(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; if (i >= params.x) { return; } - let off = i * 144u; - load_fp12_priv(0u, off, &g_fp12_a); - fp12_sqr_priv(&g_fp12_a, &g_fp12_r); - store_fp12_priv(off, &g_fp12_r); -} -@compute @workgroup_size(1) fn k_fp12_inv(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; if (i >= params.x) { return; } - let off = i * 144u; - load_fp12_priv(0u, off, &g_fp12_a); - fp12_inv_priv(&g_fp12_a, &g_fp12_r); - store_fp12_priv(off, &g_fp12_r); -} -@compute @workgroup_size(1) fn k_fp12_conj(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; if (i >= params.x) { return; } - let off = i * 144u; - load_fp12_priv(0u, off, &g_fp12_a); - fp12_conj_priv(&g_fp12_a, &g_fp12_r); - store_fp12_priv(off, &g_fp12_r); -} -@compute @workgroup_size(1) fn k_fp12_cyclo_sqr(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; if (i >= params.x) { return; } - let off = i * 144u; - load_fp12_priv(0u, off, &g_fp12_a); - fp12_cyclotomic_sqr_priv(&g_fp12_a, &g_fp12_r); - store_fp12_priv(off, &g_fp12_r); -} diff --git a/bn254/gpu/cuda/bn254.cu b/bn254/gpu/cuda/bn254.cu deleted file mode 100644 index 14ad918..0000000 --- a/bn254/gpu/cuda/bn254.cu +++ /dev/null @@ -1,618 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// First-party CUDA kernels for bn254 (alt_bn128). -// -// Algorithm transliteration of bn254/cpp/{bn254_fp,bn254_g1,bn254_fp2, -// bn254_fp6,bn254_fp12,bn254_g2,bn254_pairing,bn254_hash_to_curve}.hpp. -// All field elements stored in Montgomery form, 4 x uint64_t little-endian -// limbs, identical layout to the CPU oracle so byte-equality holds. -// -// Algorithms: -// * Fp: CIOS Montgomery multiplication (HAC §14.36) -// * Fp2: Karatsuba over Fp[u]/(u^2+1) -// * G1: Bernstein-Lange efl/jacobian-0/{dbl-2009-l, add-2007-bl} -// * G1 mul: constant-time Montgomery ladder over 256 bits (no early exit) -// * SVDW: RFC 9380 §6.6.1 map_to_curve_svdw -// -// All algorithms are documented in the cited references; this file copies -// no upstream source. -// -// Build modes: -// With CUDA toolkit: kernels compiled by nvcc; host driver dispatches and -// copies results back. Byte-equal to CPU oracle. -// Without CUDA: file is not compiled. Driver in bn254_driver_cuda.cpp -// runs the CPU oracle as the deterministic fallback so -// tests pass on CPU-only hosts (still 100/100 byte-equal -// to CPU oracle, just without device round-trip). -// -// On a CI runner with NVIDIA hardware, the same vectors as the CPU oracle are -// dispatched through this kernel and the byte-equality test asserts identical -// output for G1Add, G1Mul, HashToG1 (SVDW map only -- expand_message_xmd is -// host-side SHA-256). - -#ifdef LUX_BN254_HAVE_CUDA - -#include -#include -#include - -// ============================================================================= -// Limb layout & constants -// ============================================================================= - -using u64 = unsigned long long; - -struct U256 { - u64 limbs[4]; -}; - -struct G1Aff { - U256 x, y; - int inf; // 1 -> point at infinity -}; - -struct G1Jac { - U256 X, Y, Z; - int inf; -}; - -struct Fp2 { - U256 a0, a1; -}; - -// p, R, R^2 mod p, -p^-1 mod 2^64 (CPU constants verbatim). -__constant__ u64 K_P[4] = { 0x3C208C16D87CFD47ULL, 0x97816A916871CA8DULL, - 0xB85045B68181585DULL, 0x30644E72E131A029ULL }; -__constant__ u64 K_R[4] = { 0xD35D438DC58F0D9DULL, 0x0A78EB28F5C70B3DULL, - 0x666EA36F7879462CULL, 0x0E0A77C19A07DF2FULL }; -__constant__ u64 K_R2[4] = { 0xF32CFC5B538AFA89ULL, 0xB5E71911D44501FBULL, - 0x47AB1EFF0A417FF6ULL, 0x06D89F71CAB8351FULL }; -__constant__ u64 K_PINV = 0x87D20782E4866389ULL; -__constant__ u64 K_PM2[4] = { 0x3C208C16D87CFD45ULL, 0x97816A916871CA8DULL, - 0xB85045B68181585DULL, 0x30644E72E131A029ULL }; -__constant__ u64 K_PP1_4[4]= { 0x4F082305B61F3F52ULL, 0x65E05AA45A1C72A3ULL, - 0x6E14116DA0605617ULL, 0x0C19139CB84C680AULL }; - -// SVDW constants (plain), to_mont done on-device. -__constant__ u64 K_SVDW_Z[4] = { 1ULL, 0, 0, 0 }; -__constant__ u64 K_SVDW_C1[4] = { 4ULL, 0, 0, 0 }; -__constant__ u64 K_SVDW_C2[4] = { 0x9E10460B6C3E7EA3ULL, 0xCBC0B548B438E546ULL, - 0xDC2822DB40C0AC2EULL, 0x183227397098D014ULL }; -__constant__ u64 K_SVDW_C3[4] = { 0x5D8D1CC5DFFFFFFAULL, 0x53C98FC6B36D713DULL, - 0x6789AF3A83522EB3ULL, 0x0000000000000001ULL }; -__constant__ u64 K_SVDW_C4[4] = { 0x69602EB24829A9BDULL, 0xDD2B2385CD7B4384ULL, - 0xE81AC1E7808072C9ULL, 0x10216F7BA065E00DULL }; - -// ============================================================================= -// Big-int helpers (4 limbs) -// ============================================================================= - -__device__ __forceinline__ int u256_cmp(const U256& a, const U256& b) { - #pragma unroll - for (int i = 3; i >= 0; --i) { - if (a.limbs[i] < b.limbs[i]) return -1; - if (a.limbs[i] > b.limbs[i]) return 1; - } - return 0; -} - -__device__ __forceinline__ bool u256_is_zero(const U256& a) { - return (a.limbs[0] | a.limbs[1] | a.limbs[2] | a.limbs[3]) == 0; -} - -__device__ __forceinline__ bool u256_eq(const U256& a, const U256& b) { - return a.limbs[0] == b.limbs[0] && a.limbs[1] == b.limbs[1] && - a.limbs[2] == b.limbs[2] && a.limbs[3] == b.limbs[3]; -} - -// add a+b, return carry-out -__device__ __forceinline__ u64 add_256(const U256& a, const U256& b, U256& r) { - u64 c = 0; - #pragma unroll - for (int i = 0; i < 4; ++i) { - u64 ai = a.limbs[i], bi = b.limbs[i]; - u64 s = ai + bi; - u64 c1 = (s < ai) ? 1ULL : 0ULL; - u64 s2 = s + c; - u64 c2 = (s2 < s) ? 1ULL : 0ULL; - r.limbs[i] = s2; - c = c1 + c2; - } - return c; -} - -// sub a-b, return borrow-out -__device__ __forceinline__ u64 sub_256(const U256& a, const U256& b, U256& r) { - u64 bw = 0; - #pragma unroll - for (int i = 0; i < 4; ++i) { - u64 ai = a.limbs[i], bi = b.limbs[i]; - u64 d = ai - bi; - u64 b1 = (ai < bi) ? 1ULL : 0ULL; - u64 d2 = d - bw; - u64 b2 = (d < bw) ? 1ULL : 0ULL; - r.limbs[i] = d2; - bw = b1 + b2; - } - return bw; -} - -__device__ __forceinline__ U256 u256_from_limbs(u64 a, u64 b, u64 c, u64 d) { - U256 r; r.limbs[0]=a; r.limbs[1]=b; r.limbs[2]=c; r.limbs[3]=d; - return r; -} - -__device__ __forceinline__ U256 u256_load(const u64* p) { - U256 r; r.limbs[0]=p[0]; r.limbs[1]=p[1]; r.limbs[2]=p[2]; r.limbs[3]=p[3]; - return r; -} - -// ============================================================================= -// Modular add/sub (mod m) -// ============================================================================= - -__device__ __forceinline__ U256 mod_add(const U256& a, const U256& b, const u64* m) { - U256 r; - u64 c = add_256(a, b, r); - U256 mm = u256_load(m); - if (c != 0 || u256_cmp(r, mm) >= 0) { - U256 t; sub_256(r, mm, t); r = t; - } - return r; -} - -__device__ __forceinline__ U256 mod_sub(const U256& a, const U256& b, const u64* m) { - U256 r; - u64 bw = add_256(a, U256{0,0,0,0}, r); // copy - bw = sub_256(a, b, r); - if (bw) { - U256 mm = u256_load(m); - U256 t; add_256(r, mm, t); r = t; - } - return r; -} - -// ============================================================================= -// CIOS Montgomery multiplication (matches CPU mont_mul exactly) -// ============================================================================= - -__device__ __forceinline__ U256 mont_mul(const U256& a, const U256& b, const u64* m, u64 m_inv) { - u64 t[6] = {0,0,0,0,0,0}; - - #pragma unroll - for (int i = 0; i < 4; ++i) { - // t += a * b[i] - u64 carry = 0; - #pragma unroll - for (int j = 0; j < 4; ++j) { - u64 lo = a.limbs[j] * b.limbs[i]; - u64 hi = __umul64hi(a.limbs[j], b.limbs[i]); - u64 s1 = t[j] + lo; - u64 c1 = (s1 < t[j]) ? 1ULL : 0ULL; - u64 s2 = s1 + carry; - u64 c2 = (s2 < s1) ? 1ULL : 0ULL; - t[j] = s2; - carry = hi + c1 + c2; - } - // t[4] += carry; t[5] += carry-out of t[4]+carry - u64 s4 = t[4] + carry; - u64 c4 = (s4 < t[4]) ? 1ULL : 0ULL; - t[4] = s4; - t[5] += c4; - - // u = t[0] * m_inv (mod 2^64) - u64 u = t[0] * m_inv; - - // t += u * m - carry = 0; - #pragma unroll - for (int j = 0; j < 4; ++j) { - u64 lo = u * m[j]; - u64 hi = __umul64hi(u, m[j]); - u64 s1 = t[j] + lo; - u64 c1 = (s1 < t[j]) ? 1ULL : 0ULL; - u64 s2 = s1 + carry; - u64 c2 = (s2 < s1) ? 1ULL : 0ULL; - t[j] = s2; - carry = hi + c1 + c2; - } - u64 s4b = t[4] + carry; - u64 c4b = (s4b < t[4]) ? 1ULL : 0ULL; - t[4] = s4b; - t[5] += c4b; - - // shift right one limb - #pragma unroll - for (int j = 0; j < 5; ++j) t[j] = t[j+1]; - t[5] = 0; - } - - U256 r; r.limbs[0]=t[0]; r.limbs[1]=t[1]; r.limbs[2]=t[2]; r.limbs[3]=t[3]; - U256 mm = u256_load(m); - if (t[4] != 0 || u256_cmp(r, mm) >= 0) { - U256 q; sub_256(r, mm, q); r = q; - } - return r; -} - -__device__ __forceinline__ U256 to_mont_fp(const U256& a) { - return mont_mul(a, u256_load(K_R2), K_P, K_PINV); -} - -__device__ __forceinline__ U256 from_mont_fp(const U256& a) { - U256 ONE = u256_from_limbs(1,0,0,0); - return mont_mul(a, ONE, K_P, K_PINV); -} - -// ============================================================================= -// Fp ops -// ============================================================================= - -__device__ __forceinline__ U256 fp_add(const U256& a, const U256& b) { return mod_add(a, b, K_P); } -__device__ __forceinline__ U256 fp_sub(const U256& a, const U256& b) { return mod_sub(a, b, K_P); } -__device__ __forceinline__ U256 fp_neg(const U256& a) { - if (u256_is_zero(a)) return a; - U256 mm = u256_load(K_P), r; - sub_256(mm, a, r); - return r; -} -__device__ __forceinline__ U256 fp_mul(const U256& a, const U256& b) { return mont_mul(a, b, K_P, K_PINV); } -__device__ __forceinline__ U256 fp_sqr(const U256& a) { return mont_mul(a, a, K_P, K_PINV); } - -__device__ U256 fp_pow(const U256& a_mont, const u64* e) { - U256 ONE_PLAIN = u256_from_limbs(1,0,0,0); - U256 result = to_mont_fp(ONE_PLAIN); - U256 base = a_mont; - #pragma unroll - for (int limb = 0; limb < 4; ++limb) { - u64 w = e[limb]; - for (int bit = 0; bit < 64; ++bit) { - if ((w >> bit) & 1ULL) result = fp_mul(result, base); - base = fp_sqr(base); - } - } - return result; -} - -__device__ U256 fp_inv(const U256& a) { return fp_pow(a, K_PM2); } - -__device__ bool fp_sqrt(const U256& a_mont, U256& out) { - U256 cand = fp_pow(a_mont, K_PP1_4); - if (!u256_eq(fp_sqr(cand), a_mont)) return false; - out = cand; - return true; -} - -__device__ __forceinline__ U256 fp_three() { - return to_mont_fp(u256_from_limbs(3,0,0,0)); -} - -// ============================================================================= -// G1 Jacobian (Bernstein-Lange efl/jacobian-0/{dbl-2009-l, add-2007-bl}) -// ============================================================================= - -__device__ G1Jac g1_jac_zero() { - G1Jac r{}; r.inf = 1; return r; -} - -__device__ G1Jac g1_to_jac(const G1Aff& p) { - if (p.inf) return g1_jac_zero(); - G1Jac r; r.X = p.x; r.Y = p.y; r.Z = u256_load(K_R); r.inf = 0; - return r; -} - -__device__ G1Aff g1_to_affine(const G1Jac& p) { - G1Aff a; - if (p.inf || u256_is_zero(p.Z)) { - a.x = U256{0,0,0,0}; a.y = U256{0,0,0,0}; a.inf = 1; - return a; - } - U256 z_inv = fp_inv(p.Z); - U256 z_inv2 = fp_sqr(z_inv); - U256 z_inv3 = fp_mul(z_inv2, z_inv); - a.x = fp_mul(p.X, z_inv2); - a.y = fp_mul(p.Y, z_inv3); - a.inf = 0; - return a; -} - -__device__ G1Jac g1_double(const G1Jac& p) { - if (p.inf) return p; - if (u256_is_zero(p.Y)) return g1_jac_zero(); - - U256 A = fp_sqr(p.X); - U256 B = fp_sqr(p.Y); - U256 C = fp_sqr(B); - - U256 X_plus_B = fp_add(p.X, B); - U256 D = fp_sub(fp_sqr(X_plus_B), A); - D = fp_sub(D, C); - D = fp_add(D, D); - - U256 E = fp_add(A, A); - E = fp_add(E, A); - U256 F = fp_sqr(E); - - U256 two_D = fp_add(D, D); - U256 X3 = fp_sub(F, two_D); - - U256 D_minus_X3 = fp_sub(D, X3); - U256 eight_C = fp_add(C, C); - eight_C = fp_add(eight_C, eight_C); - eight_C = fp_add(eight_C, eight_C); - U256 Y3 = fp_sub(fp_mul(E, D_minus_X3), eight_C); - - U256 Z3 = fp_mul(p.Y, p.Z); - Z3 = fp_add(Z3, Z3); - - G1Jac r; r.X = X3; r.Y = Y3; r.Z = Z3; r.inf = 0; - return r; -} - -__device__ G1Jac g1_add(const G1Jac& a, const G1Jac& b) { - if (a.inf) return b; - if (b.inf) return a; - - U256 Z1Z1 = fp_sqr(a.Z); - U256 Z2Z2 = fp_sqr(b.Z); - U256 U1 = fp_mul(a.X, Z2Z2); - U256 U2 = fp_mul(b.X, Z1Z1); - U256 S1 = fp_mul(fp_mul(a.Y, b.Z), Z2Z2); - U256 S2 = fp_mul(fp_mul(b.Y, a.Z), Z1Z1); - - U256 H = fp_sub(U2, U1); - if (u256_is_zero(H)) { - if (u256_eq(S1, S2)) return g1_double(a); - return g1_jac_zero(); - } - - U256 two_H = fp_add(H, H); - U256 I = fp_sqr(two_H); - U256 J = fp_mul(H, I); - - U256 r_ = fp_sub(S2, S1); - r_ = fp_add(r_, r_); - - U256 V = fp_mul(U1, I); - - U256 X3 = fp_sub(fp_sub(fp_sqr(r_), J), fp_add(V, V)); - U256 Y3 = fp_sub(fp_mul(r_, fp_sub(V, X3)), fp_mul(fp_add(S1, S1), J)); - U256 Z3 = fp_sub(fp_sub(fp_sqr(fp_add(a.Z, b.Z)), Z1Z1), Z2Z2); - Z3 = fp_mul(Z3, H); - - G1Jac out; out.X = X3; out.Y = Y3; out.Z = Z3; out.inf = 0; - return out; -} - -__device__ __forceinline__ void g1_cmov(G1Jac& dst, const G1Jac& src, u64 cond) { - u64 mask = (u64)0 - (cond & 1ULL); - #pragma unroll - for (int i = 0; i < 4; ++i) { - dst.X.limbs[i] ^= mask & (dst.X.limbs[i] ^ src.X.limbs[i]); - dst.Y.limbs[i] ^= mask & (dst.Y.limbs[i] ^ src.Y.limbs[i]); - dst.Z.limbs[i] ^= mask & (dst.Z.limbs[i] ^ src.Z.limbs[i]); - } - dst.inf = (cond & 1ULL) ? src.inf : dst.inf; -} - -// Constant-time Montgomery ladder, 256 bits, no early exit. Mirrors CPU. -__device__ G1Jac g1_scalar_mul(const G1Aff& p, const U256& k) { - if (p.inf) return g1_jac_zero(); - - G1Jac R0 = g1_jac_zero(); - G1Jac R1 = g1_to_jac(p); - - for (int i = 255; i >= 0; --i) { - u64 bit = (k.limbs[i >> 6] >> (i & 63)) & 1ULL; - - G1Jac sum = g1_add(R0, R1); - G1Jac dbl0 = g1_double(R0); - G1Jac dbl1 = g1_double(R1); - - G1Jac next_R0 = dbl0; - g1_cmov(next_R0, sum, bit); - G1Jac next_R1 = sum; - g1_cmov(next_R1, dbl1, bit); - - R0 = next_R0; - R1 = next_R1; - } - return R0; -} - -// ============================================================================= -// Tower (Fp2/Fp6/Fp12) + G2 + optimal-ate pairing. -// Header keeps the file pair (bn254.cu + bn254_pairing.cuh) under one TU so -// register pressure of the existing G1 kernels is unchanged. -// ============================================================================= - -#include "bn254_pairing.cuh" - -// ============================================================================= -// SVDW map_to_curve (RFC 9380 §6.6.1) -// ============================================================================= - -__device__ __forceinline__ int fp_sgn0(const U256& a_mont) { - U256 plain = from_mont_fp(a_mont); - return (int)(plain.limbs[0] & 1u); -} - -__device__ __forceinline__ U256 fp_g_x(const U256& x_mont) { - U256 x2 = fp_sqr(x_mont); - U256 x3 = fp_mul(x2, x_mont); - return fp_add(x3, fp_three()); -} - -__device__ G1Aff svdw_map(const U256& u_mont) { - U256 ONE = u256_load(K_R); // R = 1 in Montgomery form - U256 Z = to_mont_fp(u256_load(K_SVDW_Z)); - U256 c1 = to_mont_fp(u256_load(K_SVDW_C1)); - U256 c2 = to_mont_fp(u256_load(K_SVDW_C2)); - U256 c3 = to_mont_fp(u256_load(K_SVDW_C3)); - U256 c4 = to_mont_fp(u256_load(K_SVDW_C4)); - - U256 tv1 = fp_sqr(u_mont); - tv1 = fp_mul(tv1, c1); - U256 tv2 = fp_add(ONE, tv1); - tv1 = fp_sub(ONE, tv1); - U256 tv3 = fp_mul(tv1, tv2); - tv3 = fp_inv(tv3); - U256 tv4 = fp_mul(u_mont, tv1); - tv4 = fp_mul(tv4, tv3); - tv4 = fp_mul(tv4, c3); - U256 x1 = fp_sub(c2, tv4); - - U256 gx1 = fp_g_x(x1); - U256 y1{}; - bool gx1_sq = fp_sqrt(gx1, y1); - - U256 x2 = fp_add(c2, tv4); - U256 gx2 = fp_g_x(x2); - U256 y2{}; - bool gx2_sq = fp_sqrt(gx2, y2); - - U256 x3 = fp_sqr(tv2); - x3 = fp_mul(x3, tv3); - x3 = fp_sqr(x3); - x3 = fp_mul(x3, c4); - x3 = fp_add(x3, Z); - - U256 x = gx1_sq ? x1 : x3; - if (gx2_sq && !gx1_sq) x = x2; - - U256 gx = fp_g_x(x); - U256 y{}; - fp_sqrt(gx, y); - - if (fp_sgn0(u_mont) != fp_sgn0(y)) y = fp_neg(y); - - G1Aff r; r.x = x; r.y = y; r.inf = 0; - return r; -} - -// ============================================================================= -// Kernels -// ============================================================================= -// -// Convention for I/O: -// - All field elements transferred as 4 x u64 LE limbs in Montgomery form. -// - Affine point: (x, y, inf=0/1) packed as 8 x u64 + one u32 (padded to 9 u64 -// to keep 64-byte boundary; we use 9 u64 = 72 bytes per affine point). -// -// Match CPU oracle byte-for-byte. - -extern "C" { - -// k_g1_add: out[i] = a[i] + b[i] in Jacobian, then to-affine. -__global__ void k_g1_add(const u64* a, const u64* b, u64* out, unsigned n) { - unsigned i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= n) return; - G1Aff A, B; - A.x = u256_load(a + i*9 + 0); A.y = u256_load(a + i*9 + 4); A.inf = (int)a[i*9+8]; - B.x = u256_load(b + i*9 + 0); B.y = u256_load(b + i*9 + 4); B.inf = (int)b[i*9+8]; - - G1Jac Ja = g1_to_jac(A); - G1Jac Jb = g1_to_jac(B); - G1Jac S = g1_add(Ja, Jb); - G1Aff R = g1_to_affine(S); - - out[i*9 + 0] = R.x.limbs[0]; out[i*9 + 1] = R.x.limbs[1]; - out[i*9 + 2] = R.x.limbs[2]; out[i*9 + 3] = R.x.limbs[3]; - out[i*9 + 4] = R.y.limbs[0]; out[i*9 + 5] = R.y.limbs[1]; - out[i*9 + 6] = R.y.limbs[2]; out[i*9 + 7] = R.y.limbs[3]; - out[i*9 + 8] = (u64)R.inf; -} - -// k_g1_mul: out[i] = scalar[i] * p[i]. -__global__ void k_g1_mul(const u64* points, const u64* scalars, u64* out, unsigned n) { - unsigned i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= n) return; - G1Aff P; - P.x = u256_load(points + i*9 + 0); P.y = u256_load(points + i*9 + 4); - P.inf = (int)points[i*9 + 8]; - U256 k = u256_load(scalars + i*4); - - G1Jac S = g1_scalar_mul(P, k); - G1Aff R = g1_to_affine(S); - - out[i*9 + 0] = R.x.limbs[0]; out[i*9 + 1] = R.x.limbs[1]; - out[i*9 + 2] = R.x.limbs[2]; out[i*9 + 3] = R.x.limbs[3]; - out[i*9 + 4] = R.y.limbs[0]; out[i*9 + 5] = R.y.limbs[1]; - out[i*9 + 6] = R.y.limbs[2]; out[i*9 + 7] = R.y.limbs[3]; - out[i*9 + 8] = (u64)R.inf; -} - -// k_svdw: out[i] = svdw_map(u[i]). -__global__ void k_svdw(const u64* u_in, u64* out, unsigned n) { - unsigned i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= n) return; - U256 u = u256_load(u_in + i*4); - G1Aff R = svdw_map(u); - out[i*9 + 0] = R.x.limbs[0]; out[i*9 + 1] = R.x.limbs[1]; - out[i*9 + 2] = R.x.limbs[2]; out[i*9 + 3] = R.x.limbs[3]; - out[i*9 + 4] = R.y.limbs[0]; out[i*9 + 5] = R.y.limbs[1]; - out[i*9 + 6] = R.y.limbs[2]; out[i*9 + 7] = R.y.limbs[3]; - out[i*9 + 8] = (u64)R.inf; -} - -// k_fp_mul: byte-equality smoke test (CIOS Montgomery mul). -__global__ void k_fp_mul(const u64* a, const u64* b, u64* out, unsigned n) { - unsigned i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= n) return; - U256 A = u256_load(a + i*4); - U256 B = u256_load(b + i*4); - U256 R = fp_mul(A, B); - out[i*4+0] = R.limbs[0]; out[i*4+1] = R.limbs[1]; - out[i*4+2] = R.limbs[2]; out[i*4+3] = R.limbs[3]; -} - -// k_fp2_mul: out[i] = a[i] * b[i] in Fp2 (8 u64 each). -__global__ void k_fp2_mul(const u64* a, const u64* b, u64* out, unsigned n) { - unsigned i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= n) return; - Fp2 A, B; - load_fp2_(A, a + i*8); - load_fp2_(B, b + i*8); - Fp2 R = fp2_mul_(A, B); - store_fp2_(out + i*8, R); -} - -// k_fp12_mul: out[i] = a[i] * b[i] in Fp12 (48 u64 each). -__global__ void k_fp12_mul(const u64* a, const u64* b, u64* out, unsigned n) { - unsigned i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= n) return; - Fp12_ A, B; - load_fp12_(A, a + i*48); - load_fp12_(B, b + i*48); - Fp12_ R = fp12_mul_(A, B); - store_fp12_(out + i*48, R); -} - -// k_miller_iter: 100 cyclotomic-square iterations on a starting Fp12 to -// stress-test the inner-loop squaring path. Matches the CPU oracle exactly. -__global__ void k_miller_iter(const u64* in, u64* out, unsigned n) { - unsigned i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= n) return; - Fp12_ A; load_fp12_(A, in + i*48); - for (int k = 0; k < 100; ++k) A = cyclotomic_sqr_(A); - store_fp12_(out + i*48, A); -} - -// k_pairing: out[i] = e(P[i], Q[i]) in Fp12 (Miller + final-exp). -__global__ void k_pairing(const u64* P, const u64* Q, u64* out, unsigned n) { - unsigned i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= n) return; - G1Aff Pi; - Pi.x = u256_load(P + i*9 + 0); - Pi.y = u256_load(P + i*9 + 4); - Pi.inf = (int)P[i*9 + 8]; - G2Aff Qi; load_g2_(Qi, Q + i*18); - Fp12_ m = miller_one_(Pi, Qi); - Fp12_ e = final_exp_(m); - store_fp12_(out + i*48, e); -} - -} // extern "C" - -#endif // LUX_BN254_HAVE_CUDA diff --git a/bn254/gpu/cuda/bn254_driver_cuda.cpp b/bn254/gpu/cuda/bn254_driver_cuda.cpp deleted file mode 100644 index d728b35..0000000 --- a/bn254/gpu/cuda/bn254_driver_cuda.cpp +++ /dev/null @@ -1,315 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Host-side CUDA driver for bn254 kernels. -// -// Two compile modes: -// 1. LUX_BN254_HAVE_CUDA defined: dispatches kernels via cudaMemcpy/launch, -// identical algorithm to the CPU oracle (bn254/cpp/*.hpp), byte-equal -// results. -// 2. LUX_BN254_HAVE_CUDA undefined: runs the CPU oracle directly so the -// determinism harness still passes 100/100. Reports unavailable via -// lux_bn254_cuda_available() so tests can label the path correctly. -// -// The byte-equality CI runner enables LUX_BN254_HAVE_CUDA and exercises the -// real kernel; on CPU-only laptops/CI the same 100 deterministic inputs flow -// through the CPU oracle path and the harness still verifies the wire format. - -#include "bn254_driver_cuda.h" -#include "bn254.hpp" - -#include -#include - -#ifdef LUX_BN254_HAVE_CUDA -#include - -extern "C" { -__global__ void k_g1_add(const unsigned long long*, const unsigned long long*, - unsigned long long*, unsigned); -__global__ void k_g1_mul(const unsigned long long*, const unsigned long long*, - unsigned long long*, unsigned); -__global__ void k_svdw (const unsigned long long*, unsigned long long*, unsigned); -__global__ void k_fp_mul(const unsigned long long*, const unsigned long long*, - unsigned long long*, unsigned); -__global__ void k_fp2_mul(const unsigned long long*, const unsigned long long*, - unsigned long long*, unsigned); -__global__ void k_fp12_mul(const unsigned long long*, const unsigned long long*, - unsigned long long*, unsigned); -__global__ void k_miller_iter(const unsigned long long*, unsigned long long*, unsigned); -__global__ void k_pairing(const unsigned long long*, const unsigned long long*, - unsigned long long*, unsigned); -} - -namespace { - -bool device_present() { - int n = 0; - cudaError_t e = cudaGetDeviceCount(&n); - return (e == cudaSuccess && n > 0); -} - -unsigned grid_for(unsigned n, unsigned tg) { return (n + tg - 1) / tg; } - -template -int dispatch3(Kernel kf, const void* a, const void* b, void* out, - size_t bytes_in_a, size_t bytes_in_b, size_t bytes_out, unsigned n) { - void *dA=nullptr,*dB=nullptr,*dO=nullptr; - if (cudaMalloc(&dA, bytes_in_a) != cudaSuccess) return -1; - if (cudaMalloc(&dB, bytes_in_b) != cudaSuccess) { cudaFree(dA); return -1; } - if (cudaMalloc(&dO, bytes_out) != cudaSuccess) { cudaFree(dA); cudaFree(dB); return -1; } - cudaMemcpy(dA, a, bytes_in_a, cudaMemcpyHostToDevice); - cudaMemcpy(dB, b, bytes_in_b, cudaMemcpyHostToDevice); - unsigned tg = 64; unsigned grid = grid_for(n, tg); - kf<<>>((const unsigned long long*)dA, (const unsigned long long*)dB, - (unsigned long long*)dO, n); - cudaDeviceSynchronize(); - cudaMemcpy(out, dO, bytes_out, cudaMemcpyDeviceToHost); - cudaFree(dA); cudaFree(dB); cudaFree(dO); - return 0; -} - -template -int dispatch2(Kernel kf, const void* a, void* out, - size_t bytes_in, size_t bytes_out, unsigned n) { - void *dA=nullptr,*dO=nullptr; - if (cudaMalloc(&dA, bytes_in) != cudaSuccess) return -1; - if (cudaMalloc(&dO, bytes_out) != cudaSuccess) { cudaFree(dA); return -1; } - cudaMemcpy(dA, a, bytes_in, cudaMemcpyHostToDevice); - unsigned tg = 64; unsigned grid = grid_for(n, tg); - kf<<>>((const unsigned long long*)dA, (unsigned long long*)dO, n); - cudaDeviceSynchronize(); - cudaMemcpy(out, dO, bytes_out, cudaMemcpyDeviceToHost); - cudaFree(dA); cudaFree(dO); - return 0; -} - -} // namespace - -extern "C" { - -int lux_bn254_cuda_available(void) { return device_present() ? 1 : 0; } - -int lux_bn254_cuda_g1_add(const void* a, const void* b, void* out, unsigned n) { - if (!device_present()) return -1; - return dispatch3(k_g1_add, a, b, out, 9*8*n, 9*8*n, 9*8*n, n); -} - -int lux_bn254_cuda_g1_mul(const void* points, const void* scalars, void* out, unsigned n) { - if (!device_present()) return -1; - return dispatch3(k_g1_mul, points, scalars, out, 9*8*n, 4*8*n, 9*8*n, n); -} - -int lux_bn254_cuda_svdw(const void* u_in, void* out, unsigned n) { - if (!device_present()) return -1; - return dispatch2(k_svdw, u_in, out, 4*8*n, 9*8*n, n); -} - -int lux_bn254_cuda_fp_mul(const void* a, const void* b, void* out, unsigned n) { - if (!device_present()) return -1; - return dispatch3(k_fp_mul, a, b, out, 4*8*n, 4*8*n, 4*8*n, n); -} - -int lux_bn254_cuda_fp2_mul(const void* a, const void* b, void* out, unsigned n) { - if (!device_present()) return -1; - return dispatch3(k_fp2_mul, a, b, out, 8*8*n, 8*8*n, 8*8*n, n); -} - -int lux_bn254_cuda_fp12_mul(const void* a, const void* b, void* out, unsigned n) { - if (!device_present()) return -1; - return dispatch3(k_fp12_mul, a, b, out, 48*8*n, 48*8*n, 48*8*n, n); -} - -int lux_bn254_cuda_miller_iter(const void* in_p, void* out, unsigned n) { - if (!device_present()) return -1; - return dispatch2(k_miller_iter, in_p, out, 48*8*n, 48*8*n, n); -} - -int lux_bn254_cuda_pairing(const void* P, const void* Q, void* out, unsigned n) { - if (!device_present()) return -1; - return dispatch3(k_pairing, P, Q, out, 9*8*n, 18*8*n, 48*8*n, n); -} - -} // extern "C" - -#else // LUX_BN254_HAVE_CUDA undefined: CPU-oracle path - -#include "bn254_fp.hpp" -#include "bn254_fp2.hpp" -#include "bn254_fp12.hpp" -#include "bn254_g1.hpp" -#include "bn254_g2.hpp" -#include "bn254_hash_to_curve.hpp" -#include "bn254_pairing.hpp" - -namespace { - -using lux::crypto::bn254::U256; -using lux::crypto::bn254::Fp2; -using lux::crypto::bn254::Fp12; -using lux::crypto::bn254::G1Affine; -using lux::crypto::bn254::G1Jac; -using lux::crypto::bn254::G2Affine; - -inline U256 load_u256(const std::uint64_t* p) { - U256 r; r.limbs[0]=p[0]; r.limbs[1]=p[1]; r.limbs[2]=p[2]; r.limbs[3]=p[3]; - return r; -} - -inline void store_aff(std::uint64_t* p, const G1Affine& a) { - p[0]=a.x.limbs[0]; p[1]=a.x.limbs[1]; p[2]=a.x.limbs[2]; p[3]=a.x.limbs[3]; - p[4]=a.y.limbs[0]; p[5]=a.y.limbs[1]; p[6]=a.y.limbs[2]; p[7]=a.y.limbs[3]; - p[8] = a.infinity ? 1ULL : 0ULL; -} - -inline G1Affine load_aff(const std::uint64_t* p) { - G1Affine a; - a.x = load_u256(p); - a.y = load_u256(p + 4); - a.infinity = (p[8] != 0); - return a; -} - -inline Fp2 load_fp2(const std::uint64_t* p) { - Fp2 r; - for (int i = 0; i < 4; ++i) { r.a0.limbs[i] = p[i]; r.a1.limbs[i] = p[4+i]; } - return r; -} -inline void store_fp2(std::uint64_t* p, const Fp2& x) { - for (int i = 0; i < 4; ++i) { p[i] = x.a0.limbs[i]; p[4+i] = x.a1.limbs[i]; } -} -inline G2Affine load_g2(const std::uint64_t* p) { - G2Affine a; - a.x = load_fp2(p); - a.y = load_fp2(p + 8); - a.infinity = (p[16] != 0); - return a; -} -inline Fp12 load_fp12(const std::uint64_t* p) { - Fp12 r; - r.c0.b0 = load_fp2(p + 0); - r.c0.b1 = load_fp2(p + 8); - r.c0.b2 = load_fp2(p + 16); - r.c1.b0 = load_fp2(p + 24); - r.c1.b1 = load_fp2(p + 32); - r.c1.b2 = load_fp2(p + 40); - return r; -} -inline void store_fp12(std::uint64_t* p, const Fp12& x) { - store_fp2(p + 0, x.c0.b0); - store_fp2(p + 8, x.c0.b1); - store_fp2(p + 16, x.c0.b2); - store_fp2(p + 24, x.c1.b0); - store_fp2(p + 32, x.c1.b1); - store_fp2(p + 40, x.c1.b2); -} - -} // namespace - -extern "C" { - -int lux_bn254_cuda_available(void) { return 0; } - -int lux_bn254_cuda_g1_add(const void* a, const void* b, void* out, unsigned n) { - auto* pa = (const std::uint64_t*)a; - auto* pb = (const std::uint64_t*)b; - auto* po = (std::uint64_t*)out; - for (unsigned i = 0; i < n; ++i) { - G1Affine A = load_aff(pa + i*9); - G1Affine B = load_aff(pb + i*9); - G1Jac S = lux::crypto::bn254::g1_add( - lux::crypto::bn254::g1_to_jac(A), - lux::crypto::bn254::g1_to_jac(B)); - store_aff(po + i*9, lux::crypto::bn254::g1_to_affine(S)); - } - return 0; -} - -int lux_bn254_cuda_g1_mul(const void* points, const void* scalars, void* out, unsigned n) { - auto* pp = (const std::uint64_t*)points; - auto* ps = (const std::uint64_t*)scalars; - auto* po = (std::uint64_t*)out; - for (unsigned i = 0; i < n; ++i) { - G1Affine P = load_aff(pp + i*9); - U256 k = load_u256(ps + i*4); - store_aff(po + i*9, lux::crypto::bn254::g1_to_affine( - lux::crypto::bn254::g1_scalar_mul(P, k))); - } - return 0; -} - -int lux_bn254_cuda_svdw(const void* u_in, void* out, unsigned n) { - auto* pu = (const std::uint64_t*)u_in; - auto* po = (std::uint64_t*)out; - for (unsigned i = 0; i < n; ++i) { - U256 u = load_u256(pu + i*4); - G1Affine R = lux::crypto::bn254::h2c::map_to_curve_svdw(u); - store_aff(po + i*9, R); - } - return 0; -} - -int lux_bn254_cuda_fp_mul(const void* a, const void* b, void* out, unsigned n) { - auto* pa = (const std::uint64_t*)a; - auto* pb = (const std::uint64_t*)b; - auto* po = (std::uint64_t*)out; - for (unsigned i = 0; i < n; ++i) { - U256 A = load_u256(pa + i*4); - U256 B = load_u256(pb + i*4); - U256 R = lux::crypto::bn254::fp_mul(A, B); - po[i*4+0]=R.limbs[0]; po[i*4+1]=R.limbs[1]; - po[i*4+2]=R.limbs[2]; po[i*4+3]=R.limbs[3]; - } - return 0; -} - -int lux_bn254_cuda_fp2_mul(const void* a, const void* b, void* out, unsigned n) { - auto* pa = (const std::uint64_t*)a; - auto* pb = (const std::uint64_t*)b; - auto* po = (std::uint64_t*)out; - for (unsigned i = 0; i < n; ++i) { - store_fp2(po + i*8, - lux::crypto::bn254::fp2_mul(load_fp2(pa + i*8), load_fp2(pb + i*8))); - } - return 0; -} - -int lux_bn254_cuda_fp12_mul(const void* a, const void* b, void* out, unsigned n) { - auto* pa = (const std::uint64_t*)a; - auto* pb = (const std::uint64_t*)b; - auto* po = (std::uint64_t*)out; - for (unsigned i = 0; i < n; ++i) { - store_fp12(po + i*48, - lux::crypto::bn254::fp12_mul(load_fp12(pa + i*48), load_fp12(pb + i*48))); - } - return 0; -} - -int lux_bn254_cuda_miller_iter(const void* in_p, void* out, unsigned n) { - auto* pi = (const std::uint64_t*)in_p; - auto* po = (std::uint64_t*)out; - for (unsigned i = 0; i < n; ++i) { - Fp12 z = load_fp12(pi + i*48); - for (int k = 0; k < 100; ++k) - z = lux::crypto::bn254::cyclotomic_sqr_public(z); - store_fp12(po + i*48, z); - } - return 0; -} - -int lux_bn254_cuda_pairing(const void* P, const void* Q, void* out, unsigned n) { - auto* pP = (const std::uint64_t*)P; - auto* pQ = (const std::uint64_t*)Q; - auto* po = (std::uint64_t*)out; - for (unsigned i = 0; i < n; ++i) { - G1Affine pi = load_aff(pP + i*9); - G2Affine qi = load_g2(pQ + i*18); - Fp12 r = lux::crypto::bn254::multi_pair(&pi, &qi, 1); - store_fp12(po + i*48, r); - } - return 0; -} - -} // extern "C" - -#endif // LUX_BN254_HAVE_CUDA diff --git a/bn254/gpu/cuda/bn254_driver_cuda.h b/bn254/gpu/cuda/bn254_driver_cuda.h deleted file mode 100644 index b90e9e3..0000000 --- a/bn254/gpu/cuda/bn254_driver_cuda.h +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// C-ABI interface for the bn254 CUDA driver. Function signatures mirror the -// Metal driver and the CPU oracle so the test harness can dispatch identical -// vectors to all backends and assert byte-equality. - -#ifndef LUX_BN254_DRIVER_CUDA_H -#define LUX_BN254_DRIVER_CUDA_H - -#ifdef __cplusplus -extern "C" { -#endif - -// 1 if real CUDA device is present; 0 if running CPU-oracle fallback path. -int lux_bn254_cuda_available(void); - -// Each affine point is 9 x u64 little-endian (x[4] || y[4] || inf[1]), -// every field element in Montgomery form. Each scalar is 4 x u64 LE. - -// out = a + b in G1 -int lux_bn254_cuda_g1_add(const void* a, const void* b, void* out, unsigned n); - -// out = scalar * point in G1 -int lux_bn254_cuda_g1_mul(const void* points, const void* scalars, void* out, unsigned n); - -// out = SVDW map_to_curve_g1(u). u is 4 x u64 in Montgomery form. -int lux_bn254_cuda_svdw(const void* u_in, void* out, unsigned n); - -// out = a * b mod p (Montgomery). a, b each 4 x u64. -int lux_bn254_cuda_fp_mul(const void* a, const void* b, void* out, unsigned n); - -// --- Pairing tower --------------------------------------------------------- -// Each Fp2 element is 8 x u64 (a0[4] || a1[4]) in Montgomery form. -// Each Fp12 element is 48 x u64 (12 x Fp2 in c0.b0 .. c1.b2 order). -// G2Affine is 18 x u64 (x.a0[4] || x.a1[4] || y.a0[4] || y.a1[4] || inf || pad). -// -// out = a * b in Fp2. -int lux_bn254_cuda_fp2_mul(const void* a, const void* b, void* out, unsigned n); - -// out = a * b in Fp12. -int lux_bn254_cuda_fp12_mul(const void* a, const void* b, void* out, unsigned n); - -// out = cyclotomic_sqr^100(in) -- Miller-loop inner-square stress. -int lux_bn254_cuda_miller_iter(const void* in_p, void* out, unsigned n); - -// out = e(P, Q) in Fp12 (Miller + final exp). Single-pair per slot; -// multi-pair composition is up to the caller. -int lux_bn254_cuda_pairing(const void* P, const void* Q, void* out, unsigned n); - -#ifdef __cplusplus -} -#endif - -#endif // LUX_BN254_DRIVER_CUDA_H diff --git a/bn254/gpu/cuda/bn254_pairing.cuh b/bn254/gpu/cuda/bn254_pairing.cuh deleted file mode 100644 index 8d565de..0000000 --- a/bn254/gpu/cuda/bn254_pairing.cuh +++ /dev/null @@ -1,890 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// First-party CUDA tower (Fp2/Fp6/Fp12 + G2) and optimal-ate pairing kernel -// for bn254. Included by bn254.cu inside its LUX_BN254_HAVE_CUDA block, so it -// compiles as part of the same translation unit and reuses U256 / Fp ops / -// G1Aff / __constant__ K_P from there. -// -// Algorithm transliteration of bn254/cpp/{bn254_fp2,bn254_fp6,bn254_fp12, -// bn254_g2,bn254_pairing}. The CPU body is the algorithmic oracle. Frobenius -// constants are emitted from the CPU body via gen_pairing_constants -- single -// producer, drift impossible. -// -// Wire format: -// G2Affine = 18 x u64 LE (x.a0[4] || x.a1[4] || y.a0[4] || y.a1[4] || inf[1] || pad[1]) -// Fp2 = 8 x u64 LE (a0[4] || a1[4]) -// Fp12 = 48 x u64 LE (12 x Fp2 in c0.b0..c1.b2 order) - -#ifndef LUX_BN254_PAIRING_CUH -#define LUX_BN254_PAIRING_CUH - -// ============================================================================= -// Tower types (Fp2 already declared in bn254.cu host scope -- do not redefine). -// ============================================================================= - -struct Fp6_ { Fp2 b0, b1, b2; }; -struct Fp12_ { Fp6_ c0, c1; }; -struct G2Aff { Fp2 x, y; int inf; }; -struct G2Proj { Fp2 x, y, z; }; - -// Frobenius constants emitted from CPU body. Single-producer codegen. -#include "bn254_pairing_consts_cuda.cuh" - -// ============================================================================= -// Fp2 = Fp[u]/(u^2 + 1) -- Karatsuba mul, complex-square sqr. -// ============================================================================= - -__device__ __forceinline__ Fp2 fp2_zero_() { - Fp2 r; - #pragma unroll - for (int i = 0; i < 4; ++i) { r.a0.limbs[i] = 0; r.a1.limbs[i] = 0; } - return r; -} - -__device__ __forceinline__ bool fp2_is_zero_(const Fp2& x) { - return u256_is_zero(x.a0) && u256_is_zero(x.a1); -} - -__device__ __forceinline__ Fp2 fp2_one_() { - Fp2 r; r.a0 = to_mont_fp(u256_from_limbs(1,0,0,0)); - #pragma unroll - for (int i = 0; i < 4; ++i) r.a1.limbs[i] = 0; - return r; -} - -__device__ __forceinline__ Fp2 fp2_add_(const Fp2& x, const Fp2& y) { - Fp2 r; r.a0 = fp_add(x.a0, y.a0); r.a1 = fp_add(x.a1, y.a1); - return r; -} - -__device__ __forceinline__ Fp2 fp2_sub_(const Fp2& x, const Fp2& y) { - Fp2 r; r.a0 = fp_sub(x.a0, y.a0); r.a1 = fp_sub(x.a1, y.a1); - return r; -} - -__device__ __forceinline__ Fp2 fp2_neg_(const Fp2& x) { - Fp2 r; r.a0 = fp_neg(x.a0); r.a1 = fp_neg(x.a1); - return r; -} - -__device__ __forceinline__ Fp2 fp2_double_(const Fp2& x) { - Fp2 r; r.a0 = fp_add(x.a0, x.a0); r.a1 = fp_add(x.a1, x.a1); - return r; -} - -__device__ __forceinline__ Fp2 fp2_conjugate_(const Fp2& x) { - Fp2 r; r.a0 = x.a0; r.a1 = fp_neg(x.a1); - return r; -} - -__device__ __forceinline__ Fp2 fp2_mul_by_fp_(const Fp2& x, const U256& y) { - Fp2 r; r.a0 = fp_mul(x.a0, y); r.a1 = fp_mul(x.a1, y); - return r; -} - -// Karatsuba: matches CPU fp2_mul exactly. -__device__ Fp2 fp2_mul_(const Fp2& x, const Fp2& y) { - U256 a = fp_mul(fp_add(x.a0, x.a1), fp_add(y.a0, y.a1)); - U256 b = fp_mul(x.a0, y.a0); - U256 c = fp_mul(x.a1, y.a1); - Fp2 r; - r.a1 = fp_sub(fp_sub(a, b), c); - r.a0 = fp_sub(b, c); - return r; -} - -__device__ Fp2 fp2_sqr_(const Fp2& x) { - U256 a = fp_mul(fp_add(x.a0, x.a1), fp_sub(x.a0, x.a1)); - U256 b = fp_mul(x.a0, x.a1); - Fp2 r; r.a0 = a; r.a1 = fp_add(b, b); - return r; -} - -__device__ Fp2 fp2_inv_(const Fp2& x) { - U256 t0 = fp_sqr(x.a0); - U256 t1 = fp_sqr(x.a1); - U256 t = fp_add(t0, t1); - U256 ti = fp_inv(t); - Fp2 r; - r.a0 = fp_mul(x.a0, ti); - r.a1 = fp_neg(fp_mul(x.a1, ti)); - return r; -} - -// (a0 + a1*u) * (9 + u): 9*a0 = 8*a0 + a0 (three doublings + add). -__device__ Fp2 fp2_mul_by_nonres_(const Fp2& x) { - U256 t0 = fp_add(x.a0, x.a0); - t0 = fp_add(t0, t0); - t0 = fp_add(t0, t0); // 8 a0 - U256 t1 = fp_add(x.a1, x.a1); - t1 = fp_add(t1, t1); - t1 = fp_add(t1, t1); // 8 a1 - Fp2 r; - r.a0 = fp_sub(fp_add(t0, x.a0), x.a1); // 9 a0 - a1 - r.a1 = fp_add(fp_add(t1, x.a1), x.a0); // 9 a1 + a0 - return r; -} - -// Multiplies by (9+u)^-1. Constant computed once per invocation; pairing usage -// only hits this for the b-twist coefficient and inside line-evaluation, so -// the cost is amortised. -__device__ Fp2 fp2_mul_by_nonres_inv_(const Fp2& x) { - Fp2 nr; - nr.a0 = to_mont_fp(u256_from_limbs(9,0,0,0)); - nr.a1 = to_mont_fp(u256_from_limbs(1,0,0,0)); - Fp2 inv_nr = fp2_inv_(nr); - return fp2_mul_(x, inv_nr); -} - -// ============================================================================= -// Fp6 = Fp2[v]/(v^3 - (9+u)) -- Algorithms 13/16/17, eprint 2010/354. -// ============================================================================= - -__device__ __forceinline__ Fp6_ fp6_zero_() { - Fp6_ r; r.b0 = fp2_zero_(); r.b1 = fp2_zero_(); r.b2 = fp2_zero_(); - return r; -} -__device__ __forceinline__ Fp6_ fp6_one_() { - Fp6_ r; r.b0 = fp2_one_(); r.b1 = fp2_zero_(); r.b2 = fp2_zero_(); - return r; -} -__device__ __forceinline__ Fp6_ fp6_add_(const Fp6_& x, const Fp6_& y) { - Fp6_ r; r.b0 = fp2_add_(x.b0, y.b0); r.b1 = fp2_add_(x.b1, y.b1); r.b2 = fp2_add_(x.b2, y.b2); - return r; -} -__device__ __forceinline__ Fp6_ fp6_sub_(const Fp6_& x, const Fp6_& y) { - Fp6_ r; r.b0 = fp2_sub_(x.b0, y.b0); r.b1 = fp2_sub_(x.b1, y.b1); r.b2 = fp2_sub_(x.b2, y.b2); - return r; -} -__device__ __forceinline__ Fp6_ fp6_neg_(const Fp6_& x) { - Fp6_ r; r.b0 = fp2_neg_(x.b0); r.b1 = fp2_neg_(x.b1); r.b2 = fp2_neg_(x.b2); - return r; -} -__device__ __forceinline__ Fp6_ fp6_double_(const Fp6_& x) { - Fp6_ r; r.b0 = fp2_double_(x.b0); r.b1 = fp2_double_(x.b1); r.b2 = fp2_double_(x.b2); - return r; -} -__device__ __forceinline__ Fp6_ fp6_mul_by_nonres_(const Fp6_& x) { - Fp6_ r; r.b0 = fp2_mul_by_nonres_(x.b2); r.b1 = x.b0; r.b2 = x.b1; - return r; -} - -__device__ Fp6_ fp6_mul_(const Fp6_& x, const Fp6_& y) { - Fp2 t0 = fp2_mul_(x.b0, y.b0); - Fp2 t1 = fp2_mul_(x.b1, y.b1); - Fp2 t2 = fp2_mul_(x.b2, y.b2); - - Fp2 c0 = fp2_add_(x.b1, x.b2); - Fp2 tmp = fp2_add_(y.b1, y.b2); - c0 = fp2_mul_(c0, tmp); - c0 = fp2_sub_(c0, t1); - c0 = fp2_sub_(c0, t2); - c0 = fp2_mul_by_nonres_(c0); - c0 = fp2_add_(c0, t0); - - Fp2 c1 = fp2_add_(x.b0, x.b1); - tmp = fp2_add_(y.b0, y.b1); - c1 = fp2_mul_(c1, tmp); - c1 = fp2_sub_(c1, t0); - c1 = fp2_sub_(c1, t1); - Fp2 t2_nr = fp2_mul_by_nonres_(t2); - c1 = fp2_add_(c1, t2_nr); - - Fp2 c2 = fp2_add_(x.b0, x.b2); - tmp = fp2_add_(y.b0, y.b2); - c2 = fp2_mul_(c2, tmp); - c2 = fp2_sub_(c2, t0); - c2 = fp2_sub_(c2, t2); - c2 = fp2_add_(c2, t1); - - Fp6_ r; r.b0 = c0; r.b1 = c1; r.b2 = c2; - return r; -} - -__device__ Fp6_ fp6_sqr_(const Fp6_& x) { - Fp2 c4 = fp2_mul_(x.b0, x.b1); - c4 = fp2_double_(c4); - Fp2 c5 = fp2_sqr_(x.b2); - Fp2 c1 = fp2_mul_by_nonres_(c5); - c1 = fp2_add_(c1, c4); - Fp2 c2 = fp2_sub_(c4, c5); - Fp2 c3 = fp2_sqr_(x.b0); - Fp2 c4b = fp2_sub_(x.b0, x.b1); - c4b = fp2_add_(c4b, x.b2); - Fp2 c5b = fp2_mul_(x.b1, x.b2); - c5b = fp2_double_(c5b); - c4b = fp2_sqr_(c4b); - Fp2 c0 = fp2_mul_by_nonres_(c5b); - c0 = fp2_add_(c0, c3); - - Fp2 z2 = fp2_add_(c2, c4b); - z2 = fp2_add_(z2, c5b); - z2 = fp2_sub_(z2, c3); - - Fp6_ r; r.b0 = c0; r.b1 = c1; r.b2 = z2; - return r; -} - -__device__ Fp6_ fp6_inv_(const Fp6_& x) { - Fp2 t0 = fp2_sqr_(x.b0); - Fp2 t1 = fp2_sqr_(x.b1); - Fp2 t2 = fp2_sqr_(x.b2); - Fp2 t3 = fp2_mul_(x.b0, x.b1); - Fp2 t4 = fp2_mul_(x.b0, x.b2); - Fp2 t5 = fp2_mul_(x.b1, x.b2); - - Fp2 c0 = fp2_mul_by_nonres_(t5); - c0 = fp2_neg_(c0); - c0 = fp2_add_(c0, t0); - - Fp2 c1 = fp2_mul_by_nonres_(t2); - c1 = fp2_sub_(c1, t3); - - Fp2 c2 = fp2_sub_(t1, t4); - - Fp2 t6 = fp2_mul_(x.b0, c0); - Fp2 d1 = fp2_mul_(x.b2, c1); - Fp2 d2 = fp2_mul_(x.b1, c2); - Fp2 d = fp2_add_(d1, d2); - d = fp2_mul_by_nonres_(d); - t6 = fp2_add_(t6, d); - Fp2 t6_inv = fp2_inv_(t6); - - Fp6_ r; - r.b0 = fp2_mul_(c0, t6_inv); - r.b1 = fp2_mul_(c1, t6_inv); - r.b2 = fp2_mul_(c2, t6_inv); - return r; -} - -__device__ Fp6_ fp6_mul_by_01_(const Fp6_& z, const Fp2& c0, const Fp2& c1) { - Fp2 a = fp2_mul_(z.b0, c0); - Fp2 b = fp2_mul_(z.b1, c1); - - Fp2 tmp = fp2_add_(z.b1, z.b2); - Fp2 t0 = fp2_mul_(c1, tmp); - t0 = fp2_sub_(t0, b); - t0 = fp2_mul_by_nonres_(t0); - t0 = fp2_add_(t0, a); - - tmp = fp2_add_(z.b0, z.b2); - Fp2 t2 = fp2_mul_(c0, tmp); - t2 = fp2_sub_(t2, a); - t2 = fp2_add_(t2, b); - - Fp2 t1 = fp2_add_(c0, c1); - tmp = fp2_add_(z.b0, z.b1); - t1 = fp2_mul_(t1, tmp); - t1 = fp2_sub_(t1, a); - t1 = fp2_sub_(t1, b); - - Fp6_ r; r.b0 = t0; r.b1 = t1; r.b2 = t2; - return r; -} - -__device__ Fp6_ fp6_mul_by_fp2_(const Fp6_& z, const Fp2& y) { - Fp6_ r; - r.b0 = fp2_mul_(z.b0, y); - r.b1 = fp2_mul_(z.b1, y); - r.b2 = fp2_mul_(z.b2, y); - return r; -} - -// ============================================================================= -// Fp12 = Fp6[w]/(w^2 - v) -// ============================================================================= - -__device__ __forceinline__ Fp12_ fp12_zero_() { Fp12_ r; r.c0 = fp6_zero_(); r.c1 = fp6_zero_(); return r; } -__device__ __forceinline__ Fp12_ fp12_one_() { Fp12_ r; r.c0 = fp6_one_(); r.c1 = fp6_zero_(); return r; } - -__device__ bool fp12_is_one_(const Fp12_& z) { - Fp12_ one = fp12_one_(); - return fp2_is_zero_(z.c1.b0) && fp2_is_zero_(z.c1.b1) && fp2_is_zero_(z.c1.b2) - && fp2_is_zero_(z.c0.b1) && fp2_is_zero_(z.c0.b2) - && u256_is_zero(z.c0.b0.a1) - && u256_eq(z.c0.b0.a0, one.c0.b0.a0); -} - -__device__ __forceinline__ Fp12_ fp12_add_(const Fp12_& x, const Fp12_& y) { - Fp12_ r; r.c0 = fp6_add_(x.c0, y.c0); r.c1 = fp6_add_(x.c1, y.c1); return r; -} -__device__ __forceinline__ Fp12_ fp12_sub_(const Fp12_& x, const Fp12_& y) { - Fp12_ r; r.c0 = fp6_sub_(x.c0, y.c0); r.c1 = fp6_sub_(x.c1, y.c1); return r; -} -__device__ __forceinline__ Fp12_ fp12_neg_(const Fp12_& x) { - Fp12_ r; r.c0 = fp6_neg_(x.c0); r.c1 = fp6_neg_(x.c1); return r; -} -__device__ __forceinline__ Fp12_ fp12_conjugate_(const Fp12_& x) { - Fp12_ r; r.c0 = x.c0; r.c1 = fp6_neg_(x.c1); return r; -} - -__device__ Fp12_ fp12_mul_(const Fp12_& x, const Fp12_& y) { - Fp6_ a = fp6_add_(x.c0, x.c1); - Fp6_ b = fp6_add_(y.c0, y.c1); - a = fp6_mul_(a, b); - b = fp6_mul_(x.c0, y.c0); - Fp6_ c = fp6_mul_(x.c1, y.c1); - Fp12_ r; - r.c1 = fp6_sub_(fp6_sub_(a, b), c); - r.c0 = fp6_add_(fp6_mul_by_nonres_(c), b); - return r; -} - -__device__ Fp12_ fp12_sqr_(const Fp12_& x) { - Fp6_ c0 = fp6_sub_(x.c0, x.c1); - Fp6_ c3 = fp6_mul_by_nonres_(x.c1); - c3 = fp6_neg_(c3); - c3 = fp6_add_(x.c0, c3); - Fp6_ c2 = fp6_mul_(x.c0, x.c1); - c0 = fp6_mul_(c0, c3); - c0 = fp6_add_(c0, c2); - Fp6_ r1 = fp6_double_(c2); - c2 = fp6_mul_by_nonres_(c2); - Fp6_ r0 = fp6_add_(c0, c2); - Fp12_ r; r.c0 = r0; r.c1 = r1; - return r; -} - -__device__ Fp12_ fp12_inv_(const Fp12_& x) { - Fp6_ t0 = fp6_sqr_(x.c0); - Fp6_ t1 = fp6_sqr_(x.c1); - Fp6_ tmp = fp6_mul_by_nonres_(t1); - t0 = fp6_sub_(t0, tmp); - Fp6_ t0_inv = fp6_inv_(t0); - Fp12_ r; - r.c0 = fp6_mul_(x.c0, t0_inv); - r.c1 = fp6_neg_(fp6_mul_(x.c1, t0_inv)); - return r; -} - -__device__ Fp12_ fp12_mul_by_034_(const Fp12_& z, const Fp2& c0, const Fp2& c3, const Fp2& c4) { - Fp6_ a = fp6_mul_by_fp2_(z.c0, c0); - Fp6_ b = z.c1; - b = fp6_mul_by_01_(b, c3, c4); - - Fp2 d0 = fp2_add_(c0, c3); - Fp6_ d = fp6_add_(z.c0, z.c1); - d = fp6_mul_by_01_(d, d0, c4); - - Fp6_ r1 = fp6_add_(a, b); - r1 = fp6_neg_(r1); - r1 = fp6_add_(r1, d); - Fp6_ r0 = fp6_mul_by_nonres_(b); - r0 = fp6_add_(r0, a); - Fp12_ r; r.c0 = r0; r.c1 = r1; - return r; -} - -struct Fp12Sparse5 { Fp2 v00, v01, v02, v10, v11; }; - -__device__ Fp12Sparse5 fp12_mul_034_by_034_( - const Fp2& d0, const Fp2& d3, const Fp2& d4, - const Fp2& c0, const Fp2& c3, const Fp2& c4) { - Fp2 x0 = fp2_mul_(c0, d0); - Fp2 x3 = fp2_mul_(c3, d3); - Fp2 x4 = fp2_mul_(c4, d4); - - Fp2 tmp = fp2_add_(c0, c4); - Fp2 x04 = fp2_add_(d0, d4); - x04 = fp2_mul_(x04, tmp); - x04 = fp2_sub_(x04, x0); - x04 = fp2_sub_(x04, x4); - - tmp = fp2_add_(c0, c3); - Fp2 x03 = fp2_add_(d0, d3); - x03 = fp2_mul_(x03, tmp); - x03 = fp2_sub_(x03, x0); - x03 = fp2_sub_(x03, x3); - - tmp = fp2_add_(c3, c4); - Fp2 x34 = fp2_add_(d3, d4); - x34 = fp2_mul_(x34, tmp); - x34 = fp2_sub_(x34, x3); - x34 = fp2_sub_(x34, x4); - - Fp2 z00 = fp2_mul_by_nonres_(x4); - z00 = fp2_add_(z00, x0); - Fp12Sparse5 r; - r.v00 = z00; r.v01 = x3; r.v02 = x34; r.v10 = x03; r.v11 = x04; - return r; -} - -__device__ Fp12_ fp12_mul_by_01234_(const Fp12_& z, const Fp12Sparse5& x) { - Fp6_ c0_part; c0_part.b0 = x.v00; c0_part.b1 = x.v01; c0_part.b2 = x.v02; - Fp6_ c1_part; c1_part.b0 = x.v10; c1_part.b1 = x.v11; c1_part.b2 = fp2_zero_(); - - Fp6_ a = fp6_add_(z.c0, z.c1); - Fp6_ b = fp6_add_(c0_part, c1_part); - a = fp6_mul_(a, b); - - b = fp6_mul_(z.c0, c0_part); - Fp6_ c = fp6_mul_by_01_(z.c1, x.v10, x.v11); - - Fp6_ r1 = fp6_sub_(a, b); - r1 = fp6_sub_(r1, c); - - Fp6_ r0 = fp6_mul_by_nonres_(c); - r0 = fp6_add_(r0, b); - - Fp12_ r; r.c0 = r0; r.c1 = r1; - return r; -} - -// ============================================================================= -// Frobenius operators on Fp12 (Algorithms 28-30, eprint 2010/354). -// ============================================================================= - -__device__ Fp2 mul_nr1_(const Fp2& x, const u64* nr_a0, const u64* nr_a1) { - Fp2 nr; - #pragma unroll - for (int i = 0; i < 4; ++i) { nr.a0.limbs[i] = nr_a0[i]; nr.a1.limbs[i] = nr_a1[i]; } - return fp2_mul_(x, nr); -} - -__device__ Fp2 mul_nr2_(const Fp2& x, const u64* nr_scalar) { - U256 s; - #pragma unroll - for (int i = 0; i < 4; ++i) s.limbs[i] = nr_scalar[i]; - Fp2 r; r.a0 = fp_mul(x.a0, s); r.a1 = fp_mul(x.a1, s); - return r; -} - -__device__ Fp12_ frobenius_(const Fp12_& x) { - Fp2 t0 = fp2_conjugate_(x.c0.b0); - Fp2 t1 = fp2_conjugate_(x.c0.b1); - Fp2 t2 = fp2_conjugate_(x.c0.b2); - Fp2 t3 = fp2_conjugate_(x.c1.b0); - Fp2 t4 = fp2_conjugate_(x.c1.b1); - Fp2 t5 = fp2_conjugate_(x.c1.b2); - - t1 = mul_nr1_(t1, K_NR1P2_A0, K_NR1P2_A1); - t2 = mul_nr1_(t2, K_NR1P4_A0, K_NR1P4_A1); - t3 = mul_nr1_(t3, K_NR1P1_A0, K_NR1P1_A1); - t4 = mul_nr1_(t4, K_NR1P3_A0, K_NR1P3_A1); - t5 = mul_nr1_(t5, K_NR1P5_A0, K_NR1P5_A1); - - Fp12_ z; - z.c0.b0 = t0; z.c0.b1 = t1; z.c0.b2 = t2; - z.c1.b0 = t3; z.c1.b1 = t4; z.c1.b2 = t5; - return z; -} - -__device__ Fp12_ frobenius_sq_(const Fp12_& x) { - Fp12_ z; - z.c0.b0 = x.c0.b0; - z.c0.b1 = mul_nr2_(x.c0.b1, K_NR2P2); - z.c0.b2 = mul_nr2_(x.c0.b2, K_NR2P4); - z.c1.b0 = mul_nr2_(x.c1.b0, K_NR2P1); - z.c1.b1 = mul_nr2_(x.c1.b1, K_NR2P3); - z.c1.b2 = mul_nr2_(x.c1.b2, K_NR2P5); - return z; -} - -__device__ Fp12_ frobenius_cube_(const Fp12_& x) { - Fp2 t0 = fp2_conjugate_(x.c0.b0); - Fp2 t1 = fp2_conjugate_(x.c0.b1); - Fp2 t2 = fp2_conjugate_(x.c0.b2); - Fp2 t3 = fp2_conjugate_(x.c1.b0); - Fp2 t4 = fp2_conjugate_(x.c1.b1); - Fp2 t5 = fp2_conjugate_(x.c1.b2); - - t1 = mul_nr1_(t1, K_NR3P2_A0, K_NR3P2_A1); - t2 = mul_nr1_(t2, K_NR3P4_A0, K_NR3P4_A1); - t3 = mul_nr1_(t3, K_NR3P1_A0, K_NR3P1_A1); - t4 = mul_nr1_(t4, K_NR3P3_A0, K_NR3P3_A1); - t5 = mul_nr1_(t5, K_NR3P5_A0, K_NR3P5_A1); - - Fp12_ z; - z.c0.b0 = t0; z.c0.b1 = t1; z.c0.b2 = t2; - z.c1.b0 = t3; z.c1.b1 = t4; z.c1.b2 = t5; - return z; -} - -// ============================================================================= -// Granger-Scott cyclotomic squaring (eprint 2009/565 §3.2). -// ============================================================================= - -__device__ Fp12_ cyclotomic_sqr_(const Fp12_& x) { - Fp2 t0 = fp2_sqr_(x.c1.b1); - Fp2 t1 = fp2_sqr_(x.c0.b0); - Fp2 t6 = fp2_sub_(fp2_sub_(fp2_sqr_(fp2_add_(x.c1.b1, x.c0.b0)), t0), t1); - Fp2 t2 = fp2_sqr_(x.c0.b2); - Fp2 t3 = fp2_sqr_(x.c1.b0); - Fp2 t7 = fp2_sub_(fp2_sub_(fp2_sqr_(fp2_add_(x.c0.b2, x.c1.b0)), t2), t3); - Fp2 t4 = fp2_sqr_(x.c1.b2); - Fp2 t5 = fp2_sqr_(x.c0.b1); - Fp2 t8 = fp2_sub_(fp2_sub_(fp2_sqr_(fp2_add_(x.c1.b2, x.c0.b1)), t4), t5); - t8 = fp2_mul_by_nonres_(t8); - - t0 = fp2_add_(fp2_mul_by_nonres_(t0), t1); - t2 = fp2_add_(fp2_mul_by_nonres_(t2), t3); - t4 = fp2_add_(fp2_mul_by_nonres_(t4), t5); - - Fp12_ z; - z.c0.b0 = fp2_add_(fp2_double_(fp2_sub_(t0, x.c0.b0)), t0); - z.c0.b1 = fp2_add_(fp2_double_(fp2_sub_(t2, x.c0.b1)), t2); - z.c0.b2 = fp2_add_(fp2_double_(fp2_sub_(t4, x.c0.b2)), t4); - z.c1.b0 = fp2_add_(fp2_double_(fp2_add_(t8, x.c1.b0)), t8); - z.c1.b1 = fp2_add_(fp2_double_(fp2_add_(t6, x.c1.b1)), t6); - z.c1.b2 = fp2_add_(fp2_double_(fp2_add_(t7, x.c1.b2)), t7); - return z; -} - -__device__ Fp12_ cyclotomic_n_sqr_(Fp12_ z, int n) { - for (int i = 0; i < n; ++i) z = cyclotomic_sqr_(z); - return z; -} - -// ============================================================================= -// Expt: x^t with t = 4965661367192848881 -- gnark addition chain. -// ============================================================================= - -__device__ Fp12_ expt_(const Fp12_& x) { - Fp12_ t3 = cyclotomic_sqr_(x); - Fp12_ t5 = cyclotomic_sqr_(t3); - Fp12_ result = cyclotomic_sqr_(t5); - Fp12_ t0 = cyclotomic_sqr_(result); - Fp12_ t2 = fp12_mul_(x, t0); - t0 = fp12_mul_(t3, t2); - Fp12_ t1 = fp12_mul_(x, t0); - Fp12_ t4 = fp12_mul_(result, t2); - Fp12_ t6 = cyclotomic_sqr_(t2); - t1 = fp12_mul_(t0, t1); - t0 = fp12_mul_(t3, t1); - - t6 = cyclotomic_n_sqr_(t6, 6); - t5 = fp12_mul_(t5, t6); - t5 = fp12_mul_(t4, t5); - - t5 = cyclotomic_n_sqr_(t5, 7); - t4 = fp12_mul_(t4, t5); - - t4 = cyclotomic_n_sqr_(t4, 8); - t4 = fp12_mul_(t0, t4); - t3 = fp12_mul_(t3, t4); - - t3 = cyclotomic_n_sqr_(t3, 6); - t2 = fp12_mul_(t2, t3); - - t2 = cyclotomic_n_sqr_(t2, 8); - t2 = fp12_mul_(t0, t2); - - t2 = cyclotomic_n_sqr_(t2, 6); - t2 = fp12_mul_(t0, t2); - - t2 = cyclotomic_n_sqr_(t2, 10); - t1 = fp12_mul_(t1, t2); - - t1 = cyclotomic_n_sqr_(t1, 6); - t0 = fp12_mul_(t0, t1); - return fp12_mul_(result, t0); -} - -// ============================================================================= -// G2 affine + projective ops + line evaluations. -// ============================================================================= - -__device__ G2Aff g2_neg_(const G2Aff& a) { - G2Aff r; r.x = a.x; r.y = fp2_neg_(a.y); r.inf = a.inf; - return r; -} - -__device__ G2Proj g2_to_proj_(const G2Aff& a) { - G2Proj p; p.x = a.x; p.y = a.y; p.z = fp2_one_(); - return p; -} - -struct LineEval { Fp2 r0, r1, r2; }; - -__device__ U256 fp_halve_(const U256& v) { - U256 r = v; - if (r.limbs[0] & 1ULL) { - U256 P_ = u256_load(K_P); - U256 t; - u64 c = add_256(r, P_, t); - r = t; - for (int i = 0; i < 3; ++i) - r.limbs[i] = (r.limbs[i] >> 1) | (r.limbs[i+1] << 63); - r.limbs[3] = (r.limbs[3] >> 1) | (c << 63); - } else { - for (int i = 0; i < 3; ++i) - r.limbs[i] = (r.limbs[i] >> 1) | (r.limbs[i+1] << 63); - r.limbs[3] >>= 1; - } - return r; -} - -__device__ Fp2 fp2_halve_(const Fp2& x) { - Fp2 r; r.a0 = fp_halve_(x.a0); r.a1 = fp_halve_(x.a1); - return r; -} - -__device__ Fp2 mul_b_twist_(const Fp2& x) { - Fp2 res = fp2_mul_by_nonres_inv_(x); - return fp2_add_(fp2_double_(res), res); -} - -__device__ void g2_double_step_(G2Proj& p, LineEval& ev) { - Fp2 A = fp2_mul_(p.x, p.y); - A = fp2_halve_(A); - Fp2 B = fp2_sqr_(p.y); - Fp2 C = fp2_sqr_(p.z); - Fp2 D = fp2_double_(C); - D = fp2_add_(D, C); - Fp2 E = mul_b_twist_(D); - Fp2 F = fp2_double_(E); - F = fp2_add_(F, E); - Fp2 G = fp2_add_(B, F); - G = fp2_halve_(G); - Fp2 H = fp2_add_(p.y, p.z); - H = fp2_sqr_(H); - Fp2 t1 = fp2_add_(B, C); - H = fp2_sub_(H, t1); - Fp2 I = fp2_sub_(E, B); - Fp2 J = fp2_sqr_(p.x); - Fp2 EE = fp2_sqr_(E); - Fp2 K = fp2_double_(EE); - K = fp2_add_(K, EE); - - p.x = fp2_sub_(B, F); - p.x = fp2_mul_(p.x, A); - p.y = fp2_sqr_(G); - p.y = fp2_sub_(p.y, K); - p.z = fp2_mul_(B, H); - - ev.r0 = fp2_neg_(H); - ev.r1 = fp2_double_(J); - ev.r1 = fp2_add_(ev.r1, J); - ev.r2 = I; -} - -__device__ void g2_add_mixed_step_(G2Proj& p, LineEval& ev, const G2Aff& a) { - Fp2 Y2Z1 = fp2_mul_(a.y, p.z); - Fp2 O = fp2_sub_(p.y, Y2Z1); - Fp2 X2Z1 = fp2_mul_(a.x, p.z); - Fp2 L = fp2_sub_(p.x, X2Z1); - Fp2 C = fp2_sqr_(O); - Fp2 D = fp2_sqr_(L); - Fp2 E = fp2_mul_(L, D); - Fp2 F = fp2_mul_(p.z, C); - Fp2 G = fp2_mul_(p.x, D); - Fp2 t0 = fp2_double_(G); - Fp2 H = fp2_add_(E, F); - H = fp2_sub_(H, t0); - Fp2 t1 = fp2_mul_(p.y, E); - - p.x = fp2_mul_(L, H); - p.y = fp2_sub_(G, H); - p.y = fp2_mul_(p.y, O); - p.y = fp2_sub_(p.y, t1); - p.z = fp2_mul_(E, p.z); - - Fp2 t2 = fp2_mul_(L, a.y); - Fp2 J = fp2_mul_(a.x, O); - J = fp2_sub_(J, t2); - - ev.r0 = L; - ev.r1 = fp2_neg_(O); - ev.r2 = J; -} - -__device__ void g2_line_compute_(const G2Proj& p, LineEval& ev, const G2Aff& a) { - Fp2 Y2Z1 = fp2_mul_(a.y, p.z); - Fp2 O = fp2_sub_(p.y, Y2Z1); - Fp2 X2Z1 = fp2_mul_(a.x, p.z); - Fp2 L = fp2_sub_(p.x, X2Z1); - Fp2 t2 = fp2_mul_(L, a.y); - Fp2 J = fp2_mul_(a.x, O); - J = fp2_sub_(J, t2); - - ev.r0 = L; - ev.r1 = fp2_neg_(O); - ev.r2 = J; -} - -// ============================================================================= -// 6x+2 NAF loop counter (matches CPU bn254_pairing.cpp:kLoopCounter). -// ============================================================================= - -__constant__ signed char K_LOOP_NAF[65] = { - 0, 0, 0, 1, 0, 1, 0, -1, 0, 0, 1, -1, 0, 0, 1, 0, - 0, 1, 1, 0, -1, 0, 0, 1, 0, -1, 0, 0, 0, 0, 1, 1, - 1, 0, 0, -1, 0, 0, 1, 0, 0, 0, 0, 0, -1, 0, 0, 1, - 1, 0, 0, -1, 0, 0, 0, 1, 1, 0, -1, 0, 0, 1, 0, 1, - 1 -}; - -// ============================================================================= -// Single-pair Miller loop. Multi-pair composition is host-side (tree-reduce -// of per-pair Fp12, single final-exp at end). -// ============================================================================= - -__device__ Fp12_ miller_one_(const G1Aff& P, const G2Aff& Q) { - if (P.inf || Q.inf) return fp12_one_(); - - G2Proj qProj = g2_to_proj_(Q); - G2Aff qNeg = g2_neg_(Q); - - Fp12_ result = fp12_one_(); - LineEval l1, l2; - - // Skip i=64 (LoopCounter[64] == 0 and result still 1). - g2_double_step_(qProj, l1); - result.c0.b0 = fp2_mul_by_fp_(l1.r0, P.y); - result.c1.b0 = fp2_mul_by_fp_(l1.r1, P.x); - result.c1.b1 = l1.r2; - - // i=63 (LoopCounter[63] == -1). - result = fp12_sqr_(result); - g2_line_compute_(qProj, l2, qNeg); - l2.r0 = fp2_mul_by_fp_(l2.r0, P.y); - l2.r1 = fp2_mul_by_fp_(l2.r1, P.x); - g2_add_mixed_step_(qProj, l1, Q); - l1.r0 = fp2_mul_by_fp_(l1.r0, P.y); - l1.r1 = fp2_mul_by_fp_(l1.r1, P.x); - Fp12Sparse5 prod = fp12_mul_034_by_034_(l1.r0, l1.r1, l1.r2, l2.r0, l2.r1, l2.r2); - result = fp12_mul_by_01234_(result, prod); - - // i=62 .. 0 - for (int i = 65 - 4; i >= 0; --i) { - result = fp12_sqr_(result); - g2_double_step_(qProj, l1); - l1.r0 = fp2_mul_by_fp_(l1.r0, P.y); - l1.r1 = fp2_mul_by_fp_(l1.r1, P.x); - - signed char lc = K_LOOP_NAF[i]; - if (lc == 1) { - g2_add_mixed_step_(qProj, l2, Q); - l2.r0 = fp2_mul_by_fp_(l2.r0, P.y); - l2.r1 = fp2_mul_by_fp_(l2.r1, P.x); - prod = fp12_mul_034_by_034_(l1.r0, l1.r1, l1.r2, l2.r0, l2.r1, l2.r2); - result = fp12_mul_by_01234_(result, prod); - } else if (lc == -1) { - g2_add_mixed_step_(qProj, l2, qNeg); - l2.r0 = fp2_mul_by_fp_(l2.r0, P.y); - l2.r1 = fp2_mul_by_fp_(l2.r1, P.x); - prod = fp12_mul_034_by_034_(l1.r0, l1.r1, l1.r2, l2.r0, l2.r1, l2.r2); - result = fp12_mul_by_01234_(result, prod); - } else { - result = fp12_mul_by_034_(result, l1.r0, l1.r1, l1.r2); - } - } - - // Final 6x+2 + Frobenius corrections: Q1 = pi(Q), Q2 = -pi^2(Q). - G2Aff Q1, Q2; - Fp2 q1x = fp2_conjugate_(Q.x); - Fp2 q1y = fp2_conjugate_(Q.y); - Fp2 nr_p2; - #pragma unroll - for (int i = 0; i < 4; ++i) { nr_p2.a0.limbs[i] = K_NR1P2_A0[i]; nr_p2.a1.limbs[i] = K_NR1P2_A1[i]; } - Fp2 nr_p3; - #pragma unroll - for (int i = 0; i < 4; ++i) { nr_p3.a0.limbs[i] = K_NR1P3_A0[i]; nr_p3.a1.limbs[i] = K_NR1P3_A1[i]; } - Q1.x = fp2_mul_(q1x, nr_p2); - Q1.y = fp2_mul_(q1y, nr_p3); - Q1.inf = 0; - - U256 nr2_p2; - #pragma unroll - for (int i = 0; i < 4; ++i) nr2_p2.limbs[i] = K_NR2P2[i]; - U256 nr2_p3; - #pragma unroll - for (int i = 0; i < 4; ++i) nr2_p3.limbs[i] = K_NR2P3[i]; - Fp2 q2x; q2x.a0 = fp_mul(Q.x.a0, nr2_p2); q2x.a1 = fp_mul(Q.x.a1, nr2_p2); - Fp2 q2y; q2y.a0 = fp_mul(Q.y.a0, nr2_p3); q2y.a1 = fp_mul(Q.y.a1, nr2_p3); - Q2.x = q2x; Q2.y = fp2_neg_(q2y); Q2.inf = 0; - - g2_add_mixed_step_(qProj, l2, Q1); - l2.r0 = fp2_mul_by_fp_(l2.r0, P.y); - l2.r1 = fp2_mul_by_fp_(l2.r1, P.x); - g2_line_compute_(qProj, l1, Q2); - l1.r0 = fp2_mul_by_fp_(l1.r0, P.y); - l1.r1 = fp2_mul_by_fp_(l1.r1, P.x); - prod = fp12_mul_034_by_034_(l1.r0, l1.r1, l1.r2, l2.r0, l2.r1, l2.r2); - result = fp12_mul_by_01234_(result, prod); - - return result; -} - -// ============================================================================= -// Final exponentiation -- Fuentes-Castaneda (Duquesne-Ghammam eprint 2015/192). -// ============================================================================= - -__device__ Fp12_ final_exp_(const Fp12_& z) { - Fp12_ result = z; - Fp12_ t0 = fp12_conjugate_(result); - result = fp12_inv_(result); - t0 = fp12_mul_(t0, result); - result = frobenius_sq_(t0); - result = fp12_mul_(result, t0); - - if (fp12_is_one_(result)) return result; - - Fp12_ t[5]; - t[0] = expt_(result); - t[0] = fp12_conjugate_(t[0]); - t[0] = cyclotomic_sqr_(t[0]); - t[1] = cyclotomic_sqr_(t[0]); - t[1] = fp12_mul_(t[0], t[1]); - t[2] = expt_(t[1]); - t[2] = fp12_conjugate_(t[2]); - t[3] = fp12_conjugate_(t[1]); - t[1] = fp12_mul_(t[2], t[3]); - t[3] = cyclotomic_sqr_(t[2]); - t[4] = expt_(t[3]); - t[4] = fp12_mul_(t[1], t[4]); - t[3] = fp12_mul_(t[0], t[4]); - t[0] = fp12_mul_(t[2], t[4]); - t[0] = fp12_mul_(result, t[0]); - t[2] = frobenius_(t[3]); - t[0] = fp12_mul_(t[2], t[0]); - t[2] = frobenius_sq_(t[4]); - t[0] = fp12_mul_(t[2], t[0]); - t[2] = fp12_conjugate_(result); - t[2] = fp12_mul_(t[2], t[3]); - t[2] = frobenius_cube_(t[2]); - t[0] = fp12_mul_(t[2], t[0]); - - return t[0]; -} - -// ============================================================================= -// I/O helpers -// ============================================================================= - -__device__ void load_fp2_(Fp2& out, const u64* p) { - #pragma unroll - for (int i = 0; i < 4; ++i) { out.a0.limbs[i] = p[i]; out.a1.limbs[i] = p[4+i]; } -} - -__device__ void store_fp2_(u64* p, const Fp2& x) { - #pragma unroll - for (int i = 0; i < 4; ++i) { p[i] = x.a0.limbs[i]; p[4+i] = x.a1.limbs[i]; } -} - -__device__ void load_g2_(G2Aff& out, const u64* p) { - load_fp2_(out.x, p); - load_fp2_(out.y, p + 8); - out.inf = (int)p[16]; -} - -__device__ void store_fp12_(u64* p, const Fp12_& x) { - store_fp2_(p + 0, x.c0.b0); - store_fp2_(p + 8, x.c0.b1); - store_fp2_(p + 16, x.c0.b2); - store_fp2_(p + 24, x.c1.b0); - store_fp2_(p + 32, x.c1.b1); - store_fp2_(p + 40, x.c1.b2); -} - -__device__ void load_fp12_(Fp12_& out, const u64* p) { - load_fp2_(out.c0.b0, p + 0); - load_fp2_(out.c0.b1, p + 8); - load_fp2_(out.c0.b2, p + 16); - load_fp2_(out.c1.b0, p + 24); - load_fp2_(out.c1.b1, p + 32); - load_fp2_(out.c1.b2, p + 40); -} - -#endif // LUX_BN254_PAIRING_CUH diff --git a/bn254/gpu/cuda/bn254_pairing_consts_cuda.cuh b/bn254/gpu/cuda/bn254_pairing_consts_cuda.cuh deleted file mode 100644 index ff6c468..0000000 --- a/bn254/gpu/cuda/bn254_pairing_consts_cuda.cuh +++ /dev/null @@ -1,36 +0,0 @@ -// Auto-generated by bn254_gen_pairing_constants (CUDA). Do not edit. -// Source: bn254/cpp/bn254_pairing.cpp (lines 41-119) via -// bn254_gen_pairing_constants. The CPU body is the single producer of -// these limbs; any drift between CPU and GPU fails the determinism test. -#ifndef LUX_BN254_PAIRING_CONST_CUH -#define LUX_BN254_PAIRING_CONST_CUH - -__device__ static const unsigned long long K_NR1P1_A0[4] = {0xaf9ba69633144907ULL, 0xca6b1d7387afb78aULL, 0x11bded5ef08a2087ULL, 0x02f34d751a1f3a7cULL}; -__device__ static const unsigned long long K_NR1P1_A1[4] = {0xa222ae234c492d72ULL, 0xd00f02a4565de15bULL, 0xdc2ff3a253dfc926ULL, 0x10a75716b3899551ULL}; -__device__ static const unsigned long long K_NR1P2_A0[4] = {0xb5773b104563ab30ULL, 0x347f91c8a9aa6454ULL, 0x7a007127242e0991ULL, 0x1956bcd8118214ecULL}; -__device__ static const unsigned long long K_NR1P2_A1[4] = {0x6e849f1ea0aa4757ULL, 0xaa1c7b6d89f89141ULL, 0xb6e713cdfae0ca3aULL, 0x26694fbb4e82ebc3ULL}; -__device__ static const unsigned long long K_NR1P3_A0[4] = {0xe4bbdd0c2936b629ULL, 0xbb30f162e133bacbULL, 0x31a9d1b6f9645366ULL, 0x253570bea500f8ddULL}; -__device__ static const unsigned long long K_NR1P3_A1[4] = {0xa1d77ce45ffe77c7ULL, 0x07affd117826d1dbULL, 0x6d16bd27bb7edc6bULL, 0x2c87200285defeccULL}; -__device__ static const unsigned long long K_NR1P4_A0[4] = {0x7361d77f843abe92ULL, 0xa5bb2bd3273411fbULL, 0x9c941f314b3e2399ULL, 0x15df9cddbb9fd3ecULL}; -__device__ static const unsigned long long K_NR1P4_A1[4] = {0x5dddfd154bd8c949ULL, 0x62cb29a5a4445b60ULL, 0x37bc870a0c7dd2b9ULL, 0x24830a9d3171f0fdULL}; -__device__ static const unsigned long long K_NR1P5_A0[4] = {0xc970692f41690fe7ULL, 0xe240342127694b0bULL, 0x32bee66b83c459e8ULL, 0x12aabced0ab08841ULL}; -__device__ static const unsigned long long K_NR1P5_A1[4] = {0x0d485d2340aebfa9ULL, 0x05193418ab2fcc57ULL, 0xd3b0a40b8a4910f5ULL, 0x2f21ebb535d2925aULL}; - -__device__ static const unsigned long long K_NR2P1[4] = {0xca8d800500fa1bf2ULL, 0xf0c5d61468b39769ULL, 0x0e201271ad0d4418ULL, 0x04290f65bad856e6ULL}; -__device__ static const unsigned long long K_NR2P2[4] = {0x3350c88e13e80b9cULL, 0x7dce557cdb5e56b9ULL, 0x6001b4b8b615564aULL, 0x2682e617020217e0ULL}; -__device__ static const unsigned long long K_NR2P3[4] = {0x68c3488912edefaaULL, 0x8d087f6872aabf4fULL, 0x51e1a24709081231ULL, 0x2259d6b14729c0faULL}; -__device__ static const unsigned long long K_NR2P4[4] = {0x71930c11d782e155ULL, 0xa6bb947cffbe3323ULL, 0xaa303344d4741444ULL, 0x2c3b3f0d26594943ULL}; -__device__ static const unsigned long long K_NR2P5[4] = {0x08cfc388c494f1abULL, 0x19b315148d1373d4ULL, 0x584e90fdcb6c0213ULL, 0x09e1685bdf2f8849ULL}; - -__device__ static const unsigned long long K_NR3P1_A0[4] = {0x365316184e46d97dULL, 0x0af7129ed4c96d9fULL, 0x659da72fca1009b5ULL, 0x08116d8983a20d23ULL}; -__device__ static const unsigned long long K_NR3P1_A1[4] = {0xb1df4af7c39c1939ULL, 0x3d9f02878a73bf7fULL, 0x9b2220928caf0ae0ULL, 0x26684515eff054a6ULL}; -__device__ static const unsigned long long K_NR3P2_A0[4] = {0xc9af22f716ad6badULL, 0xb311782a4aa662b2ULL, 0x19eeaf64e248c7f4ULL, 0x20273e77e3439f82ULL}; -__device__ static const unsigned long long K_NR3P2_A1[4] = {0xacc02860f7ce93acULL, 0x3933d5817ba76b4cULL, 0x69e6188b446c8467ULL, 0x0a46036d4417cc55ULL}; -__device__ static const unsigned long long K_NR3P3_A0[4] = {0x5764af0aaf46471eULL, 0xdc50792e873e0fc1ULL, 0x86a673ff881d04f6ULL, 0x0b2eddb43c30a74cULL}; -__device__ static const unsigned long long K_NR3P3_A1[4] = {0x9a490f32787e8580ULL, 0x8fd16d7ff04af8b1ULL, 0x4b39888ec6027bf2ULL, 0x03dd2e705b52a15dULL}; -__device__ static const unsigned long long K_NR3P4_A0[4] = {0x448a93a57b6762dfULL, 0xbfd62df528fdeadfULL, 0xd858f5d00e9bd47aULL, 0x06b03d4d3476ec58ULL}; -__device__ static const unsigned long long K_NR3P4_A1[4] = {0x2b19daf4bcc936d1ULL, 0xa1a54e7a56f4299fULL, 0xb533eee05adeaef1ULL, 0x170c812b84dda0b2ULL}; -__device__ static const unsigned long long K_NR3P5_A0[4] = {0xe0bc4b2275cf559fULL, 0xc238b945c154e60fULL, 0x803982a5929a7d5eULL, 0x15ce052df7e4a37eULL}; -__device__ static const unsigned long long K_NR3P5_A1[4] = {0x2d28efbdbf3799a7ULL, 0x9b097e3c1ad60773ULL, 0x982d4113af4a535bULL, 0x24e18991e3056063ULL}; - -#endif // LUX_BN254_PAIRING_CONST_CUH diff --git a/bn254/gpu/metal/bn254.metal b/bn254/gpu/metal/bn254.metal deleted file mode 100644 index 03a760d..0000000 --- a/bn254/gpu/metal/bn254.metal +++ /dev/null @@ -1,569 +0,0 @@ -// ============================================================================= -// BN254 (alt_bn128) Metal Compute Shaders -// ============================================================================= -// -// GPU-accelerated elliptic curve operations for BN254 on Apple Silicon. -// Used for Pedersen commitments, PLONK verification, and Groth16 proofs. -// -// BN254 Parameters: -// p = 21888242871839275222246405745257275088696311157297823662689037894645226208583 -// r = 21888242871839275222246405745257275088548364400416034343698204186575808495617 -// G1: y^2 = x^3 + 3 over Fp -// -// References: -// - EIP-196, EIP-197 (Ethereum precompiles) -// - Zcash BN-254 specification -// -// Copyright (C) 2024-2025 Lux Industries Inc. -// SPDX-License-Identifier: Apache-2.0 - -#include -using namespace metal; - -// ============================================================================= -// 256-bit Field Arithmetic (4 x 64-bit limbs) -// ============================================================================= - -// BN254 base field prime p (4 limbs, little-endian) -constant uint64_t BN254_P[4] = { - 0x3C208C16D87CFD47, - 0x97816A916871CA8D, - 0xB85045B68181585D, - 0x30644E72E131A029 -}; - -// p - 2 used for Fermat inversion: z^{-1} = z^{p-2} mod p. -// p_minus_2 = p - 2 (4 limbs, little-endian). -constant uint64_t BN254_P_MINUS_2[4] = { - 0x3C208C16D87CFD45, // 0x3C208C16D87CFD47 - 2 - 0x97816A916871CA8D, - 0xB85045B68181585D, - 0x30644E72E131A029 -}; - -// Montgomery R^2 mod p -constant uint64_t BN254_R2[4] = { - 0xF32CFC5B538AFA89, - 0xB5E71911D44501FB, - 0x47AB1EFF0A417FF6, - 0x06D89F71CAB8351F -}; - -// Montgomery constant: -p^{-1} mod 2^64 -constant uint64_t BN254_INV = 0x87D20782E4866389; - -// Generator points (Montgomery form) -constant uint64_t BN254_G1_X[4] = { - 0xD35D438DC58F0D9D, - 0x0A78EB28F5C70B3D, - 0x666EA36F7879462C, - 0x0E0A77C19A07DF2F -}; - -constant uint64_t BN254_G1_Y[4] = { - 0xA6BA871B8B1E1B3A, - 0x14F1D651EB8E167B, - 0xCCDD46DEF0F28C58, - 0x1C14EF83340FBE5E -}; - -// Fp256 represented as 4 uint64 limbs -struct Fp256 { - uint64_t limbs[4]; -}; - -// G1 affine point -struct G1Affine { - Fp256 x; - Fp256 y; - bool infinity; -}; - -// G1 projective point (Jacobian coordinates) -struct G1Projective { - Fp256 x; - Fp256 y; - Fp256 z; -}; - -// ============================================================================= -// Multi-precision Arithmetic -// ============================================================================= - -inline uint64_t adc(uint64_t a, uint64_t b, thread uint64_t& carry) { - uint64_t result = a + carry; - carry = (result < a) ? 1 : 0; - uint64_t sum = result + b; - carry += (sum < result) ? 1 : 0; - return sum; -} - -inline uint64_t sbb(uint64_t a, uint64_t b, thread uint64_t& borrow) { - uint64_t diff = a - borrow; - borrow = (a < borrow) ? 1 : 0; - uint64_t result = diff - b; - borrow += (diff < b) ? 1 : 0; - return result; -} - -inline void mul64(uint64_t a, uint64_t b, thread uint64_t& lo, thread uint64_t& hi) { - lo = a * b; - hi = mulhi(a, b); -} - -inline int fp256_cmp(thread const Fp256& a, constant uint64_t* b) { - for (int i = 3; i >= 0; i--) { - if (a.limbs[i] < b[i]) return -1; - if (a.limbs[i] > b[i]) return 1; - } - return 0; -} - -// ============================================================================= -// Field Operations -// ============================================================================= - -inline Fp256 fp256_zero() { - Fp256 r; - for (int i = 0; i < 4; i++) r.limbs[i] = 0; - return r; -} - -inline Fp256 fp256_one() { - // R mod p (Montgomery form of 1) - Fp256 r; - r.limbs[0] = 0x4E6E0206CA34BB1E; - r.limbs[1] = 0x7E2F6A58BE66A5E7; - r.limbs[2] = 0x30C1B89EB0E1C70D; - r.limbs[3] = 0x2AE3C0E97F5A0A1D; - return r; -} - -inline bool fp256_is_zero(thread const Fp256& a) { - return a.limbs[0] == 0 && a.limbs[1] == 0 && a.limbs[2] == 0 && a.limbs[3] == 0; -} - -inline void fp256_reduce(thread Fp256& a) { - if (fp256_cmp(a, BN254_P) >= 0) { - uint64_t borrow = 0; - for (int i = 0; i < 4; i++) { - a.limbs[i] = sbb(a.limbs[i], BN254_P[i], borrow); - } - } -} - -inline Fp256 fp256_add(thread const Fp256& a, thread const Fp256& b) { - Fp256 c; - uint64_t carry = 0; - for (int i = 0; i < 4; i++) { - c.limbs[i] = adc(a.limbs[i], b.limbs[i], carry); - } - fp256_reduce(c); - return c; -} - -inline Fp256 fp256_sub(thread const Fp256& a, thread const Fp256& b) { - Fp256 c; - uint64_t borrow = 0; - for (int i = 0; i < 4; i++) { - c.limbs[i] = sbb(a.limbs[i], b.limbs[i], borrow); - } - if (borrow) { - uint64_t carry = 0; - for (int i = 0; i < 4; i++) { - c.limbs[i] = adc(c.limbs[i], BN254_P[i], carry); - } - } - return c; -} - -inline Fp256 fp256_neg(thread const Fp256& a) { - if (fp256_is_zero(a)) return a; - Fp256 c; - uint64_t borrow = 0; - for (int i = 0; i < 4; i++) { - c.limbs[i] = sbb(BN254_P[i], a.limbs[i], borrow); - } - return c; -} - -// Montgomery multiplication -inline Fp256 fp256_mont_mul(thread const Fp256& a, thread const Fp256& b) { - uint64_t t[8] = {0}; - - // Schoolbook multiplication - for (int i = 0; i < 4; i++) { - uint64_t carry = 0; - for (int j = 0; j < 4; j++) { - uint64_t lo, hi; - mul64(a.limbs[i], b.limbs[j], lo, hi); - uint64_t sum = t[i+j] + lo + carry; - carry = (sum < t[i+j]) ? 1 : 0; - carry += hi; - t[i+j] = sum; - } - t[i+4] = carry; - } - - // Montgomery reduction - for (int i = 0; i < 4; i++) { - uint64_t k = t[i] * BN254_INV; - uint64_t carry = 0; - for (int j = 0; j < 4; j++) { - uint64_t lo, hi; - mul64(k, BN254_P[j], lo, hi); - uint64_t sum = t[i+j] + lo + carry; - carry = (sum < t[i+j]) ? 1 : 0; - carry += hi; - t[i+j] = sum; - } - // Propagate carry - for (int j = i + 4; j < 8; j++) { - uint64_t sum = t[j] + carry; - carry = (sum < t[j]) ? 1 : 0; - t[j] = sum; - if (carry == 0) break; - } - } - - Fp256 c; - for (int i = 0; i < 4; i++) { - c.limbs[i] = t[i + 4]; - } - fp256_reduce(c); - return c; -} - -inline Fp256 fp256_square(thread const Fp256& a) { - return fp256_mont_mul(a, a); -} - -inline Fp256 fp256_double(thread const Fp256& a) { - return fp256_add(a, a); -} - -// Fermat's little theorem inversion: z^{-1} = z^{p-2} mod p. -// Square-and-multiply over the bits of (p-2) LSB->MSB. The exponent is -// public + fixed, so we always perform 256 squarings + ~120 conditional -// multiplies. Sufficient for ZK fixed-exponent inversion on the GPU. -inline Fp256 fp256_inverse(thread const Fp256& z) { - Fp256 result = fp256_one(); - Fp256 base = z; - for (int i = 0; i < 4; i++) { - uint64_t e = BN254_P_MINUS_2[i]; - for (int j = 0; j < 64; j++) { - if ((e >> j) & 1ULL) { - result = fp256_mont_mul(result, base); - } - base = fp256_square(base); - } - } - return result; -} - -// ============================================================================= -// G1 Point Operations -// ============================================================================= - -inline G1Affine g1_identity() { - G1Affine p; - p.x = fp256_zero(); - p.y = fp256_zero(); - p.infinity = true; - return p; -} - -inline G1Projective g1_to_projective(thread const G1Affine& p) { - G1Projective r; - r.x = p.x; - r.y = p.y; - r.z = p.infinity ? fp256_zero() : fp256_one(); - return r; -} - -// Constant-time conditional move on G1 projective points. -// dst = mask ? src : dst, where mask is 0 (keep dst) or all-ones (take src). -// Mirrors banderwagon::pt_cmov so both metal kernels share the same idiom. -inline void pt_cmov(thread G1Projective& dst, thread const G1Projective& src, uint64_t mask) { - for (int i = 0; i < 4; i++) { - dst.x.limbs[i] = (dst.x.limbs[i] & ~mask) | (src.x.limbs[i] & mask); - dst.y.limbs[i] = (dst.y.limbs[i] & ~mask) | (src.y.limbs[i] & mask); - dst.z.limbs[i] = (dst.z.limbs[i] & ~mask) | (src.z.limbs[i] & mask); - } -} - -// Constant-time conditional move on G1 affine points (including infinity flag). -// dst = mask ? src : dst, where mask is 0 (keep dst) or all-ones (take src). -inline void pt_cmov_affine(thread G1Affine& dst, thread const G1Affine& src, uint64_t mask) { - for (int i = 0; i < 4; i++) { - dst.x.limbs[i] = (dst.x.limbs[i] & ~mask) | (src.x.limbs[i] & mask); - dst.y.limbs[i] = (dst.y.limbs[i] & ~mask) | (src.y.limbs[i] & mask); - } - bool take_src = (mask != 0ULL); - dst.infinity = take_src ? src.infinity : dst.infinity; -} - -// Branchless Jacobian -> affine. Always performs the inversion + Mont muls so -// timing does not leak whether the input was the point at infinity. The -// is-infinity case is selected via pt_cmov_affine after the unconditional work. -inline G1Affine g1_to_affine(thread const G1Projective& p) { - // Unconditionally compute affine candidate. fp256_inverse(0) yields a - // garbage field element; we discard it via cmov below, so it never reaches - // the caller. - Fp256 z_inv = fp256_inverse(p.z); - Fp256 z_inv2 = fp256_square(z_inv); - Fp256 z_inv3 = fp256_mont_mul(z_inv2, z_inv); - - G1Affine candidate; - candidate.x = fp256_mont_mul(p.x, z_inv2); - candidate.y = fp256_mont_mul(p.y, z_inv3); - candidate.infinity = false; - - // Identity selector: mask = 0xFFFF...FFFF if p.z == 0 else 0. - uint64_t z_zero = (uint64_t)(0ULL - (uint64_t)fp256_is_zero(p.z)); - - G1Affine result = candidate; - G1Affine identity = g1_identity(); - pt_cmov_affine(result, identity, z_zero); - return result; -} - -// Branchless point doubling. Always evaluates the Jacobian doubling formulas -// then cmoves the projective identity in if the input was the point at -// infinity. Eliminates the secret-dependent branch on p.z. -inline G1Projective g1_double(thread const G1Projective& p) { - // Using Jacobian doubling formulas (computed unconditionally) - Fp256 a = fp256_square(p.x); // a = X^2 - Fp256 b = fp256_square(p.y); // b = Y^2 - Fp256 c = fp256_square(b); // c = Y^4 - - Fp256 xb = fp256_add(p.x, b); - Fp256 d = fp256_sub(fp256_square(xb), fp256_add(a, c)); - d = fp256_double(d); // d = 2*((X+Y^2)^2 - X^2 - Y^4) - - Fp256 e = fp256_add(fp256_double(a), a); // e = 3*X^2 - Fp256 f = fp256_square(e); // f = (3*X^2)^2 - - G1Projective dbl; - dbl.x = fp256_sub(f, fp256_double(d)); // X' = f - 2*d - dbl.y = fp256_sub(fp256_mont_mul(e, fp256_sub(d, dbl.x)), - fp256_double(fp256_double(fp256_double(c)))); // Y' = e*(d-X') - 8*c - dbl.z = fp256_double(fp256_mont_mul(p.y, p.z)); // Z' = 2*Y*Z - - // If p.z == 0 (input was the point at infinity), return the projective - // identity. cmov makes the dispatch constant-time. - uint64_t z_zero = (uint64_t)(0ULL - (uint64_t)fp256_is_zero(p.z)); - G1Projective result = dbl; - pt_cmov(result, p, z_zero); - return result; -} - -// Point addition in Jacobian coordinates -- branchless version. -// All conditional cases (P+0, 0+Q, P+P, P+(-P)) are computed unconditionally -// then selected via pt_cmov so the dispatch is constant-time and the GPU -// path is byte-equal to the CPU oracle for every input. -inline G1Projective g1_add(thread const G1Projective& p, thread const G1Projective& q) { - Fp256 z1z1 = fp256_square(p.z); // Z1^2 - Fp256 z2z2 = fp256_square(q.z); // Z2^2 - Fp256 u1 = fp256_mont_mul(p.x, z2z2); // U1 = X1*Z2^2 - Fp256 u2 = fp256_mont_mul(q.x, z1z1); // U2 = X2*Z1^2 - Fp256 s1 = fp256_mont_mul(fp256_mont_mul(p.y, q.z), z2z2); // S1 = Y1*Z2^3 - Fp256 s2 = fp256_mont_mul(fp256_mont_mul(q.y, p.z), z1z1); // S2 = Y2*Z1^3 - - Fp256 h = fp256_sub(u2, u1); // H = U2 - U1 - Fp256 r_val = fp256_sub(s2, s1); // r = S2 - S1 - - Fp256 hh = fp256_square(h); // H^2 - Fp256 hhh = fp256_mont_mul(h, hh); // H^3 - Fp256 v = fp256_mont_mul(u1, hh); // V = U1*H^2 - - G1Projective add_result; - add_result.x = fp256_sub(fp256_sub(fp256_square(r_val), hhh), fp256_double(v)); - add_result.y = fp256_sub(fp256_mont_mul(r_val, fp256_sub(v, add_result.x)), - fp256_mont_mul(s1, hhh)); - add_result.z = fp256_mont_mul(fp256_mont_mul(p.z, q.z), h); - - G1Projective dbl_result = g1_double(p); - - G1Projective inf; - inf.x = fp256_one(); - inf.y = fp256_one(); - inf.z = fp256_zero(); - - // Selectors. mask = 0xFFFF...FFFF if condition true else 0. - uint64_t p_zero = (uint64_t)(0ULL - (uint64_t)fp256_is_zero(p.z)); - uint64_t q_zero = (uint64_t)(0ULL - (uint64_t)fp256_is_zero(q.z)); - uint64_t h_zero = (uint64_t)(0ULL - (uint64_t)fp256_is_zero(h)); - uint64_t r_zero = (uint64_t)(0ULL - (uint64_t)fp256_is_zero(r_val)); - - G1Projective out = add_result; - pt_cmov(out, dbl_result, h_zero & r_zero); // P+P -> doubling - pt_cmov(out, inf, h_zero & ~r_zero); // P+(-P) -> infinity - pt_cmov(out, q, p_zero); // 0+Q -> Q - pt_cmov(out, p, q_zero); // P+0 -> P - return out; -} - -// Constant-time scalar multiplication (Jacobian). -// Mirrors banderwagon::pt_scalar_mul (banderwagon.metal:319-333): -// * iterate LSB->MSB across the 4-limb scalar (256 bits) -// * at each bit: compute acc + base unconditionally, cmov-select via mask -// * always double the base -// Eliminates the secret-dependent `if (bit)` branch and yields output -// byte-equal to the CPU oracle. -inline G1Projective g1_scalar_mul(thread const G1Affine& p, thread const uint64_t scalar[4]) { - G1Projective acc; - acc.x = fp256_one(); - acc.y = fp256_one(); - acc.z = fp256_zero(); // identity - - G1Projective base = g1_to_projective(p); - - for (int i = 0; i < 4; i++) { - uint64_t limb = scalar[i]; - for (int j = 0; j < 64; j++) { - uint64_t bit = (limb >> j) & 1ULL; - uint64_t mask = 0ULL - bit; - G1Projective sum = g1_add(acc, base); - pt_cmov(acc, sum, mask); - base = g1_double(base); - } - } - return acc; -} - -// ============================================================================= -// Pedersen Commitment Kernel -// ============================================================================= - -// Pedersen commitment kernel. -// -// ABI (changed 2026-04-28, LP-137-ORG-LAYOUT B3 fix): -// buffer(0) values : per-commitment scalar value v (4 limbs each, Mont form) -// buffer(1) blinding : per-commitment scalar blinding r (4 limbs each, Mont form) -// buffer(2) G_xy : 8 limbs of generator G (x||y), Mont form, supplied by host -// buffer(3) H_xy : 8 limbs of generator H (x||y), Mont form, supplied by host -// buffer(4) commitments : output (8 limbs each: x||y) -// buffer(5) num_commitments -// -// Host MUST derive G,H independently via bn254 hash-to-curve with brand-neutral -// DSTs ("LUX_PEDERSEN_G", "LUX_PEDERSEN_H") so the on-curve generators have no -// known discrete-log relation. The previous `Hp = g1_double(Gp)` was discrete- -// log-revealing (H = 2G) and broke the hiding property of the commitment. -kernel void pedersen_commit( - device const uint64_t* values [[buffer(0)]], - device const uint64_t* blinding [[buffer(1)]], - device const uint64_t* G_xy [[buffer(2)]], - device const uint64_t* H_xy [[buffer(3)]], - device uint64_t* commitments [[buffer(4)]], - constant uint32_t& num_commitments [[buffer(5)]], - uint index [[thread_position_in_grid]] -) { - if (index >= num_commitments) return; - - uint64_t v[4], r[4]; - for (int i = 0; i < 4; i++) { - v[i] = values[index * 4 + i]; - r[i] = blinding[index * 4 + i]; - } - - // Load host-supplied generators G and H (Mont form, x||y). - G1Affine G; - G1Affine H; - for (int i = 0; i < 4; i++) { - G.x.limbs[i] = G_xy[i]; - G.y.limbs[i] = G_xy[4 + i]; - H.x.limbs[i] = H_xy[i]; - H.y.limbs[i] = H_xy[4 + i]; - } - G.infinity = false; - H.infinity = false; - - // C = v*G + r*H using the constant-time ladder. - G1Projective vG = g1_scalar_mul(G, v); - G1Projective rH = g1_scalar_mul(H, r); - G1Projective C = g1_add(vG, rH); - G1Affine C_affine = g1_to_affine(C); - - uint32_t out_offset = index * 8; - for (int i = 0; i < 4; i++) { - commitments[out_offset + i] = C_affine.x.limbs[i]; - commitments[out_offset + 4 + i] = C_affine.y.limbs[i]; - } -} - -// ============================================================================= -// Batch Point Addition Kernel -// ============================================================================= - -kernel void bn254_batch_add( - device const uint64_t* points_a [[buffer(0)]], // Input points A (8 limbs each) - device const uint64_t* points_b [[buffer(1)]], // Input points B (8 limbs each) - device uint64_t* results [[buffer(2)]], // Output points (8 limbs each) - constant uint32_t& num_points [[buffer(3)]], - uint index [[thread_position_in_grid]] -) { - if (index >= num_points) return; - - uint32_t offset = index * 8; - - // Load points - G1Affine a, b; - for (int i = 0; i < 4; i++) { - a.x.limbs[i] = points_a[offset + i]; - a.y.limbs[i] = points_a[offset + 4 + i]; - b.x.limbs[i] = points_b[offset + i]; - b.y.limbs[i] = points_b[offset + 4 + i]; - } - a.infinity = false; - b.infinity = false; - - // Add points - G1Projective ap = g1_to_projective(a); - G1Projective bp = g1_to_projective(b); - G1Projective sum = g1_add(ap, bp); - G1Affine result = g1_to_affine(sum); - - // Output - for (int i = 0; i < 4; i++) { - results[offset + i] = result.x.limbs[i]; - results[offset + 4 + i] = result.y.limbs[i]; - } -} - -// ============================================================================= -// Batch Scalar Multiplication Kernel (MSM - Multi-Scalar Multiplication) -// ============================================================================= - -kernel void bn254_batch_scalar_mul( - device const uint64_t* points [[buffer(0)]], // Base points (8 limbs each) - device const uint64_t* scalars [[buffer(1)]], // Scalars (4 limbs each) - device uint64_t* results [[buffer(2)]], // Output points (8 limbs each) - constant uint32_t& num_points [[buffer(3)]], - uint index [[thread_position_in_grid]] -) { - if (index >= num_points) return; - - // Load point and scalar - G1Affine p; - uint64_t s[4]; - uint32_t p_offset = index * 8; - uint32_t s_offset = index * 4; - - for (int i = 0; i < 4; i++) { - p.x.limbs[i] = points[p_offset + i]; - p.y.limbs[i] = points[p_offset + 4 + i]; - s[i] = scalars[s_offset + i]; - } - p.infinity = false; - - // Scalar multiplication - G1Projective result = g1_scalar_mul(p, s); - G1Affine result_affine = g1_to_affine(result); - - // Output - for (int i = 0; i < 4; i++) { - results[p_offset + i] = result_affine.x.limbs[i]; - results[p_offset + 4 + i] = result_affine.y.limbs[i]; - } -} diff --git a/bn254/gpu/metal/zk_metal.h b/bn254/gpu/metal/zk_metal.h deleted file mode 100644 index fd618bc..0000000 --- a/bn254/gpu/metal/zk_metal.h +++ /dev/null @@ -1,360 +0,0 @@ -// ============================================================================= -// Metal ZK Accelerator Header -// ============================================================================= -// -// C++ wrapper for Metal compute shaders for ZK cryptographic operations. -// Provides GPU-accelerated Pedersen, Blake3, KZG, and BN254 operations. -// -// Copyright (C) 2024-2025 Lux Industries Inc. -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include -#include -#include -#include - -#ifdef __APPLE__ -#include -#include -#endif - -namespace lux { -namespace crypto { -namespace metal { - -// ============================================================================= -// Type Definitions -// ============================================================================= - -// 256-bit field element (4 x 64-bit limbs) -struct Fr256 { - uint64_t limbs[4]; -}; - -// 384-bit field element (6 x 64-bit limbs) -struct Fp384 { - uint64_t limbs[6]; -}; - -// BN254 G1 affine point -struct BN254G1Affine { - Fr256 x; - Fr256 y; - bool infinity; -}; - -// BLS12-381 G1 affine point -struct BLS12G1Affine { - Fp384 x; - Fp384 y; - bool infinity; -}; - -// Pedersen commitment result -struct PedersenCommitment { - BN254G1Affine point; - bool valid; -}; - -// Blake3 hash output -struct Blake3Digest { - uint8_t bytes[64]; - uint32_t length; -}; - -// KZG commitment -struct KZGCommitment { - BLS12G1Affine point; - bool valid; -}; - -// ============================================================================= -// Metal Context -// ============================================================================= - -class MetalZKContext { -public: - MetalZKContext(); - ~MetalZKContext(); - - // Initialization - bool initialize(); - bool isAvailable() const; - - // Device info - const char* getDeviceName() const; - uint64_t getMaxMemory() const; - - // ========================================================================= - // Blake3 Operations - // ========================================================================= - - // Hash single input - Blake3Digest blake3Hash256(const uint8_t* data, uint32_t length); - Blake3Digest blake3Hash512(const uint8_t* data, uint32_t length); - - // Batch hash multiple inputs - std::vector blake3BatchHash( - const std::vector& inputs, - const std::vector& lengths - ); - - // XOF (extendable output) - std::vector blake3XOF( - const uint8_t* data, - uint32_t inputLength, - uint32_t outputLength - ); - - // Merkle tree root - Blake3Digest blake3MerkleRoot( - const std::vector& leaves - ); - - // ========================================================================= - // BN254/Pedersen Operations - // ========================================================================= - - // Single Pedersen commitment - PedersenCommitment pedersenCommit( - const Fr256& value, - const Fr256& blindingFactor - ); - - // Batch Pedersen commitments - std::vector pedersenBatchCommit( - const std::vector& values, - const std::vector& blindingFactors - ); - - // BN254 scalar multiplication - BN254G1Affine bn254ScalarMul( - const BN254G1Affine& point, - const Fr256& scalar - ); - - // BN254 batch scalar multiplication (MSM) - std::vector bn254BatchScalarMul( - const std::vector& points, - const std::vector& scalars - ); - - // BN254 point addition - BN254G1Affine bn254Add( - const BN254G1Affine& a, - const BN254G1Affine& b - ); - - // ========================================================================= - // KZG/BLS12-381 Operations - // ========================================================================= - - // Convert blob to polynomial coefficients - std::vector blobToPolynomial( - const uint8_t* blob, - uint32_t blobSize - ); - - // Compute KZG commitment using MSM - KZGCommitment kzgCommit( - const std::vector& coefficients, - const std::vector& trustedSetup - ); - - // Compute KZG opening proof - std::pair kzgComputeProof( - const std::vector& polynomial, - const Fr256& point, - const std::vector& trustedSetup - ); - - // FFT over scalar field - std::vector fft( - const std::vector& coefficients, - bool inverse = false - ); - - // Inverse FFT - std::vector ifft(const std::vector& values); - - // ========================================================================= - // Poseidon2 Operations (BN254/Fr) - // ========================================================================= - - // Hash pair (2-to-1 compression for Merkle trees) - Fr256 poseidon2HashPair(const Fr256& left, const Fr256& right); - - // Batch hash pairs - std::vector poseidon2BatchHashPair( - const std::vector& left, - const std::vector& right - ); - - // Merkle tree layer (hash adjacent pairs) - std::vector poseidon2MerkleLayer(const std::vector& current); - - // Build complete Merkle tree - std::vector poseidon2MerkleTree(const std::vector& leaves); - - // Compute commitment: Poseidon2(value, blinding, salt) - Fr256 poseidon2Commitment(const Fr256& value, const Fr256& blinding, const Fr256& salt); - - // Compute nullifier: Poseidon2(key, commitment, index) - Fr256 poseidon2Nullifier(const Fr256& key, const Fr256& commitment, const Fr256& index); - - // Batch commitments - std::vector poseidon2BatchCommitment( - const std::vector& values, - const std::vector& blindings, - const std::vector& salts - ); - - // Batch nullifiers - std::vector poseidon2BatchNullifier( - const std::vector& keys, - const std::vector& commitments, - const std::vector& indices - ); - - // ========================================================================= - // Goldilocks/FRI Operations (for STARK) - // ========================================================================= - - // FRI fold layer - std::vector friFoldLayer( - const std::vector& evals, - uint64_t alpha, - uint64_t omega_inv - ); - -private: -#ifdef __APPLE__ - id device_; - id commandQueue_; - id cryptoLibrary_; // lux_crypto.metallib (BLS, BLAKE3, KZG) - id zkLibrary_; // lux_zk.metallib (BN254, Poseidon, MSM, Goldilocks) - - // Compute pipelines - id blake3HashPipeline_; - id blake3BatchPipeline_; - id blake3MerklePipeline_; - id blake3XofPipeline_; - - id pedersenCommitPipeline_; - id bn254BatchAddPipeline_; - id bn254BatchMulPipeline_; - - // Pedersen generators G, H -- pre-computed once at init via bn254 - // hash-to-curve with brand-neutral DSTs ("LUX_PEDERSEN_G", - // "LUX_PEDERSEN_H") so they have no known discrete-log relation. - // Each buffer carries 8 x uint64 (Mont-form x || Mont-form y). - id pedersenGBuffer_; - id pedersenHBuffer_; - - id kzgMsmPipeline_; - id kzgFftPipeline_; - id kzgBitReversePipeline_; - id blobToFieldPipeline_; - - // Poseidon2 pipelines - id poseidon2HashPairPipeline_; - id poseidon2MerkleLayerPipeline_; - id poseidon2CommitmentPipeline_; - id poseidon2NullifierPipeline_; - - // FRI/Goldilocks pipelines - id friFoldLayerPipeline_; - id goldilocksBatchMulPipeline_; - - // Initialize compute pipelines - bool initBlake3Pipelines(); - bool initBN254Pipelines(); - bool initKZGPipelines(); - bool initPoseidon2Pipelines(); - bool initFRIPipelines(); - - // Helper to compile shader function - id createPipeline(const char* functionName); -#endif - - bool initialized_; - std::string deviceName_; -}; - -// ============================================================================= -// Singleton Access -// ============================================================================= - -// Get global Metal context (lazy initialization) -MetalZKContext& getMetalZKContext(); - -// Check if Metal acceleration is available -bool isMetalAvailable(); - -// ============================================================================= -// C API for CGO Bridge -// ============================================================================= - -extern "C" { - -// Blake3 -int metal_blake3_hash256(const uint8_t* data, uint32_t len, uint8_t* out); -int metal_blake3_hash512(const uint8_t* data, uint32_t len, uint8_t* out); -int metal_blake3_xof(const uint8_t* data, uint32_t inLen, uint8_t* out, uint32_t outLen); -int metal_blake3_merkle_root(const uint8_t* leaves, uint32_t numLeaves, uint8_t* out); - -// Pedersen (BN254) -int metal_pedersen_commit( - const uint64_t* value, - const uint64_t* blinding, - uint64_t* commitmentX, - uint64_t* commitmentY -); - -int metal_pedersen_batch_commit( - const uint64_t* values, - const uint64_t* blindings, - uint32_t count, - uint64_t* commitments -); - -// BN254 -int metal_bn254_scalar_mul( - const uint64_t* pointX, - const uint64_t* pointY, - const uint64_t* scalar, - uint64_t* resultX, - uint64_t* resultY -); - -int metal_bn254_add( - const uint64_t* ax, const uint64_t* ay, - const uint64_t* bx, const uint64_t* by, - uint64_t* rx, uint64_t* ry -); - -// KZG -int metal_kzg_commit( - const uint64_t* coefficients, - uint32_t numCoeffs, - const uint64_t* trustedSetup, - uint64_t* commitmentX, - uint64_t* commitmentY -); - -int metal_fft( - uint64_t* coefficients, - uint32_t n, - int inverse -); - -// Availability check -int metal_is_available(); - -} // extern "C" - -} // namespace metal -} // namespace crypto -} // namespace lux diff --git a/bn254/gpu/metal/zk_metal.mm b/bn254/gpu/metal/zk_metal.mm deleted file mode 100644 index 5803484..0000000 --- a/bn254/gpu/metal/zk_metal.mm +++ /dev/null @@ -1,988 +0,0 @@ -// ============================================================================= -// Metal ZK Accelerator Implementation -// ============================================================================= -// -// Copyright (C) 2024-2025 Lux Industries Inc. -// SPDX-License-Identifier: Apache-2.0 - -#include "zk_metal.h" - -// First-party CPU bn254 -- used to derive Pedersen generators G, H once at -// protocol init via hash-to-curve with brand-neutral DSTs. The MTLBuffers -// for each generator are then reused for every dispatch of pedersen_commit. -#include "bn254_fp.hpp" -#include "bn254_g1.hpp" -#include "bn254_hash_to_curve.hpp" - -#include -#include - -#ifdef __APPLE__ -#import -#import -#endif - -namespace lux { -namespace crypto { -namespace metal { - -namespace { - -// Pack a single bn254 G1 affine point (Montgomery form) into 8 LE uint64 -// limbs (x_lo..x_hi || y_lo..y_hi). Matches the Metal kernel's expected -// G_xy / H_xy layout in bn254.metal. -void pack_g1_affine_le(const lux::crypto::bn254::G1Affine& p, - uint64_t out_xy[8]) { - for (int i = 0; i < 4; i++) { - out_xy[i] = p.x.limbs[i]; - out_xy[4 + i] = p.y.limbs[i]; - } -} - -} // namespace - -// ============================================================================= -// MetalZKContext Implementation -// ============================================================================= - -MetalZKContext::MetalZKContext() : initialized_(false) { -#ifdef __APPLE__ - device_ = nil; - commandQueue_ = nil; - cryptoLibrary_ = nil; - zkLibrary_ = nil; - pedersenGBuffer_ = nil; - pedersenHBuffer_ = nil; -#endif -} - -MetalZKContext::~MetalZKContext() { -#ifdef __APPLE__ - // ARC handles cleanup -#endif -} - -bool MetalZKContext::initialize() { -#ifdef __APPLE__ - @autoreleasepool { - // Get default Metal device - device_ = MTLCreateSystemDefaultDevice(); - if (!device_) { - return false; - } - - deviceName_ = std::string([[device_ name] UTF8String]); - - // Create command queue - commandQueue_ = [device_ newCommandQueue]; - if (!commandQueue_) { - return false; - } - - NSError* error = nil; - - // ================================================================= - // Load crypto metallib (BLS12-381, BLAKE3, KZG) - // ================================================================= - NSArray* cryptoLibPaths = @[ - @"/usr/local/share/lux/crypto/lux_crypto.metallib", - [[NSBundle mainBundle] pathForResource:@"lux_crypto" ofType:@"metallib"] ?: @"" - ]; - - for (NSString* libPath in cryptoLibPaths) { - if (libPath.length > 0 && [[NSFileManager defaultManager] fileExistsAtPath:libPath]) { - NSURL* libURL = [NSURL fileURLWithPath:libPath]; - cryptoLibrary_ = [device_ newLibraryWithURL:libURL error:&error]; - if (cryptoLibrary_) { - NSLog(@"Loaded crypto metallib from: %@", libPath); - break; - } - } - } - - // ================================================================= - // Load ZK metallib (BN254, Poseidon, MSM, Goldilocks) - // ================================================================= - NSArray* zkLibPaths = @[ - @"/usr/local/share/lux/crypto/lux_zk.metallib", - [[NSBundle mainBundle] pathForResource:@"lux_zk" ofType:@"metallib"] ?: @"" - ]; - - for (NSString* libPath in zkLibPaths) { - if (libPath.length > 0 && [[NSFileManager defaultManager] fileExistsAtPath:libPath]) { - NSURL* libURL = [NSURL fileURLWithPath:libPath]; - zkLibrary_ = [device_ newLibraryWithURL:libURL error:&error]; - if (zkLibrary_) { - NSLog(@"Loaded ZK metallib from: %@", libPath); - break; - } - } - } - - // ================================================================= - // Fallback: Try default library (built into app bundle) - // ================================================================= - if (!cryptoLibrary_ && !zkLibrary_) { - id defaultLib = [device_ newDefaultLibrary]; - if (defaultLib) { - // Use default library for both if metallibs not found - cryptoLibrary_ = defaultLib; - zkLibrary_ = defaultLib; - NSLog(@"Using default Metal library for all shaders"); - } - } - - // ================================================================= - // Last resort: Compile from source at runtime - // ================================================================= - if (!cryptoLibrary_ || !zkLibrary_) { - NSString* shaderPath = @"/usr/local/share/lux/crypto/shaders"; - - // Crypto shaders - if (!cryptoLibrary_) { - NSArray* cryptoShaders = @[@"blake3.metal", @"kzg.metal", @"bls12_381.metal"]; - NSMutableString* cryptoSource = [NSMutableString string]; - for (NSString* file in cryptoShaders) { - NSString* path = [shaderPath stringByAppendingPathComponent:file]; - NSString* source = [NSString stringWithContentsOfFile:path - encoding:NSUTF8StringEncoding - error:&error]; - if (source) { - [cryptoSource appendString:source]; - [cryptoSource appendString:@"\n"]; - } - } - if (cryptoSource.length > 0) { - MTLCompileOptions* options = [[MTLCompileOptions alloc] init]; - if (@available(macOS 15.0, *)) { - options.mathMode = MTLMathModeFast; - } else { -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wdeprecated-declarations" - options.fastMathEnabled = YES; -#pragma clang diagnostic pop - } - cryptoLibrary_ = [device_ newLibraryWithSource:cryptoSource - options:options - error:&error]; - } - } - - // ZK shaders - if (!zkLibrary_) { - NSArray* zkShaders = @[@"bn254.metal", @"goldilocks.metal", @"poseidon.metal", - @"poseidon2_bn254.metal", @"msm.metal"]; - NSMutableString* zkSource = [NSMutableString string]; - for (NSString* file in zkShaders) { - NSString* path = [shaderPath stringByAppendingPathComponent:file]; - NSString* source = [NSString stringWithContentsOfFile:path - encoding:NSUTF8StringEncoding - error:&error]; - if (source) { - [zkSource appendString:source]; - [zkSource appendString:@"\n"]; - } - } - if (zkSource.length > 0) { - MTLCompileOptions* options = [[MTLCompileOptions alloc] init]; - if (@available(macOS 15.0, *)) { - options.mathMode = MTLMathModeFast; - } else { -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wdeprecated-declarations" - options.fastMathEnabled = YES; -#pragma clang diagnostic pop - } - zkLibrary_ = [device_ newLibraryWithSource:zkSource - options:options - error:&error]; - } - } - } - - // At least one library must be available - if (!cryptoLibrary_ && !zkLibrary_) { - NSLog(@"Failed to load any Metal shader libraries: %@", error); - return false; - } - - // Initialize pipelines (crypto shaders from cryptoLibrary_) - if (cryptoLibrary_) { - if (!initBlake3Pipelines()) { - NSLog(@"Warning: BLAKE3 pipelines not available"); - } - if (!initKZGPipelines()) { - NSLog(@"Warning: KZG pipelines not available"); - } - } - - // Initialize pipelines (ZK shaders from zkLibrary_) - if (zkLibrary_) { - if (!initBN254Pipelines()) { - NSLog(@"Warning: BN254 pipelines not available"); - } - initPoseidon2Pipelines(); // Optional, don't fail if not present - initFRIPipelines(); // Optional, don't fail if not present - } - - initialized_ = true; - return true; - } -#else - return false; -#endif -} - -bool MetalZKContext::isAvailable() const { - return initialized_; -} - -const char* MetalZKContext::getDeviceName() const { - return deviceName_.c_str(); -} - -uint64_t MetalZKContext::getMaxMemory() const { -#ifdef __APPLE__ - if (device_) { - return [device_ recommendedMaxWorkingSetSize]; - } -#endif - return 0; -} - -#ifdef __APPLE__ - -id MetalZKContext::createPipeline(const char* functionName) { - @autoreleasepool { - NSError* error = nil; - NSString* name = [NSString stringWithUTF8String:functionName]; - id function = nil; - - // Try crypto library first (BLS, BLAKE3, KZG) - if (cryptoLibrary_) { - function = [cryptoLibrary_ newFunctionWithName:name]; - } - - // Then try ZK library (BN254, Poseidon, MSM, Goldilocks) - if (!function && zkLibrary_) { - function = [zkLibrary_ newFunctionWithName:name]; - } - - if (!function) { - // Not an error - function may legitimately not exist in loaded shaders - return nil; - } - - id pipeline = - [device_ newComputePipelineStateWithFunction:function error:&error]; - - if (!pipeline) { - NSLog(@"Failed to create pipeline for %@: %@", name, error); - return nil; - } - - return pipeline; - } -} - -bool MetalZKContext::initBlake3Pipelines() { - blake3HashPipeline_ = createPipeline("blake3_hash_block"); - blake3BatchPipeline_ = createPipeline("blake3_batch_hash"); - blake3MerklePipeline_ = createPipeline("blake3_merge_nodes"); - blake3XofPipeline_ = createPipeline("blake3_xof"); - - // Some pipelines are optional - return blake3HashPipeline_ != nil || blake3BatchPipeline_ != nil; -} - -bool MetalZKContext::initBN254Pipelines() { - pedersenCommitPipeline_ = createPipeline("pedersen_commit"); - bn254BatchAddPipeline_ = createPipeline("bn254_batch_add"); - bn254BatchMulPipeline_ = createPipeline("bn254_batch_scalar_mul"); - - // ----------------------------------------------------------------- - // Pedersen generators G, H. - // - // Derived at protocol init via bn254 RFC-9380 hash-to-curve with - // brand-neutral DSTs that match luxfi/crypto/pedersen/pedersen.go: - // - // G = HashToG1("seed_g", "LUX_PEDERSEN_G") - // H = HashToG1("seed_h", "LUX_PEDERSEN_H") - // - // The seed ("seed_g" / "seed_h") is a fixed deterministic string; - // the DST enforces independence and binds to the Lux brand. - // Identical DSTs on both sides (CPU oracle in pedersen.go, GPU kernel - // via these MTLBuffers) yield byte-equal generators. - // - // Once landed, these buffers are reused for every pedersenCommit - // dispatch -- there is no per-call cost. - // ----------------------------------------------------------------- - { - using lux::crypto::bn254::G1Affine; - using lux::crypto::bn254::h2c::hash_to_curve_g1; - - const uint8_t seed_g[] = {'s','e','e','d','_','g'}; - const uint8_t seed_h[] = {'s','e','e','d','_','h'}; - const uint8_t dst_g[] = {'L','U','X','_','P','E','D','E','R','S','E','N','_','G'}; - const uint8_t dst_h[] = {'L','U','X','_','P','E','D','E','R','S','E','N','_','H'}; - - G1Affine G = hash_to_curve_g1({seed_g, sizeof(seed_g)}, - {dst_g, sizeof(dst_g)}); - G1Affine H = hash_to_curve_g1({seed_h, sizeof(seed_h)}, - {dst_h, sizeof(dst_h)}); - - uint64_t G_xy[8]; - uint64_t H_xy[8]; - pack_g1_affine_le(G, G_xy); - pack_g1_affine_le(H, H_xy); - - pedersenGBuffer_ = [device_ newBufferWithBytes:G_xy - length:sizeof(G_xy) - options:MTLResourceStorageModeShared]; - pedersenHBuffer_ = [device_ newBufferWithBytes:H_xy - length:sizeof(H_xy) - options:MTLResourceStorageModeShared]; - } - - return pedersenCommitPipeline_ != nil - && pedersenGBuffer_ != nil - && pedersenHBuffer_ != nil; -} - -bool MetalZKContext::initKZGPipelines() { - kzgMsmPipeline_ = createPipeline("kzg_msm_bucket_accumulate"); - kzgFftPipeline_ = createPipeline("kzg_fft_butterfly"); - kzgBitReversePipeline_ = createPipeline("kzg_fft_bit_reverse"); - blobToFieldPipeline_ = createPipeline("blob_to_field_elements"); - - return true; // KZG pipelines are optional for basic functionality -} - -bool MetalZKContext::initPoseidon2Pipelines() { - poseidon2HashPairPipeline_ = createPipeline("poseidon2_hash_pair"); - poseidon2MerkleLayerPipeline_ = createPipeline("poseidon2_merkle_layer"); - poseidon2CommitmentPipeline_ = createPipeline("poseidon2_commitment"); - poseidon2NullifierPipeline_ = createPipeline("poseidon2_nullifier"); - - // At least hash pair should be available - return poseidon2HashPairPipeline_ != nil; -} - -bool MetalZKContext::initFRIPipelines() { - friFoldLayerPipeline_ = createPipeline("fri_fold_layer"); - goldilocksBatchMulPipeline_ = createPipeline("goldilocks_batch_mul"); - - return true; // FRI pipelines are optional -} - -#endif // __APPLE__ - -// ============================================================================= -// Blake3 Operations -// ============================================================================= - -Blake3Digest MetalZKContext::blake3Hash256(const uint8_t* data, uint32_t length) { - Blake3Digest result = {}; - result.length = 32; - -#ifdef __APPLE__ - if (!initialized_ || !blake3HashPipeline_) { - return result; - } - - @autoreleasepool { - // Create buffers - id inputBuffer = [device_ newBufferWithBytes:data - length:length - options:MTLResourceStorageModeShared]; - - id outputBuffer = [device_ newBufferWithLength:32 - options:MTLResourceStorageModeShared]; - - id lengthBuffer = [device_ newBufferWithBytes:&length - length:sizeof(uint32_t) - options:MTLResourceStorageModeShared]; - - // Create command buffer - id commandBuffer = [commandQueue_ commandBuffer]; - id encoder = [commandBuffer computeCommandEncoder]; - - [encoder setComputePipelineState:blake3HashPipeline_]; - [encoder setBuffer:inputBuffer offset:0 atIndex:0]; - [encoder setBuffer:outputBuffer offset:0 atIndex:1]; - [encoder setBuffer:lengthBuffer offset:0 atIndex:2]; - - // Dispatch - MTLSize gridSize = MTLSizeMake(1, 1, 1); - MTLSize threadGroupSize = MTLSizeMake(1, 1, 1); - [encoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize]; - - [encoder endEncoding]; - [commandBuffer commit]; - [commandBuffer waitUntilCompleted]; - - // Copy result - memcpy(result.bytes, [outputBuffer contents], 32); - // NOTE: Blake3Digest has no `valid` field; the legacy code referenced - // a non-existent member. Phase 3 redesigns this struct. - } -#endif - - return result; -} - -Blake3Digest MetalZKContext::blake3Hash512(const uint8_t* data, uint32_t length) { - Blake3Digest result = {}; - result.length = 64; - - // Similar to hash256 but with 64-byte output - // Implementation follows same pattern - - return result; -} - -std::vector MetalZKContext::blake3BatchHash( - const std::vector& inputs, - const std::vector& lengths -) { - std::vector results(inputs.size()); - -#ifdef __APPLE__ - if (!initialized_ || !blake3BatchPipeline_) { - return results; - } - - @autoreleasepool { - // Flatten inputs into single buffer with offsets - uint32_t totalSize = 0; - std::vector offsets(inputs.size()); - for (size_t i = 0; i < inputs.size(); i++) { - offsets[i] = totalSize; - totalSize += lengths[i]; - } - - std::vector flatData(totalSize); - for (size_t i = 0; i < inputs.size(); i++) { - memcpy(&flatData[offsets[i]], inputs[i], lengths[i]); - } - - // Create buffers - id inputBuffer = [device_ newBufferWithBytes:flatData.data() - length:totalSize - options:MTLResourceStorageModeShared]; - - id outputBuffer = [device_ newBufferWithLength:inputs.size() * 32 - options:MTLResourceStorageModeShared]; - - id offsetBuffer = [device_ newBufferWithBytes:offsets.data() - length:offsets.size() * sizeof(uint32_t) - options:MTLResourceStorageModeShared]; - - id lengthBuffer = [device_ newBufferWithBytes:lengths.data() - length:lengths.size() * sizeof(uint32_t) - options:MTLResourceStorageModeShared]; - - // Dispatch - id commandBuffer = [commandQueue_ commandBuffer]; - id encoder = [commandBuffer computeCommandEncoder]; - - [encoder setComputePipelineState:blake3BatchPipeline_]; - [encoder setBuffer:inputBuffer offset:0 atIndex:0]; - [encoder setBuffer:outputBuffer offset:0 atIndex:1]; - [encoder setBuffer:offsetBuffer offset:0 atIndex:2]; - [encoder setBuffer:lengthBuffer offset:0 atIndex:3]; - - NSUInteger threadGroupWidth = [blake3BatchPipeline_ maxTotalThreadsPerThreadgroup]; - MTLSize gridSize = MTLSizeMake(inputs.size(), 1, 1); - MTLSize threadGroupSize = MTLSizeMake(std::min((NSUInteger)inputs.size(), threadGroupWidth), 1, 1); - [encoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize]; - - [encoder endEncoding]; - [commandBuffer commit]; - [commandBuffer waitUntilCompleted]; - - // Copy results - uint32_t* output = (uint32_t*)[outputBuffer contents]; - for (size_t i = 0; i < inputs.size(); i++) { - results[i].length = 32; - memcpy(results[i].bytes, &output[i * 8], 32); - } - } -#endif - - return results; -} - -// ============================================================================= -// Pedersen Operations -// ============================================================================= - -PedersenCommitment MetalZKContext::pedersenCommit( - const Fr256& value, - const Fr256& blindingFactor -) { - PedersenCommitment result = {}; - -#ifdef __APPLE__ - if (!initialized_ || !pedersenCommitPipeline_ - || !pedersenGBuffer_ || !pedersenHBuffer_) { - return result; - } - - @autoreleasepool { - // Per-call buffers: scalars + output. G/H are bound from the - // pre-computed context buffers (one MTLBuffer alloc total, not - // one per dispatch). - id valueBuffer = [device_ newBufferWithBytes:value.limbs - length:32 - options:MTLResourceStorageModeShared]; - id blindBuffer = [device_ newBufferWithBytes:blindingFactor.limbs - length:32 - options:MTLResourceStorageModeShared]; - id outputBuffer = [device_ newBufferWithLength:64 - options:MTLResourceStorageModeShared]; - uint32_t count = 1; - id countBuffer = [device_ newBufferWithBytes:&count - length:sizeof(uint32_t) - options:MTLResourceStorageModeShared]; - - id commandBuffer = [commandQueue_ commandBuffer]; - id encoder = [commandBuffer computeCommandEncoder]; - - // ABI: (values, blinding, G_xy, H_xy, commitments, num) per - // pedersen_commit kernel signature in bn254.metal. - [encoder setComputePipelineState:pedersenCommitPipeline_]; - [encoder setBuffer:valueBuffer offset:0 atIndex:0]; - [encoder setBuffer:blindBuffer offset:0 atIndex:1]; - [encoder setBuffer:pedersenGBuffer_ offset:0 atIndex:2]; - [encoder setBuffer:pedersenHBuffer_ offset:0 atIndex:3]; - [encoder setBuffer:outputBuffer offset:0 atIndex:4]; - [encoder setBuffer:countBuffer offset:0 atIndex:5]; - - MTLSize gridSize = MTLSizeMake(1, 1, 1); - MTLSize threadGroupSize = MTLSizeMake(1, 1, 1); - [encoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize]; - - [encoder endEncoding]; - [commandBuffer commit]; - [commandBuffer waitUntilCompleted]; - - uint64_t* output = (uint64_t*)[outputBuffer contents]; - memcpy(result.point.x.limbs, output, 32); - memcpy(result.point.y.limbs, output + 4, 32); - result.point.infinity = false; - result.valid = true; - } -#endif - - return result; -} - -// ============================================================================= -// Singleton Access -// ============================================================================= - -MetalZKContext& getMetalZKContext() { - static MetalZKContext context; - static bool initialized = false; - if (!initialized) { - context.initialize(); - initialized = true; - } - return context; -} - -bool isMetalAvailable() { - return getMetalZKContext().isAvailable(); -} - -// ============================================================================= -// C API Implementation -// ============================================================================= - -extern "C" { - -int metal_is_available() { - return isMetalAvailable() ? 1 : 0; -} - -int metal_blake3_hash256(const uint8_t* data, uint32_t len, uint8_t* out) { - auto& ctx = getMetalZKContext(); - if (!ctx.isAvailable()) return -1; - - Blake3Digest digest = ctx.blake3Hash256(data, len); - memcpy(out, digest.bytes, 32); - return 0; -} - -int metal_blake3_hash512(const uint8_t* data, uint32_t len, uint8_t* out) { - auto& ctx = getMetalZKContext(); - if (!ctx.isAvailable()) return -1; - - Blake3Digest digest = ctx.blake3Hash512(data, len); - memcpy(out, digest.bytes, 64); - return 0; -} - -int metal_pedersen_commit( - const uint64_t* value, - const uint64_t* blinding, - uint64_t* commitmentX, - uint64_t* commitmentY -) { - auto& ctx = getMetalZKContext(); - if (!ctx.isAvailable()) return -1; - - Fr256 v, b; - memcpy(v.limbs, value, 32); - memcpy(b.limbs, blinding, 32); - - PedersenCommitment commit = ctx.pedersenCommit(v, b); - if (!commit.valid) return -1; - - memcpy(commitmentX, commit.point.x.limbs, 32); - memcpy(commitmentY, commit.point.y.limbs, 32); - return 0; -} - -int metal_bn254_scalar_mul( - const uint64_t* pointX, - const uint64_t* pointY, - const uint64_t* scalar, - uint64_t* resultX, - uint64_t* resultY -) { - auto& ctx = getMetalZKContext(); - if (!ctx.isAvailable()) return -1; - - BN254G1Affine point; - Fr256 s; - memcpy(point.x.limbs, pointX, 32); - memcpy(point.y.limbs, pointY, 32); - memcpy(s.limbs, scalar, 32); - point.infinity = false; - - BN254G1Affine result = ctx.bn254ScalarMul(point, s); - memcpy(resultX, result.x.limbs, 32); - memcpy(resultY, result.y.limbs, 32); - return 0; -} - -// ============================================================================= -// Poseidon2 C++ Method Implementations -// ============================================================================= - -std::vector MetalZKContext::poseidon2BatchHashPair( - const std::vector& left, - const std::vector& right -) { - std::vector results(left.size()); - -#ifdef __APPLE__ - if (!initialized_ || !poseidon2HashPairPipeline_ || left.size() != right.size()) { - return results; - } - - @autoreleasepool { - uint32_t count = (uint32_t)left.size(); - size_t elemSize = sizeof(Fr256); - - id leftBuffer = [device_ newBufferWithBytes:left.data() - length:count * elemSize - options:MTLResourceStorageModeShared]; - - id rightBuffer = [device_ newBufferWithBytes:right.data() - length:count * elemSize - options:MTLResourceStorageModeShared]; - - id outputBuffer = [device_ newBufferWithLength:count * elemSize - options:MTLResourceStorageModeShared]; - - id commandBuffer = [commandQueue_ commandBuffer]; - id encoder = [commandBuffer computeCommandEncoder]; - - [encoder setComputePipelineState:poseidon2HashPairPipeline_]; - [encoder setBuffer:leftBuffer offset:0 atIndex:0]; - [encoder setBuffer:rightBuffer offset:0 atIndex:1]; - [encoder setBuffer:outputBuffer offset:0 atIndex:2]; - - NSUInteger threadGroupWidth = [poseidon2HashPairPipeline_ maxTotalThreadsPerThreadgroup]; - MTLSize gridSize = MTLSizeMake(count, 1, 1); - MTLSize threadGroupSize = MTLSizeMake(MIN((NSUInteger)count, threadGroupWidth), 1, 1); - [encoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize]; - - [encoder endEncoding]; - [commandBuffer commit]; - [commandBuffer waitUntilCompleted]; - - memcpy(results.data(), [outputBuffer contents], count * elemSize); - } -#endif - - return results; -} - -std::vector MetalZKContext::poseidon2MerkleLayer(const std::vector& current) { - std::vector results(current.size() / 2); - -#ifdef __APPLE__ - if (!initialized_ || !poseidon2MerkleLayerPipeline_ || current.size() < 2) { - return results; - } - - @autoreleasepool { - uint32_t currentSize = (uint32_t)current.size(); - uint32_t outputCount = currentSize / 2; - size_t elemSize = sizeof(Fr256); - - id currentBuffer = [device_ newBufferWithBytes:current.data() - length:currentSize * elemSize - options:MTLResourceStorageModeShared]; - - id outputBuffer = [device_ newBufferWithLength:outputCount * elemSize - options:MTLResourceStorageModeShared]; - - id sizeBuffer = [device_ newBufferWithBytes:¤tSize - length:sizeof(uint32_t) - options:MTLResourceStorageModeShared]; - - id commandBuffer = [commandQueue_ commandBuffer]; - id encoder = [commandBuffer computeCommandEncoder]; - - [encoder setComputePipelineState:poseidon2MerkleLayerPipeline_]; - [encoder setBuffer:currentBuffer offset:0 atIndex:0]; - [encoder setBuffer:outputBuffer offset:0 atIndex:1]; - [encoder setBuffer:sizeBuffer offset:0 atIndex:2]; - - NSUInteger threadGroupWidth = [poseidon2MerkleLayerPipeline_ maxTotalThreadsPerThreadgroup]; - MTLSize gridSize = MTLSizeMake(outputCount, 1, 1); - MTLSize threadGroupSize = MTLSizeMake(MIN((NSUInteger)outputCount, threadGroupWidth), 1, 1); - [encoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize]; - - [encoder endEncoding]; - [commandBuffer commit]; - [commandBuffer waitUntilCompleted]; - - memcpy(results.data(), [outputBuffer contents], outputCount * elemSize); - } -#endif - - return results; -} - -std::vector MetalZKContext::poseidon2BatchCommitment( - const std::vector& values, - const std::vector& blindings, - const std::vector& salts -) { - std::vector results(values.size()); - -#ifdef __APPLE__ - if (!initialized_ || !poseidon2CommitmentPipeline_) { - return results; - } - - @autoreleasepool { - uint32_t count = (uint32_t)values.size(); - size_t elemSize = sizeof(Fr256); - - id valueBuffer = [device_ newBufferWithBytes:values.data() - length:count * elemSize - options:MTLResourceStorageModeShared]; - - id blindBuffer = [device_ newBufferWithBytes:blindings.data() - length:count * elemSize - options:MTLResourceStorageModeShared]; - - id saltBuffer = [device_ newBufferWithBytes:salts.data() - length:count * elemSize - options:MTLResourceStorageModeShared]; - - id outputBuffer = [device_ newBufferWithLength:count * elemSize - options:MTLResourceStorageModeShared]; - - id commandBuffer = [commandQueue_ commandBuffer]; - id encoder = [commandBuffer computeCommandEncoder]; - - [encoder setComputePipelineState:poseidon2CommitmentPipeline_]; - [encoder setBuffer:valueBuffer offset:0 atIndex:0]; - [encoder setBuffer:blindBuffer offset:0 atIndex:1]; - [encoder setBuffer:saltBuffer offset:0 atIndex:2]; - [encoder setBuffer:outputBuffer offset:0 atIndex:3]; - - NSUInteger threadGroupWidth = [poseidon2CommitmentPipeline_ maxTotalThreadsPerThreadgroup]; - MTLSize gridSize = MTLSizeMake(count, 1, 1); - MTLSize threadGroupSize = MTLSizeMake(MIN((NSUInteger)count, threadGroupWidth), 1, 1); - [encoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize]; - - [encoder endEncoding]; - [commandBuffer commit]; - [commandBuffer waitUntilCompleted]; - - memcpy(results.data(), [outputBuffer contents], count * elemSize); - } -#endif - - return results; -} - -std::vector MetalZKContext::poseidon2BatchNullifier( - const std::vector& keys, - const std::vector& commitments, - const std::vector& indices -) { - std::vector results(keys.size()); - -#ifdef __APPLE__ - if (!initialized_ || !poseidon2NullifierPipeline_) { - return results; - } - - @autoreleasepool { - uint32_t count = (uint32_t)keys.size(); - size_t elemSize = sizeof(Fr256); - - id keyBuffer = [device_ newBufferWithBytes:keys.data() - length:count * elemSize - options:MTLResourceStorageModeShared]; - - id commitBuffer = [device_ newBufferWithBytes:commitments.data() - length:count * elemSize - options:MTLResourceStorageModeShared]; - - id indexBuffer = [device_ newBufferWithBytes:indices.data() - length:count * elemSize - options:MTLResourceStorageModeShared]; - - id outputBuffer = [device_ newBufferWithLength:count * elemSize - options:MTLResourceStorageModeShared]; - - id commandBuffer = [commandQueue_ commandBuffer]; - id encoder = [commandBuffer computeCommandEncoder]; - - [encoder setComputePipelineState:poseidon2NullifierPipeline_]; - [encoder setBuffer:keyBuffer offset:0 atIndex:0]; - [encoder setBuffer:commitBuffer offset:0 atIndex:1]; - [encoder setBuffer:indexBuffer offset:0 atIndex:2]; - [encoder setBuffer:outputBuffer offset:0 atIndex:3]; - - NSUInteger threadGroupWidth = [poseidon2NullifierPipeline_ maxTotalThreadsPerThreadgroup]; - MTLSize gridSize = MTLSizeMake(count, 1, 1); - MTLSize threadGroupSize = MTLSizeMake(MIN((NSUInteger)count, threadGroupWidth), 1, 1); - [encoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize]; - - [encoder endEncoding]; - [commandBuffer commit]; - [commandBuffer waitUntilCompleted]; - - memcpy(results.data(), [outputBuffer contents], count * elemSize); - } -#endif - - return results; -} - -// ============================================================================= -// Poseidon2 C API (wraps C++ methods) -// ============================================================================= - -int metal_zk_poseidon2_hash_pair( - void* /* ctx_ptr */, - void* output, - const void* left, - const void* right, - uint32_t count -) { - auto& ctx = getMetalZKContext(); - if (!ctx.isAvailable()) return -1; - - std::vector leftVec(count); - std::vector rightVec(count); - memcpy(leftVec.data(), left, count * sizeof(Fr256)); - memcpy(rightVec.data(), right, count * sizeof(Fr256)); - - auto results = ctx.poseidon2BatchHashPair(leftVec, rightVec); - memcpy(output, results.data(), count * sizeof(Fr256)); - return 0; -} - -int metal_zk_poseidon2_merkle_layer( - void* /* ctx_ptr */, - void* output, - const void* current_layer, - uint32_t current_size -) { - auto& ctx = getMetalZKContext(); - if (!ctx.isAvailable()) return -1; - if (current_size < 2 || (current_size & 1) != 0) return -5; - - std::vector currentVec(current_size); - memcpy(currentVec.data(), current_layer, current_size * sizeof(Fr256)); - - auto results = ctx.poseidon2MerkleLayer(currentVec); - memcpy(output, results.data(), (current_size / 2) * sizeof(Fr256)); - return 0; -} - -int metal_zk_batch_commitment( - void* /* ctx_ptr */, - void* output, - const void* values, - const void* blindings, - const void* salts, - uint32_t count -) { - auto& ctx = getMetalZKContext(); - if (!ctx.isAvailable()) return -1; - - std::vector valuesVec(count); - std::vector blindingsVec(count); - std::vector saltsVec(count); - memcpy(valuesVec.data(), values, count * sizeof(Fr256)); - memcpy(blindingsVec.data(), blindings, count * sizeof(Fr256)); - memcpy(saltsVec.data(), salts, count * sizeof(Fr256)); - - auto results = ctx.poseidon2BatchCommitment(valuesVec, blindingsVec, saltsVec); - memcpy(output, results.data(), count * sizeof(Fr256)); - return 0; -} - -int metal_zk_batch_nullifier( - void* /* ctx_ptr */, - void* output, - const void* keys, - const void* commitments, - const void* indices, - uint32_t count -) { - auto& ctx = getMetalZKContext(); - if (!ctx.isAvailable()) return -1; - - std::vector keysVec(count); - std::vector commitmentsVec(count); - std::vector indicesVec(count); - memcpy(keysVec.data(), keys, count * sizeof(Fr256)); - memcpy(commitmentsVec.data(), commitments, count * sizeof(Fr256)); - memcpy(indicesVec.data(), indices, count * sizeof(Fr256)); - - auto results = ctx.poseidon2BatchNullifier(keysVec, commitmentsVec, indicesVec); - memcpy(output, results.data(), count * sizeof(Fr256)); - return 0; -} - -uint32_t metal_zk_get_threshold(int op_type) { - switch (op_type) { - case 1: return 64; // POSEIDON2_HASH - case 2: return 128; // MERKLE_LAYER - case 3: return 256; // MSM - case 4: return 128; // COMMITMENT - case 5: return 128; // NULLIFIER - default: return 64; - } -} - -} // extern "C" - -} // namespace metal -} // namespace crypto -} // namespace lux diff --git a/bn254/gpu/wgsl/bn254.wgsl b/bn254/gpu/wgsl/bn254.wgsl deleted file mode 100644 index 51f7cee..0000000 --- a/bn254/gpu/wgsl/bn254.wgsl +++ /dev/null @@ -1,1684 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// First-party WebGPU/WGSL kernels for bn254 (alt_bn128). -// -// Algorithm transliteration of bn254/cpp/{bn254_fp,bn254_g1,bn254_hash_to_curve} -// .hpp. WGSL has no native u64, so each CPU u64 limb is encoded as a pair of -// u32 (lo, hi) in LE order. Wire layout matches CPU bytes-for-bytes: -// CPU u64 limb i == WGSL { lo = u32[2i], hi = u32[2i+1] } -// -// 4 x u64 CPU == 8 x u32 WGSL. -// -// Algorithms: -// * Fp: CIOS Montgomery multiplication (HAC §14.36) -// * Fp2: Karatsuba over Fp[u]/(u^2+1) -// * Fp6: Algorithms 13/16/17 of eprint 2010/354 -// * Fp12: generic Karatsuba mul + Granger-Scott cyclotomic squaring -// * G1: Bernstein-Lange efl/jacobian-0/{dbl-2009-l, add-2007-bl} -// * G1 mul: constant-time Montgomery ladder over 256 bits (no early exit) -// * G2: homogeneous projective doubleStep / mixedAddStep / lineCompute -// * Optimal-ate Miller loop: 6x+2 NAF, sparse 034-mul folding, two-line square -// * Final exp: easy part (p^6-1)(p^2+1), hard part Fuentes-Castaneda -// * SVDW: RFC 9380 §6.6.1 map_to_curve_svdw -// -// I/O packing: -// Affine point: 18 x u32 = 9 x u64 == (x[8] || y[8] || inf[2]) -// Field element: 8 x u32 = 4 x u64 -// Scalar: 8 x u32 = 4 x u64 -// -// Each kernel uses a contiguous storage buffer, indexed by global thread id. - -// ============================================================================= -// Constants (8 x u32, LE pairs of (lo, hi)) -// ============================================================================= -// -// p = 21888242871839275222246405745257275088696311157297823662689037894645226208583 -// CPU 4 x u64: { 0x3C208C16D87CFD47, 0x97816A916871CA8D, 0xB85045B68181585D, 0x30644E72E131A029 } -// WGSL 8 x u32 (lo,hi): -// { 0xD87CFD47, 0x3C208C16, 0x6871CA8D, 0x97816A91, -// 0x8181585D, 0xB85045B6, 0xE131A029, 0x30644E72 } - -const BN_P: array = array( - 0xD87CFD47u, 0x3C208C16u, 0x6871CA8Du, 0x97816A91u, - 0x8181585Du, 0xB85045B6u, 0xE131A029u, 0x30644E72u -); - -const BN_R: array = array( - 0xC58F0D9Du, 0xD35D438Du, 0xF5C70B3Du, 0x0A78EB28u, - 0x7879462Cu, 0x666EA36Fu, 0x9A07DF2Fu, 0x0E0A77C1u -); - -const BN_R2: array = array( - 0x538AFA89u, 0xF32CFC5Bu, 0xD44501FBu, 0xB5E71911u, - 0x0A417FF6u, 0x47AB1EFFu, 0xCAB8351Fu, 0x06D89F71u -); - -// p_inv low 64 bits = 0x87D20782E4866389; (lo, hi) = (0xE4866389, 0x87D20782) -const BN_PINV_LO: u32 = 0xE4866389u; -const BN_PINV_HI: u32 = 0x87D20782u; - -// p - 2 (for Fermat inversion) -const BN_PM2: array = array( - 0xD87CFD45u, 0x3C208C16u, 0x6871CA8Du, 0x97816A91u, - 0x8181585Du, 0xB85045B6u, 0xE131A029u, 0x30644E72u -); - -// (p+1)/4 (for square root, since p ≡ 3 mod 4) -const BN_PP1_4: array = array( - 0xB61F3F52u, 0x4F082305u, 0x5A1C72A3u, 0x65E05AA4u, - 0xA0605617u, 0x6E14116Du, 0xB84C680Au, 0x0C19139Cu -); - -// SVDW constants (plain). -const BN_SVDW_Z: array = array( - 1u, 0u, 0u, 0u, 0u, 0u, 0u, 0u -); -const BN_SVDW_C1: array = array( - 4u, 0u, 0u, 0u, 0u, 0u, 0u, 0u -); -// 0x183227397098D014DC2822DB40C0AC2E CBC0B548B438E5469E10460B6C3E7EA3 -const BN_SVDW_C2: array = array( - 0x6C3E7EA3u, 0x9E10460Bu, 0xB438E546u, 0xCBC0B548u, - 0x40C0AC2Eu, 0xDC2822DBu, 0x7098D014u, 0x18322739u -); -// 0x00000000 00000001 6789AF3A 83522EB3 53C98FC6 B36D713D 5D8D1CC5 DFFFFFFA -const BN_SVDW_C3: array = array( - 0xDFFFFFFAu, 0x5D8D1CC5u, 0xB36D713Du, 0x53C98FC6u, - 0x83522EB3u, 0x6789AF3Au, 0x00000001u, 0x00000000u -); -// 0x10216F7BA065E00DE81AC1E7808072C9DD2B2385CD7B438469602EB24829A9BD -const BN_SVDW_C4: array = array( - 0x4829A9BDu, 0x69602EB2u, 0xCD7B4384u, 0xDD2B2385u, - 0x808072C9u, 0xE81AC1E7u, 0xA065E00Du, 0x10216F7Bu -); - -// ============================================================================= -// 256-bit big-int primitives (8 x u32, LE) -// ============================================================================= - -fn u256_is_zero(a: array) -> bool { - var acc: u32 = 0u; - for (var i = 0u; i < 8u; i = i + 1u) { acc = acc | a[i]; } - return acc == 0u; -} - -fn u256_eq(a: array, b: array) -> bool { - var acc: u32 = 0u; - for (var i = 0u; i < 8u; i = i + 1u) { acc = acc | (a[i] ^ b[i]); } - return acc == 0u; -} - -fn u256_cmp(a: array, b: array) -> i32 { - for (var i = 7i; i >= 0; i = i - 1) { - let ui = u32(i); - if (a[ui] > b[ui]) { return 1; } - if (a[ui] < b[ui]) { return -1; } - } - return 0; -} - -fn u256_add(a: array, b: array, r: ptr>) -> u32 { - var c: u32 = 0u; - for (var i = 0u; i < 8u; i = i + 1u) { - let s1 = a[i] + c; - var c1: u32 = 0u; if (s1 < a[i]) { c1 = 1u; } - let s2 = s1 + b[i]; - var c2: u32 = 0u; if (s2 < s1) { c2 = 1u; } - (*r)[i] = s2; - c = c1 + c2; - } - return c; -} - -fn u256_sub(a: array, b: array, r: ptr>) -> u32 { - var bw: u32 = 0u; - for (var i = 0u; i < 8u; i = i + 1u) { - let d1 = a[i] - bw; - var b1: u32 = 0u; if (d1 > a[i]) { b1 = 1u; } - let d2 = d1 - b[i]; - var b2: u32 = 0u; if (d2 > d1) { b2 = 1u; } - (*r)[i] = d2; - bw = b1 + b2; - } - return bw; -} - -// 32 x 32 -> 64 split into (lo, hi). -fn mul32_64(a: u32, b: u32) -> vec2 { - let al = a & 0xFFFFu; - let ah = a >> 16u; - let bl = b & 0xFFFFu; - let bh = b >> 16u; - let ll = al * bl; - let lh = al * bh; - let hl = ah * bl; - let hh = ah * bh; - let mid = (ll >> 16u) + (lh & 0xFFFFu) + (hl & 0xFFFFu); - let lo = (mid << 16u) | (ll & 0xFFFFu); - let hi = hh + (lh >> 16u) + (hl >> 16u) + (mid >> 16u); - return vec2(lo, hi); -} - -// ============================================================================= -// Modular add/sub mod p -// ============================================================================= - -fn mod_add_p(a: array, b: array) -> array { - var r: array; - let c = u256_add(a, b, &r); - if (c != 0u || u256_cmp(r, BN_P) >= 0) { - var t: array; - let _ = u256_sub(r, BN_P, &t); - r = t; - } - return r; -} - -fn mod_sub_p(a: array, b: array) -> array { - var r: array; - let bw = u256_sub(a, b, &r); - if (bw != 0u) { - var t: array; - let _ = u256_add(r, BN_P, &t); - r = t; - } - return r; -} - -// ============================================================================= -// CIOS Montgomery multiplication -// ============================================================================= -// -// Equivalent to CPU's mont_mul, working in 32-bit chunks. We follow the same -// outer structure: process 8 limbs of `b` (one 32-bit lane at a time); after -// each addition compute u = t[0] * p_inv (low 32 bits), then add u*p, then -// shift. Final reduction subtracts p once if needed. - -fn mont_mul_p(a: array, b: array) -> array { - var t: array; - for (var i = 0u; i < 10u; i = i + 1u) { t[i] = 0u; } - - for (var i = 0u; i < 8u; i = i + 1u) { - // t += a * b[i] - var carry: u32 = 0u; - for (var j = 0u; j < 8u; j = j + 1u) { - let prod = mul32_64(a[j], b[i]); - let lo = prod.x; - let hi = prod.y; - - let s1 = t[j] + lo; - var c1: u32 = 0u; if (s1 < t[j]) { c1 = 1u; } - let s2 = s1 + carry; - var c2: u32 = 0u; if (s2 < s1) { c2 = 1u; } - t[j] = s2; - carry = hi + c1 + c2; - } - let s8 = t[8] + carry; - var c8: u32 = 0u; if (s8 < t[8]) { c8 = 1u; } - t[8] = s8; - t[9] = t[9] + c8; - - // u = t[0] * p_inv (mod 2^32). 64-bit u_inv * t[0] needs only the low 32 bits - // of (t[0] * p_inv_full). For 32-bit lanes that = (t[0] * p_inv_lo) & 0xFFFFFFFF. - let u_low = t[0] * BN_PINV_LO; - - // t += u_low * p - carry = 0u; - for (var j = 0u; j < 8u; j = j + 1u) { - let prod = mul32_64(u_low, BN_P[j]); - let lo = prod.x; - let hi = prod.y; - - let s1 = t[j] + lo; - var c1: u32 = 0u; if (s1 < t[j]) { c1 = 1u; } - let s2 = s1 + carry; - var c2: u32 = 0u; if (s2 < s1) { c2 = 1u; } - t[j] = s2; - carry = hi + c1 + c2; - } - let s8b = t[8] + carry; - var c8b: u32 = 0u; if (s8b < t[8]) { c8b = 1u; } - t[8] = s8b; - t[9] = t[9] + c8b; - - // shift right one 32-bit limb - for (var j = 0u; j < 9u; j = j + 1u) { t[j] = t[j+1u]; } - t[9] = 0u; - } - - var r: array; - for (var i = 0u; i < 8u; i = i + 1u) { r[i] = t[i]; } - if (t[8] != 0u || u256_cmp(r, BN_P) >= 0) { - var s: array; - let _ = u256_sub(r, BN_P, &s); - r = s; - } - return r; -} - -fn to_mont_p(a: array) -> array { - return mont_mul_p(a, BN_R2); -} - -fn from_mont_p(a: array) -> array { - let one = array(1u, 0u, 0u, 0u, 0u, 0u, 0u, 0u); - return mont_mul_p(a, one); -} - -// ============================================================================= -// Fp ops (alias) -// ============================================================================= - -fn fp_add(a: array, b: array) -> array { return mod_add_p(a, b); } -fn fp_sub(a: array, b: array) -> array { return mod_sub_p(a, b); } -fn fp_mul(a: array, b: array) -> array { return mont_mul_p(a, b); } -fn fp_sqr(a: array) -> array { return mont_mul_p(a, a); } - -fn fp_neg(a: array) -> array { - if (u256_is_zero(a)) { return a; } - var r: array; - let _ = u256_sub(BN_P, a, &r); - return r; -} - -fn fp_pow(a: array, e: array) -> array { - var result = to_mont_p(array(1u, 0u, 0u, 0u, 0u, 0u, 0u, 0u)); - var base = a; - for (var limb = 0u; limb < 8u; limb = limb + 1u) { - let w = e[limb]; - for (var bit = 0u; bit < 32u; bit = bit + 1u) { - if ((w >> bit) & 1u) == 1u { - result = fp_mul(result, base); - } - base = fp_sqr(base); - } - } - return result; -} - -fn fp_inv(a: array) -> array { return fp_pow(a, BN_PM2); } - -struct SqrtRes { ok: bool, v: array } - -fn fp_sqrt(a: array) -> SqrtRes { - let cand = fp_pow(a, BN_PP1_4); - var out: SqrtRes; - if (u256_eq(fp_sqr(cand), a)) { - out.ok = true; - out.v = cand; - } else { - out.ok = false; - out.v = a; - } - return out; -} - -fn fp_three() -> array { - return to_mont_p(array(3u, 0u, 0u, 0u, 0u, 0u, 0u, 0u)); -} - -// ============================================================================= -// G1 Jacobian -// ============================================================================= - -struct G1A { x: array, y: array, inf: u32 } -struct G1J { X: array, Y: array, Z: array, inf: u32 } - -fn g1_jac_zero() -> G1J { - var r: G1J; - for (var i = 0u; i < 8u; i = i + 1u) { r.X[i] = 0u; r.Y[i] = 0u; r.Z[i] = 0u; } - r.inf = 1u; - return r; -} - -fn g1_to_jac(p: G1A) -> G1J { - if (p.inf != 0u) { return g1_jac_zero(); } - var r: G1J; - r.X = p.x; r.Y = p.y; r.Z = BN_R; r.inf = 0u; - return r; -} - -fn g1_to_affine(p: G1J) -> G1A { - var a: G1A; - if (p.inf != 0u || u256_is_zero(p.Z)) { - for (var i = 0u; i < 8u; i = i + 1u) { a.x[i] = 0u; a.y[i] = 0u; } - a.inf = 1u; - return a; - } - let z_inv = fp_inv(p.Z); - let z_inv2 = fp_sqr(z_inv); - let z_inv3 = fp_mul(z_inv2, z_inv); - a.x = fp_mul(p.X, z_inv2); - a.y = fp_mul(p.Y, z_inv3); - a.inf = 0u; - return a; -} - -fn g1_double(p: G1J) -> G1J { - if (p.inf != 0u) { return p; } - if (u256_is_zero(p.Y)) { return g1_jac_zero(); } - - let A = fp_sqr(p.X); - let B = fp_sqr(p.Y); - let C = fp_sqr(B); - - let X_plus_B = fp_add(p.X, B); - var D = fp_sub(fp_sqr(X_plus_B), A); - D = fp_sub(D, C); - D = fp_add(D, D); - - var E = fp_add(A, A); - E = fp_add(E, A); - let F = fp_sqr(E); - - let two_D = fp_add(D, D); - let X3 = fp_sub(F, two_D); - - let D_minus_X3 = fp_sub(D, X3); - var eight_C = fp_add(C, C); - eight_C = fp_add(eight_C, eight_C); - eight_C = fp_add(eight_C, eight_C); - let Y3 = fp_sub(fp_mul(E, D_minus_X3), eight_C); - - var Z3 = fp_mul(p.Y, p.Z); - Z3 = fp_add(Z3, Z3); - - var r: G1J; r.X = X3; r.Y = Y3; r.Z = Z3; r.inf = 0u; - return r; -} - -fn g1_add(a: G1J, b: G1J) -> G1J { - if (a.inf != 0u) { return b; } - if (b.inf != 0u) { return a; } - - let Z1Z1 = fp_sqr(a.Z); - let Z2Z2 = fp_sqr(b.Z); - let U1 = fp_mul(a.X, Z2Z2); - let U2 = fp_mul(b.X, Z1Z1); - let S1 = fp_mul(fp_mul(a.Y, b.Z), Z2Z2); - let S2 = fp_mul(fp_mul(b.Y, a.Z), Z1Z1); - - let H = fp_sub(U2, U1); - if (u256_is_zero(H)) { - if (u256_eq(S1, S2)) { return g1_double(a); } - return g1_jac_zero(); - } - - let two_H = fp_add(H, H); - let I = fp_sqr(two_H); - let J = fp_mul(H, I); - - var r_ = fp_sub(S2, S1); - r_ = fp_add(r_, r_); - - let V = fp_mul(U1, I); - - let X3 = fp_sub(fp_sub(fp_sqr(r_), J), fp_add(V, V)); - let Y3 = fp_sub(fp_mul(r_, fp_sub(V, X3)), fp_mul(fp_add(S1, S1), J)); - var Z3 = fp_sub(fp_sub(fp_sqr(fp_add(a.Z, b.Z)), Z1Z1), Z2Z2); - Z3 = fp_mul(Z3, H); - - var out: G1J; out.X = X3; out.Y = Y3; out.Z = Z3; out.inf = 0u; - return out; -} - -fn g1_cmov(dst: G1J, src: G1J, cond: u32) -> G1J { - var r = dst; - let mask: u32 = 0u - (cond & 1u); // 0 or 0xFFFFFFFF - for (var i = 0u; i < 8u; i = i + 1u) { - r.X[i] = dst.X[i] ^ (mask & (dst.X[i] ^ src.X[i])); - r.Y[i] = dst.Y[i] ^ (mask & (dst.Y[i] ^ src.Y[i])); - r.Z[i] = dst.Z[i] ^ (mask & (dst.Z[i] ^ src.Z[i])); - } - if ((cond & 1u) == 1u) { r.inf = src.inf; } else { r.inf = dst.inf; } - return r; -} - -// Constant-time Montgomery ladder over 256 bits. -fn g1_scalar_mul(p: G1A, k: array) -> G1J { - if (p.inf != 0u) { return g1_jac_zero(); } - var R0 = g1_jac_zero(); - var R1 = g1_to_jac(p); - - for (var i = 255i; i >= 0; i = i - 1) { - let ui = u32(i); - let bit = (k[ui >> 5u] >> (ui & 31u)) & 1u; - - let sum = g1_add(R0, R1); - let dbl0 = g1_double(R0); - let dbl1 = g1_double(R1); - - let next_R0 = g1_cmov(dbl0, sum, bit); - let next_R1 = g1_cmov(sum, dbl1, bit); - - R0 = next_R0; - R1 = next_R1; - } - return R0; -} - -// ============================================================================= -// SVDW map_to_curve -// ============================================================================= - -fn fp_sgn0(a: array) -> u32 { - let p = from_mont_p(a); - return p[0] & 1u; -} - -fn fp_g_x(x: array) -> array { - let x2 = fp_sqr(x); - let x3 = fp_mul(x2, x); - return fp_add(x3, fp_three()); -} - -fn svdw_map(u_mont: array) -> G1A { - let ONE = BN_R; - let Z = to_mont_p(BN_SVDW_Z); - let c1 = to_mont_p(BN_SVDW_C1); - let c2 = to_mont_p(BN_SVDW_C2); - let c3 = to_mont_p(BN_SVDW_C3); - let c4 = to_mont_p(BN_SVDW_C4); - - var tv1 = fp_sqr(u_mont); - tv1 = fp_mul(tv1, c1); - let tv2 = fp_add(ONE, tv1); - tv1 = fp_sub(ONE, tv1); - var tv3 = fp_mul(tv1, tv2); - tv3 = fp_inv(tv3); - var tv4 = fp_mul(u_mont, tv1); - tv4 = fp_mul(tv4, tv3); - tv4 = fp_mul(tv4, c3); - let x1 = fp_sub(c2, tv4); - - let gx1 = fp_g_x(x1); - let s1 = fp_sqrt(gx1); - let x2 = fp_add(c2, tv4); - let gx2 = fp_g_x(x2); - let s2 = fp_sqrt(gx2); - - var x3 = fp_sqr(tv2); - x3 = fp_mul(x3, tv3); - x3 = fp_sqr(x3); - x3 = fp_mul(x3, c4); - x3 = fp_add(x3, Z); - - var x: array; - if (s1.ok) { x = x1; } else { x = x3; } - if (s2.ok && !s1.ok) { x = x2; } - - let gx = fp_g_x(x); - let sy = fp_sqrt(gx); - var y = sy.v; - if (fp_sgn0(u_mont) != fp_sgn0(y)) { y = fp_neg(y); } - - var r: G1A; r.x = x; r.y = y; r.inf = 0u; - return r; -} - -// ============================================================================= -// Kernel I/O packing helpers -// ============================================================================= -// -// Affine point lives in 18 x u32 (x[8] || y[8] || inf[2]) -- inf padded to two -// u32 to match the CUDA driver's 9 x u64 layout (each u64 = 2 x u32 LE). -// Field element / scalar = 8 x u32. Storage buffers are flat array. - -@group(0) @binding(0) var in_a: array; -@group(0) @binding(1) var in_b: array; -@group(0) @binding(2) var out_buf: array; - -fn load_field(off: u32) -> array { - var r: array; - for (var i = 0u; i < 8u; i = i + 1u) { r[i] = in_a[off + i]; } - return r; -} - -fn load_field_b(off: u32) -> array { - var r: array; - for (var i = 0u; i < 8u; i = i + 1u) { r[i] = in_b[off + i]; } - return r; -} - -fn store_field(off: u32, v: array) { - for (var i = 0u; i < 8u; i = i + 1u) { out_buf[off + i] = v[i]; } -} - -fn load_aff_a(off: u32) -> G1A { - var p: G1A; - for (var i = 0u; i < 8u; i = i + 1u) { p.x[i] = in_a[off + i]; } - for (var i = 0u; i < 8u; i = i + 1u) { p.y[i] = in_a[off + 8u + i]; } - p.inf = in_a[off + 16u]; - return p; -} - -fn load_aff_b(off: u32) -> G1A { - var p: G1A; - for (var i = 0u; i < 8u; i = i + 1u) { p.x[i] = in_b[off + i]; } - for (var i = 0u; i < 8u; i = i + 1u) { p.y[i] = in_b[off + 8u + i]; } - p.inf = in_b[off + 16u]; - return p; -} - -fn store_aff(off: u32, p: G1A) { - for (var i = 0u; i < 8u; i = i + 1u) { out_buf[off + i] = p.x[i]; } - for (var i = 0u; i < 8u; i = i + 1u) { out_buf[off + 8u + i] = p.y[i]; } - out_buf[off + 16u] = p.inf; - out_buf[off + 17u] = 0u; -} - -// ============================================================================= -// Kernels -// ============================================================================= - -@compute @workgroup_size(64) -fn k_g1_add(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; - let stride = 18u; // 8 + 8 + 2 (inf padded to u64) - let off = i * stride; - - let A = load_aff_a(off); - let B = load_aff_b(off); - - let Ja = g1_to_jac(A); - let Jb = g1_to_jac(B); - let S = g1_add(Ja, Jb); - let R = g1_to_affine(S); - - store_aff(off, R); -} - -@compute @workgroup_size(64) -fn k_g1_mul(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; - let p_off = i * 18u; - let s_off = i * 8u; - - let P = load_aff_a(p_off); - var k: array; - for (var j = 0u; j < 8u; j = j + 1u) { k[j] = in_b[s_off + j]; } - - let S = g1_scalar_mul(P, k); - let R = g1_to_affine(S); - - store_aff(p_off, R); -} - -@compute @workgroup_size(64) -fn k_svdw(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; - let in_off = i * 8u; - let out_off = i * 18u; - - let u = load_field(in_off); - let R = svdw_map(u); - store_aff(out_off, R); -} - -@compute @workgroup_size(64) -fn k_fp_mul(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; - let off = i * 8u; - let A = load_field(off); - let B = load_field_b(off); - let R = fp_mul(A, B); - store_field(off, R); -} - -// ============================================================================= -// Fp2 = Fp[u]/(u^2 + 1) -- Karatsuba mul (matches CPU bn254_fp2.hpp:fp2_mul). -// Layout per Fp2: a0 (8 u32) || a1 (8 u32) -- 16 u32 per element. -// ============================================================================= - -struct F2 { a0: array, a1: array }; - -fn f2_load_a(off: u32) -> F2 { - var r: F2; - for (var i = 0u; i < 8u; i = i + 1u) { r.a0[i] = in_a[off + i]; } - for (var i = 0u; i < 8u; i = i + 1u) { r.a1[i] = in_a[off + 8u + i]; } - return r; -} -fn f2_load_b(off: u32) -> F2 { - var r: F2; - for (var i = 0u; i < 8u; i = i + 1u) { r.a0[i] = in_b[off + i]; } - for (var i = 0u; i < 8u; i = i + 1u) { r.a1[i] = in_b[off + 8u + i]; } - return r; -} -fn f2_store(off: u32, v: F2) { - for (var i = 0u; i < 8u; i = i + 1u) { out_buf[off + i] = v.a0[i]; } - for (var i = 0u; i < 8u; i = i + 1u) { out_buf[off + 8u + i] = v.a1[i]; } -} - -fn fp2_mul(x: F2, y: F2) -> F2 { - let a = fp_mul(fp_add(x.a0, x.a1), fp_add(y.a0, y.a1)); - let b = fp_mul(x.a0, y.a0); - let c = fp_mul(x.a1, y.a1); - var r: F2; - r.a1 = fp_sub(fp_sub(a, b), c); - r.a0 = fp_sub(b, c); - return r; -} - -@compute @workgroup_size(64) -fn k_fp2_mul(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; - let off = i * 16u; - let A = f2_load_a(off); - let B = f2_load_b(off); - let R = fp2_mul(A, B); - f2_store(off, R); -} - -// ============================================================================= -// Fp6 / Fp12 mul -- Algorithm transliteration of CPU bn254_fp{6,12}.hpp. -// Fp12 layout: c0 (Fp6) || c1 (Fp6); Fp6 = b0 (Fp2) || b1 (Fp2) || b2 (Fp2). -// 12 x Fp2 per Fp12 -> 96 x u32. -// ============================================================================= - -struct F6 { b0: F2, b1: F2, b2: F2 }; -struct F12 { c0: F6, c1: F6 }; - -fn fp2_zero() -> F2 { - var r: F2; - for (var i = 0u; i < 8u; i = i + 1u) { r.a0[i] = 0u; r.a1[i] = 0u; } - return r; -} -fn fp2_add(x: F2, y: F2) -> F2 { - var r: F2; r.a0 = fp_add(x.a0, y.a0); r.a1 = fp_add(x.a1, y.a1); return r; -} -fn fp2_sub(x: F2, y: F2) -> F2 { - var r: F2; r.a0 = fp_sub(x.a0, y.a0); r.a1 = fp_sub(x.a1, y.a1); return r; -} -fn fp2_neg_local(x: F2) -> F2 { - var r: F2; r.a0 = fp_neg(x.a0); r.a1 = fp_neg(x.a1); return r; -} - -// (a0 + a1*u) * (9 + u): 9*a0 = 8*a0 + a0. -fn fp2_mul_by_nonres(x: F2) -> F2 { - var t0 = fp_add(x.a0, x.a0); - t0 = fp_add(t0, t0); - t0 = fp_add(t0, t0); - var t1 = fp_add(x.a1, x.a1); - t1 = fp_add(t1, t1); - t1 = fp_add(t1, t1); - var r: F2; - r.a0 = fp_sub(fp_add(t0, x.a0), x.a1); - r.a1 = fp_add(fp_add(t1, x.a1), x.a0); - return r; -} - -fn fp6_add(x: F6, y: F6) -> F6 { - var r: F6; r.b0 = fp2_add(x.b0, y.b0); r.b1 = fp2_add(x.b1, y.b1); r.b2 = fp2_add(x.b2, y.b2); return r; -} -fn fp6_sub(x: F6, y: F6) -> F6 { - var r: F6; r.b0 = fp2_sub(x.b0, y.b0); r.b1 = fp2_sub(x.b1, y.b1); r.b2 = fp2_sub(x.b2, y.b2); return r; -} -fn fp6_mul_by_nonres(x: F6) -> F6 { - var r: F6; r.b0 = fp2_mul_by_nonres(x.b2); r.b1 = x.b0; r.b2 = x.b1; return r; -} - -fn fp6_mul(x: F6, y: F6) -> F6 { - let t0 = fp2_mul(x.b0, y.b0); - let t1 = fp2_mul(x.b1, y.b1); - let t2 = fp2_mul(x.b2, y.b2); - - var c0 = fp2_add(x.b1, x.b2); - var tmp = fp2_add(y.b1, y.b2); - c0 = fp2_mul(c0, tmp); - c0 = fp2_sub(c0, t1); - c0 = fp2_sub(c0, t2); - c0 = fp2_mul_by_nonres(c0); - c0 = fp2_add(c0, t0); - - var c1 = fp2_add(x.b0, x.b1); - tmp = fp2_add(y.b0, y.b1); - c1 = fp2_mul(c1, tmp); - c1 = fp2_sub(c1, t0); - c1 = fp2_sub(c1, t1); - let t2_nr = fp2_mul_by_nonres(t2); - c1 = fp2_add(c1, t2_nr); - - var c2 = fp2_add(x.b0, x.b2); - tmp = fp2_add(y.b0, y.b2); - c2 = fp2_mul(c2, tmp); - c2 = fp2_sub(c2, t0); - c2 = fp2_sub(c2, t2); - c2 = fp2_add(c2, t1); - - var r: F6; r.b0 = c0; r.b1 = c1; r.b2 = c2; return r; -} - -fn fp12_mul(x: F12, y: F12) -> F12 { - var a = fp6_add(x.c0, x.c1); - var b = fp6_add(y.c0, y.c1); - a = fp6_mul(a, b); - b = fp6_mul(x.c0, y.c0); - let c = fp6_mul(x.c1, y.c1); - var r: F12; - r.c1 = fp6_sub(fp6_sub(a, b), c); - r.c0 = fp6_add(fp6_mul_by_nonres(c), b); - return r; -} - -fn f12_load_a(off: u32) -> F12 { - var r: F12; - for (var i = 0u; i < 8u; i = i + 1u) { r.c0.b0.a0[i] = in_a[off + 0u + i]; r.c0.b0.a1[i] = in_a[off + 8u + i]; } - for (var i = 0u; i < 8u; i = i + 1u) { r.c0.b1.a0[i] = in_a[off + 16u + i]; r.c0.b1.a1[i] = in_a[off + 24u + i]; } - for (var i = 0u; i < 8u; i = i + 1u) { r.c0.b2.a0[i] = in_a[off + 32u + i]; r.c0.b2.a1[i] = in_a[off + 40u + i]; } - for (var i = 0u; i < 8u; i = i + 1u) { r.c1.b0.a0[i] = in_a[off + 48u + i]; r.c1.b0.a1[i] = in_a[off + 56u + i]; } - for (var i = 0u; i < 8u; i = i + 1u) { r.c1.b1.a0[i] = in_a[off + 64u + i]; r.c1.b1.a1[i] = in_a[off + 72u + i]; } - for (var i = 0u; i < 8u; i = i + 1u) { r.c1.b2.a0[i] = in_a[off + 80u + i]; r.c1.b2.a1[i] = in_a[off + 88u + i]; } - return r; -} -fn f12_load_b(off: u32) -> F12 { - var r: F12; - for (var i = 0u; i < 8u; i = i + 1u) { r.c0.b0.a0[i] = in_b[off + 0u + i]; r.c0.b0.a1[i] = in_b[off + 8u + i]; } - for (var i = 0u; i < 8u; i = i + 1u) { r.c0.b1.a0[i] = in_b[off + 16u + i]; r.c0.b1.a1[i] = in_b[off + 24u + i]; } - for (var i = 0u; i < 8u; i = i + 1u) { r.c0.b2.a0[i] = in_b[off + 32u + i]; r.c0.b2.a1[i] = in_b[off + 40u + i]; } - for (var i = 0u; i < 8u; i = i + 1u) { r.c1.b0.a0[i] = in_b[off + 48u + i]; r.c1.b0.a1[i] = in_b[off + 56u + i]; } - for (var i = 0u; i < 8u; i = i + 1u) { r.c1.b1.a0[i] = in_b[off + 64u + i]; r.c1.b1.a1[i] = in_b[off + 72u + i]; } - for (var i = 0u; i < 8u; i = i + 1u) { r.c1.b2.a0[i] = in_b[off + 80u + i]; r.c1.b2.a1[i] = in_b[off + 88u + i]; } - return r; -} -fn f12_store(off: u32, v: F12) { - for (var i = 0u; i < 8u; i = i + 1u) { out_buf[off + 0u + i] = v.c0.b0.a0[i]; out_buf[off + 8u + i] = v.c0.b0.a1[i]; } - for (var i = 0u; i < 8u; i = i + 1u) { out_buf[off + 16u + i] = v.c0.b1.a0[i]; out_buf[off + 24u + i] = v.c0.b1.a1[i]; } - for (var i = 0u; i < 8u; i = i + 1u) { out_buf[off + 32u + i] = v.c0.b2.a0[i]; out_buf[off + 40u + i] = v.c0.b2.a1[i]; } - for (var i = 0u; i < 8u; i = i + 1u) { out_buf[off + 48u + i] = v.c1.b0.a0[i]; out_buf[off + 56u + i] = v.c1.b0.a1[i]; } - for (var i = 0u; i < 8u; i = i + 1u) { out_buf[off + 64u + i] = v.c1.b1.a0[i]; out_buf[off + 72u + i] = v.c1.b1.a1[i]; } - for (var i = 0u; i < 8u; i = i + 1u) { out_buf[off + 80u + i] = v.c1.b2.a0[i]; out_buf[off + 88u + i] = v.c1.b2.a1[i]; } -} - -@compute @workgroup_size(32) -fn k_fp12_mul(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; - let off = i * 96u; - let A = f12_load_a(off); - let B = f12_load_b(off); - let R = fp12_mul(A, B); - f12_store(off, R); -} - -// ============================================================================= -// Frobenius constants -- emitted from CPU body by bn254_gen_pairing_constants. -// WGSL has no native u64; each u64 limb is split into a (lo, hi) pair of u32. -// Single producer is bn254/cpp/bn254_pairing.cpp; CPU/GPU drift fails the -// determinism test by construction. -// ============================================================================= - -// NR1Power_k for k=1..5 -- 5 Fp2 elements (a0, a1). -const K_NR1P1_A0 : array = array( - 0x33144907u, 0xaf9ba696u, 0x87afb78au, 0xca6b1d73u, - 0xf08a2087u, 0x11bded5eu, 0x1a1f3a7cu, 0x02f34d75u -); -const K_NR1P1_A1 : array = array( - 0x4c492d72u, 0xa222ae23u, 0x565de15bu, 0xd00f02a4u, - 0x53dfc926u, 0xdc2ff3a2u, 0xb3899551u, 0x10a75716u -); -const K_NR1P2_A0 : array = array( - 0x4563ab30u, 0xb5773b10u, 0xa9aa6454u, 0x347f91c8u, - 0x242e0991u, 0x7a007127u, 0x118214ecu, 0x1956bcd8u -); -const K_NR1P2_A1 : array = array( - 0xa0aa4757u, 0x6e849f1eu, 0x89f89141u, 0xaa1c7b6du, - 0xfae0ca3au, 0xb6e713cdu, 0x4e82ebc3u, 0x26694fbbu -); -const K_NR1P3_A0 : array = array( - 0x2936b629u, 0xe4bbdd0cu, 0xe133bacbu, 0xbb30f162u, - 0xf9645366u, 0x31a9d1b6u, 0xa500f8ddu, 0x253570beu -); -const K_NR1P3_A1 : array = array( - 0x5ffe77c7u, 0xa1d77ce4u, 0x7826d1dbu, 0x07affd11u, - 0xbb7edc6bu, 0x6d16bd27u, 0x85defeccu, 0x2c872002u -); -const K_NR1P4_A0 : array = array( - 0x843abe92u, 0x7361d77fu, 0x273411fbu, 0xa5bb2bd3u, - 0x4b3e2399u, 0x9c941f31u, 0xbb9fd3ecu, 0x15df9cddu -); -const K_NR1P4_A1 : array = array( - 0x4bd8c949u, 0x5dddfd15u, 0xa4445b60u, 0x62cb29a5u, - 0x0c7dd2b9u, 0x37bc870au, 0x3171f0fdu, 0x24830a9du -); -const K_NR1P5_A0 : array = array( - 0x41690fe7u, 0xc970692fu, 0x27694b0bu, 0xe2403421u, - 0x83c459e8u, 0x32bee66bu, 0x0ab08841u, 0x12aabcedu -); -const K_NR1P5_A1 : array = array( - 0x40aebfa9u, 0x0d485d23u, 0xab2fcc57u, 0x05193418u, - 0x8a4910f5u, 0xd3b0a40bu, 0x35d2925au, 0x2f21ebb5u -); - -// NR2Power_k -- Fp scalars (single U256 each). -const K_NR2P1 : array = array( - 0x00fa1bf2u, 0xca8d8005u, 0x68b39769u, 0xf0c5d614u, - 0xad0d4418u, 0x0e201271u, 0xbad856e6u, 0x04290f65u -); -const K_NR2P2 : array = array( - 0x13e80b9cu, 0x3350c88eu, 0xdb5e56b9u, 0x7dce557cu, - 0xb615564au, 0x6001b4b8u, 0x020217e0u, 0x2682e617u -); -const K_NR2P3 : array = array( - 0x12edefaau, 0x68c34889u, 0x72aabf4fu, 0x8d087f68u, - 0x09081231u, 0x51e1a247u, 0x4729c0fau, 0x2259d6b1u -); -const K_NR2P4 : array = array( - 0xd782e155u, 0x71930c11u, 0xffbe3323u, 0xa6bb947cu, - 0xd4741444u, 0xaa303344u, 0x26594943u, 0x2c3b3f0du -); -const K_NR2P5 : array = array( - 0xc494f1abu, 0x08cfc388u, 0x8d1373d4u, 0x19b31514u, - 0xcb6c0213u, 0x584e90fdu, 0xdf2f8849u, 0x09e1685bu -); - -// NR3Power_k -- Fp2 elements. -const K_NR3P1_A0 : array = array( - 0x4e46d97du, 0x36531618u, 0xd4c96d9fu, 0x0af7129eu, - 0xca1009b5u, 0x659da72fu, 0x83a20d23u, 0x08116d89u -); -const K_NR3P1_A1 : array = array( - 0xc39c1939u, 0xb1df4af7u, 0x8a73bf7fu, 0x3d9f0287u, - 0x8caf0ae0u, 0x9b222092u, 0xeff054a6u, 0x26684515u -); -const K_NR3P2_A0 : array = array( - 0x16ad6badu, 0xc9af22f7u, 0x4aa662b2u, 0xb311782au, - 0xe248c7f4u, 0x19eeaf64u, 0xe3439f82u, 0x20273e77u -); -const K_NR3P2_A1 : array = array( - 0xf7ce93acu, 0xacc02860u, 0x7ba76b4cu, 0x3933d581u, - 0x446c8467u, 0x69e6188bu, 0x4417cc55u, 0x0a46036du -); -const K_NR3P3_A0 : array = array( - 0xaf46471eu, 0x5764af0au, 0x873e0fc1u, 0xdc50792eu, - 0x881d04f6u, 0x86a673ffu, 0x3c30a74cu, 0x0b2eddb4u -); -const K_NR3P3_A1 : array = array( - 0x787e8580u, 0x9a490f32u, 0xf04af8b1u, 0x8fd16d7fu, - 0xc6027bf2u, 0x4b39888eu, 0x5b52a15du, 0x03dd2e70u -); -const K_NR3P4_A0 : array = array( - 0x7b6762dfu, 0x448a93a5u, 0x28fdeadfu, 0xbfd62df5u, - 0x0e9bd47au, 0xd858f5d0u, 0x3476ec58u, 0x06b03d4du -); -const K_NR3P4_A1 : array = array( - 0xbcc936d1u, 0x2b19daf4u, 0x56f4299fu, 0xa1a54e7au, - 0x5adeaef1u, 0xb533eee0u, 0x84dda0b2u, 0x170c812bu -); -const K_NR3P5_A0 : array = array( - 0x75cf559fu, 0xe0bc4b22u, 0xc154e60fu, 0xc238b945u, - 0x929a7d5eu, 0x803982a5u, 0xf7e4a37eu, 0x15ce052du -); -const K_NR3P5_A1 : array = array( - 0xbf3799a7u, 0x2d28efbdu, 0x1ad60773u, 0x9b097e3cu, - 0xaf4a535bu, 0x982d4113u, 0xe3056063u, 0x24e18991u -); - -// ============================================================================= -// Fp2 extras: zero, one, sqr, conjugate, neg, double, mul_by_fp, mul_by_nonres_inv, -// inv, halve. Order/algorithm mirrors bn254/cpp/bn254_fp2.hpp + bn254_pairing.cuh. -// ============================================================================= - -fn fp2_one() -> F2 { - var r: F2; - r.a0 = BN_R; - for (var i = 0u; i < 8u; i = i + 1u) { r.a1[i] = 0u; } - return r; -} - -fn fp2_is_zero(x: F2) -> bool { return u256_is_zero(x.a0) && u256_is_zero(x.a1); } - -fn fp2_double(x: F2) -> F2 { - var r: F2; r.a0 = fp_add(x.a0, x.a0); r.a1 = fp_add(x.a1, x.a1); return r; -} - -fn fp2_conjugate(x: F2) -> F2 { - var r: F2; r.a0 = x.a0; r.a1 = fp_neg(x.a1); return r; -} - -fn fp2_mul_by_fp(x: F2, y: array) -> F2 { - var r: F2; r.a0 = fp_mul(x.a0, y); r.a1 = fp_mul(x.a1, y); return r; -} - -fn fp2_sqr(x: F2) -> F2 { - let a = fp_mul(fp_add(x.a0, x.a1), fp_sub(x.a0, x.a1)); - let b = fp_mul(x.a0, x.a1); - var r: F2; r.a0 = a; r.a1 = fp_add(b, b); - return r; -} - -fn fp2_inv(x: F2) -> F2 { - let t0 = fp_sqr(x.a0); - let t1 = fp_sqr(x.a1); - let t = fp_add(t0, t1); - let ti = fp_inv(t); - var r: F2; - r.a0 = fp_mul(x.a0, ti); - r.a1 = fp_neg(fp_mul(x.a1, ti)); - return r; -} - -// Multiply by (9+u)^-1. Used inside mul_b_twist (line eval) -- amortised cost. -fn fp2_mul_by_nonres_inv(x: F2) -> F2 { - var nr: F2; - nr.a0 = to_mont_p(array(9u, 0u, 0u, 0u, 0u, 0u, 0u, 0u)); - nr.a1 = BN_R; // Montgomery 1 - let inv_nr = fp2_inv(nr); - return fp2_mul(x, inv_nr); -} - -// fp_halve: (a + p)/2 if a odd else a/2. Algorithm-equivalent to CPU -// bn254_pairing.cpp:fp_halve. -fn fp_halve(v: array) -> array { - var r = v; - if ((r[0] & 1u) == 1u) { - var t: array; - let c = u256_add(r, BN_P, &t); - r = t; - // Combined right-shift by 1 across 8 limbs, with c filling top. - for (var i = 0u; i < 7u; i = i + 1u) { - r[i] = (r[i] >> 1u) | ((r[i + 1u] & 1u) << 31u); - } - r[7] = (r[7] >> 1u) | (c << 31u); - } else { - for (var i = 0u; i < 7u; i = i + 1u) { - r[i] = (r[i] >> 1u) | ((r[i + 1u] & 1u) << 31u); - } - r[7] = r[7] >> 1u; - } - return r; -} - -fn fp2_halve(x: F2) -> F2 { - var r: F2; r.a0 = fp_halve(x.a0); r.a1 = fp_halve(x.a1); return r; -} - -// b-twist coefficient: 3 * (9+u)^-1. -fn mul_b_twist(x: F2) -> F2 { - let res = fp2_mul_by_nonres_inv(x); - return fp2_add(fp2_double(res), res); -} - -// Mul by NR1Power_k: an Fp2 lookup constant, materialise as F2. -fn mul_nr1(x: F2, nr_a0: array, nr_a1: array) -> F2 { - var nr: F2; nr.a0 = nr_a0; nr.a1 = nr_a1; - return fp2_mul(x, nr); -} - -// Mul by NR2Power_k: an Fp scalar applied to both Fp2 limbs. -fn mul_nr2(x: F2, s: array) -> F2 { - var r: F2; r.a0 = fp_mul(x.a0, s); r.a1 = fp_mul(x.a1, s); return r; -} - -// ============================================================================= -// Fp6 extras: zero, one, neg, double, sqr, inv, mul_by_01, mul_by_fp2. -// ============================================================================= - -fn fp6_zero() -> F6 { - var r: F6; r.b0 = fp2_zero(); r.b1 = fp2_zero(); r.b2 = fp2_zero(); return r; -} -fn fp6_one() -> F6 { - var r: F6; r.b0 = fp2_one(); r.b1 = fp2_zero(); r.b2 = fp2_zero(); return r; -} -fn fp6_neg(x: F6) -> F6 { - var r: F6; r.b0 = fp2_neg_local(x.b0); r.b1 = fp2_neg_local(x.b1); r.b2 = fp2_neg_local(x.b2); return r; -} -fn fp6_double(x: F6) -> F6 { - var r: F6; r.b0 = fp2_double(x.b0); r.b1 = fp2_double(x.b1); r.b2 = fp2_double(x.b2); return r; -} - -fn fp6_sqr(x: F6) -> F6 { - var c4 = fp2_mul(x.b0, x.b1); - c4 = fp2_double(c4); - let c5 = fp2_sqr(x.b2); - var c1 = fp2_mul_by_nonres(c5); - c1 = fp2_add(c1, c4); - let c2 = fp2_sub(c4, c5); - let c3 = fp2_sqr(x.b0); - var c4b = fp2_sub(x.b0, x.b1); - c4b = fp2_add(c4b, x.b2); - var c5b = fp2_mul(x.b1, x.b2); - c5b = fp2_double(c5b); - c4b = fp2_sqr(c4b); - var c0 = fp2_mul_by_nonres(c5b); - c0 = fp2_add(c0, c3); - - var z2 = fp2_add(c2, c4b); - z2 = fp2_add(z2, c5b); - z2 = fp2_sub(z2, c3); - - var r: F6; r.b0 = c0; r.b1 = c1; r.b2 = z2; return r; -} - -fn fp6_inv(x: F6) -> F6 { - let t0 = fp2_sqr(x.b0); - let t1 = fp2_sqr(x.b1); - let t2 = fp2_sqr(x.b2); - let t3 = fp2_mul(x.b0, x.b1); - let t4 = fp2_mul(x.b0, x.b2); - let t5 = fp2_mul(x.b1, x.b2); - - var c0 = fp2_mul_by_nonres(t5); - c0 = fp2_neg_local(c0); - c0 = fp2_add(c0, t0); - - var c1 = fp2_mul_by_nonres(t2); - c1 = fp2_sub(c1, t3); - - let c2 = fp2_sub(t1, t4); - - var t6 = fp2_mul(x.b0, c0); - let d1 = fp2_mul(x.b2, c1); - let d2 = fp2_mul(x.b1, c2); - var d = fp2_add(d1, d2); - d = fp2_mul_by_nonres(d); - t6 = fp2_add(t6, d); - let t6_inv = fp2_inv(t6); - - var r: F6; - r.b0 = fp2_mul(c0, t6_inv); - r.b1 = fp2_mul(c1, t6_inv); - r.b2 = fp2_mul(c2, t6_inv); - return r; -} - -fn fp6_mul_by_01(z: F6, c0: F2, c1: F2) -> F6 { - let a = fp2_mul(z.b0, c0); - let b = fp2_mul(z.b1, c1); - - var tmp = fp2_add(z.b1, z.b2); - var t0 = fp2_mul(c1, tmp); - t0 = fp2_sub(t0, b); - t0 = fp2_mul_by_nonres(t0); - t0 = fp2_add(t0, a); - - tmp = fp2_add(z.b0, z.b2); - var t2 = fp2_mul(c0, tmp); - t2 = fp2_sub(t2, a); - t2 = fp2_add(t2, b); - - var t1 = fp2_add(c0, c1); - tmp = fp2_add(z.b0, z.b1); - t1 = fp2_mul(t1, tmp); - t1 = fp2_sub(t1, a); - t1 = fp2_sub(t1, b); - - var r: F6; r.b0 = t0; r.b1 = t1; r.b2 = t2; return r; -} - -fn fp6_mul_by_fp2(z: F6, y: F2) -> F6 { - var r: F6; - r.b0 = fp2_mul(z.b0, y); - r.b1 = fp2_mul(z.b1, y); - r.b2 = fp2_mul(z.b2, y); - return r; -} - -// ============================================================================= -// Fp12 extras: zero, one, sqr, conjugate, inv, mul_by_034, mul_034_by_034, -// mul_by_01234. -// ============================================================================= - -fn fp12_zero() -> F12 { - var r: F12; r.c0 = fp6_zero(); r.c1 = fp6_zero(); return r; -} -fn fp12_one() -> F12 { - var r: F12; r.c0 = fp6_one(); r.c1 = fp6_zero(); return r; -} - -fn fp12_is_one(z: F12) -> bool { - let one = fp12_one(); - return fp2_is_zero(z.c1.b0) && fp2_is_zero(z.c1.b1) && fp2_is_zero(z.c1.b2) - && fp2_is_zero(z.c0.b1) && fp2_is_zero(z.c0.b2) - && u256_is_zero(z.c0.b0.a1) - && u256_eq(z.c0.b0.a0, one.c0.b0.a0); -} - -fn fp12_conjugate(x: F12) -> F12 { - var r: F12; r.c0 = x.c0; r.c1 = fp6_neg(x.c1); return r; -} - -fn fp12_sqr(x: F12) -> F12 { - var c0 = fp6_sub(x.c0, x.c1); - var c3 = fp6_mul_by_nonres(x.c1); - c3 = fp6_neg(c3); - c3 = fp6_add(x.c0, c3); - let c2 = fp6_mul(x.c0, x.c1); - c0 = fp6_mul(c0, c3); - c0 = fp6_add(c0, c2); - let r1 = fp6_double(c2); - let c2b = fp6_mul_by_nonres(c2); - let r0 = fp6_add(c0, c2b); - var r: F12; r.c0 = r0; r.c1 = r1; return r; -} - -fn fp12_inv(x: F12) -> F12 { - let t0 = fp6_sqr(x.c0); - let t1 = fp6_sqr(x.c1); - let tmp = fp6_mul_by_nonres(t1); - let t0b = fp6_sub(t0, tmp); - let t0_inv = fp6_inv(t0b); - var r: F12; - r.c0 = fp6_mul(x.c0, t0_inv); - r.c1 = fp6_neg(fp6_mul(x.c1, t0_inv)); - return r; -} - -// Mul by sparse line element (c0, c3, c4) -- the standard 034-form line eval. -fn fp12_mul_by_034(z: F12, c0: F2, c3: F2, c4: F2) -> F12 { - let a = fp6_mul_by_fp2(z.c0, c0); - var b = z.c1; - b = fp6_mul_by_01(b, c3, c4); - - let d0 = fp2_add(c0, c3); - var d = fp6_add(z.c0, z.c1); - d = fp6_mul_by_01(d, d0, c4); - - var r1 = fp6_add(a, b); - r1 = fp6_neg(r1); - r1 = fp6_add(r1, d); - var r0 = fp6_mul_by_nonres(b); - r0 = fp6_add(r0, a); - var r: F12; r.c0 = r0; r.c1 = r1; return r; -} - -struct Fp12Sparse5 { v00: F2, v01: F2, v02: F2, v10: F2, v11: F2 }; - -// Multiply two sparse 034-form line elements. Result is sparse-5 -// (c0.b0, c0.b1, c0.b2, c1.b0, c1.b1) with c1.b2 = 0. -fn fp12_mul_034_by_034(d0: F2, d3: F2, d4: F2, c0: F2, c3: F2, c4: F2) -> Fp12Sparse5 { - let x0 = fp2_mul(c0, d0); - let x3 = fp2_mul(c3, d3); - let x4 = fp2_mul(c4, d4); - - var tmp = fp2_add(c0, c4); - var x04 = fp2_add(d0, d4); - x04 = fp2_mul(x04, tmp); - x04 = fp2_sub(x04, x0); - x04 = fp2_sub(x04, x4); - - tmp = fp2_add(c0, c3); - var x03 = fp2_add(d0, d3); - x03 = fp2_mul(x03, tmp); - x03 = fp2_sub(x03, x0); - x03 = fp2_sub(x03, x3); - - tmp = fp2_add(c3, c4); - var x34 = fp2_add(d3, d4); - x34 = fp2_mul(x34, tmp); - x34 = fp2_sub(x34, x3); - x34 = fp2_sub(x34, x4); - - var z00 = fp2_mul_by_nonres(x4); - z00 = fp2_add(z00, x0); - var r: Fp12Sparse5; - r.v00 = z00; r.v01 = x3; r.v02 = x34; r.v10 = x03; r.v11 = x04; - return r; -} - -// Generic Fp12 multiplied by sparse-5 (folded product of two 034 line evals). -fn fp12_mul_by_01234(z: F12, x: Fp12Sparse5) -> F12 { - var c0_part: F6; c0_part.b0 = x.v00; c0_part.b1 = x.v01; c0_part.b2 = x.v02; - var c1_part: F6; c1_part.b0 = x.v10; c1_part.b1 = x.v11; c1_part.b2 = fp2_zero(); - - var a = fp6_add(z.c0, z.c1); - var b = fp6_add(c0_part, c1_part); - a = fp6_mul(a, b); - - b = fp6_mul(z.c0, c0_part); - let c = fp6_mul_by_01(z.c1, x.v10, x.v11); - - var r1 = fp6_sub(a, b); - r1 = fp6_sub(r1, c); - - var r0 = fp6_mul_by_nonres(c); - r0 = fp6_add(r0, b); - - var r: F12; r.c0 = r0; r.c1 = r1; return r; -} - -// ============================================================================= -// Frobenius operators (Algorithms 28-30, eprint 2010/354). -// ============================================================================= - -fn frobenius(x: F12) -> F12 { - var t0 = fp2_conjugate(x.c0.b0); - var t1 = fp2_conjugate(x.c0.b1); - var t2 = fp2_conjugate(x.c0.b2); - var t3 = fp2_conjugate(x.c1.b0); - var t4 = fp2_conjugate(x.c1.b1); - var t5 = fp2_conjugate(x.c1.b2); - - t1 = mul_nr1(t1, K_NR1P2_A0, K_NR1P2_A1); - t2 = mul_nr1(t2, K_NR1P4_A0, K_NR1P4_A1); - t3 = mul_nr1(t3, K_NR1P1_A0, K_NR1P1_A1); - t4 = mul_nr1(t4, K_NR1P3_A0, K_NR1P3_A1); - t5 = mul_nr1(t5, K_NR1P5_A0, K_NR1P5_A1); - - var z: F12; - z.c0.b0 = t0; z.c0.b1 = t1; z.c0.b2 = t2; - z.c1.b0 = t3; z.c1.b1 = t4; z.c1.b2 = t5; - return z; -} - -fn frobenius_sq(x: F12) -> F12 { - var z: F12; - z.c0.b0 = x.c0.b0; - z.c0.b1 = mul_nr2(x.c0.b1, K_NR2P2); - z.c0.b2 = mul_nr2(x.c0.b2, K_NR2P4); - z.c1.b0 = mul_nr2(x.c1.b0, K_NR2P1); - z.c1.b1 = mul_nr2(x.c1.b1, K_NR2P3); - z.c1.b2 = mul_nr2(x.c1.b2, K_NR2P5); - return z; -} - -fn frobenius_cube(x: F12) -> F12 { - var t0 = fp2_conjugate(x.c0.b0); - var t1 = fp2_conjugate(x.c0.b1); - var t2 = fp2_conjugate(x.c0.b2); - var t3 = fp2_conjugate(x.c1.b0); - var t4 = fp2_conjugate(x.c1.b1); - var t5 = fp2_conjugate(x.c1.b2); - - t1 = mul_nr1(t1, K_NR3P2_A0, K_NR3P2_A1); - t2 = mul_nr1(t2, K_NR3P4_A0, K_NR3P4_A1); - t3 = mul_nr1(t3, K_NR3P1_A0, K_NR3P1_A1); - t4 = mul_nr1(t4, K_NR3P3_A0, K_NR3P3_A1); - t5 = mul_nr1(t5, K_NR3P5_A0, K_NR3P5_A1); - - var z: F12; - z.c0.b0 = t0; z.c0.b1 = t1; z.c0.b2 = t2; - z.c1.b0 = t3; z.c1.b1 = t4; z.c1.b2 = t5; - return z; -} - -// ============================================================================= -// Granger-Scott cyclotomic squaring (eprint 2009/565 §3.2). -// ============================================================================= - -fn cyclotomic_sqr(x: F12) -> F12 { - let t0 = fp2_sqr(x.c1.b1); - let t1 = fp2_sqr(x.c0.b0); - let t6 = fp2_sub(fp2_sub(fp2_sqr(fp2_add(x.c1.b1, x.c0.b0)), t0), t1); - let t2 = fp2_sqr(x.c0.b2); - let t3 = fp2_sqr(x.c1.b0); - let t7 = fp2_sub(fp2_sub(fp2_sqr(fp2_add(x.c0.b2, x.c1.b0)), t2), t3); - let t4 = fp2_sqr(x.c1.b2); - let t5 = fp2_sqr(x.c0.b1); - var t8 = fp2_sub(fp2_sub(fp2_sqr(fp2_add(x.c1.b2, x.c0.b1)), t4), t5); - t8 = fp2_mul_by_nonres(t8); - - let t0b = fp2_add(fp2_mul_by_nonres(t0), t1); - let t2b = fp2_add(fp2_mul_by_nonres(t2), t3); - let t4b = fp2_add(fp2_mul_by_nonres(t4), t5); - - var z: F12; - z.c0.b0 = fp2_add(fp2_double(fp2_sub(t0b, x.c0.b0)), t0b); - z.c0.b1 = fp2_add(fp2_double(fp2_sub(t2b, x.c0.b1)), t2b); - z.c0.b2 = fp2_add(fp2_double(fp2_sub(t4b, x.c0.b2)), t4b); - z.c1.b0 = fp2_add(fp2_double(fp2_add(t8, x.c1.b0)), t8); - z.c1.b1 = fp2_add(fp2_double(fp2_add(t6, x.c1.b1)), t6); - z.c1.b2 = fp2_add(fp2_double(fp2_add(t7, x.c1.b2)), t7); - return z; -} - -fn cyclotomic_n_sqr(z_in: F12, n: i32) -> F12 { - var z = z_in; - for (var i = 0i; i < n; i = i + 1) { z = cyclotomic_sqr(z); } - return z; -} - -// ============================================================================= -// expt: x^t with t = 4965661367192848881 (BN254 trace t = 6x+2). -// gnark-crypto addition chain, byte-equal to CPU oracle. -// ============================================================================= - -fn expt(x: F12) -> F12 { - let t3a = cyclotomic_sqr(x); - let t5a = cyclotomic_sqr(t3a); - let result_a = cyclotomic_sqr(t5a); - let t0a = cyclotomic_sqr(result_a); - let t2a = fp12_mul(x, t0a); - let t0b = fp12_mul(t3a, t2a); - let t1a = fp12_mul(x, t0b); - let t4a = fp12_mul(result_a, t2a); - var t6a = cyclotomic_sqr(t2a); - let t1b = fp12_mul(t0b, t1a); - let t0c = fp12_mul(t3a, t1b); - - t6a = cyclotomic_n_sqr(t6a, 6); - let t5b = fp12_mul(t5a, t6a); - let t5c = fp12_mul(t4a, t5b); - - let t5d = cyclotomic_n_sqr(t5c, 7); - let t4b = fp12_mul(t4a, t5d); - - let t4c = cyclotomic_n_sqr(t4b, 8); - let t4d = fp12_mul(t0c, t4c); - let t3b = fp12_mul(t3a, t4d); - - let t3c = cyclotomic_n_sqr(t3b, 6); - let t2b = fp12_mul(t2a, t3c); - - let t2c = cyclotomic_n_sqr(t2b, 8); - let t2d = fp12_mul(t0c, t2c); - - let t2e = cyclotomic_n_sqr(t2d, 6); - let t2f = fp12_mul(t0c, t2e); - - let t2g = cyclotomic_n_sqr(t2f, 10); - let t1c = fp12_mul(t1b, t2g); - - let t1d = cyclotomic_n_sqr(t1c, 6); - let t0d = fp12_mul(t0c, t1d); - return fp12_mul(result_a, t0d); -} - -// ============================================================================= -// G2 affine + projective ops + line evaluations. -// ============================================================================= - -struct G2A { x: F2, y: F2, inf: u32 }; -struct G2P { x: F2, y: F2, z: F2 }; - -fn g2_neg(a: G2A) -> G2A { - var r: G2A; r.x = a.x; r.y = fp2_neg_local(a.y); r.inf = a.inf; return r; -} -fn g2_to_proj(a: G2A) -> G2P { - var p: G2P; p.x = a.x; p.y = a.y; p.z = fp2_one(); return p; -} - -struct LineEval { r0: F2, r1: F2, r2: F2 }; - -// Doubling step: updates p in place via output, returns line eval (-H, 3 X^2, I). -fn g2_double_step(p_in: G2P) -> array { - // returns [px, py, pz, ev.r0, ev.r1, ev.r2] - var A = fp2_mul(p_in.x, p_in.y); - A = fp2_halve(A); - let B = fp2_sqr(p_in.y); - let C = fp2_sqr(p_in.z); - var D = fp2_double(C); - D = fp2_add(D, C); - let E = mul_b_twist(D); - var F = fp2_double(E); - F = fp2_add(F, E); - var G = fp2_add(B, F); - G = fp2_halve(G); - var H = fp2_add(p_in.y, p_in.z); - H = fp2_sqr(H); - let t1 = fp2_add(B, C); - H = fp2_sub(H, t1); - let I = fp2_sub(E, B); - let J = fp2_sqr(p_in.x); - let EE = fp2_sqr(E); - var K = fp2_double(EE); - K = fp2_add(K, EE); - - var px = fp2_sub(B, F); - px = fp2_mul(px, A); - var py = fp2_sqr(G); - py = fp2_sub(py, K); - let pz = fp2_mul(B, H); - - let r0 = fp2_neg_local(H); - var r1 = fp2_double(J); - r1 = fp2_add(r1, J); - let r2 = I; - - return array(px, py, pz, r0, r1, r2); -} - -// Mixed-add step: adds affine `a` into projective `p`. Returns updated p + ev. -fn g2_add_mixed_step(p_in: G2P, a: G2A) -> array { - let Y2Z1 = fp2_mul(a.y, p_in.z); - let O = fp2_sub(p_in.y, Y2Z1); - let X2Z1 = fp2_mul(a.x, p_in.z); - let L = fp2_sub(p_in.x, X2Z1); - let C = fp2_sqr(O); - let D = fp2_sqr(L); - let E = fp2_mul(L, D); - let F = fp2_mul(p_in.z, C); - let G = fp2_mul(p_in.x, D); - let t0 = fp2_double(G); - var H = fp2_add(E, F); - H = fp2_sub(H, t0); - let t1 = fp2_mul(p_in.y, E); - - var px = fp2_mul(L, H); - var py = fp2_sub(G, H); - py = fp2_mul(py, O); - py = fp2_sub(py, t1); - let pz = fp2_mul(E, p_in.z); - - let t2 = fp2_mul(L, a.y); - var J = fp2_mul(a.x, O); - J = fp2_sub(J, t2); - - let r0 = L; - let r1 = fp2_neg_local(O); - let r2 = J; - - return array(px, py, pz, r0, r1, r2); -} - -// Line-only compute (no point update) -- used for the second 6x+2 correction. -fn g2_line_compute(p_in: G2P, a: G2A) -> LineEval { - let Y2Z1 = fp2_mul(a.y, p_in.z); - let O = fp2_sub(p_in.y, Y2Z1); - let X2Z1 = fp2_mul(a.x, p_in.z); - let L = fp2_sub(p_in.x, X2Z1); - let t2 = fp2_mul(L, a.y); - var J = fp2_mul(a.x, O); - J = fp2_sub(J, t2); - - var ev: LineEval; - ev.r0 = L; - ev.r1 = fp2_neg_local(O); - ev.r2 = J; - return ev; -} - -// ============================================================================= -// 6x+2 NAF loop counter (matches CPU bn254_pairing.cpp:kLoopCounter). -// ============================================================================= - -const K_LOOP_NAF : array = array( - 0, 0, 0, 1, 0, 1, 0, -1, 0, 0, 1, -1, 0, 0, 1, 0, - 0, 1, 1, 0, -1, 0, 0, 1, 0, -1, 0, 0, 0, 0, 1, 1, - 1, 0, 0, -1, 0, 0, 1, 0, 0, 0, 0, 0, -1, 0, 0, 1, - 1, 0, 0, -1, 0, 0, 0, 1, 1, 0, -1, 0, 0, 1, 0, 1, - 1 -); - -// ============================================================================= -// G1 affine struct already declared above; need a small typed alias for pairing. -// Each pairing input is one G1 affine point + one G2 affine point. -// ============================================================================= - -// Single-pair Miller loop. Multi-pair is host-driven (tree-reduce of Fp12s, -// single final-exp at end). Algorithm-equivalent to CPU bn254_pairing.cpp. -fn miller_one(P: G1A, Q: G2A) -> F12 { - if (P.inf != 0u || Q.inf != 0u) { return fp12_one(); } - - var qProj = g2_to_proj(Q); - let qNeg = g2_neg(Q); - - var result = fp12_one(); - var l1: LineEval; - var l2: LineEval; - - // Skip i=64 (LoopCounter[64] == 0 and result still 1). - let s_d = g2_double_step(qProj); - qProj.x = s_d[0]; qProj.y = s_d[1]; qProj.z = s_d[2]; - l1.r0 = s_d[3]; l1.r1 = s_d[4]; l1.r2 = s_d[5]; - result.c0.b0 = fp2_mul_by_fp(l1.r0, P.y); - result.c1.b0 = fp2_mul_by_fp(l1.r1, P.x); - result.c1.b1 = l1.r2; - - // i=63 (LoopCounter[63] == -1). - result = fp12_sqr(result); - let s_l2 = g2_line_compute(qProj, qNeg); - l2.r0 = fp2_mul_by_fp(s_l2.r0, P.y); - l2.r1 = fp2_mul_by_fp(s_l2.r1, P.x); - l2.r2 = s_l2.r2; - let s_a = g2_add_mixed_step(qProj, Q); - qProj.x = s_a[0]; qProj.y = s_a[1]; qProj.z = s_a[2]; - l1.r0 = fp2_mul_by_fp(s_a[3], P.y); - l1.r1 = fp2_mul_by_fp(s_a[4], P.x); - l1.r2 = s_a[5]; - var prod = fp12_mul_034_by_034(l1.r0, l1.r1, l1.r2, l2.r0, l2.r1, l2.r2); - result = fp12_mul_by_01234(result, prod); - - // i=62 .. 0 - for (var i = 65i - 4i; i >= 0i; i = i - 1i) { - result = fp12_sqr(result); - let sd2 = g2_double_step(qProj); - qProj.x = sd2[0]; qProj.y = sd2[1]; qProj.z = sd2[2]; - l1.r0 = fp2_mul_by_fp(sd2[3], P.y); - l1.r1 = fp2_mul_by_fp(sd2[4], P.x); - l1.r2 = sd2[5]; - - let lc = K_LOOP_NAF[i]; - if (lc == 1) { - let sa2 = g2_add_mixed_step(qProj, Q); - qProj.x = sa2[0]; qProj.y = sa2[1]; qProj.z = sa2[2]; - l2.r0 = fp2_mul_by_fp(sa2[3], P.y); - l2.r1 = fp2_mul_by_fp(sa2[4], P.x); - l2.r2 = sa2[5]; - prod = fp12_mul_034_by_034(l1.r0, l1.r1, l1.r2, l2.r0, l2.r1, l2.r2); - result = fp12_mul_by_01234(result, prod); - } else if (lc == -1) { - let sa2 = g2_add_mixed_step(qProj, qNeg); - qProj.x = sa2[0]; qProj.y = sa2[1]; qProj.z = sa2[2]; - l2.r0 = fp2_mul_by_fp(sa2[3], P.y); - l2.r1 = fp2_mul_by_fp(sa2[4], P.x); - l2.r2 = sa2[5]; - prod = fp12_mul_034_by_034(l1.r0, l1.r1, l1.r2, l2.r0, l2.r1, l2.r2); - result = fp12_mul_by_01234(result, prod); - } else { - result = fp12_mul_by_034(result, l1.r0, l1.r1, l1.r2); - } - } - - // Final 6x+2 + Frobenius corrections: Q1 = pi(Q), Q2 = -pi^2(Q). - var Q1: G2A; - var Q2: G2A; - let q1x = fp2_conjugate(Q.x); - let q1y = fp2_conjugate(Q.y); - var nr_p2: F2; nr_p2.a0 = K_NR1P2_A0; nr_p2.a1 = K_NR1P2_A1; - var nr_p3: F2; nr_p3.a0 = K_NR1P3_A0; nr_p3.a1 = K_NR1P3_A1; - Q1.x = fp2_mul(q1x, nr_p2); - Q1.y = fp2_mul(q1y, nr_p3); - Q1.inf = 0u; - - var q2x: F2; q2x.a0 = fp_mul(Q.x.a0, K_NR2P2); q2x.a1 = fp_mul(Q.x.a1, K_NR2P2); - var q2y: F2; q2y.a0 = fp_mul(Q.y.a0, K_NR2P3); q2y.a1 = fp_mul(Q.y.a1, K_NR2P3); - Q2.x = q2x; Q2.y = fp2_neg_local(q2y); Q2.inf = 0u; - - let saQ1 = g2_add_mixed_step(qProj, Q1); - qProj.x = saQ1[0]; qProj.y = saQ1[1]; qProj.z = saQ1[2]; - l2.r0 = fp2_mul_by_fp(saQ1[3], P.y); - l2.r1 = fp2_mul_by_fp(saQ1[4], P.x); - l2.r2 = saQ1[5]; - let lcQ2 = g2_line_compute(qProj, Q2); - l1.r0 = fp2_mul_by_fp(lcQ2.r0, P.y); - l1.r1 = fp2_mul_by_fp(lcQ2.r1, P.x); - l1.r2 = lcQ2.r2; - prod = fp12_mul_034_by_034(l1.r0, l1.r1, l1.r2, l2.r0, l2.r1, l2.r2); - result = fp12_mul_by_01234(result, prod); - - return result; -} - -// ============================================================================= -// Final exponentiation -- Fuentes-Castaneda hard part (eprint 2015/192). -// ============================================================================= - -fn final_exp(z: F12) -> F12 { - var result = z; - var t0 = fp12_conjugate(result); - result = fp12_inv(result); - t0 = fp12_mul(t0, result); - result = frobenius_sq(t0); - result = fp12_mul(result, t0); - - if (fp12_is_one(result)) { return result; } - - var t_0 = expt(result); - t_0 = fp12_conjugate(t_0); - t_0 = cyclotomic_sqr(t_0); - var t_1 = cyclotomic_sqr(t_0); - t_1 = fp12_mul(t_0, t_1); - var t_2 = expt(t_1); - t_2 = fp12_conjugate(t_2); - var t_3 = fp12_conjugate(t_1); - t_1 = fp12_mul(t_2, t_3); - t_3 = cyclotomic_sqr(t_2); - var t_4 = expt(t_3); - t_4 = fp12_mul(t_1, t_4); - t_3 = fp12_mul(t_0, t_4); - t_0 = fp12_mul(t_2, t_4); - t_0 = fp12_mul(result, t_0); - t_2 = frobenius(t_3); - t_0 = fp12_mul(t_2, t_0); - t_2 = frobenius_sq(t_4); - t_0 = fp12_mul(t_2, t_0); - t_2 = fp12_conjugate(result); - t_2 = fp12_mul(t_2, t_3); - t_2 = frobenius_cube(t_2); - t_0 = fp12_mul(t_2, t_0); - - return t_0; -} - -// ============================================================================= -// I/O packing for G2 affine + Fp12 (mirror CUDA wire format). -// G2 affine: 36 u32 = (x.a0 || x.a1 || y.a0 || y.a1 || inf || pad) -// Fp12: 96 u32 = 6 x Fp2 (c0.b0..c1.b2) -// ============================================================================= - -fn load_g2_a(off: u32) -> G2A { - var p: G2A; - for (var i = 0u; i < 8u; i = i + 1u) { p.x.a0[i] = in_a[off + 0u + i]; } - for (var i = 0u; i < 8u; i = i + 1u) { p.x.a1[i] = in_a[off + 8u + i]; } - for (var i = 0u; i < 8u; i = i + 1u) { p.y.a0[i] = in_a[off + 16u + i]; } - for (var i = 0u; i < 8u; i = i + 1u) { p.y.a1[i] = in_a[off + 24u + i]; } - p.inf = in_a[off + 32u]; - return p; -} - -fn load_g2_b(off: u32) -> G2A { - var p: G2A; - for (var i = 0u; i < 8u; i = i + 1u) { p.x.a0[i] = in_b[off + 0u + i]; } - for (var i = 0u; i < 8u; i = i + 1u) { p.x.a1[i] = in_b[off + 8u + i]; } - for (var i = 0u; i < 8u; i = i + 1u) { p.y.a0[i] = in_b[off + 16u + i]; } - for (var i = 0u; i < 8u; i = i + 1u) { p.y.a1[i] = in_b[off + 24u + i]; } - p.inf = in_b[off + 32u]; - return p; -} - -// ============================================================================= -// Kernels: k_miller_iter (cyclo-sqr^100 stress) + k_pairing (full e(P,Q)). -// ============================================================================= - -@compute @workgroup_size(8) -fn k_miller_iter(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; - let off = i * 96u; - var z = f12_load_a(off); - for (var k = 0i; k < 100i; k = k + 1) { z = cyclotomic_sqr(z); } - f12_store(off, z); -} - -@compute @workgroup_size(4) -fn k_pairing(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; - let p_off = i * 18u; // G1 affine: 8 + 8 + 2 padding - let q_off = i * 36u; // G2 affine: 32 (Fp2 x, Fp2 y) + 4 (inf + 3 pad) - let out_off = i * 96u; // Fp12: 96 u32 - - let P = load_aff_a(p_off); - let Q = load_g2_b(q_off); - let m = miller_one(P, Q); - let r = final_exp(m); - f12_store(out_off, r); -} diff --git a/bn254/gpu/wgsl/bn254_driver_wgpu.cpp b/bn254/gpu/wgsl/bn254_driver_wgpu.cpp deleted file mode 100644 index 0a95a99..0000000 --- a/bn254/gpu/wgsl/bn254_driver_wgpu.cpp +++ /dev/null @@ -1,221 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Host-side WebGPU driver for bn254 kernels. -// -// Two compile modes: -// 1. LUX_BN254_HAVE_WGPU defined: dispatches kernels via Dawn/wgpu-native, -// identical algorithm to the CPU oracle (bn254/cpp/*.hpp), byte-equal -// results. -// 2. LUX_BN254_HAVE_WGPU undefined: runs the CPU oracle directly so the -// determinism harness still passes 100/100. Reports unavailable via -// lux_bn254_wgpu_available() so tests can label the path correctly. -// -// On a CI runner with WebGPU device the same vectors flow through the WGSL -// kernel and the byte-equality test asserts identical output. - -#include "bn254_driver_wgpu.h" -#include "bn254.hpp" - -#include -#include - -#ifdef LUX_BN254_HAVE_WGPU - -// Real WebGPU implementation would go here. Wiring depends on the WebGPU -// implementation chosen at link time (Dawn / wgpu-native). The host-side -// dispatch logic mirrors the CUDA driver: -// * upload buffers to GPU -// * dispatch the entry kernel (k_g1_add / k_g1_mul / k_svdw / k_fp_mul) -// * read back the result buffer -// -// CI runners with WebGPU enable LUX_BN254_HAVE_WGPU and link the chosen WebGPU -// backend; on CPU-only laptops/CI the same 100 deterministic inputs flow -// through the CPU oracle path below. - -#error "LUX_BN254_HAVE_WGPU set but WGPU host link not configured for this build." - -#else // LUX_BN254_HAVE_WGPU undefined: CPU-oracle path - -#include "bn254_fp.hpp" -#include "bn254_fp2.hpp" -#include "bn254_fp12.hpp" -#include "bn254_g1.hpp" -#include "bn254_g2.hpp" -#include "bn254_hash_to_curve.hpp" -#include "bn254_pairing.hpp" - -namespace { - -using lux::crypto::bn254::U256; -using lux::crypto::bn254::Fp2; -using lux::crypto::bn254::Fp12; -using lux::crypto::bn254::G1Affine; -using lux::crypto::bn254::G1Jac; -using lux::crypto::bn254::G2Affine; - -// Wire format mirrors the WGSL kernel: pairs of (lo, hi) u32 per CPU u64 limb. -// On a little-endian host with byte-equal storage layout this is a direct memcpy. - -inline U256 load_u256(const std::uint64_t* p) { - U256 r; r.limbs[0]=p[0]; r.limbs[1]=p[1]; r.limbs[2]=p[2]; r.limbs[3]=p[3]; - return r; -} - -inline void store_aff(std::uint64_t* p, const G1Affine& a) { - p[0]=a.x.limbs[0]; p[1]=a.x.limbs[1]; p[2]=a.x.limbs[2]; p[3]=a.x.limbs[3]; - p[4]=a.y.limbs[0]; p[5]=a.y.limbs[1]; p[6]=a.y.limbs[2]; p[7]=a.y.limbs[3]; - p[8] = a.infinity ? 1ULL : 0ULL; -} - -inline G1Affine load_aff(const std::uint64_t* p) { - G1Affine a; - a.x = load_u256(p); - a.y = load_u256(p + 4); - a.infinity = (p[8] != 0); - return a; -} - -inline Fp2 load_fp2(const std::uint64_t* p) { - Fp2 r; - for (int i = 0; i < 4; ++i) { r.a0.limbs[i] = p[i]; r.a1.limbs[i] = p[4+i]; } - return r; -} -inline void store_fp2(std::uint64_t* p, const Fp2& x) { - for (int i = 0; i < 4; ++i) { p[i] = x.a0.limbs[i]; p[4+i] = x.a1.limbs[i]; } -} -inline G2Affine load_g2(const std::uint64_t* p) { - G2Affine a; - a.x = load_fp2(p); - a.y = load_fp2(p + 8); - a.infinity = (p[16] != 0); - return a; -} -inline Fp12 load_fp12(const std::uint64_t* p) { - Fp12 r; - r.c0.b0 = load_fp2(p + 0); - r.c0.b1 = load_fp2(p + 8); - r.c0.b2 = load_fp2(p + 16); - r.c1.b0 = load_fp2(p + 24); - r.c1.b1 = load_fp2(p + 32); - r.c1.b2 = load_fp2(p + 40); - return r; -} -inline void store_fp12(std::uint64_t* p, const Fp12& x) { - store_fp2(p + 0, x.c0.b0); - store_fp2(p + 8, x.c0.b1); - store_fp2(p + 16, x.c0.b2); - store_fp2(p + 24, x.c1.b0); - store_fp2(p + 32, x.c1.b1); - store_fp2(p + 40, x.c1.b2); -} - -} // namespace - -extern "C" { - -int lux_bn254_wgpu_available(void) { return 0; } - -int lux_bn254_wgpu_g1_add(const void* a, const void* b, void* out, unsigned n) { - auto* pa = (const std::uint64_t*)a; - auto* pb = (const std::uint64_t*)b; - auto* po = (std::uint64_t*)out; - for (unsigned i = 0; i < n; ++i) { - G1Affine A = load_aff(pa + i*9); - G1Affine B = load_aff(pb + i*9); - G1Jac S = lux::crypto::bn254::g1_add( - lux::crypto::bn254::g1_to_jac(A), - lux::crypto::bn254::g1_to_jac(B)); - store_aff(po + i*9, lux::crypto::bn254::g1_to_affine(S)); - } - return 0; -} - -int lux_bn254_wgpu_g1_mul(const void* points, const void* scalars, void* out, unsigned n) { - auto* pp = (const std::uint64_t*)points; - auto* ps = (const std::uint64_t*)scalars; - auto* po = (std::uint64_t*)out; - for (unsigned i = 0; i < n; ++i) { - G1Affine P = load_aff(pp + i*9); - U256 k = load_u256(ps + i*4); - store_aff(po + i*9, lux::crypto::bn254::g1_to_affine( - lux::crypto::bn254::g1_scalar_mul(P, k))); - } - return 0; -} - -int lux_bn254_wgpu_svdw(const void* u_in, void* out, unsigned n) { - auto* pu = (const std::uint64_t*)u_in; - auto* po = (std::uint64_t*)out; - for (unsigned i = 0; i < n; ++i) { - U256 u = load_u256(pu + i*4); - G1Affine R = lux::crypto::bn254::h2c::map_to_curve_svdw(u); - store_aff(po + i*9, R); - } - return 0; -} - -int lux_bn254_wgpu_fp_mul(const void* a, const void* b, void* out, unsigned n) { - auto* pa = (const std::uint64_t*)a; - auto* pb = (const std::uint64_t*)b; - auto* po = (std::uint64_t*)out; - for (unsigned i = 0; i < n; ++i) { - U256 A = load_u256(pa + i*4); - U256 B = load_u256(pb + i*4); - U256 R = lux::crypto::bn254::fp_mul(A, B); - po[i*4+0]=R.limbs[0]; po[i*4+1]=R.limbs[1]; - po[i*4+2]=R.limbs[2]; po[i*4+3]=R.limbs[3]; - } - return 0; -} - -int lux_bn254_wgpu_fp2_mul(const void* a, const void* b, void* out, unsigned n) { - auto* pa = (const std::uint64_t*)a; - auto* pb = (const std::uint64_t*)b; - auto* po = (std::uint64_t*)out; - for (unsigned i = 0; i < n; ++i) { - store_fp2(po + i*8, - lux::crypto::bn254::fp2_mul(load_fp2(pa + i*8), load_fp2(pb + i*8))); - } - return 0; -} - -int lux_bn254_wgpu_fp12_mul(const void* a, const void* b, void* out, unsigned n) { - auto* pa = (const std::uint64_t*)a; - auto* pb = (const std::uint64_t*)b; - auto* po = (std::uint64_t*)out; - for (unsigned i = 0; i < n; ++i) { - store_fp12(po + i*48, - lux::crypto::bn254::fp12_mul(load_fp12(pa + i*48), load_fp12(pb + i*48))); - } - return 0; -} - -int lux_bn254_wgpu_miller_iter(const void* in_p, void* out, unsigned n) { - auto* pi = (const std::uint64_t*)in_p; - auto* po = (std::uint64_t*)out; - for (unsigned i = 0; i < n; ++i) { - Fp12 z = load_fp12(pi + i*48); - for (int k = 0; k < 100; ++k) - z = lux::crypto::bn254::cyclotomic_sqr_public(z); - store_fp12(po + i*48, z); - } - return 0; -} - -int lux_bn254_wgpu_pairing(const void* P, const void* Q, void* out, unsigned n) { - auto* pP = (const std::uint64_t*)P; - auto* pQ = (const std::uint64_t*)Q; - auto* po = (std::uint64_t*)out; - for (unsigned i = 0; i < n; ++i) { - G1Affine pi = load_aff(pP + i*9); - G2Affine qi = load_g2(pQ + i*18); - Fp12 r = lux::crypto::bn254::multi_pair(&pi, &qi, 1); - store_fp12(po + i*48, r); - } - return 0; -} - -} // extern "C" - -#endif // LUX_BN254_HAVE_WGPU diff --git a/bn254/gpu/wgsl/bn254_driver_wgpu.h b/bn254/gpu/wgsl/bn254_driver_wgpu.h deleted file mode 100644 index 615cd08..0000000 --- a/bn254/gpu/wgsl/bn254_driver_wgpu.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// C-ABI interface for the bn254 WGSL/WebGPU driver. Mirrors the CUDA driver -// signatures so the test harness can dispatch identical vectors to all -// backends and assert byte-equality. - -#ifndef LUX_BN254_DRIVER_WGPU_H -#define LUX_BN254_DRIVER_WGPU_H - -#ifdef __cplusplus -extern "C" { -#endif - -int lux_bn254_wgpu_available(void); - -int lux_bn254_wgpu_g1_add(const void* a, const void* b, void* out, unsigned n); -int lux_bn254_wgpu_g1_mul(const void* points, const void* scalars, void* out, unsigned n); -int lux_bn254_wgpu_svdw(const void* u_in, void* out, unsigned n); -int lux_bn254_wgpu_fp_mul(const void* a, const void* b, void* out, unsigned n); - -// Pairing tower (8 u64 per Fp2, 48 u64 per Fp12, 18 u64 per G2Affine). -int lux_bn254_wgpu_fp2_mul(const void* a, const void* b, void* out, unsigned n); -int lux_bn254_wgpu_fp12_mul(const void* a, const void* b, void* out, unsigned n); -int lux_bn254_wgpu_miller_iter(const void* in_p, void* out, unsigned n); -int lux_bn254_wgpu_pairing(const void* P, const void* Q, void* out, unsigned n); - -#ifdef __cplusplus -} -#endif - -#endif // LUX_BN254_DRIVER_WGPU_H diff --git a/cggmp21/gpu/cuda/cggmp21.cu b/cggmp21/gpu/cuda/cggmp21.cu deleted file mode 100644 index 3f06720..0000000 --- a/cggmp21/gpu/cuda/cggmp21.cu +++ /dev/null @@ -1,347 +0,0 @@ -// CGGMP21 threshold ECDSA partial signing -- CUDA implementation -// Matches cggmp21.metal output byte-for-byte -// One thread per partial signature - -#include - -#ifndef __CUDA_ARCH__ -#define __device__ -#define __global__ -#define __shared__ -struct dim3 { unsigned x, y, z; }; -static dim3 blockIdx, blockDim, threadIdx; -#endif - -// ============================================================================= -// 2048-bit unsigned integer (32 x 64-bit limbs, little-endian) -// ============================================================================= - -struct uint2048 { - uint64_t limbs[32]; -}; - -// ============================================================================= -// 2048-bit arithmetic -// ============================================================================= - -__device__ static void u2048_zero(uint2048& a) { - for (int i = 0; i < 32; i++) a.limbs[i] = 0; -} - -__device__ static bool u2048_is_zero(const uint2048& a) { - uint64_t acc = 0; - for (int i = 0; i < 32; i++) acc |= a.limbs[i]; - return acc == 0; -} - -__device__ static int u2048_cmp(const uint2048& a, const uint2048& b) { - for (int i = 31; i >= 0; i--) { - if (a.limbs[i] < b.limbs[i]) return -1; - if (a.limbs[i] > b.limbs[i]) return 1; - } - return 0; -} - -__device__ static uint2048 u2048_add(const uint2048& a, const uint2048& b, uint64_t& carry) { - uint2048 r; - uint64_t c = 0; - for (int i = 0; i < 32; i++) { - uint64_t sum = a.limbs[i] + c; - c = (sum < a.limbs[i]) ? 1ULL : 0ULL; - uint64_t sum2 = sum + b.limbs[i]; - c += (sum2 < sum) ? 1ULL : 0ULL; - r.limbs[i] = sum2; - } - carry = c; - return r; -} - -__device__ static uint2048 u2048_sub(const uint2048& a, const uint2048& b, uint64_t& borrow) { - uint2048 r; - uint64_t bw = 0; - for (int i = 0; i < 32; i++) { - uint64_t diff = a.limbs[i] - bw; - bw = (diff > a.limbs[i]) ? 1ULL : 0ULL; - uint64_t diff2 = diff - b.limbs[i]; - bw += (diff2 > diff) ? 1ULL : 0ULL; - r.limbs[i] = diff2; - } - borrow = bw; - return r; -} - -// 64x64->128 multiply using CUDA __int128 -__device__ static void mul64(uint64_t a, uint64_t b, uint64_t& lo, uint64_t& hi) { -#ifdef __CUDA_ARCH__ - unsigned __int128 prod = (unsigned __int128)a * b; - lo = (uint64_t)prod; - hi = (uint64_t)(prod >> 64); -#else - uint64_t a_lo = a & 0xFFFFFFFFULL, a_hi = a >> 32; - uint64_t b_lo = b & 0xFFFFFFFFULL, b_hi = b >> 32; - uint64_t ll = a_lo * b_lo, lh = a_lo * b_hi; - uint64_t hl = a_hi * b_lo, hh = a_hi * b_hi; - uint64_t mid = lh + (ll >> 32); - uint64_t mid2 = mid + hl; - if (mid2 < mid) hh += (1ULL << 32); - lo = (mid2 << 32) | (ll & 0xFFFFFFFFULL); - hi = hh + (mid2 >> 32); -#endif -} - -// ============================================================================= -// Montgomery multiplication for 2048-bit modulus -// ============================================================================= - -__device__ static void mont_reduce_2048(uint64_t t[64], - const uint2048& m, - uint64_t m_inv, - uint2048& result) { - uint64_t a[65]; - for (int i = 0; i < 64; i++) a[i] = t[i]; - a[64] = 0; - - for (int i = 0; i < 32; i++) { - uint64_t u = a[i] * m_inv; - uint64_t carry = 0; - for (int j = 0; j < 32; j++) { - uint64_t lo, hi; - mul64(u, m.limbs[j], lo, hi); - uint64_t sum = lo + carry; - if (sum < lo) hi++; - lo = sum; - sum = a[i + j] + lo; - if (sum < a[i + j]) hi++; - a[i + j] = sum; - carry = hi; - } - for (int j = 32; i + j <= 64; j++) { - uint64_t sum = a[i + j] + carry; - carry = (sum < a[i + j]) ? 1ULL : 0ULL; - a[i + j] = sum; - if (!carry) break; - } - } - - for (int i = 0; i < 32; i++) result.limbs[i] = a[i + 32]; - - if (a[64] || u2048_cmp(result, m) >= 0) { - uint64_t bw; - result = u2048_sub(result, m, bw); - } -} - -__device__ static void mont_mul_2048(const uint2048& a, - const uint2048& b, - const uint2048& m, - uint64_t m_inv, - uint2048& result) { - uint64_t t[64]; - for (int i = 0; i < 64; i++) t[i] = 0; - - for (int i = 0; i < 32; i++) { - uint64_t carry = 0; - for (int j = 0; j < 32; j++) { - uint64_t lo, hi; - mul64(a.limbs[i], b.limbs[j], lo, hi); - uint64_t sum = lo + carry; - if (sum < lo) hi++; - lo = sum; - sum = t[i + j] + lo; - if (sum < t[i + j]) hi++; - t[i + j] = sum; - carry = hi; - } - t[i + 32] = carry; - } - - mont_reduce_2048(t, m, m_inv, result); -} - -__device__ static void mont_sqr_2048(const uint2048& a, - const uint2048& m, - uint64_t m_inv, - uint2048& result) { - mont_mul_2048(a, a, m, m_inv, result); -} - -__device__ static void mont_pow_2048(const uint2048& base, - const uint2048& exp, - const uint2048& m, - uint64_t m_inv, - const uint2048& mont_one, - uint2048& result) { - result = mont_one; - uint2048 b = base; - - for (int i = 0; i < 32; i++) { - for (int bit = 0; bit < 64; bit++) { - if ((exp.limbs[i] >> bit) & 1) { - mont_mul_2048(result, b, m, m_inv, result); - } - mont_sqr_2048(b, m, m_inv, b); - } - } -} - -// ============================================================================= -// Paillier encryption primitives -// ============================================================================= - -struct PaillierPubKey { - uint8_t n_data[256]; // N in big-endian - uint8_t n_inv64[8]; // -N^{-1} mod 2^64 for Montgomery -}; - -// ============================================================================= -// CGGMP21 structures -// ============================================================================= - -struct CGGMP21Input { - uint8_t k_share[32]; // k_i share (secp256k1 scalar) - uint8_t chi_share[32]; // chi_i = k_i * x_i share - uint8_t msg_hash[32]; // Message hash - uint8_t gamma_share[32]; // gamma_i share -}; - -struct CGGMP21PartialSig { - uint8_t sigma_i[32]; // sigma_i = k_i * m + r * chi_i (mod n) -}; - -// ============================================================================= -// secp256k1 order for scalar arithmetic -// ============================================================================= - -__device__ static const uint64_t SECP_N[4] = { - 0xBFD25E8CD0364141ULL, 0xBAAEDCE6AF48A03BULL, - 0xFFFFFFFFFFFFFFFEULL, 0xFFFFFFFFFFFFFFFFULL -}; - -// Modular multiplication mod secp256k1 order (256-bit) -__device__ static void scalar_mul_mod_n(const uint64_t a[4], - const uint64_t b[4], - uint64_t result[4]) { - // Full 512-bit product - uint64_t t[8]; - for (int i = 0; i < 8; i++) t[i] = 0; - for (int i = 0; i < 4; i++) { - uint64_t carry = 0; - for (int j = 0; j < 4; j++) { - uint64_t lo, hi; - mul64(a[i], b[j], lo, hi); - uint64_t sum = lo + carry; if (sum < lo) hi++; - sum = t[i + j] + sum; if (sum < t[i + j]) hi++; - t[i + j] = sum; - carry = hi; - } - t[i + 4] = carry; - } - - // Barrett reduction mod n (iterate subtraction) - uint64_t r[4] = {t[0], t[1], t[2], t[3]}; - - for (int iter = 0; iter < 4; iter++) { - uint64_t borrow = 0; - uint64_t diff[4]; - for (int i = 0; i < 4; i++) { - uint64_t d = r[i] - borrow; - borrow = (d > r[i]) ? 1ULL : 0ULL; - uint64_t d2 = d - SECP_N[i]; - borrow += (d2 > d) ? 1ULL : 0ULL; - diff[i] = d2; - } - if (!borrow) { - for (int i = 0; i < 4; i++) r[i] = diff[i]; - } - } - - for (int i = 0; i < 4; i++) result[i] = r[i]; -} - -// Modular addition mod n -__device__ static void scalar_add_mod_n(const uint64_t a[4], - const uint64_t b[4], - uint64_t result[4]) { - uint64_t carry = 0; - for (int i = 0; i < 4; i++) { - uint64_t sum = a[i] + carry; - carry = (sum < a[i]) ? 1ULL : 0ULL; - uint64_t sum2 = sum + b[i]; - carry += (sum2 < sum) ? 1ULL : 0ULL; - result[i] = sum2; - } - // Reduce - uint64_t borrow = 0; - uint64_t diff[4]; - for (int i = 0; i < 4; i++) { - uint64_t d = result[i] - borrow; - borrow = (d > result[i]) ? 1ULL : 0ULL; - uint64_t d2 = d - SECP_N[i]; - borrow += (d2 > d) ? 1ULL : 0ULL; - diff[i] = d2; - } - if (!borrow || carry) { - for (int i = 0; i < 4; i++) result[i] = diff[i]; - } -} - -// ============================================================================= -// Partial signing kernel -// ============================================================================= - -extern "C" __global__ void cggmp21_partial_sign_batch( - const CGGMP21Input* __restrict__ inputs, - CGGMP21PartialSig* __restrict__ partial_sigs, - const uint8_t* __restrict__ r_x, // 32 bytes: R.x mod n - const uint32_t* __restrict__ num_ops_ptr) -{ - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - uint32_t num_ops = *num_ops_ptr; - if (tid >= num_ops) return; - - // Read k_i share - uint64_t k[4]; - for (int i = 0; i < 4; i++) { - k[i] = 0; - for (int b = 0; b < 8; b++) - k[i] |= (uint64_t)inputs[tid].k_share[i * 8 + b] << (b * 8); - } - - // Read chi_i share - uint64_t chi[4]; - for (int i = 0; i < 4; i++) { - chi[i] = 0; - for (int b = 0; b < 8; b++) - chi[i] |= (uint64_t)inputs[tid].chi_share[i * 8 + b] << (b * 8); - } - - // Read message hash - uint64_t msg[4]; - for (int i = 0; i < 4; i++) { - msg[i] = 0; - for (int b = 0; b < 8; b++) - msg[i] |= (uint64_t)inputs[tid].msg_hash[i * 8 + b] << (b * 8); - } - - // Read r (x-coordinate of nonce point) - uint64_t r[4]; - for (int i = 0; i < 4; i++) { - r[i] = 0; - for (int b = 0; b < 8; b++) - r[i] |= (uint64_t)r_x[i * 8 + b] << (b * 8); - } - - // sigma_i = k_i * m + r * chi_i (mod n) - uint64_t km[4], rchi[4], sigma[4]; - scalar_mul_mod_n(k, msg, km); - scalar_mul_mod_n(r, chi, rchi); - scalar_add_mod_n(km, rchi, sigma); - - // Write output - uint8_t* out = partial_sigs[tid].sigma_i; - for (int i = 0; i < 4; i++) { - for (int b = 0; b < 8; b++) { - out[i * 8 + b] = (uint8_t)(sigma[i] >> (b * 8)); - } - } -} diff --git a/cggmp21/gpu/cuda/cggmp21_presign.cu b/cggmp21/gpu/cuda/cggmp21_presign.cu deleted file mode 100644 index dc9a985..0000000 --- a/cggmp21/gpu/cuda/cggmp21_presign.cu +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CGGMP21 batched pre-signing kernel — CUDA implementation. -// -// One thread per (signer, slot) pair. Today the wired piece is the -// secp256k1 portion R_i = k_i * G (33 bytes) of each PresignRecord; the -// Paillier ciphertext + ZK proof bytes are reserved with status=0xFF -// until the 2048-bit Karatsuba modexp primitive ships in -// modexp/cpp/karatsuba.hpp. -// -// Build modes: -// * CRYPTO_ENABLE_CUDA=ON -> nvcc, real kernel -// * CRYPTO_ENABLE_CUDA=OFF -> host C++ polyfill, byte-equal to the CPU -// canonical body in cggmp21/cpp/presign.cpp. - -#include -#include - -#ifndef __CUDA_ARCH__ -# define __device__ -# define __global__ -struct dim3 { unsigned x, y, z; }; -static dim3 blockIdx, blockDim, threadIdx; -#endif - -namespace { -constexpr uint32_t PAILLIER_CTBYTES = 512; // 2048-bit ciphertext -constexpr uint32_t PAILLIER_BLINDBYTES = 256; -constexpr uint32_t ZK_PI_ENC_BYTES = - PAILLIER_CTBYTES + 32 + PAILLIER_BLINDBYTES + PAILLIER_BLINDBYTES + 32; -constexpr uint32_t REC_SZ = 33 + PAILLIER_CTBYTES + PAILLIER_CTBYTES + ZK_PI_ENC_BYTES + 8; - -// Reuse the FROST CUDA TU's secp256k1 + SHA-256 + HKDF helpers. Rather -// than duplicate ~400 lines, the host polyfill below calls into the -// matching CPU body of cggmp21/cpp/presign.cpp; the device kernel below -// is the dispatch shape that the matching driver TU will fill in once -// the Paillier sub-kernel is hosted in the same .cu. -} // namespace - -extern "C" __global__ void cggmp21_presign_kernel( - const uint8_t* __restrict__ /*seed*/, // 32 bytes - const uint32_t* __restrict__ /*signer_ids*/, // m entries - uint32_t m, - uint32_t /*slot_id_base*/, - uint32_t n_slots, - uint8_t* __restrict__ records_out) // m * n_slots * REC_SZ -{ - uint32_t gid = blockIdx.x * blockDim.x + threadIdx.x; - uint32_t total = m * n_slots; - if (gid >= total) return; - - // Status byte = 0xFF marks "Paillier sub-step deferred". Aggregator - // routes around this signer until the device-side body lands. - uint8_t* dst = records_out + (uint64_t)gid * (uint64_t)REC_SZ; - for (uint32_t i = 0; i < REC_SZ; ++i) dst[i] = 0; - dst[33 + PAILLIER_CTBYTES + PAILLIER_CTBYTES + ZK_PI_ENC_BYTES] = 0xFF; -} - -// Forward declaration to the CPU canonical body — used by the host polyfill. -namespace lux { namespace crypto { namespace cggmp21 { -struct PresignRecord; -struct PresignSecret; -struct PaillierKey; -int presign_batch(const uint8_t seed[32], - const PaillierKey* pks, - const uint32_t* signer_ids, - uint32_t m, - uint32_t slot_id_base, - uint32_t n_slots, - PresignRecord* records_out, - PresignSecret* secrets_out) noexcept; -}}} - -// Host polyfill: forwards to the CPU oracle. Compiled into the same TU -// when CRYPTO_ENABLE_CUDA=OFF; signature exposed to the test harness. -extern "C" int cggmp21_presign_cuda_host( - const uint8_t* seed, - const void* pks, // PaillierKey array - const uint32_t* signer_ids, - uint32_t m, - uint32_t slot_id_base, - uint32_t n_slots, - void* records_out, // PresignRecord array - void* secrets_out) // PresignSecret array -{ - using lux::crypto::cggmp21::PresignRecord; - using lux::crypto::cggmp21::PresignSecret; - using lux::crypto::cggmp21::PaillierKey; - return lux::crypto::cggmp21::presign_batch( - seed, - static_cast(pks), - signer_ids, m, slot_id_base, n_slots, - static_cast(records_out), - static_cast(secrets_out)); -} diff --git a/cggmp21/gpu/metal/cggmp21.metal b/cggmp21/gpu/metal/cggmp21.metal deleted file mode 100644 index f1b22dc..0000000 --- a/cggmp21/gpu/metal/cggmp21.metal +++ /dev/null @@ -1,366 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -/// @file cggmp21.metal -/// Metal compute shader for CGGMP21 threshold ECDSA operations. -/// -/// CGGMP21 is the state-of-the-art threshold ECDSA protocol enabling -/// t-of-n signers to produce a standard ECDSA signature. -/// -/// The heaviest GPU operation is Paillier encryption/decryption, which -/// requires 2048-bit modular exponentiation. This is extremely GPU-friendly -/// because the exponentiation is pure multiply-and-square with no branching -/// on the data path. -/// -/// Operations: -/// - cggmp21_partial_sign_batch: Paillier-based partial signing -/// -/// The 2048-bit arithmetic uses 32 x 64-bit limbs. - -#include -using namespace metal; - -// ============================================================================= -// 2048-bit unsigned integer (32 x 64-bit limbs, little-endian) -// ============================================================================= - -struct uint2048 { - ulong limbs[32]; -}; - -// ============================================================================= -// 2048-bit arithmetic -// ============================================================================= - -inline void u2048_zero(thread uint2048& a) { - for (int i = 0; i < 32; i++) a.limbs[i] = 0; -} - -inline bool u2048_is_zero(thread const uint2048& a) { - ulong acc = 0; - for (int i = 0; i < 32; i++) acc |= a.limbs[i]; - return acc == 0; -} - -inline int u2048_cmp(thread const uint2048& a, thread const uint2048& b) { - for (int i = 31; i >= 0; i--) { - if (a.limbs[i] < b.limbs[i]) return -1; - if (a.limbs[i] > b.limbs[i]) return 1; - } - return 0; -} - -inline uint2048 u2048_add(thread const uint2048& a, thread const uint2048& b, thread ulong& carry) { - uint2048 r; - ulong c = 0; - for (int i = 0; i < 32; i++) { - ulong sum = a.limbs[i] + c; - c = (sum < a.limbs[i]) ? 1UL : 0UL; - ulong sum2 = sum + b.limbs[i]; - c += (sum2 < sum) ? 1UL : 0UL; - r.limbs[i] = sum2; - } - carry = c; - return r; -} - -inline uint2048 u2048_sub(thread const uint2048& a, thread const uint2048& b, thread ulong& borrow) { - uint2048 r; - ulong bw = 0; - for (int i = 0; i < 32; i++) { - ulong diff = a.limbs[i] - bw; - bw = (diff > a.limbs[i]) ? 1UL : 0UL; - ulong diff2 = diff - b.limbs[i]; - bw += (diff2 > diff) ? 1UL : 0UL; - r.limbs[i] = diff2; - } - borrow = bw; - return r; -} - -// 64x64->128 multiply -inline void mul64(ulong a, ulong b, thread ulong& lo, thread ulong& hi) { - ulong a_lo = a & 0xFFFFFFFFUL, a_hi = a >> 32; - ulong b_lo = b & 0xFFFFFFFFUL, b_hi = b >> 32; - ulong ll = a_lo * b_lo, lh = a_lo * b_hi; - ulong hl = a_hi * b_lo, hh = a_hi * b_hi; - ulong mid = lh + (ll >> 32); - ulong mid2 = mid + hl; - if (mid2 < mid) hh += (1UL << 32); - lo = (mid2 << 32) | (ll & 0xFFFFFFFFUL); - hi = hh + (mid2 >> 32); -} - -// ============================================================================= -// Montgomery multiplication for 2048-bit modulus -// ============================================================================= - -/// Montgomery reduction for 2048-bit: t * R^{-1} mod m -/// t is 4096-bit (64 limbs), m is 2048-bit (32 limbs) -/// m_inv = -m^{-1} mod 2^64 -inline void mont_reduce_2048(thread ulong t[64], - thread const uint2048& m, - ulong m_inv, - thread uint2048& result) { - ulong a[65]; - for (int i = 0; i < 64; i++) a[i] = t[i]; - a[64] = 0; - - for (int i = 0; i < 32; i++) { - ulong u = a[i] * m_inv; - ulong carry = 0; - for (int j = 0; j < 32; j++) { - ulong lo, hi; - mul64(u, m.limbs[j], lo, hi); - ulong sum = lo + carry; - if (sum < lo) hi++; - lo = sum; - sum = a[i + j] + lo; - if (sum < a[i + j]) hi++; - a[i + j] = sum; - carry = hi; - } - for (int j = 32; i + j <= 64; j++) { - ulong sum = a[i + j] + carry; - carry = (sum < a[i + j]) ? 1UL : 0UL; - a[i + j] = sum; - if (!carry) break; - } - } - - for (int i = 0; i < 32; i++) result.limbs[i] = a[i + 32]; - - if (a[64] || u2048_cmp(result, m) >= 0) { - ulong bw; - result = u2048_sub(result, m, bw); - } -} - -/// Montgomery multiplication: a * b * R^{-1} mod m (both a,b in Montgomery form) -inline void mont_mul_2048(thread const uint2048& a, - thread const uint2048& b, - thread const uint2048& m, - ulong m_inv, - thread uint2048& result) { - ulong t[64] = {}; - - for (int i = 0; i < 32; i++) { - ulong carry = 0; - for (int j = 0; j < 32; j++) { - ulong lo, hi; - mul64(a.limbs[i], b.limbs[j], lo, hi); - ulong sum = lo + carry; - if (sum < lo) hi++; - lo = sum; - sum = t[i + j] + lo; - if (sum < t[i + j]) hi++; - t[i + j] = sum; - carry = hi; - } - t[i + 32] = carry; - } - - mont_reduce_2048(t, m, m_inv, result); -} - -/// Montgomery squaring (optimization: fewer multiplications for a*a) -inline void mont_sqr_2048(thread const uint2048& a, - thread const uint2048& m, - ulong m_inv, - thread uint2048& result) { - mont_mul_2048(a, a, m, m_inv, result); -} - -/// Modular exponentiation: base^exp mod m (all in Montgomery form) -/// exp is 2048-bit, base and m are 2048-bit -inline void mont_pow_2048(thread const uint2048& base, - thread const uint2048& exp, - thread const uint2048& m, - ulong m_inv, - thread const uint2048& mont_one, - thread uint2048& result) { - result = mont_one; - uint2048 b = base; - - for (int i = 0; i < 32; i++) { - for (int bit = 0; bit < 64; bit++) { - if ((exp.limbs[i] >> bit) & 1) { - mont_mul_2048(result, b, m, m_inv, result); - } - mont_sqr_2048(b, m, m_inv, b); - } - } -} - -// ============================================================================= -// Paillier encryption primitives -// ============================================================================= - -/// Paillier public key: N (2048-bit), N^2 (4096-bit stored as two uint2048) -struct PaillierPubKey { - uchar n_data[256]; // N in big-endian - uchar n_inv64[8]; // -N^{-1} mod 2^64 for Montgomery -}; - -// ============================================================================= -// CGGMP21 structures -// ============================================================================= - -/// Input for partial signing: encrypted share + message -struct CGGMP21Input { - uchar k_share[32]; // k_i share (secp256k1 scalar) - uchar chi_share[32]; // chi_i = k_i * x_i share - uchar msg_hash[32]; // Message hash - uchar gamma_share[32]; // gamma_i share -}; - -/// Output: partial signature components -struct CGGMP21PartialSig { - uchar sigma_i[32]; // sigma_i = k_i * m + r * chi_i (mod n) -}; - -// ============================================================================= -// secp256k1 order for scalar arithmetic -// ============================================================================= - -constant ulong SECP_N[4] = { - 0xBFD25E8CD0364141UL, 0xBAAEDCE6AF48A03BUL, - 0xFFFFFFFFFFFFFFFEUL, 0xFFFFFFFFFFFFFFFFUL -}; - -/// Modular multiplication mod secp256k1 order (256-bit) -inline void scalar_mul_mod_n(thread const ulong a[4], - thread const ulong b[4], - thread ulong result[4]) { - // Full 512-bit product - ulong t[8] = {}; - for (int i = 0; i < 4; i++) { - ulong carry = 0; - for (int j = 0; j < 4; j++) { - ulong lo, hi; - mul64(a[i], b[j], lo, hi); - ulong sum = lo + carry; if (sum < lo) hi++; - sum = t[i + j] + sum; if (sum < t[i + j]) hi++; - t[i + j] = sum; - carry = hi; - } - t[i + 4] = carry; - } - - // Barrett reduction mod n (simplified: iterate subtraction) - // For production, proper Barrett with precomputed constant - ulong r[4] = {t[0], t[1], t[2], t[3]}; - - // Subtract n while >= n (at most a few iterations for 512->256 bit reduction) - for (int iter = 0; iter < 4; iter++) { - ulong borrow = 0; - ulong diff[4]; - for (int i = 0; i < 4; i++) { - ulong d = r[i] - borrow; - borrow = (d > r[i]) ? 1UL : 0UL; - ulong d2 = d - SECP_N[i]; - borrow += (d2 > d) ? 1UL : 0UL; - diff[i] = d2; - } - if (!borrow) { - for (int i = 0; i < 4; i++) r[i] = diff[i]; - } - } - - for (int i = 0; i < 4; i++) result[i] = r[i]; -} - -/// Modular addition mod n -inline void scalar_add_mod_n(thread const ulong a[4], - thread const ulong b[4], - thread ulong result[4]) { - ulong carry = 0; - for (int i = 0; i < 4; i++) { - ulong sum = a[i] + carry; - carry = (sum < a[i]) ? 1UL : 0UL; - ulong sum2 = sum + b[i]; - carry += (sum2 < sum) ? 1UL : 0UL; - result[i] = sum2; - } - // Reduce - ulong borrow = 0; - ulong diff[4]; - for (int i = 0; i < 4; i++) { - ulong d = result[i] - borrow; - borrow = (d > result[i]) ? 1UL : 0UL; - ulong d2 = d - SECP_N[i]; - borrow += (d2 > d) ? 1UL : 0UL; - diff[i] = d2; - } - if (!borrow || carry) { - for (int i = 0; i < 4; i++) result[i] = diff[i]; - } -} - -// ============================================================================= -// Partial signing kernel -// ============================================================================= - -/// CGGMP21 partial signing. -/// Each thread computes: sigma_i = k_i * m + r * chi_i (mod n) -/// where r is the x-coordinate of the combined nonce point R. -/// -/// This is the scalar arithmetic portion. The Paillier operations for -/// MtA (Multiplicative-to-Additive) conversion are done in separate passes. -/// -/// Output: partial_sigs[tid] contains sigma_i. -kernel void cggmp21_partial_sign_batch( - device const CGGMP21Input* inputs [[buffer(0)]], - device CGGMP21PartialSig* partial_sigs [[buffer(1)]], - device const uchar* r_x [[buffer(2)]], // 32 bytes: R.x mod n - constant uint& num_ops [[buffer(3)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= num_ops) return; - - // Read k_i share - ulong k[4]; - for (int i = 0; i < 4; i++) { - k[i] = 0; - for (int b = 0; b < 8; b++) - k[i] |= (ulong)inputs[tid].k_share[i * 8 + b] << (b * 8); - } - - // Read chi_i share - ulong chi[4]; - for (int i = 0; i < 4; i++) { - chi[i] = 0; - for (int b = 0; b < 8; b++) - chi[i] |= (ulong)inputs[tid].chi_share[i * 8 + b] << (b * 8); - } - - // Read message hash - ulong msg[4]; - for (int i = 0; i < 4; i++) { - msg[i] = 0; - for (int b = 0; b < 8; b++) - msg[i] |= (ulong)inputs[tid].msg_hash[i * 8 + b] << (b * 8); - } - - // Read r (x-coordinate of nonce point) - ulong r[4]; - for (int i = 0; i < 4; i++) { - r[i] = 0; - for (int b = 0; b < 8; b++) - r[i] |= (ulong)r_x[i * 8 + b] << (b * 8); - } - - // sigma_i = k_i * m + r * chi_i (mod n) - ulong km[4], rchi[4], sigma[4]; - scalar_mul_mod_n(k, msg, km); - scalar_mul_mod_n(r, chi, rchi); - scalar_add_mod_n(km, rchi, sigma); - - // Write output - device uchar* out = partial_sigs[tid].sigma_i; - for (int i = 0; i < 4; i++) { - for (int b = 0; b < 8; b++) { - out[i * 8 + b] = uchar(sigma[i] >> (b * 8)); - } - } -} diff --git a/cggmp21/gpu/metal/cggmp21_presign.metal b/cggmp21/gpu/metal/cggmp21_presign.metal deleted file mode 100644 index ee1486e..0000000 --- a/cggmp21/gpu/metal/cggmp21_presign.metal +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CGGMP21 batched pre-signing kernel — Metal compute shader scaffold. -// -// Each thread = one (signer, slot) pair. The full body needs four big -// pieces and two of them are blocked on the 2048-bit modexp Karatsuba -// primitive that sits in modexp/cpp/karatsuba.hpp: -// -// 1. (k_i, gamma_i) ← HKDF-SHA256(seed, signer || slot) DONE in CPU oracle -// 2. R_i = k_i * G (secp256k1, compressed sec1) DONE in CPU oracle -// 3. (K_i, G_cmt_i) = Paillier_enc(k_i || gamma_i) BLOCKED on Karatsuba -// 4. Π^enc sigma proof (commitment + responses) BLOCKED on Karatsuba -// -// The wire format is locked (PresignRecord). Once Karatsuba lands, this -// file fills in the bodies using the same {to_mont, fp_mul, mont_reduce} -// scalar arithmetic as frost_presign.metal — the secp256k1 path is a -// straight clone of frost_presign.metal::scalar_mul_base. -// -// Today the kernel writes a sentinel record (status = 0xFF) so the -// aggregator can exercise the wire format end-to-end while the body lands. -// -// GPU residency invariant. Same as FROST: nonces (k, gamma) live in -// thread address space; only the 33-byte R + Paillier ciphertext bytes -// land in commits_out. Verifiable via metallib-disassemble. - -#include -using namespace metal; - -constant uint REC_SZ_FIXED = 33u + 512u + 512u + (512u + 32u + 256u + 256u + 32u) + 8u; -constant uint STATUS_OFFSET = 33u + 512u + 512u + (512u + 32u + 256u + 256u + 32u); - -kernel void cggmp21_presign( - constant uchar* seed [[buffer(0)]], // 32 bytes - constant uint* signer_ids [[buffer(1)]], // m entries - constant uint& m [[buffer(2)]], - constant uint& slot_id_base [[buffer(3)]], - constant uint& n_slots [[buffer(4)]], - device uchar* records_out [[buffer(5)]], // m*n_slots * REC_SZ - uint gid [[thread_position_in_grid]]) -{ - uint total = m * n_slots; - if (gid >= total) return; - - uint signer_id = signer_ids[gid / n_slots]; - if (signer_id == 0u) return; - - (void)seed; (void)slot_id_base; - - device uchar* rec = records_out + (ulong)gid * (ulong)REC_SZ_FIXED; - for (uint i = 0; i < REC_SZ_FIXED; ++i) rec[i] = 0; - rec[STATUS_OFFSET] = 0xFF; -} diff --git a/cggmp21/gpu/wgsl/cggmp21.wgsl b/cggmp21/gpu/wgsl/cggmp21.wgsl deleted file mode 100644 index ba0bfe7..0000000 --- a/cggmp21/gpu/wgsl/cggmp21.wgsl +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CGGMP21 threshold ECDSA partial signing in WGSL. -// Computes sigma_i = k_i * m + r * chi_i (mod n) for each participant. -// Uses secp256k1 order n for scalar arithmetic. - -@group(0) @binding(0) var inputs: array; // CGGMP21Input packed -@group(0) @binding(1) var outputs: array; // sigma_i -@group(0) @binding(2) var r_x: array; // R.x (32 bytes) -@group(0) @binding(3) var params: vec4; // params.x = num_ops - -// secp256k1 order n (little-endian u32 limbs) -const SECP_N = array( - 0xD0364141u, 0xBFD25E8Cu, 0xAF48A03Bu, 0xBAAEDCE6u, - 0xFFFFFFFEu, 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu -); - -// 256-bit addition mod n -fn scalar_add_mod_n(a: ptr>, - b: ptr>, - r: ptr>) { - var c = 0u; - for (var i = 0u; i < 8u; i = i + 1u) { - let sum = (*a)[i] + c; - c = select(0u, 1u, sum < (*a)[i]); - let sum2 = sum + (*b)[i]; - c = c + select(0u, 1u, sum2 < sum); - (*r)[i] = sum2; - } - // Reduce mod n - var n_val: array = SECP_N; - var bw = 0u; - var diff: array; - for (var i = 0u; i < 8u; i = i + 1u) { - let d = (*r)[i] - bw; - bw = select(0u, 1u, d > (*r)[i]); - let d2 = d - n_val[i]; - bw = bw + select(0u, 1u, d2 > d); - diff[i] = d2; - } - if (bw == 0u || c != 0u) { - for (var i = 0u; i < 8u; i = i + 1u) { (*r)[i] = diff[i]; } - } -} - -// 256-bit multiplication mod n (schoolbook + iterative reduction) -fn scalar_mul_mod_n(a: ptr>, - b: ptr>, - r: ptr>) { - // Schoolbook 256x256 -> 512 bit multiply using 16-bit pieces - var t: array; - for (var i = 0u; i < 16u; i = i + 1u) { t[i] = 0u; } - - for (var i = 0u; i < 8u; i = i + 1u) { - var carry = 0u; - for (var j = 0u; j < 8u; j = j + 1u) { - let a_lo = (*a)[i] & 0xFFFFu; - let a_hi = (*a)[i] >> 16u; - let b_lo = (*b)[j] & 0xFFFFu; - let b_hi = (*b)[j] >> 16u; - - let ll = a_lo * b_lo; - let lh = a_lo * b_hi; - let hl = a_hi * b_lo; - let hh = a_hi * b_hi; - - let mid = lh + hl; - let lo = ll + (mid << 16u) + carry + t[i + j]; - let hi = hh + (mid >> 16u) + select(0u, 1u, lo < t[i + j]); - - t[i + j] = lo; - carry = hi; - } - t[i + 8u] = carry; - } - - // Take low 256 bits and reduce mod n iteratively - for (var i = 0u; i < 8u; i = i + 1u) { (*r)[i] = t[i]; } - - var n_val: array = SECP_N; - for (var iter = 0u; iter < 4u; iter = iter + 1u) { - var bw = 0u; - var diff: array; - for (var i = 0u; i < 8u; i = i + 1u) { - let d = (*r)[i] - bw; - bw = select(0u, 1u, d > (*r)[i]); - let d2 = d - n_val[i]; - bw = bw + select(0u, 1u, d2 > d); - diff[i] = d2; - } - if (bw == 0u) { - for (var i = 0u; i < 8u; i = i + 1u) { (*r)[i] = diff[i]; } - } - } -} - -@compute @workgroup_size(64) -fn cggmp21_partial_sign_batch(@builtin(global_invocation_id) gid: vec3) { - let tid = gid.x; - if (tid >= params.x) { return; } - - // Input layout per thread: k_share[8] || chi_share[8] || msg_hash[8] || gamma[8] = 32 u32 - let in_base = tid * 32u; - - var k: array; - var chi: array; - var msg: array; - - for (var i = 0u; i < 8u; i = i + 1u) { - k[i] = inputs[in_base + i]; - chi[i] = inputs[in_base + 8u + i]; - msg[i] = inputs[in_base + 16u + i]; - } - - var r: array; - for (var i = 0u; i < 8u; i = i + 1u) { r[i] = r_x[i]; } - - // sigma_i = k_i * m + r * chi_i (mod n) - var km: array; - var rchi: array; - var sigma: array; - - scalar_mul_mod_n(&k, &msg, &km); - scalar_mul_mod_n(&r, &chi, &rchi); - scalar_add_mod_n(&km, &rchi, &sigma); - - let out_base = tid * 8u; - for (var i = 0u; i < 8u; i = i + 1u) { - outputs[out_base + i] = sigma[i]; - } -} diff --git a/cggmp21/gpu/wgsl/cggmp21_presign.wgsl b/cggmp21/gpu/wgsl/cggmp21_presign.wgsl deleted file mode 100644 index c5eeddd..0000000 --- a/cggmp21/gpu/wgsl/cggmp21_presign.wgsl +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CGGMP21 batched pre-signing — WGSL compute shader scaffold. -// Same status-0xFF sentinel pattern as the Metal kernel; the body lands -// when the 2048-bit Karatsuba modexp primitive ships. - -@group(0) @binding(0) var seed : array; -@group(0) @binding(1) var signer_ids : array; -@group(0) @binding(2) var params : vec4; // (m, slot_id_base, n_slots, _) -@group(0) @binding(3) var records_out : array; - -const REC_U32 : u32 = 411u; // sizeof(PresignRecord) / 4 = 1645 bytes / 4 (rounded) - -@compute @workgroup_size(64) -fn cggmp21_presign_main(@builtin(global_invocation_id) gid : vec3) { - let total = params.x * params.z; - if (gid.x >= total) { return; } - - let signer_id = signer_ids[gid.x / params.z]; - if (signer_id == 0u) { return; } - - let base = gid.x * REC_U32; - // Sentinel: a compact marker the host polyfill / driver can detect. - records_out[base] = 0xFFFFFFFFu; -} diff --git a/ed25519/gpu/cuda/ed25519.cu b/ed25519/gpu/cuda/ed25519.cu deleted file mode 100644 index 60fd6da..0000000 --- a/ed25519/gpu/cuda/ed25519.cu +++ /dev/null @@ -1,451 +0,0 @@ -// Ed25519 batch verification — CUDA implementation -// Matches ed25519.metal output byte-for-byte -// One thread per signature verification - -#include - -#ifndef __CUDA_ARCH__ -#define __device__ -#define __global__ -#define __shared__ -struct dim3 { unsigned x, y, z; }; -static dim3 blockIdx, blockDim, threadIdx; -#endif - -// ============================================================================= -// 256-bit integer (4 x 64-bit limbs, little-endian) -// ============================================================================= - -struct uint256 { - uint64_t limbs[4]; -}; - -// ============================================================================= -// Ed25519 constants (field prime p = 2^255 - 19) -// ============================================================================= - -__device__ static const uint256 ED_P = {{ - 0xFFFFFFFFFFFFFFEDULL, 0xFFFFFFFFFFFFFFFFULL, - 0xFFFFFFFFFFFFFFFFULL, 0x7FFFFFFFFFFFFFFFULL -}}; - -__device__ static const uint256 ED_D = {{ - 0x75EB4DCA135978A3ULL, 0x00700A4D4141D8ABULL, - 0x8CC740797779E898ULL, 0x52036CBC148B6DE8ULL -}}; - -__device__ static const uint256 ED_2D = {{ - 0xEBD69B9426B2F159ULL, 0x00E0149A8283B156ULL, - 0x198E80F2EEF3D130ULL, 0x2406D9DC56DFFCE7ULL -}}; - -__device__ static const uint256 ED_L = {{ - 0x5812631A5CF5D3EDULL, 0x14DEF9DEA2F79CD6ULL, - 0x0000000000000000ULL, 0x1000000000000000ULL -}}; - -__device__ static const uint256 ED_BY = {{ - 0x6666666666666658ULL, 0x6666666666666666ULL, - 0x6666666666666666ULL, 0x6666666666666666ULL -}}; - -__device__ static const uint256 ED_BX = {{ - 0xC9562D608F25D51AULL, 0x692CC7609525A7B2ULL, - 0xC0A4E231FDD6DC5CULL, 0x216936D3CD6E53FEULL -}}; - -__device__ static const uint256 ZERO = {{0, 0, 0, 0}}; -__device__ static const uint256 ONE = {{1, 0, 0, 0}}; - -// ============================================================================= -// 256-bit arithmetic -// ============================================================================= - -__device__ static int u256_cmp(uint256 a, uint256 b) { - for (int i = 3; i >= 0; i--) { - if (a.limbs[i] < b.limbs[i]) return -1; - if (a.limbs[i] > b.limbs[i]) return 1; - } - return 0; -} - -__device__ static bool u256_is_zero(uint256 a) { - return (a.limbs[0] | a.limbs[1] | a.limbs[2] | a.limbs[3]) == 0; -} - -__device__ static uint256 u256_add(uint256 a, uint256 b, uint64_t& carry) { - uint256 r; - uint64_t c = 0; - for (int i = 0; i < 4; i++) { - uint64_t sum = a.limbs[i] + c; - c = (sum < a.limbs[i]) ? 1ULL : 0ULL; - uint64_t sum2 = sum + b.limbs[i]; - c += (sum2 < sum) ? 1ULL : 0ULL; - r.limbs[i] = sum2; - } - carry = c; - return r; -} - -__device__ static uint256 u256_sub(uint256 a, uint256 b, uint64_t& borrow) { - uint256 r; - uint64_t bw = 0; - for (int i = 0; i < 4; i++) { - uint64_t diff = a.limbs[i] - bw; - bw = (diff > a.limbs[i]) ? 1ULL : 0ULL; - uint64_t diff2 = diff - b.limbs[i]; - bw += (diff2 > diff) ? 1ULL : 0ULL; - r.limbs[i] = diff2; - } - borrow = bw; - return r; -} - -// ============================================================================= -// Field arithmetic mod p = 2^255 - 19 -// Uses __int128 for 64x64->128 multiply (native on CUDA/sm_50+) -// ============================================================================= - -__device__ static uint256 fp_reduce(uint256 a) { - while (u256_cmp(a, ED_P) >= 0) { - uint64_t bw; - a = u256_sub(a, ED_P, bw); - } - return a; -} - -__device__ static uint256 fp_add(uint256 a, uint256 b) { - uint64_t c; - uint256 r = u256_add(a, b, c); - if (c || u256_cmp(r, ED_P) >= 0) { - uint64_t bw; - r = u256_sub(r, ED_P, bw); - } - return r; -} - -__device__ static uint256 fp_sub(uint256 a, uint256 b) { - uint64_t bw; - uint256 r = u256_sub(a, b, bw); - if (bw) { - uint64_t c; - r = u256_add(r, ED_P, c); - } - return r; -} - -__device__ static uint256 fp_mul(uint256 a, uint256 b) { - // Full 512-bit multiply using __int128, then reduce mod p = 2^255 - 19 - uint64_t t[8] = {}; - for (int i = 0; i < 4; i++) { - uint64_t carry = 0; - for (int j = 0; j < 4; j++) { -#ifdef __CUDA_ARCH__ - unsigned __int128 prod = (unsigned __int128)a.limbs[i] * b.limbs[j]; - unsigned __int128 acc = prod + carry + t[i + j]; - t[i + j] = (uint64_t)acc; - carry = (uint64_t)(acc >> 64); -#else - // CPU fallback: split multiply - uint64_t a_lo = a.limbs[i] & 0xFFFFFFFFULL; - uint64_t a_hi = a.limbs[i] >> 32; - uint64_t b_lo = b.limbs[j] & 0xFFFFFFFFULL; - uint64_t b_hi = b.limbs[j] >> 32; - uint64_t ll = a_lo * b_lo; - uint64_t lh = a_lo * b_hi; - uint64_t hl = a_hi * b_lo; - uint64_t hh = a_hi * b_hi; - uint64_t mid = lh + (ll >> 32); - uint64_t mid2 = mid + hl; - if (mid2 < mid) hh += (1ULL << 32); - uint64_t lo = (mid2 << 32) | (ll & 0xFFFFFFFFULL); - uint64_t hi = hh + (mid2 >> 32); - uint64_t sum = lo + carry; - if (sum < lo) hi++; - lo = sum; - sum = t[i + j] + lo; - if (sum < t[i + j]) hi++; - t[i + j] = sum; - carry = hi; -#endif - } - t[i + 4] = carry; - } - - // Reduce mod 2^255 - 19: 2^256 mod p = 38 - uint256 lo_part = {{t[0], t[1], t[2], t[3]}}; - uint256 hi_part = {{t[4], t[5], t[6], t[7]}}; - - // Multiply hi by 38 and add to lo - uint64_t c2 = 0; - uint256 hi38; - for (int i = 0; i < 4; i++) { -#ifdef __CUDA_ARCH__ - unsigned __int128 prod = (unsigned __int128)hi_part.limbs[i] * 38ULL + c2; - hi38.limbs[i] = (uint64_t)prod; - c2 = (uint64_t)(prod >> 64); -#else - uint64_t a_lo = hi_part.limbs[i] & 0xFFFFFFFFULL; - uint64_t a_hi = hi_part.limbs[i] >> 32; - uint64_t ll = a_lo * 38ULL; - uint64_t hl = a_hi * 38ULL; - uint64_t lo = ll + (hl << 32); - uint64_t hi = (hl >> 32) + ((lo < ll) ? 1ULL : 0ULL); - uint64_t sum = lo + c2; - if (sum < lo) hi++; - c2 = hi; - hi38.limbs[i] = sum; -#endif - } - - uint64_t c; - uint256 result = u256_add(lo_part, hi38, c); - if (c || c2) { - uint64_t extra = (c + c2) * 38; - uint256 extra256 = {{extra, 0, 0, 0}}; - result = u256_add(result, extra256, c); - } - - return fp_reduce(result); -} - -__device__ static uint256 fp_sqr(uint256 a) { return fp_mul(a, a); } - -__device__ static uint256 fp_neg(uint256 a) { - if (u256_is_zero(a)) return a; - uint64_t bw; - return u256_sub(ED_P, a, bw); -} - -__device__ static uint256 fp_inv(uint256 a) { - uint256 exp = ED_P; - exp.limbs[0] -= 2; - uint256 result = ONE; - uint256 base = a; - for (int i = 0; i < 4; i++) { - for (int bit = 0; bit < 64; bit++) { - if ((exp.limbs[i] >> bit) & 1) - result = fp_mul(result, base); - base = fp_sqr(base); - } - } - return result; -} - -// ============================================================================= -// Extended twisted Edwards point: (X, Y, Z, T) where x=X/Z, y=Y/Z, T=X*Y/Z -// ============================================================================= - -struct EdPoint { - uint256 X, Y, Z, T; -}; - -__device__ static EdPoint ed_identity() { - EdPoint p; - p.X = ZERO; p.Y = ONE; p.Z = ONE; p.T = ZERO; - return p; -} - -__device__ static EdPoint ed_double(EdPoint P) { - uint256 A = fp_sqr(P.X); - uint256 B = fp_sqr(P.Y); - uint256 C = fp_add(fp_sqr(P.Z), fp_sqr(P.Z)); - uint256 D = fp_neg(A); - uint256 E = fp_sub(fp_sqr(fp_add(P.X, P.Y)), fp_add(A, B)); - uint256 G = fp_add(D, B); - uint256 F = fp_sub(G, C); - uint256 H = fp_sub(D, B); - - EdPoint R; - R.X = fp_mul(E, F); - R.Y = fp_mul(G, H); - R.T = fp_mul(E, H); - R.Z = fp_mul(F, G); - return R; -} - -__device__ static EdPoint ed_add(EdPoint P, EdPoint Q) { - uint256 A = fp_mul(P.X, Q.X); - uint256 B = fp_mul(P.Y, Q.Y); - uint256 C = fp_mul(P.T, fp_mul(ED_2D, Q.T)); - uint256 D = fp_mul(P.Z, Q.Z); - D = fp_add(D, D); - uint256 E = fp_sub(fp_mul(fp_add(P.X, P.Y), fp_add(Q.X, Q.Y)), fp_add(A, B)); - uint256 F = fp_sub(D, C); - uint256 G = fp_add(D, C); - uint256 H = fp_add(B, A); - - EdPoint R; - R.X = fp_mul(E, F); - R.Y = fp_mul(G, H); - R.T = fp_mul(E, H); - R.Z = fp_mul(F, G); - return R; -} - -__device__ static EdPoint ed_mul(uint256 k, EdPoint P) { - EdPoint result = ed_identity(); - for (int i = 3; i >= 0; i--) { - for (int bit = 63; bit >= 0; bit--) { - result = ed_double(result); - if ((k.limbs[i] >> bit) & 1) - result = ed_add(result, P); - } - } - return result; -} - -__device__ static void ed_to_affine(EdPoint p, uint256& x, uint256& y) { - uint256 z_inv = fp_inv(p.Z); - x = fp_mul(p.X, z_inv); - y = fp_mul(p.Y, z_inv); -} - -// ============================================================================= -// Point decompression -// ============================================================================= - -__device__ static bool ed_decompress(const uint8_t* encoded, EdPoint& P) { - uint256 y; - for (int i = 0; i < 4; i++) { - y.limbs[i] = 0; - for (int b = 0; b < 8 && i * 8 + b < 32; b++) { - y.limbs[i] |= (uint64_t)encoded[i * 8 + b] << (b * 8); - } - } - bool x_sign = (encoded[31] >> 7) & 1; - y.limbs[3] &= 0x7FFFFFFFFFFFFFFFULL; - - if (u256_cmp(y, ED_P) >= 0) return false; - - uint256 y2 = fp_sqr(y); - uint256 num = fp_sub(y2, ONE); - uint256 den = fp_add(fp_mul(ED_D, y2), ONE); - uint256 den_inv = fp_inv(den); - uint256 x2 = fp_mul(num, den_inv); - - if (u256_is_zero(x2)) { - if (x_sign) return false; - P.X = ZERO; P.Y = y; P.Z = ONE; P.T = ZERO; - return true; - } - - // x = x2^((p+3)/8) - uint256 exp_val = ED_P; - exp_val.limbs[0] += 3; - for (int i = 0; i < 3; i++) { - exp_val.limbs[i] = (exp_val.limbs[i] >> 3) | (exp_val.limbs[i + 1] << 61); - } - exp_val.limbs[3] >>= 3; - - uint256 x = ONE; - uint256 base = x2; - for (int i = 0; i < 4; i++) { - for (int bit = 0; bit < 64; bit++) { - if ((exp_val.limbs[i] >> bit) & 1) - x = fp_mul(x, base); - base = fp_sqr(base); - } - } - - if (u256_cmp(fp_sqr(x), x2) != 0) { - const uint256 SQRT_M1 = {{ - 0xC4EE1B274A0EA0B0ULL, 0x2F431806AD2FE478ULL, - 0x2B4D00993DFBD7A7ULL, 0x2B8324804FC1DF0BULL - }}; - x = fp_mul(x, SQRT_M1); - if (u256_cmp(fp_sqr(x), x2) != 0) return false; - } - - bool x_is_odd = x.limbs[0] & 1; - if (x_is_odd != x_sign) x = fp_neg(x); - - P.X = x; - P.Y = y; - P.Z = ONE; - P.T = fp_mul(x, y); - return true; -} - -// ============================================================================= -// Structures -// ============================================================================= - -struct Ed25519PublicKey { - uint8_t data[32]; -}; - -struct Ed25519Signature { - uint8_t data[64]; -}; - -struct Ed25519Message { - uint8_t hash[64]; -}; - -// ============================================================================= -// Verification kernel -// ============================================================================= - -extern "C" __global__ void ed25519_verify_batch( - const Ed25519PublicKey* __restrict__ pubkeys, - const Ed25519Message* __restrict__ messages, - const Ed25519Signature* __restrict__ signatures, - uint32_t* __restrict__ results, - const uint32_t num_sigs) -{ - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= num_sigs) return; - - // Decompress public key A - EdPoint A; - if (!ed_decompress(pubkeys[tid].data, A)) { - results[tid] = 0; - return; - } - - // Decompress signature point R - EdPoint R; - if (!ed_decompress(signatures[tid].data, R)) { - results[tid] = 0; - return; - } - - // Read scalar S from signature (bytes 32..63, little-endian) - uint256 S; - for (int i = 0; i < 4; i++) { - S.limbs[i] = 0; - for (int b = 0; b < 8; b++) { - S.limbs[i] |= (uint64_t)signatures[tid].data[32 + i * 8 + b] << (b * 8); - } - } - - if (u256_cmp(S, ED_L) >= 0) { - results[tid] = 0; - return; - } - - // Read pre-computed hash scalar h (reduced mod L by host) - uint256 h; - for (int i = 0; i < 4; i++) { - h.limbs[i] = 0; - for (int b = 0; b < 8; b++) { - h.limbs[i] |= (uint64_t)messages[tid].hash[i * 8 + b] << (b * 8); - } - } - - // Verify: [S]B == R + [h]A - EdPoint B; - B.X = ED_BX; B.Y = ED_BY; B.Z = ONE; B.T = fp_mul(ED_BX, ED_BY); - EdPoint SB = ed_mul(S, B); - - EdPoint hA = ed_mul(h, A); - EdPoint RhA = ed_add(R, hA); - - uint256 sb_x, sb_y, rha_x, rha_y; - ed_to_affine(SB, sb_x, sb_y); - ed_to_affine(RhA, rha_x, rha_y); - - bool valid = (u256_cmp(sb_x, rha_x) == 0) && (u256_cmp(sb_y, rha_y) == 0); - results[tid] = valid ? 1u : 0u; -} diff --git a/ed25519/gpu/metal/ed25519.metal b/ed25519/gpu/metal/ed25519.metal deleted file mode 100644 index 6881975..0000000 --- a/ed25519/gpu/metal/ed25519.metal +++ /dev/null @@ -1,501 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -/// @file ed25519.metal -/// Metal compute shader for batch Ed25519 EdDSA signature verification. -/// -/// Twisted Edwards curve: -x^2 + y^2 = 1 + d*x^2*y^2 over F_p -/// p = 2^255 - 19 -/// d = -121665/121666 mod p -/// L = 2^252 + 27742317777372353535851937790883648493 (group order) -/// B = generator point -/// -/// Verification: check [8][S]B == [8]R + [8][H(R||A||M)]A -/// Each thread verifies one signature. - -#include -using namespace metal; - -// ============================================================================= -// 256-bit integer (4 x 64-bit limbs, little-endian) -// ============================================================================= - -struct uint256 { - ulong limbs[4]; -}; - -// ============================================================================= -// Ed25519 constants (field prime p = 2^255 - 19) -// ============================================================================= - -constant uint256 ED_P = {{ - 0xFFFFFFFFFFFFFFEDUL, 0xFFFFFFFFFFFFFFFFUL, - 0xFFFFFFFFFFFFFFFFUL, 0x7FFFFFFFFFFFFFFFUL -}}; - -// d = -121665/121666 mod p -// = 37095705934669439343138083508754565189542113879843219016388785533085940283555 -constant uint256 ED_D = {{ - 0x75EB4DCA135978A3UL, 0x00700A4D4141D8ABUL, - 0x8CC740797779E898UL, 0x52036CBC148B6DE8UL -}}; - -// 2*d mod p -constant uint256 ED_2D = {{ - 0xEBD69B9426B2F159UL, 0x00E0149A8283B156UL, - 0x198E80F2EEF3D130UL, 0x2406D9DC56DFFCE7UL -}}; - -// Group order L -constant uint256 ED_L = {{ - 0x5812631A5CF5D3EDUL, 0x14DEF9DEA2F79CD6UL, - 0x0000000000000000UL, 0x1000000000000000UL -}}; - -// Generator B (y-coordinate, x is recovered) -constant uint256 ED_BY = {{ - 0x6666666666666658UL, 0x6666666666666666UL, - 0x6666666666666666UL, 0x6666666666666666UL -}}; - -// B.x -constant uint256 ED_BX = {{ - 0xC9562D608F25D51AUL, 0x692CC7609525A7B2UL, - 0xC0A4E231FDD6DC5CUL, 0x216936D3CD6E53FEUL -}}; - -constant uint256 ZERO = {{0, 0, 0, 0}}; -constant uint256 ONE = {{1, 0, 0, 0}}; - -// ============================================================================= -// 256-bit arithmetic -// ============================================================================= - -inline int u256_cmp(uint256 a, uint256 b) { - for (int i = 3; i >= 0; i--) { - if (a.limbs[i] < b.limbs[i]) return -1; - if (a.limbs[i] > b.limbs[i]) return 1; - } - return 0; -} - -inline bool u256_is_zero(uint256 a) { - return (a.limbs[0] | a.limbs[1] | a.limbs[2] | a.limbs[3]) == 0; -} - -inline uint256 u256_add(uint256 a, uint256 b, thread ulong& carry) { - uint256 r; - ulong c = 0; - for (int i = 0; i < 4; i++) { - ulong sum = a.limbs[i] + c; - c = (sum < a.limbs[i]) ? 1UL : 0UL; - ulong sum2 = sum + b.limbs[i]; - c += (sum2 < sum) ? 1UL : 0UL; - r.limbs[i] = sum2; - } - carry = c; - return r; -} - -inline uint256 u256_sub(uint256 a, uint256 b, thread ulong& borrow) { - uint256 r; - ulong bw = 0; - for (int i = 0; i < 4; i++) { - ulong diff = a.limbs[i] - bw; - bw = (diff > a.limbs[i]) ? 1UL : 0UL; - ulong diff2 = diff - b.limbs[i]; - bw += (diff2 > diff) ? 1UL : 0UL; - r.limbs[i] = diff2; - } - borrow = bw; - return r; -} - -// 64x64->128 multiply -inline void mul64(ulong a, ulong b, thread ulong& lo, thread ulong& hi) { - ulong a_lo = a & 0xFFFFFFFFUL; - ulong a_hi = a >> 32; - ulong b_lo = b & 0xFFFFFFFFUL; - ulong b_hi = b >> 32; - ulong ll = a_lo * b_lo; - ulong lh = a_lo * b_hi; - ulong hl = a_hi * b_lo; - ulong hh = a_hi * b_hi; - ulong mid = lh + (ll >> 32); - ulong mid2 = mid + hl; - if (mid2 < mid) hh += (1UL << 32); - lo = (mid2 << 32) | (ll & 0xFFFFFFFFUL); - hi = hh + (mid2 >> 32); -} - -// ============================================================================= -// Field arithmetic mod p = 2^255 - 19 -// Uses direct reduction (not Montgomery) since p has special form. -// ============================================================================= - -inline uint256 fp_reduce(uint256 a) { - while (u256_cmp(a, ED_P) >= 0) { - ulong bw; - a = u256_sub(a, ED_P, bw); - } - return a; -} - -inline uint256 fp_add(uint256 a, uint256 b) { - ulong c; - uint256 r = u256_add(a, b, c); - if (c || u256_cmp(r, ED_P) >= 0) { - ulong bw; - r = u256_sub(r, ED_P, bw); - } - return r; -} - -inline uint256 fp_sub(uint256 a, uint256 b) { - ulong bw; - uint256 r = u256_sub(a, b, bw); - if (bw) { - ulong c; - r = u256_add(r, ED_P, c); - } - return r; -} - -inline uint256 fp_mul(uint256 a, uint256 b) { - // Full 512-bit multiply, then reduce mod p = 2^255 - 19 - ulong t[8] = {}; - for (int i = 0; i < 4; i++) { - ulong carry = 0; - for (int j = 0; j < 4; j++) { - ulong lo, hi; - mul64(a.limbs[i], b.limbs[j], lo, hi); - ulong sum = lo + carry; - if (sum < lo) hi++; - lo = sum; - sum = t[i + j] + lo; - if (sum < t[i + j]) hi++; - t[i + j] = sum; - carry = hi; - } - t[i + 4] = carry; - } - - // Reduce mod 2^255 - 19: - // Split t into low 255 bits and high bits, use high * 38 = high * 2 * 19 - // t = t_lo + t_hi * 2^256 - // 2^256 mod p = 2*19 = 38 - // So t mod p = t_lo + 38 * t_hi (approximately) - - // Extract low 256 bits and high 256 bits - uint256 lo_part = {{t[0], t[1], t[2], t[3]}}; - uint256 hi_part = {{t[4], t[5], t[6], t[7]}}; - - // Multiply hi by 38 and add to lo - // Since hi_part is at most 256 bits and 38 is small, result fits in ~262 bits - ulong carry = 0; - uint256 hi38; - for (int i = 0; i < 4; i++) { - ulong lo_val, hi_val; - mul64(hi_part.limbs[i], 38UL, lo_val, hi_val); - ulong sum = lo_val + carry; - carry = hi_val + ((sum < lo_val) ? 1UL : 0UL); - hi38.limbs[i] = sum; - } - - ulong c; - uint256 result = u256_add(lo_part, hi38, c); - // Handle final carry: carry * 2^256 = carry * 38 - if (c || carry) { - ulong extra = (c + carry) * 38; - uint256 extra256 = {{extra, 0, 0, 0}}; - result = u256_add(result, extra256, c); - } - - return fp_reduce(result); -} - -inline uint256 fp_sqr(uint256 a) { return fp_mul(a, a); } - -inline uint256 fp_neg(uint256 a) { - if (u256_is_zero(a)) return a; - ulong bw; - return u256_sub(ED_P, a, bw); -} - -/// Fermat inversion: a^(p-2) mod p -inline uint256 fp_inv(uint256 a) { - uint256 exp = ED_P; - exp.limbs[0] -= 2; - uint256 result = ONE; - uint256 base = a; - for (int i = 0; i < 4; i++) { - for (int bit = 0; bit < 64; bit++) { - if ((exp.limbs[i] >> bit) & 1) - result = fp_mul(result, base); - base = fp_sqr(base); - } - } - return result; -} - -// ============================================================================= -// Extended twisted Edwards point: (X, Y, Z, T) where x=X/Z, y=Y/Z, T=X*Y/Z -// ============================================================================= - -struct EdPoint { - uint256 X, Y, Z, T; -}; - -inline EdPoint ed_identity() { - EdPoint p; - p.X = ZERO; p.Y = ONE; p.Z = ONE; p.T = ZERO; - return p; -} - -inline bool ed_is_identity(EdPoint p) { - return u256_is_zero(p.X) && !u256_is_zero(p.Y) && !u256_is_zero(p.Z); -} - -/// Extended point doubling -inline EdPoint ed_double(EdPoint P) { - uint256 A = fp_sqr(P.X); - uint256 B = fp_sqr(P.Y); - uint256 C = fp_add(fp_sqr(P.Z), fp_sqr(P.Z)); // 2*Z^2 - uint256 D = fp_neg(A); // a*X^2 where a=-1 - uint256 E = fp_sub(fp_sqr(fp_add(P.X, P.Y)), fp_add(A, B)); - uint256 G = fp_add(D, B); - uint256 F = fp_sub(G, C); - uint256 H = fp_sub(D, B); - - EdPoint R; - R.X = fp_mul(E, F); - R.Y = fp_mul(G, H); - R.T = fp_mul(E, H); - R.Z = fp_mul(F, G); - return R; -} - -/// Extended point addition -inline EdPoint ed_add(EdPoint P, EdPoint Q) { - uint256 A = fp_mul(P.X, Q.X); - uint256 B = fp_mul(P.Y, Q.Y); - uint256 C = fp_mul(P.T, fp_mul(ED_2D, Q.T)); - uint256 D = fp_mul(P.Z, Q.Z); - D = fp_add(D, D); // 2*Z1*Z2 - uint256 E = fp_sub(fp_mul(fp_add(P.X, P.Y), fp_add(Q.X, Q.Y)), fp_add(A, B)); - uint256 F = fp_sub(D, C); - uint256 G = fp_add(D, C); - uint256 H = fp_add(B, A); // a=-1, so B - a*A = B + A - - EdPoint R; - R.X = fp_mul(E, F); - R.Y = fp_mul(G, H); - R.T = fp_mul(E, H); - R.Z = fp_mul(F, G); - return R; -} - -/// Scalar multiplication: k * P -inline EdPoint ed_mul(uint256 k, EdPoint P) { - EdPoint result = ed_identity(); - for (int i = 3; i >= 0; i--) { - for (int bit = 63; bit >= 0; bit--) { - result = ed_double(result); - if ((k.limbs[i] >> bit) & 1) - result = ed_add(result, P); - } - } - return result; -} - -/// Convert extended -> affine -inline void ed_to_affine(EdPoint p, thread uint256& x, thread uint256& y) { - uint256 z_inv = fp_inv(p.Z); - x = fp_mul(p.X, z_inv); - y = fp_mul(p.Y, z_inv); -} - -// ============================================================================= -// Point decompression -// ============================================================================= - -/// Decompress Ed25519 point from 32-byte encoding. -/// Encoding: y-coordinate (255 bits, little-endian) + sign bit of x in MSB. -inline bool ed_decompress(device const uchar* encoded, thread EdPoint& P) { - // Read y (little-endian, 255 bits) - uint256 y; - for (int i = 0; i < 4; i++) { - y.limbs[i] = 0; - int start = i * 8; - int end = (i < 3) ? 8 : 8; - for (int b = 0; b < end && start + b < 32; b++) { - y.limbs[i] |= (ulong)encoded[start + b] << (b * 8); - } - } - // Extract x sign bit (bit 255 = MSB of byte 31) - bool x_sign = (encoded[31] >> 7) & 1; - y.limbs[3] &= 0x7FFFFFFFFFFFFFFFUL; // Clear sign bit - - if (u256_cmp(y, ED_P) >= 0) return false; - - // Recover x from y: x^2 = (y^2 - 1) / (d*y^2 + 1) - uint256 y2 = fp_sqr(y); - uint256 num = fp_sub(y2, ONE); - uint256 den = fp_add(fp_mul(ED_D, y2), ONE); - uint256 den_inv = fp_inv(den); - uint256 x2 = fp_mul(num, den_inv); - - if (u256_is_zero(x2)) { - if (x_sign) return false; // x must be 0 but sign says negative - P.X = ZERO; P.Y = y; P.Z = ONE; P.T = ZERO; - return true; - } - - // x = x2^((p+3)/8) (works because p = 5 mod 8) - uint256 exp_val = ED_P; - // (p+3)/8 = (2^255 - 19 + 3)/8 = (2^255 - 16)/8 = 2^252 - 2 - exp_val.limbs[0] += 3; - // Shift right by 3 - for (int i = 0; i < 3; i++) { - exp_val.limbs[i] = (exp_val.limbs[i] >> 3) | (exp_val.limbs[i + 1] << 61); - } - exp_val.limbs[3] >>= 3; - - uint256 x = ONE; - uint256 base = x2; - for (int i = 0; i < 4; i++) { - for (int bit = 0; bit < 64; bit++) { - if ((exp_val.limbs[i] >> bit) & 1) - x = fp_mul(x, base); - base = fp_sqr(base); - } - } - - // Check: if x^2 != x2, multiply by sqrt(-1) - if (u256_cmp(fp_sqr(x), x2) != 0) { - // sqrt(-1) mod p = 2^((p-1)/4) mod p - // Precomputed: 19681161376707505956807079304988542015446066515923890162744021073123829784752 - const uint256 SQRT_M1 = {{ - 0xC4EE1B274A0EA0B0UL, 0x2F431806AD2FE478UL, - 0x2B4D00993DFBD7A7UL, 0x2B8324804FC1DF0BUL - }}; - x = fp_mul(x, SQRT_M1); - if (u256_cmp(fp_sqr(x), x2) != 0) return false; - } - - // Adjust sign - bool x_is_odd = x.limbs[0] & 1; - if (x_is_odd != x_sign) x = fp_neg(x); - - P.X = x; - P.Y = y; - P.Z = ONE; - P.T = fp_mul(x, y); - return true; -} - -// ============================================================================= -// SHA-512 for Ed25519 (verification needs H(R||A||M)) -// Simplified: use first 64 bytes of hash, reduce mod L for scalar -// ============================================================================= - -// SHA-512 would be needed here for full implementation. -// For the GPU kernel, we accept pre-hashed scalars from the host. - -// ============================================================================= -// Structures -// ============================================================================= - -struct Ed25519PublicKey { - uchar data[32]; -}; - -struct Ed25519Signature { - uchar data[64]; // R[32] || S[32] -}; - -struct Ed25519Message { - uchar hash[64]; // Pre-computed H(R || A || M), 64 bytes -}; - -// ============================================================================= -// Verification kernel -// ============================================================================= - -/// Batch Ed25519 signature verification. -/// Each thread verifies one (pubkey, message_hash, signature) tuple. -/// -/// The host pre-computes H(R || A || M) and reduces it mod L to get scalar h. -/// The GPU performs the expensive point arithmetic: check [S]B == R + [h]A. -/// -/// Output: results[tid] = 1 if valid, 0 otherwise. -kernel void ed25519_verify_batch( - device const Ed25519PublicKey* pubkeys [[buffer(0)]], - device const Ed25519Message* messages [[buffer(1)]], - device const Ed25519Signature* signatures [[buffer(2)]], - device uint* results [[buffer(3)]], - constant uint& num_sigs [[buffer(4)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= num_sigs) return; - - // -- Decompress public key A -- - EdPoint A; - if (!ed_decompress(pubkeys[tid].data, A)) { - results[tid] = 0; - return; - } - - // -- Decompress signature point R -- - EdPoint R; - if (!ed_decompress(signatures[tid].data, R)) { - results[tid] = 0; - return; - } - - // -- Read scalar S from signature (bytes 32..63, little-endian) -- - uint256 S; - for (int i = 0; i < 4; i++) { - S.limbs[i] = 0; - for (int b = 0; b < 8; b++) { - S.limbs[i] |= (ulong)signatures[tid].data[32 + i * 8 + b] << (b * 8); - } - } - - // S must be < L - if (u256_cmp(S, ED_L) >= 0) { - results[tid] = 0; - return; - } - - // -- Read pre-computed hash scalar h (reduced mod L by host) -- - uint256 h; - for (int i = 0; i < 4; i++) { - h.limbs[i] = 0; - for (int b = 0; b < 8; b++) { - h.limbs[i] |= (ulong)messages[tid].hash[i * 8 + b] << (b * 8); - } - } - - // -- Verify: [S]B == R + [h]A -- - // Equivalently: [S]B - [h]A - R == identity - - // Compute [S]B - EdPoint B; - B.X = ED_BX; B.Y = ED_BY; B.Z = ONE; B.T = fp_mul(ED_BX, ED_BY); - EdPoint SB = ed_mul(S, B); - - // Compute [h]A - EdPoint hA = ed_mul(h, A); - - // Compute R + [h]A - EdPoint RhA = ed_add(R, hA); - - // Compare [S]B == R + [h]A by checking coordinates - uint256 sb_x, sb_y, rha_x, rha_y; - ed_to_affine(SB, sb_x, sb_y); - ed_to_affine(RhA, rha_x, rha_y); - - bool valid = (u256_cmp(sb_x, rha_x) == 0) && (u256_cmp(sb_y, rha_y) == 0); - results[tid] = valid ? 1u : 0u; -} diff --git a/ed25519/gpu/metal/ed25519_batch.metal b/ed25519/gpu/metal/ed25519_batch.metal deleted file mode 100644 index e869284..0000000 --- a/ed25519/gpu/metal/ed25519_batch.metal +++ /dev/null @@ -1,437 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// GPU-batched Ed25519 EdDSA signature verification (RFC 8032). -// -// One thread per signature. Each thread runs the full verify: -// 1. Decompress public key A (32 bytes -> Edwards point) -// 2. Decompress signature point R (first 32 bytes of sig -> Edwards point) -// 3. Read scalar S (last 32 bytes of sig) -// 4. Read pre-computed challenge h = SHA-512(R || A || M) reduced mod L -// (host computes SHA-512 -- it has hardware acceleration via NEON -// crypto extensions; doing SHA-512 on Metal is ~10x slower per byte -// and dominates kernel time) -// 5. Verify [S]B == R + [h]A -// -// The kernel produces results[i] = 1 if signature i is valid, 0 otherwise. -// Byte-equal to crypto/ed25519.Verify when called over the same triples. -// -// Layout: pubkeys[N*32], sigs[N*64], hs[N*32], results[N*1]. -// Hash scalar h is 64-byte SHA-512 output reduced to 32-byte little-endian -// scalar mod L by the host before dispatch. - -#include -using namespace metal; - -// ============================================================================= -// 256-bit integer (4 x 64-bit limbs, little-endian) -// ============================================================================= - -struct uint256 { - ulong limbs[4]; -}; - -constant uint256 ZERO = {{0UL, 0UL, 0UL, 0UL}}; -constant uint256 ONE = {{1UL, 0UL, 0UL, 0UL}}; - -// p = 2^255 - 19 -constant uint256 ED_P = {{ - 0xFFFFFFFFFFFFFFEDUL, 0xFFFFFFFFFFFFFFFFUL, - 0xFFFFFFFFFFFFFFFFUL, 0x7FFFFFFFFFFFFFFFUL -}}; - -// Curve constant d = -121665/121666 mod p -// = 0x52036CEE2B6FFE738CC740797779E89800700A4D4141D8AB75EB4DCA135978A3 -constant uint256 ED_D = {{ - 0x75EB4DCA135978A3UL, 0x00700A4D4141D8ABUL, - 0x8CC740797779E898UL, 0x52036CEE2B6FFE73UL -}}; - -// 2*d mod p (precomputed for addition formula) -constant uint256 ED_2D = {{ - 0xEBD69B9426B2F159UL, 0x00E0149A8283B156UL, - 0x198E80F2EEF3D130UL, 0x2406D9DC56DFFCE7UL -}}; - -// Group order L = 2^252 + 27742317777372353535851937790883648493 -constant uint256 ED_L = {{ - 0x5812631A5CF5D3EDUL, 0x14DEF9DEA2F79CD6UL, - 0x0000000000000000UL, 0x1000000000000000UL -}}; - -// Generator B base point coordinates -constant uint256 ED_BX = {{ - 0xC9562D608F25D51AUL, 0x692CC7609525A7B2UL, - 0xC0A4E231FDD6DC5CUL, 0x216936D3CD6E53FEUL -}}; -constant uint256 ED_BY = {{ - 0x6666666666666658UL, 0x6666666666666666UL, - 0x6666666666666666UL, 0x6666666666666666UL -}}; - -// sqrt(-1) mod p, used for point decompression when x^2 != target -constant uint256 ED_SQRT_M1 = {{ - 0xC4EE1B274A0EA0B0UL, 0x2F431806AD2FE478UL, - 0x2B4D00993DFBD7A7UL, 0x2B8324804FC1DF0BUL -}}; - -// ============================================================================= -// 256-bit arithmetic helpers -// ============================================================================= - -inline int u256_cmp(uint256 a, uint256 b) { - for (int i = 3; i >= 0; --i) { - if (a.limbs[i] < b.limbs[i]) return -1; - if (a.limbs[i] > b.limbs[i]) return 1; - } - return 0; -} - -inline bool u256_is_zero(uint256 a) { - return (a.limbs[0] | a.limbs[1] | a.limbs[2] | a.limbs[3]) == 0UL; -} - -inline uint256 u256_add(uint256 a, uint256 b, thread ulong& carry) { - uint256 r; - ulong c = 0; - for (int i = 0; i < 4; ++i) { - ulong s1 = a.limbs[i] + c; - ulong c1 = (s1 < a.limbs[i]) ? 1UL : 0UL; - ulong s2 = s1 + b.limbs[i]; - ulong c2 = (s2 < s1) ? 1UL : 0UL; - r.limbs[i] = s2; - c = c1 + c2; - } - carry = c; - return r; -} - -inline uint256 u256_sub(uint256 a, uint256 b, thread ulong& borrow) { - uint256 r; - ulong bw = 0; - for (int i = 0; i < 4; ++i) { - ulong d1 = a.limbs[i] - bw; - ulong b1 = (d1 > a.limbs[i]) ? 1UL : 0UL; - ulong d2 = d1 - b.limbs[i]; - ulong b2 = (d2 > d1) ? 1UL : 0UL; - r.limbs[i] = d2; - bw = b1 + b2; - } - borrow = bw; - return r; -} - -inline void mul64(ulong a, ulong b, thread ulong& lo, thread ulong& hi) { - ulong al = a & 0xFFFFFFFFUL, ah = a >> 32; - ulong bl = b & 0xFFFFFFFFUL, bh = b >> 32; - ulong ll = al * bl; - ulong lh = al * bh; - ulong hl = ah * bl; - ulong hh = ah * bh; - ulong mid = lh + (ll >> 32); - ulong mid2 = mid + hl; - if (mid2 < mid) hh += (1UL << 32); - lo = (mid2 << 32) | (ll & 0xFFFFFFFFUL); - hi = hh + (mid2 >> 32); -} - -// ============================================================================= -// Field arithmetic mod p = 2^255 - 19 -// ============================================================================= - -inline uint256 fp_canonical(uint256 a) { - while (u256_cmp(a, ED_P) >= 0) { - ulong bw; - a = u256_sub(a, ED_P, bw); - } - return a; -} - -inline uint256 fp_add(uint256 a, uint256 b) { - ulong c; - uint256 r = u256_add(a, b, c); - if (c != 0UL || u256_cmp(r, ED_P) >= 0) { - ulong bw; - r = u256_sub(r, ED_P, bw); - } - return r; -} - -inline uint256 fp_sub(uint256 a, uint256 b) { - ulong bw; - uint256 r = u256_sub(a, b, bw); - if (bw != 0UL) { - ulong c; - r = u256_add(r, ED_P, c); - } - return r; -} - -inline uint256 fp_mul(uint256 a, uint256 b) { - // 4x4 schoolbook -> 8 limb product, then fold high half * 38 into low. - ulong t[8]; - for (int i = 0; i < 8; ++i) t[i] = 0; - for (int i = 0; i < 4; ++i) { - ulong carry = 0; - for (int j = 0; j < 4; ++j) { - ulong lo, hi; - mul64(a.limbs[i], b.limbs[j], lo, hi); - ulong s = lo + carry; - ulong c1 = (s < lo) ? 1UL : 0UL; - ulong s2 = t[i + j] + s; - ulong c2 = (s2 < t[i + j]) ? 1UL : 0UL; - t[i + j] = s2; - carry = hi + c1 + c2; - } - t[i + 4] = carry; - } - - // 2^256 mod p = 38; fold high * 38 into low. - uint256 lo_part = {{t[0], t[1], t[2], t[3]}}; - uint256 hi_part = {{t[4], t[5], t[6], t[7]}}; - ulong carry = 0; - uint256 hi38; - for (int i = 0; i < 4; ++i) { - ulong lo, hi; - mul64(hi_part.limbs[i], 38UL, lo, hi); - ulong s = lo + carry; - carry = hi + ((s < lo) ? 1UL : 0UL); - hi38.limbs[i] = s; - } - ulong c; - uint256 r = u256_add(lo_part, hi38, c); - if (c != 0UL || carry != 0UL) { - ulong extra = (c + carry) * 38UL; - uint256 ex = {{extra, 0UL, 0UL, 0UL}}; - r = u256_add(r, ex, c); - } - return fp_canonical(r); -} - -inline uint256 fp_sqr(uint256 a) { return fp_mul(a, a); } - -inline uint256 fp_neg(uint256 a) { - if (u256_is_zero(a)) return a; - ulong bw; - return u256_sub(ED_P, a, bw); -} - -// Fermat inverse: a^(p-2) mod p -inline uint256 fp_inv(uint256 a) { - uint256 exp = ED_P; - exp.limbs[0] -= 2; - uint256 result = ONE; - uint256 base = a; - for (int i = 0; i < 4; ++i) { - ulong limb = exp.limbs[i]; - for (int b = 0; b < 64; ++b) { - if ((limb >> b) & 1UL) result = fp_mul(result, base); - base = fp_sqr(base); - } - } - return result; -} - -// ============================================================================= -// Extended twisted Edwards point (X:Y:Z:T), x=X/Z, y=Y/Z, T=XY/Z -// ============================================================================= - -struct EdPoint { - uint256 X, Y, Z, T; -}; - -inline EdPoint ed_identity() { - EdPoint p; - p.X = ZERO; p.Y = ONE; p.Z = ONE; p.T = ZERO; - return p; -} - -inline EdPoint ed_double(EdPoint P) { - uint256 A = fp_sqr(P.X); - uint256 B = fp_sqr(P.Y); - uint256 C = fp_add(fp_sqr(P.Z), fp_sqr(P.Z)); - uint256 D = fp_neg(A); - uint256 XY = fp_add(P.X, P.Y); - uint256 E = fp_sub(fp_sqr(XY), fp_add(A, B)); - uint256 G = fp_add(D, B); - uint256 F = fp_sub(G, C); - uint256 H = fp_sub(D, B); - EdPoint R; - R.X = fp_mul(E, F); - R.Y = fp_mul(G, H); - R.T = fp_mul(E, H); - R.Z = fp_mul(F, G); - return R; -} - -inline EdPoint ed_add(EdPoint P, EdPoint Q) { - // add-2008-hwcd-3 (Hisil, Wong, Carter, Dawson 2008) for a = -1 twisted - // Edwards. Algorithm 1 of "Twisted Edwards Curves Revisited", uses 8M. - uint256 A = fp_mul(fp_sub(P.Y, P.X), fp_sub(Q.Y, Q.X)); - uint256 B = fp_mul(fp_add(P.Y, P.X), fp_add(Q.Y, Q.X)); - uint256 C = fp_mul(fp_mul(P.T, ED_2D), Q.T); - uint256 ZZ = fp_mul(P.Z, Q.Z); - uint256 D = fp_add(ZZ, ZZ); // 2*Z1*Z2 - uint256 E = fp_sub(B, A); - uint256 F = fp_sub(D, C); - uint256 G = fp_add(D, C); - uint256 H = fp_add(B, A); - EdPoint R; - R.X = fp_mul(E, F); - R.Y = fp_mul(G, H); - R.T = fp_mul(E, H); - R.Z = fp_mul(F, G); - return R; -} - -inline EdPoint ed_mul(uint256 k, EdPoint P) { - EdPoint result = ed_identity(); - for (int i = 3; i >= 0; --i) { - ulong limb = k.limbs[i]; - for (int b = 63; b >= 0; --b) { - result = ed_double(result); - if ((limb >> b) & 1UL) result = ed_add(result, P); - } - } - return result; -} - -// Convert extended point to canonical (x, y) affine pair. -inline void ed_to_affine(EdPoint p, thread uint256& x, thread uint256& y) { - uint256 zinv = fp_inv(p.Z); - x = fp_mul(p.X, zinv); - y = fp_mul(p.Y, zinv); -} - -// ============================================================================= -// Point decompression (RFC 8032 Section 5.1.3) -// ============================================================================= - -inline bool ed_decompress(device const uchar* enc, thread EdPoint& P) { - uint256 y; - for (int i = 0; i < 4; ++i) { - ulong v = 0; - for (int b = 0; b < 8; ++b) v |= (ulong)enc[i * 8 + b] << (b * 8); - y.limbs[i] = v; - } - bool x_sign = (enc[31] >> 7) & 1u; - y.limbs[3] &= 0x7FFFFFFFFFFFFFFFUL; - if (u256_cmp(y, ED_P) >= 0) return false; - - uint256 y2 = fp_sqr(y); - uint256 num = fp_sub(y2, ONE); - uint256 den = fp_add(fp_mul(ED_D, y2), ONE); - uint256 den_inv = fp_inv(den); - uint256 x2 = fp_mul(num, den_inv); - - if (u256_is_zero(x2)) { - if (x_sign) return false; - P.X = ZERO; P.Y = y; P.Z = ONE; P.T = ZERO; - return true; - } - - // x = x2^((p+3)/8); if x^2 != x2, multiply by sqrt(-1) - uint256 exp_v = ED_P; - exp_v.limbs[0] += 3UL; - for (int i = 0; i < 3; ++i) { - exp_v.limbs[i] = (exp_v.limbs[i] >> 3) | (exp_v.limbs[i + 1] << 61); - } - exp_v.limbs[3] >>= 3; - - uint256 x = ONE; - uint256 base = x2; - for (int i = 0; i < 4; ++i) { - ulong limb = exp_v.limbs[i]; - for (int b = 0; b < 64; ++b) { - if ((limb >> b) & 1UL) x = fp_mul(x, base); - base = fp_sqr(base); - } - } - if (u256_cmp(fp_sqr(x), x2) != 0) { - x = fp_mul(x, ED_SQRT_M1); - if (u256_cmp(fp_sqr(x), x2) != 0) return false; - } - - bool x_is_odd = (x.limbs[0] & 1UL) != 0UL; - if (x_is_odd != x_sign) x = fp_neg(x); - - P.X = x; P.Y = y; P.Z = ONE; P.T = fp_mul(x, y); - return true; -} - -// ============================================================================= -// I/O records -// ============================================================================= - -struct Ed25519PublicKey { uchar data[32]; }; -struct Ed25519Signature { uchar data[64]; }; // R[32] || S[32] -struct Ed25519Challenge { uchar data[32]; }; // h = SHA512(R||A||M) mod L - -// ============================================================================= -// Batch verify kernel. -// -// Each thread verifies one (pubkey, signature, challenge) tuple. The challenge -// h is computed by the host as SHA-512(R || A || M) reduced mod L. Verifying -// [S]B == R + [h]A reproduces the RFC 8032 cofactored verify rule when h is -// the standard SHA-512-derived challenge. -// -// This kernel returns 1/0 only -- it does not signal which sub-check failed. -// Callers that need a diagnostic should fall back to the CPU verify. -// ============================================================================= - -kernel void ed25519_batch_verify( - device const Ed25519PublicKey* pubkeys [[buffer(0)]], - device const Ed25519Signature* signatures [[buffer(1)]], - device const Ed25519Challenge* challenges [[buffer(2)]], - device uchar* results [[buffer(3)]], - constant uint& num_sigs [[buffer(4)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= num_sigs) return; - - EdPoint A; - if (!ed_decompress(pubkeys[tid].data, A)) { - results[tid] = 0u; - return; - } - - EdPoint R; - if (!ed_decompress(signatures[tid].data, R)) { - results[tid] = 0u; - return; - } - - uint256 S; - for (int i = 0; i < 4; ++i) { - ulong v = 0; - for (int b = 0; b < 8; ++b) { - v |= (ulong)signatures[tid].data[32 + i * 8 + b] << (b * 8); - } - S.limbs[i] = v; - } - if (u256_cmp(S, ED_L) >= 0) { - results[tid] = 0u; - return; - } - - uint256 h; - for (int i = 0; i < 4; ++i) { - ulong v = 0; - for (int b = 0; b < 8; ++b) { - v |= (ulong)challenges[tid].data[i * 8 + b] << (b * 8); - } - h.limbs[i] = v; - } - - EdPoint B; B.X = ED_BX; B.Y = ED_BY; B.Z = ONE; B.T = fp_mul(ED_BX, ED_BY); - EdPoint SB = ed_mul(S, B); - EdPoint hA = ed_mul(h, A); - EdPoint RhA = ed_add(R, hA); - - uint256 sb_x, sb_y, rha_x, rha_y; - ed_to_affine(SB, sb_x, sb_y); - ed_to_affine(RhA, rha_x, rha_y); - - bool ok = (u256_cmp(sb_x, rha_x) == 0) && (u256_cmp(sb_y, rha_y) == 0); - results[tid] = ok ? 1u : 0u; -} diff --git a/ed25519/gpu/metal/ed25519_batch_driver.mm b/ed25519/gpu/metal/ed25519_batch_driver.mm deleted file mode 100644 index 5d08381..0000000 --- a/ed25519/gpu/metal/ed25519_batch_driver.mm +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Metal driver for batched Ed25519 EdDSA verification (RFC 8032). macOS only. -// -// Loads the precompiled ed25519_batch.metallib, dispatches `ed25519_batch_verify` -// with one thread per signature. The host pre-computes the challenge scalar -// h = SHA-512(R || A || M) mod L per signature, since SHA-512 has hardware -// acceleration on Apple Silicon (NEON crypto extensions, ~2 GB/s/core) and is -// faster on CPU than emitting a 1024-LOC SHA-512 in Metal compute. - -#if __APPLE__ && __OBJC__ - -#import -#import - -#include -#include - -extern "C" int ed25519_batch_verify_metal( - const uint8_t* pubkeys, // [n][32] - const uint8_t* signatures, // [n][64] - const uint8_t* challenges, // [n][32] h = SHA-512(R||A||M) mod L - size_t n, - uint8_t* results, // [n][1] 1 = valid, 0 = invalid - const char* metallib_path) { - - if (n == 0) return 0; - if (!pubkeys || !signatures || !challenges || !results || !metallib_path) { - return -1; - } - - @autoreleasepool { - id device = MTLCreateSystemDefaultDevice(); - if (!device) return -2; - - NSError* err = nil; - NSURL* url = [NSURL fileURLWithPath:[NSString stringWithUTF8String:metallib_path]]; - id lib = [device newLibraryWithURL:url error:&err]; - if (!lib) return -3; - - id fn = [lib newFunctionWithName:@"ed25519_batch_verify"]; - if (!fn) return -4; - - id pipeline = - [device newComputePipelineStateWithFunction:fn error:&err]; - if (!pipeline) return -5; - - id queue = [device newCommandQueue]; - - id pubkeys_buf = [device newBufferWithBytes:pubkeys - length:n * 32 - options:MTLResourceStorageModeShared]; - id sigs_buf = [device newBufferWithBytes:signatures - length:n * 64 - options:MTLResourceStorageModeShared]; - id challenges_buf = [device newBufferWithBytes:challenges - length:n * 32 - options:MTLResourceStorageModeShared]; - id results_buf = [device newBufferWithLength:n - options:MTLResourceStorageModeShared]; - uint32_t n_u32 = (uint32_t)n; - id n_buf = [device newBufferWithBytes:&n_u32 - length:sizeof(n_u32) - options:MTLResourceStorageModeShared]; - - id cmd = [queue commandBuffer]; - id enc = [cmd computeCommandEncoder]; - [enc setComputePipelineState:pipeline]; - [enc setBuffer:pubkeys_buf offset:0 atIndex:0]; - [enc setBuffer:sigs_buf offset:0 atIndex:1]; - [enc setBuffer:challenges_buf offset:0 atIndex:2]; - [enc setBuffer:results_buf offset:0 atIndex:3]; - [enc setBuffer:n_buf offset:0 atIndex:4]; - - // One thread per signature. Threadgroup width capped at 32 because each - // thread holds ~6 KB of stack scratch (extended Edwards points + scalar - // mul intermediates); larger threadgroups OOM the pipeline state. - NSUInteger tg_max = pipeline.maxTotalThreadsPerThreadgroup; - NSUInteger tg_w = tg_max < 32 ? tg_max : 32; - MTLSize threads_per_grid = MTLSizeMake(n, 1, 1); - MTLSize threads_per_tg = MTLSizeMake(tg_w, 1, 1); - [enc dispatchThreads:threads_per_grid threadsPerThreadgroup:threads_per_tg]; - [enc endEncoding]; - [cmd commit]; - [cmd waitUntilCompleted]; - - std::memcpy(results, [results_buf contents], n); - } - return 0; -} - -#endif // __APPLE__ && __OBJC__ diff --git a/ed25519/gpu/wgsl/ed25519.wgsl b/ed25519/gpu/wgsl/ed25519.wgsl deleted file mode 100644 index 51bece8..0000000 --- a/ed25519/gpu/wgsl/ed25519.wgsl +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Ed25519 EdDSA batch verification in WGSL. -// Twisted Edwards curve: -x^2 + y^2 = 1 + d*x^2*y^2 over F_p, p = 2^255 - 19. -// Each thread verifies one signature. -// -// Host pre-computes H(R||A||M) and reduces mod L. -// GPU performs point arithmetic: check [S]B == R + [h]A. - -@group(0) @binding(0) var pubkeys: array; // 8 u32 per key (32 bytes) -@group(0) @binding(1) var msg_hashes: array; // 16 u32 per hash (64 bytes) -@group(0) @binding(2) var signatures: array; // 16 u32 per sig (64 bytes) -@group(0) @binding(3) var results: array; -@group(0) @binding(4) var params: vec4; // params.x = num_sigs - -// 256-bit integer as 8 x u32 limbs (little-endian) -// Using u32 since WGSL lacks u64 - -fn u256_is_zero(a: ptr>) -> bool { - var acc = 0u; - for (var i = 0u; i < 8u; i = i + 1u) { acc = acc | (*a)[i]; } - return acc == 0u; -} - -// 256-bit addition with carry -fn u256_add(a: ptr>, b: ptr>, - r: ptr>) -> u32 { - var c = 0u; - for (var i = 0u; i < 8u; i = i + 1u) { - let sum = (*a)[i] + c; - c = select(0u, 1u, sum < (*a)[i]); - let sum2 = sum + (*b)[i]; - c = c + select(0u, 1u, sum2 < sum); - (*r)[i] = sum2; - } - return c; -} - -// 256-bit subtraction with borrow -fn u256_sub(a: ptr>, b: ptr>, - r: ptr>) -> u32 { - var bw = 0u; - for (var i = 0u; i < 8u; i = i + 1u) { - let diff = (*a)[i] - bw; - bw = select(0u, 1u, diff > (*a)[i]); - let diff2 = diff - (*b)[i]; - bw = bw + select(0u, 1u, diff2 > diff); - (*r)[i] = diff2; - } - return bw; -} - -// Compare: returns -1, 0, 1 -fn u256_cmp(a: ptr>, b: ptr>) -> i32 { - for (var i = 7i; i >= 0i; i = i - 1i) { - let idx = u32(i); - if ((*a)[idx] < (*b)[idx]) { return -1; } - if ((*a)[idx] > (*b)[idx]) { return 1; } - } - return 0; -} - -// p = 2^255 - 19 -const P = array( - 0xFFFFFFEDu, 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu, - 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu, 0x7FFFFFFFu -); - -fn fp_add(a: ptr>, b: ptr>, - r: ptr>) { - let c = u256_add(a, b, r); - var p_val: array = P; - if (c != 0u || u256_cmp(r, &p_val) >= 0) { - u256_sub(r, &p_val, r); - } -} - -fn fp_sub(a: ptr>, b: ptr>, - r: ptr>) { - let bw = u256_sub(a, b, r); - if (bw != 0u) { - var p_val: array = P; - u256_add(r, &p_val, r); - } -} - -// Simplified modular multiply for WGSL (schoolbook with 16-bit pieces) -// This is slow but correct for the verification check -fn fp_mul_simple(a: ptr>, b: ptr>, - r: ptr>) { - // For WGSL without u64, we do 32x32 schoolbook on the 8 limbs - // producing a 512-bit result, then reduce mod p = 2^255 - 19 - var t: array; - for (var i = 0u; i < 16u; i = i + 1u) { t[i] = 0u; } - - for (var i = 0u; i < 8u; i = i + 1u) { - var carry = 0u; - for (var j = 0u; j < 8u; j = j + 1u) { - // 32x32 -> 64 using 16-bit pieces - let a_lo = (*a)[i] & 0xFFFFu; - let a_hi = (*a)[i] >> 16u; - let b_lo = (*b)[j] & 0xFFFFu; - let b_hi = (*b)[j] >> 16u; - - let ll = a_lo * b_lo; - let lh = a_lo * b_hi; - let hl = a_hi * b_lo; - let hh = a_hi * b_hi; - - let mid = lh + hl; - let lo = ll + (mid << 16u) + carry + t[i + j]; - let hi = hh + (mid >> 16u) + select(0u, 1u, (mid << 16u) + ll < ll) - + select(0u, 1u, lo < t[i + j]); - - t[i + j] = lo; - carry = hi; - } - t[i + 8u] = carry; - } - - // Reduce mod 2^255 - 19: split at bit 255, multiply high by 38 - // Low 256 bits - for (var i = 0u; i < 8u; i = i + 1u) { (*r)[i] = t[i]; } - - // High bits * 38 + low - var hi_part: array; - for (var i = 0u; i < 8u; i = i + 1u) { hi_part[i] = t[i + 8u]; } - - var hi38: array; - var carry = 0u; - for (var i = 0u; i < 8u; i = i + 1u) { - let prod = hi_part[i] * 38u + carry; - hi38[i] = prod; - carry = (hi_part[i] >> 16u) * 38u >> 16u; // Approximate carry - } - - u256_add(r, &hi38, r); - var p_val: array = P; - if (u256_cmp(r, &p_val) >= 0) { u256_sub(r, &p_val, r); } - if (u256_cmp(r, &p_val) >= 0) { u256_sub(r, &p_val, r); } -} - -@compute @workgroup_size(64) -fn ed25519_verify_batch(@builtin(global_invocation_id) gid: vec3) { - let tid = gid.x; - if (tid >= params.x) { return; } - - // Read public key (32 bytes = 8 u32) - var pk: array; - let pk_base = tid * 8u; - for (var i = 0u; i < 8u; i = i + 1u) { pk[i] = pubkeys[pk_base + i]; } - - // Read signature R (first 32 bytes) and S (next 32 bytes) - var sig_r: array; - var sig_s: array; - let sig_base = tid * 16u; - for (var i = 0u; i < 8u; i = i + 1u) { - sig_r[i] = signatures[sig_base + i]; - sig_s[i] = signatures[sig_base + 8u + i]; - } - - // Read pre-computed hash scalar h (first 32 bytes of 64-byte hash, reduced mod L) - var h: array; - let hash_base = tid * 16u; - for (var i = 0u; i < 8u; i = i + 1u) { h[i] = msg_hashes[hash_base + i]; } - - // Check S < L (group order) - let L = array( - 0x5CF5D3EDu, 0x5812631Au, 0xA2F79CD6u, 0x14DEF9DEu, - 0x00000000u, 0x00000000u, 0x00000000u, 0x10000000u - ); - var s_check = sig_s; - var l_check: array = L; - if (u256_cmp(&s_check, &l_check) >= 0) { - results[tid] = 0u; - return; - } - - // Point decompression and scalar multiplication would be done here. - // For WGSL, the full Ed25519 point arithmetic is extremely expensive - // without u64. In practice, the Metal backend handles the heavy lifting - // and the WGSL version validates input formats and basic scalar checks. - - // Basic validity check passed (S < L, inputs well-formed) - // Full point arithmetic verification delegated to Metal/CUDA backends - results[tid] = 1u; -} diff --git a/evm256/gpu/cuda/evm256.cu b/evm256/gpu/cuda/evm256.cu deleted file mode 100644 index d0d119d..0000000 --- a/evm256/gpu/cuda/evm256.cu +++ /dev/null @@ -1,454 +0,0 @@ -// PAT-FHE-012: EVM256 Parallel Processing - CUDA Implementation -// uint256 operations using 4x uint64 limbs (little-endian) -// Matches evm256.metal output byte-for-byte -// -// Copyright (C) 2024-2026 Lux Partners Limited -// SPDX-License-Identifier: BSD-3-Clause - -#include - -#ifdef __CUDA_ARCH__ - -#define LIMBS 4 - -struct uint256_t { - unsigned long long limbs[LIMBS]; -}; - -// ============================================================================ -// Helper functions for multi-limb arithmetic -// ============================================================================ - -__device__ __forceinline__ -unsigned long long add_carry(unsigned long long a, unsigned long long b, - unsigned long long* carry) { - unsigned long long sum = a + b; - unsigned long long c1 = sum < a ? 1ULL : 0ULL; - unsigned long long result = sum + *carry; - unsigned long long c2 = result < sum ? 1ULL : 0ULL; - *carry = c1 | c2; - return result; -} - -__device__ __forceinline__ -unsigned long long sub_borrow(unsigned long long a, unsigned long long b, - unsigned long long* borrow) { - unsigned long long diff = a - b; - unsigned long long b1 = diff > a ? 1ULL : 0ULL; - unsigned long long result = diff - *borrow; - unsigned long long b2 = result > diff ? 1ULL : 0ULL; - *borrow = b1 | b2; - return result; -} - -// Full 64x64 -> 128 bit multiplication -__device__ __forceinline__ -void mul64_wide(unsigned long long a, unsigned long long b, - unsigned long long* hi, unsigned long long* lo) { - unsigned long long a_lo = a & 0xFFFFFFFFULL; - unsigned long long a_hi = a >> 32; - unsigned long long b_lo = b & 0xFFFFFFFFULL; - unsigned long long b_hi = b >> 32; - - unsigned long long p0 = a_lo * b_lo; - unsigned long long p1 = a_lo * b_hi; - unsigned long long p2 = a_hi * b_lo; - unsigned long long p3 = a_hi * b_hi; - - unsigned long long cy = ((p0 >> 32) + (p1 & 0xFFFFFFFFULL) + (p2 & 0xFFFFFFFFULL)) >> 32; - *lo = p0 + (p1 << 32) + (p2 << 32); - *hi = p3 + (p1 >> 32) + (p2 >> 32) + cy; -} - -// ============================================================================ -// Compare and zero-check helpers -// ============================================================================ - -__device__ __forceinline__ -int cmp256(const uint256_t* a, const uint256_t* b) { - for (int i = LIMBS - 1; i >= 0; i--) { - if (a->limbs[i] > b->limbs[i]) return 1; - if (a->limbs[i] < b->limbs[i]) return -1; - } - return 0; -} - -__device__ __forceinline__ -bool is_zero(const uint256_t* a) { - for (int i = 0; i < LIMBS; i++) { - if (a->limbs[i] != 0) return false; - } - return true; -} - -// ============================================================================ -// Div256 Implementation (long division) -// ============================================================================ - -__device__ __forceinline__ -void div256_impl(const uint256_t* numerator, const uint256_t* denominator, - uint256_t* quotient, uint256_t* remainder) { - if (is_zero(denominator)) { - for (int i = 0; i < LIMBS; i++) { - quotient->limbs[i] = 0; - remainder->limbs[i] = 0; - } - return; - } - - uint256_t q, r; - for (int i = 0; i < LIMBS; i++) { - q.limbs[i] = 0; - r.limbs[i] = 0; - } - - for (int i = 255; i >= 0; i--) { - // r <<= 1 - unsigned long long carry = 0; - for (int j = 0; j < LIMBS; j++) { - unsigned long long temp = (r.limbs[j] << 1) | carry; - carry = r.limbs[j] >> 63; - r.limbs[j] = temp; - } - - // r[0] |= numerator bit i - int limb_idx = i / 64; - int bit_idx = i % 64; - unsigned long long bit = (numerator->limbs[limb_idx] >> bit_idx) & 1ULL; - r.limbs[0] |= bit; - - // if r >= denominator - if (cmp256(&r, denominator) >= 0) { - // r -= denominator - unsigned long long borrow = 0; - for (int j = 0; j < LIMBS; j++) { - r.limbs[j] = sub_borrow(r.limbs[j], denominator->limbs[j], &borrow); - } - - // q[i] = 1 - limb_idx = i / 64; - bit_idx = i % 64; - q.limbs[limb_idx] |= (1ULL << bit_idx); - } - } - - *quotient = q; - *remainder = r; -} - -// ============================================================================ -// Montgomery reduction -// ============================================================================ - -__device__ __forceinline__ -void montgomery_reduce(const unsigned long long* t, const uint256_t* m, - unsigned long long m_inv, uint256_t* result) { - unsigned long long a[LIMBS * 2]; - for (int i = 0; i < LIMBS * 2; i++) { - a[i] = t[i]; - } - - for (int i = 0; i < LIMBS; i++) { - unsigned long long u = a[i] * m_inv; - unsigned long long carry = 0; - - for (int j = 0; j < LIMBS; j++) { - unsigned long long hi, lo; - mul64_wide(u, m->limbs[j], &hi, &lo); - - unsigned long long sum = a[i + j] + lo + carry; - carry = (sum < a[i + j]) ? 1ULL : 0ULL; - carry += hi; - a[i + j] = sum; - } - - for (int j = LIMBS; j < LIMBS * 2 - i && carry; j++) { - unsigned long long sum = a[i + j] + carry; - carry = (sum < a[i + j]) ? 1ULL : 0ULL; - a[i + j] = sum; - } - } - - // Result is in upper half - bool needs_sub = false; - for (int i = LIMBS - 1; i >= 0; i--) { - if (a[LIMBS + i] > m->limbs[i]) { - needs_sub = true; - break; - } - if (a[LIMBS + i] < m->limbs[i]) break; - } - - if (needs_sub) { - unsigned long long borrow = 0; - for (int i = 0; i < LIMBS; i++) { - result->limbs[i] = sub_borrow(a[LIMBS + i], m->limbs[i], &borrow); - } - } else { - for (int i = 0; i < LIMBS; i++) { - result->limbs[i] = a[LIMBS + i]; - } - } -} - -// ============================================================================ -// Kernel: Batch Add256 -// ============================================================================ - -extern "C" __global__ -void cuda_add256( - const uint256_t* __restrict__ a, - const uint256_t* __restrict__ b, - uint256_t* __restrict__ result, - uint32_t count -) { - uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= count) return; - - unsigned long long carry = 0; - for (int i = 0; i < LIMBS; i++) { - result[idx].limbs[i] = add_carry(a[idx].limbs[i], b[idx].limbs[i], &carry); - } -} - -// ============================================================================ -// Kernel: Batch Sub256 -// ============================================================================ - -extern "C" __global__ -void cuda_sub256( - const uint256_t* __restrict__ a, - const uint256_t* __restrict__ b, - uint256_t* __restrict__ result, - uint32_t count -) { - uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= count) return; - - unsigned long long borrow = 0; - for (int i = 0; i < LIMBS; i++) { - result[idx].limbs[i] = sub_borrow(a[idx].limbs[i], b[idx].limbs[i], &borrow); - } -} - -// ============================================================================ -// Kernel: Batch Mul256 -// ============================================================================ - -extern "C" __global__ -void cuda_mul256( - const uint256_t* __restrict__ a, - const uint256_t* __restrict__ b, - uint256_t* __restrict__ result, - uint32_t count -) { - uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= count) return; - - // Schoolbook multiplication with 8-limb intermediate result - unsigned long long product[LIMBS * 2]; - for (int i = 0; i < LIMBS * 2; i++) { - product[i] = 0; - } - - for (int i = 0; i < LIMBS; i++) { - unsigned long long carry = 0; - for (int j = 0; j < LIMBS; j++) { - unsigned long long hi, lo; - mul64_wide(a[idx].limbs[i], b[idx].limbs[j], &hi, &lo); - - unsigned long long sum = product[i + j] + lo + carry; - carry = (sum < product[i + j]) ? 1ULL : 0ULL; - carry += hi; - product[i + j] = sum; - } - product[i + LIMBS] += carry; - } - - // Take lower 256 bits - for (int i = 0; i < LIMBS; i++) { - result[idx].limbs[i] = product[i]; - } -} - -// ============================================================================ -// Kernel: Batch Div256 -// ============================================================================ - -extern "C" __global__ -void cuda_div256( - const uint256_t* __restrict__ a, - const uint256_t* __restrict__ b, - uint256_t* __restrict__ result, - uint32_t count -) { - uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= count) return; - - uint256_t numerator = a[idx]; - uint256_t denominator = b[idx]; - uint256_t quotient, remainder; - - div256_impl(&numerator, &denominator, "ient, &remainder); - result[idx] = quotient; -} - -// ============================================================================ -// Kernel: Batch Mod256 -// ============================================================================ - -extern "C" __global__ -void cuda_mod256( - const uint256_t* __restrict__ a, - const uint256_t* __restrict__ b, - uint256_t* __restrict__ result, - uint32_t count -) { - uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= count) return; - - uint256_t numerator = a[idx]; - uint256_t denominator = b[idx]; - uint256_t quotient, remainder; - - div256_impl(&numerator, &denominator, "ient, &remainder); - result[idx] = remainder; -} - -// ============================================================================ -// Kernel: Montgomery Multiplication -// ============================================================================ - -extern "C" __global__ -void cuda_mont_mul( - const uint256_t* __restrict__ a, - const uint256_t* __restrict__ b, - const uint256_t* __restrict__ m, - const unsigned long long* __restrict__ m_inv, - uint256_t* __restrict__ result, - uint32_t count -) { - uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= count) return; - - // Compute full 512-bit product - unsigned long long product[LIMBS * 2]; - for (int i = 0; i < LIMBS * 2; i++) { - product[i] = 0; - } - - for (int i = 0; i < LIMBS; i++) { - unsigned long long carry = 0; - for (int j = 0; j < LIMBS; j++) { - unsigned long long hi, lo; - mul64_wide(a[idx].limbs[i], b[idx].limbs[j], &hi, &lo); - - unsigned long long sum = product[i + j] + lo + carry; - carry = (sum < product[i + j]) ? 1ULL : 0ULL; - carry += hi; - product[i + j] = sum; - } - product[i + LIMBS] += carry; - } - - // Montgomery reduce - uint256_t mod = m[idx]; - uint256_t res; - montgomery_reduce(product, &mod, *m_inv, &res); - result[idx] = res; -} - -// ============================================================================ -// Kernel: Modular Exponentiation (square-and-multiply) -// ============================================================================ - -extern "C" __global__ -void cuda_modexp256( - const uint256_t* __restrict__ base, - const uint256_t* __restrict__ exponent, - const uint256_t* __restrict__ modulus, - uint256_t* __restrict__ result, - uint32_t count -) { - uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= count) return; - - uint256_t res; - res.limbs[0] = 1; - res.limbs[1] = 0; - res.limbs[2] = 0; - res.limbs[3] = 0; - - uint256_t b = base[idx]; - uint256_t mod = modulus[idx]; - - for (int i = 0; i < 256; i++) { - int limb_idx = i / 64; - int bit_idx = i % 64; - unsigned long long bit = (exponent[idx].limbs[limb_idx] >> bit_idx) & 1ULL; - - if (bit) { - // res = (res * b) % modulus - unsigned long long product[LIMBS * 2]; - for (int k = 0; k < LIMBS * 2; k++) { - product[k] = 0; - } - - for (int j = 0; j < LIMBS; j++) { - unsigned long long carry = 0; - for (int k = 0; k < LIMBS; k++) { - unsigned long long hi, lo; - mul64_wide(res.limbs[j], b.limbs[k], &hi, &lo); - - unsigned long long sum = product[j + k] + lo + carry; - carry = (sum < product[j + k]) ? 1ULL : 0ULL; - carry += hi; - product[j + k] = sum; - } - product[j + LIMBS] += carry; - } - - uint256_t temp; - for (int j = 0; j < LIMBS; j++) { - temp.limbs[j] = product[j]; - } - - uint256_t quot, rem; - div256_impl(&temp, &mod, ", &rem); - res = rem; - } - - // b = (b * b) % modulus - unsigned long long product[LIMBS * 2]; - for (int k = 0; k < LIMBS * 2; k++) { - product[k] = 0; - } - - for (int j = 0; j < LIMBS; j++) { - unsigned long long carry = 0; - for (int k = 0; k < LIMBS; k++) { - unsigned long long hi, lo; - mul64_wide(b.limbs[j], b.limbs[k], &hi, &lo); - - unsigned long long sum = product[j + k] + lo + carry; - carry = (sum < product[j + k]) ? 1ULL : 0ULL; - carry += hi; - product[j + k] = sum; - } - product[j + LIMBS] += carry; - } - - uint256_t temp; - for (int j = 0; j < LIMBS; j++) { - temp.limbs[j] = product[j]; - } - - uint256_t quot, rem; - div256_impl(&temp, &mod, ", &rem); - b = rem; - } - - result[idx] = res; -} - -#endif // __CUDA_ARCH__ diff --git a/evm256/gpu/metal/evm256.metal b/evm256/gpu/metal/evm256.metal deleted file mode 100644 index a9533d3..0000000 --- a/evm256/gpu/metal/evm256.metal +++ /dev/null @@ -1,416 +0,0 @@ -// PAT-FHE-012: EVM256 Parallel Processing - Metal Implementation -// uint256 operations using 4x uint64 limbs (little-endian) - -#include -using namespace metal; - -#define LIMBS 4 - -// uint256 represented as 4x uint64 limbs -struct uint256_t { - ulong limbs[LIMBS]; -}; - -// ============================================================================ -// Helper functions for multi-limb arithmetic -// ============================================================================ - -inline ulong add_carry(ulong a, ulong b, thread ulong* carry) { - ulong sum = a + b; - ulong c1 = sum < a ? 1 : 0; - ulong result = sum + *carry; - ulong c2 = result < sum ? 1 : 0; - *carry = c1 | c2; - return result; -} - -inline ulong sub_borrow(ulong a, ulong b, thread ulong* borrow) { - ulong diff = a - b; - ulong b1 = diff > a ? 1 : 0; - ulong result = diff - *borrow; - ulong b2 = result > diff ? 1 : 0; - *borrow = b1 | b2; - return result; -} - -// Full 64x64 -> 128 bit multiplication -inline void mul64_wide(ulong a, ulong b, thread ulong* hi, thread ulong* lo) { - ulong a_lo = a & 0xFFFFFFFF; - ulong a_hi = a >> 32; - ulong b_lo = b & 0xFFFFFFFF; - ulong b_hi = b >> 32; - - ulong p0 = a_lo * b_lo; - ulong p1 = a_lo * b_hi; - ulong p2 = a_hi * b_lo; - ulong p3 = a_hi * b_hi; - - ulong cy = ((p0 >> 32) + (p1 & 0xFFFFFFFF) + (p2 & 0xFFFFFFFF)) >> 32; - *lo = p0 + (p1 << 32) + (p2 << 32); - *hi = p3 + (p1 >> 32) + (p2 >> 32) + cy; -} - -// ============================================================================ -// Kernel: Batch Add256 -// ============================================================================ - -kernel void metal_add256( - const device uint256_t* a [[buffer(0)]], - const device uint256_t* b [[buffer(1)]], - device uint256_t* result [[buffer(2)]], - uint idx [[thread_position_in_grid]] -) { - ulong carry = 0; - for (int i = 0; i < LIMBS; i++) { - result[idx].limbs[i] = add_carry(a[idx].limbs[i], b[idx].limbs[i], &carry); - } -} - -// ============================================================================ -// Kernel: Batch Sub256 -// ============================================================================ - -kernel void metal_sub256( - const device uint256_t* a [[buffer(0)]], - const device uint256_t* b [[buffer(1)]], - device uint256_t* result [[buffer(2)]], - uint idx [[thread_position_in_grid]] -) { - ulong borrow = 0; - for (int i = 0; i < LIMBS; i++) { - result[idx].limbs[i] = sub_borrow(a[idx].limbs[i], b[idx].limbs[i], &borrow); - } -} - -// ============================================================================ -// Kernel: Batch Mul256 -// ============================================================================ - -kernel void metal_mul256( - const device uint256_t* a [[buffer(0)]], - const device uint256_t* b [[buffer(1)]], - device uint256_t* result [[buffer(2)]], - uint idx [[thread_position_in_grid]] -) { - // Schoolbook multiplication with 8 limbs intermediate result - ulong product[LIMBS * 2]; - for (int i = 0; i < LIMBS * 2; i++) { - product[i] = 0; - } - - for (int i = 0; i < LIMBS; i++) { - ulong carry = 0; - for (int j = 0; j < LIMBS; j++) { - ulong hi, lo; - mul64_wide(a[idx].limbs[i], b[idx].limbs[j], &hi, &lo); - - // Add to product[i+j] - ulong sum = product[i + j] + lo + carry; - carry = (sum < product[i + j]) ? 1 : 0; - carry += hi; - product[i + j] = sum; - } - product[i + LIMBS] += carry; - } - - // Take lower 256 bits - for (int i = 0; i < LIMBS; i++) { - result[idx].limbs[i] = product[i]; - } -} - -// ============================================================================ -// Helper: Compare uint256 -// ============================================================================ - -inline int cmp256(const thread uint256_t* a, const thread uint256_t* b) { - for (int i = LIMBS - 1; i >= 0; i--) { - if (a->limbs[i] > b->limbs[i]) return 1; - if (a->limbs[i] < b->limbs[i]) return -1; - } - return 0; -} - -inline bool is_zero(const thread uint256_t* a) { - for (int i = 0; i < LIMBS; i++) { - if (a->limbs[i] != 0) return false; - } - return true; -} - -// ============================================================================ -// Helper: Div256 Implementation -// ============================================================================ - -inline void div256_impl(const thread uint256_t* numerator, const thread uint256_t* denominator, - thread uint256_t* quotient, thread uint256_t* remainder) { - // Handle division by zero - if (is_zero(denominator)) { - for (int i = 0; i < LIMBS; i++) { - quotient->limbs[i] = 0; - remainder->limbs[i] = 0; - } - return; - } - - // Initialize - uint256_t q, r; - for (int i = 0; i < LIMBS; i++) { - q.limbs[i] = 0; - r.limbs[i] = 0; - } - - // Long division algorithm - for (int i = 255; i >= 0; i--) { - // r <<= 1 - ulong carry = 0; - for (int j = 0; j < LIMBS; j++) { - ulong temp = (r.limbs[j] << 1) | carry; - carry = r.limbs[j] >> 63; - r.limbs[j] = temp; - } - - // r[0] = numerator[i] - int limb_idx = i / 64; - int bit_idx = i % 64; - ulong bit = (numerator->limbs[limb_idx] >> bit_idx) & 1; - r.limbs[0] |= bit; - - // if r >= denominator - if (cmp256(&r, denominator) >= 0) { - // r -= denominator - ulong borrow = 0; - for (int j = 0; j < LIMBS; j++) { - r.limbs[j] = sub_borrow(r.limbs[j], denominator->limbs[j], &borrow); - } - - // q[i] = 1 - limb_idx = i / 64; - bit_idx = i % 64; - q.limbs[limb_idx] |= (1UL << bit_idx); - } - } - - *quotient = q; - *remainder = r; -} - -// ============================================================================ -// Kernel: Batch Div256 -// ============================================================================ - -kernel void metal_div256( - const device uint256_t* a [[buffer(0)]], - const device uint256_t* b [[buffer(1)]], - device uint256_t* result [[buffer(2)]], - uint idx [[thread_position_in_grid]] -) { - uint256_t numerator = a[idx]; - uint256_t denominator = b[idx]; - uint256_t quotient, remainder; - - div256_impl(&numerator, &denominator, "ient, &remainder); - result[idx] = quotient; -} - -// ============================================================================ -// Kernel: Batch Mod256 -// ============================================================================ - -kernel void metal_mod256( - const device uint256_t* a [[buffer(0)]], - const device uint256_t* b [[buffer(1)]], - device uint256_t* result [[buffer(2)]], - uint idx [[thread_position_in_grid]] -) { - uint256_t numerator = a[idx]; - uint256_t denominator = b[idx]; - uint256_t quotient, remainder; - - div256_impl(&numerator, &denominator, "ient, &remainder); - result[idx] = remainder; -} - -// ============================================================================ -// Montgomery multiplication helpers -// ============================================================================ - -inline void montgomery_reduce(const thread ulong* t, const thread uint256_t* m, - ulong m_inv, thread uint256_t* result) { - ulong a[LIMBS * 2]; - for (int i = 0; i < LIMBS * 2; i++) { - a[i] = t[i]; - } - - for (int i = 0; i < LIMBS; i++) { - ulong u = a[i] * m_inv; - ulong carry = 0; - - for (int j = 0; j < LIMBS; j++) { - ulong hi, lo; - mul64_wide(u, m->limbs[j], &hi, &lo); - - ulong sum = a[i + j] + lo + carry; - carry = (sum < a[i + j]) ? 1 : 0; - carry += hi; - a[i + j] = sum; - } - - for (int j = LIMBS; j < LIMBS * 2 - i && carry; j++) { - ulong sum = a[i + j] + carry; - carry = (sum < a[i + j]) ? 1 : 0; - a[i + j] = sum; - } - } - - // Result is in upper half - bool needs_sub = false; - for (int i = LIMBS - 1; i >= 0; i--) { - if (a[LIMBS + i] > m->limbs[i]) { - needs_sub = true; - break; - } - if (a[LIMBS + i] < m->limbs[i]) break; - } - - if (needs_sub) { - ulong borrow = 0; - for (int i = 0; i < LIMBS; i++) { - result->limbs[i] = sub_borrow(a[LIMBS + i], m->limbs[i], &borrow); - } - } else { - for (int i = 0; i < LIMBS; i++) { - result->limbs[i] = a[LIMBS + i]; - } - } -} - -// ============================================================================ -// Kernel: Montgomery Multiplication -// ============================================================================ - -kernel void metal_mont_mul( - const device uint256_t* a [[buffer(0)]], - const device uint256_t* b [[buffer(1)]], - const device uint256_t* m [[buffer(2)]], - const device ulong* m_inv [[buffer(3)]], - device uint256_t* result [[buffer(4)]], - uint idx [[thread_position_in_grid]] -) { - // Compute full 512-bit product - ulong product[LIMBS * 2]; - for (int i = 0; i < LIMBS * 2; i++) { - product[i] = 0; - } - - for (int i = 0; i < LIMBS; i++) { - ulong carry = 0; - for (int j = 0; j < LIMBS; j++) { - ulong hi, lo; - mul64_wide(a[idx].limbs[i], b[idx].limbs[j], &hi, &lo); - - ulong sum = product[i + j] + lo + carry; - carry = (sum < product[i + j]) ? 1 : 0; - carry += hi; - product[i + j] = sum; - } - product[i + LIMBS] += carry; - } - - // Montgomery reduce - uint256_t mod = m[idx]; - uint256_t res; - montgomery_reduce(product, &mod, *m_inv, &res); - result[idx] = res; -} - -// ============================================================================ -// Kernel: Modular Exponentiation -// ============================================================================ - -kernel void metal_modexp256( - const device uint256_t* base [[buffer(0)]], - const device uint256_t* exponent [[buffer(1)]], - const device uint256_t* modulus [[buffer(2)]], - device uint256_t* result [[buffer(3)]], - uint idx [[thread_position_in_grid]] -) { - // Simple square-and-multiply - uint256_t res; - res.limbs[0] = 1; - res.limbs[1] = 0; - res.limbs[2] = 0; - res.limbs[3] = 0; - - uint256_t b = base[idx]; - uint256_t mod = modulus[idx]; - - for (int i = 0; i < 256; i++) { - int limb_idx = i / 64; - int bit_idx = i % 64; - ulong bit = (exponent[idx].limbs[limb_idx] >> bit_idx) & 1; - - if (bit) { - // res = (res * b) % modulus - ulong product[LIMBS * 2]; - for (int k = 0; k < LIMBS * 2; k++) { - product[k] = 0; - } - - for (int j = 0; j < LIMBS; j++) { - ulong carry = 0; - for (int k = 0; k < LIMBS; k++) { - ulong hi, lo; - mul64_wide(res.limbs[j], b.limbs[k], &hi, &lo); - - ulong sum = product[j + k] + lo + carry; - carry = (sum < product[j + k]) ? 1 : 0; - carry += hi; - product[j + k] = sum; - } - product[j + LIMBS] += carry; - } - - // Take product % modulus - uint256_t temp; - for (int j = 0; j < LIMBS; j++) { - temp.limbs[j] = product[j]; - } - - uint256_t quot, rem; - div256_impl(&temp, &mod, ", &rem); - res = rem; - } - - // b = (b * b) % modulus - ulong product[LIMBS * 2]; - for (int k = 0; k < LIMBS * 2; k++) { - product[k] = 0; - } - - for (int j = 0; j < LIMBS; j++) { - ulong carry = 0; - for (int k = 0; k < LIMBS; k++) { - ulong hi, lo; - mul64_wide(b.limbs[j], b.limbs[k], &hi, &lo); - - ulong sum = product[j + k] + lo + carry; - carry = (sum < product[j + k]) ? 1 : 0; - carry += hi; - product[j + k] = sum; - } - product[j + LIMBS] += carry; - } - - uint256_t temp; - for (int j = 0; j < LIMBS; j++) { - temp.limbs[j] = product[j]; - } - - uint256_t quot, rem; - div256_impl(&temp, &mod, ", &rem); - b = rem; - } - - result[idx] = res; -} diff --git a/evm256/gpu/wgsl/evm256.wgsl b/evm256/gpu/wgsl/evm256.wgsl deleted file mode 100644 index ecbc198..0000000 --- a/evm256/gpu/wgsl/evm256.wgsl +++ /dev/null @@ -1,162 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// EVM uint256 parallel operations in WGSL. -// uint256 = 8 x u32 limbs (little-endian, since WGSL has no u64). -// Matches evm256.metal output byte-for-byte. - -@group(0) @binding(0) var a: array; -@group(0) @binding(1) var b: array; -@group(0) @binding(2) var result: array; -@group(0) @binding(3) var params: vec4; -// params.x = num_items, params.y = op (0=add 1=sub 2=mul 3=div 4=mod) - -const LIMBS: u32 = 8u; // 256 bits / 32 bits - -fn load256(buf: ptr, read>, idx: u32) -> array { - var v: array; - let base = idx * LIMBS; - for (var i = 0u; i < LIMBS; i = i + 1u) { v[i] = (*buf)[base + i]; } - return v; -} - -fn store256(idx: u32, v: ptr>) { - let base = idx * LIMBS; - for (var i = 0u; i < LIMBS; i = i + 1u) { result[base + i] = (*v)[i]; } -} - -fn add256(x: ptr>, y: ptr>, - r: ptr>) -> u32 { - var c = 0u; - for (var i = 0u; i < LIMBS; i = i + 1u) { - let s1 = (*x)[i] + c; - c = select(0u, 1u, s1 < (*x)[i]); - let s2 = s1 + (*y)[i]; - c = c + select(0u, 1u, s2 < s1); - (*r)[i] = s2; - } - return c; -} - -fn sub256(x: ptr>, y: ptr>, - r: ptr>) -> u32 { - var bw = 0u; - for (var i = 0u; i < LIMBS; i = i + 1u) { - let d1 = (*x)[i] - bw; - bw = select(0u, 1u, d1 > (*x)[i]); - let d2 = d1 - (*y)[i]; - bw = bw + select(0u, 1u, d2 > d1); - (*r)[i] = d2; - } - return bw; -} - -fn is_zero256(v: ptr>) -> bool { - var acc = 0u; - for (var i = 0u; i < LIMBS; i = i + 1u) { acc = acc | (*v)[i]; } - return acc == 0u; -} - -fn cmp256(x: ptr>, y: ptr>) -> i32 { - for (var i = 7i; i >= 0; i = i - 1) { - let ui = u32(i); - if ((*x)[ui] > (*y)[ui]) { return 1; } - if ((*x)[ui] < (*y)[ui]) { return -1; } - } - return 0; -} - -fn mul256(x: ptr>, y: ptr>, - r: ptr>) { - var prod: array; - for (var i = 0u; i < 16u; i = i + 1u) { prod[i] = 0u; } - - for (var i = 0u; i < LIMBS; i = i + 1u) { - var carry = 0u; - for (var j = 0u; j < LIMBS; j = j + 1u) { - // 32x32 -> 64 multiply - let a_lo = (*x)[i] & 0xFFFFu; - let a_hi = (*x)[i] >> 16u; - let b_lo = (*y)[j] & 0xFFFFu; - let b_hi = (*y)[j] >> 16u; - let ll = a_lo * b_lo; - let lh = a_lo * b_hi; - let hl = a_hi * b_lo; - let hh = a_hi * b_hi; - let mid = lh + hl; - let lo = ll + (mid << 16u); - var hi = hh + (mid >> 16u) + select(0u, 1u, lo < ll) + select(0u, 0x10000u, mid < lh); - - // Add carry - let s1 = lo + carry; - hi = hi + select(0u, 1u, s1 < lo); - // Add to accumulator - let s2 = prod[i + j] + s1; - hi = hi + select(0u, 1u, s2 < prod[i + j]); - prod[i + j] = s2; - carry = hi; - } - prod[i + LIMBS] = prod[i + LIMBS] + carry; - } - - // Take lower 256 bits - for (var i = 0u; i < LIMBS; i = i + 1u) { (*r)[i] = prod[i]; } -} - -fn div256(num: ptr>, den: ptr>, - q: ptr>, rem: ptr>) { - // Zero outputs - for (var i = 0u; i < LIMBS; i = i + 1u) { (*q)[i] = 0u; (*rem)[i] = 0u; } - - if (is_zero256(den)) { return; } - - // Long division bit by bit - for (var bit = 255i; bit >= 0; bit = bit - 1) { - // rem <<= 1 - var c = 0u; - for (var j = 0u; j < LIMBS; j = j + 1u) { - let temp = ((*rem)[j] << 1u) | c; - c = (*rem)[j] >> 31u; - (*rem)[j] = temp; - } - - // rem[0] |= bit from numerator - let limb_idx = u32(bit) / 32u; - let bit_idx = u32(bit) % 32u; - (*rem)[0] = (*rem)[0] | (((*num)[limb_idx] >> bit_idx) & 1u); - - // if rem >= den: rem -= den, q[bit] = 1 - if (cmp256(rem, den) >= 0) { - let _ = sub256(rem, den, rem); - (*q)[limb_idx] = (*q)[limb_idx] | (1u << bit_idx); - } - } -} - -@compute @workgroup_size(256) -fn evm256_batch(@builtin(global_invocation_id) gid: vec3) { - let tid = gid.x; - if (tid >= params.x) { return; } - - var va = load256(&a, tid); - var vb = load256(&b, tid); - var vr: array; - for (var i = 0u; i < LIMBS; i = i + 1u) { vr[i] = 0u; } - - let op = params.y; - if (op == 0u) { - let _ = add256(&va, &vb, &vr); - } else if (op == 1u) { - let _ = sub256(&va, &vb, &vr); - } else if (op == 2u) { - mul256(&va, &vb, &vr); - } else if (op == 3u) { - var rem: array; - div256(&va, &vb, &vr, &rem); - } else if (op == 4u) { - var quot: array; - div256(&va, &vb, ", &vr); - } - - store256(tid, &vr); -} diff --git a/frost/gpu/cuda/frost.cu b/frost/gpu/cuda/frost.cu deleted file mode 100644 index 3a68a03..0000000 --- a/frost/gpu/cuda/frost.cu +++ /dev/null @@ -1,402 +0,0 @@ -// FROST threshold Schnorr signature verification -- CUDA implementation -// Matches frost.metal output byte-for-byte -// One thread per partial signature verification - -#include - -#ifndef __CUDA_ARCH__ -#define __device__ -#define __global__ -#define __shared__ -struct dim3 { unsigned x, y, z; }; -static dim3 blockIdx, blockDim, threadIdx; -#endif - -// ============================================================================= -// 256-bit integer -// ============================================================================= - -struct uint256 { - uint64_t limbs[4]; -}; - -// secp256k1 constants -__device__ static const uint256 FROST_P = {{ - 0xFFFFFFFEFFFFFC2FULL, 0xFFFFFFFFFFFFFFFFULL, - 0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL -}}; - -__device__ static const uint256 FROST_N = {{ - 0xBFD25E8CD0364141ULL, 0xBAAEDCE6AF48A03BULL, - 0xFFFFFFFFFFFFFFFEULL, 0xFFFFFFFFFFFFFFFFULL -}}; - -__device__ static const uint256 FROST_GX = {{ - 0x59F2815B16F81798ULL, 0x029BFCDB2DCE28D9ULL, - 0x55A06295CE870B07ULL, 0x79BE667EF9DCBBACULL -}}; - -__device__ static const uint256 FROST_GY = {{ - 0x9C47D08FFB10D4B8ULL, 0xFD17B448A6855419ULL, - 0x5DA4FBFC0E1108A8ULL, 0x483ADA7726A3C465ULL -}}; - -// Montgomery constants for field p -__device__ static const uint256 FROST_MONT_R2_P = {{ - 0x000007A2000E90A1ULL, 0x0000000000000001ULL, - 0x0000000000000000ULL, 0x0000000000000000ULL -}}; -__device__ static const uint64_t FROST_P_INV = 0xD838091DD2253531ULL; - -__device__ static const uint256 FROST_MONT_R = {{ - 0x00000001000003D1ULL, 0x0000000000000000ULL, - 0x0000000000000000ULL, 0x0000000000000000ULL -}}; - -__device__ static const uint256 FROST_ZERO = {{0, 0, 0, 0}}; - -// ============================================================================= -// 256-bit arithmetic -// CUDA has __int128, use it for 64x64->128 multiply -// ============================================================================= - -__device__ static int u256_cmp(uint256 a, uint256 b) { - for (int i = 3; i >= 0; i--) { - if (a.limbs[i] < b.limbs[i]) return -1; - if (a.limbs[i] > b.limbs[i]) return 1; - } - return 0; -} - -__device__ static bool u256_is_zero(uint256 a) { - return (a.limbs[0] | a.limbs[1] | a.limbs[2] | a.limbs[3]) == 0; -} - -__device__ static uint256 u256_add(uint256 a, uint256 b, uint64_t& carry) { - uint256 r; uint64_t c = 0; - for (int i = 0; i < 4; i++) { - uint64_t sum = a.limbs[i] + c; - c = (sum < a.limbs[i]) ? 1ULL : 0ULL; - uint64_t sum2 = sum + b.limbs[i]; - c += (sum2 < sum) ? 1ULL : 0ULL; - r.limbs[i] = sum2; - } - carry = c; return r; -} - -__device__ static uint256 u256_sub(uint256 a, uint256 b, uint64_t& borrow) { - uint256 r; uint64_t bw = 0; - for (int i = 0; i < 4; i++) { - uint64_t diff = a.limbs[i] - bw; - bw = (diff > a.limbs[i]) ? 1ULL : 0ULL; - uint64_t diff2 = diff - b.limbs[i]; - bw += (diff2 > diff) ? 1ULL : 0ULL; - r.limbs[i] = diff2; - } - borrow = bw; return r; -} - -// CUDA has __int128 -- use it instead of the manual 32-bit split -__device__ static void mul64(uint64_t a, uint64_t b, uint64_t& lo, uint64_t& hi) { -#ifdef __CUDA_ARCH__ - unsigned __int128 prod = (unsigned __int128)a * b; - lo = (uint64_t)prod; - hi = (uint64_t)(prod >> 64); -#else - uint64_t a_lo = a & 0xFFFFFFFFULL, a_hi = a >> 32; - uint64_t b_lo = b & 0xFFFFFFFFULL, b_hi = b >> 32; - uint64_t ll = a_lo * b_lo, lh = a_lo * b_hi; - uint64_t hl = a_hi * b_lo, hh = a_hi * b_hi; - uint64_t mid = lh + (ll >> 32); - uint64_t mid2 = mid + hl; - if (mid2 < mid) hh += (1ULL << 32); - lo = (mid2 << 32) | (ll & 0xFFFFFFFFULL); - hi = hh + (mid2 >> 32); -#endif -} - -// Montgomery reduction for 256-bit field -__device__ static uint256 mont_reduce(uint64_t t[8], uint256 m, uint64_t inv) { - uint64_t a[9]; - for (int i = 0; i < 8; i++) a[i] = t[i]; - a[8] = 0; - for (int i = 0; i < 4; i++) { - uint64_t u = a[i] * inv; - uint64_t carry = 0; - for (int j = 0; j < 4; j++) { - uint64_t lo, hi; - mul64(u, m.limbs[j], lo, hi); - uint64_t sum = lo + carry; if (sum < lo) hi++; - lo = sum; - sum = a[i + j] + lo; if (sum < a[i + j]) hi++; - a[i + j] = sum; - carry = hi; - } - for (int j = 4; i + j <= 8; j++) { - uint64_t sum = a[i + j] + carry; - carry = (sum < a[i + j]) ? 1ULL : 0ULL; - a[i + j] = sum; - if (!carry) break; - } - } - uint256 r = {{a[4], a[5], a[6], a[7]}}; - if (a[8] || u256_cmp(r, m) >= 0) { uint64_t bw; r = u256_sub(r, m, bw); } - return r; -} - -__device__ static uint256 fp_mul(uint256 a, uint256 b) { - uint64_t t[8] = {}; - for (int i = 0; i < 4; i++) { - uint64_t carry = 0; - for (int j = 0; j < 4; j++) { - uint64_t lo, hi; - mul64(a.limbs[i], b.limbs[j], lo, hi); - uint64_t sum = lo + carry; if (sum < lo) hi++; - sum = t[i + j] + sum; if (sum < t[i + j]) hi++; - t[i + j] = sum; - carry = hi; - } - t[i + 4] = carry; - } - return mont_reduce(t, FROST_P, FROST_P_INV); -} - -__device__ static uint256 fp_sqr(uint256 a) { return fp_mul(a, a); } - -__device__ static uint256 fp_add(uint256 a, uint256 b) { - uint64_t c; uint256 r = u256_add(a, b, c); - if (c || u256_cmp(r, FROST_P) >= 0) { uint64_t bw; r = u256_sub(r, FROST_P, bw); } - return r; -} - -__device__ static uint256 fp_sub(uint256 a, uint256 b) { - uint64_t bw; uint256 r = u256_sub(a, b, bw); - if (bw) { uint64_t c; r = u256_add(r, FROST_P, c); } - return r; -} - -__device__ static uint256 to_mont(uint256 a) { return fp_mul(a, FROST_MONT_R2_P); } - -__device__ static uint256 fp_inv(uint256 a) { - uint256 exp = FROST_P; exp.limbs[0] -= 2; - uint256 result = FROST_MONT_R, base = a; - for (int i = 0; i < 4; i++) - for (int bit = 0; bit < 64; bit++) { - if ((exp.limbs[i] >> bit) & 1) result = fp_mul(result, base); - base = fp_sqr(base); - } - return result; -} - -// ============================================================================= -// secp256k1 point (Jacobian) -// ============================================================================= - -struct Point { - uint256 x, y, z; -}; - -__device__ static Point point_identity() { - Point p; p.x = FROST_MONT_R; p.y = FROST_MONT_R; p.z = FROST_ZERO; - return p; -} - -__device__ static bool point_is_inf(Point p) { return u256_is_zero(p.z); } - -__device__ static Point point_double(Point p) { - if (point_is_inf(p)) return p; - uint256 A = fp_sqr(p.y); - uint256 B = fp_mul(p.x, A); - uint256 S = fp_add(B, B); S = fp_add(S, S); - uint256 C = fp_sqr(A); - uint256 X2 = fp_sqr(p.x); - uint256 M = fp_add(X2, fp_add(X2, X2)); - uint256 X3 = fp_sub(fp_sqr(M), fp_add(S, S)); - uint256 C8 = fp_add(C, C); C8 = fp_add(C8, C8); C8 = fp_add(C8, C8); - uint256 Y3 = fp_sub(fp_mul(M, fp_sub(S, X3)), C8); - uint256 Z3 = fp_mul(p.y, p.z); Z3 = fp_add(Z3, Z3); - Point r; r.x = X3; r.y = Y3; r.z = Z3; return r; -} - -__device__ static Point point_add_mixed(Point P, uint256 Qx, uint256 Qy) { - if (point_is_inf(P)) { Point r; r.x = Qx; r.y = Qy; r.z = FROST_MONT_R; return r; } - uint256 Z2 = fp_sqr(P.z); - uint256 U2 = fp_mul(Qx, Z2); - uint256 S2 = fp_mul(Qy, fp_mul(Z2, P.z)); - uint256 H = fp_sub(U2, P.x); - uint256 R = fp_sub(S2, P.y); - if (u256_is_zero(H)) { - if (u256_is_zero(R)) return point_double(P); - return point_identity(); - } - uint256 H2 = fp_sqr(H); - uint256 H3 = fp_mul(H, H2); - uint256 U1H2 = fp_mul(P.x, H2); - uint256 X3 = fp_sub(fp_sub(fp_sqr(R), H3), fp_add(U1H2, U1H2)); - uint256 Y3 = fp_sub(fp_mul(R, fp_sub(U1H2, X3)), fp_mul(P.y, H3)); - uint256 Z3 = fp_mul(H, P.z); - Point r; r.x = X3; r.y = Y3; r.z = Z3; return r; -} - -__device__ static Point point_mul(uint256 k, uint256 Px, uint256 Py) { - Point result = point_identity(); - for (int i = 3; i >= 0; i--) - for (int bit = 63; bit >= 0; bit--) { - result = point_double(result); - if ((k.limbs[i] >> bit) & 1) - result = point_add_mixed(result, Px, Py); - } - return result; -} - -__device__ static void point_to_affine(Point p, uint256& ax, uint256& ay) { - if (point_is_inf(p)) { ax = FROST_ZERO; ay = FROST_ZERO; return; } - uint256 zi = fp_inv(p.z); - uint256 zi2 = fp_sqr(zi); - ax = fp_mul(p.x, zi2); - ay = fp_mul(p.y, fp_mul(zi2, zi)); -} - -// ============================================================================= -// FROST structures -// ============================================================================= - -struct FROSTCommitment { - uint8_t data[66]; // D[33] || E[33] -}; - -struct FROSTPartialSig { - uint8_t data[32]; // z_i scalar -}; - -struct FROSTPublicKey { - uint8_t data[33]; // compressed secp256k1 point -}; - -struct FROSTChallenge { - uint8_t data[32]; // c * lambda_i scalar -}; - -// ============================================================================= -// Verification kernel -// ============================================================================= - -extern "C" __global__ void frost_partial_verify_batch( - const FROSTCommitment* __restrict__ commitments, - const FROSTPartialSig* __restrict__ signatures, - const FROSTPublicKey* __restrict__ pubkeys, - const FROSTChallenge* __restrict__ challenges, - uint32_t* __restrict__ results, - const uint32_t* __restrict__ num_ops_ptr) -{ - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - uint32_t num_ops = *num_ops_ptr; - if (tid >= num_ops) return; - - // Read z_i scalar - uint256 z; - for (int i = 0; i < 4; i++) { - z.limbs[i] = 0; - for (int b = 0; b < 8; b++) - z.limbs[i] |= (uint64_t)signatures[tid].data[i * 8 + b] << (b * 8); - } - - // z must be < n - if (u256_cmp(z, FROST_N) >= 0) { - results[tid] = 0; - return; - } - - // Read c * lambda_i scalar - uint256 cl; - for (int i = 0; i < 4; i++) { - cl.limbs[i] = 0; - for (int b = 0; b < 8; b++) - cl.limbs[i] |= (uint64_t)challenges[tid].data[i * 8 + b] << (b * 8); - } - - // Decompress commitment D (first 33 bytes) - const uint8_t* comm = commitments[tid].data; - uint256 dx_raw; - for (int i = 0; i < 4; i++) { - dx_raw.limbs[i] = 0; - for (int b = 0; b < 8 && i * 8 + b < 32; b++) { - int src = 32 - (i * 8 + b); - if (src >= 1 && src <= 32) - dx_raw.limbs[i] |= (uint64_t)comm[src] << (b * 8); - } - } - - uint256 dx_mont = to_mont(dx_raw); - - // Recover y from x on secp256k1: y^2 = x^3 + 7 - uint256 x2 = fp_sqr(dx_mont); - uint256 x3 = fp_mul(x2, dx_mont); - uint256 b7 = to_mont(uint256{{7, 0, 0, 0}}); - uint256 y2 = fp_add(x3, b7); - - // sqrt via (p+1)/4 (p = 3 mod 4) - uint256 exp = FROST_P; - exp.limbs[0] += 1; - for (int i = 0; i < 3; i++) - exp.limbs[i] = (exp.limbs[i] >> 2) | (exp.limbs[i + 1] << 62); - exp.limbs[3] >>= 2; - - uint256 dy_mont = FROST_MONT_R; - uint256 base_y = y2; - for (int i = 0; i < 4; i++) - for (int bit = 0; bit < 64; bit++) { - if ((exp.limbs[i] >> bit) & 1) dy_mont = fp_mul(dy_mont, base_y); - base_y = fp_sqr(base_y); - } - - // Compute z*G - uint256 gx_mont = to_mont(FROST_GX); - uint256 gy_mont = to_mont(FROST_GY); - Point zG = point_mul(z, gx_mont, gy_mont); - - // Decompress public key Y_i - const uint8_t* pk = pubkeys[tid].data; - uint256 yx_raw; - for (int i = 0; i < 4; i++) { - yx_raw.limbs[i] = 0; - for (int b = 0; b < 8 && i * 8 + b < 32; b++) { - int src = 32 - (i * 8 + b); - if (src >= 1 && src <= 32) - yx_raw.limbs[i] |= (uint64_t)pk[src] << (b * 8); - } - } - uint256 yx_mont = to_mont(yx_raw); - - // Recover y for public key - uint256 yx2 = fp_sqr(yx_mont); - uint256 yx3 = fp_mul(yx2, yx_mont); - uint256 yy2 = fp_add(yx3, b7); - - uint256 yy_mont = FROST_MONT_R; - uint256 base_yy = yy2; - for (int i = 0; i < 4; i++) - for (int bit = 0; bit < 64; bit++) { - if ((exp.limbs[i] >> bit) & 1) yy_mont = fp_mul(yy_mont, base_yy); - base_yy = fp_sqr(base_yy); - } - - // Compute c*lambda_i * Y_i - Point clY = point_mul(cl, yx_mont, yy_mont); - - // Compute R + c*lambda_i*Y_i - uint256 cl_ax, cl_ay; - point_to_affine(clY, cl_ax, cl_ay); - - Point R_point; - R_point.x = dx_mont; R_point.y = dy_mont; R_point.z = FROST_MONT_R; - Point sum = point_add_mixed(R_point, cl_ax, cl_ay); - - // Compare z*G == R + c*lambda_i*Y_i - uint256 zg_x, zg_y, sum_x, sum_y; - point_to_affine(zG, zg_x, zg_y); - point_to_affine(sum, sum_x, sum_y); - - bool valid = (u256_cmp(zg_x, sum_x) == 0) && (u256_cmp(zg_y, sum_y) == 0); - results[tid] = valid ? 1u : 0u; -} diff --git a/frost/gpu/cuda/frost_presign.cu b/frost/gpu/cuda/frost_presign.cu deleted file mode 100644 index 8cab091..0000000 --- a/frost/gpu/cuda/frost_presign.cu +++ /dev/null @@ -1,520 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// FROST batched pre-signing kernel — CUDA implementation. -// -// Byte-equal to frost/gpu/metal/frost_presign.metal and the CPU canonical -// body in frost/cpp/presign.cpp. One CUDA thread per (signer, slot) pair. -// Nonces (d_i, e_i) live in registers / shared memory only; the kernel -// writes only the public commitment bytes to global memory. -// -// Build modes: -// * CRYPTO_ENABLE_CUDA=ON -> nvcc, real device kernel -// * CRYPTO_ENABLE_CUDA=OFF -> host C++ polyfill (same TU compiles -// under g++ as plain C++); the polyfill -// driver below uses it as the byte-equal -// oracle for tests on Apple/non-CUDA hosts. - -#include -#include - -#ifndef __CUDA_ARCH__ -# define __device__ -# define __global__ -# define __shared__ -# define __constant__ -struct dim3 { unsigned x, y, z; }; -static dim3 blockIdx, blockDim, threadIdx; -#endif - -namespace { - -struct uint256_t { uint64_t limbs[4]; }; - -__device__ static const uint256_t FP_P = {{ - 0xFFFFFFFEFFFFFC2FULL, 0xFFFFFFFFFFFFFFFFULL, - 0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL -}}; - -__device__ static const uint256_t FP_N = {{ - 0xBFD25E8CD0364141ULL, 0xBAAEDCE6AF48A03BULL, - 0xFFFFFFFFFFFFFFFEULL, 0xFFFFFFFFFFFFFFFFULL -}}; - -__device__ static const uint256_t G_X = {{ - 0x59F2815B16F81798ULL, 0x029BFCDB2DCE28D9ULL, - 0x55A06295CE870B07ULL, 0x79BE667EF9DCBBACULL -}}; -__device__ static const uint256_t G_Y = {{ - 0x9C47D08FFB10D4B8ULL, 0xFD17B448A6855419ULL, - 0x5DA4FBFC0E1108A8ULL, 0x483ADA7726A3C465ULL -}}; - -__device__ static const uint256_t R2_P = {{ - 0x000007A2000E90A1ULL, 0x0000000000000001ULL, 0ULL, 0ULL -}}; -__device__ static const uint64_t P_INV = 0xD838091DD2253531ULL; -__device__ static const uint256_t MONT_R = {{0x00000001000003D1ULL, 0ULL, 0ULL, 0ULL}}; -__device__ static const uint256_t ZERO256 = {{0, 0, 0, 0}}; -__device__ static const uint256_t ONE256 = {{1, 0, 0, 0}}; - -__device__ static int u256_cmp(uint256_t a, uint256_t b) { - for (int i = 3; i >= 0; --i) { - if (a.limbs[i] < b.limbs[i]) return -1; - if (a.limbs[i] > b.limbs[i]) return 1; - } - return 0; -} - -__device__ static bool u256_is_zero(uint256_t a) { - return (a.limbs[0] | a.limbs[1] | a.limbs[2] | a.limbs[3]) == 0ULL; -} - -__device__ static uint256_t u256_add(uint256_t a, uint256_t b, uint64_t& carry) { - uint256_t r; uint64_t c = 0; - for (int i = 0; i < 4; ++i) { - uint64_t s = a.limbs[i] + c; - c = (s < a.limbs[i]) ? 1ULL : 0ULL; - uint64_t s2 = s + b.limbs[i]; - c += (s2 < s) ? 1ULL : 0ULL; - r.limbs[i] = s2; - } - carry = c; return r; -} - -__device__ static uint256_t u256_sub(uint256_t a, uint256_t b, uint64_t& borrow) { - uint256_t r; uint64_t bw = 0; - for (int i = 0; i < 4; ++i) { - uint64_t d = a.limbs[i] - bw; - bw = (d > a.limbs[i]) ? 1ULL : 0ULL; - uint64_t d2 = d - b.limbs[i]; - bw += (d2 > d) ? 1ULL : 0ULL; - r.limbs[i] = d2; - } - borrow = bw; return r; -} - -__device__ static void mul64(uint64_t a, uint64_t b, uint64_t& lo, uint64_t& hi) { -#ifdef __CUDA_ARCH__ - unsigned __int128 p = (unsigned __int128)a * b; - lo = (uint64_t)p; hi = (uint64_t)(p >> 64); -#else - uint64_t al = a & 0xFFFFFFFFULL, ah = a >> 32; - uint64_t bl = b & 0xFFFFFFFFULL, bh = b >> 32; - uint64_t ll = al * bl, lh = al * bh, hl = ah * bl, hh = ah * bh; - uint64_t mid = lh + (ll >> 32); - uint64_t mid2 = mid + hl; - if (mid2 < mid) hh += (1ULL << 32); - lo = (mid2 << 32) | (ll & 0xFFFFFFFFULL); - hi = hh + (mid2 >> 32); -#endif -} - -__device__ static uint256_t mont_reduce(uint64_t t[8], uint256_t m, uint64_t inv) { - uint64_t a[9]; - for (int i = 0; i < 8; ++i) a[i] = t[i]; - a[8] = 0; - for (int i = 0; i < 4; ++i) { - uint64_t u = a[i] * inv; - uint64_t carry = 0; - for (int j = 0; j < 4; ++j) { - uint64_t lo, hi; mul64(u, m.limbs[j], lo, hi); - uint64_t s = lo + carry; if (s < lo) hi++; - s = a[i + j] + s; if (s < a[i + j]) hi++; - a[i + j] = s; - carry = hi; - } - for (int j = 4; i + j <= 8; ++j) { - uint64_t s = a[i + j] + carry; - carry = (s < a[i + j]) ? 1ULL : 0ULL; - a[i + j] = s; - if (!carry) break; - } - } - uint256_t r = {{a[4], a[5], a[6], a[7]}}; - if (a[8] || u256_cmp(r, m) >= 0) { uint64_t bw; r = u256_sub(r, m, bw); } - return r; -} - -__device__ static uint256_t fp_mul(uint256_t a, uint256_t b) { - uint64_t t[8] = {0}; - for (int i = 0; i < 4; ++i) { - uint64_t carry = 0; - for (int j = 0; j < 4; ++j) { - uint64_t lo, hi; mul64(a.limbs[i], b.limbs[j], lo, hi); - uint64_t s = lo + carry; if (s < lo) hi++; - s = t[i + j] + s; if (s < t[i + j]) hi++; - t[i + j] = s; carry = hi; - } - t[i + 4] = carry; - } - return mont_reduce(t, FP_P, P_INV); -} - -__device__ static uint256_t fp_sqr(uint256_t a) { return fp_mul(a, a); } - -__device__ static uint256_t fp_add(uint256_t a, uint256_t b) { - uint64_t c; uint256_t r = u256_add(a, b, c); - if (c || u256_cmp(r, FP_P) >= 0) { uint64_t bw; r = u256_sub(r, FP_P, bw); } - return r; -} - -__device__ static uint256_t fp_sub(uint256_t a, uint256_t b) { - uint64_t bw; uint256_t r = u256_sub(a, b, bw); - if (bw) { uint64_t c; r = u256_add(r, FP_P, c); } - return r; -} - -__device__ static uint256_t to_mont(uint256_t a) { return fp_mul(a, R2_P); } - -__device__ static uint256_t fp_inv(uint256_t a) { - uint256_t exp = FP_P; exp.limbs[0] -= 2; - uint256_t r = MONT_R, base = a; - for (int i = 0; i < 4; ++i) - for (int bit = 0; bit < 64; ++bit) { - if ((exp.limbs[i] >> bit) & 1) r = fp_mul(r, base); - base = fp_sqr(base); - } - return r; -} - -struct Point { uint256_t x, y, z; }; - -__device__ static Point jac_zero() { - Point p; p.x = MONT_R; p.y = MONT_R; p.z = ZERO256; return p; -} -__device__ static bool jac_is_inf(Point p) { return u256_is_zero(p.z); } - -__device__ static Point jac_double(Point p) { - if (jac_is_inf(p)) return p; - if (u256_is_zero(p.y)) return jac_zero(); - uint256_t A = fp_sqr(p.x), B = fp_sqr(p.y), C = fp_sqr(B); - uint256_t XB = fp_add(p.x, B); - uint256_t D = fp_sub(fp_sub(fp_sqr(XB), A), C); - D = fp_add(D, D); - uint256_t E = fp_add(A, A); E = fp_add(E, A); - uint256_t F = fp_sqr(E); - uint256_t X3 = fp_sub(F, fp_add(D, D)); - uint256_t eC = fp_add(C, C); eC = fp_add(eC, eC); eC = fp_add(eC, eC); - uint256_t Y3 = fp_sub(fp_mul(E, fp_sub(D, X3)), eC); - uint256_t Z3 = fp_mul(p.y, p.z); Z3 = fp_add(Z3, Z3); - Point r; r.x = X3; r.y = Y3; r.z = Z3; return r; -} - -__device__ static Point jac_add_mixed(Point P, uint256_t Qx, uint256_t Qy) { - if (jac_is_inf(P)) { Point r; r.x = Qx; r.y = Qy; r.z = MONT_R; return r; } - uint256_t Z1Z1 = fp_sqr(P.z); - uint256_t U2 = fp_mul(Qx, Z1Z1); - uint256_t S2 = fp_mul(Qy, fp_mul(Z1Z1, P.z)); - uint256_t H = fp_sub(U2, P.x); - uint256_t R = fp_sub(S2, P.y); - if (u256_is_zero(H)) { - if (u256_is_zero(R)) return jac_double(P); - return jac_zero(); - } - uint256_t HH = fp_sqr(H); - uint256_t HHH = fp_mul(H, HH); - uint256_t U1HH = fp_mul(P.x, HH); - uint256_t X3 = fp_sub(fp_sub(fp_sqr(R), HHH), fp_add(U1HH, U1HH)); - uint256_t Y3 = fp_sub(fp_mul(R, fp_sub(U1HH, X3)), fp_mul(P.y, HHH)); - uint256_t Z3 = fp_mul(P.z, H); - Point r; r.x = X3; r.y = Y3; r.z = Z3; return r; -} - -__device__ static Point scalar_mul_base(uint256_t k) { - uint256_t Gx = to_mont(G_X); - uint256_t Gy = to_mont(G_Y); - Point r = jac_zero(); - for (int limb = 3; limb >= 0; --limb) { - uint64_t w = k.limbs[limb]; - for (int bit = 63; bit >= 0; --bit) { - r = jac_double(r); - Point cand = jac_add_mixed(r, Gx, Gy); - uint64_t mask = -((uint64_t)((w >> bit) & 1ULL)); - for (int q = 0; q < 4; ++q) { - r.x.limbs[q] = (r.x.limbs[q] & ~mask) | (cand.x.limbs[q] & mask); - r.y.limbs[q] = (r.y.limbs[q] & ~mask) | (cand.y.limbs[q] & mask); - r.z.limbs[q] = (r.z.limbs[q] & ~mask) | (cand.z.limbs[q] & mask); - } - } - } - return r; -} - -__device__ static void jac_to_compressed(Point p, uint8_t out33[33]) { - if (jac_is_inf(p)) { for (int i = 0; i < 33; ++i) out33[i] = 0; return; } - uint256_t zi = fp_inv(p.z); - uint256_t zi2 = fp_sqr(zi); - uint256_t zi3 = fp_mul(zi2, zi); - uint256_t x_mont = fp_mul(p.x, zi2); - uint256_t y_mont = fp_mul(p.y, zi3); - uint256_t x_plain = fp_mul(x_mont, ONE256); - uint256_t y_plain = fp_mul(y_mont, ONE256); - out33[0] = (y_plain.limbs[0] & 1ULL) ? 0x03 : 0x02; - for (int limb = 0; limb < 4; ++limb) { - int base = (3 - limb) * 8 + 1; - uint64_t v = x_plain.limbs[limb]; - for (int j = 7; j >= 0; --j) { out33[base + j] = (uint8_t)(v & 0xFF); v >>= 8; } - } -} - -// --- SHA-256 + HMAC + HKDF (same algorithm as Metal kernel) --- - -__device__ static const uint32_t K256[64] = { - 0x428a2f98u, 0x71374491u, 0xb5c0fbcfu, 0xe9b5dba5u, 0x3956c25bu, 0x59f111f1u, 0x923f82a4u, 0xab1c5ed5u, - 0xd807aa98u, 0x12835b01u, 0x243185beu, 0x550c7dc3u, 0x72be5d74u, 0x80deb1feu, 0x9bdc06a7u, 0xc19bf174u, - 0xe49b69c1u, 0xefbe4786u, 0x0fc19dc6u, 0x240ca1ccu, 0x2de92c6fu, 0x4a7484aau, 0x5cb0a9dcu, 0x76f988dau, - 0x983e5152u, 0xa831c66du, 0xb00327c8u, 0xbf597fc7u, 0xc6e00bf3u, 0xd5a79147u, 0x06ca6351u, 0x14292967u, - 0x27b70a85u, 0x2e1b2138u, 0x4d2c6dfcu, 0x53380d13u, 0x650a7354u, 0x766a0abbu, 0x81c2c92eu, 0x92722c85u, - 0xa2bfe8a1u, 0xa81a664bu, 0xc24b8b70u, 0xc76c51a3u, 0xd192e819u, 0xd6990624u, 0xf40e3585u, 0x106aa070u, - 0x19a4c116u, 0x1e376c08u, 0x2748774cu, 0x34b0bcb5u, 0x391c0cb3u, 0x4ed8aa4au, 0x5b9cca4fu, 0x682e6ff3u, - 0x748f82eeu, 0x78a5636fu, 0x84c87814u, 0x8cc70208u, 0x90befffau, 0xa4506cebu, 0xbef9a3f7u, 0xc67178f2u -}; - -__device__ static uint32_t rotr32(uint32_t x, uint32_t n) { return (x >> n) | (x << (32 - n)); } - -__device__ static void sha256_block(uint32_t H[8], const uint8_t block[64]) { - uint32_t W[64]; - for (int i = 0; i < 16; ++i) { - W[i] = ((uint32_t)block[i*4] << 24) | ((uint32_t)block[i*4+1] << 16) | - ((uint32_t)block[i*4+2] << 8) | ((uint32_t)block[i*4+3]); - } - for (int i = 16; i < 64; ++i) { - uint32_t s0 = rotr32(W[i-15], 7) ^ rotr32(W[i-15], 18) ^ (W[i-15] >> 3); - uint32_t s1 = rotr32(W[i-2], 17) ^ rotr32(W[i-2], 19) ^ (W[i-2] >> 10); - W[i] = W[i-16] + s0 + W[i-7] + s1; - } - uint32_t a=H[0], b=H[1], c=H[2], d=H[3], e=H[4], f=H[5], g=H[6], h=H[7]; - for (int i = 0; i < 64; ++i) { - uint32_t S1 = rotr32(e, 6) ^ rotr32(e, 11) ^ rotr32(e, 25); - uint32_t ch = (e & f) ^ ((~e) & g); - uint32_t t1 = h + S1 + ch + K256[i] + W[i]; - uint32_t S0 = rotr32(a, 2) ^ rotr32(a, 13) ^ rotr32(a, 22); - uint32_t mj = (a & b) ^ (a & c) ^ (b & c); - uint32_t t2 = S0 + mj; - h = g; g = f; f = e; e = d + t1; - d = c; c = b; b = a; a = t1 + t2; - } - H[0]+=a; H[1]+=b; H[2]+=c; H[3]+=d; H[4]+=e; H[5]+=f; H[6]+=g; H[7]+=h; -} - -__device__ static void sha256_compute(const uint8_t* data, uint32_t data_len, uint8_t out[32]) { - uint8_t buf[256 + 64]; - for (uint32_t i = 0; i < data_len; ++i) buf[i] = data[i]; - uint32_t len = data_len; - buf[len++] = 0x80; - while ((len % 64) != 56) buf[len++] = 0; - uint64_t bits = (uint64_t)data_len * 8ULL; - for (int i = 7; i >= 0; --i) buf[len++] = (uint8_t)(bits >> (i * 8)); - uint32_t H[8] = {0x6a09e667u, 0xbb67ae85u, 0x3c6ef372u, 0xa54ff53au, - 0x510e527fu, 0x9b05688cu, 0x1f83d9abu, 0x5be0cd19u}; - for (uint32_t off = 0; off < len; off += 64) sha256_block(H, buf + off); - for (int i = 0; i < 8; ++i) { - out[i*4] = (uint8_t)(H[i] >> 24); - out[i*4+1] = (uint8_t)(H[i] >> 16); - out[i*4+2] = (uint8_t)(H[i] >> 8); - out[i*4+3] = (uint8_t)(H[i] ); - } -} - -__device__ static void hmac_sha256(const uint8_t* key, uint32_t key_len, - const uint8_t* msg, uint32_t msg_len, - uint8_t out[32]) { - uint8_t k[64]; - for (int i = 0; i < 64; ++i) k[i] = 0; - for (uint32_t i = 0; i < key_len && i < 64; ++i) k[i] = key[i]; - uint8_t ipad[64 + 256]; - uint8_t opad[64 + 32]; - for (int i = 0; i < 64; ++i) { - ipad[i] = k[i] ^ 0x36; - opad[i] = k[i] ^ 0x5c; - } - for (uint32_t i = 0; i < msg_len; ++i) ipad[64 + i] = msg[i]; - uint8_t inner[32]; - sha256_compute(ipad, 64 + msg_len, inner); - for (int i = 0; i < 32; ++i) opad[64 + i] = inner[i]; - sha256_compute(opad, 64 + 32, out); -} - -__device__ static void hkdf_expand_64(const uint8_t prk[32], const uint8_t info[12], - uint8_t out[64]) { - uint8_t buf[32 + 12 + 1]; - for (int i = 0; i < 12; ++i) buf[i] = info[i]; - buf[12] = 0x01; - uint8_t T1[32]; - hmac_sha256(prk, 32, buf, 13, T1); - for (int i = 0; i < 32; ++i) out[i] = T1[i]; - for (int i = 0; i < 32; ++i) buf[i] = T1[i]; - for (int i = 0; i < 12; ++i) buf[32 + i] = info[i]; - buf[44] = 0x02; - uint8_t T2[32]; - hmac_sha256(prk, 32, buf, 45, T2); - for (int i = 0; i < 32; ++i) out[32 + i] = T2[i]; -} - -__device__ static bool be_to_scalar_lt_n(const uint8_t in_be[32], uint8_t out_be[32]) { - uint256_t v; - for (int limb = 0; limb < 4; ++limb) { - int base = (3 - limb) * 8; - uint64_t w = 0; - for (int j = 0; j < 8; ++j) w = (w << 8) | (uint64_t)in_be[base + j]; - v.limbs[limb] = w; - } - if (u256_cmp(v, FP_N) >= 0) { uint64_t bw; v = u256_sub(v, FP_N, bw); } - if (u256_is_zero(v)) return false; - for (int limb = 0; limb < 4; ++limb) { - int base = (3 - limb) * 8; - uint64_t w = v.limbs[limb]; - for (int j = 7; j >= 0; --j) { out_be[base + j] = (uint8_t)(w & 0xFF); w >>= 8; } - } - return true; -} - -} // anonymous namespace - -// ============================================================================= -// CUDA kernel entry point — same semantics as Metal frost_presign. -// ============================================================================= - -extern "C" __global__ void frost_presign_kernel( - const uint8_t* __restrict__ seed, // 32 bytes - const uint32_t* __restrict__ signer_ids, // m entries - uint32_t m, - uint32_t slot_id_base, - uint32_t n_slots, - uint8_t* __restrict__ commits_out) // m * n_slots * 66 bytes -{ - uint32_t gid = blockIdx.x * blockDim.x + threadIdx.x; - uint32_t total = m * n_slots; - if (gid >= total) return; - - uint32_t signer_idx = gid / n_slots; - uint32_t slot_idx = gid % n_slots; - uint32_t signer_id = signer_ids[signer_idx]; - uint32_t slot_id = slot_id_base + slot_idx; - if (signer_id == 0) return; - - uint8_t salt[16] = {'f','r','o','s','t','-','p','r','e','s','i','g','n','-','v','1'}; - uint8_t prk[32]; - hmac_sha256(salt, 16, seed, 32, prk); - - uint8_t info[12]; - info[0] = (uint8_t)(signer_id ); info[1] = (uint8_t)(signer_id >> 8); - info[2] = (uint8_t)(signer_id >> 16); info[3] = (uint8_t)(signer_id >> 24); - info[4] = (uint8_t)(slot_id ); info[5] = (uint8_t)(slot_id >> 8); - info[6] = (uint8_t)(slot_id >> 16); info[7] = (uint8_t)(slot_id >> 24); - - uint8_t d_be[32], e_be[32]; - bool got_d = false, got_e = false; - uint32_t ctr = 0; - while (!(got_d && got_e)) { - info[ 8] = (uint8_t)(ctr ); - info[ 9] = (uint8_t)(ctr >> 8); - info[10] = (uint8_t)(ctr >> 16); - info[11] = (uint8_t)(ctr >> 24); - uint8_t okm[64]; - hkdf_expand_64(prk, info, okm); - if (!got_d) got_d = be_to_scalar_lt_n(okm, d_be); - if (!got_e) got_e = be_to_scalar_lt_n(okm + 32, e_be); - ++ctr; - if (ctr > 1024) return; - } - - uint256_t d_u256, e_u256; - for (int limb = 0; limb < 4; ++limb) { - int base = (3 - limb) * 8; - uint64_t wd = 0, we = 0; - for (int j = 0; j < 8; ++j) { - wd = (wd << 8) | (uint64_t)d_be[base + j]; - we = (we << 8) | (uint64_t)e_be[base + j]; - } - d_u256.limbs[limb] = wd; - e_u256.limbs[limb] = we; - } - Point D = scalar_mul_base(d_u256); - Point E = scalar_mul_base(e_u256); - - uint8_t D_bytes[33], E_bytes[33]; - jac_to_compressed(D, D_bytes); - jac_to_compressed(E, E_bytes); - - uint8_t* dst = commits_out + (uint64_t)gid * 66ULL; - for (int i = 0; i < 33; ++i) dst[i] = D_bytes[i]; - for (int i = 0; i < 33; ++i) dst[33 + i] = E_bytes[i]; - - for (int i = 0; i < 32; ++i) { d_be[i] = 0; e_be[i] = 0; } -} - -// Host polyfill: callable from C++ test harness when CRYPTO_ENABLE_CUDA=OFF. -// Iterates the same kernel body sequentially; result is byte-equal to the -// device kernel by construction. -extern "C" int frost_presign_cuda_host( - const uint8_t* seed, - const uint32_t* signer_ids, - uint32_t m, - uint32_t slot_id_base, - uint32_t n_slots, - uint8_t* commits_out) -{ - if (!seed || !signer_ids || !commits_out || m == 0 || n_slots == 0) return -1; - uint32_t total = m * n_slots; - for (uint32_t gid = 0; gid < total; ++gid) { -#ifndef __CUDA_ARCH__ - // Drive the same sequence via the device-marked helpers (which are - // plain C++ in host mode thanks to the polyfill macros above). - uint32_t signer_idx = gid / n_slots; - uint32_t slot_idx = gid % n_slots; - uint32_t signer_id = signer_ids[signer_idx]; - uint32_t slot_id = slot_id_base + slot_idx; - if (signer_id == 0) return -1; - - uint8_t salt[16] = {'f','r','o','s','t','-','p','r','e','s','i','g','n','-','v','1'}; - uint8_t prk[32]; - hmac_sha256(salt, 16, seed, 32, prk); - - uint8_t info[12]; - info[0] = (uint8_t)(signer_id ); info[1] = (uint8_t)(signer_id >> 8); - info[2] = (uint8_t)(signer_id >> 16); info[3] = (uint8_t)(signer_id >> 24); - info[4] = (uint8_t)(slot_id ); info[5] = (uint8_t)(slot_id >> 8); - info[6] = (uint8_t)(slot_id >> 16); info[7] = (uint8_t)(slot_id >> 24); - - uint8_t d_be[32], e_be[32]; - bool got_d = false, got_e = false; - uint32_t ctr = 0; - while (!(got_d && got_e)) { - info[ 8] = (uint8_t)(ctr ); - info[ 9] = (uint8_t)(ctr >> 8); - info[10] = (uint8_t)(ctr >> 16); - info[11] = (uint8_t)(ctr >> 24); - uint8_t okm[64]; - hkdf_expand_64(prk, info, okm); - if (!got_d) got_d = be_to_scalar_lt_n(okm, d_be); - if (!got_e) got_e = be_to_scalar_lt_n(okm + 32, e_be); - ++ctr; - if (ctr > 1024) return -1; - } - uint256_t d_u256, e_u256; - for (int limb = 0; limb < 4; ++limb) { - int base = (3 - limb) * 8; - uint64_t wd = 0, we = 0; - for (int j = 0; j < 8; ++j) { - wd = (wd << 8) | (uint64_t)d_be[base + j]; - we = (we << 8) | (uint64_t)e_be[base + j]; - } - d_u256.limbs[limb] = wd; - e_u256.limbs[limb] = we; - } - Point D = scalar_mul_base(d_u256); - Point E = scalar_mul_base(e_u256); - uint8_t D_bytes[33], E_bytes[33]; - jac_to_compressed(D, D_bytes); - jac_to_compressed(E, E_bytes); - uint8_t* dst = commits_out + (uint64_t)gid * 66ULL; - for (int i = 0; i < 33; ++i) dst[i] = D_bytes[i]; - for (int i = 0; i < 33; ++i) dst[33 + i] = E_bytes[i]; -#else - (void)seed; (void)signer_ids; (void)commits_out; (void)slot_id_base; (void)n_slots; -#endif - } - return 0; -} diff --git a/frost/gpu/metal/frost.metal b/frost/gpu/metal/frost.metal deleted file mode 100644 index 117b02a..0000000 --- a/frost/gpu/metal/frost.metal +++ /dev/null @@ -1,422 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -/// @file frost.metal -/// Metal compute shader for FROST threshold Schnorr signature verification. -/// -/// FROST (Flexible Round-Optimized Schnorr Threshold) enables t-of-n signers -/// to produce a single Schnorr signature. This kernel verifies partial -/// signatures and the combined signature. -/// -/// Uses secp256k1 curve (reuses arithmetic from secp256k1_recover.metal). -/// -/// Operations: -/// - frost_partial_verify_batch: verify partial signatures from participants -/// -/// Each thread verifies one partial signature independently. - -#include -using namespace metal; - -// ============================================================================= -// 256-bit integer (reused from secp256k1) -// ============================================================================= - -struct uint256 { - ulong limbs[4]; -}; - -// secp256k1 constants -constant uint256 FROST_P = {{ - 0xFFFFFFFEFFFFFC2FUL, 0xFFFFFFFFFFFFFFFFUL, - 0xFFFFFFFFFFFFFFFFUL, 0xFFFFFFFFFFFFFFFFUL -}}; - -constant uint256 FROST_N = {{ - 0xBFD25E8CD0364141UL, 0xBAAEDCE6AF48A03BUL, - 0xFFFFFFFFFFFFFFFEUL, 0xFFFFFFFFFFFFFFFFUL -}}; - -constant uint256 FROST_GX = {{ - 0x59F2815B16F81798UL, 0x029BFCDB2DCE28D9UL, - 0x55A06295CE870B07UL, 0x79BE667EF9DCBBACUL -}}; - -constant uint256 FROST_GY = {{ - 0x9C47D08FFB10D4B8UL, 0xFD17B448A6855419UL, - 0x5DA4FBFC0E1108A8UL, 0x483ADA7726A3C465UL -}}; - -// Montgomery constants for field p -constant uint256 FROST_MONT_R2_P = {{ - 0x000007A2000E90A1UL, 0x0000000000000001UL, - 0x0000000000000000UL, 0x0000000000000000UL -}}; -constant ulong FROST_P_INV = 0xD838091DD2253531UL; - -constant uint256 FROST_MONT_R = {{ - 0x00000001000003D1UL, 0x0000000000000000UL, - 0x0000000000000000UL, 0x0000000000000000UL -}}; - -constant uint256 FROST_ZERO = {{0, 0, 0, 0}}; - -// ============================================================================= -// 256-bit arithmetic -// ============================================================================= - -inline int u256_cmp(uint256 a, uint256 b) { - for (int i = 3; i >= 0; i--) { - if (a.limbs[i] < b.limbs[i]) return -1; - if (a.limbs[i] > b.limbs[i]) return 1; - } - return 0; -} - -inline bool u256_is_zero(uint256 a) { - return (a.limbs[0] | a.limbs[1] | a.limbs[2] | a.limbs[3]) == 0; -} - -inline uint256 u256_add(uint256 a, uint256 b, thread ulong& carry) { - uint256 r; ulong c = 0; - for (int i = 0; i < 4; i++) { - ulong sum = a.limbs[i] + c; - c = (sum < a.limbs[i]) ? 1UL : 0UL; - ulong sum2 = sum + b.limbs[i]; - c += (sum2 < sum) ? 1UL : 0UL; - r.limbs[i] = sum2; - } - carry = c; return r; -} - -inline uint256 u256_sub(uint256 a, uint256 b, thread ulong& borrow) { - uint256 r; ulong bw = 0; - for (int i = 0; i < 4; i++) { - ulong diff = a.limbs[i] - bw; - bw = (diff > a.limbs[i]) ? 1UL : 0UL; - ulong diff2 = diff - b.limbs[i]; - bw += (diff2 > diff) ? 1UL : 0UL; - r.limbs[i] = diff2; - } - borrow = bw; return r; -} - -inline void mul64(ulong a, ulong b, thread ulong& lo, thread ulong& hi) { - ulong a_lo = a & 0xFFFFFFFFUL, a_hi = a >> 32; - ulong b_lo = b & 0xFFFFFFFFUL, b_hi = b >> 32; - ulong ll = a_lo * b_lo, lh = a_lo * b_hi; - ulong hl = a_hi * b_lo, hh = a_hi * b_hi; - ulong mid = lh + (ll >> 32); - ulong mid2 = mid + hl; - if (mid2 < mid) hh += (1UL << 32); - lo = (mid2 << 32) | (ll & 0xFFFFFFFFUL); - hi = hh + (mid2 >> 32); -} - -// Montgomery reduction -inline uint256 mont_reduce(ulong t[8], uint256 m, ulong inv) { - ulong a[9]; - for (int i = 0; i < 8; i++) a[i] = t[i]; - a[8] = 0; - for (int i = 0; i < 4; i++) { - ulong u = a[i] * inv; - ulong carry = 0; - for (int j = 0; j < 4; j++) { - ulong lo, hi; - mul64(u, m.limbs[j], lo, hi); - ulong sum = lo + carry; if (sum < lo) hi++; - lo = sum; - sum = a[i + j] + lo; if (sum < a[i + j]) hi++; - a[i + j] = sum; - carry = hi; - } - for (int j = 4; i + j <= 8; j++) { - ulong sum = a[i + j] + carry; - carry = (sum < a[i + j]) ? 1UL : 0UL; - a[i + j] = sum; - if (!carry) break; - } - } - uint256 r = {{a[4], a[5], a[6], a[7]}}; - if (a[8] || u256_cmp(r, m) >= 0) { ulong bw; r = u256_sub(r, m, bw); } - return r; -} - -inline uint256 fp_mul(uint256 a, uint256 b) { - ulong t[8] = {}; - for (int i = 0; i < 4; i++) { - ulong carry = 0; - for (int j = 0; j < 4; j++) { - ulong lo, hi; - mul64(a.limbs[i], b.limbs[j], lo, hi); - ulong sum = lo + carry; if (sum < lo) hi++; - sum = t[i + j] + sum; if (sum < t[i + j]) hi++; - t[i + j] = sum; - carry = hi; - } - t[i + 4] = carry; - } - return mont_reduce(t, FROST_P, FROST_P_INV); -} - -inline uint256 fp_sqr(uint256 a) { return fp_mul(a, a); } - -inline uint256 fp_add(uint256 a, uint256 b) { - ulong c; uint256 r = u256_add(a, b, c); - if (c || u256_cmp(r, FROST_P) >= 0) { ulong bw; r = u256_sub(r, FROST_P, bw); } - return r; -} - -inline uint256 fp_sub(uint256 a, uint256 b) { - ulong bw; uint256 r = u256_sub(a, b, bw); - if (bw) { ulong c; r = u256_add(r, FROST_P, c); } - return r; -} - -inline uint256 to_mont(uint256 a) { return fp_mul(a, FROST_MONT_R2_P); } - -inline uint256 fp_inv(uint256 a) { - uint256 exp = FROST_P; exp.limbs[0] -= 2; - uint256 result = FROST_MONT_R, base = a; - for (int i = 0; i < 4; i++) - for (int bit = 0; bit < 64; bit++) { - if ((exp.limbs[i] >> bit) & 1) result = fp_mul(result, base); - base = fp_sqr(base); - } - return result; -} - -// ============================================================================= -// secp256k1 point (Jacobian) -// ============================================================================= - -struct Point { - uint256 x, y, z; -}; - -inline Point point_identity() { - Point p; p.x = FROST_MONT_R; p.y = FROST_MONT_R; p.z = FROST_ZERO; - return p; -} - -inline bool point_is_inf(Point p) { return u256_is_zero(p.z); } - -inline Point point_double(Point p) { - if (point_is_inf(p)) return p; - uint256 A = fp_sqr(p.y); - uint256 B = fp_mul(p.x, A); - uint256 S = fp_add(B, B); S = fp_add(S, S); - uint256 C = fp_sqr(A); - uint256 X2 = fp_sqr(p.x); - uint256 M = fp_add(X2, fp_add(X2, X2)); - uint256 X3 = fp_sub(fp_sqr(M), fp_add(S, S)); - uint256 C8 = fp_add(C, C); C8 = fp_add(C8, C8); C8 = fp_add(C8, C8); - uint256 Y3 = fp_sub(fp_mul(M, fp_sub(S, X3)), C8); - uint256 Z3 = fp_mul(p.y, p.z); Z3 = fp_add(Z3, Z3); - Point r; r.x = X3; r.y = Y3; r.z = Z3; return r; -} - -inline Point point_add_mixed(Point P, uint256 Qx, uint256 Qy) { - if (point_is_inf(P)) { Point r; r.x = Qx; r.y = Qy; r.z = FROST_MONT_R; return r; } - uint256 Z2 = fp_sqr(P.z); - uint256 U2 = fp_mul(Qx, Z2); - uint256 S2 = fp_mul(Qy, fp_mul(Z2, P.z)); - uint256 H = fp_sub(U2, P.x); - uint256 R = fp_sub(S2, P.y); - if (u256_is_zero(H)) { - if (u256_is_zero(R)) return point_double(P); - return point_identity(); - } - uint256 H2 = fp_sqr(H); - uint256 H3 = fp_mul(H, H2); - uint256 U1H2 = fp_mul(P.x, H2); - uint256 X3 = fp_sub(fp_sub(fp_sqr(R), H3), fp_add(U1H2, U1H2)); - uint256 Y3 = fp_sub(fp_mul(R, fp_sub(U1H2, X3)), fp_mul(P.y, H3)); - uint256 Z3 = fp_mul(H, P.z); - Point r; r.x = X3; r.y = Y3; r.z = Z3; return r; -} - -inline Point point_mul(uint256 k, uint256 Px, uint256 Py) { - Point result = point_identity(); - for (int i = 3; i >= 0; i--) - for (int bit = 63; bit >= 0; bit--) { - result = point_double(result); - if ((k.limbs[i] >> bit) & 1) - result = point_add_mixed(result, Px, Py); - } - return result; -} - -inline void point_to_affine(Point p, thread uint256& ax, thread uint256& ay) { - if (point_is_inf(p)) { ax = FROST_ZERO; ay = FROST_ZERO; return; } - uint256 zi = fp_inv(p.z); - uint256 zi2 = fp_sqr(zi); - ax = fp_mul(p.x, zi2); - ay = fp_mul(p.y, fp_mul(zi2, zi)); -} - -// ============================================================================= -// FROST structures -// ============================================================================= - -/// Commitment: D[33] || E[33] (compressed secp256k1 points) -struct FROSTCommitment { - uchar data[66]; -}; - -/// Partial signature: z_i[32] (scalar) -struct FROSTPartialSig { - uchar data[32]; -}; - -/// Public key share: 33-byte compressed secp256k1 point -struct FROSTPublicKey { - uchar data[33]; -}; - -/// Challenge pre-computed by host: 32-byte scalar -struct FROSTChallenge { - uchar data[32]; -}; - -// ============================================================================= -// Verification kernel -// ============================================================================= - -/// Batch FROST partial signature verification. -/// Each thread verifies one partial signature from a participant. -/// -/// Verify: z_i * G == R_i + c * lambda_i * Y_i -/// where: -/// z_i = partial signature scalar -/// R_i = D_i + rho_i * E_i (nonce commitment) -/// c = challenge scalar -/// lambda_i = Lagrange coefficient -/// Y_i = public key share -/// -/// Host pre-computes c * lambda_i as a single scalar per participant. -/// -/// Output: results[tid] = 1 if valid, 0 otherwise. -kernel void frost_partial_verify_batch( - device const FROSTCommitment* commitments [[buffer(0)]], - device const FROSTPartialSig* signatures [[buffer(1)]], - device const FROSTPublicKey* pubkeys [[buffer(2)]], - device const FROSTChallenge* challenges [[buffer(3)]], // c * lambda_i - device uint* results [[buffer(4)]], - constant uint& num_ops [[buffer(5)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= num_ops) return; - - // Read z_i scalar - uint256 z; - for (int i = 0; i < 4; i++) { - z.limbs[i] = 0; - for (int b = 0; b < 8; b++) - z.limbs[i] |= (ulong)signatures[tid].data[i * 8 + b] << (b * 8); - } - - // z must be < n - if (u256_cmp(z, FROST_N) >= 0) { - results[tid] = 0; - return; - } - - // Read c * lambda_i scalar - uint256 cl; - for (int i = 0; i < 4; i++) { - cl.limbs[i] = 0; - for (int b = 0; b < 8; b++) - cl.limbs[i] |= (ulong)challenges[tid].data[i * 8 + b] << (b * 8); - } - - // Decompress commitment D (first 33 bytes) - device const uchar* comm = commitments[tid].data; - uint256 dx_raw; - for (int i = 0; i < 4; i++) { - dx_raw.limbs[i] = 0; - for (int b = 0; b < 8 && i * 8 + b < 32; b++) { - // Big-endian: byte 1 is MSB of x-coordinate - int src = 32 - (i * 8 + b); - if (src >= 1 && src <= 32) - dx_raw.limbs[i] |= (ulong)comm[src] << (b * 8); - } - } - - // Compute R = D (simplified: using just D commitment for partial verify) - // Full FROST would compute R_i = D_i + rho_i * E_i - uint256 dx_mont = to_mont(dx_raw); - - // Recover y from x on secp256k1: y^2 = x^3 + 7 - uint256 x2 = fp_sqr(dx_mont); - uint256 x3 = fp_mul(x2, dx_mont); - uint256 b7 = to_mont(uint256{{7, 0, 0, 0}}); - uint256 y2 = fp_add(x3, b7); - - // sqrt via Tonelli-Shanks (p = 3 mod 4) - uint256 exp = FROST_P; - exp.limbs[0] += 1; - // (p+1)/4 - for (int i = 0; i < 3; i++) - exp.limbs[i] = (exp.limbs[i] >> 2) | (exp.limbs[i + 1] << 62); - exp.limbs[3] >>= 2; - - uint256 dy_mont = FROST_MONT_R; - uint256 base_y = y2; - for (int i = 0; i < 4; i++) - for (int bit = 0; bit < 64; bit++) { - if ((exp.limbs[i] >> bit) & 1) dy_mont = fp_mul(dy_mont, base_y); - base_y = fp_sqr(base_y); - } - - // Compute z*G - uint256 gx_mont = to_mont(FROST_GX); - uint256 gy_mont = to_mont(FROST_GY); - Point zG = point_mul(z, gx_mont, gy_mont); - - // Decompress public key Y_i - device const uchar* pk = pubkeys[tid].data; - uint256 yx_raw; - for (int i = 0; i < 4; i++) { - yx_raw.limbs[i] = 0; - for (int b = 0; b < 8 && i * 8 + b < 32; b++) { - int src = 32 - (i * 8 + b); - if (src >= 1 && src <= 32) - yx_raw.limbs[i] |= (ulong)pk[src] << (b * 8); - } - } - uint256 yx_mont = to_mont(yx_raw); - - // Recover y for public key - uint256 yx2 = fp_sqr(yx_mont); - uint256 yx3 = fp_mul(yx2, yx_mont); - uint256 yy2 = fp_add(yx3, b7); - - uint256 yy_mont = FROST_MONT_R; - uint256 base_yy = yy2; - for (int i = 0; i < 4; i++) - for (int bit = 0; bit < 64; bit++) { - if ((exp.limbs[i] >> bit) & 1) yy_mont = fp_mul(yy_mont, base_yy); - base_yy = fp_sqr(base_yy); - } - - // Compute c*lambda_i * Y_i - Point clY = point_mul(cl, yx_mont, yy_mont); - - // Compute R + c*lambda_i*Y_i - // For simplicity, convert to affine and add - uint256 cl_ax, cl_ay; - point_to_affine(clY, cl_ax, cl_ay); - - Point R_point; - R_point.x = dx_mont; R_point.y = dy_mont; R_point.z = FROST_MONT_R; - Point sum = point_add_mixed(R_point, cl_ax, cl_ay); - - // Compare z*G == R + c*lambda_i*Y_i - uint256 zg_x, zg_y, sum_x, sum_y; - point_to_affine(zG, zg_x, zg_y); - point_to_affine(sum, sum_x, sum_y); - - bool valid = (u256_cmp(zg_x, sum_x) == 0) && (u256_cmp(zg_y, sum_y) == 0); - results[tid] = valid ? 1u : 0u; -} diff --git a/frost/gpu/metal/frost_aggregate.metal b/frost/gpu/metal/frost_aggregate.metal deleted file mode 100644 index 60de1c8..0000000 --- a/frost/gpu/metal/frost_aggregate.metal +++ /dev/null @@ -1,439 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// FROST (Flexible Round-Optimized Schnorr Threshold) Signature Aggregation -// GPU-accelerated threshold signature operations for Ed25519/secp256k1 -// Optimized for Apple Silicon GPUs - -#include -using namespace metal; - -// ============================================================================ -// Scalar Field Types (Ed25519: 2^252 + ..., secp256k1: FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141) -// ============================================================================ - -// 256-bit scalar (8 x 32-bit limbs) -struct Scalar256 { - uint limbs[8]; -}; - -// Ed25519 curve point (affine) -struct Ed25519Affine { - Scalar256 x; - Scalar256 y; -}; - -// Ed25519 extended coordinates (x, y, z, t) where x = X/Z, y = Y/Z, xy = T/Z -struct Ed25519Extended { - Scalar256 x; - Scalar256 y; - Scalar256 z; - Scalar256 t; -}; - -// secp256k1 affine point -struct Secp256k1Affine { - Scalar256 x; - Scalar256 y; -}; - -// secp256k1 Jacobian point (x, y, z) where X = x/z^2, Y = y/z^3 -struct Secp256k1Jacobian { - Scalar256 x; - Scalar256 y; - Scalar256 z; -}; - -// FROST signature share from a participant -struct FrostSignatureShare { - uint participant_id; - Scalar256 response; // z_i = d_i + e_i * rho + lambda_i * s_i * c - Scalar256 commitment_d; // D_i = g^d_i - Scalar256 commitment_e; // E_i = g^e_i - uint _pad[3]; // Align to 16 bytes -}; - -// FROST parameters -struct FrostParams { - uint num_participants; // n - total participants - uint threshold; // t - threshold - uint curve_type; // 0 = Ed25519, 1 = secp256k1 - uint batch_size; // Number of signatures to verify in parallel -}; - -// ============================================================================ -// Ed25519 Scalar Field Modulus: l = 2^252 + 27742317777372353535851937790883648493 -// ============================================================================ - -constant uint ED25519_L[8] = { - 0x5cf5d3edu, 0x5812631au, 0xa2f79cd6u, 0x14def9deu, - 0x00000000u, 0x00000000u, 0x00000000u, 0x10000000u -}; - -// secp256k1 scalar field modulus n -constant uint SECP256K1_N[8] = { - 0xd0364141u, 0xbfd25e8cu, 0xaf48a03bu, 0xbaaedce6u, - 0xfffffffeu, 0xffffffffu, 0xffffffffu, 0xffffffffu -}; - -// ============================================================================ -// 256-bit Scalar Arithmetic -// ============================================================================ - -inline Scalar256 scalar_zero() { - Scalar256 r; - for (int i = 0; i < 8; i++) r.limbs[i] = 0; - return r; -} - -inline Scalar256 scalar_one() { - Scalar256 r = scalar_zero(); - r.limbs[0] = 1; - return r; -} - -inline bool scalar_is_zero(Scalar256 a) { - for (int i = 0; i < 8; i++) { - if (a.limbs[i] != 0) return false; - } - return true; -} - -inline bool scalar_gte(Scalar256 a, constant uint* mod) { - for (int i = 7; i >= 0; i--) { - if (a.limbs[i] > mod[i]) return true; - if (a.limbs[i] < mod[i]) return false; - } - return true; -} - -inline Scalar256 scalar_add(Scalar256 a, Scalar256 b, constant uint* mod) { - Scalar256 r; - uint carry = 0; - - for (int i = 0; i < 8; i++) { - uint sum = a.limbs[i] + b.limbs[i] + carry; - carry = (sum < a.limbs[i]) || (carry && sum == a.limbs[i]) ? 1 : 0; - r.limbs[i] = sum; - } - - // Reduce if >= mod - if (carry || scalar_gte(r, mod)) { - uint borrow = 0; - for (int i = 0; i < 8; i++) { - uint diff = r.limbs[i] - mod[i] - borrow; - borrow = (r.limbs[i] < mod[i] + borrow) ? 1 : 0; - r.limbs[i] = diff; - } - } - - return r; -} - -inline Scalar256 scalar_sub(Scalar256 a, Scalar256 b, constant uint* mod) { - Scalar256 r; - uint borrow = 0; - - for (int i = 0; i < 8; i++) { - uint diff = a.limbs[i] - b.limbs[i] - borrow; - borrow = (a.limbs[i] < b.limbs[i] + borrow) ? 1 : 0; - r.limbs[i] = diff; - } - - // If underflow, add modulus - if (borrow) { - uint carry = 0; - for (int i = 0; i < 8; i++) { - uint sum = r.limbs[i] + mod[i] + carry; - carry = (sum < r.limbs[i]) ? 1 : 0; - r.limbs[i] = sum; - } - } - - return r; -} - -inline Scalar256 scalar_neg(Scalar256 a, constant uint* mod) { - if (scalar_is_zero(a)) return a; - - Scalar256 r; - uint borrow = 0; - for (int i = 0; i < 8; i++) { - uint diff = mod[i] - a.limbs[i] - borrow; - borrow = (mod[i] < a.limbs[i] + borrow) ? 1 : 0; - r.limbs[i] = diff; - } - return r; -} - -// Multiply two 256-bit scalars with reduction (simplified - uses schoolbook) -inline Scalar256 scalar_mul(Scalar256 a, Scalar256 b, constant uint* mod) { - // 512-bit intermediate product - uint product[16] = {0}; - - // Schoolbook multiplication - for (int i = 0; i < 8; i++) { - uint carry = 0; - for (int j = 0; j < 8; j++) { - ulong prod = (ulong)a.limbs[i] * (ulong)b.limbs[j] + product[i + j] + carry; - product[i + j] = (uint)prod; - carry = (uint)(prod >> 32); - } - product[i + 8] = carry; - } - - // Barrett reduction (simplified - iterative subtraction for demo) - // Production code would use precomputed Barrett factor - Scalar256 r; - for (int i = 0; i < 8; i++) { - r.limbs[i] = product[i]; - } - - // Iterative reduction - while (scalar_gte(r, mod)) { - uint borrow = 0; - for (int i = 0; i < 8; i++) { - uint diff = r.limbs[i] - mod[i] - borrow; - borrow = (r.limbs[i] < mod[i] + borrow) ? 1 : 0; - r.limbs[i] = diff; - } - } - - return r; -} - -// ============================================================================ -// Lagrange Coefficient Computation -// ============================================================================ - -// Compute Lagrange coefficient: lambda_i = prod_{j!=i} (0 - j) / (i - j) mod l -// For FROST, we evaluate at x=0 for secret reconstruction -inline Scalar256 compute_lagrange_coeff( - uint participant_id, - device const uint* participant_ids, - uint num_participants, - constant uint* mod -) { - Scalar256 numerator = scalar_one(); - Scalar256 denominator = scalar_one(); - - for (uint j = 0; j < num_participants; j++) { - uint other_id = participant_ids[j]; - if (other_id == participant_id) continue; - - // numerator *= (0 - other_id) = -other_id mod l - Scalar256 neg_j = scalar_zero(); - neg_j.limbs[0] = other_id; - neg_j = scalar_neg(neg_j, mod); - numerator = scalar_mul(numerator, neg_j, mod); - - // denominator *= (participant_id - other_id) - Scalar256 diff = scalar_zero(); - if (participant_id > other_id) { - diff.limbs[0] = participant_id - other_id; - } else { - diff.limbs[0] = other_id - participant_id; - diff = scalar_neg(diff, mod); - } - denominator = scalar_mul(denominator, diff, mod); - } - - // Compute denominator inverse using Fermat's little theorem: a^(-1) = a^(l-2) mod l - // Simplified: for production, use extended GCD or precomputed inverses - Scalar256 inv = denominator; - Scalar256 exp = scalar_zero(); - for (int i = 0; i < 8; i++) exp.limbs[i] = mod[i]; - exp.limbs[0] -= 2; // l - 2 - - Scalar256 result = scalar_one(); - while (!scalar_is_zero(exp)) { - if (exp.limbs[0] & 1) { - result = scalar_mul(result, inv, mod); - } - inv = scalar_mul(inv, inv, mod); - // Right shift exp by 1 - uint carry = 0; - for (int i = 7; i >= 0; i--) { - uint new_val = (exp.limbs[i] >> 1) | (carry << 31); - carry = exp.limbs[i] & 1; - exp.limbs[i] = new_val; - } - } - - return scalar_mul(numerator, result, mod); -} - -// ============================================================================ -// FROST Signature Aggregation Kernels -// ============================================================================ - -// Kernel 1: Aggregate signature shares into final signature -// z = sum(lambda_i * z_i) for all participating signers -kernel void frost_aggregate_shares( - device const FrostSignatureShare* shares [[buffer(0)]], - device const uint* participant_ids [[buffer(1)]], - device Scalar256* aggregated_response [[buffer(2)]], - constant FrostParams& params [[buffer(3)]], - uint gid [[thread_position_in_grid]] -) { - if (gid != 0) return; // Single-threaded aggregation (can be parallelized with reduction) - - constant uint* mod = (params.curve_type == 0) ? ED25519_L : SECP256K1_N; - - Scalar256 sum = scalar_zero(); - - for (uint i = 0; i < params.num_participants; i++) { - // Compute Lagrange coefficient for this participant - Scalar256 lambda = compute_lagrange_coeff( - shares[i].participant_id, - participant_ids, - params.num_participants, - mod - ); - - // Add lambda_i * z_i to the sum - Scalar256 weighted = scalar_mul(lambda, shares[i].response, mod); - sum = scalar_add(sum, weighted, mod); - } - - aggregated_response[0] = sum; -} - -// Kernel 2: Batch verify partial signatures (each thread verifies one share) -kernel void frost_verify_partial_signatures( - device const FrostSignatureShare* shares [[buffer(0)]], - device const Ed25519Affine* public_keys [[buffer(1)]], // Participant public keys - device const Scalar256* challenge [[buffer(2)]], // c = H(R, Y, m) - device const Scalar256* binding_factors [[buffer(3)]], // rho_i - device uint* verification_results [[buffer(4)]], - constant FrostParams& params [[buffer(5)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= params.num_participants) return; - - // Verify: g^z_i == D_i * E_i^rho_i * Y_i^(c * lambda_i) - // This is a Schnorr verification for the partial signature - - FrostSignatureShare share = shares[gid]; - - // For now, set all as valid - full implementation needs point arithmetic - // The actual verification requires: - // 1. Compute left side: G * z_i - // 2. Compute right side: D_i + rho_i * E_i + (c * lambda_i) * Y_i - // 3. Compare points - - // Placeholder: verify share is well-formed - bool valid = !scalar_is_zero(share.response); - valid = valid && (share.participant_id > 0); - valid = valid && (share.participant_id <= params.num_participants); - - verification_results[gid] = valid ? 1 : 0; -} - -// Kernel 3: Compute group commitment R = sum(D_i + rho_i * E_i) -// This runs in parallel for each participant, then requires reduction -kernel void frost_compute_group_commitment( - device const FrostSignatureShare* shares [[buffer(0)]], - device const Scalar256* binding_factors [[buffer(1)]], - device Ed25519Extended* partial_commitments [[buffer(2)]], - constant FrostParams& params [[buffer(3)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= params.num_participants) return; - - FrostSignatureShare share = shares[gid]; - Scalar256 rho = binding_factors[gid]; - - // Compute: R_i = D_i + rho_i * E_i - // This requires point multiplication and addition - // Placeholder: store D_i as the partial commitment - - Ed25519Extended result; - result.x = share.commitment_d; - result.y = share.commitment_e; - result.z = scalar_one(); - result.t = scalar_zero(); - - partial_commitments[gid] = result; -} - -// Kernel 4: Parallel reduction for aggregating commitments -kernel void frost_reduce_commitments( - device Ed25519Extended* commitments [[buffer(0)]], - constant uint& count [[buffer(1)]], - uint gid [[thread_position_in_grid]], - uint threads [[threads_per_grid]] -) { - // Tree reduction - add pairs of commitments - uint stride = 1; - while (stride < count) { - if (gid < count / (2 * stride)) { - uint i = gid * 2 * stride; - uint j = i + stride; - - if (j < count) { - // Add commitments[i] and commitments[j] - // Placeholder: just add x coordinates for now - commitments[i].x = scalar_add( - commitments[i].x, - commitments[j].x, - ED25519_L - ); - } - } - stride *= 2; - threadgroup_barrier(mem_flags::mem_device); - } -} - -// Kernel 5: Batch signature verification (verify multiple aggregated signatures) -kernel void frost_batch_verify( - device const Scalar256* responses [[buffer(0)]], // z values - device const Ed25519Affine* group_commitments [[buffer(1)]], // R values - device const Ed25519Affine* group_public_keys [[buffer(2)]], // Y values - device const Scalar256* challenges [[buffer(3)]], // c values - device uint* results [[buffer(4)]], - constant uint& batch_size [[buffer(5)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= batch_size) return; - - // Verify: g^z == R + c*Y - // This is the final Schnorr signature verification - - Scalar256 z = responses[gid]; - Scalar256 c = challenges[gid]; - - // For full implementation: - // 1. Compute G * z (scalar multiplication) - // 2. Compute c * Y (scalar multiplication) - // 3. Compute R + c*Y (point addition) - // 4. Compare with G * z - - // Placeholder: basic sanity checks - bool valid = !scalar_is_zero(z); - valid = valid && !scalar_is_zero(c); - - results[gid] = valid ? 1 : 0; -} - -// Kernel 6: Compute challenge hash contributions (for batched hashing) -kernel void frost_challenge_precompute( - device const Ed25519Affine* group_commitments [[buffer(0)]], - device const Ed25519Affine* group_public_keys [[buffer(1)]], - device const Scalar256* messages [[buffer(2)]], - device Scalar256* hash_inputs [[buffer(3)]], - constant uint& batch_size [[buffer(4)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= batch_size) return; - - // Prepare input for H(R || Y || m) - // The actual hashing would be done on CPU or with a hash kernel - - // For each signature, concatenate R, Y, m - uint base = gid * 3; - hash_inputs[base] = group_commitments[gid].x; - hash_inputs[base + 1] = group_public_keys[gid].x; - hash_inputs[base + 2] = messages[gid]; -} diff --git a/frost/gpu/metal/frost_nonce.metal b/frost/gpu/metal/frost_nonce.metal deleted file mode 100644 index 20d0f00..0000000 --- a/frost/gpu/metal/frost_nonce.metal +++ /dev/null @@ -1,613 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// FROST Nonce Generation and Commitment Operations -// Batch nonce generation, hash-to-curve, and binding factor computation -// Optimized for Apple Silicon GPUs - -#include -using namespace metal; - -// ============================================================================ -// Types (shared with frost_aggregate.metal) -// ============================================================================ - -struct Scalar256 { - uint limbs[8]; -}; - -struct Ed25519Affine { - Scalar256 x; - Scalar256 y; -}; - -struct Ed25519Extended { - Scalar256 x; - Scalar256 y; - Scalar256 z; - Scalar256 t; -}; - -struct NonceCommitment { - Scalar256 hiding_nonce_d; // d_i - Scalar256 binding_nonce_e; // e_i - Ed25519Affine commitment_d; // D_i = g^d_i - Ed25519Affine commitment_e; // E_i = g^e_i -}; - -struct NonceParams { - uint num_participants; - uint seed_entropy_offset; - uint curve_type; // 0 = Ed25519, 1 = secp256k1 - uint batch_size; -}; - -// SHA-512 state for hash-to-scalar -struct SHA512State { - ulong h[8]; - uint total_len; - uint _pad; -}; - -// ============================================================================ -// Ed25519 Constants -// ============================================================================ - -constant uint ED25519_L[8] = { - 0x5cf5d3edu, 0x5812631au, 0xa2f79cd6u, 0x14def9deu, - 0x00000000u, 0x00000000u, 0x00000000u, 0x10000000u -}; - -// Ed25519 base point G (compressed y, x calculated) -constant uint ED25519_GY[8] = { - 0x58666666u, 0x66666666u, 0x66666666u, 0x66666666u, - 0x66666666u, 0x66666666u, 0x66666666u, 0x66666666u -}; - -// secp256k1 constants -constant uint SECP256K1_N[8] = { - 0xd0364141u, 0xbfd25e8cu, 0xaf48a03bu, 0xbaaedce6u, - 0xfffffffeu, 0xffffffffu, 0xffffffffu, 0xffffffffu -}; - -// ============================================================================ -// Scalar Arithmetic (from frost_aggregate) -// ============================================================================ - -inline Scalar256 scalar_zero() { - Scalar256 r; - for (int i = 0; i < 8; i++) r.limbs[i] = 0; - return r; -} - -inline Scalar256 scalar_one() { - Scalar256 r = scalar_zero(); - r.limbs[0] = 1; - return r; -} - -inline bool scalar_is_zero(Scalar256 a) { - for (int i = 0; i < 8; i++) { - if (a.limbs[i] != 0) return false; - } - return true; -} - -inline bool scalar_gte(Scalar256 a, constant uint* mod) { - for (int i = 7; i >= 0; i--) { - if (a.limbs[i] > mod[i]) return true; - if (a.limbs[i] < mod[i]) return false; - } - return true; -} - -inline Scalar256 scalar_add(Scalar256 a, Scalar256 b, constant uint* mod) { - Scalar256 r; - uint carry = 0; - - for (int i = 0; i < 8; i++) { - uint sum = a.limbs[i] + b.limbs[i] + carry; - carry = (sum < a.limbs[i]) || (carry && sum == a.limbs[i]) ? 1 : 0; - r.limbs[i] = sum; - } - - if (carry || scalar_gte(r, mod)) { - uint borrow = 0; - for (int i = 0; i < 8; i++) { - uint diff = r.limbs[i] - mod[i] - borrow; - borrow = (r.limbs[i] < mod[i] + borrow) ? 1 : 0; - r.limbs[i] = diff; - } - } - - return r; -} - -inline Scalar256 scalar_mul(Scalar256 a, Scalar256 b, constant uint* mod) { - uint product[16] = {0}; - - for (int i = 0; i < 8; i++) { - uint carry = 0; - for (int j = 0; j < 8; j++) { - ulong prod = (ulong)a.limbs[i] * (ulong)b.limbs[j] + product[i + j] + carry; - product[i + j] = (uint)prod; - carry = (uint)(prod >> 32); - } - product[i + 8] = carry; - } - - Scalar256 r; - for (int i = 0; i < 8; i++) { - r.limbs[i] = product[i]; - } - - while (scalar_gte(r, mod)) { - uint borrow = 0; - for (int i = 0; i < 8; i++) { - uint diff = r.limbs[i] - mod[i] - borrow; - borrow = (r.limbs[i] < mod[i] + borrow) ? 1 : 0; - r.limbs[i] = diff; - } - } - - return r; -} - -// ============================================================================ -// Pseudo-Random Number Generation (ChaCha20-based) -// ============================================================================ - -struct ChaCha20State { - uint state[16]; -}; - -inline uint rotl32(uint x, int n) { - return (x << n) | (x >> (32 - n)); -} - -inline void chacha_quarter_round(thread uint* a, thread uint* b, thread uint* c, thread uint* d) { - *a += *b; *d ^= *a; *d = rotl32(*d, 16); - *c += *d; *b ^= *c; *b = rotl32(*b, 12); - *a += *b; *d ^= *a; *d = rotl32(*d, 8); - *c += *d; *b ^= *c; *b = rotl32(*b, 7); -} - -inline ChaCha20State chacha20_init(Scalar256 key, ulong counter, ulong nonce) { - ChaCha20State s; - - // "expand 32-byte k" - s.state[0] = 0x61707865u; - s.state[1] = 0x3320646eu; - s.state[2] = 0x79622d32u; - s.state[3] = 0x6b206574u; - - // Key - for (int i = 0; i < 8; i++) { - s.state[4 + i] = key.limbs[i]; - } - - // Counter - s.state[12] = (uint)counter; - s.state[13] = (uint)(counter >> 32); - - // Nonce - s.state[14] = (uint)nonce; - s.state[15] = (uint)(nonce >> 32); - - return s; -} - -inline void chacha20_block(thread ChaCha20State* s) { - uint working[16]; - for (int i = 0; i < 16; i++) working[i] = s->state[i]; - - // 20 rounds (10 double rounds) - for (int i = 0; i < 10; i++) { - // Column rounds - chacha_quarter_round(&working[0], &working[4], &working[8], &working[12]); - chacha_quarter_round(&working[1], &working[5], &working[9], &working[13]); - chacha_quarter_round(&working[2], &working[6], &working[10], &working[14]); - chacha_quarter_round(&working[3], &working[7], &working[11], &working[15]); - - // Diagonal rounds - chacha_quarter_round(&working[0], &working[5], &working[10], &working[15]); - chacha_quarter_round(&working[1], &working[6], &working[11], &working[12]); - chacha_quarter_round(&working[2], &working[7], &working[8], &working[13]); - chacha_quarter_round(&working[3], &working[4], &working[9], &working[14]); - } - - // Add original state - for (int i = 0; i < 16; i++) { - s->state[i] += working[i]; - } - - // Increment counter - s->state[12]++; - if (s->state[12] == 0) s->state[13]++; -} - -// Generate a random scalar in [1, l-1] -inline Scalar256 random_scalar(thread ChaCha20State* rng, constant uint* mod) { - Scalar256 r; - - do { - chacha20_block(rng); - for (int i = 0; i < 8; i++) { - r.limbs[i] = rng->state[i]; - } - - // Clear top bits to ensure < 2^256 - r.limbs[7] &= 0x0FFFFFFFu; - - // Reduce mod l - while (scalar_gte(r, mod)) { - uint borrow = 0; - for (int i = 0; i < 8; i++) { - uint diff = r.limbs[i] - mod[i] - borrow; - borrow = (r.limbs[i] < mod[i] + borrow) ? 1 : 0; - r.limbs[i] = diff; - } - } - } while (scalar_is_zero(r)); // Retry if zero - - return r; -} - -// ============================================================================ -// Scalar Multiplication (simplified Montgomery ladder) -// ============================================================================ - -// Placeholder point operations - full implementation needs complete curve arithmetic -inline Ed25519Extended ed25519_identity() { - Ed25519Extended r; - r.x = scalar_zero(); - r.y = scalar_one(); - r.z = scalar_one(); - r.t = scalar_zero(); - return r; -} - -inline Ed25519Extended ed25519_double(Ed25519Extended p) { - // Placeholder - actual implementation requires full field arithmetic - Ed25519Extended r = p; - r.z = scalar_add(p.z, p.z, ED25519_L); - return r; -} - -inline Ed25519Extended ed25519_add(Ed25519Extended p, Ed25519Extended q) { - // Placeholder - actual implementation requires full field arithmetic - Ed25519Extended r; - r.x = scalar_add(p.x, q.x, ED25519_L); - r.y = scalar_add(p.y, q.y, ED25519_L); - r.z = scalar_add(p.z, q.z, ED25519_L); - r.t = scalar_add(p.t, q.t, ED25519_L); - return r; -} - -// Compute G * scalar using double-and-add -inline Ed25519Extended ed25519_scalar_mul_base(Scalar256 scalar) { - Ed25519Extended result = ed25519_identity(); - Ed25519Extended base; - - // Base point - for (int i = 0; i < 8; i++) base.y.limbs[i] = ED25519_GY[i]; - base.x = scalar_zero(); // Would need to compute from y - base.z = scalar_one(); - base.t = scalar_zero(); - - // Double-and-add - for (int bit = 0; bit < 256; bit++) { - int limb = bit / 32; - int bit_in_limb = bit % 32; - - if ((scalar.limbs[limb] >> bit_in_limb) & 1) { - result = ed25519_add(result, base); - } - base = ed25519_double(base); - } - - return result; -} - -inline Ed25519Affine ed25519_to_affine(Ed25519Extended ext) { - Ed25519Affine r; - // Would need modular inversion: x = X/Z, y = Y/Z - r.x = ext.x; - r.y = ext.y; - return r; -} - -// ============================================================================ -// Hash-to-Scalar (RFC 8032 style) -// ============================================================================ - -// Simple hash mixing function (for binding factor derivation) - device buffer version -inline Scalar256 hash_to_scalar_device( - device const uint* data, - uint data_len, - constant uint* mod -) { - // Simplified hash - production would use SHA-512 and reduce - Scalar256 r = scalar_zero(); - - for (uint i = 0; i < data_len && i < 8; i++) { - r.limbs[i] = data[i]; - } - - // Mix - for (int round = 0; round < 4; round++) { - for (int i = 0; i < 8; i++) { - r.limbs[i] ^= rotl32(r.limbs[(i + 1) % 8], 7); - r.limbs[i] += r.limbs[(i + 3) % 8]; - } - } - - // Reduce mod l - while (scalar_gte(r, mod)) { - uint borrow = 0; - for (int i = 0; i < 8; i++) { - uint diff = r.limbs[i] - mod[i] - borrow; - borrow = (r.limbs[i] < mod[i] + borrow) ? 1 : 0; - r.limbs[i] = diff; - } - } - - return r; -} - -// Thread-local array version for in-kernel computed hash inputs -inline Scalar256 hash_to_scalar( - thread const uint* data, - uint data_len, - constant uint* mod -) { - // Simplified hash - production would use SHA-512 and reduce - Scalar256 r = scalar_zero(); - - for (uint i = 0; i < data_len && i < 8; i++) { - r.limbs[i] = data[i]; - } - - // Mix - for (int round = 0; round < 4; round++) { - for (int i = 0; i < 8; i++) { - r.limbs[i] ^= rotl32(r.limbs[(i + 1) % 8], 7); - r.limbs[i] += r.limbs[(i + 3) % 8]; - } - } - - // Reduce mod l - while (scalar_gte(r, mod)) { - uint borrow = 0; - for (int i = 0; i < 8; i++) { - uint diff = r.limbs[i] - mod[i] - borrow; - borrow = (r.limbs[i] < mod[i] + borrow) ? 1 : 0; - r.limbs[i] = diff; - } - } - - return r; -} - -// ============================================================================ -// FROST Nonce Generation Kernels -// ============================================================================ - -// Generate nonce pair (d_i, e_i) for each participant -kernel void frost_generate_nonces( - device const Scalar256* seeds [[buffer(0)]], // Per-participant seeds - device NonceCommitment* nonces [[buffer(1)]], // Output nonce commitments - constant NonceParams& params [[buffer(2)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= params.num_participants) return; - - constant uint* mod = (params.curve_type == 0) ? ED25519_L : SECP256K1_N; - - // Initialize RNG from seed - Scalar256 seed = seeds[gid]; - ulong counter = gid; - ulong nonce_val = params.seed_entropy_offset; - ChaCha20State rng = chacha20_init(seed, counter, nonce_val); - - // Generate hiding nonce d_i - Scalar256 d = random_scalar(&rng, mod); - - // Generate binding nonce e_i - Scalar256 e = random_scalar(&rng, mod); - - // Compute commitments D_i = g^d_i, E_i = g^e_i - Ed25519Extended D_ext = ed25519_scalar_mul_base(d); - Ed25519Extended E_ext = ed25519_scalar_mul_base(e); - - // Store results - NonceCommitment result; - result.hiding_nonce_d = d; - result.binding_nonce_e = e; - result.commitment_d = ed25519_to_affine(D_ext); - result.commitment_e = ed25519_to_affine(E_ext); - - nonces[gid] = result; -} - -// Compute binding factors rho_i = H(i, m, B) where B is list of commitments -kernel void frost_compute_binding_factors( - device const uint* participant_ids [[buffer(0)]], - device const Ed25519Affine* commitment_list [[buffer(1)]], // All D_i, E_i pairs - device const Scalar256* message [[buffer(2)]], - device Scalar256* binding_factors [[buffer(3)]], - constant NonceParams& params [[buffer(4)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= params.num_participants) return; - - constant uint* mod = (params.curve_type == 0) ? ED25519_L : SECP256K1_N; - - // Construct hash input: (participant_id, message, commitment_list_hash) - uint hash_input[16]; - - // Participant ID - hash_input[0] = participant_ids[gid]; - - // Message (first few words) - Scalar256 msg = message[0]; - for (int i = 0; i < 8; i++) { - hash_input[1 + i] = msg.limbs[i]; - } - - // Commitment list contribution (simplified - XOR all commitments) - uint commitment_hash = 0; - for (uint i = 0; i < params.num_participants * 2; i++) { - Ed25519Affine c = commitment_list[i]; - commitment_hash ^= c.x.limbs[0] ^ c.y.limbs[0]; - } - hash_input[9] = commitment_hash; - - // Compute rho_i = H(hash_input) - Scalar256 rho = hash_to_scalar(hash_input, 10, mod); - - binding_factors[gid] = rho; -} - -// Compute group commitment R = sum(D_i + rho_i * E_i) -kernel void frost_compute_commitment_shares( - device const NonceCommitment* nonces [[buffer(0)]], - device const Scalar256* binding_factors [[buffer(1)]], - device Ed25519Extended* commitment_shares [[buffer(2)]], - constant NonceParams& params [[buffer(3)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= params.num_participants) return; - - constant uint* mod = (params.curve_type == 0) ? ED25519_L : SECP256K1_N; - - NonceCommitment nonce = nonces[gid]; - Scalar256 rho = binding_factors[gid]; - - // Compute R_i = D_i + rho_i * E_i - // First compute rho_i * E_i (scalar multiplication) - Ed25519Extended rho_E; - rho_E.x = scalar_mul(rho, nonce.commitment_e.x, mod); - rho_E.y = scalar_mul(rho, nonce.commitment_e.y, mod); - rho_E.z = scalar_one(); - rho_E.t = scalar_zero(); - - // Then add D_i - Ed25519Extended D; - D.x = nonce.commitment_d.x; - D.y = nonce.commitment_d.y; - D.z = scalar_one(); - D.t = scalar_zero(); - - Ed25519Extended R_i = ed25519_add(D, rho_E); - - commitment_shares[gid] = R_i; -} - -// Parallel reduction to sum commitment shares into group commitment R -kernel void frost_aggregate_commitments( - device Ed25519Extended* commitment_shares [[buffer(0)]], - device Ed25519Extended* group_commitment [[buffer(1)]], - constant uint& count [[buffer(2)]], - uint gid [[thread_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint group_size [[threads_per_threadgroup]] -) { - // Tree reduction within threadgroup - threadgroup Ed25519Extended shared_data[256]; - - if (gid < count) { - shared_data[lid] = commitment_shares[gid]; - } else { - shared_data[lid] = ed25519_identity(); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Reduction - for (uint stride = group_size / 2; stride > 0; stride /= 2) { - if (lid < stride && lid + stride < count) { - shared_data[lid] = ed25519_add(shared_data[lid], shared_data[lid + stride]); - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - // Write result from first thread - if (lid == 0) { - group_commitment[gid / group_size] = shared_data[0]; - } -} - -// Batch hash-to-curve for multiple messages -kernel void frost_batch_hash_to_curve( - device const Scalar256* messages [[buffer(0)]], - device Ed25519Extended* curve_points [[buffer(1)]], - constant uint& batch_size [[buffer(2)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= batch_size) return; - - Scalar256 msg = messages[gid]; - - // Elligator2 or similar hash-to-curve - // Simplified: use message as scalar and compute G * H(m) - Scalar256 h = msg; - - // Mix to get uniform distribution - for (int i = 0; i < 8; i++) { - h.limbs[i] ^= rotl32(h.limbs[(i + 1) % 8], 11); - h.limbs[i] += rotl32(h.limbs[(i + 5) % 8], 7); - } - - // Reduce mod l - while (scalar_gte(h, ED25519_L)) { - uint borrow = 0; - for (int i = 0; i < 8; i++) { - uint diff = h.limbs[i] - ED25519_L[i] - borrow; - borrow = (h.limbs[i] < ED25519_L[i] + borrow) ? 1 : 0; - h.limbs[i] = diff; - } - } - - // Compute H(m) * G - curve_points[gid] = ed25519_scalar_mul_base(h); -} - -// Verify nonce commitment (D_i = g^d_i) -kernel void frost_verify_nonce_commitments( - device const NonceCommitment* nonces [[buffer(0)]], - device uint* valid [[buffer(1)]], - constant NonceParams& params [[buffer(2)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= params.num_participants) return; - - NonceCommitment nonce = nonces[gid]; - - // Verify D_i = g^d_i - Ed25519Extended computed_D = ed25519_scalar_mul_base(nonce.hiding_nonce_d); - Ed25519Affine computed_D_affine = ed25519_to_affine(computed_D); - - // Check equality (simplified - just check x coordinates) - bool d_valid = true; - for (int i = 0; i < 8; i++) { - if (computed_D_affine.x.limbs[i] != nonce.commitment_d.x.limbs[i]) { - d_valid = false; - break; - } - } - - // Verify E_i = g^e_i - Ed25519Extended computed_E = ed25519_scalar_mul_base(nonce.binding_nonce_e); - Ed25519Affine computed_E_affine = ed25519_to_affine(computed_E); - - bool e_valid = true; - for (int i = 0; i < 8; i++) { - if (computed_E_affine.x.limbs[i] != nonce.commitment_e.x.limbs[i]) { - e_valid = false; - break; - } - } - - valid[gid] = (d_valid && e_valid) ? 1 : 0; -} diff --git a/frost/gpu/metal/frost_presign.metal b/frost/gpu/metal/frost_presign.metal deleted file mode 100644 index 4e8ad1d..0000000 --- a/frost/gpu/metal/frost_presign.metal +++ /dev/null @@ -1,533 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// FROST batched pre-signing kernel — Metal compute shader. -// -// Each thread = one (signer, slot) pair. Generates the nonce pair (d_i, e_i) -// in private/thread memory and writes only the public commitment -// (D_i, E_i) = (d_i*G, e_i*G) to device memory. -// -// Byte-equal to the CPU canonical body in frost/cpp/presign.cpp. Same HKDF -// expansion, same rejection rule, same fixed-time scalar-mul-base. The -// scalar field arithmetic + curve formulas are identical to those used by -// secp256k1_recover.metal so the disassembly chain is uniform. -// -// GPU residency invariant. Nonces live exclusively in thread (private) -// address space. The kernel never writes (d_i, e_i) to a device buffer; -// only the 33-byte compressed (D_i, E_i) lands in commits_out. Verifiable -// via metallib-disassemble: the kernel disassembly contains zero -// `device.write` instructions on uint256 nonce buffers, and the only -// device-side stores target commits_out (uchar) — see -// frost/test/frost_presign_disasm_test.cpp. - -#include -using namespace metal; - -// ============================================================================= -// 256-bit integer (matches CPU body U256 layout, little-endian limbs) -// ============================================================================= - -struct uint256_t { - ulong limbs[4]; -}; - -// secp256k1 field prime p = 2^256 - 2^32 - 977 -constant uint256_t FP_P = {{ - 0xFFFFFFFEFFFFFC2FUL, 0xFFFFFFFFFFFFFFFFUL, - 0xFFFFFFFFFFFFFFFFUL, 0xFFFFFFFFFFFFFFFFUL -}}; - -// Curve order n -constant uint256_t FP_N = {{ - 0xBFD25E8CD0364141UL, 0xBAAEDCE6AF48A03BUL, - 0xFFFFFFFFFFFFFFFEUL, 0xFFFFFFFFFFFFFFFFUL -}}; - -// Generator G in plain (non-Montgomery) form -constant uint256_t G_X = {{ - 0x59F2815B16F81798UL, 0x029BFCDB2DCE28D9UL, - 0x55A06295CE870B07UL, 0x79BE667EF9DCBBACUL -}}; -constant uint256_t G_Y = {{ - 0x9C47D08FFB10D4B8UL, 0xFD17B448A6855419UL, - 0x5DA4FBFC0E1108A8UL, 0x483ADA7726A3C465UL -}}; - -// Montgomery constants for field p -constant uint256_t R2_P = {{ - 0x000007A2000E90A1UL, 0x0000000000000001UL, 0UL, 0UL -}}; -constant ulong P_INV = 0xD838091DD2253531UL; -constant uint256_t MONT_R = {{0x00000001000003D1UL, 0UL, 0UL, 0UL}}; -constant uint256_t ZERO256 = {{0, 0, 0, 0}}; -constant uint256_t ONE256 = {{1, 0, 0, 0}}; - -// ============================================================================= -// 256-bit arithmetic — reused from secp256k1_recover.metal pattern -// ============================================================================= - -inline int u256_cmp(uint256_t a, uint256_t b) { - for (int i = 3; i >= 0; --i) { - if (a.limbs[i] < b.limbs[i]) return -1; - if (a.limbs[i] > b.limbs[i]) return 1; - } - return 0; -} - -inline bool u256_is_zero(uint256_t a) { - return (a.limbs[0] | a.limbs[1] | a.limbs[2] | a.limbs[3]) == 0UL; -} - -inline uint256_t u256_add(uint256_t a, uint256_t b, thread ulong& carry) { - uint256_t r; - ulong c = 0; - for (int i = 0; i < 4; ++i) { - ulong sum = a.limbs[i] + c; - c = (sum < a.limbs[i]) ? 1UL : 0UL; - ulong sum2 = sum + b.limbs[i]; - c += (sum2 < sum) ? 1UL : 0UL; - r.limbs[i] = sum2; - } - carry = c; - return r; -} - -inline uint256_t u256_sub(uint256_t a, uint256_t b, thread ulong& borrow) { - uint256_t r; - ulong bw = 0; - for (int i = 0; i < 4; ++i) { - ulong diff = a.limbs[i] - bw; - bw = (diff > a.limbs[i]) ? 1UL : 0UL; - ulong diff2 = diff - b.limbs[i]; - bw += (diff2 > diff) ? 1UL : 0UL; - r.limbs[i] = diff2; - } - borrow = bw; - return r; -} - -inline void mul64(ulong a, ulong b, thread ulong& lo, thread ulong& hi) { - ulong al = a & 0xFFFFFFFFUL, ah = a >> 32; - ulong bl = b & 0xFFFFFFFFUL, bh = b >> 32; - ulong ll = al * bl, lh = al * bh, hl = ah * bl, hh = ah * bh; - ulong mid = lh + (ll >> 32); - ulong mid2 = mid + hl; - if (mid2 < mid) hh += (1UL << 32); - lo = (mid2 << 32) | (ll & 0xFFFFFFFFUL); - hi = hh + (mid2 >> 32); -} - -inline uint256_t mont_reduce(thread ulong t[8], uint256_t m, ulong inv) { - ulong a[9]; - for (int i = 0; i < 8; ++i) a[i] = t[i]; - a[8] = 0; - for (int i = 0; i < 4; ++i) { - ulong u = a[i] * inv; - ulong carry = 0; - for (int j = 0; j < 4; ++j) { - ulong lo, hi; - mul64(u, m.limbs[j], lo, hi); - ulong sum = lo + carry; if (sum < lo) hi++; - sum = a[i + j] + sum; if (sum < a[i + j]) hi++; - a[i + j] = sum; - carry = hi; - } - for (int j = 4; i + j <= 8; ++j) { - ulong sum = a[i + j] + carry; - carry = (sum < a[i + j]) ? 1UL : 0UL; - a[i + j] = sum; - if (!carry) break; - } - } - uint256_t r = {{a[4], a[5], a[6], a[7]}}; - if (a[8] || u256_cmp(r, m) >= 0) { - ulong bw; r = u256_sub(r, m, bw); - } - return r; -} - -inline uint256_t fp_mul(uint256_t a, uint256_t b) { - ulong t[8] = {0}; - for (int i = 0; i < 4; ++i) { - ulong carry = 0; - for (int j = 0; j < 4; ++j) { - ulong lo, hi; - mul64(a.limbs[i], b.limbs[j], lo, hi); - ulong sum = lo + carry; if (sum < lo) hi++; - sum = t[i + j] + sum; if (sum < t[i + j]) hi++; - t[i + j] = sum; - carry = hi; - } - t[i + 4] = carry; - } - return mont_reduce(t, FP_P, P_INV); -} - -inline uint256_t fp_sqr(uint256_t a) { return fp_mul(a, a); } - -inline uint256_t fp_add(uint256_t a, uint256_t b) { - ulong c; uint256_t r = u256_add(a, b, c); - if (c || u256_cmp(r, FP_P) >= 0) { ulong bw; r = u256_sub(r, FP_P, bw); } - return r; -} - -inline uint256_t fp_sub(uint256_t a, uint256_t b) { - ulong bw; uint256_t r = u256_sub(a, b, bw); - if (bw) { ulong c; r = u256_add(r, FP_P, c); } - return r; -} - -inline uint256_t to_mont(uint256_t a) { return fp_mul(a, R2_P); } - -inline uint256_t fp_inv(uint256_t a) { - // p - 2 - uint256_t exp = FP_P; exp.limbs[0] -= 2; - uint256_t result = MONT_R, base = a; - for (int i = 0; i < 4; ++i) - for (int bit = 0; bit < 64; ++bit) { - if ((exp.limbs[i] >> bit) & 1) result = fp_mul(result, base); - base = fp_sqr(base); - } - return result; -} - -// ============================================================================= -// Jacobian point ops -// ============================================================================= - -struct Point { uint256_t x, y, z; }; - -inline Point jac_zero() { Point p; p.x = MONT_R; p.y = MONT_R; p.z = ZERO256; return p; } - -inline bool jac_is_inf(Point p) { return u256_is_zero(p.z); } - -inline Point jac_double(Point p) { - if (jac_is_inf(p)) return p; - if (u256_is_zero(p.y)) return jac_zero(); - uint256_t A = fp_sqr(p.x), B = fp_sqr(p.y), C = fp_sqr(B); - uint256_t XB = fp_add(p.x, B); - uint256_t D = fp_sub(fp_sub(fp_sqr(XB), A), C); - D = fp_add(D, D); - uint256_t E = fp_add(A, A); E = fp_add(E, A); - uint256_t F = fp_sqr(E); - uint256_t X3 = fp_sub(F, fp_add(D, D)); - uint256_t eC = fp_add(C, C); eC = fp_add(eC, eC); eC = fp_add(eC, eC); - uint256_t Y3 = fp_sub(fp_mul(E, fp_sub(D, X3)), eC); - uint256_t Z3 = fp_mul(p.y, p.z); Z3 = fp_add(Z3, Z3); - Point r; r.x = X3; r.y = Y3; r.z = Z3; return r; -} - -inline Point jac_add_mixed(Point P, uint256_t Qx, uint256_t Qy) { - if (jac_is_inf(P)) { Point r; r.x = Qx; r.y = Qy; r.z = MONT_R; return r; } - uint256_t Z1Z1 = fp_sqr(P.z); - uint256_t U2 = fp_mul(Qx, Z1Z1); - uint256_t S2 = fp_mul(Qy, fp_mul(Z1Z1, P.z)); - uint256_t H = fp_sub(U2, P.x); - uint256_t R = fp_sub(S2, P.y); - if (u256_is_zero(H)) { - if (u256_is_zero(R)) return jac_double(P); - return jac_zero(); - } - uint256_t HH = fp_sqr(H); - uint256_t HHH = fp_mul(H, HH); - uint256_t U1HH = fp_mul(P.x, HH); - uint256_t X3 = fp_sub(fp_sub(fp_sqr(R), HHH), fp_add(U1HH, U1HH)); - uint256_t Y3 = fp_sub(fp_mul(R, fp_sub(U1HH, X3)), fp_mul(P.y, HHH)); - uint256_t Z3 = fp_mul(P.z, H); - Point r; r.x = X3; r.y = Y3; r.z = Z3; return r; -} - -// Constant-time scalar mul k * G via always-double, conditional-add ladder. -// k is in plain (non-Montgomery) form. -inline Point scalar_mul_base(uint256_t k) { - uint256_t Gx = to_mont(G_X); - uint256_t Gy = to_mont(G_Y); - Point r = jac_zero(); - for (int limb = 3; limb >= 0; --limb) { - ulong w = k.limbs[limb]; - for (int bit = 63; bit >= 0; --bit) { - r = jac_double(r); - // Constant-time conditional add — always do the add; cmov result. - Point cand = jac_add_mixed(r, Gx, Gy); - ulong mask = -((w >> bit) & 1UL); - r.x.limbs[0] = (r.x.limbs[0] & ~mask) | (cand.x.limbs[0] & mask); - r.x.limbs[1] = (r.x.limbs[1] & ~mask) | (cand.x.limbs[1] & mask); - r.x.limbs[2] = (r.x.limbs[2] & ~mask) | (cand.x.limbs[2] & mask); - r.x.limbs[3] = (r.x.limbs[3] & ~mask) | (cand.x.limbs[3] & mask); - r.y.limbs[0] = (r.y.limbs[0] & ~mask) | (cand.y.limbs[0] & mask); - r.y.limbs[1] = (r.y.limbs[1] & ~mask) | (cand.y.limbs[1] & mask); - r.y.limbs[2] = (r.y.limbs[2] & ~mask) | (cand.y.limbs[2] & mask); - r.y.limbs[3] = (r.y.limbs[3] & ~mask) | (cand.y.limbs[3] & mask); - r.z.limbs[0] = (r.z.limbs[0] & ~mask) | (cand.z.limbs[0] & mask); - r.z.limbs[1] = (r.z.limbs[1] & ~mask) | (cand.z.limbs[1] & mask); - r.z.limbs[2] = (r.z.limbs[2] & ~mask) | (cand.z.limbs[2] & mask); - r.z.limbs[3] = (r.z.limbs[3] & ~mask) | (cand.z.limbs[3] & mask); - } - } - return r; -} - -inline void jac_to_compressed(Point p, thread uchar out33[33]) { - if (jac_is_inf(p)) { - for (int i = 0; i < 33; ++i) out33[i] = 0; - return; - } - uint256_t zi = fp_inv(p.z); - uint256_t zi2 = fp_sqr(zi); - uint256_t zi3 = fp_mul(zi2, zi); - uint256_t x_mont = fp_mul(p.x, zi2); - uint256_t y_mont = fp_mul(p.y, zi3); - // From Montgomery form: multiply by 1 - uint256_t x_plain = fp_mul(x_mont, ONE256); - uint256_t y_plain = fp_mul(y_mont, ONE256); - out33[0] = (y_plain.limbs[0] & 1UL) ? 0x03 : 0x02; - for (int limb = 0; limb < 4; ++limb) { - int base = (3 - limb) * 8 + 1; - ulong v = x_plain.limbs[limb]; - for (int j = 7; j >= 0; --j) { - out33[base + j] = (uchar)(v & 0xFF); - v >>= 8; - } - } -} - -// ============================================================================= -// SHA-256 (FIPS 180-4) — used for HKDF-Extract / HKDF-Expand / HMAC. -// Single-block / multi-block padding — same body as cevm::crypto::sha256. -// ============================================================================= - -constant uint K256[64] = { - 0x428a2f98u, 0x71374491u, 0xb5c0fbcfu, 0xe9b5dba5u, 0x3956c25bu, 0x59f111f1u, 0x923f82a4u, 0xab1c5ed5u, - 0xd807aa98u, 0x12835b01u, 0x243185beu, 0x550c7dc3u, 0x72be5d74u, 0x80deb1feu, 0x9bdc06a7u, 0xc19bf174u, - 0xe49b69c1u, 0xefbe4786u, 0x0fc19dc6u, 0x240ca1ccu, 0x2de92c6fu, 0x4a7484aau, 0x5cb0a9dcu, 0x76f988dau, - 0x983e5152u, 0xa831c66du, 0xb00327c8u, 0xbf597fc7u, 0xc6e00bf3u, 0xd5a79147u, 0x06ca6351u, 0x14292967u, - 0x27b70a85u, 0x2e1b2138u, 0x4d2c6dfcu, 0x53380d13u, 0x650a7354u, 0x766a0abbu, 0x81c2c92eu, 0x92722c85u, - 0xa2bfe8a1u, 0xa81a664bu, 0xc24b8b70u, 0xc76c51a3u, 0xd192e819u, 0xd6990624u, 0xf40e3585u, 0x106aa070u, - 0x19a4c116u, 0x1e376c08u, 0x2748774cu, 0x34b0bcb5u, 0x391c0cb3u, 0x4ed8aa4au, 0x5b9cca4fu, 0x682e6ff3u, - 0x748f82eeu, 0x78a5636fu, 0x84c87814u, 0x8cc70208u, 0x90befffau, 0xa4506cebu, 0xbef9a3f7u, 0xc67178f2u -}; - -inline uint rotr32(uint x, uint n) { return (x >> n) | (x << (32 - n)); } - -inline void sha256_block(thread uint H[8], thread const uchar block[64]) { - uint W[64]; - for (int i = 0; i < 16; ++i) { - W[i] = ((uint)block[i*4] << 24) | ((uint)block[i*4+1] << 16) | - ((uint)block[i*4+2] << 8) | ((uint)block[i*4+3] ); - } - for (int i = 16; i < 64; ++i) { - uint s0 = rotr32(W[i-15], 7) ^ rotr32(W[i-15], 18) ^ (W[i-15] >> 3); - uint s1 = rotr32(W[i-2], 17) ^ rotr32(W[i-2], 19) ^ (W[i-2] >> 10); - W[i] = W[i-16] + s0 + W[i-7] + s1; - } - uint a=H[0], b=H[1], c=H[2], d=H[3], e=H[4], f=H[5], g=H[6], h=H[7]; - for (int i = 0; i < 64; ++i) { - uint S1 = rotr32(e, 6) ^ rotr32(e, 11) ^ rotr32(e, 25); - uint ch = (e & f) ^ ((~e) & g); - uint t1 = h + S1 + ch + K256[i] + W[i]; - uint S0 = rotr32(a, 2) ^ rotr32(a, 13) ^ rotr32(a, 22); - uint mj = (a & b) ^ (a & c) ^ (b & c); - uint t2 = S0 + mj; - h = g; g = f; f = e; e = d + t1; - d = c; c = b; b = a; a = t1 + t2; - } - H[0]+=a; H[1]+=b; H[2]+=c; H[3]+=d; H[4]+=e; H[5]+=f; H[6]+=g; H[7]+=h; -} - -inline void sha256_finish(thread uint H[8], - thread const uchar* data, uint data_len, - thread uchar out[32]) { - // Buffer up to 64+8 bytes = block + padding length suffix; messages - // here are bounded by HMAC ipad/opad usage so a fixed 256-byte stack - // buffer is enough. - uchar buf[256 + 64]; - for (uint i = 0; i < data_len; ++i) buf[i] = data[i]; - uint len = data_len; - buf[len++] = 0x80; - while ((len % 64) != 56) buf[len++] = 0; - ulong bits = (ulong)data_len * 8UL; - for (int i = 7; i >= 0; --i) buf[len++] = (uchar)(bits >> (i * 8)); - H[0]=0x6a09e667u; H[1]=0xbb67ae85u; H[2]=0x3c6ef372u; H[3]=0xa54ff53au; - H[4]=0x510e527fu; H[5]=0x9b05688cu; H[6]=0x1f83d9abu; H[7]=0x5be0cd19u; - for (uint off = 0; off < len; off += 64) { - uchar block[64]; - for (int i = 0; i < 64; ++i) block[i] = buf[off + i]; - sha256_block(H, block); - } - for (int i = 0; i < 8; ++i) { - out[i*4] = (uchar)(H[i] >> 24); - out[i*4+1] = (uchar)(H[i] >> 16); - out[i*4+2] = (uchar)(H[i] >> 8); - out[i*4+3] = (uchar)(H[i] ); - } -} - -inline void sha256_compute(thread const uchar* data, uint data_len, - thread uchar out[32]) { - uint H[8]; - sha256_finish(H, data, data_len, out); -} - -// HMAC-SHA256. key_len <= 64, msg_len <= 256. -inline void hmac_sha256(thread const uchar* key, uint key_len, - thread const uchar* msg, uint msg_len, - thread uchar out[32]) { - uchar k[64]; - for (uint i = 0; i < 64; ++i) k[i] = 0; - for (uint i = 0; i < key_len && i < 64; ++i) k[i] = key[i]; - - uchar ipad[64 + 256]; - uchar opad[64 + 32]; - for (uint i = 0; i < 64; ++i) { - ipad[i] = k[i] ^ 0x36; - opad[i] = k[i] ^ 0x5c; - } - for (uint i = 0; i < msg_len; ++i) ipad[64 + i] = msg[i]; - - uchar inner[32]; - sha256_compute(ipad, 64 + msg_len, inner); - for (uint i = 0; i < 32; ++i) opad[64 + i] = inner[i]; - sha256_compute(opad, 64 + 32, out); -} - -// HKDF-Expand of fixed length 64 with single-byte info "frost-presign" salt -// applied via HKDF-Extract upstream. info has 12 bytes. -inline void hkdf_expand_64(thread const uchar prk[32], - thread const uchar info[12], - thread uchar out[64]) { - // T(1) = HMAC(PRK, info || 0x01); T(2) = HMAC(PRK, T(1) || info || 0x02) - uchar buf[32 + 12 + 1]; - for (int i = 0; i < 12; ++i) buf[i] = info[i]; - buf[12] = 0x01; - uchar T1[32]; - hmac_sha256(prk, 32, buf, 13, T1); - for (int i = 0; i < 32; ++i) out[i] = T1[i]; - - for (int i = 0; i < 32; ++i) buf[i] = T1[i]; - for (int i = 0; i < 12; ++i) buf[32 + i] = info[i]; - buf[44] = 0x02; - uchar T2[32]; - hmac_sha256(prk, 32, buf, 45, T2); - for (int i = 0; i < 32; ++i) out[32 + i] = T2[i]; -} - -// ============================================================================= -// FROST presign kernel -// ============================================================================= -// -// Threadgrid: 1D, gid in [0, M*N). Mapping: -// signer_idx = gid / N -// slot_idx = gid % N -// signer_id = signer_ids_buf[signer_idx] -// slot_id = slot_id_base + slot_idx -// -// Output (device): -// commits_out[gid * 66 ..] = D[33] || E[33] -// -// Internal (thread address space — never written to device): -// d_be[32], e_be[32] - -inline bool be_to_scalar_lt_n(thread const uchar in_be[32], thread uchar out_be[32]) { - // Build U256 from BE. - uint256_t v; - for (int limb = 0; limb < 4; ++limb) { - int base = (3 - limb) * 8; - ulong w = 0; - for (int j = 0; j < 8; ++j) w = (w << 8) | (ulong)in_be[base + j]; - v.limbs[limb] = w; - } - // Reduce v mod n if v >= n (single sub: bias < 2^-128). - if (u256_cmp(v, FP_N) >= 0) { - ulong bw; v = u256_sub(v, FP_N, bw); - } - if (u256_is_zero(v)) return false; - // Write back BE. - for (int limb = 0; limb < 4; ++limb) { - int base = (3 - limb) * 8; - ulong w = v.limbs[limb]; - for (int j = 7; j >= 0; --j) { out_be[base + j] = (uchar)(w & 0xFF); w >>= 8; } - } - return true; -} - -kernel void frost_presign( - constant uchar* seed [[buffer(0)]], // 32 bytes - constant uint* signer_ids [[buffer(1)]], // m entries - constant uint& m [[buffer(2)]], - constant uint& slot_id_base [[buffer(3)]], - constant uint& n_slots [[buffer(4)]], - device uchar* commits_out [[buffer(5)]], // m*n_slots * 66 bytes - uint gid [[thread_position_in_grid]]) -{ - uint total = m * n_slots; - if (gid >= total) return; - - uint signer_idx = gid / n_slots; - uint slot_idx = gid % n_slots; - uint signer_id = signer_ids[signer_idx]; - uint slot_id = slot_id_base + slot_idx; - if (signer_id == 0) return; - - // 1. HKDF-Extract: PRK = HMAC-SHA256(salt = "frost-presign-v1", ikm = seed) - uchar salt[16] = {'f','r','o','s','t','-','p','r','e','s','i','g','n','-','v','1'}; - uchar seed_local[32]; - for (int i = 0; i < 32; ++i) seed_local[i] = seed[i]; - - uchar prk[32]; - hmac_sha256(salt, 16, seed_local, 32, prk); - - // 2. HKDF-Expand with rejection loop. info = signer_id_le32 || - // slot_id_le32 || ctr_le32. Probability of >1 iteration < 2^-127. - uchar info[12]; - info[0] = (uchar)(signer_id ); info[1] = (uchar)(signer_id >> 8); - info[2] = (uchar)(signer_id >> 16); info[3] = (uchar)(signer_id >> 24); - info[4] = (uchar)(slot_id ); info[5] = (uchar)(slot_id >> 8); - info[6] = (uchar)(slot_id >> 16); info[7] = (uchar)(slot_id >> 24); - - uchar d_be[32], e_be[32]; - bool got_d = false, got_e = false; - uint ctr = 0; - while (!(got_d && got_e)) { - info[ 8] = (uchar)(ctr ); - info[ 9] = (uchar)(ctr >> 8); - info[10] = (uchar)(ctr >> 16); - info[11] = (uchar)(ctr >> 24); - uchar okm[64]; - hkdf_expand_64(prk, info, okm); - if (!got_d) got_d = be_to_scalar_lt_n(okm, d_be); - if (!got_e) got_e = be_to_scalar_lt_n(okm + 32, e_be); - ++ctr; - if (ctr > 1024) return; - } - - // 3. D = d * G, E = e * G — constant-time scalar mul. - uint256_t d_u256; - for (int limb = 0; limb < 4; ++limb) { - int base = (3 - limb) * 8; - ulong w = 0; - for (int j = 0; j < 8; ++j) w = (w << 8) | (ulong)d_be[base + j]; - d_u256.limbs[limb] = w; - } - uint256_t e_u256; - for (int limb = 0; limb < 4; ++limb) { - int base = (3 - limb) * 8; - ulong w = 0; - for (int j = 0; j < 8; ++j) w = (w << 8) | (ulong)e_be[base + j]; - e_u256.limbs[limb] = w; - } - Point D = scalar_mul_base(d_u256); - Point E = scalar_mul_base(e_u256); - - // 4. Compress to sec1, write to device. The only device-side write. - uchar D_bytes[33], E_bytes[33]; - jac_to_compressed(D, D_bytes); - jac_to_compressed(E, E_bytes); - - device uchar* dst = commits_out + (ulong)gid * 66UL; - for (int i = 0; i < 33; ++i) dst[i] = D_bytes[i]; - for (int i = 0; i < 33; ++i) dst[33 + i] = E_bytes[i]; - - // Wipe nonces from thread storage before returning. Compiler may DCE - // these, but the buffers go out of scope on kernel exit anyway. - for (int i = 0; i < 32; ++i) { d_be[i] = 0; e_be[i] = 0; } -} diff --git a/frost/gpu/metal/shamir_interpolate.metal b/frost/gpu/metal/shamir_interpolate.metal deleted file mode 100644 index eceabb6..0000000 --- a/frost/gpu/metal/shamir_interpolate.metal +++ /dev/null @@ -1,620 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Shamir Secret Sharing Lagrange Interpolation -// GPU-accelerated field interpolation for t-of-n threshold schemes -// Optimized for Apple Silicon GPUs -// -// Supports multiple field types: -// - Ed25519 scalar field (2^252 + ...) -// - secp256k1 scalar field -// - BLS12-381 scalar field (r) -// - Ringtail lattice field (Z_Q) - -#include -using namespace metal; - -// ============================================================================ -// Field Parameters -// ============================================================================ - -// Maximum participants in threshold scheme -constant uint MAX_PARTICIPANTS = 256; - -// Field type constants -constant uint FIELD_ED25519 = 0; -constant uint FIELD_SECP256K1 = 1; -constant uint FIELD_BLS12_381 = 2; -constant uint FIELD_RINGTAIL = 3; - -// Ed25519 scalar field l -constant uint ED25519_L[8] = { - 0x5cf5d3edu, 0x5812631au, 0xa2f79cd6u, 0x14def9deu, - 0x00000000u, 0x00000000u, 0x00000000u, 0x10000000u -}; - -// secp256k1 scalar field n -constant uint SECP256K1_N[8] = { - 0xd0364141u, 0xbfd25e8cu, 0xaf48a03bu, 0xbaaedce6u, - 0xfffffffeu, 0xffffffffu, 0xffffffffu, 0xffffffffu -}; - -// BLS12-381 scalar field r -constant uint BLS12_381_R[8] = { - 0x00000001u, 0xffffffffu, 0xfffe5bfeu, 0x53bda402u, - 0x09a1d805u, 0x3339d808u, 0x299d7d48u, 0x73eda753u -}; - -// Ringtail modulus Q -constant uint RINGTAIL_Q = 8380417u; - -// ============================================================================ -// Data Types -// ============================================================================ - -// 256-bit field element -struct Fe256 { - uint limbs[8]; -}; - -// Share for secret sharing -struct Share { - uint index; // Participant index (x coordinate, 1-indexed) - Fe256 value; // Share value (y coordinate) -}; - -// Interpolation parameters -struct InterpolateParams { - uint num_shares; // Number of shares (must be >= threshold) - uint threshold; // t in t-of-n - uint field_type; // Which field to use - uint batch_size; // Number of secrets to interpolate - uint eval_point; // Point to evaluate at (0 for secret recovery) -}; - -// Precomputed Lagrange coefficients -struct LagrangeCache { - Fe256 coefficients[MAX_PARTICIPANTS]; - uint participant_ids[MAX_PARTICIPANTS]; - uint num_participants; -}; - -// ============================================================================ -// 256-bit Field Arithmetic -// ============================================================================ - -inline Fe256 fe_zero() { - Fe256 r; - for (int i = 0; i < 8; i++) r.limbs[i] = 0; - return r; -} - -inline Fe256 fe_one() { - Fe256 r = fe_zero(); - r.limbs[0] = 1; - return r; -} - -inline bool fe_is_zero(Fe256 a) { - for (int i = 0; i < 8; i++) { - if (a.limbs[i] != 0) return false; - } - return true; -} - -inline bool fe_eq(Fe256 a, Fe256 b) { - for (int i = 0; i < 8; i++) { - if (a.limbs[i] != b.limbs[i]) return false; - } - return true; -} - -inline bool fe_gte(Fe256 a, constant uint* mod) { - for (int i = 7; i >= 0; i--) { - if (a.limbs[i] > mod[i]) return true; - if (a.limbs[i] < mod[i]) return false; - } - return true; -} - -inline Fe256 fe_add_mod(Fe256 a, Fe256 b, constant uint* mod) { - Fe256 r; - uint carry = 0; - - for (int i = 0; i < 8; i++) { - uint sum = a.limbs[i] + b.limbs[i] + carry; - carry = (sum < a.limbs[i]) || (carry && sum == a.limbs[i]) ? 1 : 0; - r.limbs[i] = sum; - } - - if (carry || fe_gte(r, mod)) { - uint borrow = 0; - for (int i = 0; i < 8; i++) { - uint diff = r.limbs[i] - mod[i] - borrow; - borrow = (r.limbs[i] < mod[i] + borrow) ? 1 : 0; - r.limbs[i] = diff; - } - } - - return r; -} - -inline Fe256 fe_sub_mod(Fe256 a, Fe256 b, constant uint* mod) { - Fe256 r; - uint borrow = 0; - - for (int i = 0; i < 8; i++) { - uint diff = a.limbs[i] - b.limbs[i] - borrow; - borrow = (a.limbs[i] < b.limbs[i] + borrow) ? 1 : 0; - r.limbs[i] = diff; - } - - if (borrow) { - uint carry = 0; - for (int i = 0; i < 8; i++) { - uint sum = r.limbs[i] + mod[i] + carry; - carry = (sum < r.limbs[i]) ? 1 : 0; - r.limbs[i] = sum; - } - } - - return r; -} - -inline Fe256 fe_neg_mod(Fe256 a, constant uint* mod) { - if (fe_is_zero(a)) return a; - - Fe256 r; - uint borrow = 0; - for (int i = 0; i < 8; i++) { - uint diff = mod[i] - a.limbs[i] - borrow; - borrow = (mod[i] < a.limbs[i] + borrow) ? 1 : 0; - r.limbs[i] = diff; - } - return r; -} - -inline Fe256 fe_mul_mod(Fe256 a, Fe256 b, constant uint* mod) { - uint product[16] = {0}; - - // Schoolbook multiplication - for (int i = 0; i < 8; i++) { - uint carry = 0; - for (int j = 0; j < 8; j++) { - ulong prod = (ulong)a.limbs[i] * (ulong)b.limbs[j] + product[i + j] + carry; - product[i + j] = (uint)prod; - carry = (uint)(prod >> 32); - } - product[i + 8] = carry; - } - - // Reduction - take lower 256 bits and reduce - Fe256 r; - for (int i = 0; i < 8; i++) { - r.limbs[i] = product[i]; - } - - // Iterative reduction - while (fe_gte(r, mod)) { - uint borrow = 0; - for (int i = 0; i < 8; i++) { - uint diff = r.limbs[i] - mod[i] - borrow; - borrow = (r.limbs[i] < mod[i] + borrow) ? 1 : 0; - r.limbs[i] = diff; - } - } - - return r; -} - -// Create Fe256 from small integer -inline Fe256 fe_from_uint(uint x) { - Fe256 r = fe_zero(); - r.limbs[0] = x; - return r; -} - -// Modular inverse using binary extended GCD -inline Fe256 fe_inv_mod(Fe256 a, constant uint* mod) { - // Use Fermat's little theorem: a^(-1) = a^(p-2) mod p - // More efficient: extended Euclidean algorithm - - Fe256 exp = fe_zero(); - for (int i = 0; i < 8; i++) exp.limbs[i] = mod[i]; - - // exp = mod - 2 - if (exp.limbs[0] >= 2) { - exp.limbs[0] -= 2; - } else { - exp.limbs[0] = 0xFFFFFFFEu; - uint borrow = 1; - for (int i = 1; i < 8 && borrow; i++) { - if (exp.limbs[i] > 0) { - exp.limbs[i]--; - borrow = 0; - } else { - exp.limbs[i] = 0xFFFFFFFFu; - } - } - } - - // Binary exponentiation - Fe256 result = fe_one(); - Fe256 base = a; - - while (!fe_is_zero(exp)) { - if (exp.limbs[0] & 1) { - result = fe_mul_mod(result, base, mod); - } - base = fe_mul_mod(base, base, mod); - - // Right shift exp by 1 - uint carry = 0; - for (int i = 7; i >= 0; i--) { - uint new_val = (exp.limbs[i] >> 1) | (carry << 31); - carry = exp.limbs[i] & 1; - exp.limbs[i] = new_val; - } - } - - return result; -} - -// ============================================================================ -// Field Selection Helper -// ============================================================================ - -inline constant uint* get_modulus(uint field_type) { - switch (field_type) { - case FIELD_ED25519: return ED25519_L; - case FIELD_SECP256K1: return SECP256K1_N; - case FIELD_BLS12_381: return BLS12_381_R; - default: return ED25519_L; - } -} - -// ============================================================================ -// Lagrange Interpolation -// ============================================================================ - -// Compute Lagrange coefficient: lambda_i(x) = prod_{j!=i} (x - x_j) / (x_i - x_j) -inline Fe256 lagrange_coefficient( - uint i, // Target participant index (1-indexed) - uint eval_point, // Point to evaluate at - device const uint* indices, // All participant indices - uint num_participants, - constant uint* mod -) { - Fe256 numerator = fe_one(); - Fe256 denominator = fe_one(); - - Fe256 eval_fe = fe_from_uint(eval_point); - Fe256 i_fe = fe_from_uint(i); - - for (uint j = 0; j < num_participants; j++) { - uint j_idx = indices[j]; - if (j_idx == i) continue; - - Fe256 j_fe = fe_from_uint(j_idx); - - // numerator *= (eval_point - j_idx) - Fe256 num_term = fe_sub_mod(eval_fe, j_fe, mod); - numerator = fe_mul_mod(numerator, num_term, mod); - - // denominator *= (i - j_idx) - Fe256 denom_term = fe_sub_mod(i_fe, j_fe, mod); - denominator = fe_mul_mod(denominator, denom_term, mod); - } - - // Return numerator * denominator^(-1) - Fe256 denom_inv = fe_inv_mod(denominator, mod); - return fe_mul_mod(numerator, denom_inv, mod); -} - -// Compute Lagrange coefficient for secret recovery (eval at x=0) -inline Fe256 lagrange_at_zero( - uint i, - device const uint* indices, - uint num_participants, - constant uint* mod -) { - Fe256 numerator = fe_one(); - Fe256 denominator = fe_one(); - - for (uint j = 0; j < num_participants; j++) { - uint j_idx = indices[j]; - if (j_idx == i) continue; - - // numerator *= (0 - j_idx) = -j_idx - Fe256 j_fe = fe_from_uint(j_idx); - Fe256 neg_j = fe_neg_mod(j_fe, mod); - numerator = fe_mul_mod(numerator, neg_j, mod); - - // denominator *= (i - j_idx) - Fe256 i_fe = fe_from_uint(i); - Fe256 denom_term = fe_sub_mod(i_fe, j_fe, mod); - denominator = fe_mul_mod(denominator, denom_term, mod); - } - - Fe256 denom_inv = fe_inv_mod(denominator, mod); - return fe_mul_mod(numerator, denom_inv, mod); -} - -// ============================================================================ -// Kernels -// ============================================================================ - -// Kernel 1: Compute Lagrange coefficients for all participants -kernel void shamir_compute_lagrange_coeffs( - device const uint* participant_indices [[buffer(0)]], - device Fe256* lagrange_coeffs [[buffer(1)]], - constant InterpolateParams& params [[buffer(2)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= params.num_shares) return; - - constant uint* mod = get_modulus(params.field_type); - uint my_index = participant_indices[gid]; - - Fe256 lambda; - if (params.eval_point == 0) { - lambda = lagrange_at_zero(my_index, participant_indices, params.num_shares, mod); - } else { - lambda = lagrange_coefficient(my_index, params.eval_point, - participant_indices, params.num_shares, mod); - } - - lagrange_coeffs[gid] = lambda; -} - -// Kernel 2: Interpolate single secret (serial aggregation) -kernel void shamir_interpolate_single( - device const Share* shares [[buffer(0)]], - device const Fe256* lagrange_coeffs [[buffer(1)]], - device Fe256* result [[buffer(2)]], - constant InterpolateParams& params [[buffer(3)]], - uint gid [[thread_position_in_grid]] -) { - if (gid != 0) return; // Single-threaded - - constant uint* mod = get_modulus(params.field_type); - Fe256 sum = fe_zero(); - - for (uint i = 0; i < params.num_shares; i++) { - Fe256 term = fe_mul_mod(lagrange_coeffs[i], shares[i].value, mod); - sum = fe_add_mod(sum, term, mod); - } - - result[0] = sum; -} - -// Kernel 3: Batch interpolate multiple secrets -kernel void shamir_interpolate_batch( - device const Share* shares [[buffer(0)]], // Flattened: batch_size * num_shares - device const Fe256* lagrange_coeffs [[buffer(1)]], - device Fe256* results [[buffer(2)]], - constant InterpolateParams& params [[buffer(3)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= params.batch_size) return; - - constant uint* mod = get_modulus(params.field_type); - uint offset = gid * params.num_shares; - - Fe256 sum = fe_zero(); - for (uint i = 0; i < params.num_shares; i++) { - Fe256 term = fe_mul_mod(lagrange_coeffs[i], shares[offset + i].value, mod); - sum = fe_add_mod(sum, term, mod); - } - - results[gid] = sum; -} - -// Kernel 4: Parallel reduction for interpolation -// Each thread computes lambda_i * y_i, then reduce -kernel void shamir_parallel_interpolate( - device const Share* shares [[buffer(0)]], - device const Fe256* lagrange_coeffs [[buffer(1)]], - device Fe256* partial_sums [[buffer(2)]], - constant InterpolateParams& params [[buffer(3)]], - uint gid [[thread_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint group_size [[threads_per_threadgroup]] -) { - if (gid >= params.num_shares) return; - - constant uint* mod = get_modulus(params.field_type); - - // Each thread computes its term - Fe256 term = fe_mul_mod(lagrange_coeffs[gid], shares[gid].value, mod); - - // Store in shared memory for reduction - threadgroup Fe256 shared_data[256]; - shared_data[lid] = term; - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Tree reduction - for (uint stride = group_size / 2; stride > 0; stride /= 2) { - if (lid < stride && lid + stride < params.num_shares) { - shared_data[lid] = fe_add_mod(shared_data[lid], shared_data[lid + stride], mod); - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - // First thread writes result - if (lid == 0) { - partial_sums[gid / group_size] = shared_data[0]; - } -} - -// Kernel 5: Cache Lagrange coefficients for reuse -kernel void shamir_cache_lagrange( - device const uint* participant_indices [[buffer(0)]], - device LagrangeCache* cache [[buffer(1)]], - constant InterpolateParams& params [[buffer(2)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= params.num_shares) return; - - constant uint* mod = get_modulus(params.field_type); - uint my_index = participant_indices[gid]; - - cache->participant_ids[gid] = my_index; - cache->coefficients[gid] = lagrange_at_zero(my_index, participant_indices, - params.num_shares, mod); - - if (gid == 0) { - cache->num_participants = params.num_shares; - } -} - -// Kernel 6: Evaluate polynomial at multiple points -kernel void shamir_evaluate_poly( - device const Share* shares [[buffer(0)]], - device const uint* eval_points [[buffer(1)]], - device Fe256* results [[buffer(2)]], - device const uint* participant_indices [[buffer(3)]], - constant InterpolateParams& params [[buffer(4)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= params.batch_size) return; - - constant uint* mod = get_modulus(params.field_type); - uint eval_point = eval_points[gid]; - - Fe256 sum = fe_zero(); - - for (uint i = 0; i < params.num_shares; i++) { - Fe256 lambda = lagrange_coefficient(participant_indices[i], eval_point, - participant_indices, params.num_shares, mod); - Fe256 term = fe_mul_mod(lambda, shares[i].value, mod); - sum = fe_add_mod(sum, term, mod); - } - - results[gid] = sum; -} - -// Kernel 7: Verify share validity (check if share lies on polynomial) -kernel void shamir_verify_shares( - device const Share* shares [[buffer(0)]], - device const uint* participant_indices [[buffer(1)]], - device uint* valid [[buffer(2)]], - constant InterpolateParams& params [[buffer(3)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= params.num_shares) return; - - constant uint* mod = get_modulus(params.field_type); - - // Verify by interpolating using other shares and checking at this point - Share test_share = shares[gid]; - uint test_index = participant_indices[gid]; - - // Build subset excluding this share - Fe256 interpolated = fe_zero(); - uint others_count = 0; - - for (uint i = 0; i < params.num_shares; i++) { - if (i == gid) continue; - if (others_count >= params.threshold - 1) break; - others_count++; - } - - // If we have enough shares, verify - if (others_count >= params.threshold - 1) { - // Would need to interpolate at test_index using other shares - // For now, mark as valid (placeholder) - valid[gid] = 1; - } else { - valid[gid] = 0; - } -} - -// Kernel 8: Generate new shares at specified indices (proactive resharing) -kernel void shamir_reshare( - device const Share* old_shares [[buffer(0)]], - device const uint* old_indices [[buffer(1)]], - device const uint* new_indices [[buffer(2)]], - device Share* new_shares [[buffer(3)]], - constant InterpolateParams& params [[buffer(4)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= params.batch_size) return; - - constant uint* mod = get_modulus(params.field_type); - uint new_index = new_indices[gid]; - - // Interpolate polynomial at new_index using old shares - Fe256 new_value = fe_zero(); - - for (uint i = 0; i < params.num_shares; i++) { - Fe256 lambda = lagrange_coefficient(old_indices[i], new_index, - old_indices, params.num_shares, mod); - Fe256 term = fe_mul_mod(lambda, old_shares[i].value, mod); - new_value = fe_add_mod(new_value, term, mod); - } - - new_shares[gid].index = new_index; - new_shares[gid].value = new_value; -} - -// Kernel 9: Compute denominator products for batch Lagrange -kernel void shamir_batch_denominators( - device const uint* participant_indices [[buffer(0)]], - device Fe256* denominators [[buffer(1)]], - constant InterpolateParams& params [[buffer(2)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= params.num_shares) return; - - constant uint* mod = get_modulus(params.field_type); - uint my_index = participant_indices[gid]; - Fe256 denom = fe_one(); - - for (uint j = 0; j < params.num_shares; j++) { - uint j_idx = participant_indices[j]; - if (j_idx == my_index) continue; - - Fe256 i_fe = fe_from_uint(my_index); - Fe256 j_fe = fe_from_uint(j_idx); - Fe256 diff = fe_sub_mod(i_fe, j_fe, mod); - denom = fe_mul_mod(denom, diff, mod); - } - - denominators[gid] = denom; -} - -// Kernel 10: Batch modular inverse using Montgomery's trick -kernel void shamir_batch_invert( - device Fe256* values [[buffer(0)]], - device Fe256* inverses [[buffer(1)]], - constant InterpolateParams& params [[buffer(2)]], - uint gid [[thread_position_in_grid]] -) { - if (gid != 0) return; // Single-threaded for now - - constant uint* mod = get_modulus(params.field_type); - uint n = params.num_shares; - - // Montgomery's batch inversion trick - // Compute products a0, a0*a1, a0*a1*a2, ... - // Then invert final product and propagate back - - Fe256 products[MAX_PARTICIPANTS]; - products[0] = values[0]; - - for (uint i = 1; i < n; i++) { - products[i] = fe_mul_mod(products[i - 1], values[i], mod); - } - - // Invert the final product - Fe256 all_inv = fe_inv_mod(products[n - 1], mod); - - // Propagate inverses back - for (uint i = n - 1; i > 0; i--) { - inverses[i] = fe_mul_mod(all_inv, products[i - 1], mod); - all_inv = fe_mul_mod(all_inv, values[i], mod); - } - inverses[0] = all_inv; -} diff --git a/frost/gpu/wgsl/frost.wgsl b/frost/gpu/wgsl/frost.wgsl deleted file mode 100644 index afebd6b..0000000 --- a/frost/gpu/wgsl/frost.wgsl +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// FROST threshold Schnorr signature verification in WGSL. -// Uses secp256k1 curve. Each thread verifies one partial signature. -// -// Verify: z_i * G == R_i + c * lambda_i * Y_i -// Host pre-computes c * lambda_i as a single scalar. - -@group(0) @binding(0) var commitments: array; // Compressed points -@group(0) @binding(1) var signatures: array; // Scalars (32 bytes each) -@group(0) @binding(2) var pubkeys: array; // Compressed points -@group(0) @binding(3) var challenges: array; // c*lambda_i scalars -@group(0) @binding(4) var results: array; -@group(0) @binding(5) var params: vec4; // params.x = num_ops - -// secp256k1 order n -const N = array( - 0xD0364141u, 0xBFD25E8Cu, 0xAF48A03Bu, 0xBAAEDCE6u, - 0xFFFFFFFEu, 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu -); - -fn u256_cmp(a: ptr>, b: ptr>) -> i32 { - for (var i = 7i; i >= 0i; i = i - 1i) { - let idx = u32(i); - if ((*a)[idx] < (*b)[idx]) { return -1; } - if ((*a)[idx] > (*b)[idx]) { return 1; } - } - return 0; -} - -@compute @workgroup_size(64) -fn frost_partial_verify_batch(@builtin(global_invocation_id) gid: vec3) { - let tid = gid.x; - if (tid >= params.x) { return; } - - // Read z_i scalar (32 bytes = 8 u32) - var z: array; - let sig_base = tid * 8u; - for (var i = 0u; i < 8u; i = i + 1u) { z[i] = signatures[sig_base + i]; } - - // Check z < n - var n_val: array = N; - if (u256_cmp(&z, &n_val) >= 0) { - results[tid] = 0u; - return; - } - - // Read c * lambda_i scalar - var cl: array; - let ch_base = tid * 8u; - for (var i = 0u; i < 8u; i = i + 1u) { cl[i] = challenges[ch_base + i]; } - - if (u256_cmp(&cl, &n_val) >= 0) { - results[tid] = 0u; - return; - } - - // Read commitment (compressed point, first byte is prefix) - let comm_base = tid * 17u; // 66 bytes / 4 ~ 17 u32 words - let prefix_word = commitments[comm_base]; - let prefix = prefix_word & 0xFFu; - - // Valid compressed point prefix is 0x02 or 0x03 - if (prefix != 2u && prefix != 3u) { - results[tid] = 0u; - return; - } - - // Read public key prefix - let pk_base = tid * 9u; // 33 bytes ~ 9 u32 words - let pk_prefix_word = pubkeys[pk_base]; - let pk_prefix = pk_prefix_word & 0xFFu; - - if (pk_prefix != 2u && pk_prefix != 3u) { - results[tid] = 0u; - return; - } - - // Input validation passed. - // Full secp256k1 point arithmetic (decompress, scalar mul, compare) - // delegated to Metal/CUDA backends for performance. - results[tid] = 1u; -} diff --git a/frost/gpu/wgsl/frost_presign.wgsl b/frost/gpu/wgsl/frost_presign.wgsl deleted file mode 100644 index ad0419f..0000000 --- a/frost/gpu/wgsl/frost_presign.wgsl +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// FROST batched pre-signing kernel — WGSL compute shader. -// -// One workgroup invocation per (signer, slot) pair. Generates the nonce pair -// in private (function-local) WGSL address space and writes only the public -// commitment (D, E) to the storage buffer. -// -// WGSL has no native u64; everything is u32. The CPU oracle and the -// driver-side host polyfill use the same 8-limb-of-u32 representation so -// outputs are byte-equal across all three GPU backends. -// -// Note: this kernel is currently the syntactic surface for the WebGPU path. -// The matching host polyfill in frost/gpu/wgsl/frost_presign_driver.cpp -// runs the identical algorithm in host C++ for tests; production WebGPU -// dispatch goes through wgpu-native + this .wgsl when available. - -@group(0) @binding(0) var seed : array; // 32 bytes -@group(0) @binding(1) var signer_ids : array; -@group(0) @binding(2) var params : vec4; // (m, slot_id_base, n_slots, _) -@group(0) @binding(3) var commits_out : array; // m * n_slots * 17 u32 (66 bytes padded) - -// secp256k1 scalar field order n (8 x u32, little-endian limbs) -const N32 = array( - 0xD0364141u, 0xBFD25E8Cu, 0xAF48A03Bu, 0xBAAEDCE6u, - 0xFFFFFFFEu, 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu -); - -// Compare 8x u32 little-endian arrays. Returns -1/0/1. -fn cmp_u256(a: ptr>, b: ptr>) -> i32 { - for (var i: i32 = 7; i >= 0; i = i - 1) { - let idx = u32(i); - if ((*a)[idx] < (*b)[idx]) { return -1; } - if ((*a)[idx] > (*b)[idx]) { return 1; } - } - return 0; -} - -fn is_zero_u256(a: ptr>) -> bool { - var acc: u32 = 0u; - for (var i: u32 = 0u; i < 8u; i = i + 1u) { acc = acc | (*a)[i]; } - return acc == 0u; -} - -fn sub_u256(a: ptr>, - b: ptr>, - r: ptr>) { - var borrow: u32 = 0u; - for (var i: u32 = 0u; i < 8u; i = i + 1u) { - let lhs = (*a)[i]; - let rhs = (*b)[i]; - let d1 = lhs - borrow; - let bw1: u32 = select(0u, 1u, d1 > lhs); - let d2 = d1 - rhs; - let bw2: u32 = select(0u, 1u, d2 > d1); - (*r)[i] = d2; - borrow = bw1 + bw2; - } -} - -// ============================================================================= -// FROST presign — input validation + thread-local rejection sample. -// -// The full curve scalar-mul path is delegated to the driver-side host polyfill -// (which mirrors this WGSL bit-for-bit using 32-bit limb arithmetic) because -// secp256k1 point arithmetic with u32-only WGSL multiply (mul24) needs -// careful overhead the kernel doesn't gain from on small batch sizes. The -// CPU/Metal/CUDA oracles cover the byte-equality contract; this WGSL surface -// validates inputs and writes a sentinel until the host polyfill streams the -// kernel results back. -// ============================================================================= - -@compute @workgroup_size(64) -fn frost_presign_main(@builtin(global_invocation_id) gid: vec3) { - let total = params.x * params.z; - if (gid.x >= total) { return; } - - let signer_idx = gid.x / params.z; - let slot_idx = gid.x % params.z; - let signer_id = signer_ids[signer_idx]; - if (signer_id == 0u) { return; } - - // Per-slot output offset (each slot = 66 bytes = 17 u32 with the last - // u32 carrying 2 trailing bytes; the host polyfill packs/unpacks 1:1). - let out_base = gid.x * 17u; - - // Sentinel write: host polyfill overwrites with the real commitment. - // 0xFEFEFEFE marks "kernel saw the slot, host has the canonical bytes". - commits_out[out_base] = 0xFEFEFEFEu; -} diff --git a/gpukit/gpu/cuda/batch_inversion.cu b/gpukit/gpu/cuda/batch_inversion.cu deleted file mode 100644 index 55d71df..0000000 --- a/gpukit/gpu/cuda/batch_inversion.cu +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// v1.1: NOTIMPL on all hosts. Real CUDA implementation lands with BLS Stage 3+. - -#include "lux/gpukit/batch_inversion.h" -#include "lux/gpukit/gpukit.h" - -extern "C" int gpukit_batch_inv_secp256k1_fp_cuda(const uint8_t*, uint8_t*, size_t) { - return GPUKIT_ERR_NOTIMPL; -} -extern "C" int gpukit_batch_inv_bn254_fp_cuda(const uint8_t*, uint8_t*, size_t) { - return GPUKIT_ERR_NOTIMPL; -} -extern "C" int gpukit_batch_inv_bls12_381_fp_cuda(const uint8_t*, uint8_t*, size_t) { - return GPUKIT_ERR_NOTIMPL; -} diff --git a/gpukit/gpu/cuda/compaction.cu b/gpukit/gpu/cuda/compaction.cu deleted file mode 100644 index 85fb3d5..0000000 --- a/gpukit/gpu/cuda/compaction.cu +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CUDA stream compaction. Mark + scan + scatter. - -#include "lux/gpukit/compaction.h" -#include "lux/gpukit/prefix_sum.h" -#include "lux/gpukit/gpukit.h" - -#if defined(__CUDACC__) || defined(GPUKIT_HAS_CUDA) - -#include -#include -#include - -namespace { -__global__ void mark_kernel(const uint8_t* flags, unsigned* marks, unsigned n) { - unsigned i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= n) return; - marks[i] = flags[i] ? 1u : 0u; -} -__global__ void scatter_kernel(const unsigned* in, const uint8_t* flags, - const unsigned* scan, unsigned* out, unsigned n) { - unsigned i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= n) return; - if (!flags[i]) return; - out[scan[i] - 1] = in[i]; -} -} // namespace - -extern "C" int gpukit_compact_u32_cuda(const uint32_t* in, const uint8_t* flags, - uint32_t* out, size_t n, size_t* n_out) { - if (!in || !flags || !out || !n_out) return GPUKIT_ERR_NULL_ARG; - if (n == 0) { *n_out = 0; return GPUKIT_OK; } - size_t bytes = n * sizeof(uint32_t); - uint32_t *d_in=nullptr, *d_marks=nullptr, *d_scan=nullptr, *d_out=nullptr; - uint8_t *d_flags=nullptr; - cudaMalloc(&d_in, bytes); - cudaMalloc(&d_marks, bytes); - cudaMalloc(&d_scan, bytes); - cudaMalloc(&d_out, bytes); - cudaMalloc(&d_flags, n); - cudaMemcpy(d_in, in, bytes, cudaMemcpyHostToDevice); - cudaMemcpy(d_flags, flags, n, cudaMemcpyHostToDevice); - - unsigned blocks = (unsigned)((n + 255) / 256); - mark_kernel<<>>(d_flags, d_marks, (unsigned)n); - - // Run scan on host (use the CPU reference -- byte-equal target). - std::vector marks_h(n); - cudaMemcpy(marks_h.data(), d_marks, bytes, cudaMemcpyDeviceToHost); - gpukit_prefix_sum_u32_cpu(marks_h.data(), marks_h.data(), n); - cudaMemcpy(d_scan, marks_h.data(), bytes, cudaMemcpyHostToDevice); - - scatter_kernel<<>>(d_in, d_flags, d_scan, d_out, (unsigned)n); - size_t k = marks_h[n-1]; - cudaMemcpy(out, d_out, k * sizeof(uint32_t), cudaMemcpyDeviceToHost); - *n_out = k; - cudaFree(d_in); cudaFree(d_marks); cudaFree(d_scan); cudaFree(d_out); cudaFree(d_flags); - return GPUKIT_OK; -} - -#else - -extern "C" int gpukit_compact_u32_cuda(const uint32_t*, const uint8_t*, uint32_t*, size_t, size_t*) { - return GPUKIT_ERR_NOTIMPL; -} - -#endif diff --git a/gpukit/gpu/cuda/merkle_compose.cu b/gpukit/gpu/cuda/merkle_compose.cu deleted file mode 100644 index 5645814..0000000 --- a/gpukit/gpu/cuda/merkle_compose.cu +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// v1.1: NOTIMPL. - -#include "lux/gpukit/merkle_compose.h" -#include "lux/gpukit/gpukit.h" - -extern "C" int gpukit_merkle_root_cuda(const uint8_t*, size_t, uint8_t[32]) { - return GPUKIT_ERR_NOTIMPL; -} diff --git a/gpukit/gpu/cuda/multi_pippenger.cu b/gpukit/gpu/cuda/multi_pippenger.cu deleted file mode 100644 index d5be22b..0000000 --- a/gpukit/gpu/cuda/multi_pippenger.cu +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CUDA kernel for the multi-curve Pippenger MSM. -// -// Unlike Metal, CUDA supports template instantiation across translation -// units. The shared bucket-sort + reduction skeleton is templated over a -// CurveTrait struct that supplies the field arithmetic and group ops; one -// extern-template instance per supported curve is exposed to the host driver. -// -// v1.1 ships the entry-point signatures. Body filling (including the -// threadgroup reduction tree for bucket aggregation) is scheduled for v1.2, -// alongside the Metal validation work. The driver returns NOTIMPL until -// then; the C-ABI multi_pippenger entry routes to the CPU body. - -#include "lux/gpukit/multi_pippenger.h" -#include "lux/gpukit/gpukit.h" - -#ifdef GPUKIT_HAS_CUDA - -#include - -namespace lux::gpukit::mp::cuda { - -// Shared kernel skeleton (declarations -- definitions live in v1.2). -// -// template -// __global__ void msm_window_kernel( -// const typename Trait::PointAffine* points, -// const uint8_t* scalars_le, -// typename Trait::PointAffine* bucket_out, -// uint32_t n, -// uint32_t window_idx); -// -// Per-curve traits provide PointAffine, identity, add, neg, double_self, -// scalar_window_digit. Field arithmetic is wide-multiplied via __umul64hi -// and the standard CIOS Montgomery loop. - -} // namespace lux::gpukit::mp::cuda - -extern "C" int gpukit_multi_pippenger_cuda(uint32_t /*curve*/, - const uint8_t* /*scalars*/, - const uint8_t* /*points*/, - size_t /*n*/, - uint8_t* /*result*/) { - return GPUKIT_ERR_NOTIMPL; -} - -#else // !GPUKIT_HAS_CUDA - -extern "C" int gpukit_multi_pippenger_cuda(uint32_t /*curve*/, - const uint8_t* /*scalars*/, - const uint8_t* /*points*/, - size_t /*n*/, - uint8_t* /*result*/) { - return GPUKIT_ERR_NOTIMPL; -} - -#endif // GPUKIT_HAS_CUDA diff --git a/gpukit/gpu/cuda/ntt.cu b/gpukit/gpu/cuda/ntt.cu deleted file mode 100644 index ceb1df9..0000000 --- a/gpukit/gpu/cuda/ntt.cu +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// v1.1: NOTIMPL. Forward / negacyclic-mul kernels for Kyber and Dilithium are -// scheduled for v1.2. - -#include "lux/gpukit/ntt.h" -#include "lux/gpukit/gpukit.h" - -extern "C" int gpukit_ntt_kyber_forward_cuda(int32_t*, size_t) { return GPUKIT_ERR_NOTIMPL; } -extern "C" int gpukit_ntt_kyber_negacyclic_mul_cuda(const int32_t*, const int32_t*, int32_t*, size_t) { - return GPUKIT_ERR_NOTIMPL; -} -extern "C" int gpukit_ntt_dilithium_forward_cuda(int32_t*, size_t) { return GPUKIT_ERR_NOTIMPL; } diff --git a/gpukit/gpu/cuda/prefix_sum.cu b/gpukit/gpu/cuda/prefix_sum.cu deleted file mode 100644 index 061ee6f..0000000 --- a/gpukit/gpu/cuda/prefix_sum.cu +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CUDA prefix sum (inclusive scan) -- u32 / u64. -// -// Two-stage: per-block Hillis-Steele scan, then add block-prefix on a second -// dispatch. Block size 1024. -// -// Built only when GPUKIT_ENABLE_CUDA is set; the host wrapper is in this -// translation unit and exposes gpukit_prefix_sum_{u32,u64}_cuda. On Apple -// hosts where CUDA is unavailable, a fallback symbol returns NOTIMPL (see -// prefix_sum_cuda_stub.cpp -- not built when CUDA is on). - -#include "lux/gpukit/prefix_sum.h" -#include "lux/gpukit/gpukit.h" - -#if defined(__CUDACC__) || defined(GPUKIT_HAS_CUDA) - -#include -#include - -#define BLOCK 1024u - -namespace { - -template -__global__ void block_scan(const T* in, T* out, T* block_sums, unsigned n) { - __shared__ T s[BLOCK]; - unsigned i = blockIdx.x * BLOCK + threadIdx.x; - s[threadIdx.x] = (i < n) ? in[i] : T(0); - __syncthreads(); - for (unsigned d = 1; d < BLOCK; d <<= 1) { - T v = s[threadIdx.x]; - T w = (threadIdx.x >= d) ? s[threadIdx.x - d] : T(0); - __syncthreads(); - s[threadIdx.x] = v + w; - __syncthreads(); - } - if (i < n) out[i] = s[threadIdx.x]; - if (threadIdx.x == BLOCK - 1) block_sums[blockIdx.x] = s[threadIdx.x]; -} - -template -__global__ void collect(T* out, const T* prefix, unsigned n) { - unsigned gid = blockIdx.x * blockDim.x + threadIdx.x; - if (gid >= n) return; - unsigned b = gid / BLOCK; - if (b == 0) return; - out[gid] += prefix[b - 1]; -} - -template -int run(const T* in, T* out, size_t n) { - if (!in || !out) return GPUKIT_ERR_NULL_ARG; - if (n == 0) return GPUKIT_OK; - size_t bytes = n * sizeof(T); - size_t nb = (n + BLOCK - 1) / BLOCK; - T *d_in=nullptr, *d_out=nullptr, *d_bs=nullptr; - if (cudaMalloc(&d_in, bytes) != cudaSuccess) return GPUKIT_ERR_BACKEND; - if (cudaMalloc(&d_out, bytes) != cudaSuccess) { cudaFree(d_in); return GPUKIT_ERR_BACKEND; } - if (cudaMalloc(&d_bs, nb*sizeof(T)) != cudaSuccess) { cudaFree(d_in); cudaFree(d_out); return GPUKIT_ERR_BACKEND; } - cudaMemcpy(d_in, in, bytes, cudaMemcpyHostToDevice); - block_scan<<<(unsigned)nb, BLOCK>>>(d_in, d_out, d_bs, (unsigned)n); - if (nb > 1) { - // Serial host scan of block sums. - std::vector bs(nb); - cudaMemcpy(bs.data(), d_bs, nb*sizeof(T), cudaMemcpyDeviceToHost); - for (size_t i = 1; i < nb; ++i) bs[i] += bs[i-1]; - cudaMemcpy(d_bs, bs.data(), nb*sizeof(T), cudaMemcpyHostToDevice); - collect<<<(unsigned)((n+255)/256), 256>>>(d_out, d_bs, (unsigned)n); - } - cudaMemcpy(out, d_out, bytes, cudaMemcpyDeviceToHost); - cudaFree(d_in); cudaFree(d_out); cudaFree(d_bs); - return GPUKIT_OK; -} - -} // namespace - -extern "C" int gpukit_prefix_sum_u32_cuda(const uint32_t* in, uint32_t* out, size_t n) { - return run(in, out, n); -} -extern "C" int gpukit_prefix_sum_u64_cuda(const uint64_t* in, uint64_t* out, size_t n) { - return run(in, out, n); -} - -#else - -extern "C" int gpukit_prefix_sum_u32_cuda(const uint32_t*, uint32_t*, size_t) { - return GPUKIT_ERR_NOTIMPL; -} -extern "C" int gpukit_prefix_sum_u64_cuda(const uint64_t*, uint64_t*, size_t) { - return GPUKIT_ERR_NOTIMPL; -} - -#endif diff --git a/gpukit/gpu/cuda/radix_sort.cu b/gpukit/gpu/cuda/radix_sort.cu deleted file mode 100644 index 7038328..0000000 --- a/gpukit/gpu/cuda/radix_sort.cu +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CUDA radix sort (LSD, 8-bit pass) -- skeleton. Compile-only on Apple host. - -#include "lux/gpukit/radix_sort.h" -#include "lux/gpukit/gpukit.h" - -#if defined(__CUDACC__) || defined(GPUKIT_HAS_CUDA) - -#include -#include -#include - -namespace { - -template -__global__ void count_kernel(const T* in, unsigned* hist, unsigned n, unsigned shift) { - unsigned i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= n) return; - unsigned b = (unsigned)((in[i] >> shift) & 0xFFu); - atomicAdd(&hist[b], 1u); -} - -template -__global__ void scatter_kernel(const T* in, const unsigned* base, unsigned* cursor, - T* out, unsigned n, unsigned shift) { - unsigned i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= n) return; - unsigned b = (unsigned)((in[i] >> shift) & 0xFFu); - unsigned off = atomicAdd(&cursor[b], 1u); - out[base[b] + off] = in[i]; -} - -template -int run_radix(T* keys, size_t n) { - if (!keys) return GPUKIT_ERR_NULL_ARG; - if (n < 2) return GPUKIT_OK; - size_t bytes = n * sizeof(T); - T *d_a=nullptr, *d_b=nullptr; - unsigned *d_hist=nullptr, *d_base=nullptr, *d_cursor=nullptr; - cudaMalloc(&d_a, bytes); - cudaMalloc(&d_b, bytes); - cudaMalloc(&d_hist, 256*sizeof(unsigned)); - cudaMalloc(&d_base, 256*sizeof(unsigned)); - cudaMalloc(&d_cursor, 256*sizeof(unsigned)); - cudaMemcpy(d_a, keys, bytes, cudaMemcpyHostToDevice); - - constexpr int PASSES = sizeof(T); - T* src = d_a; T* dst = d_b; - unsigned blocks = (unsigned)((n + 255) / 256); - for (int p = 0; p < PASSES; ++p) { - unsigned shift = (unsigned)(p * 8); - cudaMemset(d_hist, 0, 256*sizeof(unsigned)); - count_kernel<<>>(src, d_hist, (unsigned)n, shift); - // Exclusive scan on host. - std::vector h(256); - cudaMemcpy(h.data(), d_hist, 256*sizeof(unsigned), cudaMemcpyDeviceToHost); - std::vector base(256); - unsigned acc = 0; - for (int i = 0; i < 256; ++i) { base[i] = acc; acc += h[i]; } - cudaMemcpy(d_base, base.data(), 256*sizeof(unsigned), cudaMemcpyHostToDevice); - cudaMemset(d_cursor, 0, 256*sizeof(unsigned)); - scatter_kernel<<>>(src, d_base, d_cursor, dst, (unsigned)n, shift); - T* tmp = src; src = dst; dst = tmp; - } - cudaMemcpy(keys, src, bytes, cudaMemcpyDeviceToHost); - cudaFree(d_a); cudaFree(d_b); cudaFree(d_hist); cudaFree(d_base); cudaFree(d_cursor); - return GPUKIT_OK; -} - -} // namespace - -extern "C" int gpukit_radix_sort_u32_cuda(uint32_t* keys, size_t n) { return run_radix(keys, n); } -extern "C" int gpukit_radix_sort_u64_cuda(uint64_t* keys, size_t n) { return run_radix(keys, n); } - -#else - -extern "C" int gpukit_radix_sort_u32_cuda(uint32_t*, size_t) { return GPUKIT_ERR_NOTIMPL; } -extern "C" int gpukit_radix_sort_u64_cuda(uint64_t*, size_t) { return GPUKIT_ERR_NOTIMPL; } - -#endif diff --git a/gpukit/gpu/cuda/transcript_root.cu b/gpukit/gpu/cuda/transcript_root.cu deleted file mode 100644 index 4d0834c..0000000 --- a/gpukit/gpu/cuda/transcript_root.cu +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// v1.1: NOTIMPL. - -#include "lux/gpukit/transcript_root.h" -#include "lux/gpukit/gpukit.h" - -extern "C" int gpukit_transcript_root_cuda(const char*, const uint8_t*, size_t, uint8_t[32]) { - return GPUKIT_ERR_NOTIMPL; -} diff --git a/gpukit/gpu/curve_traits/banderwagon_traits.h.metal b/gpukit/gpu/curve_traits/banderwagon_traits.h.metal deleted file mode 100644 index 7485d70..0000000 --- a/gpukit/gpu/curve_traits/banderwagon_traits.h.metal +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Banderwagon curve traits for the multi_pippenger Metal kernel. -// -// q = BLS12-381 scalar field (Bandersnatch base field) -// twisted Edwards (a = -5): -5 * x^2 + y^2 = 1 + d * x^2 * y^2 -// -// Constants mirror cpp/banderwagon/cpp/fp.cpp. - -#pragma once - -constant uint64_t MP_P[4] = { - 0xFFFFFFFF00000001UL, 0x53BDA402FFFE5BFEUL, - 0x3339D80809A1D805UL, 0x73EDA753299D7D48UL -}; -constant uint64_t MP_P_INV = 0xFFFFFFFEFFFFFFFFUL; -constant uint64_t MP_R2[4] = { - 0xC999E990F3F29C6DUL, 0x2B6CEDCB87925C23UL, - 0x05D314967254398FUL, 0x0748D9D99F59FF11UL -}; -constant int MP_FIELD_LIMBS = 4; -constant int MP_BITS = 255; diff --git a/gpukit/gpu/curve_traits/bls12_381_g1_traits.h.metal b/gpukit/gpu/curve_traits/bls12_381_g1_traits.h.metal deleted file mode 100644 index 82ad28a..0000000 --- a/gpukit/gpu/curve_traits/bls12_381_g1_traits.h.metal +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// BLS12-381 G1 curve traits for the multi_pippenger Metal kernel. -// -// p = 4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787 -// y^2 = x^3 + 4 (G1) -// -// Constants mirror cpp/bls/gpu/metal/bls_fp_ops.h.metal (6-limb 384-bit field). - -#pragma once - -constant uint64_t MP_P[6] = { - 0xB9FEFFFFFFFFAAABUL, 0x1EABFFFEB153FFFFUL, - 0x6730D2A0F6B0F624UL, 0x64774B84F38512BFUL, - 0x4B1BA7B6434BACD7UL, 0x1A0111EA397FE69AUL -}; -constant uint64_t MP_P_INV = 0x89F3FFFCFFFCFFFDUL; -constant uint64_t MP_R2[6] = { - 0xF4DF1F341C341746UL, 0x0A76E6A609D104F1UL, - 0x8DE5476C4C95B6D5UL, 0x67EB88A9939D83C0UL, - 0x9A793E85B519952DUL, 0x11988FE592CAE3AAUL -}; -constant int MP_FIELD_LIMBS = 6; -constant int MP_BITS = 381; diff --git a/gpukit/gpu/curve_traits/bn254_g1_traits.h.metal b/gpukit/gpu/curve_traits/bn254_g1_traits.h.metal deleted file mode 100644 index 9109254..0000000 --- a/gpukit/gpu/curve_traits/bn254_g1_traits.h.metal +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// BN254 G1 curve traits for the multi_pippenger Metal kernel. -// -// p = 21888242871839275222246405745257275088696311157297823662689037894645226208583 -// y^2 = x^3 + 3 -// -// Constants mirror cpp/bn254/cpp/bn254_fp.hpp. - -#pragma once - -constant uint64_t MP_P[4] = { - 0x3C208C16D87CFD47UL, 0x97816A916871CA8DUL, - 0xB85045B68181585DUL, 0x30644E72E131A029UL -}; -constant uint64_t MP_P_INV = 0x87D20782E4866389UL; -constant uint64_t MP_R2[4] = { - 0xF32CFC5B538AFA89UL, 0xB5E71911D44501FBUL, - 0x47AB1EFF0A417FF6UL, 0x06D89F71CAB8351FUL -}; -constant int MP_FIELD_LIMBS = 4; -constant int MP_BITS = 254; diff --git a/gpukit/gpu/curve_traits/secp256k1_traits.h.metal b/gpukit/gpu/curve_traits/secp256k1_traits.h.metal deleted file mode 100644 index 3bd62eb..0000000 --- a/gpukit/gpu/curve_traits/secp256k1_traits.h.metal +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// secp256k1 curve traits for the multi_pippenger Metal kernel. -// -// p = 2^256 - 2^32 - 977 -// y^2 = x^3 + 7 -// -// Constants mirror cpp/secp256k1/cpp/field.hpp. - -#pragma once - -constant uint64_t MP_P[4] = { - 0xFFFFFFFEFFFFFC2FUL, 0xFFFFFFFFFFFFFFFFUL, - 0xFFFFFFFFFFFFFFFFUL, 0xFFFFFFFFFFFFFFFFUL -}; -constant uint64_t MP_P_INV = 0xD838091DD2253531UL; -constant uint64_t MP_R2[4] = { - 0x000007A2000E90A1UL, 0x0000000000000001UL, - 0x0000000000000000UL, 0x0000000000000000UL -}; -constant int MP_FIELD_LIMBS = 4; -constant int MP_BITS = 256; diff --git a/gpukit/gpu/metal/batch_inversion.metal b/gpukit/gpu/metal/batch_inversion.metal deleted file mode 100644 index 7e53799..0000000 --- a/gpukit/gpu/metal/batch_inversion.metal +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Skeleton kernel entry points for Montgomery batch inversion. v1.1 publishes -// stable function names so the metallib has them; the host driver returns -// NOTIMPL until the BLS Stage 3+ port wires this up. - -#include -using namespace metal; - -kernel void batch_inv_secp256k1_pointwise_mul( - uint gid [[ thread_position_in_grid ]]) { (void)gid; } - -kernel void batch_inv_bn254_pointwise_mul( - uint gid [[ thread_position_in_grid ]]) { (void)gid; } - -kernel void batch_inv_bls12_381_pointwise_mul( - uint gid [[ thread_position_in_grid ]]) { (void)gid; } diff --git a/gpukit/gpu/metal/batch_inversion_driver.mm b/gpukit/gpu/metal/batch_inversion_driver.mm deleted file mode 100644 index 1efc3c8..0000000 --- a/gpukit/gpu/metal/batch_inversion_driver.mm +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Metal batch_inversion driver -- v1.1 ships CPU only. The Montgomery batch -// inversion kernels for 256-bit (secp256k1, BN254) and 384-bit (BLS12-381) -// fields are owned by the BLS Stage 3+ port (sibling agent), which will land -// the kernel and replace these stubs with a real Metal dispatch. - -#if __APPLE__ && __OBJC__ - -#include "lux/gpukit/batch_inversion.h" -#include "lux/gpukit/gpukit.h" - -extern "C" int gpukit_batch_inv_secp256k1_fp_metal(const uint8_t*, uint8_t*, size_t) { - return GPUKIT_ERR_NOTIMPL; -} -extern "C" int gpukit_batch_inv_bn254_fp_metal(const uint8_t*, uint8_t*, size_t) { - return GPUKIT_ERR_NOTIMPL; -} -extern "C" int gpukit_batch_inv_bls12_381_fp_metal(const uint8_t*, uint8_t*, size_t) { - return GPUKIT_ERR_NOTIMPL; -} - -#endif // __APPLE__ && __OBJC__ diff --git a/gpukit/gpu/metal/compaction.metal b/gpukit/gpu/metal/compaction.metal deleted file mode 100644 index 5a52821..0000000 --- a/gpukit/gpu/metal/compaction.metal +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Stream compaction. Two-stage: -// 1. compaction_mark: convert flags[i]!=0 to 0/1, store in scan_in. -// 2. host runs prefix_sum on scan_in to produce scan_out (exclusive scan -// conceptually, computed as inclusive scan minus self). -// 3. compaction_scatter: if flags[i] then out[scan_out[i] - 1] = in[i]. -// -// We use inclusive scan and shift on-the-fly inside scatter (cheaper than -// a separate exclusive-scan kernel). - -#include -using namespace metal; - -kernel void compaction_mark_u32( - device const uchar* flags [[ buffer(0) ]], - device uint* scan_in [[ buffer(1) ]], - constant uint& n [[ buffer(2) ]], - uint gid [[ thread_position_in_grid ]]) -{ - if (gid >= n) return; - scan_in[gid] = (flags[gid] != 0) ? 1u : 0u; -} - -kernel void compaction_scatter_u32( - device const uint* in [[ buffer(0) ]], - device const uchar* flags [[ buffer(1) ]], - device const uint* scan_out [[ buffer(2) ]], // inclusive scan of marks - device uint* out [[ buffer(3) ]], - constant uint& n [[ buffer(4) ]], - uint gid [[ thread_position_in_grid ]]) -{ - if (gid >= n) return; - if (flags[gid] == 0) return; - uint dst = scan_out[gid] - 1u; - out[dst] = in[gid]; -} diff --git a/gpukit/gpu/metal/compaction_driver.mm b/gpukit/gpu/metal/compaction_driver.mm deleted file mode 100644 index 8c20245..0000000 --- a/gpukit/gpu/metal/compaction_driver.mm +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Metal driver for stream compaction. Uses prefix_sum kernels (loaded from the -// same metallib) for the scan stage. - -#if __APPLE__ && __OBJC__ - -#import -#import - -#include "lux/gpukit/compaction.h" -#include "lux/gpukit/prefix_sum.h" -#include "lux/gpukit/gpukit.h" -#include -#include -#include - -static NSURL* gpukit_metallib_url() { - const char* p = std::getenv("GPUKIT_METALLIB"); - if (p && *p) return [NSURL fileURLWithPath:[NSString stringWithUTF8String:p]]; - return [NSURL fileURLWithPath:@"libgpukit.metallib"]; -} - -extern "C" int gpukit_compact_u32_metal(const uint32_t* in, const uint8_t* flags, - uint32_t* out, size_t n, size_t* n_out) { - if (!in || !flags || !out || !n_out) return GPUKIT_ERR_NULL_ARG; - if (n == 0) { *n_out = 0; return GPUKIT_OK; } - - @autoreleasepool { - id device = MTLCreateSystemDefaultDevice(); - if (!device) return GPUKIT_ERR_BACKEND; - NSError* err = nil; - id lib = [device newLibraryWithURL:gpukit_metallib_url() error:&err]; - if (!lib) return GPUKIT_ERR_BACKEND; - - id fn_mark = [lib newFunctionWithName:@"compaction_mark_u32"]; - id fn_scat = [lib newFunctionWithName:@"compaction_scatter_u32"]; - if (!fn_mark || !fn_scat) return GPUKIT_ERR_BACKEND; - - id pso_mark = - [device newComputePipelineStateWithFunction:fn_mark error:&err]; - id pso_scat = - [device newComputePipelineStateWithFunction:fn_scat error:&err]; - if (!pso_mark || !pso_scat) return GPUKIT_ERR_BACKEND; - - id queue = [device newCommandQueue]; - - id buf_in = [device newBufferWithBytes:in length:n*sizeof(uint32_t) - options:MTLResourceStorageModeShared]; - id buf_flags = [device newBufferWithBytes:flags length:n - options:MTLResourceStorageModeShared]; - id buf_marks = [device newBufferWithLength:n*sizeof(uint32_t) - options:MTLResourceStorageModeShared]; - id buf_out = [device newBufferWithLength:n*sizeof(uint32_t) - options:MTLResourceStorageModeShared]; - uint32_t n_u32 = (uint32_t)n; - id buf_n = [device newBufferWithBytes:&n_u32 length:sizeof(n_u32) - options:MTLResourceStorageModeShared]; - - // Stage 1: mark. - { - id cmd = [queue commandBuffer]; - id enc = [cmd computeCommandEncoder]; - [enc setComputePipelineState:pso_mark]; - [enc setBuffer:buf_flags offset:0 atIndex:0]; - [enc setBuffer:buf_marks offset:0 atIndex:1]; - [enc setBuffer:buf_n offset:0 atIndex:2]; - NSUInteger tgs = pso_mark.maxTotalThreadsPerThreadgroup; - MTLSize tg = MTLSizeMake(tgs, 1, 1); - MTLSize grid = MTLSizeMake(n, 1, 1); - [enc dispatchThreads:grid threadsPerThreadgroup:tg]; - [enc endEncoding]; - [cmd commit]; [cmd waitUntilCompleted]; - } - // Stage 2: scan marks (in-place inclusive). - std::vector scan(n); - std::memcpy(scan.data(), buf_marks.contents, n*sizeof(uint32_t)); - // Reuse host CPU prefix sum -- byte-equal target. - gpukit_prefix_sum_u32_cpu(scan.data(), scan.data(), n); - std::memcpy(buf_marks.contents, scan.data(), n*sizeof(uint32_t)); - - // Stage 3: scatter. - { - id cmd = [queue commandBuffer]; - id enc = [cmd computeCommandEncoder]; - [enc setComputePipelineState:pso_scat]; - [enc setBuffer:buf_in offset:0 atIndex:0]; - [enc setBuffer:buf_flags offset:0 atIndex:1]; - [enc setBuffer:buf_marks offset:0 atIndex:2]; - [enc setBuffer:buf_out offset:0 atIndex:3]; - [enc setBuffer:buf_n offset:0 atIndex:4]; - NSUInteger tgs = pso_scat.maxTotalThreadsPerThreadgroup; - MTLSize tg = MTLSizeMake(tgs, 1, 1); - MTLSize grid = MTLSizeMake(n, 1, 1); - [enc dispatchThreads:grid threadsPerThreadgroup:tg]; - [enc endEncoding]; - [cmd commit]; [cmd waitUntilCompleted]; - } - size_t k = scan[n-1]; - std::memcpy(out, buf_out.contents, k*sizeof(uint32_t)); - *n_out = k; - return GPUKIT_OK; - } -} - -#endif // __APPLE__ && __OBJC__ diff --git a/gpukit/gpu/metal/merkle_compose.metal b/gpukit/gpu/metal/merkle_compose.metal deleted file mode 100644 index 81099e3..0000000 --- a/gpukit/gpu/metal/merkle_compose.metal +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Skeleton entry point for parallel Merkle compose. Real impl v1.2. - -#include -using namespace metal; - -kernel void merkle_compose_layer( - uint gid [[ thread_position_in_grid ]]) { (void)gid; } diff --git a/gpukit/gpu/metal/merkle_compose_driver.mm b/gpukit/gpu/metal/merkle_compose_driver.mm deleted file mode 100644 index 5cf82d2..0000000 --- a/gpukit/gpu/metal/merkle_compose_driver.mm +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Metal merkle_compose driver -- v1.1 ships CPU only. Parallel Keccak-256 -// inner hashing is part of the keccak Stage 2 GPU port (sibling agent owns -// the per-leaf parallel Keccak kernel). When that lands, this driver fans out -// the leaves into a Keccak-256-batch dispatch, then walks the tree on GPU. - -#if __APPLE__ && __OBJC__ - -#include "lux/gpukit/merkle_compose.h" -#include "lux/gpukit/gpukit.h" - -extern "C" int gpukit_merkle_root_metal(const uint8_t*, size_t, uint8_t[32]) { - return GPUKIT_ERR_NOTIMPL; -} - -#endif // __APPLE__ && __OBJC__ diff --git a/gpukit/gpu/metal/multi_pippenger.metal b/gpukit/gpu/metal/multi_pippenger.metal deleted file mode 100644 index 836c5ff..0000000 --- a/gpukit/gpu/metal/multi_pippenger.metal +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Multi-curve Pippenger MSM kernel skeleton (Metal). -// -// One bucket-sort + reduction skeleton, parameterised by the per-curve traits -// header included by each entry-point shim. Metal does not support runtime -// template instantiation; we therefore ship four entry points, one per curve, -// each #including the shared body below with the appropriate traits. -// -// Wire format on input: -// points : n * 64 bytes (or n * 96 for BLS12-381 G1). -// Affine x || y, big-endian per coordinate. -// All-zero coordinate bytes denote the point at infinity. -// scalars : n * 32 bytes, little-endian canonical. -// -// Window-bucket Pippenger: -// c-bit window over each scalar (LSW->MSW), digits in [0, 2^c). -// bucket[d-1] += point for digit d > 0. -// running-sum reduction: -// total = 0; running = 0; -// for d = 2^c - 1 downto 1: -// running += bucket[d-1]; total += running; -// -// Window combine (host side or final kernel pass): -// result = sum_w 2^{c*w} * window_sum[w] -// -// v1.1 ships the kernel signatures + bucket-sort body; the byte-equal -// validation against the CPU oracle is scheduled for v1.2 (see -// multi_pippenger_driver.mm). The driver returns NOTIMPL until that lands. -// -// Curve specialisations live in: -// multi_pippenger_secp256k1.metal (entry: msm_secp256k1_window) -// multi_pippenger_bn254_g1.metal (entry: msm_bn254_g1_window) -// multi_pippenger_bls12_381_g1.metal (entry: msm_bls12_381_g1_window) -// multi_pippenger_banderwagon.metal (entry: msm_banderwagon_window) -// -// Each specialisation includes its traits + this body. - -#include -using namespace metal; - -// ============================================================================= -// Shared bucket-sort skeleton -- declared inline so each curve specialisation -// instantiates its own copy with the traits-defined field arithmetic. -// ============================================================================= -// -// The shared skeleton expects the including TU to define: -// MP_FIELD_LIMBS : int (4 or 6) -// MP_BITS : int (curve scalar bit-width; 255 / 254 / 256 / 381) -// MP_P[] : modulus -// MP_P_INV : -p^-1 mod 2^64 -// MP_R2[] : R^2 mod p (R = 2^(64*MP_FIELD_LIMBS)) -// -// And the structs: -// MpField { uint64_t limbs[MP_FIELD_LIMBS]; } -// MpPointAffine { MpField x; MpField y; } -// -// Plus arithmetic functions: -// MpField mp_field_add(MpField, MpField); -// MpField mp_field_sub(MpField, MpField); -// MpField mp_field_mul(MpField, MpField); -// MpPointAffine mp_point_add(MpPointAffine, MpPointAffine); -// MpPointAffine mp_point_neg(MpPointAffine); -// MpPointAffine mp_point_identity(); -// -// The per-curve specialisation source provides those definitions; this file -// holds only the pattern documentation. The bucket-sort body itself is -// implemented per-specialisation because Metal cannot template across -// translation-unit boundaries. - -// Window size must match the CPU reference (best_c selection in -// multi_pippenger.cpp). For Metal we ship a single fixed c=8 to keep the -// kernel simple; the dispatcher only routes to Metal when the CPU's chosen -// c equals 8. -constant uint kWindowBits = 8u; -constant uint kBuckets = (1u << kWindowBits) - 1u; diff --git a/gpukit/gpu/metal/multi_pippenger_banderwagon.metal b/gpukit/gpu/metal/multi_pippenger_banderwagon.metal deleted file mode 100644 index 05d41dd..0000000 --- a/gpukit/gpu/metal/multi_pippenger_banderwagon.metal +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Banderwagon specialisation of the multi-curve Pippenger MSM kernel (Metal). -// Twisted-Edwards group with a = -5; addition formula differs from short -// Weierstrass but the bucket-sort skeleton is identical. - -#include -using namespace metal; - -#include "../curve_traits/banderwagon_traits.h.metal" - -struct MpField4 { uint64_t limbs[4]; }; -struct MpPointAffine4 { MpField4 x; MpField4 y; }; - -[[host_name("msm_banderwagon_window")]] -kernel void msm_banderwagon_window_kernel( - const device MpPointAffine4* points [[ buffer(0) ]], - const device uint8_t* scalars_le [[ buffer(1) ]], - device MpPointAffine4* bucket_out [[ buffer(2) ]], - constant uint& n [[ buffer(3) ]], - constant uint& window_idx [[ buffer(4) ]], - uint gid [[ thread_position_in_grid ]]) -{ - if (gid >= n) return; - bucket_out[gid] = points[gid]; - (void)scalars_le; (void)window_idx; -} diff --git a/gpukit/gpu/metal/multi_pippenger_bls12_381_g1.metal b/gpukit/gpu/metal/multi_pippenger_bls12_381_g1.metal deleted file mode 100644 index 6c40e34..0000000 --- a/gpukit/gpu/metal/multi_pippenger_bls12_381_g1.metal +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// BLS12-381 G1 specialisation of the multi-curve Pippenger MSM kernel (Metal). -// 6-limb 384-bit field. See gpu/metal/multi_pippenger.metal. - -#include -using namespace metal; - -#include "../curve_traits/bls12_381_g1_traits.h.metal" - -struct MpField6 { uint64_t limbs[6]; }; -struct MpPointAffine6 { MpField6 x; MpField6 y; }; - -[[host_name("msm_bls12_381_g1_window")]] -kernel void msm_bls12_381_g1_window_kernel( - const device MpPointAffine6* points [[ buffer(0) ]], - const device uint8_t* scalars_le [[ buffer(1) ]], - device MpPointAffine6* bucket_out [[ buffer(2) ]], - constant uint& n [[ buffer(3) ]], - constant uint& window_idx [[ buffer(4) ]], - uint gid [[ thread_position_in_grid ]]) -{ - if (gid >= n) return; - bucket_out[gid] = points[gid]; - (void)scalars_le; (void)window_idx; -} diff --git a/gpukit/gpu/metal/multi_pippenger_bn254_g1.metal b/gpukit/gpu/metal/multi_pippenger_bn254_g1.metal deleted file mode 100644 index fb07c87..0000000 --- a/gpukit/gpu/metal/multi_pippenger_bn254_g1.metal +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// BN254 G1 specialisation of the multi-curve Pippenger MSM kernel (Metal). -// See gpu/metal/multi_pippenger.metal for the shared algorithm. - -#include -using namespace metal; - -#include "../curve_traits/bn254_g1_traits.h.metal" - -struct MpField4 { uint64_t limbs[4]; }; -struct MpPointAffine4 { MpField4 x; MpField4 y; }; - -[[host_name("msm_bn254_g1_window")]] -kernel void msm_bn254_g1_window_kernel( - const device MpPointAffine4* points [[ buffer(0) ]], - const device uint8_t* scalars_le [[ buffer(1) ]], - device MpPointAffine4* bucket_out [[ buffer(2) ]], - constant uint& n [[ buffer(3) ]], - constant uint& window_idx [[ buffer(4) ]], - uint gid [[ thread_position_in_grid ]]) -{ - if (gid >= n) return; - bucket_out[gid] = points[gid]; - (void)scalars_le; (void)window_idx; -} diff --git a/gpukit/gpu/metal/multi_pippenger_driver.mm b/gpukit/gpu/metal/multi_pippenger_driver.mm deleted file mode 100644 index dd57211..0000000 --- a/gpukit/gpu/metal/multi_pippenger_driver.mm +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Metal driver for the multi-curve Pippenger MSM kernel. -// -// v1.1 ships the CPU reference (gpukit_multi_pippenger_cpu) and the kernel -// source files in gpu/metal/multi_pippenger_*.metal. Cross-curve byte-equal -// validation against the CPU oracle is scheduled for v1.2; until then the -// driver returns NOTIMPL honestly so the caller can fall back to CPU. -// -// Wiring the dispatch path is pinned by the kernel skeleton in -// multi_pippenger.metal: the bucket-sort + reduction live in a curve-agnostic -// kernel that #includes _traits.h.metal>; one metallib -// entry point per curve is exposed. - -#if __APPLE__ && __OBJC__ - -#include "lux/gpukit/multi_pippenger.h" -#include "lux/gpukit/gpukit.h" - -extern "C" int gpukit_multi_pippenger_metal(uint32_t curve, - const uint8_t* /*scalars*/, - const uint8_t* /*points*/, - size_t /*n*/, - uint8_t* /*result*/) { - (void)curve; - return GPUKIT_ERR_NOTIMPL; -} - -#endif // __APPLE__ && __OBJC__ diff --git a/gpukit/gpu/metal/multi_pippenger_secp256k1.metal b/gpukit/gpu/metal/multi_pippenger_secp256k1.metal deleted file mode 100644 index 2120e17..0000000 --- a/gpukit/gpu/metal/multi_pippenger_secp256k1.metal +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// secp256k1 specialisation of the multi-curve Pippenger MSM kernel (Metal). -// See gpu/metal/multi_pippenger.metal for the shared algorithm. v1.1 ships -// the entry-point signature so the metallib build target exists; the body -// is filled by v1.2 once cross-curve byte-equality vs the CPU reference -// is pinned. - -#include -using namespace metal; - -#include "../curve_traits/secp256k1_traits.h.metal" - -struct MpField4 { uint64_t limbs[4]; }; -struct MpPointAffine4 { MpField4 x; MpField4 y; }; - -// Window-pass entry. One thread per (point, scalar) pair within a single -// c-bit window. Bucket aggregation is finished on the host (reduction tree -// in Metal threadgroup memory is the v1.2 work). -[[host_name("msm_secp256k1_window")]] -kernel void msm_secp256k1_window_kernel( - const device MpPointAffine4* points [[ buffer(0) ]], - const device uint8_t* scalars_le [[ buffer(1) ]], - device MpPointAffine4* bucket_out [[ buffer(2) ]], - constant uint& n [[ buffer(3) ]], - constant uint& window_idx [[ buffer(4) ]], - uint gid [[ thread_position_in_grid ]]) -{ - if (gid >= n) return; - // v1.1: signature-only. v1.2 fills the bucket extraction + per-thread - // accumulation against the threadgroup reduction tree. - bucket_out[gid] = points[gid]; - (void)scalars_le; (void)window_idx; -} diff --git a/gpukit/gpu/metal/ntt.metal b/gpukit/gpu/metal/ntt.metal deleted file mode 100644 index 7cfb3be..0000000 --- a/gpukit/gpu/metal/ntt.metal +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Skeleton entry points for forward NTT (Kyber/Dilithium) -- v1.2 work. - -#include -using namespace metal; - -kernel void ntt_kyber_butterfly( - uint gid [[ thread_position_in_grid ]]) { (void)gid; } - -kernel void ntt_dilithium_butterfly( - uint gid [[ thread_position_in_grid ]]) { (void)gid; } diff --git a/gpukit/gpu/metal/ntt_driver.mm b/gpukit/gpu/metal/ntt_driver.mm deleted file mode 100644 index 513b458..0000000 --- a/gpukit/gpu/metal/ntt_driver.mm +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Metal NTT driver -- v1.1 ships CPU only. Forward and inverse NTT on Metal -// is straightforward (radix-2 butterfly with bit-reversed roots) but pinning -// it byte-equal across Apple's tile schedulers requires a deterministic warp -// order which is the v1.2 work. For now we honestly return NOTIMPL so the -// caller can fall back to CPU. - -#if __APPLE__ && __OBJC__ - -#include "lux/gpukit/ntt.h" -#include "lux/gpukit/gpukit.h" - -extern "C" int gpukit_ntt_kyber_forward_metal(int32_t*, size_t) { return GPUKIT_ERR_NOTIMPL; } -extern "C" int gpukit_ntt_kyber_negacyclic_mul_metal(const int32_t*, const int32_t*, int32_t*, size_t) { - return GPUKIT_ERR_NOTIMPL; -} -extern "C" int gpukit_ntt_dilithium_forward_metal(int32_t*, size_t) { return GPUKIT_ERR_NOTIMPL; } - -#endif // __APPLE__ && __OBJC__ diff --git a/gpukit/gpu/metal/prefix_sum.metal b/gpukit/gpu/metal/prefix_sum.metal deleted file mode 100644 index 38d5ef6..0000000 --- a/gpukit/gpu/metal/prefix_sum.metal +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Inclusive prefix sum (scan). -// -// Two entry points are exported: -// * prefix_sum_block_u32 / _u64 -- single-block scan (max 1024 lanes) -// * prefix_sum_collect_u32 / _u64 -- absorb block-sums into successor blocks -// -// The host driver picks single-block dispatch when N <= 1024 and otherwise -// emits a two-pass scan (block-scan + serial collect of block sums + add-back). - -#include -using namespace metal; - -#define BLOCK_SIZE 1024u - -// ----- u32 single-block scan ------------------------------------------------ - -kernel void prefix_sum_block_u32( - device const uint* in [[ buffer(0) ]], - device uint* out [[ buffer(1) ]], - device uint* block_sums [[ buffer(2) ]], - constant uint& n [[ buffer(3) ]], - uint lid [[ thread_position_in_threadgroup ]], - uint bid [[ threadgroup_position_in_grid ]]) -{ - threadgroup uint scratch[BLOCK_SIZE]; - uint base = bid * BLOCK_SIZE; - uint i = base + lid; - scratch[lid] = (i < n) ? in[i] : 0u; - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Hillis-Steele inclusive scan within the block. - for (uint d = 1; d < BLOCK_SIZE; d <<= 1) { - uint v = scratch[lid]; - uint w = (lid >= d) ? scratch[lid - d] : 0u; - threadgroup_barrier(mem_flags::mem_threadgroup); - scratch[lid] = v + w; - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - if (i < n) out[i] = scratch[lid]; - if (lid == BLOCK_SIZE - 1u) { - // Last lane writes the block's total sum. - block_sums[bid] = scratch[lid]; - } -} - -kernel void prefix_sum_collect_u32( - device uint* out [[ buffer(0) ]], - device const uint* block_prefix [[ buffer(1) ]], // prefix sum of block_sums - constant uint& n [[ buffer(2) ]], - uint gid [[ thread_position_in_grid ]]) -{ - if (gid >= n) return; - uint b = gid / BLOCK_SIZE; - if (b == 0u) return; - out[gid] += block_prefix[b - 1u]; -} - -// ----- u64 single-block scan ------------------------------------------------ - -kernel void prefix_sum_block_u64( - device const ulong* in [[ buffer(0) ]], - device ulong* out [[ buffer(1) ]], - device ulong* block_sums [[ buffer(2) ]], - constant uint& n [[ buffer(3) ]], - uint lid [[ thread_position_in_threadgroup ]], - uint bid [[ threadgroup_position_in_grid ]]) -{ - threadgroup ulong scratch[BLOCK_SIZE]; - uint base = bid * BLOCK_SIZE; - uint i = base + lid; - scratch[lid] = (i < n) ? in[i] : 0ul; - threadgroup_barrier(mem_flags::mem_threadgroup); - - for (uint d = 1; d < BLOCK_SIZE; d <<= 1) { - ulong v = scratch[lid]; - ulong w = (lid >= d) ? scratch[lid - d] : 0ul; - threadgroup_barrier(mem_flags::mem_threadgroup); - scratch[lid] = v + w; - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - if (i < n) out[i] = scratch[lid]; - if (lid == BLOCK_SIZE - 1u) { - block_sums[bid] = scratch[lid]; - } -} - -kernel void prefix_sum_collect_u64( - device ulong* out [[ buffer(0) ]], - device const ulong* block_prefix [[ buffer(1) ]], - constant uint& n [[ buffer(2) ]], - uint gid [[ thread_position_in_grid ]]) -{ - if (gid >= n) return; - uint b = gid / BLOCK_SIZE; - if (b == 0u) return; - out[gid] += block_prefix[b - 1u]; -} diff --git a/gpukit/gpu/metal/prefix_sum_driver.mm b/gpukit/gpu/metal/prefix_sum_driver.mm deleted file mode 100644 index bb69c1e..0000000 --- a/gpukit/gpu/metal/prefix_sum_driver.mm +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Metal driver for prefix_sum (u32 / u64). Two-pass when N > BLOCK_SIZE. - -#if __APPLE__ && __OBJC__ - -#import -#import - -#include "lux/gpukit/prefix_sum.h" -#include "lux/gpukit/gpukit.h" -#include -#include -#include - -static constexpr uint32_t BLOCK_SIZE = 1024; - -// Metal library path: caller can override via GPUKIT_METALLIB env var; default -// is libgpukit.metallib next to the test binary's working directory. -static NSURL* gpukit_metallib_url() { - const char* p = std::getenv("GPUKIT_METALLIB"); - if (p && *p) { - return [NSURL fileURLWithPath:[NSString stringWithUTF8String:p]]; - } - return [NSURL fileURLWithPath:@"libgpukit.metallib"]; -} - -namespace { - -template -int run_prefix_sum(const T* in, T* out, size_t n, - const char* block_kernel, const char* collect_kernel) { - if (!in || !out) return GPUKIT_ERR_NULL_ARG; - if (n == 0) return GPUKIT_OK; - - @autoreleasepool { - id device = MTLCreateSystemDefaultDevice(); - if (!device) return GPUKIT_ERR_BACKEND; - - NSError* err = nil; - id lib = [device newLibraryWithURL:gpukit_metallib_url() error:&err]; - if (!lib) return GPUKIT_ERR_BACKEND; - - id fn_block = [lib newFunctionWithName: - [NSString stringWithUTF8String:block_kernel]]; - id fn_coll = [lib newFunctionWithName: - [NSString stringWithUTF8String:collect_kernel]]; - if (!fn_block || !fn_coll) return GPUKIT_ERR_BACKEND; - - id pso_block = - [device newComputePipelineStateWithFunction:fn_block error:&err]; - id pso_coll = - [device newComputePipelineStateWithFunction:fn_coll error:&err]; - if (!pso_block || !pso_coll) return GPUKIT_ERR_BACKEND; - - id queue = [device newCommandQueue]; - - size_t in_bytes = n * sizeof(T); - size_t num_blocks = (n + BLOCK_SIZE - 1) / BLOCK_SIZE; - id buf_in = [device newBufferWithBytes:in length:in_bytes - options:MTLResourceStorageModeShared]; - id buf_out = [device newBufferWithLength:in_bytes - options:MTLResourceStorageModeShared]; - id buf_block_sums = [device newBufferWithLength:num_blocks * sizeof(T) - options:MTLResourceStorageModeShared]; - uint32_t n_u32 = (uint32_t)n; - id buf_n = [device newBufferWithBytes:&n_u32 length:sizeof(n_u32) - options:MTLResourceStorageModeShared]; - - // Pass 1: per-block inclusive scan. - { - id cmd = [queue commandBuffer]; - id enc = [cmd computeCommandEncoder]; - [enc setComputePipelineState:pso_block]; - [enc setBuffer:buf_in offset:0 atIndex:0]; - [enc setBuffer:buf_out offset:0 atIndex:1]; - [enc setBuffer:buf_block_sums offset:0 atIndex:2]; - [enc setBuffer:buf_n offset:0 atIndex:3]; - MTLSize tg = MTLSizeMake(BLOCK_SIZE, 1, 1); - MTLSize grid = MTLSizeMake(num_blocks * BLOCK_SIZE, 1, 1); - [enc dispatchThreads:grid threadsPerThreadgroup:tg]; - [enc endEncoding]; - [cmd commit]; - [cmd waitUntilCompleted]; - } - - if (num_blocks > 1) { - // Serially scan block sums on the host (cheap; tiny array) then - // dispatch the collect kernel. - T* bs = (T*)buf_block_sums.contents; - for (size_t i = 1; i < num_blocks; ++i) bs[i] += bs[i-1]; - - id cmd = [queue commandBuffer]; - id enc = [cmd computeCommandEncoder]; - [enc setComputePipelineState:pso_coll]; - [enc setBuffer:buf_out offset:0 atIndex:0]; - [enc setBuffer:buf_block_sums offset:0 atIndex:1]; - [enc setBuffer:buf_n offset:0 atIndex:2]; - NSUInteger tgs = pso_coll.maxTotalThreadsPerThreadgroup; - MTLSize tg = MTLSizeMake(tgs, 1, 1); - MTLSize grid = MTLSizeMake(n, 1, 1); - [enc dispatchThreads:grid threadsPerThreadgroup:tg]; - [enc endEncoding]; - [cmd commit]; - [cmd waitUntilCompleted]; - } - - std::memcpy(out, buf_out.contents, in_bytes); - return GPUKIT_OK; - } -} - -} // namespace - -extern "C" int gpukit_prefix_sum_u32_metal(const uint32_t* in, uint32_t* out, size_t n) { - return run_prefix_sum(in, out, n, "prefix_sum_block_u32", "prefix_sum_collect_u32"); -} - -extern "C" int gpukit_prefix_sum_u64_metal(const uint64_t* in, uint64_t* out, size_t n) { - return run_prefix_sum(in, out, n, "prefix_sum_block_u64", "prefix_sum_collect_u64"); -} - -#endif // __APPLE__ && __OBJC__ diff --git a/gpukit/gpu/metal/radix_sort.metal b/gpukit/gpu/metal/radix_sort.metal deleted file mode 100644 index 71f5ab3..0000000 --- a/gpukit/gpu/metal/radix_sort.metal +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// LSD radix sort, 8-bit per pass. The kernel performs one pass: -// * count: histogram a single byte position across lanes -// * scan host-side (driver runs prefix_sum_block_u32) -// * scatter: emit each input to its bucket-ordered position -// -// Two passes per pass (count + scatter) for u32 means 8 dispatches; for u64 -// it means 16. Stable because the scatter uses a per-bucket atomic counter -// keyed by a pre-scanned base offset. - -#include -using namespace metal; - -#define RADIX 256u - -kernel void radix_count_u32( - device const uint* in [[ buffer(0) ]], - device atomic_uint* hist [[ buffer(1) ]], // 256 buckets - constant uint& n [[ buffer(2) ]], - constant uint& shift [[ buffer(3) ]], - uint gid [[ thread_position_in_grid ]]) -{ - if (gid >= n) return; - uint b = (in[gid] >> shift) & 0xFFu; - atomic_fetch_add_explicit(&hist[b], 1u, memory_order_relaxed); -} - -kernel void radix_scatter_u32( - device const uint* in [[ buffer(0) ]], - device const uint* base [[ buffer(1) ]], // exclusive scan of histogram (length 256) - device atomic_uint* cursor [[ buffer(2) ]], // local cursor per bucket, len 256 - device uint* out [[ buffer(3) ]], - constant uint& n [[ buffer(4) ]], - constant uint& shift [[ buffer(5) ]], - uint gid [[ thread_position_in_grid ]]) -{ - if (gid >= n) return; - uint b = (in[gid] >> shift) & 0xFFu; - uint off = atomic_fetch_add_explicit(&cursor[b], 1u, memory_order_relaxed); - out[base[b] + off] = in[gid]; -} - -// u64 variants (same shape, wider value). - -kernel void radix_count_u64( - device const ulong* in [[ buffer(0) ]], - device atomic_uint* hist [[ buffer(1) ]], - constant uint& n [[ buffer(2) ]], - constant uint& shift [[ buffer(3) ]], - uint gid [[ thread_position_in_grid ]]) -{ - if (gid >= n) return; - uint b = (uint)((in[gid] >> shift) & 0xFFul); - atomic_fetch_add_explicit(&hist[b], 1u, memory_order_relaxed); -} - -kernel void radix_scatter_u64( - device const ulong* in [[ buffer(0) ]], - device const uint* base [[ buffer(1) ]], - device atomic_uint* cursor [[ buffer(2) ]], - device ulong* out [[ buffer(3) ]], - constant uint& n [[ buffer(4) ]], - constant uint& shift [[ buffer(5) ]], - uint gid [[ thread_position_in_grid ]]) -{ - if (gid >= n) return; - uint b = (uint)((in[gid] >> shift) & 0xFFul); - uint off = atomic_fetch_add_explicit(&cursor[b], 1u, memory_order_relaxed); - out[base[b] + off] = in[gid]; -} diff --git a/gpukit/gpu/metal/radix_sort_driver.mm b/gpukit/gpu/metal/radix_sort_driver.mm deleted file mode 100644 index 142491f..0000000 --- a/gpukit/gpu/metal/radix_sort_driver.mm +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Metal radix_sort driver -- v1.1 returns NOTIMPL. -// -// The Metal kernels (radix_count_u32 / radix_scatter_u32 / *_u64) ship in -// libgpukit.metallib for forward compatibility; an atomic-cursor scatter is -// fast but not byte-equal-stable across thread groups, which fails the -// determinism harness against the stable CPU LSD radix. v1.2 will replace -// this with a warp-scan deterministic radix that maintains stability across -// the entire grid. - -#if __APPLE__ && __OBJC__ - -#include "lux/gpukit/radix_sort.h" -#include "lux/gpukit/gpukit.h" - -extern "C" int gpukit_radix_sort_u32_metal(uint32_t*, size_t) { - return GPUKIT_ERR_NOTIMPL; -} -extern "C" int gpukit_radix_sort_u64_metal(uint64_t*, size_t) { - return GPUKIT_ERR_NOTIMPL; -} - -#endif // __APPLE__ && __OBJC__ diff --git a/gpukit/gpu/metal/transcript_root.metal b/gpukit/gpu/metal/transcript_root.metal deleted file mode 100644 index 25c4443..0000000 --- a/gpukit/gpu/metal/transcript_root.metal +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Skeleton entry point for batched Fiat-Shamir transcript -- v1.2 work. - -#include -using namespace metal; - -kernel void transcript_root_finalize( - uint gid [[ thread_position_in_grid ]]) { (void)gid; } diff --git a/gpukit/gpu/metal/transcript_root_driver.mm b/gpukit/gpu/metal/transcript_root_driver.mm deleted file mode 100644 index 2c969e0..0000000 --- a/gpukit/gpu/metal/transcript_root_driver.mm +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Metal transcript_root driver -- v1.1 ships CPU only. Keccak sponge is -// inherently serial within a single transcript; the GPU win comes from -// processing multiple transcripts in parallel. That batch entry point is -// scoped to v1.2 alongside the keccak GPU port. - -#if __APPLE__ && __OBJC__ - -#include "lux/gpukit/transcript_root.h" -#include "lux/gpukit/gpukit.h" - -extern "C" int gpukit_transcript_root_metal(const char*, const uint8_t*, size_t, uint8_t[32]) { - return GPUKIT_ERR_NOTIMPL; -} - -#endif // __APPLE__ && __OBJC__ diff --git a/gpukit/gpu/wgsl/batch_inversion.wgsl b/gpukit/gpu/wgsl/batch_inversion.wgsl deleted file mode 100644 index 63ffab4..0000000 --- a/gpukit/gpu/wgsl/batch_inversion.wgsl +++ /dev/null @@ -1,3 +0,0 @@ -// SPDX-License-Identifier: BSD-3-Clause-Eco -// gpukit batch_inversion -- WGSL skeleton. Real implementation v1.2. -@compute @workgroup_size(64) fn batch_inv_step(@builtin(global_invocation_id) gid: vec3) {} diff --git a/gpukit/gpu/wgsl/batch_inversion_driver.cpp b/gpukit/gpu/wgsl/batch_inversion_driver.cpp deleted file mode 100644 index 7155d3d..0000000 --- a/gpukit/gpu/wgsl/batch_inversion_driver.cpp +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco - -#include "lux/gpukit/batch_inversion.h" -#include "lux/gpukit/gpukit.h" - -extern "C" int gpukit_batch_inv_secp256k1_fp_wgsl(const uint8_t*, uint8_t*, size_t) { return GPUKIT_ERR_NOTIMPL; } -extern "C" int gpukit_batch_inv_bn254_fp_wgsl(const uint8_t*, uint8_t*, size_t) { return GPUKIT_ERR_NOTIMPL; } -extern "C" int gpukit_batch_inv_bls12_381_fp_wgsl(const uint8_t*, uint8_t*, size_t) { return GPUKIT_ERR_NOTIMPL; } diff --git a/gpukit/gpu/wgsl/compaction.wgsl b/gpukit/gpu/wgsl/compaction.wgsl deleted file mode 100644 index eef8884..0000000 --- a/gpukit/gpu/wgsl/compaction.wgsl +++ /dev/null @@ -1,25 +0,0 @@ -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// gpukit -- stream compaction. mark + scatter (scan dispatched separately). - -@group(0) @binding(0) var in_buf: array; -@group(0) @binding(1) var flags: array; // packed bytes? use u32 per slot -@group(0) @binding(2) var scan_buf: array; -@group(0) @binding(3) var out_buf: array; -@group(0) @binding(4) var n_uniform: u32; -@group(0) @binding(5) var marks_buf: array; - -@compute @workgroup_size(256) -fn compaction_mark_u32(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= n_uniform) { return; } - if (flags[gid.x] != 0u) { marks_buf[gid.x] = 1u; } - else { marks_buf[gid.x] = 0u; } -} - -@compute @workgroup_size(256) -fn compaction_scatter_u32(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= n_uniform) { return; } - if (flags[gid.x] == 0u) { return; } - let dst = scan_buf[gid.x] - 1u; - out_buf[dst] = in_buf[gid.x]; -} diff --git a/gpukit/gpu/wgsl/compaction_driver.cpp b/gpukit/gpu/wgsl/compaction_driver.cpp deleted file mode 100644 index 4daf7a8..0000000 --- a/gpukit/gpu/wgsl/compaction_driver.cpp +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco - -#include "lux/gpukit/compaction.h" -#include "lux/gpukit/gpukit.h" - -extern "C" int gpukit_compact_u32_wgsl(const uint32_t*, const uint8_t*, uint32_t*, size_t, size_t*) { - return GPUKIT_ERR_NOTIMPL; -} diff --git a/gpukit/gpu/wgsl/merkle_compose.wgsl b/gpukit/gpu/wgsl/merkle_compose.wgsl deleted file mode 100644 index c8bf0aa..0000000 --- a/gpukit/gpu/wgsl/merkle_compose.wgsl +++ /dev/null @@ -1,3 +0,0 @@ -// SPDX-License-Identifier: BSD-3-Clause-Eco -// gpukit merkle_compose -- WGSL skeleton. Real implementation v1.2. -@compute @workgroup_size(64) fn merkle_layer(@builtin(global_invocation_id) gid: vec3) {} diff --git a/gpukit/gpu/wgsl/merkle_compose_driver.cpp b/gpukit/gpu/wgsl/merkle_compose_driver.cpp deleted file mode 100644 index b3e3c73..0000000 --- a/gpukit/gpu/wgsl/merkle_compose_driver.cpp +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco - -#include "lux/gpukit/merkle_compose.h" -#include "lux/gpukit/gpukit.h" - -extern "C" int gpukit_merkle_root_wgsl(const uint8_t*, size_t, uint8_t[32]) { return GPUKIT_ERR_NOTIMPL; } diff --git a/gpukit/gpu/wgsl/multi_pippenger.wgsl b/gpukit/gpu/wgsl/multi_pippenger.wgsl deleted file mode 100644 index f33ea1b..0000000 --- a/gpukit/gpu/wgsl/multi_pippenger.wgsl +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// WGSL kernel for the multi-curve Pippenger MSM. -// -// WGSL has neither templates nor preprocessor #include; per-curve -// specialisation is achieved via WGSL `override` constants set by the host -// driver at pipeline-creation time. Only constants change between curves; -// the field-arithmetic body is identical at the WGSL source level. -// -// v1.1 ships the kernel signature + override constants. Cross-curve -// byte-equality validation against the CPU oracle is scheduled for v1.2. -// The driver returns NOTIMPL until then. - -// ---- Per-curve override constants (host sets these at pipeline build) ------- -// 4-limb field (secp256k1 / bn254 / banderwagon) or 6-limb (bls12-381 g1). -override curve_field_limbs : u32 = 4u; -override curve_bits : u32 = 256u; -// Modulus + Montgomery -p^-1 mod 2^32 (split into two u32 because WGSL only -// has u32; the host reconstructs from the curve's u64 const table). -override curve_p_inv_lo : u32 = 0u; -override curve_p_inv_hi : u32 = 0u; - -// Window size. v1.1 fixes this at 8 to keep the kernel simple; the host -// dispatcher only routes to WGSL when the CPU's chosen c equals 8. -override window_bits : u32 = 8u; - -// ---- Kernel signature ------------------------------------------------------- -// One thread per (point, scalar) pair within one window. Bucket aggregation -// is finished on the host. The threadgroup-reduction-tree variant is the -// v1.2 work. - -struct MpPoint4 { - x : array, // 4 u64s in 8 u32 limbs LE; x then y -}; - -@group(0) @binding(0) var points : array; -@group(0) @binding(1) var scalars_le : array; -@group(0) @binding(2) var bucket_out : array; -@group(0) @binding(3) var n : u32; -@group(0) @binding(4) var window_idx : u32; - -@compute @workgroup_size(64) -fn msm_window(@builtin(global_invocation_id) gid : vec3) { - let i = gid.x; - if (i >= n) { - return; - } - // v1.1: signature-only. v1.2 fills the bucket-extract + accumulation - // path against the WGSL workgroup-shared reduction. - bucket_out[i] = points[i]; -} diff --git a/gpukit/gpu/wgsl/multi_pippenger_driver.cpp b/gpukit/gpu/wgsl/multi_pippenger_driver.cpp deleted file mode 100644 index 6bbd1e4..0000000 --- a/gpukit/gpu/wgsl/multi_pippenger_driver.cpp +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// WGSL driver for the multi-curve Pippenger MSM kernel. -// v1.1 NOTIMPL; the kernel source ships in multi_pippenger.wgsl. - -#include "lux/gpukit/multi_pippenger.h" -#include "lux/gpukit/gpukit.h" - -extern "C" int gpukit_multi_pippenger_wgsl(uint32_t /*curve*/, - const uint8_t* /*scalars*/, - const uint8_t* /*points*/, - size_t /*n*/, - uint8_t* /*result*/) { - return GPUKIT_ERR_NOTIMPL; -} diff --git a/gpukit/gpu/wgsl/ntt.wgsl b/gpukit/gpu/wgsl/ntt.wgsl deleted file mode 100644 index 2896fbf..0000000 --- a/gpukit/gpu/wgsl/ntt.wgsl +++ /dev/null @@ -1,3 +0,0 @@ -// SPDX-License-Identifier: BSD-3-Clause-Eco -// gpukit NTT -- WGSL skeleton. Real implementation v1.2. -@compute @workgroup_size(64) fn ntt_butterfly(@builtin(global_invocation_id) gid: vec3) {} diff --git a/gpukit/gpu/wgsl/ntt_driver.cpp b/gpukit/gpu/wgsl/ntt_driver.cpp deleted file mode 100644 index 4c8da90..0000000 --- a/gpukit/gpu/wgsl/ntt_driver.cpp +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco - -#include "lux/gpukit/ntt.h" -#include "lux/gpukit/gpukit.h" - -extern "C" int gpukit_ntt_kyber_forward_wgsl(int32_t*, size_t) { return GPUKIT_ERR_NOTIMPL; } -extern "C" int gpukit_ntt_kyber_negacyclic_mul_wgsl(const int32_t*, const int32_t*, int32_t*, size_t) { return GPUKIT_ERR_NOTIMPL; } -extern "C" int gpukit_ntt_dilithium_forward_wgsl(int32_t*, size_t) { return GPUKIT_ERR_NOTIMPL; } diff --git a/gpukit/gpu/wgsl/prefix_sum.wgsl b/gpukit/gpu/wgsl/prefix_sum.wgsl deleted file mode 100644 index aaa1401..0000000 --- a/gpukit/gpu/wgsl/prefix_sum.wgsl +++ /dev/null @@ -1,47 +0,0 @@ -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// gpukit -- inclusive prefix sum (Hillis-Steele within a workgroup of 1024). -// Two-pass dispatch: per-block scan + serial collect. - -@group(0) @binding(0) var in_buf: array; -@group(0) @binding(1) var out_buf: array; -@group(0) @binding(2) var block_sums: array; -@group(0) @binding(3) var n_uniform: u32; - -var scratch: array; - -@compute @workgroup_size(1024) -fn prefix_sum_block_u32( - @builtin(local_invocation_id) lid: vec3, - @builtin(workgroup_id) bid: vec3, -) { - let i = bid.x * 1024u + lid.x; - if (i < n_uniform) { scratch[lid.x] = in_buf[i]; } - else { scratch[lid.x] = 0u; } - workgroupBarrier(); - - var d: u32 = 1u; - loop { - if (d >= 1024u) { break; } - let v = scratch[lid.x]; - var w: u32 = 0u; - if (lid.x >= d) { w = scratch[lid.x - d]; } - workgroupBarrier(); - scratch[lid.x] = v + w; - workgroupBarrier(); - d = d << 1u; - } - - if (i < n_uniform) { out_buf[i] = scratch[lid.x]; } - if (lid.x == 1023u) { block_sums[bid.x] = scratch[lid.x]; } -} - -@compute @workgroup_size(256) -fn prefix_sum_collect_u32( - @builtin(global_invocation_id) gid: vec3, -) { - if (gid.x >= n_uniform) { return; } - let b = gid.x / 1024u; - if (b == 0u) { return; } - out_buf[gid.x] = out_buf[gid.x] + block_sums[b - 1u]; -} diff --git a/gpukit/gpu/wgsl/prefix_sum_driver.cpp b/gpukit/gpu/wgsl/prefix_sum_driver.cpp deleted file mode 100644 index ac4f255..0000000 --- a/gpukit/gpu/wgsl/prefix_sum_driver.cpp +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// WGSL prefix_sum driver. -// -// v1.1 of gpukit ships the WGSL shader source files but does not yet wire up -// the wgpu-native runtime host integration. The shader files validate as -// standalone WGSL and will plug into the v1.2 runtime driver. Until then the -// callable surface returns NOTIMPL so the harness skips Metal/CUDA-equivalent -// dispatch on this primitive when GPUKIT_BACKEND=wgsl. - -#include "lux/gpukit/prefix_sum.h" -#include "lux/gpukit/gpukit.h" - -extern "C" int gpukit_prefix_sum_u32_wgsl(const uint32_t*, uint32_t*, size_t) { - return GPUKIT_ERR_NOTIMPL; -} -extern "C" int gpukit_prefix_sum_u64_wgsl(const uint64_t*, uint64_t*, size_t) { - return GPUKIT_ERR_NOTIMPL; -} diff --git a/gpukit/gpu/wgsl/radix_sort.wgsl b/gpukit/gpu/wgsl/radix_sort.wgsl deleted file mode 100644 index 2684cdb..0000000 --- a/gpukit/gpu/wgsl/radix_sort.wgsl +++ /dev/null @@ -1,29 +0,0 @@ -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// gpukit -- LSD radix sort, 8-bit per pass. count + scatter kernels. - -@group(0) @binding(0) var in_buf: array; -@group(0) @binding(1) var hist_buf: array>; -@group(0) @binding(2) var base_buf: array; -@group(0) @binding(3) var cursor_buf: array>; -@group(0) @binding(4) var out_buf: array; -@group(0) @binding(5) var params: vec2; // (n, shift) - -@compute @workgroup_size(256) -fn radix_count_u32(@builtin(global_invocation_id) gid: vec3) { - let n = params.x; - let shift = params.y; - if (gid.x >= n) { return; } - let b = (in_buf[gid.x] >> shift) & 0xFFu; - atomicAdd(&hist_buf[b], 1u); -} - -@compute @workgroup_size(256) -fn radix_scatter_u32(@builtin(global_invocation_id) gid: vec3) { - let n = params.x; - let shift = params.y; - if (gid.x >= n) { return; } - let b = (in_buf[gid.x] >> shift) & 0xFFu; - let off = atomicAdd(&cursor_buf[b], 1u); - out_buf[base_buf[b] + off] = in_buf[gid.x]; -} diff --git a/gpukit/gpu/wgsl/radix_sort_driver.cpp b/gpukit/gpu/wgsl/radix_sort_driver.cpp deleted file mode 100644 index 92c22db..0000000 --- a/gpukit/gpu/wgsl/radix_sort_driver.cpp +++ /dev/null @@ -1,8 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco - -#include "lux/gpukit/radix_sort.h" -#include "lux/gpukit/gpukit.h" - -extern "C" int gpukit_radix_sort_u32_wgsl(uint32_t*, size_t) { return GPUKIT_ERR_NOTIMPL; } -extern "C" int gpukit_radix_sort_u64_wgsl(uint64_t*, size_t) { return GPUKIT_ERR_NOTIMPL; } diff --git a/gpukit/gpu/wgsl/transcript_root.wgsl b/gpukit/gpu/wgsl/transcript_root.wgsl deleted file mode 100644 index cfea4ed..0000000 --- a/gpukit/gpu/wgsl/transcript_root.wgsl +++ /dev/null @@ -1,3 +0,0 @@ -// SPDX-License-Identifier: BSD-3-Clause-Eco -// gpukit transcript_root -- WGSL skeleton. Real implementation v1.2. -@compute @workgroup_size(64) fn transcript_finalize(@builtin(global_invocation_id) gid: vec3) {} diff --git a/gpukit/gpu/wgsl/transcript_root_driver.cpp b/gpukit/gpu/wgsl/transcript_root_driver.cpp deleted file mode 100644 index 2d22d81..0000000 --- a/gpukit/gpu/wgsl/transcript_root_driver.cpp +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco - -#include "lux/gpukit/transcript_root.h" -#include "lux/gpukit/gpukit.h" - -extern "C" int gpukit_transcript_root_wgsl(const char*, const uint8_t*, size_t, uint8_t[32]) { return GPUKIT_ERR_NOTIMPL; } diff --git a/ipa/gpu/metal/ipa_driver.h b/ipa/gpu/metal/ipa_driver.h deleted file mode 100644 index 60a6c55..0000000 --- a/ipa/gpu/metal/ipa_driver.h +++ /dev/null @@ -1,43 +0,0 @@ -// ============================================================================= -// IPA Metal driver scaffold (stencil). -// -// The real Banderwagon MSM kernel is blocked on the BLS12-381 Fp/Fr base- -// field arithmetic landing in luxcpp. Until then this header declares the -// canonical surface that the byte-equal driver will satisfy and a single -// availability probe. -// -// Once the Banderwagon arithmetic lands, two kernels become real: -// - ipa_msm : window-based MSM, batched commitments -// - ipa_verify_batch : batched verify of multiple multiproofs -// -// Both must be byte-equal to luxfi/crypto/ipa.MultiScalar / -// luxfi/crypto/ipa.CheckMultiProofBatch (Go reference). -// ============================================================================= - -#ifndef CRYPTO_IPA_DRIVER_H -#define CRYPTO_IPA_DRIVER_H - -#include -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -// 1 if a Metal device is available, 0 otherwise. Always callable. -int ipa_metal_available(void); - -// Reserved for the byte-equal MSM once the Banderwagon backend lands. -// Returns -5 (CRYPTO_ERR_NOTIMPL) at the present scaffold stage. -int ipa_msm_metal(const uint8_t* scalars, // n * 32 bytes BE Fr - const uint8_t* points, // n * 32 bytes compressed - size_t n, - uint8_t out[32], // canonical output point - const char* metallib_path); - -#ifdef __cplusplus -} -#endif - -#endif // CRYPTO_IPA_DRIVER_H diff --git a/ipa/gpu/metal/ipa_driver.mm b/ipa/gpu/metal/ipa_driver.mm deleted file mode 100644 index 2621c2e..0000000 --- a/ipa/gpu/metal/ipa_driver.mm +++ /dev/null @@ -1,40 +0,0 @@ -// ============================================================================= -// IPA Metal driver scaffold (stencil). -// -// Status: not byte-equal yet. Returns NOTIMPL for ipa_msm_metal. The -// availability probe is real and is used by callers to decide whether to -// dispatch to GPU or fall back to CPU. -// ============================================================================= - -#if __APPLE__ && __OBJC__ - -#import -#import - -#include "ipa_driver.h" - -#include -#include - -extern "C" int ipa_metal_available(void) { - @autoreleasepool { - id device = MTLCreateSystemDefaultDevice(); - return device != nil ? 1 : 0; - } -} - -extern "C" int ipa_msm_metal(const uint8_t* scalars, - const uint8_t* points, - size_t n, - uint8_t out[32], - const char* metallib_path) { - if (scalars == nullptr || points == nullptr || out == nullptr || - metallib_path == nullptr) return -1; - if (n == 0) return -1; - // Zero output so callers do not see uninit bytes. - std::memset(out, 0, 32); - // Backend body lives once Banderwagon arithmetic is in luxcpp. - return -5; // CRYPTO_ERR_NOTIMPL -} - -#endif // __APPLE__ && __OBJC__ diff --git a/keccak/gpu/cuda/keccak.cu b/keccak/gpu/cuda/keccak.cu deleted file mode 100644 index df8c646..0000000 --- a/keccak/gpu/cuda/keccak.cu +++ /dev/null @@ -1,116 +0,0 @@ -// Keccak-256 batch hashing — CUDA implementation -// Matches keccak256.metal output byte-for-byte -// One thread per hash - -#include - -#ifndef __CUDA_ARCH__ -#define __device__ -#define __global__ -#define __shared__ -struct dim3 { unsigned x, y, z; }; -static dim3 blockIdx, blockDim, threadIdx; -#endif - -__device__ static const uint64_t RC[24] = { - 0x0000000000000001ULL, 0x0000000000008082ULL, 0x800000000000808aULL, - 0x8000000080008000ULL, 0x000000000000808bULL, 0x0000000080000001ULL, - 0x8000000080008081ULL, 0x8000000000008009ULL, 0x000000000000008aULL, - 0x0000000000000088ULL, 0x0000000080008009ULL, 0x000000008000000aULL, - 0x000000008000808bULL, 0x800000000000008bULL, 0x8000000000008089ULL, - 0x8000000000008003ULL, 0x8000000000008002ULL, 0x8000000000000080ULL, - 0x000000000000800aULL, 0x800000008000000aULL, 0x8000000080008081ULL, - 0x8000000000008080ULL, 0x0000000080000001ULL, 0x8000000080008008ULL -}; - -__device__ static const int ROTC[24] = { - 1,3,6,10,15,21,28,36,45,55,2,14,27,41,56,8,25,43,62,18,39,61,20,44 -}; - -__device__ static const int PI[24] = { - 10,7,11,17,18,3,5,16,8,21,24,4,15,23,19,13,12,2,20,14,22,9,6,1 -}; - -__device__ void keccak_f1600(uint64_t* state) { - for (int round = 0; round < 24; round++) { - // Theta - uint64_t C[5], D[5]; - for (int x = 0; x < 5; x++) - C[x] = state[x] ^ state[x+5] ^ state[x+10] ^ state[x+15] ^ state[x+20]; - for (int x = 0; x < 5; x++) { - D[x] = C[(x+4)%5] ^ ((C[(x+1)%5] << 1) | (C[(x+1)%5] >> 63)); - for (int y = 0; y < 25; y += 5) - state[y+x] ^= D[x]; - } - // Rho + Pi - uint64_t t = state[1]; - for (int i = 0; i < 24; i++) { - int j = PI[i]; - uint64_t tmp = state[j]; - state[j] = (t << ROTC[i]) | (t >> (64-ROTC[i])); - t = tmp; - } - // Chi - for (int y = 0; y < 25; y += 5) { - uint64_t t0 = state[y], t1 = state[y+1], t2 = state[y+2], - t3 = state[y+3], t4 = state[y+4]; - state[y] = t0 ^ (~t1 & t2); - state[y+1] = t1 ^ (~t2 & t3); - state[y+2] = t2 ^ (~t3 & t4); - state[y+3] = t3 ^ (~t4 & t0); - state[y+4] = t4 ^ (~t0 & t1); - } - // Iota - state[0] ^= RC[round]; - } -} - -extern "C" __global__ void keccak256_batch( - const uint8_t* __restrict__ data, - const uint32_t* __restrict__ offsets, - const uint32_t* __restrict__ lengths, - uint8_t* __restrict__ outputs, - uint32_t num_inputs) -{ - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= num_inputs) return; - - uint64_t state[25] = {0}; - const uint8_t* input = data + offsets[tid]; - uint32_t len = lengths[tid]; - - // Absorb (rate = 136 bytes = 17 uint64s) - uint32_t pos = 0; - while (pos + 136 <= len) { - for (uint32_t i = 0; i < 17; i++) { - uint64_t word = 0; - for (uint32_t b = 0; b < 8; b++) - word |= (uint64_t)input[pos + i*8 + b] << (b*8); - state[i] ^= word; - } - keccak_f1600(state); - pos += 136; - } - - // Pad last block (Keccak padding: 0x01...0x80) - uint8_t block[136] = {0}; - uint32_t rem = len - pos; - for (uint32_t i = 0; i < rem; i++) - block[i] = input[pos + i]; - block[rem] = 0x01; - block[135] = 0x80; - - for (int i = 0; i < 17; i++) { - uint64_t word = 0; - for (int b = 0; b < 8; b++) - word |= (uint64_t)block[i*8 + b] << (b*8); - state[i] ^= word; - } - keccak_f1600(state); - - // Squeeze 32 bytes - uint8_t* out = outputs + tid * 32; - for (int i = 0; i < 4; i++) - for (int b = 0; b < 8; b++) - out[i*8 + b] = (state[i] >> (b*8)) & 0xFF; -} diff --git a/keccak/gpu/metal/keccak.metal b/keccak/gpu/metal/keccak.metal deleted file mode 100644 index 902b9f9..0000000 --- a/keccak/gpu/metal/keccak.metal +++ /dev/null @@ -1,166 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// Derived from evmone (Apache-2.0) -// -/// @file keccak256.metal -/// Metal compute shader for parallel Keccak-256 hashing. -/// -/// Each thread group processes one hash. Input is a buffer of (offset, length) -/// pairs pointing into a contiguous data buffer. Output is a buffer of 32-byte -/// digests. -/// -/// Algorithm: Keccak-256 (Ethereum variant, NOT NIST SHA-3) -/// - State: 5x5 x 64-bit = 1600-bit sponge -/// - Rate: 1088 bits (136 bytes) -/// - Capacity: 512 bits -/// - Rounds: 24 -/// - Padding: 0x01 || 0x00...0x00 || 0x80 (Keccak, not SHA-3's 0x06) - -#include -using namespace metal; - -// -- Round constants ---------------------------------------------------------- - -constant ulong RC[24] = { - 0x0000000000000001UL, 0x0000000000008082UL, - 0x800000000000808AUL, 0x8000000080008000UL, - 0x000000000000808BUL, 0x0000000080000001UL, - 0x8000000080008081UL, 0x8000000000008009UL, - 0x000000000000008AUL, 0x0000000000000088UL, - 0x0000000080008009UL, 0x000000008000000AUL, - 0x000000008000808BUL, 0x800000000000008BUL, - 0x8000000000008089UL, 0x8000000000008003UL, - 0x8000000000008002UL, 0x8000000000000080UL, - 0x000000000000800AUL, 0x800000008000000AUL, - 0x8000000080008081UL, 0x8000000000008080UL, - 0x0000000080000001UL, 0x8000000080008008UL, -}; - -// -- Pi-lane destination indices for the rho+pi "moving lane" sequence -------- -// Starting from lane 1, each step moves to PI_LANE[i] with rotation RHO[i]. - -constant int PI_LANE[24] = { - 10, 7, 11, 17, 18, 3, 5, 16, 8, 21, 24, 4, - 15, 23, 19, 13, 12, 2, 20, 14, 22, 9, 6, 1 -}; - -constant int RHO[24] = { - 1, 3, 6, 10, 15, 21, 28, 36, 45, 55, 2, 14, - 27, 41, 56, 8, 25, 43, 62, 18, 39, 61, 20, 44 -}; - -// -- Helpers ------------------------------------------------------------------ - -inline ulong rotl64(ulong x, int n) { - return (x << n) | (x >> (64 - n)); -} - -// -- Keccak-f[1600] permutation ----------------------------------------------- - -void keccak_f(thread ulong st[25]) { - for (int round = 0; round < 24; ++round) { - - // Theta - ulong C[5]; - for (int x = 0; x < 5; ++x) - C[x] = st[x] ^ st[x + 5] ^ st[x + 10] ^ st[x + 15] ^ st[x + 20]; - - for (int x = 0; x < 5; ++x) { - ulong d = C[(x + 4) % 5] ^ rotl64(C[(x + 1) % 5], 1); - for (int y = 0; y < 5; ++y) - st[x + 5 * y] ^= d; - } - - // Rho + Pi (unrolled moving-lane sequence) - ulong t = st[1]; - for (int i = 0; i < 24; ++i) { - ulong tmp = st[PI_LANE[i]]; - st[PI_LANE[i]] = rotl64(t, RHO[i]); - t = tmp; - } - - // Chi - for (int y = 0; y < 5; ++y) { - ulong row[5]; - for (int x = 0; x < 5; ++x) - row[x] = st[x + 5 * y]; - for (int x = 0; x < 5; ++x) - st[x + 5 * y] = row[x] ^ ((~row[(x + 1) % 5]) & row[(x + 2) % 5]); - } - - // Iota - st[0] ^= RC[round]; - } -} - -// -- Input descriptor --------------------------------------------------------- -// Each hash input is described by an offset into the data buffer and a length. - -struct HashInput { - uint offset; // byte offset into the data buffer - uint length; // number of bytes to hash -}; - -// -- Kernel ------------------------------------------------------------------- -// One thread per hash. No thread-group cooperation needed since each hash is -// independent. - -kernel void keccak256_batch( - device const HashInput* inputs [[buffer(0)]], - device const uchar* data [[buffer(1)]], - device uchar* outputs [[buffer(2)]], - constant uint& num_inputs [[buffer(3)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= num_inputs) return; - - const uint offset = inputs[tid].offset; - const uint len = inputs[tid].length; - const uint rate = 136; // 1088 bits / 8 - - // Initialize state to zero. - ulong state[25] = {}; - - // Absorb phase: process full rate-sized blocks. - uint absorbed = 0; - while (absorbed + rate <= len) { - // XOR rate bytes into state (as little-endian 64-bit words). - for (uint w = 0; w < rate / 8; ++w) { - ulong lane = 0; - for (uint b = 0; b < 8; ++b) - lane |= ulong(data[offset + absorbed + w * 8 + b]) << (b * 8); - state[w] ^= lane; - } - keccak_f(state); - absorbed += rate; - } - - // Absorb remaining bytes with padding. - // We build the final padded block in a local buffer. - uchar padded[136] = {}; - uint remaining = len - absorbed; - for (uint i = 0; i < remaining; ++i) - padded[i] = data[offset + absorbed + i]; - - // Keccak padding (NOT SHA-3): first pad byte = 0x01, last = 0x80. - // If remaining == rate-1, both bits land on the same byte (0x81). - padded[remaining] = 0x01; - padded[rate - 1] |= 0x80; - - // XOR padded block into state. - for (uint w = 0; w < rate / 8; ++w) { - ulong lane = 0; - for (uint b = 0; b < 8; ++b) - lane |= ulong(padded[w * 8 + b]) << (b * 8); - state[w] ^= lane; - } - keccak_f(state); - - // Squeeze: extract first 256 bits (32 bytes) from state. - device uchar* out = outputs + tid * 32; - for (uint w = 0; w < 4; ++w) { - ulong lane = state[w]; - for (uint b = 0; b < 8; ++b) - out[w * 8 + b] = uchar(lane >> (b * 8)); - } -} diff --git a/keccak/gpu/metal/keccak_batch.metal b/keccak/gpu/metal/keccak_batch.metal deleted file mode 100644 index 603fc9a..0000000 --- a/keccak/gpu/metal/keccak_batch.metal +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// GPU-batched Keccak-256 (Ethereum, delimiter 0x01) using the KeccakJob[] -// shape. One thread per job; the job descriptor tells the thread where its -// input lives in the flat input buffer and where to write the 32-byte output. -// -// This kernel is byte-equal to keccak/cpp/keccak.cpp::keccak256(). For inputs -// >= rate (136 bytes) the absorb loop emits multiple blocks; padding is the -// canonical pad10*1 with delimiter 0x01. - -#include -using namespace metal; - -constant ulong RC[24] = { - 0x0000000000000001UL, 0x0000000000008082UL, - 0x800000000000808AUL, 0x8000000080008000UL, - 0x000000000000808BUL, 0x0000000080000001UL, - 0x8000000080008081UL, 0x8000000000008009UL, - 0x000000000000008AUL, 0x0000000000000088UL, - 0x0000000080008009UL, 0x000000008000000AUL, - 0x000000008000808BUL, 0x800000000000008BUL, - 0x8000000000008089UL, 0x8000000000008003UL, - 0x8000000000008002UL, 0x8000000000000080UL, - 0x000000000000800AUL, 0x800000008000000AUL, - 0x8000000080008081UL, 0x8000000000008080UL, - 0x0000000080000001UL, 0x8000000080008008UL, -}; - -// Mod-64 rotation offsets matching keccak/cpp/keccak.cpp. -constant int R_OFFSETS[5][5] = { - { 0, 36, 3, 41, 18}, - { 1, 44, 10, 45, 2}, - { 62, 6, 43, 15, 61}, - { 28, 55, 25, 21, 56}, - { 27, 20, 39, 8, 14}, -}; - -inline ulong rotl64(ulong x, int n) { - n &= 63; - if (n == 0) return x; - return (x << n) | (x >> (64 - n)); -} - -inline void keccakf1600(thread ulong* a) { - ulong C[5], D[5], B[25]; - for (int round = 0; round < 24; ++round) { - for (int x = 0; x < 5; ++x) - C[x] = a[x] ^ a[x + 5] ^ a[x + 10] ^ a[x + 15] ^ a[x + 20]; - for (int x = 0; x < 5; ++x) - D[x] = C[(x + 4) % 5] ^ rotl64(C[(x + 1) % 5], 1); - for (int y = 0; y < 5; ++y) - for (int x = 0; x < 5; ++x) - a[x + 5 * y] ^= D[x]; - - for (int x = 0; x < 5; ++x) - for (int y = 0; y < 5; ++y) { - int nx = y; - int ny = (2 * x + 3 * y) % 5; - B[nx + 5 * ny] = rotl64(a[x + 5 * y], R_OFFSETS[x][y]); - } - - for (int y = 0; y < 5; ++y) { - ulong row[5]; - for (int x = 0; x < 5; ++x) row[x] = B[x + 5 * y]; - for (int x = 0; x < 5; ++x) - a[x + 5 * y] = row[x] ^ ((~row[(x + 1) % 5]) & row[(x + 2) % 5]); - } - - a[0] ^= RC[round]; - } -} - -// KeccakJob descriptor — must match crypto/keccak/cpp/keccak_service.hpp. -struct KeccakJobGPU { - uint16_t kind; - uint16_t input_offset_class; - uint32_t input_offset; - uint32_t input_len; - uint32_t output_offset; -}; - -// Kernel: one thread per job. -kernel void keccak256_jobs( - device const KeccakJobGPU* jobs [[buffer(0)]], - device const uchar* inputs [[buffer(1)]], - device uchar* outputs [[buffer(2)]], - constant uint& num_jobs [[buffer(3)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= num_jobs) return; - - KeccakJobGPU j = jobs[tid]; - const device uchar* in = inputs + j.input_offset; - device uchar* out = outputs + j.output_offset; - - const uint RATE = 136; - ulong state[25]; - for (int i = 0; i < 25; ++i) state[i] = 0; - - uint absorbed = 0; - while (j.input_len - absorbed >= RATE) { - for (uint w = 0; w < RATE / 8; ++w) { - ulong lane = 0; - for (uint b = 0; b < 8; ++b) - lane |= ulong(in[absorbed + w * 8 + b]) << (b * 8); - state[w] ^= lane; - } - keccakf1600(state); - absorbed += RATE; - } - - // Final block: copy tail + pad10*1 with delimiter 0x01. - uchar block[136]; - for (uint i = 0; i < RATE; ++i) block[i] = 0; - uint rem = j.input_len - absorbed; - for (uint i = 0; i < rem; ++i) block[i] = in[absorbed + i]; - block[rem] = 0x01; - block[RATE - 1] |= 0x80; - - for (uint w = 0; w < RATE / 8; ++w) { - ulong lane = 0; - for (uint b = 0; b < 8; ++b) - lane |= ulong(block[w * 8 + b]) << (b * 8); - state[w] ^= lane; - } - keccakf1600(state); - - // Squeeze 32 bytes. - for (uint w = 0; w < 4; ++w) { - ulong lane = state[w]; - for (uint b = 0; b < 8; ++b) - out[w * 8 + b] = uchar(lane >> (b * 8)); - } -} diff --git a/keccak/gpu/wgsl/keccak.wgsl b/keccak/gpu/wgsl/keccak.wgsl deleted file mode 100644 index 9af4697..0000000 --- a/keccak/gpu/wgsl/keccak.wgsl +++ /dev/null @@ -1,226 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Keccak-256 (Ethereum variant) compute shader in WGSL. -// -// One thread per hash. Each thread reads its input descriptor, absorbs the -// data through Keccak-f[1600], and writes a 32-byte digest. -// -// Padding: 0x01 || 0x00...0x00 || 0x80 (Keccak, NOT SHA-3's 0x06) - -struct HashInput { - offset: u32, - length: u32, -} - -@group(0) @binding(0) var inputs: array; -@group(0) @binding(1) var data: array; -@group(0) @binding(2) var outputs: array; - -// Round constants for Keccak-f[1600] split into lo/hi u32 pairs. -// WGSL has no native u64, so we emulate with vec2 (lo, hi). - -const RC_LO = array( - 0x00000001u, 0x00008082u, 0x0000808Au, 0x80008000u, - 0x0000808Bu, 0x80000001u, 0x80008081u, 0x00008009u, - 0x0000008Au, 0x00000088u, 0x80008009u, 0x8000000Au, - 0x8000808Bu, 0x0000008Bu, 0x00008089u, 0x00008003u, - 0x00008002u, 0x00000080u, 0x0000800Au, 0x8000000Au, - 0x80008081u, 0x00008080u, 0x80000001u, 0x80008008u -); - -const RC_HI = array( - 0x00000000u, 0x00000000u, 0x80000000u, 0x80000000u, - 0x00000000u, 0x00000000u, 0x80000000u, 0x80000000u, - 0x00000000u, 0x00000000u, 0x00000000u, 0x00000000u, - 0x00000000u, 0x80000000u, 0x80000000u, 0x80000000u, - 0x80000000u, 0x80000000u, 0x00000000u, 0x80000000u, - 0x80000000u, 0x80000000u, 0x00000000u, 0x80000000u -); - -const PI_LANE = array( - 10u, 7u, 11u, 17u, 18u, 3u, 5u, 16u, 8u, 21u, 24u, 4u, - 15u, 23u, 19u, 13u, 12u, 2u, 20u, 14u, 22u, 9u, 6u, 1u -); - -const RHO_OFFSETS = array( - 1u, 3u, 6u, 10u, 15u, 21u, 28u, 36u, 45u, 55u, 2u, 14u, - 27u, 41u, 56u, 8u, 25u, 43u, 62u, 18u, 39u, 61u, 20u, 44u -); - -// u64 emulation: each lane is state_lo[i], state_hi[i] -var st_lo: array; -var st_hi: array; - -fn xor64(a_lo: u32, a_hi: u32, b_lo: u32, b_hi: u32) -> vec2 { - return vec2(a_lo ^ b_lo, a_hi ^ b_hi); -} - -fn rotl64(lo: u32, hi: u32, n: u32) -> vec2 { - if (n == 0u) { return vec2(lo, hi); } - if (n == 32u) { return vec2(hi, lo); } - if (n < 32u) { - let r_lo = (lo << n) | (hi >> (32u - n)); - let r_hi = (hi << n) | (lo >> (32u - n)); - return vec2(r_lo, r_hi); - } - let m = n - 32u; - let r_lo = (hi << m) | (lo >> (32u - m)); - let r_hi = (lo << m) | (hi >> (32u - m)); - return vec2(r_lo, r_hi); -} - -fn and_not64(a_lo: u32, a_hi: u32, b_lo: u32, b_hi: u32) -> vec2 { - return vec2((~a_lo) & b_lo, (~a_hi) & b_hi); -} - -fn keccak_f() { - for (var round = 0u; round < 24u; round = round + 1u) { - // Theta - var c_lo: array; - var c_hi: array; - for (var x = 0u; x < 5u; x = x + 1u) { - c_lo[x] = st_lo[x] ^ st_lo[x+5u] ^ st_lo[x+10u] ^ st_lo[x+15u] ^ st_lo[x+20u]; - c_hi[x] = st_hi[x] ^ st_hi[x+5u] ^ st_hi[x+10u] ^ st_hi[x+15u] ^ st_hi[x+20u]; - } - for (var x = 0u; x < 5u; x = x + 1u) { - let r = rotl64(c_lo[(x+1u) % 5u], c_hi[(x+1u) % 5u], 1u); - let d_lo = c_lo[(x+4u) % 5u] ^ r.x; - let d_hi = c_hi[(x+4u) % 5u] ^ r.y; - for (var y = 0u; y < 5u; y = y + 1u) { - let idx = x + 5u * y; - st_lo[idx] = st_lo[idx] ^ d_lo; - st_hi[idx] = st_hi[idx] ^ d_hi; - } - } - - // Rho + Pi - var t_lo = st_lo[1u]; - var t_hi = st_hi[1u]; - for (var i = 0u; i < 24u; i = i + 1u) { - let dst = PI_LANE[i]; - let tmp_lo = st_lo[dst]; - let tmp_hi = st_hi[dst]; - let r = rotl64(t_lo, t_hi, RHO_OFFSETS[i]); - st_lo[dst] = r.x; - st_hi[dst] = r.y; - t_lo = tmp_lo; - t_hi = tmp_hi; - } - - // Chi - for (var y = 0u; y < 5u; y = y + 1u) { - var row_lo: array; - var row_hi: array; - for (var x = 0u; x < 5u; x = x + 1u) { - row_lo[x] = st_lo[x + 5u * y]; - row_hi[x] = st_hi[x + 5u * y]; - } - for (var x = 0u; x < 5u; x = x + 1u) { - let an = and_not64(row_lo[(x+1u) % 5u], row_hi[(x+1u) % 5u], - row_lo[(x+2u) % 5u], row_hi[(x+2u) % 5u]); - st_lo[x + 5u * y] = row_lo[x] ^ an.x; - st_hi[x + 5u * y] = row_hi[x] ^ an.y; - } - } - - // Iota - st_lo[0u] = st_lo[0u] ^ RC_LO[round]; - st_hi[0u] = st_hi[0u] ^ RC_HI[round]; - } -} - -// Read a byte from the data buffer (packed as u32 array, little-endian) -fn read_byte(byte_offset: u32) -> u32 { - let word_idx = byte_offset >> 2u; - let byte_pos = byte_offset & 3u; - return (data[word_idx] >> (byte_pos * 8u)) & 0xFFu; -} - -// Write a byte to a u32 array position in outputs -fn write_output_byte(base_word: u32, byte_in_word: u32, val: u32) { - // Atomic or on the output word would be ideal, but we build the full word - // in private memory and write once per word instead. -} - -@compute @workgroup_size(64) -fn keccak256_batch(@builtin(global_invocation_id) gid: vec3) { - let tid = gid.x; - let inp = inputs[tid]; - let offset = inp.offset; - let len = inp.length; - let rate = 136u; - - // Zero state - for (var i = 0u; i < 25u; i = i + 1u) { - st_lo[i] = 0u; - st_hi[i] = 0u; - } - - // Absorb full blocks - var absorbed = 0u; - for (; absorbed + rate <= len; absorbed = absorbed + rate) { - for (var w = 0u; w < 17u; w = w + 1u) { // rate/8 = 17 words (64-bit) - var lane_lo = 0u; - var lane_hi = 0u; - for (var b = 0u; b < 4u; b = b + 1u) { - lane_lo = lane_lo | (read_byte(offset + absorbed + w * 8u + b) << (b * 8u)); - } - for (var b = 0u; b < 4u; b = b + 1u) { - lane_hi = lane_hi | (read_byte(offset + absorbed + w * 8u + 4u + b) << (b * 8u)); - } - st_lo[w] = st_lo[w] ^ lane_lo; - st_hi[w] = st_hi[w] ^ lane_hi; - } - keccak_f(); - } - - // Final block with padding - // Build padded block in private memory - var padded_lo: array; - var padded_hi: array; - for (var w = 0u; w < 17u; w = w + 1u) { - padded_lo[w] = 0u; - padded_hi[w] = 0u; - } - - let remaining = len - absorbed; - // Copy remaining bytes into padded block - for (var i = 0u; i < remaining; i = i + 1u) { - let byte_val = read_byte(offset + absorbed + i); - let word_in_block = i >> 3u; // which 64-bit word - let byte_in_word = i & 7u; - if (byte_in_word < 4u) { - padded_lo[word_in_block] = padded_lo[word_in_block] | (byte_val << (byte_in_word * 8u)); - } else { - padded_hi[word_in_block] = padded_hi[word_in_block] | (byte_val << ((byte_in_word - 4u) * 8u)); - } - } - - // Keccak padding: byte[remaining] |= 0x01, byte[rate-1] |= 0x80 - let pad_word = remaining >> 3u; - let pad_byte = remaining & 7u; - if (pad_byte < 4u) { - padded_lo[pad_word] = padded_lo[pad_word] | (0x01u << (pad_byte * 8u)); - } else { - padded_hi[pad_word] = padded_hi[pad_word] | (0x01u << ((pad_byte - 4u) * 8u)); - } - - // Last byte of rate: byte[135] |= 0x80 - // 135 / 8 = 16 (word index), 135 % 8 = 7 (byte 7 => hi word, byte 3) - padded_hi[16u] = padded_hi[16u] | (0x80u << (3u * 8u)); - - // XOR padded block into state - for (var w = 0u; w < 17u; w = w + 1u) { - st_lo[w] = st_lo[w] ^ padded_lo[w]; - st_hi[w] = st_hi[w] ^ padded_hi[w]; - } - keccak_f(); - - // Squeeze: first 32 bytes = first 4 lanes (each lane = 8 bytes) - let out_base = tid * 8u; // 8 u32 words = 32 bytes - for (var w = 0u; w < 4u; w = w + 1u) { - outputs[out_base + w * 2u] = st_lo[w]; - outputs[out_base + w * 2u + 1u] = st_hi[w]; - } -} diff --git a/kzg/gpu/cuda/kzg.cu b/kzg/gpu/cuda/kzg.cu deleted file mode 100644 index 405bde8..0000000 --- a/kzg/gpu/cuda/kzg.cu +++ /dev/null @@ -1,247 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CUDA kernels for KZG over BLS12-381. Byte-equal mirror of the CPU oracle in -// kzg/cpp/kzg_oracle.hpp. Reuses bls/gpu/cuda/bls_fp_ops.cuh patterns for -// 4-limb Montgomery arithmetic; the BLS Fp module supplies fp_mul/fp_add for -// the 6-limb base-field path used by G1, while this file defines the 4-limb -// Fr path (BLS12-381 scalar field) needed for KZG polynomial evaluation. - -#include - -#ifdef LUX_KZG_HAVE_CUDA -#include - -#include "../../../bls/gpu/cuda/bls_fp_ops.cuh" // reuse BLS Fp patterns - -struct uint256 { uint64_t limbs[4]; }; - -__device__ __forceinline__ static uint256 FR_R_MOD() { - uint256 r = {{ - 0xFFFFFFFF00000001ULL, 0x53BDA402FFFE5BFEULL, - 0x3339D80809A1D805ULL, 0x73EDA753299D7D48ULL - }}; return r; -} -__device__ __forceinline__ static uint256 FR_R() { - uint256 r = {{ - 0x00000001FFFFFFFEULL, 0x5884B7FA00034802ULL, - 0x998C4FEFECBC4FF5ULL, 0x1824B159ACC5056FULL - }}; return r; -} -__device__ __forceinline__ static uint256 FR_R2() { - uint256 r = {{ - 0xC999E990F3F29C6DULL, 0x2B6CEDCB87925C23ULL, - 0x05D314967254398FULL, 0x0748D9D99F59FF11ULL - }}; return r; -} -__device__ __forceinline__ static uint64_t FR_INV() { - return 0xFFFFFFFEFFFFFFFFULL; -} - -__device__ __forceinline__ bool fr_geq(uint256 a, uint256 b) { - for (int i = 3; i >= 0; --i) { - if (a.limbs[i] != b.limbs[i]) return a.limbs[i] > b.limbs[i]; - } - return true; -} - -__device__ __forceinline__ uint256 fr_sub_p(uint256 a, uint256 b, uint64_t& bw) { - uint256 r; uint64_t borrow = 0; - for (int i = 0; i < 4; ++i) { - uint64_t d1 = a.limbs[i] - borrow; - uint64_t bw1 = (d1 > a.limbs[i]) ? 1ULL : 0ULL; - uint64_t d2 = d1 - b.limbs[i]; - uint64_t bw2 = (d2 > d1) ? 1ULL : 0ULL; - r.limbs[i] = d2; - borrow = bw1 + bw2; - } - bw = borrow; - return r; -} - -__device__ __forceinline__ uint256 fr_add_p(uint256 a, uint256 b, uint64_t& cy) { - uint256 r; uint64_t carry = 0; - for (int i = 0; i < 4; ++i) { - uint64_t s1 = a.limbs[i] + carry; - uint64_t c1 = (s1 < a.limbs[i]) ? 1ULL : 0ULL; - uint64_t s2 = s1 + b.limbs[i]; - uint64_t c2 = (s2 < s1) ? 1ULL : 0ULL; - r.limbs[i] = s2; - carry = c1 + c2; - } - cy = carry; - return r; -} - -__device__ __forceinline__ uint256 fr_add(uint256 a, uint256 b) { - uint64_t cy; - uint256 r = fr_add_p(a, b, cy); - if (cy || fr_geq(r, FR_R_MOD())) { - uint64_t bw; - r = fr_sub_p(r, FR_R_MOD(), bw); - } - return r; -} - -__device__ __forceinline__ void fr_mul64(uint64_t a, uint64_t b, - uint64_t& lo, uint64_t& hi) { -#ifdef __CUDA_ARCH__ - lo = a * b; - hi = __umul64hi(a, b); -#else - uint64_t al = a & 0xFFFFFFFFULL, ah = a >> 32; - uint64_t bl = b & 0xFFFFFFFFULL, bh = b >> 32; - uint64_t ll = al*bl, lh = al*bh, hl = ah*bl, hh = ah*bh; - uint64_t mid = lh + (ll >> 32); - uint64_t mid2 = mid + hl; - if (mid2 < mid) hh += (1ULL << 32); - lo = (mid2 << 32) | (ll & 0xFFFFFFFFULL); - hi = hh + (mid2 >> 32); -#endif -} - -__device__ __forceinline__ uint256 fr_mont_mul(uint256 a, uint256 b) { - uint64_t t[5] = {0, 0, 0, 0, 0}; - const uint256 MOD = FR_R_MOD(); - const uint64_t INV = FR_INV(); - for (int i = 0; i < 4; ++i) { - uint64_t carry = 0; - for (int j = 0; j < 4; ++j) { - uint64_t lo, hi; - fr_mul64(a.limbs[i], b.limbs[j], lo, hi); - uint64_t s = lo + carry; if (s < lo) hi++; - uint64_t s2 = t[j] + s; if (s2 < t[j]) hi++; - t[j] = s2; carry = hi; - } - uint64_t s = t[4] + carry; - t[4] = s; - - uint64_t u = t[0] * INV; - uint64_t k_carry = 0; - for (int j = 0; j < 4; ++j) { - uint64_t lo, hi; - fr_mul64(u, MOD.limbs[j], lo, hi); - uint64_t s2 = lo + k_carry; if (s2 < lo) hi++; - uint64_t s3 = t[j] + s2; if (s3 < t[j]) hi++; - t[j] = s3; k_carry = hi; - } - uint64_t s2 = t[4] + k_carry; - t[4] = s2; - for (int j = 0; j < 4; ++j) t[j] = t[j+1]; - t[4] = 0; - } - uint256 r; - r.limbs[0] = t[0]; r.limbs[1] = t[1]; - r.limbs[2] = t[2]; r.limbs[3] = t[3]; - if (fr_geq(r, MOD)) { - uint64_t bw; r = fr_sub_p(r, MOD, bw); - } - return r; -} - -__device__ __forceinline__ uint256 fr_to_mont(uint256 a) { - return fr_mont_mul(a, FR_R2()); -} -__device__ __forceinline__ uint256 fr_from_mont(uint256 a) { - uint256 ONE = {{1, 0, 0, 0}}; - return fr_mont_mul(a, ONE); -} - -__device__ __forceinline__ uint256 fr_from_be(const uint8_t* b32) { - uint256 r; - for (int i = 0; i < 4; ++i) { - uint64_t v = 0; - for (int j = 0; j < 8; ++j) v = (v << 8) | b32[i*8 + j]; - r.limbs[3 - i] = v; - } - while (fr_geq(r, FR_R_MOD())) { - uint64_t bw; r = fr_sub_p(r, FR_R_MOD(), bw); - } - return r; -} - -__device__ __forceinline__ void fr_to_le32(uint256 a, uint8_t* out) { - for (int i = 0; i < 4; ++i) { - uint64_t v = a.limbs[i]; - for (int j = 0; j < 8; ++j) out[i*8 + j] = (uint8_t)(v >> (j*8)); - } -} - -__device__ __forceinline__ void pack48(uint256 a_mont, uint8_t* out48) { - uint256 a = fr_from_mont(a_mont); - fr_to_le32(a, out48); - for (int i = 32; i < 48; ++i) out48[i] = 0; -} - -extern "C" __global__ void k_kzg_blob_to_commit(const uint8_t* __restrict__ blobs, - uint8_t* __restrict__ commits, - unsigned n) { - unsigned i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= n) return; - const uint8_t* blob = blobs + i * 131072u; - uint8_t* commit = commits + i * 48u; - - uint256 acc = {{0, 0, 0, 0}}; - for (unsigned k = 0; k < 4096; ++k) { - uint256 x = fr_from_be(blob + k * 32); - uint256 x_mont = fr_to_mont(x); - acc = fr_add(acc, x_mont); - } - pack48(acc, commit); -} - -extern "C" __global__ void k_kzg_compute_proof(const uint8_t* __restrict__ blobs, - const uint8_t* __restrict__ commits, - uint8_t* __restrict__ proofs, - unsigned n) { - unsigned i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= n) return; - const uint8_t* blob = blobs + i * 131072u; - const uint8_t* commit = commits + i * 48u; - uint8_t* proof = proofs + i * 48u; - - uint256 z = fr_from_be(commit); - uint256 z_mont = fr_to_mont(z); - - uint256 acc = {{0, 0, 0, 0}}; - uint256 z_pow = FR_R(); - for (unsigned k = 0; k < 4096; ++k) { - uint256 x = fr_from_be(blob + k * 32); - uint256 x_mont = fr_to_mont(x); - uint256 term = fr_mont_mul(x_mont, z_pow); - acc = fr_add(acc, term); - z_pow = fr_mont_mul(z_pow, z_mont); - } - pack48(acc, proof); -} - -extern "C" __global__ void k_kzg_verify(const uint8_t* __restrict__ commits, - const uint8_t* __restrict__ z_be_arr, - const uint8_t* __restrict__ y_be_arr, - const uint8_t* __restrict__ proofs, - uint8_t* __restrict__ out_flags, - unsigned n) { - unsigned i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= n) return; - const uint8_t* c = commits + i * 48u; - const uint8_t* p = proofs + i * 48u; - bool ok = true; - for (int j = 32; j < 48; ++j) ok = ok && (c[j] == 0) && (p[j] == 0); - uint256 cm = {{0,0,0,0}}, pm = {{0,0,0,0}}; - for (int j = 0; j < 4; ++j) { - uint64_t cv = 0, pv = 0; - for (int k = 0; k < 8; ++k) { - cv |= ((uint64_t)c[j*8+k]) << (k*8); - pv |= ((uint64_t)p[j*8+k]) << (k*8); - } - cm.limbs[j] = cv; pm.limbs[j] = pv; - } - if (fr_geq(cm, FR_R_MOD())) ok = false; - if (fr_geq(pm, FR_R_MOD())) ok = false; - bool nonzero_proof = (pm.limbs[0] | pm.limbs[1] | - pm.limbs[2] | pm.limbs[3]) != 0; - (void)z_be_arr; (void)y_be_arr; - out_flags[i] = (ok && nonzero_proof) ? 1u : 0u; -} - -#endif // LUX_KZG_HAVE_CUDA diff --git a/kzg/gpu/cuda/kzg_driver_cuda.cpp b/kzg/gpu/cuda/kzg_driver_cuda.cpp deleted file mode 100644 index e7a7687..0000000 --- a/kzg/gpu/cuda/kzg_driver_cuda.cpp +++ /dev/null @@ -1,157 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Host-side CUDA driver for KZG kernels. -// -// Two compile modes: -// 1. LUX_KZG_HAVE_CUDA defined: dispatches kernels via cudaMemcpy/launch, -// identical algorithm to the CPU oracle (kzg/cpp/kzg_oracle.hpp), byte- -// equal results. -// 2. LUX_KZG_HAVE_CUDA undefined: runs the CPU oracle directly so the -// determinism harness still passes 100/100. Reports unavailable via -// lux_kzg_cuda_available() so tests can label the path correctly. - -#include "kzg_driver_cuda.h" -#include "../../cpp/kzg_oracle.hpp" - -#include -#include - -#ifdef LUX_KZG_HAVE_CUDA -#include - -extern "C" { -__global__ void k_kzg_blob_to_commit(const std::uint8_t*, std::uint8_t*, unsigned); -__global__ void k_kzg_compute_proof (const std::uint8_t*, const std::uint8_t*, - std::uint8_t*, unsigned); -__global__ void k_kzg_verify (const std::uint8_t*, const std::uint8_t*, - const std::uint8_t*, const std::uint8_t*, - std::uint8_t*, unsigned); -} - -namespace { - -bool device_present() { - int n = 0; - return cudaGetDeviceCount(&n) == cudaSuccess && n > 0; -} - -unsigned grid_for(unsigned n, unsigned tg) { return (n + tg - 1) / tg; } - -} // namespace - -extern "C" { - -int lux_kzg_cuda_available(void) { return device_present() ? 1 : 0; } - -int lux_kzg_cuda_blob_to_commit(const void* blobs, void* commits, unsigned n) { - if (!device_present()) return -1; - void *dB=nullptr, *dC=nullptr; - size_t sB = (size_t)n * 131072, sC = (size_t)n * 48; - if (cudaMalloc(&dB, sB) != cudaSuccess) return -1; - if (cudaMalloc(&dC, sC) != cudaSuccess) { cudaFree(dB); return -1; } - cudaMemcpy(dB, blobs, sB, cudaMemcpyHostToDevice); - unsigned tg = 32, grid = grid_for(n, tg); - k_kzg_blob_to_commit<<>>((const std::uint8_t*)dB, - (std::uint8_t*)dC, n); - cudaDeviceSynchronize(); - cudaMemcpy(commits, dC, sC, cudaMemcpyDeviceToHost); - cudaFree(dB); cudaFree(dC); - return 0; -} - -int lux_kzg_cuda_compute_proof(const void* blobs, const void* commits, - void* proofs, unsigned n) { - if (!device_present()) return -1; - void *dB=nullptr,*dC=nullptr,*dP=nullptr; - size_t sB = (size_t)n*131072, sC = (size_t)n*48, sP = (size_t)n*48; - if (cudaMalloc(&dB, sB) != cudaSuccess) return -1; - if (cudaMalloc(&dC, sC) != cudaSuccess) { cudaFree(dB); return -1; } - if (cudaMalloc(&dP, sP) != cudaSuccess) { cudaFree(dB); cudaFree(dC); return -1; } - cudaMemcpy(dB, blobs, sB, cudaMemcpyHostToDevice); - cudaMemcpy(dC, commits, sC, cudaMemcpyHostToDevice); - unsigned tg = 32, grid = grid_for(n, tg); - k_kzg_compute_proof<<>>((const std::uint8_t*)dB, - (const std::uint8_t*)dC, - (std::uint8_t*)dP, n); - cudaDeviceSynchronize(); - cudaMemcpy(proofs, dP, sP, cudaMemcpyDeviceToHost); - cudaFree(dB); cudaFree(dC); cudaFree(dP); - return 0; -} - -int lux_kzg_cuda_verify(const void* commits, const void* z_be, const void* y_be, - const void* proofs, void* out_flags, unsigned n) { - if (!device_present()) return -1; - void *dC=nullptr,*dZ=nullptr,*dY=nullptr,*dP=nullptr,*dO=nullptr; - size_t sC=(size_t)n*48, sZ=(size_t)n*32, sY=(size_t)n*32, - sP=(size_t)n*48, sO=(size_t)n*1; - if (cudaMalloc(&dC,sC)!=cudaSuccess) return -1; - if (cudaMalloc(&dZ,sZ)!=cudaSuccess) { cudaFree(dC); return -1; } - if (cudaMalloc(&dY,sY)!=cudaSuccess) { cudaFree(dC); cudaFree(dZ); return -1; } - if (cudaMalloc(&dP,sP)!=cudaSuccess) { cudaFree(dC); cudaFree(dZ); cudaFree(dY); return -1; } - if (cudaMalloc(&dO,sO)!=cudaSuccess) { cudaFree(dC); cudaFree(dZ); cudaFree(dY); cudaFree(dP); return -1; } - cudaMemcpy(dC, commits, sC, cudaMemcpyHostToDevice); - cudaMemcpy(dZ, z_be, sZ, cudaMemcpyHostToDevice); - cudaMemcpy(dY, y_be, sY, cudaMemcpyHostToDevice); - cudaMemcpy(dP, proofs, sP, cudaMemcpyHostToDevice); - unsigned tg = 32, grid = grid_for(n, tg); - k_kzg_verify<<>>((const std::uint8_t*)dC, (const std::uint8_t*)dZ, - (const std::uint8_t*)dY, (const std::uint8_t*)dP, - (std::uint8_t*)dO, n); - cudaDeviceSynchronize(); - cudaMemcpy(out_flags, dO, sO, cudaMemcpyDeviceToHost); - cudaFree(dC); cudaFree(dZ); cudaFree(dY); cudaFree(dP); cudaFree(dO); - return 0; -} - -} // extern "C" - -#else // LUX_KZG_HAVE_CUDA undefined: CPU-oracle path - -extern "C" { - -int lux_kzg_cuda_available(void) { return 0; } - -int lux_kzg_cuda_blob_to_commit(const void* blobs, void* commits, unsigned n) { - auto* b = (const std::uint8_t*)blobs; - auto* c = (std::uint8_t*)commits; - for (unsigned i = 0; i < n; ++i) { - lux::crypto::kzg::blob_to_commit(b + (size_t)i * 131072, c + (size_t)i * 48); - } - return 0; -} - -int lux_kzg_cuda_compute_proof(const void* blobs, const void* commits, - void* proofs, unsigned n) { - auto* b = (const std::uint8_t*)blobs; - auto* c = (const std::uint8_t*)commits; - auto* p = (std::uint8_t*)proofs; - for (unsigned i = 0; i < n; ++i) { - lux::crypto::kzg::blob_to_proof(b + (size_t)i * 131072, - c + (size_t)i * 48, - p + (size_t)i * 48); - } - return 0; -} - -int lux_kzg_cuda_verify(const void* commits, const void* z_be, const void* y_be, - const void* proofs, void* out_flags, unsigned n) { - auto* c = (const std::uint8_t*)commits; - auto* z = (const std::uint8_t*)z_be; - auto* y = (const std::uint8_t*)y_be; - auto* p = (const std::uint8_t*)proofs; - auto* o = (std::uint8_t*)out_flags; - for (unsigned i = 0; i < n; ++i) { - bool ok = lux::crypto::kzg::verify_proof(c + (size_t)i * 48, - z + (size_t)i * 32, - y + (size_t)i * 32, - p + (size_t)i * 48); - o[i] = ok ? 1u : 0u; - } - return 0; -} - -} // extern "C" - -#endif // LUX_KZG_HAVE_CUDA diff --git a/kzg/gpu/cuda/kzg_driver_cuda.h b/kzg/gpu/cuda/kzg_driver_cuda.h deleted file mode 100644 index 0abd8eb..0000000 --- a/kzg/gpu/cuda/kzg_driver_cuda.h +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Public C-ABI for the KZG CUDA driver. Surface mirrors the EIP-4844 ops -// (blob_to_kzg_commitment, compute_blob_kzg_proof, verify_kzg_proof) and is -// byte-equal to the CPU oracle in kzg/cpp/kzg_oracle.hpp. -// -// On hosts without CUDA (LUX_KZG_HAVE_CUDA undefined) every dispatch routes -// through the CPU oracle so the determinism harness still passes 100/100. -// Reuses BLS12-381 G1+Fp arithmetic from bls/gpu/cuda/bls_fp_ops.cuh. - -#ifndef LUX_KZG_DRIVER_CUDA_H -#define LUX_KZG_DRIVER_CUDA_H - -#ifdef __cplusplus -extern "C" { -#endif - -int lux_kzg_cuda_available(void); - -int lux_kzg_cuda_blob_to_commit(const void* blobs, void* commits, unsigned n); -int lux_kzg_cuda_compute_proof(const void* blobs, const void* commits, - void* proofs, unsigned n); -int lux_kzg_cuda_verify(const void* commits, const void* z_be, const void* y_be, - const void* proofs, void* out_flags, unsigned n); - -#ifdef __cplusplus -} -#endif - -#endif // LUX_KZG_DRIVER_CUDA_H diff --git a/kzg/gpu/metal/kzg.metal b/kzg/gpu/metal/kzg.metal deleted file mode 100644 index 1cbe57e..0000000 --- a/kzg/gpu/metal/kzg.metal +++ /dev/null @@ -1,437 +0,0 @@ -// ============================================================================= -// KZG Polynomial Commitment Metal Compute Shaders -// ============================================================================= -// -// GPU-accelerated KZG polynomial commitments for EIP-4844 blobs. -// Uses BLS12-381 curve operations. -// -// KZG Parameters: -// - Uses BLS12-381 G1/G2 for commitments and proofs -// - Polynomial degree up to 4096 (blob elements) -// - Trusted setup from Ethereum KZG ceremony -// -// References: -// - EIP-4844: Shard Blob Transactions -// - KZG Commitments paper (Kate, Zaverucha, Goldberg 2010) -// -// Copyright (C) 2024-2025 Lux Industries Inc. -// SPDX-License-Identifier: Apache-2.0 - -#include -using namespace metal; - -// ============================================================================= -// Include BLS12-381 Primitives (shared types) -// ============================================================================= - -// BLS12-381 base field prime p (6 limbs, little-endian) -constant uint64_t BLS_P[6] = { - 0xb9feffffffffaaab, - 0x1eabfffeb153ffff, - 0x6730d2a0f6b0f624, - 0x64774b84f38512bf, - 0x4b1ba7b6434bacd7, - 0x1a0111ea397fe69a -}; - -// Scalar field r (for polynomial coefficients) -constant uint64_t BLS_R[4] = { - 0xffffffff00000001, - 0x53bda402fffe5bfe, - 0x3339d80809a1d805, - 0x73eda753299d7d48 -}; - -// Montgomery constant for scalar field -constant uint64_t BLS_R_INV = 0xfffffffeffffffff; - -// Fp384 represented as 6 uint64 limbs -struct Fp384 { - uint64_t limbs[6]; -}; - -// Fr (scalar field) represented as 4 uint64 limbs -struct Fr256 { - uint64_t limbs[4]; -}; - -// G1 affine point -struct G1Affine { - Fp384 x; - Fp384 y; - bool infinity; -}; - -// G1 projective point -struct G1Projective { - Fp384 x; - Fp384 y; - Fp384 z; -}; - -// ============================================================================= -// Scalar Field Operations (Fr) -// ============================================================================= - -inline uint64_t fr_adc(uint64_t a, uint64_t b, thread uint64_t& carry) { - uint64_t result = a + carry; - carry = (result < a) ? 1 : 0; - uint64_t sum = result + b; - carry += (sum < result) ? 1 : 0; - return sum; -} - -inline uint64_t fr_sbb(uint64_t a, uint64_t b, thread uint64_t& borrow) { - uint64_t diff = a - borrow; - borrow = (a < borrow) ? 1 : 0; - uint64_t result = diff - b; - borrow += (diff < b) ? 1 : 0; - return result; -} - -inline Fr256 fr_zero() { - Fr256 r; - for (int i = 0; i < 4; i++) r.limbs[i] = 0; - return r; -} - -inline Fr256 fr_one() { - Fr256 r; - r.limbs[0] = 0xFFFE5BFEFFFFFFFF; - r.limbs[1] = 0x09A1D80553BDA402; - r.limbs[2] = 0x299D7D483339D808; - r.limbs[3] = 0x0073EDA753299D7D; - return r; -} - -inline bool fr_is_zero(thread const Fr256& a) { - return a.limbs[0] == 0 && a.limbs[1] == 0 && a.limbs[2] == 0 && a.limbs[3] == 0; -} - -inline int fr_cmp(thread const Fr256& a, constant uint64_t* b) { - for (int i = 3; i >= 0; i--) { - if (a.limbs[i] < b[i]) return -1; - if (a.limbs[i] > b[i]) return 1; - } - return 0; -} - -inline void fr_reduce(thread Fr256& a) { - if (fr_cmp(a, BLS_R) >= 0) { - uint64_t borrow = 0; - for (int i = 0; i < 4; i++) { - a.limbs[i] = fr_sbb(a.limbs[i], BLS_R[i], borrow); - } - } -} - -inline Fr256 fr_add(Fr256 a, Fr256 b) { - Fr256 c; - uint64_t carry = 0; - for (int i = 0; i < 4; i++) { - c.limbs[i] = fr_adc(a.limbs[i], b.limbs[i], carry); - } - fr_reduce(c); - return c; -} - -inline Fr256 fr_sub(Fr256 a, Fr256 b) { - Fr256 c; - uint64_t borrow = 0; - for (int i = 0; i < 4; i++) { - c.limbs[i] = fr_sbb(a.limbs[i], b.limbs[i], borrow); - } - if (borrow) { - uint64_t carry = 0; - for (int i = 0; i < 4; i++) { - c.limbs[i] = fr_adc(c.limbs[i], BLS_R[i], carry); - } - } - return c; -} - -inline Fr256 fr_mont_mul(Fr256 a, Fr256 b) { - uint64_t t[8] = {0}; - - // Schoolbook multiplication - for (int i = 0; i < 4; i++) { - uint64_t carry = 0; - for (int j = 0; j < 4; j++) { - uint64_t lo = a.limbs[i] * b.limbs[j]; - uint64_t hi = mulhi(a.limbs[i], b.limbs[j]); - uint64_t sum = t[i+j] + lo + carry; - carry = (sum < t[i+j]) ? 1 : 0; - carry += hi; - t[i+j] = sum; - } - t[i+4] = carry; - } - - // Montgomery reduction - for (int i = 0; i < 4; i++) { - uint64_t k = t[i] * BLS_R_INV; - uint64_t carry = 0; - for (int j = 0; j < 4; j++) { - uint64_t lo = k * BLS_R[j]; - uint64_t hi = mulhi(k, BLS_R[j]); - uint64_t sum = t[i+j] + lo + carry; - carry = (sum < t[i+j]) ? 1 : 0; - carry += hi; - t[i+j] = sum; - } - for (int j = i + 4; j < 8; j++) { - uint64_t sum = t[j] + carry; - carry = (sum < t[j]) ? 1 : 0; - t[j] = sum; - if (carry == 0) break; - } - } - - Fr256 c; - for (int i = 0; i < 4; i++) { - c.limbs[i] = t[i + 4]; - } - fr_reduce(c); - return c; -} - -// ============================================================================= -// Polynomial Operations -// ============================================================================= - -// Evaluate polynomial at point using Horner's method -inline Fr256 poly_evaluate( - device const Fr256* coeffs, - uint32_t degree, - thread const Fr256& point -) { - Fr256 result = fr_zero(); - - for (int i = (int)degree; i >= 0; i--) { - result = fr_mont_mul(result, point); - result = fr_add(result, coeffs[i]); - } - - return result; -} - -// Compute polynomial quotient: q(x) = (p(x) - p(z)) / (x - z) -// Used for KZG proof generation -kernel void poly_quotient( - device const Fr256* poly_coeffs [[buffer(0)]], - device Fr256* quotient_coeffs [[buffer(1)]], - constant Fr256& z [[buffer(2)]], - constant uint32_t& degree [[buffer(3)]], - uint index [[thread_position_in_grid]] -) { - if (index >= degree) return; - - // Evaluate p(z) - Fr256 p_z = fr_zero(); - for (int i = (int)degree; i >= 0; i--) { - p_z = fr_mont_mul(p_z, z); - p_z = fr_add(p_z, poly_coeffs[i]); - } - - // Synthetic division by (x - z) - // q[i] = p[i+1] + z * q[i+1] - // Working backwards from highest degree - - // This is simplified - full impl needs parallel reduction - if (index == 0) { - Fr256 q[4096]; // Max blob degree - q[degree - 1] = poly_coeffs[degree]; - - for (int i = (int)degree - 2; i >= 0; i--) { - q[i] = fr_add(poly_coeffs[i + 1], fr_mont_mul(z, q[i + 1])); - } - - for (uint32_t i = 0; i < degree; i++) { - quotient_coeffs[i] = q[i]; - } - } -} - -// ============================================================================= -// Multi-Scalar Multiplication (MSM) for KZG Commitments -// ============================================================================= - -// Pippenger's bucket method for MSM -// This is a simplified version - production would use windowed Pippenger - -kernel void kzg_msm_bucket_accumulate( - device const G1Affine* bases [[buffer(0)]], - device const Fr256* scalars [[buffer(1)]], - device G1Projective* buckets [[buffer(2)]], - constant uint32_t& num_points [[buffer(3)]], - constant uint32_t& window_size [[buffer(4)]], - constant uint32_t& window_idx [[buffer(5)]], - uint index [[thread_position_in_grid]] -) { - if (index >= num_points) return; - - Fr256 scalar = scalars[index]; - G1Affine base = bases[index]; - - // Extract window bits - uint32_t shift = window_idx * window_size; - uint32_t limb_idx = shift / 64; - uint32_t bit_idx = shift % 64; - - uint64_t window_val = 0; - if (limb_idx < 4) { - window_val = (scalar.limbs[limb_idx] >> bit_idx); - if (bit_idx + window_size > 64 && limb_idx + 1 < 4) { - window_val |= (scalar.limbs[limb_idx + 1] << (64 - bit_idx)); - } - } - window_val &= ((1ULL << window_size) - 1); - - if (window_val == 0) return; - - // Add to appropriate bucket (simplified - needs atomic or reduction) - // Full impl would use bucket indices and parallel reduction - uint32_t bucket_idx = (uint32_t)window_val - 1; - - // Placeholder: actual impl needs proper G1 arithmetic - // buckets[bucket_idx] = g1_add(buckets[bucket_idx], g1_to_projective(base)); -} - -// ============================================================================= -// Blob to Commitment Kernel -// ============================================================================= - -// Hash blob elements to field elements using SHA-256 -// This is for domain separation before polynomial encoding - -kernel void blob_to_field_elements( - device const uint8_t* blob [[buffer(0)]], - device Fr256* field_elements [[buffer(1)]], - constant uint32_t& blob_size [[buffer(2)]], - uint index [[thread_position_in_grid]] -) { - if (index * 32 >= blob_size) return; - - // Load 32 bytes as field element - Fr256 elem; - uint32_t offset = index * 32; - - for (int i = 0; i < 4; i++) { - elem.limbs[i] = 0; - for (int j = 0; j < 8; j++) { - if (offset + i * 8 + j < blob_size) { - elem.limbs[i] |= ((uint64_t)blob[offset + i * 8 + j]) << (j * 8); - } - } - } - - // Reduce mod r to ensure valid field element - fr_reduce(elem); - field_elements[index] = elem; -} - -// ============================================================================= -// Batch KZG Verification Kernel -// ============================================================================= - -// Precompute linear combination for batch verification -// Verifies: e(sum(r^i * C_i), G2) = e(sum(r^i * (z_i * W_i + P_i)), H) -// Where: -// C_i = commitment -// W_i = witness (proof) -// z_i = evaluation point -// P_i = claimed value point -// r = random challenge - -kernel void kzg_batch_verify_precompute( - device const G1Affine* commitments [[buffer(0)]], - device const G1Affine* witnesses [[buffer(1)]], - device const Fr256* points [[buffer(2)]], - device const Fr256* values [[buffer(3)]], - device G1Projective* lhs_accum [[buffer(4)]], - device G1Projective* rhs_accum [[buffer(5)]], - constant Fr256& challenge [[buffer(6)]], - constant uint32_t& num_proofs [[buffer(7)]], - uint index [[thread_position_in_grid]] -) { - if (index >= num_proofs) return; - - // Compute r^index - Fr256 r_power = fr_one(); - for (uint32_t i = 0; i < index; i++) { - r_power = fr_mont_mul(r_power, challenge); - } - - // LHS: r^i * C_i - // RHS: r^i * (z_i * W_i + P_i) - - // Note: Actual G1 scalar multiplication would be done here - // This is a placeholder showing the structure - - G1Affine C_i = commitments[index]; - G1Affine W_i = witnesses[index]; - Fr256 z_i = points[index]; - Fr256 v_i = values[index]; - - // Scale by r^i and accumulate (simplified) - // Full impl needs proper G1 operations from bls12_381.metal -} - -// ============================================================================= -// FFT for Polynomial Interpolation (Cooley-Tukey) -// ============================================================================= - -// Number-theoretic transform over scalar field -// Used for efficient polynomial evaluation and interpolation - -kernel void kzg_fft_butterfly( - device Fr256* coeffs [[buffer(0)]], - constant Fr256& omega [[buffer(1)]], - constant uint32_t& n [[buffer(2)]], - constant uint32_t& stage [[buffer(3)]], - uint index [[thread_position_in_grid]] -) { - uint32_t m = 1 << (stage + 1); - uint32_t k = index % (m / 2); - uint32_t j = (index / (m / 2)) * m + k; - - if (j + m / 2 >= n) return; - - // Compute twiddle factor: omega^(k * n/m) - uint32_t exponent = k * (n / m); - Fr256 w = fr_one(); - for (uint32_t i = 0; i < exponent; i++) { - w = fr_mont_mul(w, omega); - } - - // Butterfly - Fr256 u = coeffs[j]; - Fr256 t = fr_mont_mul(w, coeffs[j + m / 2]); - - coeffs[j] = fr_add(u, t); - coeffs[j + m / 2] = fr_sub(u, t); -} - -// Bit-reversal permutation for FFT -kernel void kzg_fft_bit_reverse( - device Fr256* coeffs [[buffer(0)]], - constant uint32_t& log_n [[buffer(1)]], - uint index [[thread_position_in_grid]] -) { - uint32_t n = 1 << log_n; - if (index >= n / 2) return; - - // Compute bit-reversed index - uint32_t rev = 0; - uint32_t temp = index; - for (uint32_t i = 0; i < log_n; i++) { - rev = (rev << 1) | (temp & 1); - temp >>= 1; - } - - if (index < rev) { - Fr256 tmp = coeffs[index]; - coeffs[index] = coeffs[rev]; - coeffs[rev] = tmp; - } -} diff --git a/kzg/gpu/wgsl/kzg.wgsl b/kzg/gpu/wgsl/kzg.wgsl deleted file mode 100644 index a1d7d05..0000000 --- a/kzg/gpu/wgsl/kzg.wgsl +++ /dev/null @@ -1,266 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// WGSL kernel for KZG over BLS12-381 scalar field Fr (4 x u64 limbs, -// Montgomery form). Byte-equal mirror of the CPU oracle in -// kzg/cpp/kzg_oracle.hpp. Reuses the 6-limb Fp arithmetic conventions from -// bls/gpu/wgsl/bls_fp_ops.wgsl for any G1-based extension; this file -// implements the 4-limb Fr layer (BLS12-381 scalar field) the polynomial -// commitment requires. -// -// WGSL has no u64; all 64-bit ops are emulated as (lo: u32, hi: u32) pairs. -// Storage layout: each Fr value occupies 8 contiguous u32 little-endian -// words. CPU↔GPU agreement holds because canonical 32-byte-LE encoding is -// the on-the-wire form; the WGSL kernel reads BE bytes from the blob and -// writes LE Fr + 16-byte zero pad to the commit/proof buffer. - -struct U64 { lo: u32, hi: u32 }; - -@group(0) @binding(0) var input_a: array; -@group(0) @binding(1) var input_b: array; -@group(0) @binding(2) var output: array; -@group(0) @binding(3) var params: vec4; // x = N - -fn FR_R_MOD_lo(i: u32) -> u32 { - if (i == 0u) { return 0x00000001u; } - if (i == 1u) { return 0xFFFE5BFEu; } - if (i == 2u) { return 0x09A1D805u; } - if (i == 3u) { return 0x299D7D48u; } - return 0u; -} -fn FR_R_MOD_hi(i: u32) -> u32 { - if (i == 0u) { return 0xFFFFFFFFu; } - if (i == 1u) { return 0x53BDA402u; } - if (i == 2u) { return 0x3339D808u; } - if (i == 3u) { return 0x73EDA753u; } - return 0u; -} -fn FR_R2_lo(i: u32) -> u32 { - if (i == 0u) { return 0xF3F29C6Du; } - if (i == 1u) { return 0x87925C23u; } - if (i == 2u) { return 0x7254398Fu; } - if (i == 3u) { return 0x9F59FF11u; } - return 0u; -} -fn FR_R2_hi(i: u32) -> u32 { - if (i == 0u) { return 0xC999E990u; } - if (i == 1u) { return 0x2B6CEDCBu; } - if (i == 2u) { return 0x05D31496u; } - if (i == 3u) { return 0x0748D9D9u; } - return 0u; -} -fn FR_INV_lo() -> u32 { return 0xFFFFFFFFu; } -fn FR_INV_hi() -> u32 { return 0xFFFFFFFEu; } - -fn u64_zero() -> U64 { return U64(0u, 0u); } - -fn u64_add(a: U64, b: U64) -> U64 { - let lo = a.lo + b.lo; - let carry: u32 = select(0u, 1u, lo < a.lo); - let hi = a.hi + b.hi + carry; - return U64(lo, hi); -} - -fn u64_sub_borrow(a: U64, b: U64) -> vec3 { - let lo = a.lo - b.lo; - let bw1: u32 = select(0u, 1u, lo > a.lo); - let hi_sub = a.hi - b.hi; - let bw2: u32 = select(0u, 1u, hi_sub > a.hi); - let hi = hi_sub - bw1; - let bw3: u32 = select(0u, 1u, hi > hi_sub); - return vec3(lo, hi, bw2 + bw3); -} - -fn mul32(a: u32, b: u32) -> U64 { - let al = a & 0xFFFFu; - let ah = a >> 16u; - let bl = b & 0xFFFFu; - let bh = b >> 16u; - let ll = al * bl; - let lh = al * bh; - let hl = ah * bl; - let hh = ah * bh; - let mid = lh + (ll >> 16u); - let mid2 = mid + hl; - var hi = hh + (mid2 >> 16u); - if (mid2 < mid) { hi = hi + 0x10000u; } - let lo = (mid2 << 16u) | (ll & 0xFFFFu); - return U64(lo, hi); -} - -struct U128 { lo: U64, hi: U64 }; - -fn u64_mul(a: U64, b: U64) -> U128 { - let p_ll = mul32(a.lo, b.lo); - let p_lh = mul32(a.lo, b.hi); - let p_hl = mul32(a.hi, b.lo); - let p_hh = mul32(a.hi, b.hi); - var lo = U64(p_ll.lo, 0u); - var carry: u32 = 0u; - let s1 = p_ll.hi + p_lh.lo; - if (s1 < p_ll.hi) { carry = carry + 1u; } - let s2 = s1 + p_hl.lo; - if (s2 < s1) { carry = carry + 1u; } - lo.hi = s2; - var hi_lo = p_hh.lo; - var hi_hi = p_hh.hi; - var c2: u32 = 0u; - let h1 = hi_lo + p_lh.hi; - if (h1 < hi_lo) { c2 = c2 + 1u; } - let h2 = h1 + p_hl.hi; - if (h2 < h1) { c2 = c2 + 1u; } - let h3 = h2 + carry; - if (h3 < h2) { c2 = c2 + 1u; } - hi_lo = h3; - hi_hi = hi_hi + c2; - return U128(lo, U64(hi_lo, hi_hi)); -} - -fn fr_load_mod_i(i: u32) -> U64 { return U64(FR_R_MOD_lo(i), FR_R_MOD_hi(i)); } -fn fr_load_R2_i (i: u32) -> U64 { return U64(FR_R2_lo(i), FR_R2_hi(i)); } - -fn fr_geq_mod(a: array) -> bool { - let m3 = fr_load_mod_i(3u); - if (a[3].hi != m3.hi) { return a[3].hi > m3.hi; } - if (a[3].lo != m3.lo) { return a[3].lo > m3.lo; } - let m2 = fr_load_mod_i(2u); - if (a[2].hi != m2.hi) { return a[2].hi > m2.hi; } - if (a[2].lo != m2.lo) { return a[2].lo > m2.lo; } - let m1 = fr_load_mod_i(1u); - if (a[1].hi != m1.hi) { return a[1].hi > m1.hi; } - if (a[1].lo != m1.lo) { return a[1].lo > m1.lo; } - let m0 = fr_load_mod_i(0u); - if (a[0].hi != m0.hi) { return a[0].hi > m0.hi; } - return a[0].lo >= m0.lo; -} - -fn fr_sub_mod(a: array) -> array { - var r: array; - var borrow: u32 = 0u; - for (var i: u32 = 0u; i < 4u; i = i + 1u) { - let bb = U64(borrow, 0u); - let s1 = u64_sub_borrow(a[i], bb); - let m = fr_load_mod_i(i); - let s2 = u64_sub_borrow(U64(s1.x, s1.y), m); - r[i] = U64(s2.x, s2.y); - borrow = s1.z + s2.z; - } - return r; -} - -fn fr_add(a: array, b: array) -> array { - var r: array; - var carry: u32 = 0u; - for (var i: u32 = 0u; i < 4u; i = i + 1u) { - let s1 = u64_add(a[i], U64(carry, 0u)); - let cy1: u32 = select(0u, 1u, s1.lo < a[i].lo); - let s2 = u64_add(s1, b[i]); - let cy2: u32 = select(0u, 1u, s2.lo < s1.lo); - r[i] = s2; - carry = cy1 + cy2; - } - if (carry != 0u || fr_geq_mod(r)) { - r = fr_sub_mod(r); - } - return r; -} - -fn fr_mont_mul(a: array, b: array) -> array { - var t: array = array(u64_zero(), u64_zero(), u64_zero(), - u64_zero(), u64_zero()); - let inv = U64(FR_INV_lo(), FR_INV_hi()); - for (var i: u32 = 0u; i < 4u; i = i + 1u) { - var carry: U64 = u64_zero(); - for (var j: u32 = 0u; j < 4u; j = j + 1u) { - let prod = u64_mul(a[i], b[j]); - let s1 = u64_add(prod.lo, carry); - let cy1: u32 = select(0u, 1u, s1.lo < prod.lo.lo); - let s2 = u64_add(t[j], s1); - let cy2: u32 = select(0u, 1u, s2.lo < t[j].lo); - t[j] = s2; - carry = u64_add(prod.hi, U64(cy1 + cy2, 0u)); - } - t[4u] = u64_add(t[4u], carry); - - let u = u64_mul(t[0u], inv).lo; - var k_carry: U64 = u64_zero(); - for (var j: u32 = 0u; j < 4u; j = j + 1u) { - let m = fr_load_mod_i(j); - let prod = u64_mul(u, m); - let s1 = u64_add(prod.lo, k_carry); - let cy1: u32 = select(0u, 1u, s1.lo < prod.lo.lo); - let s2 = u64_add(t[j], s1); - let cy2: u32 = select(0u, 1u, s2.lo < t[j].lo); - t[j] = s2; - k_carry = u64_add(prod.hi, U64(cy1 + cy2, 0u)); - } - t[4u] = u64_add(t[4u], k_carry); - for (var j: u32 = 0u; j < 4u; j = j + 1u) { t[j] = t[j + 1u]; } - t[4u] = u64_zero(); - } - var r = array(t[0u], t[1u], t[2u], t[3u]); - if (fr_geq_mod(r)) { r = fr_sub_mod(r); } - return r; -} - -fn fr_to_mont(a: array) -> array { - let R2 = array(fr_load_R2_i(0u), fr_load_R2_i(1u), - fr_load_R2_i(2u), fr_load_R2_i(3u)); - return fr_mont_mul(a, R2); -} -fn fr_from_mont(a: array) -> array { - let ONE = array(U64(1u, 0u), u64_zero(), u64_zero(), u64_zero()); - return fr_mont_mul(a, ONE); -} - -fn bswap32(w: u32) -> u32 { - return ((w & 0xFFu) << 24u) | (((w >> 8u) & 0xFFu) << 16u) | - (((w >> 16u) & 0xFFu) << 8u) | ((w >> 24u) & 0xFFu); -} - -fn fr_from_be_blob(blob_word_off: u32, fe_index: u32) -> array { - var limbs: array; - let base = blob_word_off + fe_index * 8u; - for (var i: u32 = 0u; i < 4u; i = i + 1u) { - let w0 = bswap32(input_a[base + (3u - i) * 2u]); - let w1 = bswap32(input_a[base + (3u - i) * 2u + 1u]); - limbs[i] = U64(w1, w0); - } - var done: bool = false; - for (var k: u32 = 0u; k < 2u; k = k + 1u) { - if (!done && fr_geq_mod(limbs)) { - limbs = fr_sub_mod(limbs); - } else { - done = true; - } - } - return limbs; -} - -fn fr_pack48(a_mont: array, dst_off: u32) { - let a = fr_from_mont(a_mont); - for (var i: u32 = 0u; i < 4u; i = i + 1u) { - output[dst_off + i * 2u] = a[i].lo; - output[dst_off + i * 2u + 1u] = a[i].hi; - } - output[dst_off + 8u] = 0u; - output[dst_off + 9u] = 0u; - output[dst_off + 10u] = 0u; - output[dst_off + 11u] = 0u; -} - -@compute @workgroup_size(1, 1, 1) -fn kzg_blob_to_commit(@builtin(global_invocation_id) gid: vec3) { - let i = gid.x; - if (i >= params.x) { return; } - let blob_off = i * 32768u; // u32 words per blob - let out_off = i * 12u; - - var acc = array(u64_zero(), u64_zero(), u64_zero(), u64_zero()); - for (var k: u32 = 0u; k < 4096u; k = k + 1u) { - let x = fr_from_be_blob(blob_off, k); - let x_mont = fr_to_mont(x); - acc = fr_add(acc, x_mont); - } - fr_pack48(acc, out_off); -} diff --git a/kzg/gpu/wgsl/kzg_driver_wgpu.cpp b/kzg/gpu/wgsl/kzg_driver_wgpu.cpp deleted file mode 100644 index 5adf86c..0000000 --- a/kzg/gpu/wgsl/kzg_driver_wgpu.cpp +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Host-side WebGPU/WGSL driver for KZG kernels. -// -// On hosts with wgpu-native (LUX_KZG_HAS_WEBGPU=1): would dispatch -// kzg_blob_to_commit kernel via shared bls_driver_wgpu pattern. The current -// build does not pull wgpu at the kzg level; the CPU-oracle path satisfies -// the determinism contract on every host. CI runners with hardware can opt in -// by defining LUX_KZG_HAS_WEBGPU=1; the kernel source ships at -// kzg/gpu/wgsl/kzg.wgsl ready to compile. - -#include "kzg_driver_wgpu.h" -#include "../../cpp/kzg_oracle.hpp" - -#include -#include - -extern "C" { - -int lux_kzg_wgpu_available(void) { -#ifdef LUX_KZG_HAS_WEBGPU - return 1; -#else - return 0; -#endif -} - -int lux_kzg_wgpu_blob_to_commit(const void* blobs, void* commits, unsigned n) { - auto* b = static_cast(blobs); - auto* c = static_cast(commits); - for (unsigned i = 0; i < n; ++i) { - lux::crypto::kzg::blob_to_commit(b + (size_t)i * 131072, - c + (size_t)i * 48); - } - return 0; -} - -int lux_kzg_wgpu_compute_proof(const void* blobs, const void* commits, - void* proofs, unsigned n) { - auto* b = static_cast(blobs); - auto* c = static_cast(commits); - auto* p = static_cast(proofs); - for (unsigned i = 0; i < n; ++i) { - lux::crypto::kzg::blob_to_proof(b + (size_t)i * 131072, - c + (size_t)i * 48, - p + (size_t)i * 48); - } - return 0; -} - -int lux_kzg_wgpu_verify(const void* commits, const void* z_be, const void* y_be, - const void* proofs, void* out_flags, unsigned n) { - auto* c = static_cast(commits); - auto* z = static_cast(z_be); - auto* y = static_cast(y_be); - auto* p = static_cast(proofs); - auto* o = static_cast(out_flags); - for (unsigned i = 0; i < n; ++i) { - bool ok = lux::crypto::kzg::verify_proof(c + (size_t)i * 48, - z + (size_t)i * 32, - y + (size_t)i * 32, - p + (size_t)i * 48); - o[i] = ok ? 1u : 0u; - } - return 0; -} - -} // extern "C" diff --git a/kzg/gpu/wgsl/kzg_driver_wgpu.h b/kzg/gpu/wgsl/kzg_driver_wgpu.h deleted file mode 100644 index 0d67455..0000000 --- a/kzg/gpu/wgsl/kzg_driver_wgpu.h +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Public C-ABI for the KZG WebGPU/WGSL driver. - -#ifndef LUX_KZG_DRIVER_WGPU_H -#define LUX_KZG_DRIVER_WGPU_H - -#ifdef __cplusplus -extern "C" { -#endif - -int lux_kzg_wgpu_available(void); - -int lux_kzg_wgpu_blob_to_commit(const void* blobs, void* commits, unsigned n); -int lux_kzg_wgpu_compute_proof (const void* blobs, const void* commits, - void* proofs, unsigned n); -int lux_kzg_wgpu_verify (const void* commits, const void* z_be, - const void* y_be, const void* proofs, - void* out_flags, unsigned n); - -#ifdef __cplusplus -} -#endif - -#endif // LUX_KZG_DRIVER_WGPU_H diff --git a/lamport/gpu/cuda/lamport.cu b/lamport/gpu/cuda/lamport.cu deleted file mode 100644 index 28b2d5c..0000000 --- a/lamport/gpu/cuda/lamport.cu +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CUDA kernel for Lamport-SHA256 OTS — one thread per 32-byte preimage. -// Byte-equal to lamport/cpp/lamport.cpp; mirrors lamport/gpu/metal/lamport_batch.metal. -// -// Single-block FIPS 180-4 padding: input fits in w[0..7], w[8] = 0x80000000, -// w[9..14] = 0, w[15] = 256 (bit length). - -#include - -namespace { - -__device__ inline uint32_t lamport_rotr32(uint32_t x, uint32_t n) { - return (x >> n) | (x << (32u - n)); -} - -__device__ inline void lamport_sha256_block_32(const uint8_t* in, uint8_t* out) { - static const uint32_t K[64] = { - 0x428a2f98u, 0x71374491u, 0xb5c0fbcfu, 0xe9b5dba5u, - 0x3956c25bu, 0x59f111f1u, 0x923f82a4u, 0xab1c5ed5u, - 0xd807aa98u, 0x12835b01u, 0x243185beu, 0x550c7dc3u, - 0x72be5d74u, 0x80deb1feu, 0x9bdc06a7u, 0xc19bf174u, - 0xe49b69c1u, 0xefbe4786u, 0x0fc19dc6u, 0x240ca1ccu, - 0x2de92c6fu, 0x4a7484aau, 0x5cb0a9dcu, 0x76f988dau, - 0x983e5152u, 0xa831c66du, 0xb00327c8u, 0xbf597fc7u, - 0xc6e00bf3u, 0xd5a79147u, 0x06ca6351u, 0x14292967u, - 0x27b70a85u, 0x2e1b2138u, 0x4d2c6dfcu, 0x53380d13u, - 0x650a7354u, 0x766a0abbu, 0x81c2c92eu, 0x92722c85u, - 0xa2bfe8a1u, 0xa81a664bu, 0xc24b8b70u, 0xc76c51a3u, - 0xd192e819u, 0xd6990624u, 0xf40e3585u, 0x106aa070u, - 0x19a4c116u, 0x1e376c08u, 0x2748774cu, 0x34b0bcb5u, - 0x391c0cb3u, 0x4ed8aa4au, 0x5b9cca4fu, 0x682e6ff3u, - 0x748f82eeu, 0x78a5636fu, 0x84c87814u, 0x8cc70208u, - 0x90befffau, 0xa4506cebu, 0xbef9a3f7u, 0xc67178f2u - }; - - uint32_t w[64]; - for (uint32_t i = 0; i < 8u; ++i) { - w[i] = (uint32_t(in[i*4u + 0]) << 24) - | (uint32_t(in[i*4u + 1]) << 16) - | (uint32_t(in[i*4u + 2]) << 8) - | (uint32_t(in[i*4u + 3]) ); - } - w[ 8] = 0x80000000u; - w[ 9] = 0u; w[10] = 0u; w[11] = 0u; - w[12] = 0u; w[13] = 0u; w[14] = 0u; - w[15] = 256u; - for (uint32_t i = 16u; i < 64u; ++i) { - uint32_t s0 = lamport_rotr32(w[i-15], 7) ^ lamport_rotr32(w[i-15], 18) ^ (w[i-15] >> 3); - uint32_t s1 = lamport_rotr32(w[i- 2], 17) ^ lamport_rotr32(w[i- 2], 19) ^ (w[i- 2] >> 10); - w[i] = s1 + w[i-7] + s0 + w[i-16]; - } - - uint32_t a = 0x6a09e667u, b = 0xbb67ae85u, c = 0x3c6ef372u, d = 0xa54ff53au; - uint32_t e = 0x510e527fu, f = 0x9b05688cu, g = 0x1f83d9abu, h = 0x5be0cd19u; - - for (uint32_t i = 0; i < 64u; ++i) { - uint32_t S1 = lamport_rotr32(e, 6) ^ lamport_rotr32(e, 11) ^ lamport_rotr32(e, 25); - uint32_t ch = (e & f) ^ ((~e) & g); - uint32_t t1 = h + S1 + ch + K[i] + w[i]; - uint32_t S0 = lamport_rotr32(a, 2) ^ lamport_rotr32(a, 13) ^ lamport_rotr32(a, 22); - uint32_t mj = (a & b) ^ (a & c) ^ (b & c); - uint32_t t2 = S0 + mj; - h = g; g = f; f = e; e = d + t1; - d = c; c = b; b = a; a = t1 + t2; - } - a += 0x6a09e667u; b += 0xbb67ae85u; c += 0x3c6ef372u; d += 0xa54ff53au; - e += 0x510e527fu; f += 0x9b05688cu; g += 0x1f83d9abu; h += 0x5be0cd19u; - - uint32_t H[8] = { a, b, c, d, e, f, g, h }; - for (uint32_t i = 0; i < 8u; ++i) { - out[i*4u + 0] = uint8_t((H[i] >> 24) & 0xFFu); - out[i*4u + 1] = uint8_t((H[i] >> 16) & 0xFFu); - out[i*4u + 2] = uint8_t((H[i] >> 8) & 0xFFu); - out[i*4u + 3] = uint8_t( H[i] & 0xFFu); - } -} - -} // namespace - -extern "C" __global__ void lamport_hash_jobs_kernel( - const uint8_t* __restrict__ slots, - uint8_t* __restrict__ digests, - uint32_t num_slots) -{ - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= num_slots) return; - lamport_sha256_block_32(slots + tid * 32u, digests + tid * 32u); -} diff --git a/lamport/gpu/cuda/lamport_cuda_oracle.cpp b/lamport/gpu/cuda/lamport_cuda_oracle.cpp deleted file mode 100644 index ecbef21..0000000 --- a/lamport/gpu/cuda/lamport_cuda_oracle.cpp +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Host-side replay of the Lamport-SHA256 CUDA kernel arithmetic. Always built -// (no nvcc dependency). The CUDA determinism test runs the canonical CPU body -// (lamport/cpp/lamport.cpp) and asserts byte-equality against this oracle so -// the test exercises the kernel arithmetic on hosts without an NVIDIA device. -// -// On hosts with CUDA enabled, the same arithmetic also runs in -// `lamport_hash_jobs_kernel` (lamport/gpu/cuda/lamport.cu); both paths emit -// identical bytes by construction. - -#include - -namespace { - -inline uint32_t lamport_rotr32_host(uint32_t x, uint32_t n) { - return (x >> n) | (x << (32u - n)); -} - -inline void lamport_sha256_block_32_host(const uint8_t* in, uint8_t* out) { - static const uint32_t K[64] = { - 0x428a2f98u, 0x71374491u, 0xb5c0fbcfu, 0xe9b5dba5u, - 0x3956c25bu, 0x59f111f1u, 0x923f82a4u, 0xab1c5ed5u, - 0xd807aa98u, 0x12835b01u, 0x243185beu, 0x550c7dc3u, - 0x72be5d74u, 0x80deb1feu, 0x9bdc06a7u, 0xc19bf174u, - 0xe49b69c1u, 0xefbe4786u, 0x0fc19dc6u, 0x240ca1ccu, - 0x2de92c6fu, 0x4a7484aau, 0x5cb0a9dcu, 0x76f988dau, - 0x983e5152u, 0xa831c66du, 0xb00327c8u, 0xbf597fc7u, - 0xc6e00bf3u, 0xd5a79147u, 0x06ca6351u, 0x14292967u, - 0x27b70a85u, 0x2e1b2138u, 0x4d2c6dfcu, 0x53380d13u, - 0x650a7354u, 0x766a0abbu, 0x81c2c92eu, 0x92722c85u, - 0xa2bfe8a1u, 0xa81a664bu, 0xc24b8b70u, 0xc76c51a3u, - 0xd192e819u, 0xd6990624u, 0xf40e3585u, 0x106aa070u, - 0x19a4c116u, 0x1e376c08u, 0x2748774cu, 0x34b0bcb5u, - 0x391c0cb3u, 0x4ed8aa4au, 0x5b9cca4fu, 0x682e6ff3u, - 0x748f82eeu, 0x78a5636fu, 0x84c87814u, 0x8cc70208u, - 0x90befffau, 0xa4506cebu, 0xbef9a3f7u, 0xc67178f2u - }; - - uint32_t w[64]; - for (uint32_t i = 0; i < 8u; ++i) { - w[i] = (uint32_t(in[i*4u + 0]) << 24) - | (uint32_t(in[i*4u + 1]) << 16) - | (uint32_t(in[i*4u + 2]) << 8) - | (uint32_t(in[i*4u + 3]) ); - } - w[ 8] = 0x80000000u; - w[ 9] = 0u; w[10] = 0u; w[11] = 0u; - w[12] = 0u; w[13] = 0u; w[14] = 0u; - w[15] = 256u; - for (uint32_t i = 16u; i < 64u; ++i) { - uint32_t s0 = lamport_rotr32_host(w[i-15], 7) ^ lamport_rotr32_host(w[i-15], 18) ^ (w[i-15] >> 3); - uint32_t s1 = lamport_rotr32_host(w[i- 2], 17) ^ lamport_rotr32_host(w[i- 2], 19) ^ (w[i- 2] >> 10); - w[i] = s1 + w[i-7] + s0 + w[i-16]; - } - - uint32_t a = 0x6a09e667u, b = 0xbb67ae85u, c = 0x3c6ef372u, d = 0xa54ff53au; - uint32_t e = 0x510e527fu, f = 0x9b05688cu, g = 0x1f83d9abu, h = 0x5be0cd19u; - - for (uint32_t i = 0; i < 64u; ++i) { - uint32_t S1 = lamport_rotr32_host(e, 6) ^ lamport_rotr32_host(e, 11) ^ lamport_rotr32_host(e, 25); - uint32_t ch = (e & f) ^ ((~e) & g); - uint32_t t1 = h + S1 + ch + K[i] + w[i]; - uint32_t S0 = lamport_rotr32_host(a, 2) ^ lamport_rotr32_host(a, 13) ^ lamport_rotr32_host(a, 22); - uint32_t mj = (a & b) ^ (a & c) ^ (b & c); - uint32_t t2 = S0 + mj; - h = g; g = f; f = e; e = d + t1; - d = c; c = b; b = a; a = t1 + t2; - } - a += 0x6a09e667u; b += 0xbb67ae85u; c += 0x3c6ef372u; d += 0xa54ff53au; - e += 0x510e527fu; f += 0x9b05688cu; g += 0x1f83d9abu; h += 0x5be0cd19u; - - uint32_t H[8] = { a, b, c, d, e, f, g, h }; - for (uint32_t i = 0; i < 8u; ++i) { - out[i*4u + 0] = uint8_t((H[i] >> 24) & 0xFFu); - out[i*4u + 1] = uint8_t((H[i] >> 16) & 0xFFu); - out[i*4u + 2] = uint8_t((H[i] >> 8) & 0xFFu); - out[i*4u + 3] = uint8_t( H[i] & 0xFFu); - } -} - -} // namespace - -extern "C" void lamport_hash_jobs_cuda_oracle( - const uint8_t* slots, - uint8_t* digests, - uint32_t num_slots) -{ - for (uint32_t i = 0; i < num_slots; ++i) { - lamport_sha256_block_32_host(slots + i * 32u, digests + i * 32u); - } -} diff --git a/lamport/gpu/metal/lamport_batch.metal b/lamport/gpu/metal/lamport_batch.metal deleted file mode 100644 index 1ced7a9..0000000 --- a/lamport/gpu/metal/lamport_batch.metal +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// GPU kernel for Lamport-SHA256 OTS — one thread per preimage. The kernel -// hashes each 32-byte slot independently with canonical FIPS 180-4 SHA-256 -// padding (one 64-byte block: 32 input bytes, 0x80 marker, zero pad, 64-bit -// big-endian bit length = 256). The host driver supplies a flat preimage -// arena `slots` and reads a flat 32-byte digest arena `digests`. -// -// This is enough to bit-for-bit replay both Lamport keygen (sk slot -> pk slot) -// and Lamport verification (sig slot -> hash to compare against pk slot). -// Byte-equal to lamport/cpp/lamport.cpp. - -#include -using namespace metal; - -constant uint K[64] = { - 0x428a2f98u, 0x71374491u, 0xb5c0fbcfu, 0xe9b5dba5u, - 0x3956c25bu, 0x59f111f1u, 0x923f82a4u, 0xab1c5ed5u, - 0xd807aa98u, 0x12835b01u, 0x243185beu, 0x550c7dc3u, - 0x72be5d74u, 0x80deb1feu, 0x9bdc06a7u, 0xc19bf174u, - 0xe49b69c1u, 0xefbe4786u, 0x0fc19dc6u, 0x240ca1ccu, - 0x2de92c6fu, 0x4a7484aau, 0x5cb0a9dcu, 0x76f988dau, - 0x983e5152u, 0xa831c66du, 0xb00327c8u, 0xbf597fc7u, - 0xc6e00bf3u, 0xd5a79147u, 0x06ca6351u, 0x14292967u, - 0x27b70a85u, 0x2e1b2138u, 0x4d2c6dfcu, 0x53380d13u, - 0x650a7354u, 0x766a0abbu, 0x81c2c92eu, 0x92722c85u, - 0xa2bfe8a1u, 0xa81a664bu, 0xc24b8b70u, 0xc76c51a3u, - 0xd192e819u, 0xd6990624u, 0xf40e3585u, 0x106aa070u, - 0x19a4c116u, 0x1e376c08u, 0x2748774cu, 0x34b0bcb5u, - 0x391c0cb3u, 0x4ed8aa4au, 0x5b9cca4fu, 0x682e6ff3u, - 0x748f82eeu, 0x78a5636fu, 0x84c87814u, 0x8cc70208u, - 0x90befffau, 0xa4506cebu, 0xbef9a3f7u, 0xc67178f2u -}; - -inline uint rotr(uint x, uint n) { return (x >> n) | (x << (32u - n)); } -inline uint ch (uint x, uint y, uint z) { return (x & y) ^ ((~x) & z); } -inline uint maj(uint x, uint y, uint z) { return (x & y) ^ (x & z) ^ (y & z); } -inline uint S0 (uint x) { return rotr(x, 2) ^ rotr(x, 13) ^ rotr(x, 22); } -inline uint S1 (uint x) { return rotr(x, 6) ^ rotr(x, 11) ^ rotr(x, 25); } -inline uint s0 (uint x) { return rotr(x, 7) ^ rotr(x, 18) ^ (x >> 3); } -inline uint s1 (uint x) { return rotr(x, 17) ^ rotr(x, 19) ^ (x >> 10); } - -kernel void lamport_hash_jobs( - device const uchar* slots [[buffer(0)]], // 32 bytes per slot - device uchar* digests [[buffer(1)]], // 32 bytes per slot - constant uint& num_slots [[buffer(2)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= num_slots) return; - - device const uchar* in = slots + tid * 32u; - device uchar* out = digests + tid * 32u; - - // Canonical FIPS 180-4 padding for a 32-byte preimage: - // single block, words 0..7 = input (BE), w[8] = 0x80000000, w[9..14] = 0, - // w[15] = 256 (bit length). - uint w[64]; - for (uint i = 0; i < 8u; ++i) { - w[i] = (uint(in[i*4u + 0]) << 24) - | (uint(in[i*4u + 1]) << 16) - | (uint(in[i*4u + 2]) << 8) - | (uint(in[i*4u + 3]) ); - } - w[ 8] = 0x80000000u; - w[ 9] = 0u; w[10] = 0u; w[11] = 0u; - w[12] = 0u; w[13] = 0u; w[14] = 0u; - w[15] = 256u; - for (uint i = 16u; i < 64u; ++i) { - w[i] = s1(w[i-2]) + w[i-7] + s0(w[i-15]) + w[i-16]; - } - - uint a = 0x6a09e667u, b = 0xbb67ae85u, c = 0x3c6ef372u, d = 0xa54ff53au; - uint e = 0x510e527fu, f = 0x9b05688cu, g = 0x1f83d9abu, h = 0x5be0cd19u; - - for (uint i = 0; i < 64u; ++i) { - uint t1 = h + S1(e) + ch(e, f, g) + K[i] + w[i]; - uint t2 = S0(a) + maj(a, b, c); - h = g; g = f; f = e; e = d + t1; - d = c; c = b; b = a; a = t1 + t2; - } - a += 0x6a09e667u; b += 0xbb67ae85u; c += 0x3c6ef372u; d += 0xa54ff53au; - e += 0x510e527fu; f += 0x9b05688cu; g += 0x1f83d9abu; h += 0x5be0cd19u; - - uint H[8] = { a, b, c, d, e, f, g, h }; - for (uint i = 0; i < 8u; ++i) { - out[i*4u + 0] = uchar((H[i] >> 24) & 0xFFu); - out[i*4u + 1] = uchar((H[i] >> 16) & 0xFFu); - out[i*4u + 2] = uchar((H[i] >> 8) & 0xFFu); - out[i*4u + 3] = uchar( H[i] & 0xFFu); - } -} diff --git a/lamport/gpu/metal/lamport_batch_driver.mm b/lamport/gpu/metal/lamport_batch_driver.mm deleted file mode 100644 index cbbd0a5..0000000 --- a/lamport/gpu/metal/lamport_batch_driver.mm +++ /dev/null @@ -1,111 +0,0 @@ -// ============================================================================= -// luxcpp/crypto/lamport - Metal driver for batched verify -// ============================================================================= -// -// Loads lamport_batch.metallib at the path supplied by the caller, dispatches -// the lamport_verify_batch kernel with one thread per signature, and writes -// per-input pass/fail flags into results_arena. -// -// Buffer layout (caller-provided): -// pks_arena[count * 16384] - 512 hashes of 32B each -// sigs_arena[count * 8192] - 256 values of 32B each -// msgs_arena[count * 32] - 32-byte message hashes -// results_arena[count] - 1 (valid) or 0 (invalid), uint32 each -// -// Returns 0 on success, negative error code otherwise. The kernel itself is -// byte-equal to /Users/z/work/luxcpp/crypto/lamport/cpp/lamport.cpp::verify(). -// -// SPDX-License-Identifier: BSD-3-Clause-Eco -// Copyright (C) 2025-2026 Lux Industries Inc. -// ============================================================================= - -#if __APPLE__ && __OBJC__ - -#import -#import - -#include -#include -#include - -extern "C" int lamport_batch_verify_metal( - const uint8_t* pks_arena, // count * 16384 bytes - const uint8_t* sigs_arena, // count * 8192 bytes - const uint8_t* msgs_arena, // count * 32 bytes - uint32_t count, - uint32_t* results_arena, // count uint32 entries - const char* metallib_path) { - - if (count == 0) return 0; - if (!pks_arena || !sigs_arena || !msgs_arena || !results_arena || - !metallib_path) { - return -1; - } - - @autoreleasepool { - id device = MTLCreateSystemDefaultDevice(); - if (!device) return -2; - - NSError* err = nil; - NSString* path = [NSString stringWithUTF8String:metallib_path]; - NSURL* url = [NSURL fileURLWithPath:path]; - id lib = [device newLibraryWithURL:url error:&err]; - if (!lib) return -3; - - id fn = [lib newFunctionWithName:@"lamport_verify_batch"]; - if (!fn) return -4; - - id pipeline = - [device newComputePipelineStateWithFunction:fn error:&err]; - if (!pipeline) return -5; - - id queue = [device newCommandQueue]; - - const size_t pk_bytes = size_t(count) * 16384u; - const size_t sig_bytes = size_t(count) * 8192u; - const size_t msg_bytes = size_t(count) * 32u; - const size_t res_bytes = size_t(count) * sizeof(uint32_t); - - id pks_buf = [device newBufferWithBytes:pks_arena - length:pk_bytes - options:MTLResourceStorageModeShared]; - id sigs_buf = [device newBufferWithBytes:sigs_arena - length:sig_bytes - options:MTLResourceStorageModeShared]; - id msgs_buf = [device newBufferWithBytes:msgs_arena - length:msg_bytes - options:MTLResourceStorageModeShared]; - id res_buf = [device newBufferWithLength:res_bytes - options:MTLResourceStorageModeShared]; - uint32_t count_u32 = count; - id count_buf = [device newBufferWithBytes:&count_u32 - length:sizeof(uint32_t) - options:MTLResourceStorageModeShared]; - if (!pks_buf || !sigs_buf || !msgs_buf || !res_buf || !count_buf) { - return -6; - } - - id cmd = [queue commandBuffer]; - id enc = [cmd computeCommandEncoder]; - [enc setComputePipelineState:pipeline]; - [enc setBuffer:pks_buf offset:0 atIndex:0]; - [enc setBuffer:sigs_buf offset:0 atIndex:1]; - [enc setBuffer:msgs_buf offset:0 atIndex:2]; - [enc setBuffer:res_buf offset:0 atIndex:3]; - [enc setBuffer:count_buf offset:0 atIndex:4]; - - NSUInteger tg_max = pipeline.maxTotalThreadsPerThreadgroup; - NSUInteger tg_w = tg_max < 32 ? tg_max : 32; - MTLSize threads_per_grid = MTLSizeMake(count, 1, 1); - MTLSize threads_per_tg = MTLSizeMake(tg_w, 1, 1); - [enc dispatchThreads:threads_per_grid threadsPerThreadgroup:threads_per_tg]; - [enc endEncoding]; - [cmd commit]; - [cmd waitUntilCompleted]; - - std::memcpy(results_arena, [res_buf contents], res_bytes); - } - return 0; -} - -#endif // __APPLE__ && __OBJC__ diff --git a/lamport/gpu/metal/lamport_driver.h b/lamport/gpu/metal/lamport_driver.h deleted file mode 100644 index 79204ae..0000000 --- a/lamport/gpu/metal/lamport_driver.h +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Public C surface for the Lamport-SHA256 Metal driver. The driver hashes a -// flat array of 32-byte preimages with FIPS 180-4 padding, byte-equal to -// lamport/cpp/lamport.cpp. - -#pragma once - -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -/// One thread per slot. `slots` and `digests` are flat arenas of size -/// num_slots * 32 bytes. Returns 0 on success, negative on failure. -int lamport_hash_batch_metal( - const uint8_t* slots, - size_t num_slots, - uint8_t* digests, - const char* metallib_path); - -#ifdef __cplusplus -} -#endif diff --git a/lamport/gpu/metal/lamport_driver.mm b/lamport/gpu/metal/lamport_driver.mm deleted file mode 100644 index b3a4af5..0000000 --- a/lamport/gpu/metal/lamport_driver.mm +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Metal driver for batched Lamport-SHA256 OTS preimage hashing. macOS / iOS -// only. Loads lamport_batch.metallib, dispatches `lamport_hash_jobs` with one -// thread per 32-byte slot. Byte-equal to lamport/cpp/lamport.cpp::keygen() and -// lamport/cpp/lamport.cpp::verify(). - -#if __APPLE__ && __OBJC__ - -#import -#import - -#include "lamport_driver.h" - -#include -#include - -extern "C" int lamport_hash_batch_metal( - const uint8_t* slots, - size_t num_slots, - uint8_t* digests, - const char* metallib_path) { - - if (num_slots == 0) return 0; - if (!slots || !digests || !metallib_path) return -1; - - @autoreleasepool { - id device = MTLCreateSystemDefaultDevice(); - if (!device) return -2; - - NSError* err = nil; - NSString* path = [NSString stringWithUTF8String:metallib_path]; - NSURL* url = [NSURL fileURLWithPath:path]; - id lib = [device newLibraryWithURL:url error:&err]; - if (!lib) return -3; - - id fn = [lib newFunctionWithName:@"lamport_hash_jobs"]; - if (!fn) return -4; - - id pipeline = - [device newComputePipelineStateWithFunction:fn error:&err]; - if (!pipeline) return -5; - - id queue = [device newCommandQueue]; - - id slots_buf = [device newBufferWithBytes:slots - length:num_slots * 32 - options:MTLResourceStorageModeShared]; - id digests_buf = [device newBufferWithLength:num_slots * 32 - options:MTLResourceStorageModeShared]; - uint32_t n_u32 = (uint32_t)num_slots; - id n_buf = [device newBufferWithBytes:&n_u32 - length:sizeof(n_u32) - options:MTLResourceStorageModeShared]; - - id cmd = [queue commandBuffer]; - id enc = [cmd computeCommandEncoder]; - [enc setComputePipelineState:pipeline]; - [enc setBuffer:slots_buf offset:0 atIndex:0]; - [enc setBuffer:digests_buf offset:0 atIndex:1]; - [enc setBuffer:n_buf offset:0 atIndex:2]; - - NSUInteger tg_max = pipeline.maxTotalThreadsPerThreadgroup; - NSUInteger tg_w = tg_max < 64 ? tg_max : 64; - MTLSize threads_per_grid = MTLSizeMake(num_slots, 1, 1); - MTLSize threads_per_tg = MTLSizeMake(tg_w, 1, 1); - [enc dispatchThreads:threads_per_grid threadsPerThreadgroup:threads_per_tg]; - [enc endEncoding]; - [cmd commit]; - [cmd waitUntilCompleted]; - - std::memcpy(digests, [digests_buf contents], num_slots * 32); - } - return 0; -} - -#endif // __APPLE__ && __OBJC__ diff --git a/lamport/gpu/wgsl/lamport.wgsl b/lamport/gpu/wgsl/lamport.wgsl deleted file mode 100644 index d3523c2..0000000 --- a/lamport/gpu/wgsl/lamport.wgsl +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// WGSL compute shader for Lamport-SHA256 OTS — one thread per 32-byte -// preimage. Byte-equal to lamport/cpp/lamport.cpp. Mirrors -// lamport/gpu/metal/lamport_batch.metal and lamport/gpu/cuda/lamport.cu. -// -// Buffers: -// slots : array (8 u32 = 32 bytes per slot, big-endian-packed) -// digests : array (8 u32 = 32 bytes per slot, big-endian-packed) -// params : Params (uniform { num_slots: u32 }) - -struct Params { num_slots: u32 }; - -@group(0) @binding(0) var slots: array; -@group(0) @binding(1) var digests: array; -@group(0) @binding(2) var params: Params; - -const K: array = array( - 0x428a2f98u, 0x71374491u, 0xb5c0fbcfu, 0xe9b5dba5u, - 0x3956c25bu, 0x59f111f1u, 0x923f82a4u, 0xab1c5ed5u, - 0xd807aa98u, 0x12835b01u, 0x243185beu, 0x550c7dc3u, - 0x72be5d74u, 0x80deb1feu, 0x9bdc06a7u, 0xc19bf174u, - 0xe49b69c1u, 0xefbe4786u, 0x0fc19dc6u, 0x240ca1ccu, - 0x2de92c6fu, 0x4a7484aau, 0x5cb0a9dcu, 0x76f988dau, - 0x983e5152u, 0xa831c66du, 0xb00327c8u, 0xbf597fc7u, - 0xc6e00bf3u, 0xd5a79147u, 0x06ca6351u, 0x14292967u, - 0x27b70a85u, 0x2e1b2138u, 0x4d2c6dfcu, 0x53380d13u, - 0x650a7354u, 0x766a0abbu, 0x81c2c92eu, 0x92722c85u, - 0xa2bfe8a1u, 0xa81a664bu, 0xc24b8b70u, 0xc76c51a3u, - 0xd192e819u, 0xd6990624u, 0xf40e3585u, 0x106aa070u, - 0x19a4c116u, 0x1e376c08u, 0x2748774cu, 0x34b0bcb5u, - 0x391c0cb3u, 0x4ed8aa4au, 0x5b9cca4fu, 0x682e6ff3u, - 0x748f82eeu, 0x78a5636fu, 0x84c87814u, 0x8cc70208u, - 0x90befffau, 0xa4506cebu, 0xbef9a3f7u, 0xc67178f2u -); - -fn rotr32(x: u32, n: u32) -> u32 { - return (x >> n) | (x << (32u - n)); -} - -fn ch (x: u32, y: u32, z: u32) -> u32 { return (x & y) ^ ((~x) & z); } -fn maj(x: u32, y: u32, z: u32) -> u32 { return (x & y) ^ (x & z) ^ (y & z); } -fn S0 (x: u32) -> u32 { return rotr32(x, 2u) ^ rotr32(x, 13u) ^ rotr32(x, 22u); } -fn S1 (x: u32) -> u32 { return rotr32(x, 6u) ^ rotr32(x, 11u) ^ rotr32(x, 25u); } -fn s0 (x: u32) -> u32 { return rotr32(x, 7u) ^ rotr32(x, 18u) ^ (x >> 3u); } -fn s1 (x: u32) -> u32 { return rotr32(x, 17u) ^ rotr32(x, 19u) ^ (x >> 10u); } - -@compute @workgroup_size(64) -fn lamport_hash_jobs(@builtin(global_invocation_id) gid: vec3) { - let tid = gid.x; - if (tid >= params.num_slots) { return; } - - var w: array; - let in_base = tid * 8u; - for (var i = 0u; i < 8u; i = i + 1u) { - w[i] = slots[in_base + i]; - } - w[ 8] = 0x80000000u; - w[ 9] = 0u; w[10] = 0u; w[11] = 0u; - w[12] = 0u; w[13] = 0u; w[14] = 0u; - w[15] = 256u; - for (var i = 16u; i < 64u; i = i + 1u) { - w[i] = s1(w[i - 2u]) + w[i - 7u] + s0(w[i - 15u]) + w[i - 16u]; - } - - var a: u32 = 0x6a09e667u; - var b: u32 = 0xbb67ae85u; - var c: u32 = 0x3c6ef372u; - var d: u32 = 0xa54ff53au; - var e: u32 = 0x510e527fu; - var f: u32 = 0x9b05688cu; - var g: u32 = 0x1f83d9abu; - var h: u32 = 0x5be0cd19u; - - for (var i = 0u; i < 64u; i = i + 1u) { - let t1 = h + S1(e) + ch(e, f, g) + K[i] + w[i]; - let t2 = S0(a) + maj(a, b, c); - h = g; g = f; f = e; e = d + t1; - d = c; c = b; b = a; a = t1 + t2; - } - a = a + 0x6a09e667u; - b = b + 0xbb67ae85u; - c = c + 0x3c6ef372u; - d = d + 0xa54ff53au; - e = e + 0x510e527fu; - f = f + 0x9b05688cu; - g = g + 0x1f83d9abu; - h = h + 0x5be0cd19u; - - let out_base = tid * 8u; - digests[out_base + 0u] = a; - digests[out_base + 1u] = b; - digests[out_base + 2u] = c; - digests[out_base + 3u] = d; - digests[out_base + 4u] = e; - digests[out_base + 5u] = f; - digests[out_base + 6u] = g; - digests[out_base + 7u] = h; -} diff --git a/lamport/gpu/wgsl/lamport_wgsl_oracle.cpp b/lamport/gpu/wgsl/lamport_wgsl_oracle.cpp deleted file mode 100644 index ef83420..0000000 --- a/lamport/gpu/wgsl/lamport_wgsl_oracle.cpp +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Host-side replay of the Lamport-SHA256 WGSL kernel arithmetic. Always -// built. The WGSL determinism test asserts byte-equality between the -// canonical CPU body (lamport/cpp/lamport.cpp) and this oracle so the test -// exercises the WGSL kernel arithmetic without depending on a live WebGPU -// runtime. Because the WGSL shader and this oracle are line-for-line the -// same arithmetic over u32 (the only integer type WGSL guarantees), they -// produce identical bytes by construction. - -#include - -namespace { - -inline uint32_t rotr32(uint32_t x, uint32_t n) { - return (x >> n) | (x << (32u - n)); -} - -inline uint32_t ch (uint32_t x, uint32_t y, uint32_t z) { return (x & y) ^ ((~x) & z); } -inline uint32_t maj(uint32_t x, uint32_t y, uint32_t z) { return (x & y) ^ (x & z) ^ (y & z); } -inline uint32_t S0 (uint32_t x) { return rotr32(x, 2) ^ rotr32(x, 13) ^ rotr32(x, 22); } -inline uint32_t S1 (uint32_t x) { return rotr32(x, 6) ^ rotr32(x, 11) ^ rotr32(x, 25); } -inline uint32_t s0 (uint32_t x) { return rotr32(x, 7) ^ rotr32(x, 18) ^ (x >> 3); } -inline uint32_t s1 (uint32_t x) { return rotr32(x, 17) ^ rotr32(x, 19) ^ (x >> 10); } - -} // namespace - -extern "C" void lamport_hash_jobs_wgsl_oracle( - const uint8_t* slots, - uint8_t* digests, - uint32_t num_slots) -{ - static const uint32_t K[64] = { - 0x428a2f98u, 0x71374491u, 0xb5c0fbcfu, 0xe9b5dba5u, - 0x3956c25bu, 0x59f111f1u, 0x923f82a4u, 0xab1c5ed5u, - 0xd807aa98u, 0x12835b01u, 0x243185beu, 0x550c7dc3u, - 0x72be5d74u, 0x80deb1feu, 0x9bdc06a7u, 0xc19bf174u, - 0xe49b69c1u, 0xefbe4786u, 0x0fc19dc6u, 0x240ca1ccu, - 0x2de92c6fu, 0x4a7484aau, 0x5cb0a9dcu, 0x76f988dau, - 0x983e5152u, 0xa831c66du, 0xb00327c8u, 0xbf597fc7u, - 0xc6e00bf3u, 0xd5a79147u, 0x06ca6351u, 0x14292967u, - 0x27b70a85u, 0x2e1b2138u, 0x4d2c6dfcu, 0x53380d13u, - 0x650a7354u, 0x766a0abbu, 0x81c2c92eu, 0x92722c85u, - 0xa2bfe8a1u, 0xa81a664bu, 0xc24b8b70u, 0xc76c51a3u, - 0xd192e819u, 0xd6990624u, 0xf40e3585u, 0x106aa070u, - 0x19a4c116u, 0x1e376c08u, 0x2748774cu, 0x34b0bcb5u, - 0x391c0cb3u, 0x4ed8aa4au, 0x5b9cca4fu, 0x682e6ff3u, - 0x748f82eeu, 0x78a5636fu, 0x84c87814u, 0x8cc70208u, - 0x90befffau, 0xa4506cebu, 0xbef9a3f7u, 0xc67178f2u - }; - - for (uint32_t k = 0; k < num_slots; ++k) { - const uint8_t* in = slots + k * 32u; - uint8_t* out = digests + k * 32u; - - uint32_t w[64]; - for (uint32_t i = 0; i < 8u; ++i) { - w[i] = (uint32_t(in[i*4u + 0]) << 24) - | (uint32_t(in[i*4u + 1]) << 16) - | (uint32_t(in[i*4u + 2]) << 8) - | (uint32_t(in[i*4u + 3]) ); - } - w[ 8] = 0x80000000u; - w[ 9] = 0u; w[10] = 0u; w[11] = 0u; - w[12] = 0u; w[13] = 0u; w[14] = 0u; - w[15] = 256u; - for (uint32_t i = 16u; i < 64u; ++i) { - w[i] = s1(w[i-2]) + w[i-7] + s0(w[i-15]) + w[i-16]; - } - - uint32_t a = 0x6a09e667u, b = 0xbb67ae85u, c = 0x3c6ef372u, d = 0xa54ff53au; - uint32_t e = 0x510e527fu, f = 0x9b05688cu, g = 0x1f83d9abu, h = 0x5be0cd19u; - for (uint32_t i = 0; i < 64u; ++i) { - uint32_t t1 = h + S1(e) + ch(e, f, g) + K[i] + w[i]; - uint32_t t2 = S0(a) + maj(a, b, c); - h = g; g = f; f = e; e = d + t1; - d = c; c = b; b = a; a = t1 + t2; - } - a += 0x6a09e667u; b += 0xbb67ae85u; c += 0x3c6ef372u; d += 0xa54ff53au; - e += 0x510e527fu; f += 0x9b05688cu; g += 0x1f83d9abu; h += 0x5be0cd19u; - - uint32_t H[8] = { a, b, c, d, e, f, g, h }; - for (uint32_t i = 0; i < 8u; ++i) { - out[i*4u + 0] = uint8_t((H[i] >> 24) & 0xFFu); - out[i*4u + 1] = uint8_t((H[i] >> 16) & 0xFFu); - out[i*4u + 2] = uint8_t((H[i] >> 8) & 0xFFu); - out[i*4u + 3] = uint8_t( H[i] & 0xFFu); - } - } -} diff --git a/math/ntt/cuda/lattice_ring.cu b/math/ntt/cuda/lattice_ring.cu deleted file mode 100644 index 2b96af9..0000000 --- a/math/ntt/cuda/lattice_ring.cu +++ /dev/null @@ -1,782 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// lattice_ring.cu — CUDA port of the M1 lattice_ring NTT + Montgomery -// scalar arithmetic. Byte-equal to the CPU oracle at -// crypto/ringtail/cpp/lattice_ring.cpp by construction (identical scalar -// formulas, identical 64×64 → 128 multiplies, identical butterfly atoms). -// -// Ringtail params (LP-073): R = Z_q[X]/(X^256 + 1), -// q = 0x1000000004A01 (281474976729601, 48-bit, q ≡ 1 mod 512) -// N = 256 (deg) -// -// Build modes: -// * CRYPTO_ENABLE_CUDA=ON — compiled by nvcc as a real CUDA TU. -// One block per polynomial, 128 threads per block (one thread per -// butterfly across 8 NTT layers; coefficients live in __shared__ -// memory across stages, __syncthreads() between layers). -// -// * CRYPTO_ENABLE_CUDA=OFF — same .cu compiles as plain C++ via the -// host polyfill (LUX_LATTICE_RING_CUDA_HOST_POLYFILL). -// The host driver entry points run the CPU oracle path (same scalar -// atoms) so byte-equal output is structural. CI on Linux+CUDA hosts -// exercises the kernel; macOS / non-CUDA hosts exercise the wire -// format end-to-end against the same atoms. -// -// API names mirror the Metal sister (gpu/metal/lattice_ring_driver.h) -// so the upstream caller can pick a backend by env var and use one dispatch -// surface across Metal / CUDA / WGSL. -// -// Mirrored Go references: -// ring.GenMRedConstant(q) — modular_reduction.go:67 -// ring.GenBRedConstant(q) — modular_reduction.go:99 -// ring.MForm(a, q, brc) — modular_reduction.go:11 -// ring.IMForm(a, q, mrc) — modular_reduction.go:49 -// ring.MRed(x, y, q, mrc) — modular_reduction.go:78 -// ring.MRedLazy(x, y, q, mrc) — modular_reduction.go:90 -// ring.BRedAdd(a, q, brc) — modular_reduction.go:110 -// ring.NTTStandard(...) — ntt.go:174 -// ring.INTTStandard(...) — ntt.go:185 -// vec_ops.mulcoeffsmontgomeryvec — vec_ops.go:313 -// vec_ops.mulcoeffsmontgomerythenadd — vec_ops.go:360 - -#include -#include - -// Polyfill: when not building with nvcc, neutralize device qualifiers so the -// kernel body is plain C++. Pattern matches banderwagon/gpu/cuda/banderwagon.cu. -#ifndef __CUDACC__ -# ifndef LUX_LATTICE_RING_CUDA_HOST_POLYFILL -# define LUX_LATTICE_RING_CUDA_HOST_POLYFILL 1 -# endif -#endif - -#if LUX_LATTICE_RING_CUDA_HOST_POLYFILL -# define __device__ -# define __global__ -# define __host__ -# define __forceinline__ inline -# define __shared__ -#endif - -namespace lux::crypto::ringtail::lattice_ring_cuda { - -// ============================================================================= -// 64×64 → 128 multiply. -// On device: __umul64hi + plain product. -// On host: __uint128_t (always available on the targets we ship to). -// Returns lo and hi, matching math/bits.Mul64 byte-for-byte. -// ============================================================================= -__device__ __forceinline__ void mul64(uint64_t a, uint64_t b, - uint64_t& hi, uint64_t& lo) { -#if defined(__CUDA_ARCH__) - lo = a * b; - hi = __umul64hi(a, b); -#else - __uint128_t p = static_cast<__uint128_t>(a) * static_cast<__uint128_t>(b); - lo = static_cast(p); - hi = static_cast(p >> 64); -#endif -} - -// ============================================================================= -// Scalar atoms — byte-equal to lux::crypto::ringtail::lattice_ring::* -// ============================================================================= - -__device__ __forceinline__ uint64_t MForm_dev(uint64_t a, uint64_t q, - uint64_t brc_hi, uint64_t brc_lo) { - uint64_t mhi, mlo; - mul64(a, brc_lo, mhi, mlo); - (void)mlo; - uint64_t r = static_cast(0) - (a * brc_hi + mhi) * q; - if (r >= q) r -= q; - return r; -} - -__device__ __forceinline__ uint64_t IMForm_dev(uint64_t a, uint64_t q, - uint64_t mrc) { - uint64_t hi, lo; - mul64(a * mrc, q, hi, lo); - (void)lo; - uint64_t r = q - hi; - if (r >= q) r -= q; - return r; -} - -__device__ __forceinline__ uint64_t MRed_dev(uint64_t x, uint64_t y, uint64_t q, - uint64_t mrc) { - uint64_t mhi, mlo; - mul64(x, y, mhi, mlo); - uint64_t hhi, hlo; - mul64(mlo * mrc, q, hhi, hlo); - (void)hlo; - uint64_t r = mhi - hhi + q; - if (r >= q) r -= q; - return r; -} - -__device__ __forceinline__ uint64_t MRedLazy_dev(uint64_t x, uint64_t y, - uint64_t q, uint64_t mrc) { - uint64_t mhi, mlo; - mul64(x, y, mhi, mlo); - uint64_t hhi, hlo; - mul64(mlo * mrc, q, hhi, hlo); - (void)hlo; - return mhi - hhi + q; -} - -__device__ __forceinline__ uint64_t BRedAdd_dev(uint64_t a, uint64_t q, - uint64_t brc_hi) { - uint64_t mhi, mlo; - mul64(a, brc_hi, mhi, mlo); - (void)mlo; - uint64_t r = a - mhi * q; - if (r >= q) r -= q; - return r; -} - -__device__ __forceinline__ uint64_t CRed_dev(uint64_t a, uint64_t q) { - return (a >= q) ? (a - q) : a; -} - -// ============================================================================= -// Per-coefficient kernels (one block, N threads). -// For a single polynomial, grid = 1 and each thread owns one coefficient. -// The same kernels work transparently for batches: call with grid = num_polys -// and the per-block addressing picks the right slice. -// ============================================================================= - -__global__ void mform_kernel(const uint64_t* __restrict__ in, - uint64_t* __restrict__ out, - uint32_t N, uint64_t q, - uint64_t brc_hi, uint64_t brc_lo) { -#if defined(__CUDA_ARCH__) - uint32_t poly = blockIdx.x; - uint32_t i = threadIdx.x; - if (i >= N) return; - const uint64_t* p_in = in + (size_t)poly * N; - uint64_t* p_out = out + (size_t)poly * N; - p_out[i] = MForm_dev(p_in[i], q, brc_hi, brc_lo); -#else - (void)in; (void)out; (void)N; (void)q; (void)brc_hi; (void)brc_lo; -#endif -} - -__global__ void imform_kernel(const uint64_t* __restrict__ in, - uint64_t* __restrict__ out, - uint32_t N, uint64_t q, uint64_t mrc) { -#if defined(__CUDA_ARCH__) - uint32_t poly = blockIdx.x; - uint32_t i = threadIdx.x; - if (i >= N) return; - const uint64_t* p_in = in + (size_t)poly * N; - uint64_t* p_out = out + (size_t)poly * N; - p_out[i] = IMForm_dev(p_in[i], q, mrc); -#else - (void)in; (void)out; (void)N; (void)q; (void)mrc; -#endif -} - -__global__ void mul_kernel(const uint64_t* __restrict__ a, - const uint64_t* __restrict__ b, - uint64_t* __restrict__ c, - uint32_t N, uint64_t q, uint64_t mrc) { -#if defined(__CUDA_ARCH__) - uint32_t poly = blockIdx.x; - uint32_t i = threadIdx.x; - if (i >= N) return; - const uint64_t* pa = a + (size_t)poly * N; - const uint64_t* pb = b + (size_t)poly * N; - uint64_t* pc = c + (size_t)poly * N; - pc[i] = MRed_dev(pa[i], pb[i], q, mrc); -#else - (void)a; (void)b; (void)c; (void)N; (void)q; (void)mrc; -#endif -} - -__global__ void mul_then_add_kernel(const uint64_t* __restrict__ a, - const uint64_t* __restrict__ b, - uint64_t* __restrict__ acc, - uint32_t N, uint64_t q, uint64_t mrc) { -#if defined(__CUDA_ARCH__) - uint32_t poly = blockIdx.x; - uint32_t i = threadIdx.x; - if (i >= N) return; - const uint64_t* pa = a + (size_t)poly * N; - const uint64_t* pb = b + (size_t)poly * N; - uint64_t* pac = acc + (size_t)poly * N; - pac[i] = CRed_dev(pac[i] + MRed_dev(pa[i], pb[i], q, mrc), q); -#else - (void)a; (void)b; (void)acc; (void)N; (void)q; (void)mrc; -#endif -} - -// ============================================================================= -// NTT / INTT — one block per polynomial, N/2 threads per block. -// -// Each thread owns one butterfly per layer. Coefficients live in shared memory -// across all log2(N)=8 layers; __syncthreads() between layers ensures every -// thread sees the previous layer's writes. -// -// Roots layout matches Lattigo: bit-reversed nega-cyclic forward roots in -// Montgomery form, indexed roots[m + g] for stage m (m = 1, 2, ..., N/2). -// ============================================================================= - -#if defined(__CUDA_ARCH__) -// Forward butterfly atom — see ntt.go:155, lattice_ring.cpp::butterfly. -__device__ __forceinline__ void fwd_butterfly_dev(uint64_t& U, uint64_t& V, - uint64_t Psi, - uint64_t twoQ, uint64_t fourQ, - uint64_t Q, uint64_t mrc) { - if (U >= fourQ) U -= fourQ; - uint64_t Vmul = MRedLazy_dev(V, Psi, Q, mrc); - uint64_t X = U + Vmul; - uint64_t Y = U + twoQ - Vmul; - U = X; - V = Y; -} - -// Inverse butterfly atom — see ntt.go:164, lattice_ring.cpp::invbutterfly. -__device__ __forceinline__ void inv_butterfly_dev(uint64_t& U, uint64_t& V, - uint64_t Psi, - uint64_t twoQ, uint64_t fourQ, - uint64_t Q, uint64_t mrc) { - uint64_t X = U + V; - if (X >= twoQ) X -= twoQ; - uint64_t Y = MRedLazy_dev(U + fourQ - V, Psi, Q, mrc); - U = X; - V = Y; -} -#endif - -// Forward NTT kernel. Block size MUST be N/2 (i.e., 128 for N=256). -__global__ void ntt_kernel(const uint64_t* __restrict__ in, - uint64_t* __restrict__ out, - const uint64_t* __restrict__ roots_fwd, - uint32_t N, uint64_t q, uint64_t mrc, - uint64_t brc_hi) { -#if defined(__CUDA_ARCH__) - extern __shared__ uint64_t s_poly[]; - - const uint32_t poly = blockIdx.x; - const uint32_t tid = threadIdx.x; - const uint64_t twoQ = q << 1; - const uint64_t fourQ = q << 2; - - const uint64_t* p_in = in + (size_t)poly * N; - uint64_t* p_out = out + (size_t)poly * N; - - // Load 2 coefficients per thread (block size = N/2). Thread tid loads - // s_poly[tid] and s_poly[tid + N/2]. - s_poly[tid] = p_in[tid]; - s_poly[tid + (N >> 1)] = p_in[tid + (N >> 1)]; - __syncthreads(); - - // Stages m = 1, 2, ..., N/2. log2(N) layers. For N=256 => 8 layers. - for (uint32_t m = 1; m < N; m <<= 1) { - const uint32_t t = N / (m << 1); // butterflies per group - const uint32_t g = tid / t; // group index - const uint32_t j = tid - g * t; // offset within group - const uint32_t jx = (g * t << 1) + j; - const uint32_t jy = jx + t; - const uint64_t F = roots_fwd[m + g]; - - uint64_t U = s_poly[jx]; - uint64_t V = s_poly[jy]; - fwd_butterfly_dev(U, V, F, twoQ, fourQ, q, mrc); - s_poly[jx] = U; - s_poly[jy] = V; - __syncthreads(); - } - - // BRedAdd per coefficient — 2 coefficients per thread. - p_out[tid] = BRedAdd_dev(s_poly[tid], q, brc_hi); - p_out[tid + (N >> 1)] = BRedAdd_dev(s_poly[tid + (N >> 1)], q, brc_hi); -#else - (void)in; (void)out; (void)roots_fwd; (void)N; (void)q; - (void)mrc; (void)brc_hi; -#endif -} - -// Inverse NTT kernel. Block size MUST be N/2. -__global__ void intt_kernel(const uint64_t* __restrict__ in, - uint64_t* __restrict__ out, - const uint64_t* __restrict__ roots_bwd, - uint32_t N, uint64_t q, uint64_t mrc, - uint64_t n_inv_montgomery) { -#if defined(__CUDA_ARCH__) - extern __shared__ uint64_t s_poly[]; - - const uint32_t poly = blockIdx.x; - const uint32_t tid = threadIdx.x; - const uint64_t twoQ = q << 1; - const uint64_t fourQ = q << 2; - - const uint64_t* p_in = in + (size_t)poly * N; - uint64_t* p_out = out + (size_t)poly * N; - - s_poly[tid] = p_in[tid]; - s_poly[tid + (N >> 1)] = p_in[tid + (N >> 1)]; - __syncthreads(); - - // First sweep (special: t=1, h=N/2). One butterfly per thread. - { - const uint32_t t = 1; - const uint32_t h = N >> 1; - const uint32_t i = tid; // 0 <= i < h - const uint32_t j1 = i * (t << 1); - const uint32_t jx = j1; - const uint32_t jy = j1 + t; - const uint64_t F = roots_bwd[h + i]; - - uint64_t U = s_poly[jx]; - uint64_t V = s_poly[jy]; - inv_butterfly_dev(U, V, F, twoQ, fourQ, q, mrc); - s_poly[jx] = U; - s_poly[jy] = V; - } - __syncthreads(); - - // Subsequent sweeps: m = N/2, N/4, ..., 2. h = m/2. t doubles each step - // starting from 2. - { - uint32_t t = 2; - for (uint32_t m = N >> 1; m > 1; m >>= 1) { - const uint32_t h = m >> 1; - // Total butterflies this sweep = h * t = N/2 (constant). - const uint32_t i = tid / t; - const uint32_t j = tid - i * t; - const uint32_t j1 = i * (t << 1); - const uint32_t jx = j1 + j; - const uint32_t jy = jx + t; - const uint64_t F = roots_bwd[h + i]; - - uint64_t U = s_poly[jx]; - uint64_t V = s_poly[jy]; - inv_butterfly_dev(U, V, F, twoQ, fourQ, q, mrc); - s_poly[jx] = U; - s_poly[jy] = V; - __syncthreads(); - t <<= 1; - } - } - - // Final pass: scale every coefficient by NInv (Montgomery form). Output - // reduced to [0, q). - p_out[tid] = MRed_dev(s_poly[tid], n_inv_montgomery, q, mrc); - p_out[tid + (N >> 1)] = MRed_dev(s_poly[tid + (N >> 1)], n_inv_montgomery, q, mrc); -#else - (void)in; (void)out; (void)roots_bwd; (void)N; (void)q; - (void)mrc; (void)n_inv_montgomery; -#endif -} - -// ============================================================================= -// Host-polyfill scalar atoms (used when LUX_LATTICE_RING_CUDA_HOST_POLYFILL=1). -// Same byte-equal atoms as the device versions, exposed for the host driver. -// ============================================================================= - -#if LUX_LATTICE_RING_CUDA_HOST_POLYFILL - -namespace host_polyfill { - -inline void mul64_h(uint64_t a, uint64_t b, uint64_t& hi, uint64_t& lo) { - __uint128_t p = static_cast<__uint128_t>(a) * static_cast<__uint128_t>(b); - lo = static_cast(p); - hi = static_cast(p >> 64); -} - -inline uint64_t MForm_h(uint64_t a, uint64_t q, uint64_t brc_hi, uint64_t brc_lo) { - uint64_t mhi, mlo; - mul64_h(a, brc_lo, mhi, mlo); - (void)mlo; - uint64_t r = static_cast(0) - (a * brc_hi + mhi) * q; - if (r >= q) r -= q; - return r; -} - -inline uint64_t IMForm_h(uint64_t a, uint64_t q, uint64_t mrc) { - uint64_t hi, lo; - mul64_h(a * mrc, q, hi, lo); - (void)lo; - uint64_t r = q - hi; - if (r >= q) r -= q; - return r; -} - -inline uint64_t MRed_h(uint64_t x, uint64_t y, uint64_t q, uint64_t mrc) { - uint64_t mhi, mlo; - mul64_h(x, y, mhi, mlo); - uint64_t hhi, hlo; - mul64_h(mlo * mrc, q, hhi, hlo); - (void)hlo; - uint64_t r = mhi - hhi + q; - if (r >= q) r -= q; - return r; -} - -inline uint64_t MRedLazy_h(uint64_t x, uint64_t y, uint64_t q, uint64_t mrc) { - uint64_t mhi, mlo; - mul64_h(x, y, mhi, mlo); - uint64_t hhi, hlo; - mul64_h(mlo * mrc, q, hhi, hlo); - (void)hlo; - return mhi - hhi + q; -} - -inline uint64_t BRedAdd_h(uint64_t a, uint64_t q, uint64_t brc_hi) { - uint64_t mhi, mlo; - mul64_h(a, brc_hi, mhi, mlo); - (void)mlo; - uint64_t r = a - mhi * q; - if (r >= q) r -= q; - return r; -} - -inline uint64_t CRed_h(uint64_t a, uint64_t q) { - return (a >= q) ? (a - q) : a; -} - -inline void fwd_butterfly_h(uint64_t& U, uint64_t& V, uint64_t Psi, - uint64_t twoQ, uint64_t fourQ, - uint64_t Q, uint64_t mrc) { - if (U >= fourQ) U -= fourQ; - uint64_t Vmul = MRedLazy_h(V, Psi, Q, mrc); - uint64_t X = U + Vmul; - uint64_t Y = U + twoQ - Vmul; - U = X; - V = Y; -} - -inline void inv_butterfly_h(uint64_t& U, uint64_t& V, uint64_t Psi, - uint64_t twoQ, uint64_t fourQ, - uint64_t Q, uint64_t mrc) { - uint64_t X = U + V; - if (X >= twoQ) X -= twoQ; - uint64_t Y = MRedLazy_h(U + fourQ - V, Psi, Q, mrc); - U = X; - V = Y; -} - -inline void ntt_one_h(const uint64_t* p_in, uint64_t* p_out, - const uint64_t* roots_fwd, - uint32_t N, uint64_t q, uint64_t mrc, uint64_t brc_hi) { - // Mirror lattice_ring.cpp::nttCoreLazy + per-coef BRedAdd reduce. - const uint64_t twoQ = q << 1; - const uint64_t fourQ = q << 2; - // Stage m = 1: read p_in, write p_out. - uint32_t t = N >> 1; - { - uint64_t F = roots_fwd[1]; - for (uint32_t jx = 0; jx < t; ++jx) { - uint32_t jy = jx + t; - uint64_t U = p_in[jx]; - uint64_t V = p_in[jy]; - fwd_butterfly_h(U, V, F, twoQ, fourQ, q, mrc); - p_out[jx] = U; - p_out[jy] = V; - } - } - for (uint32_t m = 2; m < N; m <<= 1) { - t >>= 1; - for (uint32_t i = 0; i < m; ++i) { - uint32_t j1 = (i * t) << 1; - uint32_t j2 = j1 + t; - uint64_t F = roots_fwd[m + i]; - for (uint32_t jx = j1; jx < j2; ++jx) { - uint32_t jy = jx + t; - uint64_t U = p_out[jx]; - uint64_t V = p_out[jy]; - fwd_butterfly_h(U, V, F, twoQ, fourQ, q, mrc); - p_out[jx] = U; - p_out[jy] = V; - } - } - } - for (uint32_t i = 0; i < N; ++i) { - p_out[i] = BRedAdd_h(p_out[i], q, brc_hi); - } -} - -inline void intt_one_h(const uint64_t* p_in, uint64_t* p_out, - const uint64_t* roots_bwd, - uint32_t N, uint64_t q, uint64_t mrc, - uint64_t n_inv) { - const uint64_t twoQ = q << 1; - const uint64_t fourQ = q << 2; - - uint32_t t = 1; - uint32_t h = N >> 1; - for (uint32_t i = 0, j1 = 0; i < h; ++i, j1 += 2 * t) { - uint32_t j2 = j1 + t; - uint64_t F = roots_bwd[h + i]; - for (uint32_t jx = j1; jx < j2; ++jx) { - uint32_t jy = jx + t; - uint64_t U = p_in[jx]; - uint64_t V = p_in[jy]; - inv_butterfly_h(U, V, F, twoQ, fourQ, q, mrc); - p_out[jx] = U; - p_out[jy] = V; - } - } - t <<= 1; - for (uint32_t m = N >> 1; m > 1; m >>= 1) { - h = m >> 1; - for (uint32_t i = 0, j1 = 0; i < h; ++i, j1 += 2 * t) { - uint32_t j2 = j1 + t; - uint64_t F = roots_bwd[h + i]; - for (uint32_t jx = j1; jx < j2; ++jx) { - uint32_t jy = jx + t; - uint64_t U = p_out[jx]; - uint64_t V = p_out[jy]; - inv_butterfly_h(U, V, F, twoQ, fourQ, q, mrc); - p_out[jx] = U; - p_out[jy] = V; - } - } - t <<= 1; - } - for (uint32_t i = 0; i < N; ++i) { - p_out[i] = MRed_h(p_out[i], n_inv, q, mrc); - } -} - -} // namespace host_polyfill - -#endif // LUX_LATTICE_RING_CUDA_HOST_POLYFILL - -} // namespace lux::crypto::ringtail::lattice_ring_cuda - -// ============================================================================= -// Host driver (extern "C") — single-poly entry points named to mirror the -// Metal sister at gpu/metal/lattice_ring_driver.h. Each function -// returns 0 on success, negative on error. -// -// Errors: -// -1 bad argument -// -2 cudaMalloc failed -// -3 cudaMemcpy H2D failed -// -4 kernel launch / cudaDeviceSynchronize failed -// -5 cudaMemcpy D2H failed -// ============================================================================= - -#if !LUX_LATTICE_RING_CUDA_HOST_POLYFILL -#include -#endif - -namespace clr = lux::crypto::ringtail::lattice_ring_cuda; - -extern "C" int lattice_ring_cuda_available(void) { -#if LUX_LATTICE_RING_CUDA_HOST_POLYFILL - return 0; -#else - int count = 0; - return (cudaGetDeviceCount(&count) == cudaSuccess && count > 0) ? 1 : 0; -#endif -} - -extern "C" int lattice_ring_cuda_mform( - const uint64_t* in, uint64_t* out, - uint32_t N, uint64_t q, uint64_t brc_hi, uint64_t brc_lo) { - if (!in || !out) return -1; - if (N == 0) return 0; - -#if LUX_LATTICE_RING_CUDA_HOST_POLYFILL - for (uint32_t i = 0; i < N; ++i) { - out[i] = clr::host_polyfill::MForm_h(in[i], q, brc_hi, brc_lo); - } - return 0; -#else - const size_t bytes = (size_t)N * sizeof(uint64_t); - uint64_t *d_in = nullptr, *d_out = nullptr; - cudaError_t st; - st = cudaMalloc(&d_in, bytes); if (st != cudaSuccess) return -2; - st = cudaMalloc(&d_out, bytes); - if (st != cudaSuccess) { cudaFree(d_in); return -2; } - st = cudaMemcpy(d_in, in, bytes, cudaMemcpyHostToDevice); - if (st != cudaSuccess) { cudaFree(d_in); cudaFree(d_out); return -3; } - clr::mform_kernel<<<1, N>>>(d_in, d_out, N, q, brc_hi, brc_lo); - st = cudaGetLastError(); - if (st == cudaSuccess) st = cudaDeviceSynchronize(); - if (st != cudaSuccess) { cudaFree(d_in); cudaFree(d_out); return -4; } - st = cudaMemcpy(out, d_out, bytes, cudaMemcpyDeviceToHost); - cudaFree(d_in); cudaFree(d_out); - return (st != cudaSuccess) ? -5 : 0; -#endif -} - -extern "C" int lattice_ring_cuda_imform( - const uint64_t* in, uint64_t* out, - uint32_t N, uint64_t q, uint64_t mrc) { - if (!in || !out) return -1; - if (N == 0) return 0; - -#if LUX_LATTICE_RING_CUDA_HOST_POLYFILL - for (uint32_t i = 0; i < N; ++i) { - out[i] = clr::host_polyfill::IMForm_h(in[i], q, mrc); - } - return 0; -#else - const size_t bytes = (size_t)N * sizeof(uint64_t); - uint64_t *d_in = nullptr, *d_out = nullptr; - cudaError_t st; - st = cudaMalloc(&d_in, bytes); if (st != cudaSuccess) return -2; - st = cudaMalloc(&d_out, bytes); - if (st != cudaSuccess) { cudaFree(d_in); return -2; } - st = cudaMemcpy(d_in, in, bytes, cudaMemcpyHostToDevice); - if (st != cudaSuccess) { cudaFree(d_in); cudaFree(d_out); return -3; } - clr::imform_kernel<<<1, N>>>(d_in, d_out, N, q, mrc); - st = cudaGetLastError(); - if (st == cudaSuccess) st = cudaDeviceSynchronize(); - if (st != cudaSuccess) { cudaFree(d_in); cudaFree(d_out); return -4; } - st = cudaMemcpy(out, d_out, bytes, cudaMemcpyDeviceToHost); - cudaFree(d_in); cudaFree(d_out); - return (st != cudaSuccess) ? -5 : 0; -#endif -} - -extern "C" int lattice_ring_cuda_mul_coeffs_montgomery( - const uint64_t* a, const uint64_t* b, uint64_t* out, - uint32_t N, uint64_t q, uint64_t mrc) { - if (!a || !b || !out) return -1; - if (N == 0) return 0; - -#if LUX_LATTICE_RING_CUDA_HOST_POLYFILL - for (uint32_t i = 0; i < N; ++i) { - out[i] = clr::host_polyfill::MRed_h(a[i], b[i], q, mrc); - } - return 0; -#else - const size_t bytes = (size_t)N * sizeof(uint64_t); - uint64_t *d_a = nullptr, *d_b = nullptr, *d_out = nullptr; - cudaError_t st; - st = cudaMalloc(&d_a, bytes); if (st != cudaSuccess) return -2; - st = cudaMalloc(&d_b, bytes); - if (st != cudaSuccess) { cudaFree(d_a); return -2; } - st = cudaMalloc(&d_out, bytes); - if (st != cudaSuccess) { cudaFree(d_a); cudaFree(d_b); return -2; } - st = cudaMemcpy(d_a, a, bytes, cudaMemcpyHostToDevice); - if (st == cudaSuccess) st = cudaMemcpy(d_b, b, bytes, cudaMemcpyHostToDevice); - if (st != cudaSuccess) { cudaFree(d_a); cudaFree(d_b); cudaFree(d_out); return -3; } - clr::mul_kernel<<<1, N>>>(d_a, d_b, d_out, N, q, mrc); - st = cudaGetLastError(); - if (st == cudaSuccess) st = cudaDeviceSynchronize(); - if (st != cudaSuccess) { cudaFree(d_a); cudaFree(d_b); cudaFree(d_out); return -4; } - st = cudaMemcpy(out, d_out, bytes, cudaMemcpyDeviceToHost); - cudaFree(d_a); cudaFree(d_b); cudaFree(d_out); - return (st != cudaSuccess) ? -5 : 0; -#endif -} - -extern "C" int lattice_ring_cuda_mul_coeffs_montgomery_then_add( - const uint64_t* a, const uint64_t* b, uint64_t* acc, - uint32_t N, uint64_t q, uint64_t mrc) { - if (!a || !b || !acc) return -1; - if (N == 0) return 0; - -#if LUX_LATTICE_RING_CUDA_HOST_POLYFILL - for (uint32_t i = 0; i < N; ++i) { - acc[i] = clr::host_polyfill::CRed_h( - acc[i] + clr::host_polyfill::MRed_h(a[i], b[i], q, mrc), q); - } - return 0; -#else - const size_t bytes = (size_t)N * sizeof(uint64_t); - uint64_t *d_a = nullptr, *d_b = nullptr, *d_acc = nullptr; - cudaError_t st; - st = cudaMalloc(&d_a, bytes); if (st != cudaSuccess) return -2; - st = cudaMalloc(&d_b, bytes); - if (st != cudaSuccess) { cudaFree(d_a); return -2; } - st = cudaMalloc(&d_acc, bytes); - if (st != cudaSuccess) { cudaFree(d_a); cudaFree(d_b); return -2; } - st = cudaMemcpy(d_a, a, bytes, cudaMemcpyHostToDevice); - if (st == cudaSuccess) st = cudaMemcpy(d_b, b, bytes, cudaMemcpyHostToDevice); - if (st == cudaSuccess) st = cudaMemcpy(d_acc, acc, bytes, cudaMemcpyHostToDevice); - if (st != cudaSuccess) { cudaFree(d_a); cudaFree(d_b); cudaFree(d_acc); return -3; } - clr::mul_then_add_kernel<<<1, N>>>(d_a, d_b, d_acc, N, q, mrc); - st = cudaGetLastError(); - if (st == cudaSuccess) st = cudaDeviceSynchronize(); - if (st != cudaSuccess) { cudaFree(d_a); cudaFree(d_b); cudaFree(d_acc); return -4; } - st = cudaMemcpy(acc, d_acc, bytes, cudaMemcpyDeviceToHost); - cudaFree(d_a); cudaFree(d_b); cudaFree(d_acc); - return (st != cudaSuccess) ? -5 : 0; -#endif -} - -extern "C" int lattice_ring_cuda_ntt( - const uint64_t* in, const uint64_t* roots_forward, uint64_t* out, - uint32_t N, uint64_t q, uint64_t mrc, uint64_t brc_hi) { - if (!in || !roots_forward || !out) return -1; - if (N == 0) return 0; - -#if LUX_LATTICE_RING_CUDA_HOST_POLYFILL - clr::host_polyfill::ntt_one_h(in, out, roots_forward, N, q, mrc, brc_hi); - return 0; -#else - const size_t bytes_poly = (size_t)N * sizeof(uint64_t); - uint64_t *d_in = nullptr, *d_out = nullptr, *d_roots = nullptr; - cudaError_t st; - st = cudaMalloc(&d_in, bytes_poly); if (st != cudaSuccess) return -2; - st = cudaMalloc(&d_out, bytes_poly); - if (st != cudaSuccess) { cudaFree(d_in); return -2; } - st = cudaMalloc(&d_roots, bytes_poly); - if (st != cudaSuccess) { cudaFree(d_in); cudaFree(d_out); return -2; } - st = cudaMemcpy(d_in, in, bytes_poly, cudaMemcpyHostToDevice); - if (st == cudaSuccess) - st = cudaMemcpy(d_roots, roots_forward, bytes_poly, cudaMemcpyHostToDevice); - if (st != cudaSuccess) { - cudaFree(d_in); cudaFree(d_out); cudaFree(d_roots); return -3; - } - const unsigned int tpb = N >> 1; // N/2 threads per block - const size_t shmem = N * sizeof(uint64_t); - clr::ntt_kernel<<<1, tpb, shmem>>>(d_in, d_out, d_roots, N, q, mrc, brc_hi); - st = cudaGetLastError(); - if (st == cudaSuccess) st = cudaDeviceSynchronize(); - if (st != cudaSuccess) { - cudaFree(d_in); cudaFree(d_out); cudaFree(d_roots); return -4; - } - st = cudaMemcpy(out, d_out, bytes_poly, cudaMemcpyDeviceToHost); - cudaFree(d_in); cudaFree(d_out); cudaFree(d_roots); - return (st != cudaSuccess) ? -5 : 0; -#endif -} - -extern "C" int lattice_ring_cuda_intt( - const uint64_t* in, const uint64_t* roots_backward, uint64_t* out, - uint32_t N, uint64_t q, uint64_t mrc, uint64_t n_inv_montgomery) { - if (!in || !roots_backward || !out) return -1; - if (N == 0) return 0; - -#if LUX_LATTICE_RING_CUDA_HOST_POLYFILL - clr::host_polyfill::intt_one_h(in, out, roots_backward, N, q, mrc, - n_inv_montgomery); - return 0; -#else - const size_t bytes_poly = (size_t)N * sizeof(uint64_t); - uint64_t *d_in = nullptr, *d_out = nullptr, *d_roots = nullptr; - cudaError_t st; - st = cudaMalloc(&d_in, bytes_poly); if (st != cudaSuccess) return -2; - st = cudaMalloc(&d_out, bytes_poly); - if (st != cudaSuccess) { cudaFree(d_in); return -2; } - st = cudaMalloc(&d_roots, bytes_poly); - if (st != cudaSuccess) { cudaFree(d_in); cudaFree(d_out); return -2; } - st = cudaMemcpy(d_in, in, bytes_poly, cudaMemcpyHostToDevice); - if (st == cudaSuccess) - st = cudaMemcpy(d_roots, roots_backward, bytes_poly, cudaMemcpyHostToDevice); - if (st != cudaSuccess) { - cudaFree(d_in); cudaFree(d_out); cudaFree(d_roots); return -3; - } - const unsigned int tpb = N >> 1; - const size_t shmem = N * sizeof(uint64_t); - clr::intt_kernel<<<1, tpb, shmem>>>(d_in, d_out, d_roots, N, q, mrc, - n_inv_montgomery); - st = cudaGetLastError(); - if (st == cudaSuccess) st = cudaDeviceSynchronize(); - if (st != cudaSuccess) { - cudaFree(d_in); cudaFree(d_out); cudaFree(d_roots); return -4; - } - st = cudaMemcpy(out, d_out, bytes_poly, cudaMemcpyDeviceToHost); - cudaFree(d_in); cudaFree(d_out); cudaFree(d_roots); - return (st != cudaSuccess) ? -5 : 0; -#endif -} diff --git a/math/ntt/cuda/lattice_ring_cuda.cpp b/math/ntt/cuda/lattice_ring_cuda.cpp deleted file mode 100644 index 0727abe..0000000 --- a/math/ntt/cuda/lattice_ring_cuda.cpp +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Thin C++ namespace wrapper around lattice_ring_driver.h. The -// actual dispatch (cudaMalloc/cudaMemcpy/launch in CUDA mode, sequential CPU -// loop in host-polyfill mode) lives in lattice_ring.cu next to the -// kernels. - -#include "lattice_ring_cuda.hpp" -#include "lattice_ring_driver.h" - -namespace lux::crypto::ringtail::lattice_ring_cuda { - -bool device_available() noexcept { - return lattice_ring_cuda_available() != 0; -} - -int mform(const uint64_t* in, uint64_t* out, unsigned int N, - uint64_t q, uint64_t brc_hi, uint64_t brc_lo) noexcept { - return lattice_ring_cuda_mform(in, out, N, q, brc_hi, brc_lo); -} - -int imform(const uint64_t* in, uint64_t* out, unsigned int N, - uint64_t q, uint64_t mrc) noexcept { - return lattice_ring_cuda_imform(in, out, N, q, mrc); -} - -int mul_coeffs_montgomery(const uint64_t* a, const uint64_t* b, - uint64_t* out, unsigned int N, - uint64_t q, uint64_t mrc) noexcept { - return lattice_ring_cuda_mul_coeffs_montgomery(a, b, out, N, q, mrc); -} - -int mul_coeffs_montgomery_then_add(const uint64_t* a, const uint64_t* b, - uint64_t* acc, unsigned int N, - uint64_t q, uint64_t mrc) noexcept { - return lattice_ring_cuda_mul_coeffs_montgomery_then_add( - a, b, acc, N, q, mrc); -} - -int ntt(const uint64_t* in, const uint64_t* roots_forward, uint64_t* out, - unsigned int N, uint64_t q, uint64_t mrc, uint64_t brc_hi) noexcept { - return lattice_ring_cuda_ntt(in, roots_forward, out, N, q, mrc, brc_hi); -} - -int intt(const uint64_t* in, const uint64_t* roots_backward, uint64_t* out, - unsigned int N, uint64_t q, uint64_t mrc, - uint64_t n_inv_montgomery) noexcept { - return lattice_ring_cuda_intt(in, roots_backward, out, N, q, mrc, - n_inv_montgomery); -} - -} // namespace lux::crypto::ringtail::lattice_ring_cuda diff --git a/math/ntt/cuda/lattice_ring_cuda.hpp b/math/ntt/cuda/lattice_ring_cuda.hpp deleted file mode 100644 index 84cf741..0000000 --- a/math/ntt/cuda/lattice_ring_cuda.hpp +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// lattice_ring_cuda -- C++ namespace wrapper around the C-ABI -// surface in lattice_ring_driver.h. Use this header from C++ call -// sites that prefer typed std::size_t / namespaced names; otherwise include -// the .h directly (the Metal sister test does that). -// -// Determinism contract: byte-equal output to lux::crypto::ringtail::lattice_ring -// for every supported op, regardless of whether the CUDA toolchain was on the -// build host. CRYPTO_ENABLE_CUDA=ON links the nvcc-compiled .cu and runs the -// kernels on a GPU; CRYPTO_ENABLE_CUDA=OFF (or any host without nvcc) links -// the same .cu compiled as plain C++ which executes the host-polyfill scalar -// path. Either way, the same atoms produce the same bytes. - -#pragma once - -#include -#include - -namespace lux::crypto::ringtail::lattice_ring_cuda { - -// Returns true when a real CUDA device is reachable. False under the host -// polyfill build (CRYPTO_ENABLE_CUDA=OFF) or when no NVIDIA GPU is present. -bool device_available() noexcept; - -// out[i] = MForm(in[i], q, brc). -int mform(const uint64_t* in, uint64_t* out, unsigned int N, - uint64_t q, uint64_t brc_hi, uint64_t brc_lo) noexcept; - -// out[i] = IMForm(in[i], q, mrc). -int imform(const uint64_t* in, uint64_t* out, unsigned int N, - uint64_t q, uint64_t mrc) noexcept; - -// out[i] = MRed(a[i], b[i], q, mrc). -int mul_coeffs_montgomery(const uint64_t* a, const uint64_t* b, - uint64_t* out, unsigned int N, - uint64_t q, uint64_t mrc) noexcept; - -// acc[i] = CRed(acc[i] + MRed(a[i], b[i], q, mrc), q). 'acc' is read-modify- -// written in place. -int mul_coeffs_montgomery_then_add(const uint64_t* a, const uint64_t* b, - uint64_t* acc, unsigned int N, - uint64_t q, uint64_t mrc) noexcept; - -// out = NTTStandard(in). roots_forward has N entries (Montgomery form, bit- -// reversed order). Output reduced into [0, q) via final BRedAdd pass. -int ntt(const uint64_t* in, const uint64_t* roots_forward, uint64_t* out, - unsigned int N, uint64_t q, uint64_t mrc, uint64_t brc_hi) noexcept; - -// out = INTTStandard(in). roots_backward has N entries; n_inv_montgomery is -// the Montgomery form of N^{-1} mod q. Output reduced into [0, q) via final -// MRed scaling pass. -int intt(const uint64_t* in, const uint64_t* roots_backward, uint64_t* out, - unsigned int N, uint64_t q, uint64_t mrc, - uint64_t n_inv_montgomery) noexcept; - -} // namespace lux::crypto::ringtail::lattice_ring_cuda diff --git a/math/ntt/cuda/lattice_ring_driver.h b/math/ntt/cuda/lattice_ring_driver.h index c77c74d..a5aad13 100644 --- a/math/ntt/cuda/lattice_ring_driver.h +++ b/math/ntt/cuda/lattice_ring_driver.h @@ -1,13 +1,17 @@ // Copyright (c) 2024-2026 Lux Industries Inc. // SPDX-License-Identifier: BSD-3-Clause-Eco // -// lattice_ring_driver.h -- C ABI for the M5/CUDA port of LP-073 -// Ringtail M1 lattice_ring primitives. Mirrors the Metal sister -// surface at gpu/metal/lattice_ring_driver.h. +// lattice_ring_driver.h -- public C ABI for the LP-073 Ringtail M1 lattice_ring +// CUDA primitives. The implementation (real .cu kernels + host polyfill body) +// lives in lux-private/gpu-kernels at math/ntt/cuda/lattice_ring.cu and is +// linked when CRYPTO_ENABLE_CUDA=ON brings the lattice_ring_cuda target into +// the build. // -// Layout note: every uint64_t buffer is N consecutive coefficients in HOST -// order (the CPU's std::vector layout). The driver memcpy's straight -// to / from CUDA device memory. +// On a CPU-only build (lux-gpu-kernels not installed) the symbol bodies are +// supplied by fhe/cpp/backends/cuda/lattice_ring_cuda_stub.cpp so the FHE +// host dispatcher in fhe/cpp/backends/cuda/cuda_ntt_kernel.cpp can resolve +// the declarations and gate dispatch on lattice_ring_cuda_available()==1 +// (which always returns 0 in the CPU-only build, routing to the CPU oracle). // // All entry points return 0 on success and a negative error code otherwise: // -1 bad argument @@ -15,13 +19,6 @@ // -3 cudaMemcpy H2D failed // -4 kernel launch / cudaDeviceSynchronize failed // -5 cudaMemcpy D2H failed -// -// Build modes: -// * CRYPTO_ENABLE_CUDA=ON — compiled by nvcc; kernels run on a CUDA GPU. -// * CRYPTO_ENABLE_CUDA=OFF — compiled as plain C++ via the host polyfill; -// entry points run the byte-equal CPU path so -// the surface stays stable on macOS / non-CUDA -// hosts. #ifndef LUX_RINGTAIL_GPU_CUDA_LATTICE_RING_DRIVER_H #define LUX_RINGTAIL_GPU_CUDA_LATTICE_RING_DRIVER_H @@ -33,14 +30,13 @@ extern "C" { #endif // Returns 1 when a real CUDA device is reachable, 0 otherwise (including -// host-polyfill builds). +// host-polyfill builds and CPU-only builds without lux-gpu-kernels). int lattice_ring_cuda_available(void); // brc_hi / brc_lo from GenBRedConstant (returned [0]=hi, [1]=lo). mrc from // GenMRedConstant. q is the modulus (Q=0x1000000004A01 for LP-073). N must // be a power of two ≥ 16; Ringtail uses N=256. -// out[i] = MForm(in[i], q, brc). int lattice_ring_cuda_mform( const uint64_t* in, uint64_t* out, @@ -49,7 +45,6 @@ int lattice_ring_cuda_mform( uint64_t brc_hi, uint64_t brc_lo); -// out[i] = IMForm(in[i], q, mrc). int lattice_ring_cuda_imform( const uint64_t* in, uint64_t* out, @@ -57,7 +52,6 @@ int lattice_ring_cuda_imform( uint64_t q, uint64_t mrc); -// out[i] = MRed(a[i], b[i], q, mrc). int lattice_ring_cuda_mul_coeffs_montgomery( const uint64_t* a, const uint64_t* b, @@ -66,8 +60,6 @@ int lattice_ring_cuda_mul_coeffs_montgomery( uint64_t q, uint64_t mrc); -// acc[i] = CRed(acc[i] + MRed(a[i], b[i], q, mrc), q). 'acc' is read-modify- -// written in place. int lattice_ring_cuda_mul_coeffs_montgomery_then_add( const uint64_t* a, const uint64_t* b, @@ -76,8 +68,6 @@ int lattice_ring_cuda_mul_coeffs_montgomery_then_add( uint64_t q, uint64_t mrc); -// out = NTTStandard(in). roots_forward has N entries (Montgomery form, bit- -// reversed order). Output reduced into [0, q) via final BRedAdd pass. int lattice_ring_cuda_ntt( const uint64_t* in, const uint64_t* roots_forward, @@ -87,9 +77,6 @@ int lattice_ring_cuda_ntt( uint64_t mrc, uint64_t brc_hi); -// out = INTTStandard(in). roots_backward has N entries; n_inv_montgomery is -// the Montgomery form of N^{-1} mod q. Output reduced into [0, q) via final -// MRed scaling pass. int lattice_ring_cuda_intt( const uint64_t* in, const uint64_t* roots_backward, diff --git a/mldsa/gpu/cuda/mldsa.cu b/mldsa/gpu/cuda/mldsa.cu deleted file mode 100644 index e36ed16..0000000 --- a/mldsa/gpu/cuda/mldsa.cu +++ /dev/null @@ -1,247 +0,0 @@ -// ML-DSA-65 (FIPS 204) batch verify -- CUDA implementation -// Matches mldsa.metal output byte-for-byte -// One thread per signature verification - -#include - -#ifndef __CUDA_ARCH__ -#define __device__ -#define __global__ -#define __shared__ -struct dim3 { unsigned x, y, z; }; -static dim3 blockIdx, blockDim, threadIdx; -#endif - -// ============================================================================= -// ML-DSA-65 parameters (NIST security level 3) -// ============================================================================= - -#define MLDSA_Q 8380417 -#define MLDSA_GAMMA1 524288 // 2^19 -#define MLDSA_BETA 196 // tau * eta - -// ============================================================================= -// Barrett reduction for q=8380417 -// ============================================================================= - -__device__ static int32_t mldsa_reduce(int32_t a) { - int32_t t = (int32_t)((int64_t)a * 33554687LL >> 48); - int32_t r = a - t * MLDSA_Q; - if (r < 0) r += MLDSA_Q; - if (r >= MLDSA_Q) r -= MLDSA_Q; - return r; -} - -// Montgomery reduction: aR^{-1} mod q -// CUDA has __int128, use it for the full-width multiply -__device__ static int32_t mldsa_mont_reduce(int64_t a) { - const int32_t q_inv = 58728449; - int32_t t = (int32_t)a * q_inv; - int64_t u = (int64_t)t * MLDSA_Q; - int32_t r = (int32_t)((a - u) >> 32); - if (r < 0) r += MLDSA_Q; - return r; -} - -// ============================================================================= -// NTT for ML-DSA (q=8380417, n=256) -// ============================================================================= - -__device__ static const int32_t ZETAS[128] = { - 25847, -2608894, -518909, 237124, -777960, -876248, 466468, 1826347, - 2353451, -359251, -2091905, 3119733, -2884855, 3111497, 2680103, 2725464, - 1024112, -1079900, 3585928, -549488, -1119584, 2619752, -2108549, -2118186, - -3859737, -1399561,-3277672, 1757237, -19422, 4010497, 280005, -2353451, - -1012179, -1277625, 1526252, -1402780, -2091905, 3119733, 3585928, -549488, - 2619752, -2108549, 2804197, -3199876, -38575, -2704181, 1757237, -19422, - 280005, 2706023, 1391570, 2287915, -3583748, -1399561, -3277672, -2353451, - 2353451, 3585928, -549488, 2619752, -2108549, 2804197, -3199876, -38575, - -2704181, 1757237, -19422, 280005, 2706023, 1391570, 2287915, -3583748, - -1399561, -3277672, 237124, -777960, -876248, 466468, 1826347, -2608894, - -518909, 237124, -777960, -876248, 466468, 1826347, 2353451, -359251, - -2091905, 3119733,-2884855, 3111497, 2680103, 2725464, 1024112, -1079900, - 3585928, -549488,-1119584, 2619752, -2108549, -2118186, -3859737, -1399561, - -3277672, 1757237, -19422, 4010497, 280005, -2353451, -1012179, -1277625, - 1526252, -1402780, 2706023, 1391570, 2287915, -3583748, -1399561, -3277672, - 1757237, -19422, 280005, 2706023, 1391570, 2287915, -3583748, -1399561 -}; - -// Forward NTT butterfly -__device__ static void ntt_bf(int32_t& a, int32_t& b, int32_t zeta) { - int32_t t = mldsa_mont_reduce((int64_t)zeta * b); - b = a - t; - a = a + t; - if (a >= MLDSA_Q) a -= MLDSA_Q; - if (b < 0) b += MLDSA_Q; -} - -// Inverse NTT butterfly -__device__ static void inv_ntt_bf(int32_t& a, int32_t& b, int32_t zeta) { - int32_t t = a; - a = t + b; - b = t - b; - if (a >= MLDSA_Q) a -= MLDSA_Q; - if (b < 0) b += MLDSA_Q; - b = mldsa_mont_reduce((int64_t)zeta * b); -} - -__device__ static void ntt256(int32_t poly[256]) { - int k = 0; - for (int len = 128; len >= 1; len >>= 1) { - for (int start = 0; start < 256; start += 2 * len) { - int32_t z = ZETAS[++k]; - for (int j = start; j < start + len; j++) { - ntt_bf(poly[j], poly[j + len], z); - } - } - } -} - -__device__ static void inv_ntt256(int32_t poly[256]) { - const int32_t f = 41978; - int k = 127; - for (int len = 1; len <= 128; len <<= 1) { - for (int start = 0; start < 256; start += 2 * len) { - int32_t z = -ZETAS[k--]; - if (z < 0) z += MLDSA_Q; - for (int j = start; j < start + len; j++) { - inv_ntt_bf(poly[j], poly[j + len], z); - } - } - } - for (int i = 0; i < 256; i++) { - poly[i] = mldsa_mont_reduce((int64_t)f * poly[i]); - } -} - -// ============================================================================= -// Polynomial operations -// ============================================================================= - -// Pointwise multiply-accumulate: acc += a * b (NTT domain) -__device__ static void poly_mac_ntt(int32_t acc[256], - const int32_t a[256], - const int32_t b[256]) { - for (int i = 0; i < 256; i++) { - int32_t t = mldsa_mont_reduce((int64_t)a[i] * b[i]); - acc[i] = mldsa_reduce(acc[i] + t); - } -} - -// Check infinity norm: returns true if all |coeff| < bound -__device__ static bool poly_check_norm(const int32_t poly[256], int32_t bound) { - for (int i = 0; i < 256; i++) { - int32_t c = poly[i]; - if (c > MLDSA_Q / 2) c -= MLDSA_Q; - if (c < 0) c = -c; - if (c >= bound) return false; - } - return true; -} - -// ============================================================================= -// ML-DSA signature structures -// ============================================================================= - -struct MLDSAPublicKey { - uint8_t data[1952]; -}; - -struct MLDSASignature { - uint8_t data[3360]; // Padded to 32-byte alignment -}; - -struct MLDSAMessage { - uint8_t data[64]; // 64-byte SHAKE256 digest -}; - -// ============================================================================= -// Verification kernel -// ============================================================================= - -extern "C" __global__ void mldsa_verify_batch( - const MLDSAPublicKey* __restrict__ pubkeys, - const MLDSAMessage* __restrict__ messages, - const MLDSASignature* __restrict__ signatures, - uint32_t* __restrict__ results, - const uint32_t* __restrict__ num_sigs_ptr) -{ - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - uint32_t num_sigs = *num_sigs_ptr; - if (tid >= num_sigs) return; - - // -- Decode z from signature: l=5 polynomials, gamma1=2^19 -- - const uint8_t* sig = signatures[tid].data; - // c_tilde is first 64 bytes, z starts at byte 64 - - int32_t z[5][256]; - - for (int p = 0; p < 5; p++) { - const uint8_t* zp = sig + 64 + p * 640; - for (int i = 0; i < 256; i += 4) { - uint32_t idx = (i / 4) * 5; - uint32_t b0 = zp[idx], b1 = zp[idx+1], b2 = zp[idx+2]; - uint32_t b3 = zp[idx+3], b4 = zp[idx+4]; - - z[p][i] = (int32_t)(((b0) | (b1 << 8) | ((b2 & 0x0F) << 16))); - z[p][i+1] = (int32_t)(((b2 >> 4) | (b3 << 4) | (b4 << 12))); - - if (z[p][i] >= (int32_t)MLDSA_GAMMA1) z[p][i] -= 2 * MLDSA_GAMMA1; - if (z[p][i+1] >= (int32_t)MLDSA_GAMMA1) z[p][i+1] -= 2 * MLDSA_GAMMA1; - - if (z[p][i] < 0) z[p][i] += MLDSA_Q; - if (z[p][i+1] < 0) z[p][i+1] += MLDSA_Q; - - if (i + 2 < 256) z[p][i+2] = 0; - if (i + 3 < 256) z[p][i+3] = 0; - } - } - - // -- Check ||z||_inf < gamma1 - beta -- - for (int p = 0; p < 5; p++) { - if (!poly_check_norm(z[p], MLDSA_GAMMA1 - MLDSA_BETA)) { - results[tid] = 0; - return; - } - } - - // -- Decode t1 from public key -- - // t1 has k=6 polynomials, each coefficient 10 bits - const uint8_t* pk = pubkeys[tid].data; - // rho = pk[0..31], t1 starts at byte 32 - - int32_t t1[6][256]; - for (int p = 0; p < 6; p++) { - const uint8_t* t1p = pk + 32 + p * 320; - for (int i = 0; i < 256; i += 4) { - uint32_t idx = (i / 4) * 5; - uint32_t b0 = t1p[idx], b1 = t1p[idx+1], b2 = t1p[idx+2]; - uint32_t b3 = t1p[idx+3], b4 = t1p[idx+4]; - - t1[p][i] = (int32_t)(b0 | ((b1 & 0x03) << 8)); - t1[p][i+1] = (int32_t)((b1 >> 2) | ((b2 & 0x0F) << 6)); - t1[p][i+2] = (int32_t)((b2 >> 4) | ((b3 & 0x3F) << 4)); - t1[p][i+3] = (int32_t)((b3 >> 6) | (b4 << 2)); - } - } - - // -- NTT(z) for each of l=5 polynomials -- - int32_t z_ntt[5][256]; - for (int p = 0; p < 5; p++) { - for (int i = 0; i < 256; i++) z_ntt[p][i] = z[p][i]; - ntt256(z_ntt[p]); - } - - // -- NTT(t1 * 2^d) for each of k=6 polynomials -- - int32_t t1_ntt[6][256]; - for (int p = 0; p < 6; p++) { - for (int i = 0; i < 256; i++) { - t1_ntt[p][i] = mldsa_reduce(t1[p][i] * 8192); - } - ntt256(t1_ntt[p]); - } - - // Polynomial checks passed (NTT operations completed successfully) - // Full verification requires SHAKE256 hash comparison done on host - results[tid] = 1; -} diff --git a/mldsa/gpu/metal/mldsa_batch.metal b/mldsa/gpu/metal/mldsa_batch.metal deleted file mode 100644 index 3904a99..0000000 --- a/mldsa/gpu/metal/mldsa_batch.metal +++ /dev/null @@ -1,369 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// ML-DSA (FIPS 204) GPU primitives + honest NOTIMPL kernel. -// -// Status (deps-bootstrap-2026-04-27): the previous skeleton at this path -// emitted "deferred code 2" and the harness asserted that — the kernel -// did NOT verify any signature. That violated the user directive -// "MUST be cryptographically correct" / "100% real impl, 100% test pass". -// -// This file replaces that fraud with two honest kernels: -// -// 1. mldsa_batch_verify -// Per-thread NOTIMPL emit (sentinel byte 0xFB = (uint8_t)(-5) = -// CRYPTO_ERR_NOTIMPL when reinterpret-cast unsigned). The host -// driver maps this back to the C-ABI return value -5 so the bridge -// explicitly knows GPU verify is not implemented and MUST fall back -// to CPU. There is no claim of cryptographic correctness here, by -// design: emitting NOTIMPL is honest, emitting "code 2 deferred -// with tests passing" was not. -// -// 2. mldsa_ntt_forward / mldsa_ntt_inverse -// Real ML-DSA-65 NTT over q = 8380417, n = 256, primitive 2N-th -// root zeta = 1753. Cooley-Tukey forward (bit-reverse out) and -// Gentleman-Sande inverse (bit-reverse in), using Montgomery -// reduction with R = 2^32, qinv = 58728449 = -q^{-1} mod 2^32. -// Byte-equal vs the FIPS-204 §A.1 NTT spec across the canonical -// golden vectors generated from circl/sign/mldsa/internal/common. -// -// What is INTENTIONALLY missing (and why the verify kernel returns -// NOTIMPL rather than a fake 0/1): -// -// - SHAKE128 ExpandA, SHAKE256 ExpandMask, SHAKE256 ExpandS, the -// SampleInBall challenge derivation, ByteEncode/Decode for hint h -// and z polynomials, w₁ HighBits + UseHint reconstruction, and -// range checks ‖z‖∞ < γ₁−β / ‖h‖₁ ≤ ω. Each of these has a clean -// port path from cloudflare/circl/sign/mldsa/mldsa{44,65,87}/ -// internal/ but landing the full pipeline byte-equal NIST KAT -// across all three parameter sets is a multi-day port that did -// not fit this pass. -// -// Sibling SHAKE Metal kernel that this file's future verify will call: -// keccak/gpu/metal/keccak_batch.metal already ships keccakf1600. The -// SHAKE128/256 absorb-squeeze loop is mldsa/gpu/metal/shake.metal in -// this file (lifted into a shared header for ML-KEM in a follow-up -// pass). -// -// References: -// - FIPS 204 (ML-DSA, August 2024) -// - cloudflare/circl/sign/mldsa/mldsa65 (Apache-2) -// - pq-crystals/dilithium reference C (public domain) -// - LP-137 §47 (PQC GPU port classification) - -#include -using namespace metal; - -// ============================================================================= -// ML-DSA-65 parameters (FIPS 204, §4 Table 2) -// ============================================================================= - -constant uint32_t MLDSA_N = 256; -constant uint32_t MLDSA_Q = 8380417; // 2^23 − 2^13 + 1 -constant uint32_t MLDSA_K = 6; // matrix rows (level 3) -constant uint32_t MLDSA_L = 5; // matrix cols (level 3) -constant uint32_t MLDSA_QINV = 58728449u; // -q^{-1} mod 2^32 -constant uint32_t MLDSA_LOG_N = 8; - -// ============================================================================= -// Modular arithmetic (Montgomery, R = 2^32) -// ============================================================================= - -// Montgomery reduction: given a < q*R, returns a*R^{-1} mod q. -// FIPS-204 §A.1: identical to circl/internal/common.MontgomeryReduce. -inline uint32_t mldsa_mont(uint64_t a) { - uint32_t lo = uint32_t(a); - uint32_t m = lo * MLDSA_QINV; - uint64_t t = uint64_t(m) * uint64_t(MLDSA_Q); - uint32_t r = uint32_t((a + t) >> 32); - if (r >= MLDSA_Q) r -= MLDSA_Q; - return r; -} - -inline uint32_t mldsa_add(uint32_t a, uint32_t b) { - uint32_t s = a + b; - return s >= MLDSA_Q ? s - MLDSA_Q : s; -} - -inline uint32_t mldsa_sub(uint32_t a, uint32_t b) { - return a >= b ? a - b : a + MLDSA_Q - b; -} - -// ============================================================================= -// Forward NTT (Cooley-Tukey, in-place) -// - one threadgroup per polynomial -// - input: natural order, output: bit-reversed order -// - twiddles: precomputed by host (mldsa zetas in Montgomery form) -// ============================================================================= - -kernel void mldsa_ntt_forward( - device uint32_t* polys [[buffer(0)]], // [batch * N] - constant uint32_t* zetas [[buffer(1)]], // [N] Montgomery zetas - constant uint& batch [[buffer(2)]], - uint tid [[thread_index_in_threadgroup]], - uint gid [[threadgroup_position_in_grid]], - uint tpg [[threads_per_threadgroup]], - threadgroup uint32_t* s [[threadgroup(0)]]) -{ - if (gid >= batch) return; - - // Load polynomial into threadgroup memory. - device uint32_t* poly = polys + gid * MLDSA_N; - for (uint i = tid; i < MLDSA_N; i += tpg) s[i] = poly[i]; - threadgroup_barrier(mem_flags::mem_threadgroup); - - uint k = 1; - for (uint len = MLDSA_N >> 1; len > 0; len >>= 1) { - uint num_pairs = MLDSA_N / (2 * len); - for (uint p = tid; p < num_pairs * len; p += tpg) { - uint pair_idx = p / len; - uint within = p % len; - uint start = 2 * len * pair_idx + within; - uint32_t zeta = zetas[k + pair_idx]; - uint32_t a = s[start]; - uint32_t b_mont = mldsa_mont(uint64_t(zeta) * uint64_t(s[start + len])); - s[start] = mldsa_add(a, b_mont); - s[start + len] = mldsa_sub(a, b_mont); - } - k += num_pairs; - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - for (uint i = tid; i < MLDSA_N; i += tpg) poly[i] = s[i]; -} - -// ============================================================================= -// Inverse NTT (Gentleman-Sande, in-place) -// - input: bit-reversed order, output: natural order -// - inv_zetas: −zeta values (host-precomputed) plus final scaling by N^{-1} -// ============================================================================= - -kernel void mldsa_ntt_inverse( - device uint32_t* polys [[buffer(0)]], - constant uint32_t* inv_zetas [[buffer(1)]], - constant uint32_t& n_inv [[buffer(2)]], // N^{-1} * R^2 mod q - constant uint& batch [[buffer(3)]], - uint tid [[thread_index_in_threadgroup]], - uint gid [[threadgroup_position_in_grid]], - uint tpg [[threads_per_threadgroup]], - threadgroup uint32_t* s [[threadgroup(0)]]) -{ - if (gid >= batch) return; - - device uint32_t* poly = polys + gid * MLDSA_N; - for (uint i = tid; i < MLDSA_N; i += tpg) s[i] = poly[i]; - threadgroup_barrier(mem_flags::mem_threadgroup); - - uint k = 0; - for (uint len = 1; len < MLDSA_N; len <<= 1) { - uint num_pairs = MLDSA_N / (2 * len); - for (uint p = tid; p < num_pairs * len; p += tpg) { - uint pair_idx = p / len; - uint within = p % len; - uint start = 2 * len * pair_idx + within; - uint32_t zeta = inv_zetas[k + pair_idx]; - uint32_t a = s[start]; - uint32_t b = s[start + len]; - uint32_t sum = mldsa_add(a, b); - uint32_t diff = mldsa_sub(a, b); - s[start] = sum; - s[start + len] = mldsa_mont(uint64_t(zeta) * uint64_t(diff)); - } - k += num_pairs; - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - // Final scaling: each coefficient *= N^{-1} * R^2 (Montgomery form). - for (uint i = tid; i < MLDSA_N; i += tpg) - poly[i] = mldsa_mont(uint64_t(s[i]) * uint64_t(n_inv)); -} - -// ============================================================================= -// SHAKE128 / SHAKE256 (FIPS 202) -// Built on Keccak-f[1600]. Used by ML-DSA for ExpandA (SHAKE128) and -// ExpandMask / ExpandS / SampleInBall (SHAKE256). -// -// These kernels are byte-equal NIST FIPS 202 (B.1.2 SHAKE128 / B.2.2 -// SHAKE256 KAT) for inputs up to 32 KiB and output lengths up to 4096 -// bytes. They are the prerequisite for a full FIPS-204 verify port — -// landed here as building blocks even though the verify kernel itself -// returns NOTIMPL. -// ============================================================================= - -constant ulong KECCAK_RC[24] = { - 0x0000000000000001UL, 0x0000000000008082UL, - 0x800000000000808AUL, 0x8000000080008000UL, - 0x000000000000808BUL, 0x0000000080000001UL, - 0x8000000080008081UL, 0x8000000000008009UL, - 0x000000000000008AUL, 0x0000000000000088UL, - 0x0000000080008009UL, 0x000000008000000AUL, - 0x000000008000808BUL, 0x800000000000008BUL, - 0x8000000000008089UL, 0x8000000000008003UL, - 0x8000000000008002UL, 0x8000000000000080UL, - 0x000000000000800AUL, 0x800000008000000AUL, - 0x8000000080008081UL, 0x8000000000008080UL, - 0x0000000080000001UL, 0x8000000080008008UL, -}; - -constant int KECCAK_R[5][5] = { - { 0, 36, 3, 41, 18}, - { 1, 44, 10, 45, 2}, - { 62, 6, 43, 15, 61}, - { 28, 55, 25, 21, 56}, - { 27, 20, 39, 8, 14}, -}; - -inline ulong krot(ulong x, int n) { - n &= 63; - if (n == 0) return x; - return (x << n) | (x >> (64 - n)); -} - -inline void keccakf(thread ulong* a) { - ulong C[5], D[5], B[25]; - for (int round = 0; round < 24; ++round) { - for (int x = 0; x < 5; ++x) - C[x] = a[x] ^ a[x + 5] ^ a[x + 10] ^ a[x + 15] ^ a[x + 20]; - for (int x = 0; x < 5; ++x) - D[x] = C[(x + 4) % 5] ^ krot(C[(x + 1) % 5], 1); - for (int y = 0; y < 5; ++y) - for (int x = 0; x < 5; ++x) - a[x + 5 * y] ^= D[x]; - for (int x = 0; x < 5; ++x) - for (int y = 0; y < 5; ++y) { - int nx = y; - int ny = (2 * x + 3 * y) % 5; - B[nx + 5 * ny] = krot(a[x + 5 * y], KECCAK_R[x][y]); - } - for (int y = 0; y < 5; ++y) { - ulong row[5]; - for (int x = 0; x < 5; ++x) row[x] = B[x + 5 * y]; - for (int x = 0; x < 5; ++x) - a[x + 5 * y] = row[x] ^ ((~row[(x + 1) % 5]) & row[(x + 2) % 5]); - } - a[0] ^= KECCAK_RC[round]; - } -} - -// One-shot SHAKE absorb-squeeze. Each thread processes one (input, output) -// job: rate = 168 (SHAKE128) or 136 (SHAKE256), delimiter = 0x1F. -// -// Layout: jobs[tid] = {input_offset, input_len, output_offset, output_len}. -struct ShakeJob { - uint32_t input_offset; - uint32_t input_len; - uint32_t output_offset; - uint32_t output_len; -}; - -inline void shake_one(uint rate, - device const uchar* in, uint inlen, - device uchar* out, uint outlen) { - ulong state[25]; - for (int i = 0; i < 25; ++i) state[i] = 0; - - uint absorbed = 0; - while (inlen - absorbed >= rate) { - for (uint w = 0; w < rate / 8; ++w) { - ulong lane = 0; - for (uint b = 0; b < 8; ++b) - lane |= ulong(in[absorbed + w * 8 + b]) << (b * 8); - state[w] ^= lane; - } - keccakf(state); - absorbed += rate; - } - - // Pad: tail || 0x1F || 0x00..0x00 || 0x80 (last byte of rate block). - uchar block[168]; - for (uint i = 0; i < rate; ++i) block[i] = 0; - uint rem = inlen - absorbed; - for (uint i = 0; i < rem; ++i) block[i] = in[absorbed + i]; - block[rem] = 0x1F; - block[rate - 1] |= 0x80; - for (uint w = 0; w < rate / 8; ++w) { - ulong lane = 0; - for (uint b = 0; b < 8; ++b) - lane |= ulong(block[w * 8 + b]) << (b * 8); - state[w] ^= lane; - } - keccakf(state); - - // Squeeze. - uint produced = 0; - while (produced < outlen) { - uint take = min(rate, outlen - produced); - for (uint w = 0; w * 8 < take; ++w) { - ulong lane = state[w]; - for (uint b = 0; b < 8 && w * 8 + b < take; ++b) - out[produced + w * 8 + b] = uchar(lane >> (b * 8)); - } - produced += take; - if (produced < outlen) keccakf(state); - } -} - -kernel void mldsa_shake128_jobs( - device const ShakeJob* jobs [[buffer(0)]], - device const uchar* inputs [[buffer(1)]], - device uchar* outputs [[buffer(2)]], - constant uint& num [[buffer(3)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= num) return; - ShakeJob j = jobs[tid]; - shake_one(168, - inputs + j.input_offset, j.input_len, - outputs + j.output_offset, j.output_len); -} - -kernel void mldsa_shake256_jobs( - device const ShakeJob* jobs [[buffer(0)]], - device const uchar* inputs [[buffer(1)]], - device uchar* outputs [[buffer(2)]], - constant uint& num [[buffer(3)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= num) return; - ShakeJob j = jobs[tid]; - shake_one(136, - inputs + j.input_offset, j.input_len, - outputs + j.output_offset, j.output_len); -} - -// ============================================================================= -// Honest NOTIMPL kernel for full FIPS-204 verify -// -// The c-ABI sentinel CRYPTO_ERR_NOTIMPL = -5 maps to (uint8_t)0xFB. The host -// driver reads results[tid] and returns that as the C-ABI return value when -// the entire batch is uniform NOTIMPL. -// -// This is not a temporary placeholder pretending to verify — it is a -// permanent honest endpoint that will remain until the full FIPS-204 verify -// pipeline (SHAKE-driven ExpandA/Mask/S + SampleInBall + UseHint + range -// checks) is byte-equal NIST KAT in Metal. -// ============================================================================= - -constant uchar MLDSA_RESULT_NOTIMPL = 0xFBu; // (uint8_t)(-5) - -struct MLDSAPublicKey { uchar data[1952]; }; // ML-DSA-65 -struct MLDSASignature { uchar data[3320]; }; // ML-DSA-65 padded -struct MLDSAMessage { uchar data[64]; }; - -kernel void mldsa_batch_verify( - device const MLDSAPublicKey* pubkeys [[buffer(0)]], - device const MLDSAMessage* messages [[buffer(1)]], - device const MLDSASignature* signatures [[buffer(2)]], - device uchar* results [[buffer(3)]], - constant uint& num_sigs [[buffer(4)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= num_sigs) return; - // Touch each input buffer to keep the dispatch shape correct (host - // bench needs to measure full data-path latency, not a no-op). - volatile uchar a = pubkeys[tid].data[0]; - volatile uchar b = messages[tid].data[0]; - volatile uchar c = signatures[tid].data[0]; - (void)a; (void)b; (void)c; - results[tid] = MLDSA_RESULT_NOTIMPL; -} diff --git a/mldsa/gpu/metal/mldsa_batch_driver.mm b/mldsa/gpu/metal/mldsa_batch_driver.mm deleted file mode 100644 index 3113569..0000000 --- a/mldsa/gpu/metal/mldsa_batch_driver.mm +++ /dev/null @@ -1,222 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Metal driver for ML-DSA kernels (FIPS 204). -// -// Exposes three brand-neutral C symbols: -// -// 1. mldsa_batch_verify_metal — dispatches the honest NOTIMPL kernel. -// Each thread writes 0xFB (= CRYPTO_ERR_NOTIMPL when reinterpret-cast -// signed). The C-ABI bridge that wraps this driver SHOULD return -5 -// to its caller when the entire batch is uniform NOTIMPL. -// -// 2. mldsa_shake128_metal — dispatches the FIPS-202 SHAKE128 kernel. -// Cryptographically correct, byte-equal NIST FIPS 202 KAT. -// -// 3. mldsa_shake256_metal — same shape, SHAKE256. -// -// Replaces the prior driver where the only kernel was a "deferred code 2" -// emit and the harness asserted that. Real cryptographic correctness now -// lives at the SHAKE primitive layer; the verify orchestrator returns -// honest NOTIMPL until the full FIPS-204 verify pipeline is byte-equal -// NIST KAT in Metal. - -#if __APPLE__ && __OBJC__ - -#import -#import - -#include -#include -#include - -namespace { - -struct ShakeJobHost { - uint32_t input_offset; - uint32_t input_len; - uint32_t output_offset; - uint32_t output_len; -}; - -int dispatch_shake(NSString* fn_name, - const uint8_t* inputs, - const uint32_t* input_offsets, - const uint32_t* input_lens, - const uint32_t* output_lens, - size_t n, - uint8_t* outputs, - const char* metallib_path) { - if (n == 0) return 0; - if (!input_lens || !output_lens || !outputs || !metallib_path) return -1; - - @autoreleasepool { - id device = MTLCreateSystemDefaultDevice(); - if (!device) return -2; - - NSError* err = nil; - NSURL* url = [NSURL fileURLWithPath:[NSString stringWithUTF8String:metallib_path]]; - id lib = [device newLibraryWithURL:url error:&err]; - if (!lib) return -3; - - id fn = [lib newFunctionWithName:fn_name]; - if (!fn) return -4; - - id pipeline = - [device newComputePipelineStateWithFunction:fn error:&err]; - if (!pipeline) return -5; - - id queue = [device newCommandQueue]; - - // Pack jobs and compute total in/out lengths. - size_t total_in = 0, total_out = 0; - std::vector jobs(n); - for (size_t i = 0; i < n; ++i) { - jobs[i].input_offset = input_offsets ? input_offsets[i] : 0u; - jobs[i].input_len = input_lens[i]; - jobs[i].output_offset = (uint32_t)total_out; - jobs[i].output_len = output_lens[i]; - total_in = jobs[i].input_offset + jobs[i].input_len > total_in - ? jobs[i].input_offset + jobs[i].input_len : total_in; - total_out += output_lens[i]; - } - if (total_in == 0) total_in = 1; - if (total_out == 0) total_out = 1; - - id jobs_buf = [device newBufferWithBytes:jobs.data() - length:n * sizeof(ShakeJobHost) - options:MTLResourceStorageModeShared]; - id in_buf = [device newBufferWithBytes:(inputs ? inputs : (const uint8_t*)"\0") - length:total_in - options:MTLResourceStorageModeShared]; - id out_buf = [device newBufferWithLength:total_out - options:MTLResourceStorageModeShared]; - uint32_t n_u32 = (uint32_t)n; - id n_buf = [device newBufferWithBytes:&n_u32 - length:sizeof(n_u32) - options:MTLResourceStorageModeShared]; - - id cmd = [queue commandBuffer]; - id enc = [cmd computeCommandEncoder]; - [enc setComputePipelineState:pipeline]; - [enc setBuffer:jobs_buf offset:0 atIndex:0]; - [enc setBuffer:in_buf offset:0 atIndex:1]; - [enc setBuffer:out_buf offset:0 atIndex:2]; - [enc setBuffer:n_buf offset:0 atIndex:3]; - - NSUInteger tg_max = pipeline.maxTotalThreadsPerThreadgroup; - NSUInteger tg_w = tg_max < 64 ? tg_max : 64; - if (tg_w > n) tg_w = n; - MTLSize threads_per_grid = MTLSizeMake(n, 1, 1); - MTLSize threads_per_tg = MTLSizeMake(tg_w, 1, 1); - [enc dispatchThreads:threads_per_grid threadsPerThreadgroup:threads_per_tg]; - [enc endEncoding]; - [cmd commit]; - [cmd waitUntilCompleted]; - - // Pack results back into caller's contiguous output buffer. - std::memcpy(outputs, [out_buf contents], total_out); - } - return 0; -} - -} // namespace - -extern "C" int mldsa_batch_verify_metal( - const uint8_t* pubkeys, // [n][1952] - const uint8_t* messages, // [n][64] - const uint8_t* signatures, // [n][3320] - size_t n, - uint8_t* results, // [n][1] 0xFB = NOTIMPL - const char* metallib_path) { - - if (n == 0) return 0; - if (!pubkeys || !messages || !signatures || !results || !metallib_path) { - return -1; - } - - @autoreleasepool { - id device = MTLCreateSystemDefaultDevice(); - if (!device) return -2; - - NSError* err = nil; - NSURL* url = [NSURL fileURLWithPath:[NSString stringWithUTF8String:metallib_path]]; - id lib = [device newLibraryWithURL:url error:&err]; - if (!lib) return -3; - - id fn = [lib newFunctionWithName:@"mldsa_batch_verify"]; - if (!fn) return -4; - - id pipeline = - [device newComputePipelineStateWithFunction:fn error:&err]; - if (!pipeline) return -5; - - id queue = [device newCommandQueue]; - - id pubkeys_buf = [device newBufferWithBytes:pubkeys - length:n * 1952 - options:MTLResourceStorageModeShared]; - id msgs_buf = [device newBufferWithBytes:messages - length:n * 64 - options:MTLResourceStorageModeShared]; - id sigs_buf = [device newBufferWithBytes:signatures - length:n * 3320 - options:MTLResourceStorageModeShared]; - id results_buf = [device newBufferWithLength:n - options:MTLResourceStorageModeShared]; - uint32_t n_u32 = (uint32_t)n; - id n_buf = [device newBufferWithBytes:&n_u32 - length:sizeof(n_u32) - options:MTLResourceStorageModeShared]; - - id cmd = [queue commandBuffer]; - id enc = [cmd computeCommandEncoder]; - [enc setComputePipelineState:pipeline]; - [enc setBuffer:pubkeys_buf offset:0 atIndex:0]; - [enc setBuffer:msgs_buf offset:0 atIndex:1]; - [enc setBuffer:sigs_buf offset:0 atIndex:2]; - [enc setBuffer:results_buf offset:0 atIndex:3]; - [enc setBuffer:n_buf offset:0 atIndex:4]; - - NSUInteger tg_max = pipeline.maxTotalThreadsPerThreadgroup; - NSUInteger tg_w = tg_max < 64 ? tg_max : 64; - if (tg_w > n) tg_w = n; - MTLSize threads_per_grid = MTLSizeMake(n, 1, 1); - MTLSize threads_per_tg = MTLSizeMake(tg_w, 1, 1); - [enc dispatchThreads:threads_per_grid threadsPerThreadgroup:threads_per_tg]; - [enc endEncoding]; - [cmd commit]; - [cmd waitUntilCompleted]; - - std::memcpy(results, [results_buf contents], n); - } - return 0; -} - -extern "C" int mldsa_shake128_metal( - const uint8_t* inputs, - const uint32_t* input_offsets, - const uint32_t* input_lens, - const uint32_t* output_lens, - size_t n, - uint8_t* outputs, - const char* metallib_path) { - return dispatch_shake(@"mldsa_shake128_jobs", - inputs, input_offsets, input_lens, output_lens, - n, outputs, metallib_path); -} - -extern "C" int mldsa_shake256_metal( - const uint8_t* inputs, - const uint32_t* input_offsets, - const uint32_t* input_lens, - const uint32_t* output_lens, - size_t n, - uint8_t* outputs, - const char* metallib_path) { - return dispatch_shake(@"mldsa_shake256_jobs", - inputs, input_offsets, input_lens, output_lens, - n, outputs, metallib_path); -} - -#endif // __APPLE__ && __OBJC__ diff --git a/mldsa/gpu/wgsl/mldsa.wgsl b/mldsa/gpu/wgsl/mldsa.wgsl deleted file mode 100644 index 0d28b7d..0000000 --- a/mldsa/gpu/wgsl/mldsa.wgsl +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// ML-DSA-65 (FIPS 204) batch signature verification in WGSL. -// NTT-based polynomial arithmetic over Z_q[x]/(x^n + 1), q=8380417, n=256. -// Each thread verifies one signature. - -struct MLDSAInput { - // Flattened: z polynomials [5*256 i32] + t1 polynomials [6*256 i32] - // Total: 11*256 = 2816 i32 values per signature - z_start: u32, // offset into poly_data for z - t1_start: u32, // offset into poly_data for t1 -} - -@group(0) @binding(0) var inputs: array; -@group(0) @binding(1) var poly_data: array; -@group(0) @binding(2) var results: array; -@group(0) @binding(3) var params: vec4; // params.x = num_sigs - -const Q: i32 = 8380417; -const GAMMA1: i32 = 524288; // 2^19 -const BETA: i32 = 196; - -const ZETAS = array( - 25847, -2608894, -518909, 237124, -777960, -876248, 466468, 1826347, - 2353451, -359251, -2091905, 3119733, -2884855, 3111497, 2680103, 2725464, - 1024112, -1079900, 3585928, -549488, -1119584, 2619752, -2108549, -2118186, - -3859737, -1399561, -3277672, 1757237, -19422, 4010497, 280005, -2353451, - -1012179, -1277625, 1526252, -1402780, -2091905, 3119733, 3585928, -549488, - 2619752, -2108549, 2804197, -3199876, -38575, -2704181, 1757237, -19422, - 280005, 2706023, 1391570, 2287915, -3583748, -1399561, -3277672, -2353451, - 2353451, 3585928, -549488, 2619752, -2108549, 2804197, -3199876, -38575, - -2704181, 1757237, -19422, 280005, 2706023, 1391570, 2287915, -3583748, - -1399561, -3277672, 237124, -777960, -876248, 466468, 1826347, -2608894, - -518909, 237124, -777960, -876248, 466468, 1826347, 2353451, -359251, - -2091905, 3119733, -2884855, 3111497, 2680103, 2725464, 1024112, -1079900, - 3585928, -549488, -1119584, 2619752, -2108549, -2118186, -3859737, -1399561, - -3277672, 1757237, -19422, 4010497, 280005, -2353451, -1012179, -1277625, - 1526252, -1402780, 2706023, 1391570, 2287915, -3583748, -1399561, -3277672, - 1757237, -19422, 280005, 2706023, 1391570, 2287915, -3583748, -1399561 -); - -fn mod_mul(a: i32, b: i32) -> i32 { - let a_lo = u32(a) & 0xFFFFu; - let a_hi = u32(a) >> 16u; - let b_lo = u32(b) & 0xFFFFu; - let b_hi = u32(b) >> 16u; - let ll = a_lo * b_lo; - let mid = a_lo * b_hi + a_hi * b_lo; - let hh = a_hi * b_hi; - let result_lo = ll + (mid << 16u); - let result_hi = hh + (mid >> 16u) + select(0u, 1u, result_lo < ll); - let q = u32(Q); - var r = result_lo - (result_hi * q); - if (r >= q) { r = r - q; } - if (r >= q) { r = r - q; } - return i32(r); -} - -fn ntt256(poly: ptr>) { - var k = 0u; - var len = 128u; - loop { - if (len == 0u) { break; } - var start = 0u; - loop { - if (start >= 256u) { break; } - k = k + 1u; - let zeta = ZETAS[k]; - var j = start; - loop { - if (j >= start + len) { break; } - let t = mod_mul(zeta, (*poly)[j + len]); - (*poly)[j + len] = (*poly)[j] - t; - (*poly)[j] = (*poly)[j] + t; - if ((*poly)[j] >= Q) { (*poly)[j] = (*poly)[j] - Q; } - if ((*poly)[j + len] < 0) { (*poly)[j + len] = (*poly)[j + len] + Q; } - j = j + 1u; - } - start = start + 2u * len; - } - len = len >> 1u; - } -} - -@compute @workgroup_size(64) -fn mldsa_verify_batch(@builtin(global_invocation_id) gid: vec3) { - let tid = gid.x; - if (tid >= params.x) { return; } - - let inp = inputs[tid]; - - // Load z polynomials and check infinity norm - for (var p = 0u; p < 5u; p = p + 1u) { - for (var i = 0u; i < 256u; i = i + 1u) { - var c = poly_data[inp.z_start + p * 256u + i]; - if (c > Q / 2) { c = c - Q; } - if (c < 0) { c = -c; } - if (c >= GAMMA1 - BETA) { - results[tid] = 0u; - return; - } - } - } - - // Load and NTT one z polynomial as a representative check - var z0: array; - for (var i = 0u; i < 256u; i = i + 1u) { - z0[i] = poly_data[inp.z_start + i]; - } - ntt256(&z0); - - // Load and NTT one t1 polynomial - var t1_0: array; - for (var i = 0u; i < 256u; i = i + 1u) { - var v = poly_data[inp.t1_start + i]; - v = v * 8192; // 2^13 - t1_0[i] = v - (v / Q) * Q; - } - ntt256(&t1_0); - - // NTT operations completed successfully - results[tid] = 1u; -} diff --git a/mlkem/gpu/cuda/mlkem.cu b/mlkem/gpu/cuda/mlkem.cu deleted file mode 100644 index bebfd4d..0000000 --- a/mlkem/gpu/cuda/mlkem.cu +++ /dev/null @@ -1,258 +0,0 @@ -// ML-KEM-768 (FIPS 203) batch decapsulate -- CUDA implementation -// Matches mlkem.metal output byte-for-byte -// One thread per decapsulation - -#include - -#ifndef __CUDA_ARCH__ -#define __device__ -#define __global__ -#define __shared__ -struct dim3 { unsigned x, y, z; }; -static dim3 blockIdx, blockDim, threadIdx; -#endif - -// ============================================================================= -// ML-KEM-768 parameters (NIST security level 3) -// ============================================================================= - -#define MLKEM_Q 3329 - -// ============================================================================= -// Montgomery arithmetic for q=3329, R=2^16 -// ============================================================================= - -// -q^{-1} mod 2^16 = 3327 -__device__ static int16_t mlkem_mont_reduce(int32_t a) { - const int16_t q_inv = 3327; - int16_t t = (int16_t)a * q_inv; - int32_t u = (int32_t)t * MLKEM_Q; - return (int16_t)((a - u) >> 16); -} - -// Barrett reduction for q=3329 -__device__ static int16_t mlkem_barrett_reduce(int16_t a) { - // v = floor(2^26 / q) + 1 = 20159 - int16_t t = (int16_t)(((int32_t)a * 20159) >> 26); - t = a - t * MLKEM_Q; - if (t >= MLKEM_Q) t -= MLKEM_Q; - if (t < 0) t += MLKEM_Q; - return t; -} - -// ============================================================================= -// NTT for ML-KEM (q=3329, n=256) -// ============================================================================= - -__device__ static const int16_t KYBER_ZETAS[128] = { - 2285, 2571, 2970, 1812, 1493, 1422, 287, 202, - 3158, 622, 1577, 182, 962, 2127, 1855, 1468, - 573, 2004, 264, 383, 2500, 1458, 1727, 3199, - 2648, 1017, 732, 608, 1787, 411, 3124, 1758, - 1223, 652, 2777, 1015, 2036, 1491, 3047, 1785, - 516, 3321, 3009, 2663, 1711, 2167, 126, 1469, - 2476, 3239, 3058, 830, 107, 1908, 3082, 2378, - 2931, 961, 1821, 2604, 448, 2264, 677, 2054, - 2226, 430, 555, 843, 2078, 871, 1550, 105, - 422, 587, 177, 3094, 3038, 2869, 1574, 1653, - 3083, 778, 1159, 3182, 2552, 1483, 2727, 1119, - 1739, 644, 2457, 349, 418, 329, 3173, 3254, - 817, 1097, 603, 610, 1322, 2044, 1864, 384, - 2114, 3193, 1218, 1994, 2455, 220, 2142, 1670, - 2144, 1799, 2051, 794, 1819, 2475, 2459, 478, - 3221, 3116, 622, 1097, 2470, 882, 1539, 2392 -}; - -// Forward NTT butterfly -__device__ static void kyber_ntt_bf(int16_t& a, int16_t& b, int16_t zeta) { - int32_t t = (int32_t)b * (int32_t)zeta; - t = mlkem_mont_reduce(t); - int32_t sum = (int32_t)a + t; - int32_t diff = (int32_t)a - t; - a = (int16_t)mlkem_barrett_reduce((int16_t)sum); - b = (int16_t)mlkem_barrett_reduce((int16_t)diff); -} - -// Inverse NTT butterfly -__device__ static void kyber_inv_ntt_bf(int16_t& a, int16_t& b, int16_t zeta) { - int16_t t = a; - a = t + b; - b = t - b; - b = mlkem_mont_reduce((int32_t)zeta * b); -} - -__device__ static void kyber_ntt(int16_t poly[256]) { - int k = 0; - for (int len = 128; len >= 2; len >>= 1) { - for (int start = 0; start < 256; start += 2 * len) { - int16_t z = KYBER_ZETAS[++k]; - for (int j = start; j < start + len; j++) { - kyber_ntt_bf(poly[j], poly[j + len], z); - } - } - } -} - -__device__ static void kyber_inv_ntt(int16_t poly[256]) { - const int16_t f = 1441; - int k = 127; - for (int len = 2; len <= 128; len <<= 1) { - for (int start = 0; start < 256; start += 2 * len) { - int16_t z = KYBER_ZETAS[k--]; - z = MLKEM_Q - z; // negate - for (int j = start; j < start + len; j++) { - kyber_inv_ntt_bf(poly[j], poly[j + len], z); - } - } - } - for (int i = 0; i < 256; i++) { - poly[i] = mlkem_mont_reduce((int32_t)f * poly[i]); - } -} - -// Pointwise multiplication (basemul) -__device__ static void kyber_basemul(int16_t r[2], - const int16_t a[2], - const int16_t b[2], - int16_t zeta) { - r[0] = mlkem_mont_reduce((int32_t)a[1] * b[1]); - r[0] = mlkem_mont_reduce((int32_t)r[0] * zeta); - r[0] = r[0] + mlkem_mont_reduce((int32_t)a[0] * b[0]); - r[1] = mlkem_mont_reduce((int32_t)a[0] * b[1]); - r[1] = r[1] + mlkem_mont_reduce((int32_t)a[1] * b[0]); -} - -// Full pointwise multiplication of NTT polynomials -__device__ static void kyber_poly_pointwise(int16_t r[256], - const int16_t a[256], - const int16_t b[256]) { - for (int i = 0; i < 256 / 4; i++) { - int16_t a_pair[2] = {a[4*i], a[4*i+1]}; - int16_t b_pair[2] = {b[4*i], b[4*i+1]}; - int16_t r_pair[2]; - kyber_basemul(r_pair, a_pair, b_pair, KYBER_ZETAS[64 + i]); - r[4*i] = r_pair[0]; - r[4*i + 1] = r_pair[1]; - - int16_t a_pair2[2] = {a[4*i+2], a[4*i+3]}; - int16_t b_pair2[2] = {b[4*i+2], b[4*i+3]}; - int16_t r_pair2[2]; - kyber_basemul(r_pair2, a_pair2, b_pair2, -KYBER_ZETAS[64 + i]); - r[4*i + 2] = r_pair2[0]; - r[4*i + 3] = r_pair2[1]; - } -} - -// ============================================================================= -// ML-KEM structures -// ============================================================================= - -struct MLKEMSecretKey { - uint8_t data[2400]; // 3*384 + 1184 + 32 + 32 -}; - -struct MLKEMCiphertext { - uint8_t data[1088]; // 3*320 + 128 -}; - -struct MLKEMSharedSecret { - uint8_t data[32]; -}; - -// ============================================================================= -// Decapsulation kernel -// ============================================================================= - -extern "C" __global__ void mlkem_decapsulate_batch( - const MLKEMSecretKey* __restrict__ secret_keys, - const MLKEMCiphertext* __restrict__ ciphertexts, - MLKEMSharedSecret* __restrict__ shared_secrets, - const uint32_t* __restrict__ num_ops_ptr) -{ - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - uint32_t num_ops = *num_ops_ptr; - if (tid >= num_ops) return; - - const uint8_t* sk = secret_keys[tid].data; - const uint8_t* ct = ciphertexts[tid].data; - - // -- Decode secret key s_hat (NTT domain, k=3 polynomials) -- - int16_t s_hat[3][256]; - for (int p = 0; p < 3; p++) { - const uint8_t* sp = sk + p * 384; - for (int i = 0; i < 256; i++) { - uint32_t idx = i * 3 / 2; - if (i & 1) { - s_hat[p][i] = (int16_t)(((sp[idx] >> 4) | ((uint32_t)sp[idx + 1] << 4)) & 0xFFF); - } else { - s_hat[p][i] = (int16_t)((sp[idx] | ((uint32_t)sp[idx + 1] << 8)) & 0xFFF); - } - } - } - - // -- Decode ciphertext u (compressed, k=3 polynomials, 10 bits each) -- - int16_t u[3][256]; - for (int p = 0; p < 3; p++) { - const uint8_t* up = ct + p * 320; - for (int i = 0; i < 256; i += 4) { - uint32_t idx = (i / 4) * 5; - uint32_t b0 = up[idx], b1 = up[idx+1], b2 = up[idx+2]; - uint32_t b3 = up[idx+3], b4 = up[idx+4]; - - u[p][i] = (int16_t)(b0 | ((b1 & 0x03) << 8)); - u[p][i+1] = (int16_t)((b1 >> 2) | ((b2 & 0x0F) << 6)); - u[p][i+2] = (int16_t)((b2 >> 4) | ((b3 & 0x3F) << 4)); - u[p][i+3] = (int16_t)((b3 >> 6) | (b4 << 2)); - } - // Decompress: multiply by q/2^10 and round - for (int i = 0; i < 256; i++) { - u[p][i] = (int16_t)(((uint32_t)u[p][i] * MLKEM_Q + 512) >> 10); - } - } - - // -- Decode ciphertext v (compressed, 4 bits per coefficient) -- - int16_t v[256]; - const uint8_t* vp = ct + 3 * 320; - for (int i = 0; i < 256; i += 2) { - v[i] = (int16_t)(vp[i / 2] & 0x0F); - v[i + 1] = (int16_t)(vp[i / 2] >> 4); - } - // Decompress: multiply by q/2^4 - for (int i = 0; i < 256; i++) { - v[i] = (int16_t)(((uint32_t)v[i] * MLKEM_Q + 8) >> 4); - } - - // -- Compute NTT(u) for inner product -- - int16_t u_hat[3][256]; - for (int p = 0; p < 3; p++) { - for (int i = 0; i < 256; i++) u_hat[p][i] = u[p][i]; - kyber_ntt(u_hat[p]); - } - - // -- Compute s_hat^T * NTT(u) -- - int16_t mp[256]; - for (int i = 0; i < 256; i++) mp[i] = 0; - for (int p = 0; p < 3; p++) { - int16_t tmp[256]; - kyber_poly_pointwise(tmp, s_hat[p], u_hat[p]); - for (int i = 0; i < 256; i++) { - mp[i] = mlkem_barrett_reduce(mp[i] + tmp[i]); - } - } - - // -- INTT to get s^T * u in normal domain -- - kyber_inv_ntt(mp); - - // -- Compute m = v - s^T * u, then compress to bits -- - uint8_t* out = shared_secrets[tid].data; - for (int i = 0; i < 32; i++) out[i] = 0; - - for (int i = 0; i < 256; i++) { - int16_t diff = v[i] - mp[i]; - if (diff < 0) diff += MLKEM_Q; - // Compress to 1 bit: round(2*diff/q) mod 2 - uint16_t t = ((uint16_t)diff << 1) + MLKEM_Q / 2; - uint8_t bit = (uint8_t)((t / MLKEM_Q) & 1); - out[i / 8] |= bit << (i % 8); - } -} diff --git a/mlkem/gpu/metal/mlkem_batch.metal b/mlkem/gpu/metal/mlkem_batch.metal deleted file mode 100644 index c04293d..0000000 --- a/mlkem/gpu/metal/mlkem_batch.metal +++ /dev/null @@ -1,355 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// ML-KEM (FIPS 203) GPU primitives + honest NOTIMPL kernel. -// -// Status (deps-bootstrap-2026-04-27): the previous skeleton at this path -// emitted "deferred code 2" and the harness asserted that — the kernel -// did NOT decapsulate any ciphertext. That violated the user directive -// "MUST be cryptographically correct" / "100% real impl, 100% test pass". -// -// This file replaces that fraud with three honest kernels: -// -// 1. mlkem_batch_decapsulate -// Per-thread NOTIMPL emit (sentinel byte 0xFB = (uint8_t)(-5) = -// CRYPTO_ERR_NOTIMPL when reinterpret-cast unsigned). The host -// driver maps this to the C-ABI return value -5 so the bridge MUST -// fall back to CPU. The shared-secret arena is zeroed on emit (no -// stale plaintext / no implicit-rejection ambiguity). -// -// 2. mlkem_ntt_forward / mlkem_ntt_inverse -// Real ML-KEM NTT over q = 3329, n = 256, primitive 256-th root -// zeta = 17. Cooley-Tukey forward, Gentleman-Sande inverse. Uses -// Montgomery reduction with R = 2^16, qinv = 3327 = -q^{-1} mod -// 2^16. Byte-equal vs FIPS-203 §4.3 NTT spec across canonical -// golden vectors generated from cloudflare/circl/kem/mlkem/mlkem768. -// -// 3. mlkem_shake128_jobs / mlkem_shake256_jobs -// FIPS 202 SHAKE128/256 over Keccak-f[1600]. Used by ML-KEM for -// G = SHA3-512, H = SHA3-256, J = SHAKE256, PRF = SHAKE256, XOF = -// SHAKE128. The same kernel is shared with ML-DSA (sibling at -// mldsa/gpu/metal/mldsa_batch.metal). Byte-equal NIST FIPS 202 KAT. -// -// What is INTENTIONALLY missing (and why decap returns NOTIMPL rather -// than a fake shared-secret): -// -// - K-PKE.Decrypt full pipeline byte-equal NIST KAT: ByteDecode_d_v -// of c1, decompress to u, NTT(u), compute s_hat^T * NTT(u), INTT, -// v - that, compress to m'. -// - The Fujisaki-Okamoto re-encrypt step (G(m' || h_pk) → (K', r'), -// K-PKE.Encrypt(ek, m', r') = c', constant-time c == c' compare, -// final K = K' on match else J(z || c)). -// - SHA3-256 for H and SHA3-512 for G (same Keccak-f, different rate -// and delimiter 0x06; the SHAKE kernels here are the prerequisite). -// -// The full-decap port to Metal is a multi-day effort. This file lands -// the cryptographically-correct primitives (NTT, SHAKE) that are the -// building blocks; the orchestration kernel returns NOTIMPL until those -// primitives are wired into a byte-equal full FIPS-203 decap. -// -// References: -// - FIPS 203 (ML-KEM, August 2024) -// - cloudflare/circl/kem/mlkem/mlkem768 (Apache-2) -// - pq-crystals/kyber reference C (public domain) - -#include -using namespace metal; - -// ============================================================================= -// ML-KEM-768 parameters (FIPS 203 §4 Table 2) -// ============================================================================= - -constant int16_t MLKEM_Q = 3329; -constant int16_t MLKEM_QINV = 3327; // -q^{-1} mod 2^16 -constant uint32_t MLKEM_N = 256; -constant uint32_t MLKEM_K = 3; // ML-KEM-768 (level 3) - -// ============================================================================= -// Modular arithmetic (Montgomery, R = 2^16; Barrett finalisation) -// ============================================================================= - -// Montgomery reduction: returns a * R^{-1} mod q for any signed a. -// FIPS-203 reference: kyber/ref/reduce.c::montgomery_reduce. -inline int16_t mlkem_mont(int32_t a) { - int16_t t = (int16_t)((int16_t)a * MLKEM_QINV); - int32_t u = (int32_t)t * (int32_t)MLKEM_Q; - return (int16_t)((a - u) >> 16); -} - -// Barrett: returns a mod q for a in [-q, q] roughly (signed). -inline int16_t mlkem_barrett(int16_t a) { - int16_t t = (int16_t)(((int32_t)a * 20159) >> 26); - t = a - t * MLKEM_Q; - if (t >= MLKEM_Q) t -= MLKEM_Q; - if (t < 0) t += MLKEM_Q; - return t; -} - -// ============================================================================= -// Kyber zetas (Montgomery form) — primitive 256-th roots of unity mod q. -// First 128 entries match cloudflare/circl/kem/mlkem/internal/kyber/zetas.go -// ============================================================================= - -constant int16_t KYBER_ZETAS[128] = { - -1044, -758, -359, -1517, 1493, 1422, 287, 202, - -171, 622, 1577, 182, 962, -1202, -1474, 1468, - 573, -1325, 264, 383, -829, 1458, -1602, -130, - -681, 1017, 732, 608, -1542, 411, -205, -1571, - 1223, 652, -552, 1015, -1293, 1491, -282, -1544, - 516, -8, -320, -666, -1618, -1162, 126, 1469, - -853, -90, -271, 830, 107, -1421, -247, -951, - -398, 961, -1508, -725, 448, -1065, 677, -1275, - -1103, 430, 555, 843, -1251, 871, 1550, 105, - 422, 587, 177, -235, -291, -460, 1574, 1653, - -246, 778, 1159, -147, -777, 1483, -602, 1119, - -1590, 644, -872, 349, 418, 329, -156, -75, - 817, 1097, 603, 610, 1322, -1285, -1465, 384, - -1215, -136, 1218, -1335, -874, 220, -1187, -1659, - -1185, -1530, -1278, 794, -1510, -854, -870, 478, - -108, -308, 996, 991, 958, -1460, 1522, 1628 -}; - -// ============================================================================= -// Forward NTT (in-place, bit-reversed output) -// - one threadgroup per polynomial; tid runs over butterflies in stage -// - matches kyber/ref/ntt.c::ntt -// ============================================================================= - -kernel void mlkem_ntt_forward( - device int16_t* polys [[buffer(0)]], // [batch * 256] - constant uint& batch [[buffer(1)]], - uint tid [[thread_index_in_threadgroup]], - uint gid [[threadgroup_position_in_grid]], - uint tpg [[threads_per_threadgroup]], - threadgroup int16_t* s [[threadgroup(0)]]) -{ - if (gid >= batch) return; - device int16_t* poly = polys + gid * MLKEM_N; - for (uint i = tid; i < MLKEM_N; i += tpg) s[i] = poly[i]; - threadgroup_barrier(mem_flags::mem_threadgroup); - - uint k = 1; - for (uint len = 128; len >= 2; len >>= 1) { - uint num_pairs = MLKEM_N / (2 * len); - for (uint p = tid; p < num_pairs * len; p += tpg) { - uint pair_idx = p / len; - uint within = p % len; - uint start = 2 * len * pair_idx + within; - int16_t zeta = KYBER_ZETAS[k + pair_idx]; - int16_t t = mlkem_mont((int32_t)zeta * (int32_t)s[start + len]); - s[start + len] = s[start] - t; - s[start] = s[start] + t; - } - k += num_pairs; - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - for (uint i = tid; i < MLKEM_N; i += tpg) poly[i] = s[i]; -} - -// ============================================================================= -// Inverse NTT (in-place, natural-order output) -// f = mont * (256)^{-1} = 1441 -// ============================================================================= - -kernel void mlkem_ntt_inverse( - device int16_t* polys [[buffer(0)]], - constant uint& batch [[buffer(1)]], - uint tid [[thread_index_in_threadgroup]], - uint gid [[threadgroup_position_in_grid]], - uint tpg [[threads_per_threadgroup]], - threadgroup int16_t* s [[threadgroup(0)]]) -{ - if (gid >= batch) return; - device int16_t* poly = polys + gid * MLKEM_N; - for (uint i = tid; i < MLKEM_N; i += tpg) s[i] = poly[i]; - threadgroup_barrier(mem_flags::mem_threadgroup); - - int k = 127; - for (uint len = 2; len <= 128; len <<= 1) { - uint num_pairs = MLKEM_N / (2 * len); - for (uint p = tid; p < num_pairs * len; p += tpg) { - uint pair_idx = p / len; - uint within = p % len; - uint start = 2 * len * pair_idx + within; - int16_t zeta = -KYBER_ZETAS[k - (int)pair_idx]; - int16_t t = s[start]; - s[start] = mlkem_barrett(t + s[start + len]); - s[start + len] = mlkem_mont((int32_t)zeta * (int32_t)(s[start + len] - t)); - } - k -= num_pairs; - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - const int16_t f = 1441; - for (uint i = tid; i < MLKEM_N; i += tpg) - poly[i] = mlkem_mont((int32_t)f * (int32_t)s[i]); -} - -// ============================================================================= -// SHAKE128 / SHAKE256 (FIPS 202) -// -// Same kernel shape as mldsa/gpu/metal/mldsa_batch.metal. Duplicated rather -// than included because Metal's compile model has no preprocessor cross-file -// inclusion across .metal sources. A future pass extracts to a shared -// .metal.h header. -// ============================================================================= - -constant ulong KECCAK_RC[24] = { - 0x0000000000000001UL, 0x0000000000008082UL, - 0x800000000000808AUL, 0x8000000080008000UL, - 0x000000000000808BUL, 0x0000000080000001UL, - 0x8000000080008081UL, 0x8000000000008009UL, - 0x000000000000008AUL, 0x0000000000000088UL, - 0x0000000080008009UL, 0x000000008000000AUL, - 0x000000008000808BUL, 0x800000000000008BUL, - 0x8000000000008089UL, 0x8000000000008003UL, - 0x8000000000008002UL, 0x8000000000000080UL, - 0x000000000000800AUL, 0x800000008000000AUL, - 0x8000000080008081UL, 0x8000000000008080UL, - 0x0000000080000001UL, 0x8000000080008008UL, -}; - -constant int KECCAK_R[5][5] = { - { 0, 36, 3, 41, 18}, - { 1, 44, 10, 45, 2}, - { 62, 6, 43, 15, 61}, - { 28, 55, 25, 21, 56}, - { 27, 20, 39, 8, 14}, -}; - -inline ulong krot(ulong x, int n) { - n &= 63; - if (n == 0) return x; - return (x << n) | (x >> (64 - n)); -} - -inline void keccakf(thread ulong* a) { - ulong C[5], D[5], B[25]; - for (int round = 0; round < 24; ++round) { - for (int x = 0; x < 5; ++x) - C[x] = a[x] ^ a[x + 5] ^ a[x + 10] ^ a[x + 15] ^ a[x + 20]; - for (int x = 0; x < 5; ++x) - D[x] = C[(x + 4) % 5] ^ krot(C[(x + 1) % 5], 1); - for (int y = 0; y < 5; ++y) - for (int x = 0; x < 5; ++x) - a[x + 5 * y] ^= D[x]; - for (int x = 0; x < 5; ++x) - for (int y = 0; y < 5; ++y) { - int nx = y; - int ny = (2 * x + 3 * y) % 5; - B[nx + 5 * ny] = krot(a[x + 5 * y], KECCAK_R[x][y]); - } - for (int y = 0; y < 5; ++y) { - ulong row[5]; - for (int x = 0; x < 5; ++x) row[x] = B[x + 5 * y]; - for (int x = 0; x < 5; ++x) - a[x + 5 * y] = row[x] ^ ((~row[(x + 1) % 5]) & row[(x + 2) % 5]); - } - a[0] ^= KECCAK_RC[round]; - } -} - -struct ShakeJob { - uint32_t input_offset; - uint32_t input_len; - uint32_t output_offset; - uint32_t output_len; -}; - -inline void shake_one(uint rate, - device const uchar* in, uint inlen, - device uchar* out, uint outlen) { - ulong state[25]; - for (int i = 0; i < 25; ++i) state[i] = 0; - - uint absorbed = 0; - while (inlen - absorbed >= rate) { - for (uint w = 0; w < rate / 8; ++w) { - ulong lane = 0; - for (uint b = 0; b < 8; ++b) - lane |= ulong(in[absorbed + w * 8 + b]) << (b * 8); - state[w] ^= lane; - } - keccakf(state); - absorbed += rate; - } - - uchar block[168]; - for (uint i = 0; i < rate; ++i) block[i] = 0; - uint rem = inlen - absorbed; - for (uint i = 0; i < rem; ++i) block[i] = in[absorbed + i]; - block[rem] = 0x1F; - block[rate - 1] |= 0x80; - for (uint w = 0; w < rate / 8; ++w) { - ulong lane = 0; - for (uint b = 0; b < 8; ++b) - lane |= ulong(block[w * 8 + b]) << (b * 8); - state[w] ^= lane; - } - keccakf(state); - - uint produced = 0; - while (produced < outlen) { - uint take = min(rate, outlen - produced); - for (uint w = 0; w * 8 < take; ++w) { - ulong lane = state[w]; - for (uint b = 0; b < 8 && w * 8 + b < take; ++b) - out[produced + w * 8 + b] = uchar(lane >> (b * 8)); - } - produced += take; - if (produced < outlen) keccakf(state); - } -} - -kernel void mlkem_shake128_jobs( - device const ShakeJob* jobs [[buffer(0)]], - device const uchar* inputs [[buffer(1)]], - device uchar* outputs [[buffer(2)]], - constant uint& num [[buffer(3)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= num) return; - ShakeJob j = jobs[tid]; - shake_one(168, inputs + j.input_offset, j.input_len, - outputs + j.output_offset, j.output_len); -} - -kernel void mlkem_shake256_jobs( - device const ShakeJob* jobs [[buffer(0)]], - device const uchar* inputs [[buffer(1)]], - device uchar* outputs [[buffer(2)]], - constant uint& num [[buffer(3)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= num) return; - ShakeJob j = jobs[tid]; - shake_one(136, inputs + j.input_offset, j.input_len, - outputs + j.output_offset, j.output_len); -} - -// ============================================================================= -// Honest NOTIMPL kernel for full FIPS-203 decap -// ============================================================================= - -constant uchar MLKEM_RESULT_NOTIMPL = 0xFBu; // (uint8_t)(-5) - -struct MLKEMSecretKey { uchar data[2400]; }; // ML-KEM-768 -struct MLKEMCiphertext { uchar data[1088]; }; // ML-KEM-768 -struct MLKEMSharedSecret { uchar data[32]; }; - -kernel void mlkem_batch_decapsulate( - device const MLKEMSecretKey* secret_keys [[buffer(0)]], - device const MLKEMCiphertext* ciphertexts [[buffer(1)]], - device MLKEMSharedSecret* shared_secrets [[buffer(2)]], - device uchar* results [[buffer(3)]], - constant uint& num_ops [[buffer(4)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= num_ops) return; - volatile uchar a = secret_keys[tid].data[0]; - volatile uchar b = ciphertexts[tid].data[0]; - (void)a; (void)b; - for (int i = 0; i < 32; ++i) shared_secrets[tid].data[i] = 0; - results[tid] = MLKEM_RESULT_NOTIMPL; -} diff --git a/mlkem/gpu/metal/mlkem_batch_driver.mm b/mlkem/gpu/metal/mlkem_batch_driver.mm deleted file mode 100644 index 076df18..0000000 --- a/mlkem/gpu/metal/mlkem_batch_driver.mm +++ /dev/null @@ -1,207 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Metal driver for ML-KEM kernels (FIPS 203). -// -// Same shape as the ML-DSA driver. Exposes: -// - mlkem_batch_decapsulate_metal — honest NOTIMPL, zeroes shared_secrets -// - mlkem_shake128_metal — FIPS-202 SHAKE128, byte-equal NIST KAT -// - mlkem_shake256_metal — FIPS-202 SHAKE256, byte-equal NIST KAT - -#if __APPLE__ && __OBJC__ - -#import -#import - -#include -#include -#include - -namespace { - -struct ShakeJobHost { - uint32_t input_offset; - uint32_t input_len; - uint32_t output_offset; - uint32_t output_len; -}; - -int dispatch_shake(NSString* fn_name, - const uint8_t* inputs, - const uint32_t* input_offsets, - const uint32_t* input_lens, - const uint32_t* output_lens, - size_t n, - uint8_t* outputs, - const char* metallib_path) { - if (n == 0) return 0; - if (!input_lens || !output_lens || !outputs || !metallib_path) return -1; - - @autoreleasepool { - id device = MTLCreateSystemDefaultDevice(); - if (!device) return -2; - - NSError* err = nil; - NSURL* url = [NSURL fileURLWithPath:[NSString stringWithUTF8String:metallib_path]]; - id lib = [device newLibraryWithURL:url error:&err]; - if (!lib) return -3; - - id fn = [lib newFunctionWithName:fn_name]; - if (!fn) return -4; - - id pipeline = - [device newComputePipelineStateWithFunction:fn error:&err]; - if (!pipeline) return -5; - - id queue = [device newCommandQueue]; - - size_t total_in = 0, total_out = 0; - std::vector jobs(n); - for (size_t i = 0; i < n; ++i) { - jobs[i].input_offset = input_offsets ? input_offsets[i] : 0u; - jobs[i].input_len = input_lens[i]; - jobs[i].output_offset = (uint32_t)total_out; - jobs[i].output_len = output_lens[i]; - total_in = jobs[i].input_offset + jobs[i].input_len > total_in - ? jobs[i].input_offset + jobs[i].input_len : total_in; - total_out += output_lens[i]; - } - if (total_in == 0) total_in = 1; - if (total_out == 0) total_out = 1; - - id jobs_buf = [device newBufferWithBytes:jobs.data() - length:n * sizeof(ShakeJobHost) - options:MTLResourceStorageModeShared]; - id in_buf = [device newBufferWithBytes:(inputs ? inputs : (const uint8_t*)"\0") - length:total_in - options:MTLResourceStorageModeShared]; - id out_buf = [device newBufferWithLength:total_out - options:MTLResourceStorageModeShared]; - uint32_t n_u32 = (uint32_t)n; - id n_buf = [device newBufferWithBytes:&n_u32 - length:sizeof(n_u32) - options:MTLResourceStorageModeShared]; - - id cmd = [queue commandBuffer]; - id enc = [cmd computeCommandEncoder]; - [enc setComputePipelineState:pipeline]; - [enc setBuffer:jobs_buf offset:0 atIndex:0]; - [enc setBuffer:in_buf offset:0 atIndex:1]; - [enc setBuffer:out_buf offset:0 atIndex:2]; - [enc setBuffer:n_buf offset:0 atIndex:3]; - - NSUInteger tg_max = pipeline.maxTotalThreadsPerThreadgroup; - NSUInteger tg_w = tg_max < 64 ? tg_max : 64; - if (tg_w > n) tg_w = n; - MTLSize threads_per_grid = MTLSizeMake(n, 1, 1); - MTLSize threads_per_tg = MTLSizeMake(tg_w, 1, 1); - [enc dispatchThreads:threads_per_grid threadsPerThreadgroup:threads_per_tg]; - [enc endEncoding]; - [cmd commit]; - [cmd waitUntilCompleted]; - - std::memcpy(outputs, [out_buf contents], total_out); - } - return 0; -} - -} // namespace - -extern "C" int mlkem_batch_decapsulate_metal( - const uint8_t* secret_keys, - const uint8_t* ciphertexts, - size_t n, - uint8_t* shared_secrets, - uint8_t* results, - const char* metallib_path) { - - if (n == 0) return 0; - if (!secret_keys || !ciphertexts || !shared_secrets || !results || !metallib_path) { - return -1; - } - - @autoreleasepool { - id device = MTLCreateSystemDefaultDevice(); - if (!device) return -2; - - NSError* err = nil; - NSURL* url = [NSURL fileURLWithPath:[NSString stringWithUTF8String:metallib_path]]; - id lib = [device newLibraryWithURL:url error:&err]; - if (!lib) return -3; - - id fn = [lib newFunctionWithName:@"mlkem_batch_decapsulate"]; - if (!fn) return -4; - - id pipeline = - [device newComputePipelineStateWithFunction:fn error:&err]; - if (!pipeline) return -5; - - id queue = [device newCommandQueue]; - - id sks_buf = [device newBufferWithBytes:secret_keys - length:n * 2400 - options:MTLResourceStorageModeShared]; - id cts_buf = [device newBufferWithBytes:ciphertexts - length:n * 1088 - options:MTLResourceStorageModeShared]; - id ss_buf = [device newBufferWithLength:n * 32 - options:MTLResourceStorageModeShared]; - id res_buf = [device newBufferWithLength:n - options:MTLResourceStorageModeShared]; - uint32_t n_u32 = (uint32_t)n; - id n_buf = [device newBufferWithBytes:&n_u32 - length:sizeof(n_u32) - options:MTLResourceStorageModeShared]; - - id cmd = [queue commandBuffer]; - id enc = [cmd computeCommandEncoder]; - [enc setComputePipelineState:pipeline]; - [enc setBuffer:sks_buf offset:0 atIndex:0]; - [enc setBuffer:cts_buf offset:0 atIndex:1]; - [enc setBuffer:ss_buf offset:0 atIndex:2]; - [enc setBuffer:res_buf offset:0 atIndex:3]; - [enc setBuffer:n_buf offset:0 atIndex:4]; - - NSUInteger tg_max = pipeline.maxTotalThreadsPerThreadgroup; - NSUInteger tg_w = tg_max < 64 ? tg_max : 64; - if (tg_w > n) tg_w = n; - MTLSize threads_per_grid = MTLSizeMake(n, 1, 1); - MTLSize threads_per_tg = MTLSizeMake(tg_w, 1, 1); - [enc dispatchThreads:threads_per_grid threadsPerThreadgroup:threads_per_tg]; - [enc endEncoding]; - [cmd commit]; - [cmd waitUntilCompleted]; - - std::memcpy(shared_secrets, [ss_buf contents], n * 32); - std::memcpy(results, [res_buf contents], n); - } - return 0; -} - -extern "C" int mlkem_shake128_metal( - const uint8_t* inputs, - const uint32_t* input_offsets, - const uint32_t* input_lens, - const uint32_t* output_lens, - size_t n, - uint8_t* outputs, - const char* metallib_path) { - return dispatch_shake(@"mlkem_shake128_jobs", - inputs, input_offsets, input_lens, output_lens, - n, outputs, metallib_path); -} - -extern "C" int mlkem_shake256_metal( - const uint8_t* inputs, - const uint32_t* input_offsets, - const uint32_t* input_lens, - const uint32_t* output_lens, - size_t n, - uint8_t* outputs, - const char* metallib_path) { - return dispatch_shake(@"mlkem_shake256_jobs", - inputs, input_offsets, input_lens, output_lens, - n, outputs, metallib_path); -} - -#endif // __APPLE__ && __OBJC__ diff --git a/mlkem/gpu/wgsl/mlkem.wgsl b/mlkem/gpu/wgsl/mlkem.wgsl deleted file mode 100644 index 1062381..0000000 --- a/mlkem/gpu/wgsl/mlkem.wgsl +++ /dev/null @@ -1,198 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// ML-KEM-768 (FIPS 203) batch decapsulation in WGSL. -// NTT-based polynomial arithmetic over Z_q[x]/(x^n+1), q=3329, n=256. -// Each thread decapsulates one ciphertext. - -@group(0) @binding(0) var sk_data: array; // Secret keys (packed) -@group(0) @binding(1) var ct_data: array; // Ciphertexts (packed) -@group(0) @binding(2) var out_data: array; // Shared secrets -@group(0) @binding(3) var params: vec4; // params.x = num_ops - -const Q: i32 = 3329; -const N: u32 = 256u; -const K: u32 = 3u; - -const KYBER_ZETAS = array( - 2285, 2571, 2970, 1812, 1493, 1422, 287, 202, - 3158, 622, 1577, 182, 962, 2127, 1855, 1468, - 573, 2004, 264, 383, 2500, 1458, 1727, 3199, - 2648, 1017, 732, 608, 1787, 411, 3124, 1758, - 1223, 652, 2777, 1015, 2036, 1491, 3047, 1785, - 516, 3321, 3009, 2663, 1711, 2167, 126, 1469, - 2476, 3239, 3058, 830, 107, 1908, 3082, 2378, - 2931, 961, 1821, 2604, 448, 2264, 677, 2054, - 2226, 430, 555, 843, 2078, 871, 1550, 105, - 422, 587, 177, 3094, 3038, 2869, 1574, 1653, - 3083, 778, 1159, 3182, 2552, 1483, 2727, 1119, - 1739, 644, 2457, 349, 418, 329, 3173, 3254, - 817, 1097, 603, 610, 1322, 2044, 1864, 384, - 2114, 3193, 1218, 1994, 2455, 220, 2142, 1670, - 2144, 1799, 2051, 794, 1819, 2475, 2459, 478, - 3221, 3116, 622, 1097, 2470, 882, 1539, 2392 -); - -fn mont_reduce_16(a: i32) -> i32 { - let t: i32 = (a & 0xFFFF) * 3327; - let u: i32 = (t & 0xFFFF) * Q; - var r: i32 = (a - u) >> 16; - if (r < 0) { r = r + Q; } - if (r >= Q) { r = r - Q; } - return r; -} - -fn barrett_reduce(a: i32) -> i32 { - let t = (a * 20159) >> 26; - var r = a - t * Q; - if (r >= Q) { r = r - Q; } - if (r < 0) { r = r + Q; } - return r; -} - -fn kyber_ntt(poly: ptr>) { - var k = 0u; - var len = 128u; - loop { - if (len < 2u) { break; } - var start = 0u; - loop { - if (start >= 256u) { break; } - k = k + 1u; - let z = KYBER_ZETAS[k]; - var j = start; - loop { - if (j >= start + len) { break; } - let t = mont_reduce_16(z * (*poly)[j + len]); - (*poly)[j + len] = (*poly)[j] - t; - (*poly)[j] = (*poly)[j] + t; - j = j + 1u; - } - start = start + 2u * len; - } - len = len >> 1u; - } -} - -fn kyber_inv_ntt(poly: ptr>) { - let f: i32 = 1441; // Montgomery form of 256^{-1} - var k = 127u; - var len = 2u; - loop { - if (len > 128u) { break; } - var start = 0u; - loop { - if (start >= 256u) { break; } - let z = Q - KYBER_ZETAS[k]; - k = k - 1u; - var j = start; - loop { - if (j >= start + len) { break; } - let t = (*poly)[j]; - (*poly)[j] = t + (*poly)[j + len]; - (*poly)[j + len] = t - (*poly)[j + len]; - (*poly)[j + len] = mont_reduce_16(z * (*poly)[j + len]); - j = j + 1u; - } - start = start + 2u * len; - } - len = len << 1u; - } - for (var i = 0u; i < 256u; i = i + 1u) { - (*poly)[i] = mont_reduce_16(f * (*poly)[i]); - } -} - -fn read_byte_sk(base: u32, idx: u32) -> u32 { - let word_idx = (base + idx) >> 2u; - let byte_pos = (base + idx) & 3u; - return (sk_data[word_idx] >> (byte_pos * 8u)) & 0xFFu; -} - -fn read_byte_ct(base: u32, idx: u32) -> u32 { - let word_idx = (base + idx) >> 2u; - let byte_pos = (base + idx) & 3u; - return (ct_data[word_idx] >> (byte_pos * 8u)) & 0xFFu; -} - -@compute @workgroup_size(64) -fn mlkem_decapsulate_batch(@builtin(global_invocation_id) gid: vec3) { - let tid = gid.x; - if (tid >= params.x) { return; } - - let sk_base = tid * 2400u; - let ct_base = tid * 1088u; - - // Decode s_hat (NTT domain, k=3 polynomials, 12 bits per coeff) - var s_hat: array, 3>; - for (var p = 0u; p < 3u; p = p + 1u) { - let sp = sk_base + p * 384u; - for (var i = 0u; i < 256u; i = i + 1u) { - let idx = i * 3u / 2u; - if ((i & 1u) == 1u) { - s_hat[p][i] = i32((read_byte_sk(sp, idx) >> 4u) | (read_byte_sk(sp, idx + 1u) << 4u)) & 0xFFF; - } else { - s_hat[p][i] = i32(read_byte_sk(sp, idx) | (read_byte_sk(sp, idx + 1u) << 8u)) & 0xFFF; - } - } - } - - // Decode u (compressed, 10 bits per coeff) and decompress - var u_hat: array, 3>; - for (var p = 0u; p < 3u; p = p + 1u) { - let up = ct_base + p * 320u; - for (var i = 0u; i < 256u; i = i + 4u) { - let idx = (i / 4u) * 5u; - let b0 = read_byte_ct(up, idx); - let b1 = read_byte_ct(up, idx + 1u); - let b2 = read_byte_ct(up, idx + 2u); - let b3 = read_byte_ct(up, idx + 3u); - let b4 = read_byte_ct(up, idx + 4u); - - u_hat[p][i] = i32((b0 | ((b1 & 0x03u) << 8u)) * u32(Q) + 512u) >> 10; - u_hat[p][i + 1u] = i32(((b1 >> 2u) | ((b2 & 0x0Fu) << 6u)) * u32(Q) + 512u) >> 10; - u_hat[p][i + 2u] = i32(((b2 >> 4u) | ((b3 & 0x3Fu) << 4u)) * u32(Q) + 512u) >> 10; - u_hat[p][i + 3u] = i32(((b3 >> 6u) | (b4 << 2u)) * u32(Q) + 512u) >> 10; - } - kyber_ntt(&u_hat[p]); - } - - // Inner product: s_hat^T * u_hat - var mp: array; - for (var i = 0u; i < 256u; i = i + 1u) { mp[i] = 0; } - - for (var p = 0u; p < 3u; p = p + 1u) { - for (var i = 0u; i < 256u; i = i + 1u) { - mp[i] = barrett_reduce(mp[i] + mont_reduce_16(s_hat[p][i] * u_hat[p][i])); - } - } - - kyber_inv_ntt(&mp); - - // Decode v (4 bits per coeff) and decompress - let vp = ct_base + 960u; // 3 * 320 - var out_words: array; - for (var i = 0u; i < 8u; i = i + 1u) { out_words[i] = 0u; } - - for (var i = 0u; i < 256u; i = i + 1u) { - let byte_idx = i / 2u; - let v_raw = read_byte_ct(vp, byte_idx); - var v_coeff: i32; - if ((i & 1u) == 0u) { - v_coeff = i32((v_raw & 0x0Fu) * u32(Q) + 8u) >> 4; - } else { - v_coeff = i32((v_raw >> 4u) * u32(Q) + 8u) >> 4; - } - - var diff = v_coeff - mp[i]; - if (diff < 0) { diff = diff + Q; } - let t = u32(diff) * 2u + u32(Q / 2); - let bit = (t / u32(Q)) & 1u; - out_words[i / 32u] = out_words[i / 32u] | (bit << (i % 32u)); - } - - let out_base = tid * 8u; - for (var i = 0u; i < 8u; i = i + 1u) { - out_data[out_base + i] = out_words[i]; - } -} diff --git a/modexp/gpu/cuda/modexp_karatsuba.cu b/modexp/gpu/cuda/modexp_karatsuba.cu deleted file mode 100644 index c530c35..0000000 --- a/modexp/gpu/cuda/modexp_karatsuba.cu +++ /dev/null @@ -1,104 +0,0 @@ -// luxcpp/crypto: Karatsuba multiplication kernel (CUDA / nvcc-compatible C++). -// Copyright (C) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// One-shot full multi-precision multiply r[2n] = x[n] * y[n] for n in -// [16, 64] limbs (1024..4096 bits). Native 64-bit integer arithmetic is -// available on CUDA so we use uint64_t and __umul64hi directly. -// -// Parallelization (matching the Metal kernel's host orchestration): -// * One block per multiplication. -// * Block-cooperative computation of three half-sized sub-products of -// Karatsuba (z0, z1', z2). The host driver issues three child kernels -// in parallel (different streams) when the Karatsuba split is profitable; -// the kernel below is the schoolbook base case dispatched per sub-product. -// -// This file compiles as plain C++ when CRYPTO_ENABLE_CUDA is OFF (host -// polyfill: same kernel body, single-threaded). The polyfill produces -// byte-identical output to the GPU path by construction. - -#if defined(__CUDACC__) - #define LUX_KKERNEL extern "C" __global__ - #define LUX_KDEVICE __device__ -#else - // Host polyfill -- compiled as plain C++ when CUDA isn't enabled. - #define LUX_KKERNEL extern "C" - #define LUX_KDEVICE static inline - #include -#endif - -#if !defined(__CUDACC__) -// Host umul64hi polyfill. -LUX_KDEVICE uint64_t lux_umul64hi(uint64_t a, uint64_t b) -{ - // Use the standard 32x32 decomposition. - const uint64_t a_lo = (uint32_t)a; - const uint64_t a_hi = a >> 32; - const uint64_t b_lo = (uint32_t)b; - const uint64_t b_hi = b >> 32; - - const uint64_t ll = a_lo * b_lo; - const uint64_t lh = a_lo * b_hi; - const uint64_t hl = a_hi * b_lo; - const uint64_t hh = a_hi * b_hi; - - const uint64_t mid = (ll >> 32) + (lh & 0xFFFFFFFFULL) + (hl & 0xFFFFFFFFULL); - return hh + (lh >> 32) + (hl >> 32) + (mid >> 32); -} -#define lux_umul_hi lux_umul64hi -#else -#include -#define lux_umul_hi __umul64hi -#endif - -// Full schoolbook product: r[2n] = x[n] * y[n]. -// Caller-allocated r is zero-initialised before launch (host driver clears it). -// -// Block strategy: one block performs the full product. Thread tid in [0,n) -// computes one j-row of the partial product. Carry chain across rows is -// handled by serializing the column adds inside a single warp (warp-level -// __shfl propagation), but for simplicity and byte-equivalence with CPU -// we use the canonical row-major schoolbook here. Performance optimizations -// (Comba diagonal, warp-level carry) live in a future kernel; this version -// targets correctness. -LUX_KKERNEL void modexp_kara_mul( - const uint64_t* __restrict__ x, - const uint64_t* __restrict__ y, - uint64_t* __restrict__ r, - unsigned n) -{ -#if defined(__CUDACC__) - if (blockIdx.x != 0 || threadIdx.x != 0) return; -#endif - - // Clear r. - for (unsigned k = 0; k < 2 * n; ++k) r[k] = 0; - - // Schoolbook: one block, one thread does the work. The Karatsuba split - // (3 sub-multiplies) is performed by the host driver issuing 3 launches - // in parallel; this kernel is the base case those launches use. - for (unsigned j = 0; j < n; ++j) - { - uint64_t carry = 0; - for (unsigned i = 0; i < n; ++i) - { -#if defined(__CUDACC__) - // 64x64 -> 128 bits via two intrinsics. - const uint64_t lo = x[i] * y[j]; - const uint64_t hi = lux_umul_hi(x[i], y[j]); -#else - // Host polyfill: __uint128_t is a GCC/Clang extension. - const __uint128_t prod = (__uint128_t)x[i] * (__uint128_t)y[j]; - const uint64_t lo = (uint64_t)prod; - const uint64_t hi = (uint64_t)(prod >> 64); -#endif - const uint64_t s1 = lo + r[i + j]; - const uint64_t c1 = (s1 < lo) ? 1 : 0; - const uint64_t s2 = s1 + carry; - const uint64_t c2 = (s2 < s1) ? 1 : 0; - r[i + j] = s2; - carry = hi + c1 + c2; // hi <= 2^64 - 2, so carries don't overflow - } - r[j + n] = carry; - } -} diff --git a/modexp/gpu/metal/modexp_karatsuba.metal b/modexp/gpu/metal/modexp_karatsuba.metal deleted file mode 100644 index ba6d9b2..0000000 --- a/modexp/gpu/metal/modexp_karatsuba.metal +++ /dev/null @@ -1,151 +0,0 @@ -// luxcpp/crypto: Karatsuba multiplication kernel (Metal Shading Language). -// Copyright (C) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// One-shot full multi-precision multiply r[2n] = x[n] * y[n] for power-of-2 -// n in [16, 64] limbs (1024..4096 bits). -// -// Parallelization strategy (matching the CPU body): -// * One threadgroup per multiplication. -// * threads cooperate on the three half-sized sub-products of Karatsuba: -// z0 = x_lo * y_lo -// z2 = x_hi * y_hi -// z1 = (x_hi+x_lo)*(y_hi+y_lo) - z2 - z0 -// * Each sub-product is a half-sized schoolbook multiply done by -// n/2 lanes in parallel; carry propagation is serial within each lane -// row. Total work-per-lane is O(n) limb-mul operations vs the CPU's -// O(n^1.585); we keep this simple because at n=64 the parallel sweep -// fits in a single threadgroup (32-lane wavefront on Apple Silicon). -// -// Byte-equivalence with the CPU body (cevm::crypto::karatsuba::kmul) is the -// hard correctness contract. The host driver picks up the result as a -// uint64 little-endian limb buffer. -// -// Metal does not have native 64-bit integer arithmetic on most GPU families; -// we implement uint64 multiply-add via two uint32 halves with the standard -// schoolbook algorithm. This trades parallelism for portability. - -#include -using namespace metal; - -// 64-bit limb stored as a pair of 32-bit halves (lo, hi). -struct U64 { uint lo; uint hi; }; - -inline U64 u64_make(uint lo, uint hi) { return {lo, hi}; } - -inline U64 u64_add(U64 a, U64 b, thread bool& carry_out) { - uint lo = a.lo + b.lo; - uint clo = (lo < a.lo) ? 1u : 0u; - uint hi = a.hi + b.hi + clo; - carry_out = (hi < a.hi) || (clo == 1u && hi == a.hi); - return {lo, hi}; -} - -// 32x32 -> 64 multiplication (Metal native). -inline U64 u32x32_to_u64(uint a, uint b) { - ulong p = ulong(a) * ulong(b); - return {uint(p & 0xFFFFFFFFu), uint(p >> 32)}; -} - -// 64x64 -> 128 multiplication: returns (lo64, hi64) as a 4-tuple. -inline void u64_mul_full(U64 a, U64 b, thread U64& lo, thread U64& hi) { - // Decompose: a = (a.hi << 32) | a.lo, similarly b. - // Cross-products: - // ll = a.lo * b.lo -> 64 bits, low half of result - // lh = a.lo * b.hi -> 64 bits, mid - // hl = a.hi * b.lo -> 64 bits, mid - // hh = a.hi * b.hi -> 64 bits, high half - U64 ll = u32x32_to_u64(a.lo, b.lo); - U64 lh = u32x32_to_u64(a.lo, b.hi); - U64 hl = u32x32_to_u64(a.hi, b.lo); - U64 hh = u32x32_to_u64(a.hi, b.hi); - - // Combine: result = ll + (lh + hl) << 32 + hh << 64. - // lo64 = ll.lo | ((ll.hi + lh.lo + hl.lo) << 32) - // The mid sum lh+hl can exceed 64 bits; track carry into hi64. - uint mid_lo = ll.hi + lh.lo; - uint c1 = (mid_lo < ll.hi) ? 1u : 0u; - uint mid_lo2 = mid_lo + hl.lo; - uint c2 = (mid_lo2 < mid_lo) ? 1u : 0u; - - uint mid_hi = lh.hi + hl.hi + c1 + c2; // can't overflow uint here - - lo = u64_make(ll.lo, mid_lo2); - // hi64 = hh + (mid_hi-portion) + carry from mid bits over 64. - uint hi_lo = hh.lo + mid_hi; - uint c3 = (hi_lo < hh.lo) ? 1u : 0u; - uint hi_hi = hh.hi + c3; - hi = u64_make(hi_lo, hi_hi); -} - -// Schoolbook full product: r[2n] = x[n] * y[n], single-thread-per-row. -// Threadgroup limit at the dispatch level: n=64 limbs (full 4096-bit RSA case). -// Each thread tid in [0, n) computes one diagonal of the convolution. -// Carries flow vertically; we serialize the row reductions via threadgroup -// barriers so the kernel stays correct without atomics. -// -// Actual production performance comes from the host orchestration: at n=64 -// we dispatch 3 such kernels concurrently to compute z0, z1', z2 in three -// command buffers, then a final fix-up kernel sums them per Karatsuba. -// -// For simplicity, this single-kernel form computes the full schoolbook -// product (correct, deterministic). The Karatsuba split is in the host -// driver; the 3 kernel invocations are themselves trivial schoolbook and -// thus byte-identical to the CPU base case at n=4 (THRESHOLD). -kernel void modexp_kara_mul( - device const U64* x [[ buffer(0) ]], - device const U64* y [[ buffer(1) ]], - device U64* r [[ buffer(2) ]], - constant uint& n [[ buffer(3) ]], - threadgroup U64* tmp [[ threadgroup(0) ]], - uint tid [[ thread_position_in_threadgroup ]], - uint tcount [[ threads_per_threadgroup ]]) -{ - // Per-row schoolbook: each thread handles one j-row computing - // r[i+j..i+j+n] += x[i] * y[j] - // for j == tid. Up to n threads cooperate. The barrier ensures all - // partial products are visible before the next reduction step. - - // Initialize result to zero (one-shot, bounded n <= MAX_LIMBS). - if (tid < 2 * n) { - r[tid] = u64_make(0u, 0u); - if (tid + 64 < 2 * n) r[tid + 64] = u64_make(0u, 0u); - } - threadgroup_barrier(mem_flags::mem_device); - - if (tid >= n) return; // only the first n threads do work - - // Each thread j computes its row contribution serially and accumulates - // into r[j..j+n]. To keep the writes correct without atomics, threads - // are serialized: thread j waits for thread j-1's contributions to - // settle by tag-based phase synchronization. - // - // Simpler approach: serialize the rows by doing each j sequentially - // using a single thread (tid==0). Coarse-grained but byte-correct, - // matches the CPU schoolbook bit-for-bit. The Karatsuba parallelism - // comes from dispatching 3 such kernels concurrently from the host. - if (tid == 0) { - for (uint j = 0; j < n; ++j) { - U64 carry = u64_make(0u, 0u); - for (uint i = 0; i < n; ++i) { - U64 plo, phi; - u64_mul_full(x[i], y[j], plo, phi); - - // r[i+j] += plo + carry - bool c1 = false; - U64 sum1 = u64_add(r[i + j], plo, c1); - bool c2 = false; - U64 sum2 = u64_add(sum1, carry, c2); - r[i + j] = sum2; - - // carry = phi + (c1 + c2) - U64 c_word = u64_make((c1 ? 1u : 0u) + (c2 ? 1u : 0u), 0u); - bool c3 = false; - carry = u64_add(phi, c_word, c3); - // c3 always false here (phi < 2^64 - 2 in bound). - } - r[j + n] = carry; - } - } - threadgroup_barrier(mem_flags::mem_device); -} diff --git a/modexp/gpu/metal/modular.metal b/modexp/gpu/metal/modular.metal deleted file mode 100644 index e1a749e..0000000 --- a/modexp/gpu/metal/modular.metal +++ /dev/null @@ -1,327 +0,0 @@ -// Copyright (c) 2024-2026 Lux Partners Limited -// SPDX-License-Identifier: BSD-3-Clause -// -// Modular Arithmetic - High-Performance Metal Implementation -// Montgomery and Barrett reduction for finite field operations. - -#include -using namespace metal; - -// ============================================================================ -// 64-bit Unsigned Integer (emulated) -// ============================================================================ - -struct U64 { - uint lo; - uint hi; -}; - -inline U64 u64_from(ulong v) { - return {uint(v & 0xFFFFFFFFu), uint(v >> 32)}; -} - -inline ulong u64_to(U64 v) { - return ulong(v.lo) | (ulong(v.hi) << 32); -} - -inline U64 u64_zero() { return {0u, 0u}; } -inline U64 u64_one() { return {1u, 0u}; } - -inline bool u64_eq(U64 a, U64 b) { - return a.lo == b.lo && a.hi == b.hi; -} - -inline bool u64_lt(U64 a, U64 b) { - if (a.hi < b.hi) return true; - if (a.hi > b.hi) return false; - return a.lo < b.lo; -} - -inline bool u64_gte(U64 a, U64 b) { - return !u64_lt(a, b); -} - -inline U64 u64_add(U64 a, U64 b) { - uint lo = a.lo + b.lo; - uint carry = (lo < a.lo) ? 1u : 0u; - uint hi = a.hi + b.hi + carry; - return {lo, hi}; -} - -inline U64 u64_sub(U64 a, U64 b) { - uint borrow = (a.lo < b.lo) ? 1u : 0u; - uint lo = a.lo - b.lo; - uint hi = a.hi - b.hi - borrow; - return {lo, hi}; -} - -// 32x32 -> 64 bit multiplication -inline U64 mul32_to_64(uint a, uint b) { - uint a_lo = a & 0xFFFFu; - uint a_hi = a >> 16u; - uint b_lo = b & 0xFFFFu; - uint b_hi = b >> 16u; - - uint p0 = a_lo * b_lo; - uint p1 = a_lo * b_hi; - uint p2 = a_hi * b_lo; - uint p3 = a_hi * b_hi; - - uint mid = p1 + p2; - uint mid_carry = (mid < p1) ? 0x10000u : 0u; - - uint lo = p0 + (mid << 16u); - uint carry = (lo < p0) ? 1u : 0u; - uint hi = p3 + (mid >> 16u) + mid_carry + carry; - - return {lo, hi}; -} - -// 64x64 -> 128 bit multiplication (returns low 64 bits and high 64 bits) -inline void mul64_to_128(U64 a, U64 b, thread U64& lo, thread U64& hi) { - U64 p0 = mul32_to_64(a.lo, b.lo); - U64 p1 = mul32_to_64(a.lo, b.hi); - U64 p2 = mul32_to_64(a.hi, b.lo); - U64 p3 = mul32_to_64(a.hi, b.hi); - - // lo = p0.lo, carry from p0.hi + p1.lo + p2.lo - lo.lo = p0.lo; - - uint sum1 = p0.hi + p1.lo; - uint c1 = (sum1 < p0.hi) ? 1u : 0u; - uint sum2 = sum1 + p2.lo; - uint c2 = (sum2 < sum1) ? 1u : 0u; - lo.hi = sum2; - - // hi = p3 + p1.hi + p2.hi + carries - uint carry_sum = c1 + c2 + p1.hi + p2.hi; - hi = u64_add(p3, {carry_sum, 0u}); -} - -// ============================================================================ -// Modular Arithmetic -// ============================================================================ - -// Modular addition: (a + b) mod q -inline U64 mod_add(U64 a, U64 b, U64 q) { - U64 sum = u64_add(a, b); - // Check for overflow or sum >= q - bool overflow = (sum.hi < a.hi) || (sum.hi == a.hi && sum.lo < a.lo); - if (overflow || u64_gte(sum, q)) { - sum = u64_sub(sum, q); - } - return sum; -} - -// Modular subtraction: (a - b) mod q -inline U64 mod_sub(U64 a, U64 b, U64 q) { - if (u64_lt(a, b)) { - return u64_sub(u64_add(a, q), b); - } - return u64_sub(a, b); -} - -// Modular negation: -a mod q -inline U64 mod_neg(U64 a, U64 q) { - if (u64_eq(a, u64_zero())) { - return u64_zero(); - } - return u64_sub(q, a); -} - -// Barrett reduction for 128-bit product -inline U64 barrett_reduce_wide(U64 lo, U64 hi, U64 q, U64 mu) { - // Approximate quotient: q_hat = (hi * mu) >> 64 - U64 q_lo, q_hi; - mul64_to_128(hi, mu, q_lo, q_hi); - - // r = (lo, hi) - q_hat * q - U64 prod_lo, prod_hi; - mul64_to_128(q_hi, q, prod_lo, prod_hi); - - U64 r = u64_sub(lo, prod_lo); - - // Correction steps - while (u64_gte(r, q)) { - r = u64_sub(r, q); - } - - return r; -} - -// Modular multiplication using Barrett -inline U64 mod_mul(U64 a, U64 b, U64 q, U64 mu) { - U64 lo, hi; - mul64_to_128(a, b, lo, hi); - return barrett_reduce_wide(lo, hi, q, mu); -} - -// Montgomery reduction: compute aR^-1 mod q -inline U64 mont_reduce(U64 lo, U64 hi, U64 q, U64 m0_inv) { - // m = lo * m0_inv (mod 2^64) - U64 m_lo, m_hi; - mul64_to_128(lo, m0_inv, m_lo, m_hi); - - // t = (lo + m*q) >> 64 - U64 prod_lo, prod_hi; - mul64_to_128(m_lo, q, prod_lo, prod_hi); - - // Add lo + prod and take high part - U64 sum = u64_add(lo, prod_lo); - uint carry = u64_lt(sum, lo) ? 1u : 0u; - - U64 result = u64_add(hi, prod_hi); - result = u64_add(result, {carry, 0u}); - - // Conditional subtraction - if (u64_gte(result, q)) { - result = u64_sub(result, q); - } - - return result; -} - -// Montgomery multiplication -inline U64 mont_mul(U64 a, U64 b, U64 q, U64 m0_inv) { - U64 lo, hi; - mul64_to_128(a, b, lo, hi); - return mont_reduce(lo, hi, q, m0_inv); -} - -// Modular exponentiation (square-and-multiply) -inline U64 mod_pow(U64 base, U64 exp, U64 q, U64 mu) { - U64 result = u64_one(); - U64 b = base; - - // Process low 32 bits - uint e = exp.lo; - while (e > 0) { - if (e & 1u) { - result = mod_mul(result, b, q, mu); - } - b = mod_mul(b, b, q, mu); - e >>= 1u; - } - - // Process high 32 bits - e = exp.hi; - while (e > 0) { - if (e & 1u) { - result = mod_mul(result, b, q, mu); - } - b = mod_mul(b, b, q, mu); - e >>= 1u; - } - - return result; -} - -// Modular inverse using Fermat's little theorem: a^(q-2) mod q -inline U64 mod_inv(U64 a, U64 q, U64 mu) { - U64 exp = u64_sub(q, {2u, 0u}); - return mod_pow(a, exp, q, mu); -} - -// ============================================================================ -// Batch Kernels -// ============================================================================ - -kernel void batch_mod_add( - device const U64* a [[buffer(0)]], - device const U64* b [[buffer(1)]], - device U64* c [[buffer(2)]], - constant U64& q [[buffer(3)]], - constant uint& n [[buffer(4)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= n) return; - c[gid] = mod_add(a[gid], b[gid], q); -} - -kernel void batch_mod_sub( - device const U64* a [[buffer(0)]], - device const U64* b [[buffer(1)]], - device U64* c [[buffer(2)]], - constant U64& q [[buffer(3)]], - constant uint& n [[buffer(4)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= n) return; - c[gid] = mod_sub(a[gid], b[gid], q); -} - -kernel void batch_mod_mul( - device const U64* a [[buffer(0)]], - device const U64* b [[buffer(1)]], - device U64* c [[buffer(2)]], - constant U64& q [[buffer(3)]], - constant U64& mu [[buffer(4)]], - constant uint& n [[buffer(5)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= n) return; - c[gid] = mod_mul(a[gid], b[gid], q, mu); -} - -kernel void batch_mont_mul( - device const U64* a [[buffer(0)]], - device const U64* b [[buffer(1)]], - device U64* c [[buffer(2)]], - constant U64& q [[buffer(3)]], - constant U64& m0_inv [[buffer(4)]], - constant uint& n [[buffer(5)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= n) return; - c[gid] = mont_mul(a[gid], b[gid], q, m0_inv); -} - -kernel void batch_mod_pow( - device const U64* bases [[buffer(0)]], - device const U64* exps [[buffer(1)]], - device U64* results [[buffer(2)]], - constant U64& q [[buffer(3)]], - constant U64& mu [[buffer(4)]], - constant uint& n [[buffer(5)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= n) return; - results[gid] = mod_pow(bases[gid], exps[gid], q, mu); -} - -kernel void batch_mod_inv( - device const U64* a [[buffer(0)]], - device U64* a_inv [[buffer(1)]], - constant U64& q [[buffer(2)]], - constant U64& mu [[buffer(3)]], - constant uint& n [[buffer(4)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= n) return; - a_inv[gid] = mod_inv(a[gid], q, mu); -} - -kernel void batch_to_mont( - device const U64* a [[buffer(0)]], - device U64* a_mont [[buffer(1)]], - constant U64& q [[buffer(2)]], - constant U64& r2 [[buffer(3)]], - constant U64& m0_inv [[buffer(4)]], - constant uint& n [[buffer(5)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= n) return; - a_mont[gid] = mont_mul(a[gid], r2, q, m0_inv); -} - -kernel void batch_from_mont( - device const U64* a_mont [[buffer(0)]], - device U64* a [[buffer(1)]], - constant U64& q [[buffer(2)]], - constant U64& m0_inv [[buffer(3)]], - constant uint& n [[buffer(4)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= n) return; - a[gid] = mont_mul(a_mont[gid], u64_one(), q, m0_inv); -} diff --git a/modexp/gpu/wgsl/modexp_karatsuba.wgsl b/modexp/gpu/wgsl/modexp_karatsuba.wgsl deleted file mode 100644 index c2f9b25..0000000 --- a/modexp/gpu/wgsl/modexp_karatsuba.wgsl +++ /dev/null @@ -1,114 +0,0 @@ -// luxcpp/crypto: Karatsuba multiplication kernel (WebGPU Shading Language). -// Copyright (C) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// WGSL has no native uint64; we represent each 64-bit limb as a vec2 -// (lo, hi). 64-bit arithmetic is implemented via 32x32 -> 64-bit native ops. -// -// One-shot full multi-precision multiply r[2n] = x[n] * y[n] for n in -// [16, 64] limbs (1024..4096 bits). Threadgroup-cooperative for the three -// Karatsuba sub-products (host driver dispatches three workgroups concurrently); -// the kernel itself is the schoolbook base case. - -const MAX_LIMBS : u32 = 64u; - -struct U64 { lo: u32, hi: u32 }; - -@group(0) @binding(0) var x : array; -@group(0) @binding(1) var y : array; -@group(0) @binding(2) var r : array; -@group(0) @binding(3) var n : u32; - -// 32 x 32 -> 64 multiply, returned as U64 {lo, hi}. -fn u32x32_to_u64(a: u32, b: u32) -> U64 { - // Use WGSL's native 64-bit-by-multiplication-decomposition. The - // u32 * u32 -> u32 result truncates the high 32 bits; we recompute via - // 16-bit splits to recover the full 64-bit product portably. - let a_lo : u32 = a & 0xFFFFu; - let a_hi : u32 = a >> 16u; - let b_lo : u32 = b & 0xFFFFu; - let b_hi : u32 = b >> 16u; - - let ll : u32 = a_lo * b_lo; - let lh : u32 = a_lo * b_hi; - let hl : u32 = a_hi * b_lo; - let hh : u32 = a_hi * b_hi; - - // result = ll + (lh + hl) << 16 + hh << 32. - let mid : u32 = lh + hl; - let mid_carry : u32 = select(0u, 1u << 16u, mid < lh); - - let lo_part : u32 = ll + (mid << 16u); - let lo_carry : u32 = select(0u, 1u, lo_part < ll); - - let hi_part : u32 = hh + (mid >> 16u) + mid_carry + lo_carry; - - return U64(lo_part, hi_part); -} - -// 64x64 -> 128 multiply: returns (lo64, hi64). Output via four u32 lanes -// because WGSL doesn't support multiple return values; we use out parameters -// via storage but inline four-result vec4. -fn u64_mul_full(a: U64, b: U64) -> vec4 { - let ll : U64 = u32x32_to_u64(a.lo, b.lo); - let lh : U64 = u32x32_to_u64(a.lo, b.hi); - let hl : U64 = u32x32_to_u64(a.hi, b.lo); - let hh : U64 = u32x32_to_u64(a.hi, b.hi); - - // sum mid words: (ll.hi + lh.lo + hl.lo) carrying into next. - let m1 : u32 = ll.hi + lh.lo; - let c1 : u32 = select(0u, 1u, m1 < ll.hi); - let m2 : u32 = m1 + hl.lo; - let c2 : u32 = select(0u, 1u, m2 < m1); - - let lo64 : U64 = U64(ll.lo, m2); - let mid_hi : u32 = lh.hi + hl.hi + c1 + c2; - let h1 : u32 = hh.lo + mid_hi; - let c3 : u32 = select(0u, 1u, h1 < hh.lo); - let hi64 : U64 = U64(h1, hh.hi + c3); - - return vec4(lo64.lo, lo64.hi, hi64.lo, hi64.hi); -} - -@compute @workgroup_size(1) -fn modexp_kara_mul(@builtin(global_invocation_id) gid: vec3) { - if (gid.x != 0u) { return; } - - // Zero r. - for (var k : u32 = 0u; k < 2u * n; k = k + 1u) { - r[k] = U64(0u, 0u); - } - - // Schoolbook full product. - for (var j : u32 = 0u; j < n; j = j + 1u) { - var carry : U64 = U64(0u, 0u); - for (var i : u32 = 0u; i < n; i = i + 1u) { - let p : vec4 = u64_mul_full(x[i], y[j]); - let plo : U64 = U64(p.x, p.y); - let phi : U64 = U64(p.z, p.w); - - // r[i+j] += plo + carry - let r_old = r[i + j]; - let s1_lo : u32 = r_old.lo + plo.lo; - let s1_lc : u32 = select(0u, 1u, s1_lo < r_old.lo); - let s1_hi : u32 = r_old.hi + plo.hi + s1_lc; - let s1_hc : u32 = select(0u, 1u, - (s1_hi < r_old.hi) || (s1_lc == 1u && s1_hi == r_old.hi)); - - let s2_lo : u32 = s1_lo + carry.lo; - let s2_lc : u32 = select(0u, 1u, s2_lo < s1_lo); - let s2_hi : u32 = s1_hi + carry.hi + s2_lc; - let s2_hc : u32 = select(0u, 1u, - (s2_hi < s1_hi) || (s2_lc == 1u && s2_hi == s1_hi)); - - r[i + j] = U64(s2_lo, s2_hi); - - // carry = phi + (s1_hc + s2_hc) - let cw : u32 = s1_hc + s2_hc; - let nc_lo : u32 = phi.lo + cw; - let nc_lc : u32 = select(0u, 1u, nc_lo < phi.lo); - carry = U64(nc_lo, phi.hi + nc_lc); - } - r[j + n] = carry; - } -} diff --git a/ntt/gpu/cuda/four_step_ntt.cu b/ntt/gpu/cuda/four_step_ntt.cu deleted file mode 100644 index 298e8f4..0000000 --- a/ntt/gpu/cuda/four_step_ntt.cu +++ /dev/null @@ -1,452 +0,0 @@ -// ============================================================================= -// Four-Step NTT Optimized for CUDA -// ============================================================================= -// CUDA port of four_step_ntt.metal -- byte-identical arithmetic output. -// -// Four-Step Algorithm for N = N1 * N2: -// 1. N2 parallel column NTTs of size N1 -// 2. Twiddle multiplication by omega^(i*j) -// 3. Matrix transpose -// 4. N1 parallel row NTTs of size N2 -// -// Copyright (C) 2024-2025 Lux Industries Inc. -// SPDX-License-Identifier: BSD-2-Clause - -#include - -#ifdef __CUDA_ARCH__ -#define NTT_DEVICE __device__ __forceinline__ -#else -#define NTT_DEVICE inline -#define __global__ -#define __shared__ -static inline uint64_t __umul64hi(uint64_t a, uint64_t b) { - __uint128_t r = (__uint128_t)a * b; return (uint64_t)(r >> 64); -} -static inline void __syncthreads() {} -#endif - -static const uint32_t MAX_TILE_SIZE = 4096; -static const uint32_t MAX_TILE_DIM = 64; - -struct FourStepParams { - uint64_t Q; - uint64_t mu; - uint64_t N_inv; - uint64_t N_inv_precon; - uint32_t N; - uint32_t N1; - uint32_t N2; - uint32_t log_N1; - uint32_t log_N2; - uint32_t tile_stride; - uint32_t batch_size; -}; - -NTT_DEVICE uint64_t barrett_mul_precon(uint64_t a, uint64_t b, uint64_t Q, uint64_t precon) { - uint64_t q_approx = __umul64hi(a, precon); - uint64_t product = a * b; - uint64_t result = product - q_approx * Q; - return (result >= Q) ? (result - Q) : result; -} - -NTT_DEVICE uint64_t barrett_mul(uint64_t a, uint64_t b, uint64_t Q, uint64_t mu) { - uint64_t lo = a * b; - uint64_t q = __umul64hi(lo, mu); - uint64_t result = lo - q * Q; - return (result >= Q) ? (result - Q) : result; -} - -NTT_DEVICE uint64_t mod_add(uint64_t a, uint64_t b, uint64_t Q) { - uint64_t sum = a + b; - return (sum >= Q) ? (sum - Q) : sum; -} - -NTT_DEVICE uint64_t mod_sub(uint64_t a, uint64_t b, uint64_t Q) { - return (a >= b) ? (a - b) : (a + Q - b); -} - -NTT_DEVICE void ct_butterfly(uint64_t& lo, uint64_t& hi, uint64_t tw, uint64_t tw_pre, uint64_t Q) { - uint64_t hi_tw = barrett_mul_precon(hi, tw, Q, tw_pre); - uint64_t new_lo = mod_add(lo, hi_tw, Q); - uint64_t new_hi = mod_sub(lo, hi_tw, Q); - lo = new_lo; - hi = new_hi; -} - -NTT_DEVICE void gs_butterfly(uint64_t& lo, uint64_t& hi, uint64_t tw, uint64_t tw_pre, uint64_t Q) { - uint64_t sum = mod_add(lo, hi, Q); - uint64_t diff = mod_sub(lo, hi, Q); - lo = sum; - hi = barrett_mul_precon(diff, tw, Q, tw_pre); -} - -// In-shared-memory NTT helpers (stride access for column NTTs) -NTT_DEVICE void threadgroup_ntt_forward( - uint64_t* shared, uint32_t stride, uint32_t N, uint32_t log_N, - uint32_t thread_idx, uint32_t num_threads, - const uint64_t* twiddles, const uint64_t* twiddle_precon, uint64_t Q) -{ - for (uint32_t stage = 0; stage < log_N; ++stage) { - uint32_t m = 1u << stage; - uint32_t t = N >> (stage + 1); - uint32_t num_butterflies = N >> 1; - uint32_t bpt = (num_butterflies + num_threads - 1) / num_threads; - for (uint32_t b = 0; b < bpt; ++b) { - uint32_t bi = thread_idx + b * num_threads; - if (bi >= num_butterflies) break; - uint32_t group = bi / t; - uint32_t j = bi % t; - uint32_t idx_lo = (group * 2 * t + j) * stride; - uint32_t idx_hi = idx_lo + t * stride; - uint32_t tw_idx = m + group; - uint64_t tw = twiddles[tw_idx]; - uint64_t tw_pre = twiddle_precon[tw_idx]; - uint64_t lo = shared[idx_lo]; - uint64_t hi = shared[idx_hi]; - ct_butterfly(lo, hi, tw, tw_pre, Q); - shared[idx_lo] = lo; - shared[idx_hi] = hi; - } -#ifdef __CUDA_ARCH__ - __syncthreads(); -#endif - } -} - -NTT_DEVICE void threadgroup_ntt_inverse( - uint64_t* shared, uint32_t stride, uint32_t N, uint32_t log_N, - uint32_t thread_idx, uint32_t num_threads, - const uint64_t* twiddles, const uint64_t* twiddle_precon, uint64_t Q) -{ - for (uint32_t stage = 0; stage < log_N; ++stage) { - uint32_t m = N >> (stage + 1); - uint32_t t = 1u << stage; - uint32_t num_butterflies = N >> 1; - uint32_t bpt = (num_butterflies + num_threads - 1) / num_threads; - for (uint32_t b = 0; b < bpt; ++b) { - uint32_t bi = thread_idx + b * num_threads; - if (bi >= num_butterflies) break; - uint32_t group = bi / t; - uint32_t j = bi % t; - uint32_t idx_lo = (group * 2 * t + j) * stride; - uint32_t idx_hi = idx_lo + t * stride; - uint32_t tw_idx = m + group; - uint64_t tw = twiddles[tw_idx]; - uint64_t tw_pre = twiddle_precon[tw_idx]; - uint64_t lo = shared[idx_lo]; - uint64_t hi = shared[idx_hi]; - gs_butterfly(lo, hi, tw, tw_pre, Q); - shared[idx_lo] = lo; - shared[idx_hi] = hi; - } -#ifdef __CUDA_ARCH__ - __syncthreads(); -#endif - } -} - -// ============================================================================= -// Step 1: Column NTTs (Forward) -// ============================================================================= - -extern "C" __global__ void four_step_column_ntt( - uint64_t* data, const uint64_t* twiddles, const uint64_t* twiddle_precon, - const FourStepParams params) -{ -#ifdef __CUDA_ARCH__ - extern __shared__ uint64_t shared[]; - uint32_t thread_idx = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y; - uint32_t threadgroup_size = blockDim.x * blockDim.y * blockDim.z; - uint32_t N1 = params.N1; - uint32_t N2 = params.N2; - uint32_t N = params.N; - uint64_t Q = params.Q; - uint32_t batch_idx = blockIdx.z; - uint32_t tile_row = blockIdx.y; - uint32_t tile_col = blockIdx.x; - uint32_t tile_stride = params.tile_stride; - uint32_t TILE_N1 = min(N1, MAX_TILE_DIM); - uint32_t TILE_N2 = min(N2, MAX_TILE_DIM); - uint64_t* batch_data = data + batch_idx * N; - - uint32_t ept = (TILE_N1 * TILE_N2 + threadgroup_size - 1) / threadgroup_size; - for (uint32_t e = 0; e < ept; ++e) { - uint32_t li = thread_idx + e * threadgroup_size; - if (li >= TILE_N1 * TILE_N2) break; - uint32_t lr = li / TILE_N2; - uint32_t lc = li % TILE_N2; - uint32_t gr = tile_row * TILE_N1 + lr; - uint32_t gc = tile_col * TILE_N2 + lc; - if (gr < N1 && gc < N2) - shared[lr * tile_stride + lc] = batch_data[gr * N2 + gc]; - } - __syncthreads(); - - uint32_t log_N1 = params.log_N1; - for (uint32_t col = 0; col < TILE_N2; ++col) { - threadgroup_ntt_forward(shared + col, tile_stride, TILE_N1, log_N1, - thread_idx, threadgroup_size, twiddles, twiddle_precon, Q); - } - - for (uint32_t e = 0; e < ept; ++e) { - uint32_t li = thread_idx + e * threadgroup_size; - if (li >= TILE_N1 * TILE_N2) break; - uint32_t lr = li / TILE_N2; - uint32_t lc = li % TILE_N2; - uint32_t gr = tile_row * TILE_N1 + lr; - uint32_t gc = tile_col * TILE_N2 + lc; - if (gr < N1 && gc < N2) - batch_data[gr * N2 + gc] = shared[lr * tile_stride + lc]; - } -#endif -} - -// ============================================================================= -// Step 1: Column NTTs (Inverse) -// ============================================================================= - -extern "C" __global__ void four_step_column_intt( - uint64_t* data, const uint64_t* twiddles, const uint64_t* twiddle_precon, - const FourStepParams params) -{ -#ifdef __CUDA_ARCH__ - extern __shared__ uint64_t shared[]; - uint32_t thread_idx = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y; - uint32_t threadgroup_size = blockDim.x * blockDim.y * blockDim.z; - uint32_t N1 = params.N1, N2 = params.N2, N = params.N; - uint64_t Q = params.Q; - uint32_t batch_idx = blockIdx.z; - uint32_t tile_row = blockIdx.y, tile_col = blockIdx.x; - uint32_t tile_stride = params.tile_stride; - uint32_t TILE_N1 = min(N1, MAX_TILE_DIM); - uint32_t TILE_N2 = min(N2, MAX_TILE_DIM); - uint64_t* batch_data = data + batch_idx * N; - - uint32_t ept = (TILE_N1 * TILE_N2 + threadgroup_size - 1) / threadgroup_size; - for (uint32_t e = 0; e < ept; ++e) { - uint32_t li = thread_idx + e * threadgroup_size; - if (li >= TILE_N1 * TILE_N2) break; - uint32_t lr = li / TILE_N2, lc = li % TILE_N2; - uint32_t gr = tile_row * TILE_N1 + lr, gc = tile_col * TILE_N2 + lc; - if (gr < N1 && gc < N2) - shared[lr * tile_stride + lc] = batch_data[gr * N2 + gc]; - } - __syncthreads(); - - for (uint32_t col = 0; col < TILE_N2; ++col) { - threadgroup_ntt_inverse(shared + col, tile_stride, TILE_N1, params.log_N1, - thread_idx, threadgroup_size, twiddles, twiddle_precon, Q); - } - - for (uint32_t e = 0; e < ept; ++e) { - uint32_t li = thread_idx + e * threadgroup_size; - if (li >= TILE_N1 * TILE_N2) break; - uint32_t lr = li / TILE_N2, lc = li % TILE_N2; - uint32_t gr = tile_row * TILE_N1 + lr, gc = tile_col * TILE_N2 + lc; - if (gr < N1 && gc < N2) - batch_data[gr * N2 + gc] = shared[lr * tile_stride + lc]; - } -#endif -} - -// ============================================================================= -// Step 2+3: Fused Twiddle Multiplication and Transpose -// ============================================================================= - -extern "C" __global__ void four_step_twiddle_transpose( - uint64_t* output, const uint64_t* input, - const uint64_t* twiddles, const uint64_t* twiddle_precon, - const FourStepParams params) -{ -#ifdef __CUDA_ARCH__ - extern __shared__ uint64_t shared[]; - uint32_t thread_idx = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y; - uint32_t threadgroup_size = blockDim.x * blockDim.y * blockDim.z; - uint32_t N1 = params.N1, N2 = params.N2, N = params.N; - uint64_t Q = params.Q; - uint32_t batch_idx = blockIdx.z; - uint32_t tile_row = blockIdx.y, tile_col = blockIdx.x; - uint32_t tile_stride = params.tile_stride; - uint32_t TILE_DIM = MAX_TILE_DIM; - const uint64_t* batch_input = input + batch_idx * N; - uint64_t* batch_output = output + batch_idx * N; - - uint32_t ept = (TILE_DIM * TILE_DIM + threadgroup_size - 1) / threadgroup_size; - for (uint32_t e = 0; e < ept; ++e) { - uint32_t li = thread_idx + e * threadgroup_size; - if (li >= TILE_DIM * TILE_DIM) break; - uint32_t lr = li / TILE_DIM, lc = li % TILE_DIM; - uint32_t gr = tile_row * TILE_DIM + lr, gc = tile_col * TILE_DIM + lc; - if (gr < N1 && gc < N2) { - uint32_t in_idx = gr * N2 + gc; - uint64_t val = batch_input[in_idx]; - uint32_t tw_idx = gr * N2 + gc; - val = barrett_mul_precon(val, twiddles[tw_idx], Q, twiddle_precon[tw_idx]); - shared[lc * tile_stride + lr] = val; // transposed store - } - } - __syncthreads(); - - for (uint32_t e = 0; e < ept; ++e) { - uint32_t li = thread_idx + e * threadgroup_size; - if (li >= TILE_DIM * TILE_DIM) break; - uint32_t lr = li / TILE_DIM, lc = li % TILE_DIM; - uint32_t out_row = tile_col * TILE_DIM + lr; - uint32_t out_col = tile_row * TILE_DIM + lc; - if (out_row < N2 && out_col < N1) - batch_output[out_row * N1 + out_col] = shared[lr * tile_stride + lc]; - } -#endif -} - -// ============================================================================= -// Step 4: Row NTTs (Forward) -// ============================================================================= - -extern "C" __global__ void four_step_row_ntt( - uint64_t* data, const uint64_t* twiddles, const uint64_t* twiddle_precon, - const FourStepParams params) -{ -#ifdef __CUDA_ARCH__ - extern __shared__ uint64_t shared[]; - uint32_t thread_idx = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y; - uint32_t threadgroup_size = blockDim.x * blockDim.y * blockDim.z; - uint32_t N1 = params.N1, N2 = params.N2, N = params.N; - uint64_t Q = params.Q; - uint32_t batch_idx = blockIdx.z; - uint32_t tile_row = blockIdx.y, tile_col = blockIdx.x; - uint32_t tile_stride = params.tile_stride; - uint32_t TILE_N2 = min(N2, MAX_TILE_DIM); - uint32_t TILE_N1 = min(N1, MAX_TILE_DIM); - uint64_t* batch_data = data + batch_idx * N; - - uint32_t ept = (TILE_N2 * TILE_N1 + threadgroup_size - 1) / threadgroup_size; - for (uint32_t e = 0; e < ept; ++e) { - uint32_t li = thread_idx + e * threadgroup_size; - if (li >= TILE_N2 * TILE_N1) break; - uint32_t lr = li / TILE_N1, lc = li % TILE_N1; - uint32_t gr = tile_row * TILE_N2 + lr, gc = tile_col * TILE_N1 + lc; - if (gr < N2 && gc < N1) - shared[lr * tile_stride + lc] = batch_data[gr * N1 + gc]; - } - __syncthreads(); - - for (uint32_t row = 0; row < TILE_N2; ++row) { - threadgroup_ntt_forward(shared + row * tile_stride, 1, TILE_N1, params.log_N2, - thread_idx, threadgroup_size, twiddles, twiddle_precon, Q); - } - - for (uint32_t e = 0; e < ept; ++e) { - uint32_t li = thread_idx + e * threadgroup_size; - if (li >= TILE_N2 * TILE_N1) break; - uint32_t lr = li / TILE_N1, lc = li % TILE_N1; - uint32_t gr = tile_row * TILE_N2 + lr, gc = tile_col * TILE_N1 + lc; - if (gr < N2 && gc < N1) - batch_data[gr * N1 + gc] = shared[lr * tile_stride + lc]; - } -#endif -} - -// ============================================================================= -// Scaling and Pointwise -// ============================================================================= - -extern "C" __global__ void four_step_scale_n_inv(uint64_t* data, const FourStepParams params) -{ -#ifdef __CUDA_ARCH__ - uint32_t global_idx = blockIdx.x * blockDim.x + threadIdx.x; - uint32_t total_elements = params.N * params.batch_size; - if (global_idx >= total_elements) return; - data[global_idx] = barrett_mul_precon(data[global_idx], params.N_inv, params.Q, params.N_inv_precon); -#endif -} - -extern "C" __global__ void four_step_pointwise_mul( - uint64_t* result, const uint64_t* a, const uint64_t* b, const FourStepParams params) -{ -#ifdef __CUDA_ARCH__ - uint32_t global_idx = blockIdx.x * blockDim.x + threadIdx.x; - uint32_t total_elements = params.N * params.batch_size; - if (global_idx >= total_elements) return; - result[global_idx] = barrett_mul(a[global_idx], b[global_idx], params.Q, params.mu); -#endif -} - -// ============================================================================= -// Fused Four-Step NTT (N <= 4096) -// ============================================================================= - -extern "C" __global__ void four_step_ntt_fused( - uint64_t* data, - const uint64_t* col_twiddles, const uint64_t* col_tw_precon, - const uint64_t* trans_twiddles, const uint64_t* trans_tw_precon, - const uint64_t* row_twiddles, const uint64_t* row_tw_precon, - const FourStepParams params) -{ -#ifdef __CUDA_ARCH__ - extern __shared__ uint64_t shared[]; - uint32_t thread_idx = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y; - uint32_t threadgroup_size = blockDim.x * blockDim.y * blockDim.z; - uint32_t N1 = params.N1, N2 = params.N2, N = params.N; - uint64_t Q = params.Q; - uint32_t batch_idx = blockIdx.x; - uint32_t log_N1 = params.log_N1, log_N2 = params.log_N2; - uint64_t* batch_data = data + batch_idx * N; - - // Load entire polynomial - uint32_t ept = (N + threadgroup_size - 1) / threadgroup_size; - for (uint32_t e = 0; e < ept; ++e) { - uint32_t li = thread_idx + e * threadgroup_size; - if (li < N) shared[li] = batch_data[li]; - } - __syncthreads(); - - // Column NTTs - for (uint32_t col = 0; col < N2; ++col) { - threadgroup_ntt_forward(shared + col, N2, N1, log_N1, - thread_idx, threadgroup_size, col_twiddles, col_tw_precon, Q); - } - - // Twiddle multiplication - for (uint32_t e = 0; e < ept; ++e) { - uint32_t li = thread_idx + e * threadgroup_size; - if (li < N) { - uint32_t i = li / N2; - uint32_t j = li % N2; - shared[li] = barrett_mul_precon(shared[li], trans_twiddles[i * N2 + j], Q, trans_tw_precon[i * N2 + j]); - } - } - __syncthreads(); - - // In-place transpose (square case) - if (N1 == N2) { - for (uint32_t e = 0; e < ept; ++e) { - uint32_t li = thread_idx + e * threadgroup_size; - if (li < N) { - uint32_t row = li / N2, col = li % N2; - if (row < col) { - uint32_t idx1 = row * N2 + col; - uint32_t idx2 = col * N1 + row; - uint64_t temp = shared[idx1]; - shared[idx1] = shared[idx2]; - shared[idx2] = temp; - } - } - } - } - __syncthreads(); - - // Row NTTs - for (uint32_t row = 0; row < N1; ++row) { - threadgroup_ntt_forward(shared + row * N2, 1, N2, log_N2, - thread_idx, threadgroup_size, row_twiddles, row_tw_precon, Q); - } - - // Write back - for (uint32_t e = 0; e < ept; ++e) { - uint32_t li = thread_idx + e * threadgroup_size; - if (li < N) batch_data[li] = shared[li]; - } -#endif -} diff --git a/ntt/gpu/cuda/ntt.cu b/ntt/gpu/cuda/ntt.cu deleted file mode 100644 index be5237d..0000000 --- a/ntt/gpu/cuda/ntt.cu +++ /dev/null @@ -1,242 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -/// @file ntt.cu -/// Shared Number Theoretic Transform (NTT) primitives for lattice-based PQ crypto. -/// CUDA port of ntt.metal -- byte-identical arithmetic output. -/// -/// Used by: ML-DSA (FIPS 204), ML-KEM (FIPS 203), Ringtail, SLH-DSA (FIPS 205) -/// -/// NTT operates over polynomial rings Z_q[x]/(x^n + 1). -/// The butterfly operations are perfectly parallel -- each layer of the -/// NTT can be dispatched across GPU threads. -/// -/// This file provides: -/// - Forward NTT (Cooley-Tukey butterfly) -/// - Inverse NTT (Gentleman-Sande butterfly) -/// - Pointwise polynomial multiplication in NTT domain -/// - Barrett reduction for arbitrary moduli -/// -/// Parameters are passed via constants so the same code works for: -/// ML-DSA: q=8380417, n=256 -/// ML-KEM: q=3329, n=256 -/// Ringtail: q=8380417, n=256 (same ring as ML-DSA) - -#ifndef NTT_CUDA_H -#define NTT_CUDA_H - -#include - -#ifdef __CUDA_ARCH__ -#define NTT_DEVICE __device__ __forceinline__ -#else -#define NTT_DEVICE inline -#define __global__ -#define __shared__ -#endif - -// ============================================================================= -// Barrett reduction: a mod q without division -// ============================================================================= - -/// Barrett reduction for q = 8380417 (ML-DSA / Ringtail) -NTT_DEVICE int32_t barrett_reduce_mldsa(int32_t a) { - const int32_t q = 8380417; - const int64_t v = 33554687LL; // floor(2^48 / q) + 1 - int64_t t = (int64_t)a * v >> 48; - int32_t r = a - (int32_t)t * q; - if (r < 0) r += q; - if (r >= q) r -= q; - return r; -} - -/// Barrett reduction for q = 3329 (ML-KEM) -NTT_DEVICE int32_t barrett_reduce_mlkem(int32_t a) { - const int32_t q = 3329; - const int64_t v = 5039835LL; // floor(2^36 / q) + 1 - int64_t t = (int64_t)a * v >> 36; - int32_t r = a - (int32_t)t * q; - if (r < 0) r += q; - if (r >= q) r -= q; - return r; -} - -/// Montgomery reduction for ML-DSA: aR^{-1} mod q, R = 2^32 -/// q_inv = -q^{-1} mod 2^32 = 58728449 -NTT_DEVICE int32_t mont_reduce_mldsa(int64_t a) { - const int32_t q = 8380417; - const int32_t q_inv = 58728449; // -q^(-1) mod 2^32 - int32_t t = (int32_t)a * q_inv; - int64_t u = (int64_t)t * q; - int32_t r = (int32_t)((a - u) >> 32); - if (r < 0) r += q; - return r; -} - -/// Montgomery reduction for ML-KEM: aR^{-1} mod q, R = 2^16 -/// q_inv = -q^{-1} mod 2^16 = 3327 -NTT_DEVICE int16_t mont_reduce_mlkem(int32_t a) { - const int16_t q = 3329; - const int16_t q_inv = 3327; // -q^(-1) mod 2^16 - int16_t t = (int16_t)a * q_inv; - int32_t u = (int32_t)t * q; - int16_t r = (int16_t)((a - u) >> 16); - return r; -} - -// ============================================================================= -// NTT butterfly operations (Cooley-Tukey, in-place) -// ============================================================================= - -/// Forward NTT butterfly for ML-DSA (q=8380417) -NTT_DEVICE void ntt_butterfly_mldsa(int32_t& a, int32_t& b, int32_t zeta) { - int32_t t = mont_reduce_mldsa((int64_t)zeta * b); - b = a - t; - a = a + t; - if (a >= 8380417) a -= 8380417; - if (b < 0) b += 8380417; -} - -/// Inverse NTT butterfly for ML-DSA (Gentleman-Sande) -NTT_DEVICE void inv_ntt_butterfly_mldsa(int32_t& a, int32_t& b, int32_t zeta) { - int32_t t = a; - a = t + b; - b = t - b; - if (a >= 8380417) a -= 8380417; - if (b < 0) b += 8380417; - b = mont_reduce_mldsa((int64_t)zeta * b); -} - -/// Forward NTT butterfly for ML-KEM (q=3329) -NTT_DEVICE void ntt_butterfly_mlkem(int16_t& a, int16_t& b, int16_t zeta) { - int16_t t = mont_reduce_mlkem((int32_t)zeta * b); - b = a - t; - a = a + t; -} - -/// Inverse NTT butterfly for ML-KEM -NTT_DEVICE void inv_ntt_butterfly_mlkem(int16_t& a, int16_t& b, int16_t zeta) { - int16_t t = a; - a = t + b; - b = t - b; - b = mont_reduce_mlkem((int32_t)zeta * b); -} - -// ============================================================================= -// Precomputed zetas (roots of unity in Montgomery form) -// ============================================================================= - -#ifdef __CUDA_ARCH__ -__constant__ -#else -static const -#endif -int32_t MLDSA_ZETAS[128] = { - 25847, -2608894, -518909, 237124, -777960, -876248, 466468, 1826347, - 2353451, -359251, -2091905, 3119733, -2884855, 3111497, 2680103, 2725464, - 1024112, -1079900, 3585928, -549488, -1119584, 2619752, -2108549, -2118186, - -3859737, -1399561,-3277672, 1757237, -19422, 4010497, 280005, -2353451, - -1012179, -1277625, 1526252, -1402780, -2091905, 3119733, 3585928, -549488, - 2619752, -2108549, 2804197, -3199876, -38575, -2704181, 1757237, -19422, - 280005, 2706023, 1391570, 2287915, -3583748, -1399561, -3277672, -2353451, - 2353451, 3585928, -549488, 2619752, -2108549, 2804197, -3199876, -38575, - -2704181, 1757237, -19422, 280005, 2706023, 1391570, 2287915, -3583748, - -1399561, -3277672, 237124, -777960, -876248, 466468, 1826347, -2608894, - -518909, 237124, -777960, -876248, 466468, 1826347, 2353451, -359251, - -2091905, 3119733,-2884855, 3111497, 2680103, 2725464, 1024112, -1079900, - 3585928, -549488,-1119584, 2619752, -2108549, -2118186, -3859737, -1399561, - -3277672, 1757237, -19422, 4010497, 280005, -2353451, -1012179, -1277625, - 1526252, -1402780, 2706023, 1391570, 2287915, -3583748, -1399561, -3277672, - 1757237, -19422, 280005, 2706023, 1391570, 2287915, -3583748, -1399561 -}; - -// ============================================================================= -// Full NTT / inverse NTT for n=256 polynomials -// ============================================================================= - -/// In-place forward NTT for ML-DSA polynomial (n=256, q=8380417) -NTT_DEVICE void ntt_mldsa(int32_t poly[256]) { - int k = 0; - for (int len = 128; len >= 1; len >>= 1) { - for (int start = 0; start < 256; start += 2 * len) { - int32_t zeta = MLDSA_ZETAS[++k]; - for (int j = start; j < start + len; j++) { - ntt_butterfly_mldsa(poly[j], poly[j + len], zeta); - } - } - } -} - -/// In-place inverse NTT for ML-DSA polynomial -NTT_DEVICE void inv_ntt_mldsa(int32_t poly[256]) { - const int32_t q = 8380417; - const int32_t f = 41978; // 2^32 * 256^{-1} mod q - - int k = 127; - for (int len = 1; len <= 128; len <<= 1) { - for (int start = 0; start < 256; start += 2 * len) { - int32_t zeta = -MLDSA_ZETAS[k--]; - if (zeta < 0) zeta += q; - for (int j = start; j < start + len; j++) { - inv_ntt_butterfly_mldsa(poly[j], poly[j + len], zeta); - } - } - } - for (int i = 0; i < 256; i++) { - poly[i] = mont_reduce_mldsa((int64_t)f * poly[i]); - } -} - -/// Pointwise multiplication of two NTT-domain ML-DSA polynomials -NTT_DEVICE void poly_pointwise_mldsa(int32_t c[256], - const int32_t a[256], - const int32_t b[256]) { - for (int i = 0; i < 256; i++) { - c[i] = mont_reduce_mldsa((int64_t)a[i] * b[i]); - } -} - -// ============================================================================= -// NTT batch kernel: each thread transforms one polynomial -// ============================================================================= - -/// Batch forward NTT for ML-DSA polynomials. -/// Each thread computes NTT of one 256-coefficient polynomial. -extern "C" __global__ void ntt_mldsa_batch( - int32_t* polys, - const uint32_t num_polys) -{ -#ifdef __CUDA_ARCH__ - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= num_polys) return; - - int32_t poly[256]; - int32_t* src = polys + tid * 256; - for (int i = 0; i < 256; i++) poly[i] = src[i]; - - ntt_mldsa(poly); - - for (int i = 0; i < 256; i++) src[i] = poly[i]; -#endif -} - -/// Batch inverse NTT for ML-DSA polynomials. -extern "C" __global__ void inv_ntt_mldsa_batch( - int32_t* polys, - const uint32_t num_polys) -{ -#ifdef __CUDA_ARCH__ - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= num_polys) return; - - int32_t poly[256]; - int32_t* src = polys + tid * 256; - for (int i = 0; i < 256; i++) poly[i] = src[i]; - - inv_ntt_mldsa(poly); - - for (int i = 0; i < 256; i++) src[i] = poly[i]; -#endif -} - -#endif // NTT_CUDA_H diff --git a/ntt/gpu/cuda/ntt_kernels.cu b/ntt/gpu/cuda/ntt_kernels.cu deleted file mode 100644 index 8e791c9..0000000 --- a/ntt/gpu/cuda/ntt_kernels.cu +++ /dev/null @@ -1,479 +0,0 @@ -// ============================================================================= -// Optimal NTT Kernels for Lux FHE - CUDA Port of ntt_kernels.metal -// ============================================================================= -// -// Design based on OpenFHE's NumberTheoreticTransformNat: -// - Forward: Cooley-Tukey (DIT) with bit-reversed output -// - Inverse: Gentleman-Sande (GS) with bit-reversed input -// - Barrett reduction with precomputed constants (ModMulFastConst) -// - Byte-identical output to Metal implementation -// -// CUDA advantages over Metal for this kernel: -// - Native __umul64hi for 64-bit mulhi -// - Larger shared memory (48KB vs 32KB) -// - Warp-synchronous execution eliminates some barriers - -#include - -#ifdef __CUDA_ARCH__ -#define NTT_DEVICE __device__ __forceinline__ -#else -#define NTT_DEVICE inline -#define __global__ -#define __shared__ -static inline uint64_t __umul64hi(uint64_t a, uint64_t b) { - __uint128_t r = (__uint128_t)a * b; - return (uint64_t)(r >> 64); -} -#endif - -// ============================================================================= -// NTT Parameters Structure -// ============================================================================= - -struct NTTParams { - uint64_t Q; // Prime modulus - uint64_t mu; // Barrett constant: floor(2^64 / Q) - uint64_t N_inv; // N^{-1} mod Q - uint64_t N_inv_precon; // Barrett precomputation for N_inv - uint32_t N; // Ring dimension (power of 2) - uint32_t log_N; // log2(N) -}; - -// ============================================================================= -// Barrett Modular Multiplication -// ============================================================================= - -NTT_DEVICE uint64_t mod_mul_barrett(uint64_t a, uint64_t omega, uint64_t Q, uint64_t precon_omega) { - uint64_t q_approx = __umul64hi(a, precon_omega); - uint64_t product = a * omega; - uint64_t result = product - q_approx * Q; - return result >= Q ? result - Q : result; -} - -NTT_DEVICE uint64_t mod_mul(uint64_t a, uint64_t b, uint64_t Q) { -#ifdef __CUDA_ARCH__ - uint64_t lo = a * b; - uint64_t hi = __umul64hi(a, b); - - if (hi == 0) { - return lo % Q; - } - - uint64_t two64_mod_q = ((uint64_t(1) << 32) % Q); - two64_mod_q = (two64_mod_q * two64_mod_q) % Q; - - return (lo % Q + (hi % Q) * two64_mod_q % Q) % Q; -#else - __uint128_t prod = (__uint128_t)a * b; - return (uint64_t)(prod % Q); -#endif -} - -NTT_DEVICE uint64_t mod_add(uint64_t a, uint64_t b, uint64_t Q) { - uint64_t sum = a + b; - return sum - (sum >= Q ? Q : 0); -} - -NTT_DEVICE uint64_t mod_sub(uint64_t a, uint64_t b, uint64_t Q) { - return a + (b > a ? Q : 0) - b; -} - -// ============================================================================= -// Butterfly Operations -// ============================================================================= - -NTT_DEVICE void ct_butterfly(uint64_t* data, - uint32_t idx_lo, uint32_t idx_hi, - uint64_t omega, uint64_t precon_omega, - uint64_t Q) { - uint64_t lo_val = data[idx_lo]; - uint64_t hi_val = data[idx_hi]; - uint64_t omega_factor = mod_mul_barrett(hi_val, omega, Q, precon_omega); - data[idx_lo] = mod_add(lo_val, omega_factor, Q); - data[idx_hi] = mod_sub(lo_val, omega_factor, Q); -} - -NTT_DEVICE void gs_butterfly(uint64_t* data, - uint32_t idx_lo, uint32_t idx_hi, - uint64_t omega, uint64_t precon_omega, - uint64_t Q) { - uint64_t lo_val = data[idx_lo]; - uint64_t hi_val = data[idx_hi]; - uint64_t sum = mod_add(lo_val, hi_val, Q); - uint64_t diff = mod_sub(lo_val, hi_val, Q); - uint64_t diff_tw = mod_mul_barrett(diff, omega, Q, precon_omega); - data[idx_lo] = sum; - data[idx_hi] = diff_tw; -} - -// ============================================================================= -// Forward NTT Stage Kernel -// ============================================================================= - -extern "C" __global__ void ntt_forward_stage_optimal( - uint64_t* data, - const uint64_t* twiddles, - const uint64_t* precon_twiddles, - const NTTParams params, - uint32_t stage, - uint32_t batch_size) -{ -#ifdef __CUDA_ARCH__ - uint32_t batch_idx = blockIdx.y * blockDim.y + threadIdx.y; - uint32_t butterfly_idx = blockIdx.x * blockDim.x + threadIdx.x; - - if (batch_idx >= batch_size) return; - - uint32_t N = params.N; - uint64_t Q = params.Q; - uint32_t m = 1u << stage; - uint32_t t = N >> (stage + 1); - uint32_t num_butterflies = N >> 1; - if (butterfly_idx >= num_butterflies) return; - - uint32_t i = butterfly_idx / t; - uint32_t j = butterfly_idx % t; - uint32_t idx_lo = (i << (params.log_N - stage)) + j; - uint32_t idx_hi = idx_lo + t; - uint32_t tw_idx = m + i; - uint64_t omega = twiddles[tw_idx]; - uint64_t precon = precon_twiddles[tw_idx]; - - uint64_t* poly = data + batch_idx * N; - ct_butterfly(poly, idx_lo, idx_hi, omega, precon, Q); -#endif -} - -// ============================================================================= -// Inverse NTT Stage Kernel -// ============================================================================= - -extern "C" __global__ void ntt_inverse_stage_optimal( - uint64_t* data, - const uint64_t* inv_twiddles, - const uint64_t* precon_inv_twiddles, - const NTTParams params, - uint32_t stage, - uint32_t batch_size) -{ -#ifdef __CUDA_ARCH__ - uint32_t batch_idx = blockIdx.y * blockDim.y + threadIdx.y; - uint32_t butterfly_idx = blockIdx.x * blockDim.x + threadIdx.x; - - if (batch_idx >= batch_size) return; - - uint32_t N = params.N; - uint64_t Q = params.Q; - uint32_t m = N >> (stage + 1); - uint32_t t = 1u << stage; - uint32_t num_butterflies = N >> 1; - if (butterfly_idx >= num_butterflies) return; - - uint32_t i = butterfly_idx / t; - uint32_t j = butterfly_idx % t; - uint32_t idx_lo = (i << (stage + 1)) + j; - uint32_t idx_hi = idx_lo + t; - uint32_t tw_idx = m + i; - uint64_t omega = inv_twiddles[tw_idx]; - uint64_t precon = precon_inv_twiddles[tw_idx]; - - uint64_t* poly = data + batch_idx * N; - gs_butterfly(poly, idx_lo, idx_hi, omega, precon, Q); -#endif -} - -// ============================================================================= -// Scale by N^{-1} after inverse NTT -// ============================================================================= - -extern "C" __global__ void ntt_scale_optimal( - uint64_t* data, - const NTTParams params, - uint32_t batch_size) -{ -#ifdef __CUDA_ARCH__ - uint32_t batch_idx = blockIdx.y * blockDim.y + threadIdx.y; - uint32_t coeff_idx = blockIdx.x * blockDim.x + threadIdx.x; - - if (batch_idx >= batch_size || coeff_idx >= params.N) return; - - uint64_t* poly = data + batch_idx * params.N; - poly[coeff_idx] = mod_mul_barrett(poly[coeff_idx], params.N_inv, params.Q, params.N_inv_precon); -#endif -} - -// ============================================================================= -// Complete Forward NTT (All Stages in Shared Memory) -// ============================================================================= - -extern "C" __global__ void ntt_forward_complete_optimal( - uint64_t* data, - const uint64_t* twiddles, - const uint64_t* precon_twiddles, - const NTTParams params, - uint32_t batch_size) -{ -#ifdef __CUDA_ARCH__ - extern __shared__ uint64_t shared[]; - - uint32_t batch_idx = blockIdx.y; - uint32_t local_idx = threadIdx.x; - uint32_t tg_size = blockDim.x; - - if (batch_idx >= batch_size) return; - - uint32_t N = params.N; - uint32_t log_N = params.log_N; - uint64_t Q = params.Q; - uint64_t* poly = data + batch_idx * N; - - // Load into shared memory - for (uint32_t i = local_idx; i < N; i += tg_size) { - shared[i] = poly[i]; - } - __syncthreads(); - - // Cooley-Tukey stages - for (uint32_t stage = 0; stage < log_N; ++stage) { - uint32_t m = 1u << stage; - uint32_t t = N >> (stage + 1); - - for (uint32_t butterfly_idx = local_idx; butterfly_idx < N / 2; butterfly_idx += tg_size) { - uint32_t i = butterfly_idx / t; - uint32_t j = butterfly_idx % t; - uint32_t idx_lo = (i << (log_N - stage)) + j; - uint32_t idx_hi = idx_lo + t; - uint32_t tw_idx = m + i; - - uint64_t lo_val = shared[idx_lo]; - uint64_t hi_val = shared[idx_hi]; - uint64_t omega = twiddles[tw_idx]; - uint64_t precon = precon_twiddles[tw_idx]; - uint64_t omega_factor = mod_mul_barrett(hi_val, omega, Q, precon); - - shared[idx_lo] = mod_add(lo_val, omega_factor, Q); - shared[idx_hi] = mod_sub(lo_val, omega_factor, Q); - } - __syncthreads(); - } - - // Write back - for (uint32_t i = local_idx; i < N; i += tg_size) { - poly[i] = shared[i]; - } -#endif -} - -// ============================================================================= -// Complete Inverse NTT (All Stages + Scaling) -// ============================================================================= - -extern "C" __global__ void ntt_inverse_complete_optimal( - uint64_t* data, - const uint64_t* inv_twiddles, - const uint64_t* precon_inv_twiddles, - const NTTParams params, - uint32_t batch_size) -{ -#ifdef __CUDA_ARCH__ - extern __shared__ uint64_t shared[]; - - uint32_t batch_idx = blockIdx.y; - uint32_t local_idx = threadIdx.x; - uint32_t tg_size = blockDim.x; - - if (batch_idx >= batch_size) return; - - uint32_t N = params.N; - uint32_t log_N = params.log_N; - uint64_t Q = params.Q; - uint64_t N_inv = params.N_inv; - uint64_t N_inv_precon = params.N_inv_precon; - uint64_t* poly = data + batch_idx * N; - - for (uint32_t i = local_idx; i < N; i += tg_size) { - shared[i] = poly[i]; - } - __syncthreads(); - - // Gentleman-Sande stages - for (uint32_t stage = 0; stage < log_N; ++stage) { - uint32_t m = N >> (stage + 1); - uint32_t t = 1u << stage; - - for (uint32_t butterfly_idx = local_idx; butterfly_idx < N / 2; butterfly_idx += tg_size) { - uint32_t i = butterfly_idx / t; - uint32_t j = butterfly_idx % t; - uint32_t idx_lo = (i << (stage + 1)) + j; - uint32_t idx_hi = idx_lo + t; - uint32_t tw_idx = m + i; - - uint64_t lo_val = shared[idx_lo]; - uint64_t hi_val = shared[idx_hi]; - uint64_t omega = inv_twiddles[tw_idx]; - uint64_t precon = precon_inv_twiddles[tw_idx]; - - shared[idx_lo] = mod_add(lo_val, hi_val, Q); - uint64_t diff = mod_sub(lo_val, hi_val, Q); - shared[idx_hi] = mod_mul_barrett(diff, omega, Q, precon); - } - __syncthreads(); - } - - // Scale by N^{-1} and write back - for (uint32_t i = local_idx; i < N; i += tg_size) { - poly[i] = mod_mul_barrett(shared[i], N_inv, Q, N_inv_precon); - } -#endif -} - -// ============================================================================= -// Negacyclic Rotation for Blind Rotation -// ============================================================================= - -extern "C" __global__ void negacyclic_rotate_optimal( - uint64_t* output, - const uint64_t* input, - const NTTParams params, - const int32_t* rotations, - uint32_t batch_size) -{ -#ifdef __CUDA_ARCH__ - uint32_t batch_idx = blockIdx.y * blockDim.y + threadIdx.y; - uint32_t coeff_idx = blockIdx.x * blockDim.x + threadIdx.x; - - if (batch_idx >= batch_size || coeff_idx >= params.N) return; - - uint32_t N = params.N; - uint64_t Q = params.Q; - - int32_t k = rotations[batch_idx]; - int32_t two_N = 2 * (int32_t)N; - k = ((k % two_N) + two_N) % two_N; - - int32_t src_idx = (int32_t)coeff_idx - k; - bool negate = false; - - while (src_idx < 0) { - src_idx += N; - negate = !negate; - } - while (src_idx >= (int32_t)N) { - src_idx -= N; - negate = !negate; - } - - uint32_t in_offset = batch_idx * N + (uint32_t)src_idx; - uint32_t out_offset = batch_idx * N + coeff_idx; - - uint64_t val = input[in_offset]; - output[out_offset] = negate ? (Q - val) : val; -#endif -} - -// ============================================================================= -// Pointwise Multiply-Accumulate for External Product -// ============================================================================= - -extern "C" __global__ void ntt_pointwise_mac_optimal( - uint64_t* acc, - const uint64_t* a, - const uint64_t* b, - const NTTParams params, - uint32_t batch_size) -{ -#ifdef __CUDA_ARCH__ - uint32_t batch_idx = blockIdx.y * blockDim.y + threadIdx.y; - uint32_t coeff_idx = blockIdx.x * blockDim.x + threadIdx.x; - - if (batch_idx >= batch_size || coeff_idx >= params.N) return; - - uint32_t idx = batch_idx * params.N + coeff_idx; - uint64_t Q = params.Q; - uint64_t prod = mod_mul(a[idx], b[idx], Q); - acc[idx] = mod_add(acc[idx], prod, Q); -#endif -} - -// ============================================================================= -// Digit Decomposition for External Product -// ============================================================================= - -extern "C" __global__ void decompose_digits( - uint64_t* digits, - const uint64_t* poly, - const NTTParams params, - uint64_t base, - uint32_t num_levels, - uint32_t batch_size) -{ -#ifdef __CUDA_ARCH__ - uint32_t batch_idx = blockIdx.z * blockDim.z + threadIdx.z; - uint32_t level = blockIdx.y * blockDim.y + threadIdx.y; - uint32_t coeff_idx = blockIdx.x * blockDim.x + threadIdx.x; - - if (batch_idx >= batch_size || level >= num_levels || coeff_idx >= params.N) return; - - uint64_t val = poly[batch_idx * params.N + coeff_idx]; - - for (uint32_t l = 0; l < level; ++l) { - val /= base; - } - uint64_t digit = val % base; - - digits[batch_idx * num_levels * params.N + level * params.N + coeff_idx] = digit; -#endif -} - -// ============================================================================= -// CMux Difference -// ============================================================================= - -extern "C" __global__ void cmux_diff( - uint64_t* diff, - const uint64_t* d0, - const uint64_t* d1, - const NTTParams params, - uint32_t batch_size) -{ -#ifdef __CUDA_ARCH__ - uint32_t batch_idx = blockIdx.y * blockDim.y + threadIdx.y; - uint32_t coeff_idx = blockIdx.x * blockDim.x + threadIdx.x; - - if (batch_idx >= batch_size || coeff_idx >= params.N) return; - - uint32_t idx = batch_idx * params.N + coeff_idx; - diff[idx] = mod_sub(d1[idx], d0[idx], params.Q); -#endif -} - -// ============================================================================= -// External Product Finalize -// ============================================================================= - -extern "C" __global__ void external_product_finalize( - uint64_t* acc, - const uint64_t* prod, - const NTTParams params, - uint32_t num_levels, - uint32_t batch_size) -{ -#ifdef __CUDA_ARCH__ - uint32_t batch_idx = blockIdx.y * blockDim.y + threadIdx.y; - uint32_t coeff_idx = blockIdx.x * blockDim.x + threadIdx.x; - - if (batch_idx >= batch_size || coeff_idx >= params.N) return; - - uint64_t Q = params.Q; - uint64_t sum = 0; - - for (uint32_t l = 0; l < num_levels; ++l) { - uint32_t idx = batch_idx * num_levels * params.N + l * params.N + coeff_idx; - sum = mod_add(sum, prod[idx], Q); - } - - uint32_t out_idx = batch_idx * params.N + coeff_idx; - acc[out_idx] = mod_add(acc[out_idx], sum, Q); -#endif -} diff --git a/ntt/gpu/cuda/ntt_large.cu b/ntt/gpu/cuda/ntt_large.cu deleted file mode 100644 index df1bd6b..0000000 --- a/ntt/gpu/cuda/ntt_large.cu +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CUDA driver for the six-step large-N NTT. -// -// Wire pattern (matches kzg/gpu/cuda/kzg_driver_cuda.cpp, -// banderwagon/gpu/cuda/banderwagon_driver.cpp, and the rest of the GPU -// drivers in this repo): on hosts with a real CUDA device this would launch -// the kernels in cuda/four_step_ntt.cu; without a device it falls through -// to the CPU oracle so byte-equality with CPU is structurally exact and the -// caller's surface is the same. -// -// The kernel arithmetic in four_step_ntt.cu uses Barrett reduction with the -// caller-supplied prime modulus. Both the CPU oracle and the kernel produce -// values in [0, q) at every stage and at the final output, so byte-equality -// is invariant of which path runs. -// -// For TFHE (q = 2^64) the kernel collapses to plain machine-word arithmetic -// (Barrett constants degenerate when q has no bits above 64); the CPU oracle -// uses exactly machine-word arithmetic by routing through ntt_large with -// q = 0 sentinel. Same observable output. - -#include "ntt_large.hpp" - -namespace lux::crypto::ntt::large::gpu_cuda { - -// Returns true when a real CUDA device is reachable. The CPU-fallback build -// of this TU stubs to false; the CUDA-enabled build (CRYPTO_ENABLE_CUDA) -// would link against runtime/driver and probe. -bool device_available() { - return false; -} - -// Dispatch the six-step forward NTT on the GPU. Falls through to the CPU -// oracle when no device is reachable. Output is byte-identical either way. -void forward(uint64_t* a, const LargeContext& ctx) { - // Future: when CRYPTO_ENABLE_CUDA is on and device_available() - // returns true, launch four_step_column_ntt + four_step_twiddle_transpose - // + four_step_row_ntt against `a`. The kernel parameters (Q, mu, twiddles - // tables, N1, N2, log_N1, log_N2) are derived from ctx. - lux::crypto::ntt::large::forward(a, ctx); -} - -void inverse(uint64_t* a, const LargeContext& ctx) { - lux::crypto::ntt::large::inverse(a, ctx); -} - -} // namespace lux::crypto::ntt::large::gpu_cuda diff --git a/ntt/gpu/cuda/ntt_metal_kernel.cu b/ntt/gpu/cuda/ntt_metal_kernel.cu deleted file mode 100644 index 96905b8..0000000 --- a/ntt/gpu/cuda/ntt_metal_kernel.cu +++ /dev/null @@ -1,208 +0,0 @@ -// ============================================================================= -// NTT CUDA Kernels with Shared Memory Twiddle Prefetch -// ============================================================================= -// CUDA port of ntt_metal_kernel.metal -- byte-identical arithmetic output. -// -// Copyright (C) 2024-2025 Lux Industries Inc. -// SPDX-License-Identifier: Apache-2.0 - -#include - -#ifdef __CUDA_ARCH__ -#define NTT_DEVICE __device__ __forceinline__ -#else -#define NTT_DEVICE inline -#define __global__ -#define __shared__ -static inline uint64_t __umul64hi(uint64_t a, uint64_t b) { - __uint128_t r = (__uint128_t)a * b; return (uint64_t)(r >> 64); -} -static inline void __syncthreads() {} -static inline void __threadfence_block() {} -#endif - -struct NTTMetalParams { - uint64_t Q; - uint64_t mu; - uint64_t N_inv; - uint64_t N_inv_precon; - uint32_t N; - uint32_t log_N; - uint32_t stage; - uint32_t batch; -}; - -NTT_DEVICE uint64_t barrett_mul(uint64_t a, uint64_t b, uint64_t Q, uint64_t mu) { - uint64_t lo = a * b; - uint64_t q = __umul64hi(lo, mu); - uint64_t result = lo - q * Q; - if (result >= Q) result -= Q; - return result; -} - -NTT_DEVICE uint64_t mod_add(uint64_t a, uint64_t b, uint64_t Q) { - uint64_t sum = a + b; - return (sum >= Q) ? sum - Q : sum; -} - -NTT_DEVICE uint64_t mod_sub(uint64_t a, uint64_t b, uint64_t Q) { - return (a >= b) ? a - b : a + Q - b; -} - -static const uint32_t MAX_SHARED_TWIDDLES = 4096; - -extern "C" __global__ void ntt_forward_stage_shared( - uint64_t* data, const uint64_t* twiddles, const NTTMetalParams params) -{ -#ifdef __CUDA_ARCH__ - __shared__ uint64_t twiddles_shared[MAX_SHARED_TWIDDLES]; - uint32_t thread_idx = threadIdx.x; - uint32_t threadgroup_size = blockDim.x; - uint32_t batch_idx = blockIdx.x; - uint32_t N = params.N; - uint64_t Q = params.Q; - uint64_t mu = params.mu; - uint32_t stage = params.stage; - uint32_t m = 1u << stage; - uint32_t t = N >> (stage + 1); - - uint32_t twiddles_to_load = m; - uint32_t loads_per_thread = (twiddles_to_load + threadgroup_size - 1) / threadgroup_size; - for (uint32_t i = 0; i < loads_per_thread; ++i) { - uint32_t tw_idx = thread_idx + i * threadgroup_size; - if (tw_idx < m && tw_idx < MAX_SHARED_TWIDDLES) - twiddles_shared[tw_idx] = twiddles[m + tw_idx]; - } - __syncthreads(); - - uint64_t* batch_data = data + batch_idx * N; - uint32_t butterflies_per_thread = (N / 2 + threadgroup_size - 1) / threadgroup_size; - for (uint32_t b = 0; b < butterflies_per_thread; ++b) { - uint32_t butterfly_idx = thread_idx + b * threadgroup_size; - if (butterfly_idx >= N / 2) break; - uint32_t group = butterfly_idx / t; - uint32_t elem = butterfly_idx % t; - uint32_t idx_lo = (group << (params.log_N - stage)) + elem; - uint32_t idx_hi = idx_lo + t; - uint64_t lo = batch_data[idx_lo]; - uint64_t hi = batch_data[idx_hi]; - uint64_t tw = twiddles_shared[group]; - uint64_t hi_tw = barrett_mul(hi, tw, Q, mu); - batch_data[idx_lo] = mod_add(lo, hi_tw, Q); - batch_data[idx_hi] = mod_sub(lo, hi_tw, Q); - } -#endif -} - -extern "C" __global__ void ntt_inverse_stage_shared( - uint64_t* data, const uint64_t* twiddles, const NTTMetalParams params) -{ -#ifdef __CUDA_ARCH__ - __shared__ uint64_t twiddles_shared[MAX_SHARED_TWIDDLES]; - uint32_t thread_idx = threadIdx.x; - uint32_t threadgroup_size = blockDim.x; - uint32_t batch_idx = blockIdx.x; - uint32_t N = params.N; - uint64_t Q = params.Q; - uint64_t mu = params.mu; - uint32_t stage = params.stage; - uint32_t m = N >> (stage + 1); - uint32_t t = 1u << stage; - - uint32_t twiddles_to_load = m; - uint32_t loads_per_thread = (twiddles_to_load + threadgroup_size - 1) / threadgroup_size; - for (uint32_t i = 0; i < loads_per_thread; ++i) { - uint32_t tw_idx = thread_idx + i * threadgroup_size; - if (tw_idx < m && tw_idx < MAX_SHARED_TWIDDLES) - twiddles_shared[tw_idx] = twiddles[m + tw_idx]; - } - __syncthreads(); - - uint64_t* batch_data = data + batch_idx * N; - uint32_t butterflies_per_thread = (N / 2 + threadgroup_size - 1) / threadgroup_size; - for (uint32_t b = 0; b < butterflies_per_thread; ++b) { - uint32_t butterfly_idx = thread_idx + b * threadgroup_size; - if (butterfly_idx >= N / 2) break; - uint32_t group = butterfly_idx / t; - uint32_t elem = butterfly_idx % t; - uint32_t idx_lo = (group << (stage + 1)) + elem; - uint32_t idx_hi = idx_lo + t; - uint64_t lo = batch_data[idx_lo]; - uint64_t hi = batch_data[idx_hi]; - uint64_t tw = twiddles_shared[group]; - batch_data[idx_lo] = mod_add(lo, hi, Q); - uint64_t diff = mod_sub(lo, hi, Q); - batch_data[idx_hi] = barrett_mul(diff, tw, Q, mu); - } -#endif -} - -extern "C" __global__ void ntt_forward_fused( - uint64_t* data, const uint64_t* twiddles_flat, - const uint32_t* stage_offsets, const NTTMetalParams params) -{ -#ifdef __CUDA_ARCH__ - __shared__ uint64_t twiddles_shared[MAX_SHARED_TWIDDLES]; - uint32_t thread_idx = threadIdx.x; - uint32_t threadgroup_size = blockDim.x; - uint32_t batch_idx = blockIdx.x; - uint32_t N = params.N; - uint64_t Q = params.Q; - uint64_t mu = params.mu; - uint32_t log_N = params.log_N; - uint64_t* batch_data = data + batch_idx * N; - - uint32_t total_twiddles = N - 1; - uint32_t loads_per_thread = (total_twiddles + threadgroup_size - 1) / threadgroup_size; - for (uint32_t i = 0; i < loads_per_thread; ++i) { - uint32_t tw_idx = thread_idx + i * threadgroup_size; - if (tw_idx < total_twiddles && tw_idx < MAX_SHARED_TWIDDLES) - twiddles_shared[tw_idx] = twiddles_flat[tw_idx]; - } - __syncthreads(); - - for (uint32_t stage = 0; stage < log_N; ++stage) { - uint32_t m = 1u << stage; - uint32_t t = N >> (stage + 1); - uint32_t tw_base = m; - uint32_t bpt = (N / 2 + threadgroup_size - 1) / threadgroup_size; - for (uint32_t b = 0; b < bpt; ++b) { - uint32_t butterfly_idx = thread_idx + b * threadgroup_size; - if (butterfly_idx >= N / 2) break; - uint32_t group = butterfly_idx / t; - uint32_t elem = butterfly_idx % t; - uint32_t idx_lo = (group << (log_N - stage)) + elem; - uint32_t idx_hi = idx_lo + t; - uint64_t lo = batch_data[idx_lo]; - uint64_t hi = batch_data[idx_hi]; - uint64_t tw = twiddles_shared[tw_base + group]; - uint64_t hi_tw = barrett_mul(hi, tw, Q, mu); - batch_data[idx_lo] = mod_add(lo, hi_tw, Q); - batch_data[idx_hi] = mod_sub(lo, hi_tw, Q); - } - __threadfence_block(); - __syncthreads(); - } -#endif -} - -extern "C" __global__ void ntt_scale_ninv(uint64_t* data, const NTTMetalParams params) -{ -#ifdef __CUDA_ARCH__ - uint32_t global_idx = blockIdx.x * blockDim.x + threadIdx.x; - uint32_t total = params.N * params.batch; - if (global_idx >= total) return; - data[global_idx] = barrett_mul(data[global_idx], params.N_inv, params.Q, params.mu); -#endif -} - -extern "C" __global__ void pointwise_mul_mod( - uint64_t* result, const uint64_t* a, const uint64_t* b, const NTTMetalParams params) -{ -#ifdef __CUDA_ARCH__ - uint32_t global_idx = blockIdx.x * blockDim.x + threadIdx.x; - uint32_t total = params.N * params.batch; - if (global_idx >= total) return; - result[global_idx] = barrett_mul(a[global_idx], b[global_idx], params.Q, params.mu); -#endif -} diff --git a/ntt/gpu/cuda/ntt_unified_memory.cu b/ntt/gpu/cuda/ntt_unified_memory.cu deleted file mode 100644 index 798a015..0000000 --- a/ntt/gpu/cuda/ntt_unified_memory.cu +++ /dev/null @@ -1,365 +0,0 @@ -// ============================================================================= -// Unified/Managed Memory NTT CUDA Kernels -// ============================================================================= -// CUDA port of ntt_unified_memory.metal -- byte-identical arithmetic output. -// -// On CUDA, "unified memory" maps to cudaMallocManaged. The kernel structure -// is identical; the host-side allocation strategy differs. -// -// Copyright (C) 2024-2025 Lux Industries Inc. -// SPDX-License-Identifier: Apache-2.0 - -#include - -#ifdef __CUDA_ARCH__ -#define NTT_DEVICE __device__ __forceinline__ -#else -#define NTT_DEVICE inline -#define __global__ -#define __shared__ -static inline uint64_t __umul64hi(uint64_t a, uint64_t b) { - __uint128_t r = (__uint128_t)a * b; return (uint64_t)(r >> 64); -} -static inline void __syncthreads() {} -#endif - -static const uint32_t MAX_SHARED_TWIDDLES = 4096; - -struct NTTUnifiedParams { - uint64_t Q; - uint64_t mu; - uint64_t N_inv; - uint64_t N_inv_precon; - uint32_t N; - uint32_t log_N; - uint32_t stage; - uint32_t batch; -}; - -NTT_DEVICE uint64_t barrett_mul_unified(uint64_t a, uint64_t b, uint64_t Q, uint64_t mu) { - uint64_t lo = a * b; - uint64_t q = __umul64hi(lo, mu); - uint64_t result = lo - q * Q; - // Branch-free conditional subtraction - uint64_t mask = (result >= Q) ? ~0ULL : 0ULL; - result -= (Q & mask); - return result; -} - -NTT_DEVICE uint64_t mod_add_unified(uint64_t a, uint64_t b, uint64_t Q) { - uint64_t sum = a + b; - uint64_t mask = (sum >= Q) ? ~0ULL : 0ULL; - return sum - (Q & mask); -} - -NTT_DEVICE uint64_t mod_sub_unified(uint64_t a, uint64_t b, uint64_t Q) { - uint64_t diff = a - b; - uint64_t mask = (a < b) ? ~0ULL : 0ULL; - return diff + (Q & mask); -} - -NTT_DEVICE void ct_butterfly(uint64_t& lo, uint64_t& hi, - uint64_t tw, uint64_t Q, uint64_t mu) { - uint64_t hi_tw = barrett_mul_unified(hi, tw, Q, mu); - uint64_t new_lo = mod_add_unified(lo, hi_tw, Q); - uint64_t new_hi = mod_sub_unified(lo, hi_tw, Q); - lo = new_lo; - hi = new_hi; -} - -NTT_DEVICE void gs_butterfly(uint64_t& lo, uint64_t& hi, - uint64_t tw, uint64_t Q, uint64_t mu) { - uint64_t sum = mod_add_unified(lo, hi, Q); - uint64_t diff = mod_sub_unified(lo, hi, Q); - lo = sum; - hi = barrett_mul_unified(diff, tw, Q, mu); -} - -// ============================================================================= -// Forward NTT Stage -// ============================================================================= - -extern "C" __global__ void unified_ntt_forward_stage( - uint64_t* data, const uint64_t* twiddles, const NTTUnifiedParams params) -{ -#ifdef __CUDA_ARCH__ - extern __shared__ uint64_t shared_tw[]; - uint32_t tid = threadIdx.x; - uint32_t tg_size = blockDim.x; - uint32_t batch_idx = blockIdx.x; - uint32_t N = params.N; - uint64_t Q = params.Q; - uint64_t mu = params.mu; - uint32_t stage = params.stage; - uint32_t m = 1u << stage; - uint32_t t = N >> (stage + 1); - - uint32_t tw_to_load = min(m, MAX_SHARED_TWIDDLES); - uint32_t loads_per_thread = (tw_to_load + tg_size - 1) / tg_size; - for (uint32_t i = 0; i < loads_per_thread; ++i) { - uint32_t tw_idx = tid + i * tg_size; - if (tw_idx < tw_to_load) - shared_tw[tw_idx] = twiddles[m + tw_idx]; - } - __syncthreads(); - - uint64_t* batch_data = data + batch_idx * N; - uint32_t butterflies_total = N / 2; - uint32_t butterflies_per_thread = (butterflies_total + tg_size - 1) / tg_size; - for (uint32_t b = 0; b < butterflies_per_thread; ++b) { - uint32_t butterfly_idx = tid + b * tg_size; - if (butterfly_idx >= butterflies_total) break; - uint32_t group = butterfly_idx / t; - uint32_t elem = butterfly_idx % t; - uint32_t idx_lo = (group << (params.log_N - stage)) + elem; - uint32_t idx_hi = idx_lo + t; - uint64_t lo = batch_data[idx_lo]; - uint64_t hi = batch_data[idx_hi]; - uint64_t tw = (group < MAX_SHARED_TWIDDLES) ? shared_tw[group] : twiddles[m + group]; - ct_butterfly(lo, hi, tw, Q, mu); - batch_data[idx_lo] = lo; - batch_data[idx_hi] = hi; - } -#endif -} - -// ============================================================================= -// Inverse NTT Stage -// ============================================================================= - -extern "C" __global__ void unified_ntt_inverse_stage( - uint64_t* data, const uint64_t* twiddles, const NTTUnifiedParams params) -{ -#ifdef __CUDA_ARCH__ - extern __shared__ uint64_t shared_tw[]; - uint32_t tid = threadIdx.x; - uint32_t tg_size = blockDim.x; - uint32_t batch_idx = blockIdx.x; - uint32_t N = params.N; - uint64_t Q = params.Q; - uint64_t mu = params.mu; - uint32_t stage = params.stage; - uint32_t m = N >> (stage + 1); - uint32_t t = 1u << stage; - - uint32_t tw_to_load = min(m, MAX_SHARED_TWIDDLES); - uint32_t loads_per_thread = (tw_to_load + tg_size - 1) / tg_size; - for (uint32_t i = 0; i < loads_per_thread; ++i) { - uint32_t tw_idx = tid + i * tg_size; - if (tw_idx < tw_to_load) - shared_tw[tw_idx] = twiddles[m + tw_idx]; - } - __syncthreads(); - - uint64_t* batch_data = data + batch_idx * N; - uint32_t butterflies_total = N / 2; - uint32_t bpt = (butterflies_total + tg_size - 1) / tg_size; - for (uint32_t b = 0; b < bpt; ++b) { - uint32_t butterfly_idx = tid + b * tg_size; - if (butterfly_idx >= butterflies_total) break; - uint32_t group = butterfly_idx / t; - uint32_t elem = butterfly_idx % t; - uint32_t idx_lo = (group << (stage + 1)) + elem; - uint32_t idx_hi = idx_lo + t; - uint64_t lo = batch_data[idx_lo]; - uint64_t hi = batch_data[idx_hi]; - uint64_t tw = (group < MAX_SHARED_TWIDDLES) ? shared_tw[group] : twiddles[m + group]; - gs_butterfly(lo, hi, tw, Q, mu); - batch_data[idx_lo] = lo; - batch_data[idx_hi] = hi; - } -#endif -} - -// ============================================================================= -// Fused Forward NTT (all stages in shared memory, N <= 4096) -// ============================================================================= - -extern "C" __global__ void unified_ntt_forward_fused( - uint64_t* data, const uint64_t* twiddles, const NTTUnifiedParams params) -{ -#ifdef __CUDA_ARCH__ - // Shared memory: first half for twiddles, second half for data - extern __shared__ uint64_t smem[]; - uint64_t* shared_tw = smem; - uint64_t* shared_data = smem + MAX_SHARED_TWIDDLES; - - uint32_t tid = threadIdx.x; - uint32_t tg_size = blockDim.x; - uint32_t batch_idx = blockIdx.x; - uint32_t N = params.N; - uint64_t Q = params.Q; - uint64_t mu = params.mu; - uint32_t log_N = params.log_N; - uint64_t* batch_data = data + batch_idx * N; - - // Load twiddles - uint32_t total_twiddles = N - 1; - for (uint32_t i = tid; i < total_twiddles && i < MAX_SHARED_TWIDDLES; i += tg_size) { - shared_tw[i] = twiddles[i + 1]; - } - // Load polynomial - for (uint32_t i = tid; i < N; i += tg_size) { - shared_data[i] = batch_data[i]; - } - __syncthreads(); - - for (uint32_t stage = 0; stage < log_N; ++stage) { - uint32_t m = 1u << stage; - uint32_t t = N >> (stage + 1); - uint32_t tw_offset = m - 1; - uint32_t bpt = (N / 2 + tg_size - 1) / tg_size; - for (uint32_t b = 0; b < bpt; ++b) { - uint32_t butterfly_idx = tid + b * tg_size; - if (butterfly_idx >= N / 2) break; - uint32_t group = butterfly_idx / t; - uint32_t elem = butterfly_idx % t; - uint32_t idx_lo = (group << (log_N - stage)) + elem; - uint32_t idx_hi = idx_lo + t; - uint64_t lo = shared_data[idx_lo]; - uint64_t hi = shared_data[idx_hi]; - uint64_t tw = shared_tw[tw_offset + group]; - ct_butterfly(lo, hi, tw, Q, mu); - shared_data[idx_lo] = lo; - shared_data[idx_hi] = hi; - } - __syncthreads(); - } - - for (uint32_t i = tid; i < N; i += tg_size) { - batch_data[i] = shared_data[i]; - } -#endif -} - -// ============================================================================= -// Fused Inverse NTT -// ============================================================================= - -extern "C" __global__ void unified_ntt_inverse_fused( - uint64_t* data, const uint64_t* twiddles, const NTTUnifiedParams params) -{ -#ifdef __CUDA_ARCH__ - extern __shared__ uint64_t smem[]; - uint64_t* shared_tw = smem; - uint64_t* shared_data = smem + MAX_SHARED_TWIDDLES; - - uint32_t tid = threadIdx.x; - uint32_t tg_size = blockDim.x; - uint32_t batch_idx = blockIdx.x; - uint32_t N = params.N; - uint64_t Q = params.Q; - uint64_t mu = params.mu; - uint64_t N_inv = params.N_inv; - uint32_t log_N = params.log_N; - uint64_t* batch_data = data + batch_idx * N; - - uint32_t total_twiddles = N - 1; - for (uint32_t i = tid; i < total_twiddles && i < MAX_SHARED_TWIDDLES; i += tg_size) { - shared_tw[i] = twiddles[i + 1]; - } - for (uint32_t i = tid; i < N; i += tg_size) { - shared_data[i] = batch_data[i]; - } - __syncthreads(); - - for (uint32_t stage = 0; stage < log_N; ++stage) { - uint32_t m = N >> (stage + 1); - uint32_t t = 1u << stage; - uint32_t tw_offset = m - 1; - uint32_t bpt = (N / 2 + tg_size - 1) / tg_size; - for (uint32_t b = 0; b < bpt; ++b) { - uint32_t butterfly_idx = tid + b * tg_size; - if (butterfly_idx >= N / 2) break; - uint32_t group = butterfly_idx / t; - uint32_t elem = butterfly_idx % t; - uint32_t idx_lo = (group << (stage + 1)) + elem; - uint32_t idx_hi = idx_lo + t; - uint64_t lo = shared_data[idx_lo]; - uint64_t hi = shared_data[idx_hi]; - uint64_t tw = shared_tw[tw_offset + group]; - gs_butterfly(lo, hi, tw, Q, mu); - shared_data[idx_lo] = lo; - shared_data[idx_hi] = hi; - } - __syncthreads(); - } - - for (uint32_t i = tid; i < N; i += tg_size) { - batch_data[i] = barrett_mul_unified(shared_data[i], N_inv, Q, mu); - } -#endif -} - -// ============================================================================= -// Scaling and Pointwise Kernels -// ============================================================================= - -extern "C" __global__ void unified_scale_ninv(uint64_t* data, const NTTUnifiedParams params) -{ -#ifdef __CUDA_ARCH__ - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - uint32_t stride = gridDim.x * blockDim.x; - uint32_t total = params.N * params.batch; - uint64_t Q = params.Q; - uint64_t mu = params.mu; - uint64_t N_inv = params.N_inv; - for (uint32_t i = tid; i < total; i += stride) { - data[i] = barrett_mul_unified(data[i], N_inv, Q, mu); - } -#endif -} - -extern "C" __global__ void unified_pointwise_mul( - uint64_t* result, const uint64_t* a, const uint64_t* b, const NTTUnifiedParams params) -{ -#ifdef __CUDA_ARCH__ - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - uint32_t stride = gridDim.x * blockDim.x; - uint32_t total = params.N * params.batch; - uint64_t Q = params.Q; - uint64_t mu = params.mu; - for (uint32_t i = tid; i < total; i += stride) { - result[i] = barrett_mul_unified(a[i], b[i], Q, mu); - } -#endif -} - -extern "C" __global__ void unified_pointwise_add( - uint64_t* result, const uint64_t* a, const uint64_t* b, const NTTUnifiedParams params) -{ -#ifdef __CUDA_ARCH__ - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - uint32_t stride = gridDim.x * blockDim.x; - uint32_t total = params.N * params.batch; - uint64_t Q = params.Q; - for (uint32_t i = tid; i < total; i += stride) { - result[i] = mod_add_unified(a[i], b[i], Q); - } -#endif -} - -extern "C" __global__ void unified_pointwise_sub( - uint64_t* result, const uint64_t* a, const uint64_t* b, const NTTUnifiedParams params) -{ -#ifdef __CUDA_ARCH__ - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - uint32_t stride = gridDim.x * blockDim.x; - uint32_t total = params.N * params.batch; - uint64_t Q = params.Q; - for (uint32_t i = tid; i < total; i += stride) { - result[i] = mod_sub_unified(a[i], b[i], Q); - } -#endif -} - -extern "C" __global__ void unified_memcpy(uint64_t* dst, const uint64_t* src, uint32_t count) -{ -#ifdef __CUDA_ARCH__ - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - uint32_t stride = gridDim.x * blockDim.x; - for (uint32_t i = tid; i < count; i += stride) { - dst[i] = src[i]; - } -#endif -} diff --git a/ntt/gpu/cuda/twiddle_cache.cu b/ntt/gpu/cuda/twiddle_cache.cu deleted file mode 100644 index 4ec6e9d..0000000 --- a/ntt/gpu/cuda/twiddle_cache.cu +++ /dev/null @@ -1,288 +0,0 @@ -// ============================================================================= -// Twiddle Hotset Caching Kernels for CUDA -// ============================================================================= -// CUDA port of twiddle_cache.metal -- byte-identical arithmetic output. -// -// Copyright (C) 2024-2025 Lux Industries Inc. -// SPDX-License-Identifier: BSD-2-Clause - -#include - -#ifdef __CUDA_ARCH__ -#define NTT_DEVICE __device__ __forceinline__ -#else -#define NTT_DEVICE inline -#define __global__ -#define __shared__ -static inline uint64_t __umul64hi(uint64_t a, uint64_t b) { - __uint128_t r = (__uint128_t)a * b; return (uint64_t)(r >> 64); -} -static inline void __syncthreads() {} -static inline void __threadfence_block() {} -#endif - -static const uint32_t MAX_THREADGROUP_TWIDDLES = 4096; -static const uint32_t FIRST_LEVEL_TWIDDLE_COUNT = 8; -static const uint32_t MAX_RNS_PRIMES = 16; -static const uint32_t BANK_WIDTH = 32; -static const uint32_t BANK_PADDING = 1; - -struct PrimeConstants { - uint64_t q; - uint64_t q_inv; - uint64_t mu_hi; - uint64_t mu_lo; - uint64_t r_squared; - uint64_t root; - uint64_t root_inv; - uint64_t n_inv; -}; - -struct ConstantCache { - uint32_t numPrimes; - uint32_t ringDim; - uint32_t padding[2]; - PrimeConstants primes[MAX_RNS_PRIMES]; - uint64_t firstLevelTwiddles[MAX_RNS_PRIMES][FIRST_LEVEL_TWIDDLE_COUNT]; - uint64_t firstLevelInvTwiddles[MAX_RNS_PRIMES][FIRST_LEVEL_TWIDDLE_COUNT]; -}; - -struct NTTCacheParams { - uint64_t Q; - uint64_t mu; - uint64_t N_inv; - uint64_t N_inv_precon; - uint32_t N; - uint32_t log_N; - uint32_t stage; - uint32_t primeIdx; - uint32_t batch; - uint32_t prefetchStage; -}; - -NTT_DEVICE uint64_t barrett_mul(uint64_t a, uint64_t b, uint64_t Q, uint64_t mu) { - uint64_t lo = a * b; - uint64_t q = __umul64hi(lo, mu); - uint64_t result = lo - q * Q; - if (result >= Q) result -= Q; - return result; -} - -NTT_DEVICE uint64_t mod_add(uint64_t a, uint64_t b, uint64_t Q) { - uint64_t sum = a + b; - return (sum >= Q) ? sum - Q : sum; -} - -NTT_DEVICE uint64_t mod_sub(uint64_t a, uint64_t b, uint64_t Q) { - return (a >= b) ? a - b : a + Q - b; -} - -NTT_DEVICE uint32_t padded_index(uint32_t idx) { - return idx + (idx / BANK_WIDTH) * BANK_PADDING; -} - -// ============================================================================= -// Single Stage NTT with Hotset Caching -// ============================================================================= - -extern "C" __global__ void ntt_hotset_forward_stage( - uint64_t* data, const uint64_t* twiddles, - const ConstantCache cache, const NTTCacheParams params) -{ -#ifdef __CUDA_ARCH__ - __shared__ uint64_t twiddles_shared[MAX_THREADGROUP_TWIDDLES + MAX_THREADGROUP_TWIDDLES / BANK_WIDTH]; - __shared__ uint64_t twiddles_prefetch[MAX_THREADGROUP_TWIDDLES + MAX_THREADGROUP_TWIDDLES / BANK_WIDTH]; - - uint32_t thread_idx = threadIdx.x; - uint32_t threadgroup_size = blockDim.x; - uint32_t batch_idx = blockIdx.x; - uint32_t N = params.N; - uint64_t Q = params.Q; - uint64_t mu = params.mu; - uint32_t stage = params.stage; - uint32_t primeIdx = params.primeIdx; - uint32_t m = 1u << stage; - uint32_t t = N >> (stage + 1); - uint64_t* batch_data = data + batch_idx * N; - - bool use_constant_memory = (stage < 4 && m <= FIRST_LEVEL_TWIDDLE_COUNT); - - if (!use_constant_memory) { - uint32_t twiddles_to_load = m; - uint32_t loads_per_thread = (twiddles_to_load + threadgroup_size - 1) / threadgroup_size; - for (uint32_t i = 0; i < loads_per_thread; ++i) { - uint32_t tw_idx = thread_idx + i * threadgroup_size; - if (tw_idx < m) { - uint32_t padded = padded_index(tw_idx); - twiddles_shared[padded] = twiddles[m + tw_idx]; - } - } - if (params.prefetchStage < params.log_N && params.prefetchStage > stage) { - uint32_t next_m = 1u << params.prefetchStage; - uint32_t prefetch_loads = (next_m + threadgroup_size - 1) / threadgroup_size; - for (uint32_t i = 0; i < prefetch_loads; ++i) { - uint32_t tw_idx = thread_idx + i * threadgroup_size; - if (tw_idx < next_m && tw_idx < MAX_THREADGROUP_TWIDDLES) { - uint32_t padded = padded_index(tw_idx); - twiddles_prefetch[padded] = twiddles[next_m + tw_idx]; - } - } - } - __syncthreads(); - } - - uint32_t butterflies_per_thread = (N / 2 + threadgroup_size - 1) / threadgroup_size; - for (uint32_t b = 0; b < butterflies_per_thread; ++b) { - uint32_t butterfly_idx = thread_idx + b * threadgroup_size; - if (butterfly_idx >= N / 2) break; - uint32_t group = butterfly_idx / t; - uint32_t elem = butterfly_idx % t; - uint32_t idx_lo = (group << (params.log_N - stage)) + elem; - uint32_t idx_hi = idx_lo + t; - uint64_t lo = batch_data[idx_lo]; - uint64_t hi = batch_data[idx_hi]; - - uint64_t tw; - if (use_constant_memory) { - tw = cache.firstLevelTwiddles[primeIdx][group]; - } else { - tw = twiddles_shared[padded_index(group)]; - } - - uint64_t hi_tw = barrett_mul(hi, tw, Q, mu); - batch_data[idx_lo] = mod_add(lo, hi_tw, Q); - batch_data[idx_hi] = mod_sub(lo, hi_tw, Q); - } -#endif -} - -// ============================================================================= -// Inverse NTT Stage with Hotset -// ============================================================================= - -extern "C" __global__ void ntt_hotset_inverse_stage( - uint64_t* data, const uint64_t* twiddles, - const ConstantCache cache, const NTTCacheParams params) -{ -#ifdef __CUDA_ARCH__ - __shared__ uint64_t twiddles_shared[MAX_THREADGROUP_TWIDDLES + MAX_THREADGROUP_TWIDDLES / BANK_WIDTH]; - - uint32_t thread_idx = threadIdx.x; - uint32_t threadgroup_size = blockDim.x; - uint32_t batch_idx = blockIdx.x; - uint32_t N = params.N; - uint64_t Q = params.Q; - uint64_t mu = params.mu; - uint32_t stage = params.stage; - uint32_t primeIdx = params.primeIdx; - uint32_t m = N >> (stage + 1); - uint32_t t = 1u << stage; - uint64_t* batch_data = data + batch_idx * N; - - bool use_constant_memory = (stage >= params.log_N - 4 && m <= FIRST_LEVEL_TWIDDLE_COUNT); - - if (!use_constant_memory) { - uint32_t twiddles_to_load = m; - uint32_t loads_per_thread = (twiddles_to_load + threadgroup_size - 1) / threadgroup_size; - for (uint32_t i = 0; i < loads_per_thread; ++i) { - uint32_t tw_idx = thread_idx + i * threadgroup_size; - if (tw_idx < m) { - twiddles_shared[padded_index(tw_idx)] = twiddles[m + tw_idx]; - } - } - __syncthreads(); - } - - uint32_t bpt = (N / 2 + threadgroup_size - 1) / threadgroup_size; - for (uint32_t b = 0; b < bpt; ++b) { - uint32_t butterfly_idx = thread_idx + b * threadgroup_size; - if (butterfly_idx >= N / 2) break; - uint32_t group = butterfly_idx / t; - uint32_t elem = butterfly_idx % t; - uint32_t idx_lo = (group << (stage + 1)) + elem; - uint32_t idx_hi = idx_lo + t; - uint64_t lo = batch_data[idx_lo]; - uint64_t hi = batch_data[idx_hi]; - - uint64_t tw; - if (use_constant_memory) { - tw = cache.firstLevelInvTwiddles[primeIdx][group]; - } else { - tw = twiddles_shared[padded_index(group)]; - } - - uint64_t sum = mod_add(lo, hi, Q); - uint64_t diff = mod_sub(lo, hi, Q); - batch_data[idx_lo] = sum; - batch_data[idx_hi] = barrett_mul(diff, tw, Q, mu); - } -#endif -} - -// ============================================================================= -// Multi-Stage Fused NTT with Full Hotset -// ============================================================================= - -extern "C" __global__ void ntt_hotset_fused( - uint64_t* data, const uint64_t* twiddles_flat, - const ConstantCache cache, const NTTCacheParams params) -{ -#ifdef __CUDA_ARCH__ - __shared__ uint64_t twiddles_shared[MAX_THREADGROUP_TWIDDLES]; - - uint32_t thread_idx = threadIdx.x; - uint32_t threadgroup_size = blockDim.x; - uint32_t batch_idx = blockIdx.x; - uint32_t N = params.N; - uint64_t Q = params.Q; - uint64_t mu = params.mu; - uint32_t log_N = params.log_N; - uint64_t* batch_data = data + batch_idx * N; - - uint32_t total_twiddles = N - 1; - uint32_t loads_per_thread = (total_twiddles + threadgroup_size - 1) / threadgroup_size; - for (uint32_t i = 0; i < loads_per_thread; ++i) { - uint32_t tw_idx = thread_idx + i * threadgroup_size; - if (tw_idx < total_twiddles) - twiddles_shared[tw_idx] = twiddles_flat[tw_idx]; - } - __syncthreads(); - - for (uint32_t stage = 0; stage < log_N; ++stage) { - uint32_t m = 1u << stage; - uint32_t t = N >> (stage + 1); - uint32_t tw_base = m; - uint32_t bpt = (N / 2 + threadgroup_size - 1) / threadgroup_size; - for (uint32_t b = 0; b < bpt; ++b) { - uint32_t butterfly_idx = thread_idx + b * threadgroup_size; - if (butterfly_idx >= N / 2) break; - uint32_t group = butterfly_idx / t; - uint32_t elem = butterfly_idx % t; - uint32_t idx_lo = (group << (log_N - stage)) + elem; - uint32_t idx_hi = idx_lo + t; - uint64_t lo = batch_data[idx_lo]; - uint64_t hi = batch_data[idx_hi]; - uint64_t tw = twiddles_shared[tw_base + group]; - uint64_t hi_tw = barrett_mul(hi, tw, Q, mu); - batch_data[idx_lo] = mod_add(lo, hi_tw, Q); - batch_data[idx_hi] = mod_sub(lo, hi_tw, Q); - } - __threadfence_block(); - __syncthreads(); - } -#endif -} - -// ============================================================================= -// N^(-1) Scaling -// ============================================================================= - -extern "C" __global__ void ntt_hotset_scale_ninv(uint64_t* data, const NTTCacheParams params) -{ -#ifdef __CUDA_ARCH__ - uint32_t global_idx = blockIdx.x * blockDim.x + threadIdx.x; - uint32_t total = params.N * params.batch; - if (global_idx >= total) return; - data[global_idx] = barrett_mul(data[global_idx], params.N_inv, params.Q, params.mu); -#endif -} diff --git a/ntt/gpu/metal/four_step_ntt.metal b/ntt/gpu/metal/four_step_ntt.metal deleted file mode 100644 index 2092b9e..0000000 --- a/ntt/gpu/metal/four_step_ntt.metal +++ /dev/null @@ -1,963 +0,0 @@ -// ============================================================================= -// Four-Step NTT Optimized for Apple Metal Threadgroup Memory and SIMDgroup -// ============================================================================= -// -// This implements the Four-Step NTT algorithm specifically tuned for: -// - 32KB Metal threadgroup memory (vs 48KB CUDA shared memory) -// - 32-lane SIMDgroup operations via threadgroup memory -// - Coalesced memory access patterns for unified memory -// - Integer-only arithmetic for FHE determinism -// -// Four-Step Algorithm for N = N1 * N2: -// 1. N2 parallel column NTTs of size N1 -// 2. Twiddle multiplication by omega^(i*j) -// 3. Matrix transpose -// 4. N1 parallel row NTTs of size N2 -// -// NOTE: Metal's simd_shuffle does NOT support uint64_t. Instead, we use -// threadgroup memory for all inter-lane communication, which is still very -// fast (~20ns latency, ~3TB/s bandwidth per SIMD on M3). -// -// See patent: PAT-FHE-010-four-step-ntt-metal.md -// -// Copyright (C) 2024-2025 Lux Industries Inc. -// SPDX-License-Identifier: BSD-2-Clause -// ============================================================================= - -#include -using namespace metal; - -// ============================================================================= -// Constants -// ============================================================================= - -constant uint32_t SIMD_SIZE = 32; // Metal SIMDgroup size -constant uint32_t MAX_TILE_SIZE = 4096; // 32KB / 8 bytes = 4096 uint64_t -constant uint32_t MAX_TILE_DIM = 64; // sqrt(4096) = 64 for square tiles - -// ============================================================================= -// Parameters Structure -// ============================================================================= - -struct FourStepParams { - uint64_t Q; // Prime modulus - uint64_t mu; // Barrett constant: floor(2^64 / Q) - uint64_t N_inv; // N^{-1} mod Q for inverse NTT - uint64_t N_inv_precon; // Barrett precomputation for N_inv - uint32_t N; // Total ring dimension - uint32_t N1; // Column dimension for Four-Step - uint32_t N2; // Row dimension for Four-Step - uint32_t log_N1; // log2(N1) - uint32_t log_N2; // log2(N2) - uint32_t tile_stride; // Padded stride - uint32_t batch_size; // Number of polynomials to process -}; - -// ============================================================================= -// Modular Arithmetic (Integer-Only for Determinism) -// ============================================================================= - -/** - * @brief Barrett modular multiplication - * - * Computes (a * b) mod Q using Barrett reduction. - * Requires: a, b < Q and Q < 2^62 - * - * The precomputed constant precon = floor(2^64 * b / Q) allows - * faster approximate division. - */ -inline uint64_t barrett_mul_precon(uint64_t a, uint64_t b, uint64_t Q, uint64_t precon) { - // Approximate quotient: q ≈ (a * precon) >> 64 - uint64_t q_approx = metal::mulhi(a, precon); - - // Compute a * b - q_approx * Q - uint64_t product = a * b; - uint64_t result = product - q_approx * Q; - - // Conditional reduction (result may be in [0, 2Q)) - return (result >= Q) ? (result - Q) : result; -} - -/** - * @brief Simple Barrett multiplication without precomputation - * - * Used when precomputed constants are not available. - */ -inline uint64_t barrett_mul(uint64_t a, uint64_t b, uint64_t Q, uint64_t mu) { - uint64_t lo = a * b; - uint64_t q = metal::mulhi(lo, mu); - uint64_t result = lo - q * Q; - return (result >= Q) ? (result - Q) : result; -} - -/** - * @brief Modular addition: (a + b) mod Q - */ -inline uint64_t mod_add(uint64_t a, uint64_t b, uint64_t Q) { - uint64_t sum = a + b; - return (sum >= Q) ? (sum - Q) : sum; -} - -/** - * @brief Modular subtraction: (a - b) mod Q - */ -inline uint64_t mod_sub(uint64_t a, uint64_t b, uint64_t Q) { - return (a >= b) ? (a - b) : (a + Q - b); -} - -// ============================================================================= -// Cooley-Tukey Butterfly (for forward NTT) -// ============================================================================= - -/** - * @brief Single Cooley-Tukey butterfly operation - * - * CT: (lo, hi) -> (lo + hi*tw, lo - hi*tw) - */ -inline void ct_butterfly( - thread uint64_t& lo, - thread uint64_t& hi, - uint64_t tw, - uint64_t tw_pre, - uint64_t Q -) { - uint64_t hi_tw = barrett_mul_precon(hi, tw, Q, tw_pre); - uint64_t new_lo = mod_add(lo, hi_tw, Q); - uint64_t new_hi = mod_sub(lo, hi_tw, Q); - lo = new_lo; - hi = new_hi; -} - -/** - * @brief Single Gentleman-Sande butterfly operation (for inverse NTT) - * - * GS: (lo, hi) -> (lo + hi, (lo - hi) * tw) - */ -inline void gs_butterfly( - thread uint64_t& lo, - thread uint64_t& hi, - uint64_t tw, - uint64_t tw_pre, - uint64_t Q -) { - uint64_t sum = mod_add(lo, hi, Q); - uint64_t diff = mod_sub(lo, hi, Q); - uint64_t diff_tw = barrett_mul_precon(diff, tw, Q, tw_pre); - lo = sum; - hi = diff_tw; -} - -// ============================================================================= -// In-Threadgroup NTT (for small dimensions that fit in threadgroup memory) -// ============================================================================= - -/** - * @brief Forward NTT on data in threadgroup memory - * - * Performs complete NTT on column data using threadgroup memory. - * Each thread handles multiple elements as needed. - * - * @param shared Pointer to threadgroup memory containing column data - * @param stride Stride between consecutive elements (for column access) - * @param N Size of NTT - * @param log_N log2(N) - * @param thread_idx Thread index within group - * @param num_threads Number of threads in group - * @param twiddles Forward twiddle factors - * @param twiddle_precon Barrett precomputed constants - * @param Q Prime modulus - */ -inline void threadgroup_ntt_forward( - threadgroup uint64_t* shared, - uint32_t stride, - uint32_t N, - uint32_t log_N, - uint32_t thread_idx, - uint32_t num_threads, - constant uint64_t* twiddles, - constant uint64_t* twiddle_precon, - uint64_t Q -) { - // Cooley-Tukey DIT NTT - for (uint32_t stage = 0; stage < log_N; ++stage) { - uint32_t m = 1u << stage; // Number of groups - uint32_t t = N >> (stage + 1); // Half-size of each group - - uint32_t num_butterflies = N >> 1; - uint32_t butterflies_per_thread = (num_butterflies + num_threads - 1) / num_threads; - - for (uint32_t b = 0; b < butterflies_per_thread; ++b) { - uint32_t butterfly_idx = thread_idx + b * num_threads; - if (butterfly_idx >= num_butterflies) break; - - uint32_t group = butterfly_idx / t; - uint32_t j = butterfly_idx % t; - - uint32_t idx_lo = (group * 2 * t + j) * stride; - uint32_t idx_hi = idx_lo + t * stride; - - uint32_t tw_idx = m + group; - uint64_t tw = twiddles[tw_idx]; - uint64_t tw_pre = twiddle_precon[tw_idx]; - - uint64_t lo = shared[idx_lo]; - uint64_t hi = shared[idx_hi]; - - ct_butterfly(lo, hi, tw, tw_pre, Q); - - shared[idx_lo] = lo; - shared[idx_hi] = hi; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - } -} - -/** - * @brief Inverse NTT on data in threadgroup memory (Gentleman-Sande) - */ -inline void threadgroup_ntt_inverse( - threadgroup uint64_t* shared, - uint32_t stride, - uint32_t N, - uint32_t log_N, - uint32_t thread_idx, - uint32_t num_threads, - constant uint64_t* twiddles, - constant uint64_t* twiddle_precon, - uint64_t Q -) { - // Gentleman-Sande DIF INTT - for (uint32_t stage = 0; stage < log_N; ++stage) { - uint32_t m = N >> (stage + 1); // Number of groups - uint32_t t = 1u << stage; // Half-size of each group - - uint32_t num_butterflies = N >> 1; - uint32_t butterflies_per_thread = (num_butterflies + num_threads - 1) / num_threads; - - for (uint32_t b = 0; b < butterflies_per_thread; ++b) { - uint32_t butterfly_idx = thread_idx + b * num_threads; - if (butterfly_idx >= num_butterflies) break; - - uint32_t group = butterfly_idx / t; - uint32_t j = butterfly_idx % t; - - uint32_t idx_lo = (group * 2 * t + j) * stride; - uint32_t idx_hi = idx_lo + t * stride; - - uint32_t tw_idx = m + group; - uint64_t tw = twiddles[tw_idx]; - uint64_t tw_pre = twiddle_precon[tw_idx]; - - uint64_t lo = shared[idx_lo]; - uint64_t hi = shared[idx_hi]; - - gs_butterfly(lo, hi, tw, tw_pre, Q); - - shared[idx_lo] = lo; - shared[idx_hi] = hi; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - } -} - -// ============================================================================= -// Step 1: Column NTTs (Forward) -// ============================================================================= - -/** - * @brief Column NTT kernel for Four-Step algorithm - * - * Processes N2 parallel column NTTs of size N1. - * Each threadgroup handles one tile. - * - * Thread organization: - * - Threadgroup: up to 1024 threads - * - Threadgroup memory: tile_N1 x tile_stride elements (with padding) - */ -kernel void four_step_column_ntt( - device uint64_t* data [[buffer(0)]], - constant uint64_t* twiddles [[buffer(1)]], - constant uint64_t* twiddle_precon [[buffer(2)]], - constant FourStepParams& params [[buffer(3)]], - uint3 tg_pos [[threadgroup_position_in_grid]], - uint3 thread_pos_in_tg [[thread_position_in_threadgroup]], - uint3 tg_size [[threads_per_threadgroup]], - threadgroup uint64_t* shared [[threadgroup(0)]] -) { - uint32_t thread_idx = thread_pos_in_tg.x + thread_pos_in_tg.y * tg_size.x + thread_pos_in_tg.z * tg_size.x * tg_size.y; - uint32_t threadgroup_size = tg_size.x * tg_size.y * tg_size.z; - - uint32_t N1 = params.N1; - uint32_t N2 = params.N2; - uint32_t N = params.N; - uint64_t Q = params.Q; - uint32_t batch_idx = tg_pos.z; - uint32_t tile_row = tg_pos.y; - uint32_t tile_col = tg_pos.x; - uint32_t tile_stride = params.tile_stride; - - // Tile dimensions - uint32_t TILE_N1 = min(N1, MAX_TILE_DIM); - uint32_t TILE_N2 = min(N2, MAX_TILE_DIM); - - // Global offset for this batch and tile - device uint64_t* batch_data = data + batch_idx * N; - - // Phase 1: Cooperative load tile into shared memory - uint32_t elements_per_thread = (TILE_N1 * TILE_N2 + threadgroup_size - 1) / threadgroup_size; - - for (uint32_t e = 0; e < elements_per_thread; ++e) { - uint32_t local_idx = thread_idx + e * threadgroup_size; - if (local_idx >= TILE_N1 * TILE_N2) break; - - uint32_t local_row = local_idx / TILE_N2; - uint32_t local_col = local_idx % TILE_N2; - - uint32_t global_row = tile_row * TILE_N1 + local_row; - uint32_t global_col = tile_col * TILE_N2 + local_col; - - if (global_row < N1 && global_col < N2) { - uint32_t global_idx = global_row * N2 + global_col; - shared[local_row * tile_stride + local_col] = batch_data[global_idx]; - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Phase 2: Column NTTs - // Each column is processed by all threads cooperatively - uint32_t log_N1 = params.log_N1; - - for (uint32_t col = 0; col < TILE_N2; ++col) { - // NTT on column `col` using all threads - threadgroup_ntt_forward( - shared + col, // Start of column - tile_stride, // Stride = tile_stride (to next row) - TILE_N1, // Size - log_N1, - thread_idx, - threadgroup_size, - twiddles, - twiddle_precon, - Q - ); - } - - // Phase 3: Write back to global memory - for (uint32_t e = 0; e < elements_per_thread; ++e) { - uint32_t local_idx = thread_idx + e * threadgroup_size; - if (local_idx >= TILE_N1 * TILE_N2) break; - - uint32_t local_row = local_idx / TILE_N2; - uint32_t local_col = local_idx % TILE_N2; - - uint32_t global_row = tile_row * TILE_N1 + local_row; - uint32_t global_col = tile_col * TILE_N2 + local_col; - - if (global_row < N1 && global_col < N2) { - uint32_t global_idx = global_row * N2 + global_col; - batch_data[global_idx] = shared[local_row * tile_stride + local_col]; - } - } -} - -// ============================================================================= -// Step 1: Column NTTs (Inverse) -// ============================================================================= - -kernel void four_step_column_intt( - device uint64_t* data [[buffer(0)]], - constant uint64_t* twiddles [[buffer(1)]], - constant uint64_t* twiddle_precon [[buffer(2)]], - constant FourStepParams& params [[buffer(3)]], - uint3 tg_pos [[threadgroup_position_in_grid]], - uint3 thread_pos_in_tg [[thread_position_in_threadgroup]], - uint3 tg_size [[threads_per_threadgroup]], - threadgroup uint64_t* shared [[threadgroup(0)]] -) { - uint32_t thread_idx = thread_pos_in_tg.x + thread_pos_in_tg.y * tg_size.x + thread_pos_in_tg.z * tg_size.x * tg_size.y; - uint32_t threadgroup_size = tg_size.x * tg_size.y * tg_size.z; - - uint32_t N1 = params.N1; - uint32_t N2 = params.N2; - uint32_t N = params.N; - uint64_t Q = params.Q; - uint32_t batch_idx = tg_pos.z; - uint32_t tile_row = tg_pos.y; - uint32_t tile_col = tg_pos.x; - uint32_t tile_stride = params.tile_stride; - - uint32_t TILE_N1 = min(N1, MAX_TILE_DIM); - uint32_t TILE_N2 = min(N2, MAX_TILE_DIM); - - device uint64_t* batch_data = data + batch_idx * N; - - // Load - uint32_t elements_per_thread = (TILE_N1 * TILE_N2 + threadgroup_size - 1) / threadgroup_size; - - for (uint32_t e = 0; e < elements_per_thread; ++e) { - uint32_t local_idx = thread_idx + e * threadgroup_size; - if (local_idx >= TILE_N1 * TILE_N2) break; - - uint32_t local_row = local_idx / TILE_N2; - uint32_t local_col = local_idx % TILE_N2; - - uint32_t global_row = tile_row * TILE_N1 + local_row; - uint32_t global_col = tile_col * TILE_N2 + local_col; - - if (global_row < N1 && global_col < N2) { - uint32_t global_idx = global_row * N2 + global_col; - shared[local_row * tile_stride + local_col] = batch_data[global_idx]; - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Inverse column NTTs - uint32_t log_N1 = params.log_N1; - - for (uint32_t col = 0; col < TILE_N2; ++col) { - threadgroup_ntt_inverse( - shared + col, - tile_stride, - TILE_N1, - log_N1, - thread_idx, - threadgroup_size, - twiddles, - twiddle_precon, - Q - ); - } - - // Write back - for (uint32_t e = 0; e < elements_per_thread; ++e) { - uint32_t local_idx = thread_idx + e * threadgroup_size; - if (local_idx >= TILE_N1 * TILE_N2) break; - - uint32_t local_row = local_idx / TILE_N2; - uint32_t local_col = local_idx % TILE_N2; - - uint32_t global_row = tile_row * TILE_N1 + local_row; - uint32_t global_col = tile_col * TILE_N2 + local_col; - - if (global_row < N1 && global_col < N2) { - uint32_t global_idx = global_row * N2 + global_col; - batch_data[global_idx] = shared[local_row * tile_stride + local_col]; - } - } -} - -// ============================================================================= -// Step 2+3: Fused Twiddle Multiplication and Transpose -// ============================================================================= - -/** - * @brief Fused twiddle multiplication and transpose kernel - * - * Combines Step 2 (multiply by omega^(i*j)) with Step 3 (transpose). - * Uses bank-conflict-free shared memory access with padding. - * - * Input: N1 x N2 matrix - * Output: N2 x N1 matrix (transposed) - */ -kernel void four_step_twiddle_transpose( - device uint64_t* output [[buffer(0)]], - device const uint64_t* input [[buffer(1)]], - constant uint64_t* twiddles [[buffer(2)]], - constant uint64_t* twiddle_precon [[buffer(3)]], - constant FourStepParams& params [[buffer(4)]], - uint3 tg_pos [[threadgroup_position_in_grid]], - uint3 thread_pos_in_tg [[thread_position_in_threadgroup]], - uint3 tg_size [[threads_per_threadgroup]], - threadgroup uint64_t* shared [[threadgroup(0)]] -) { - uint32_t thread_idx = thread_pos_in_tg.x + thread_pos_in_tg.y * tg_size.x + thread_pos_in_tg.z * tg_size.x * tg_size.y; - uint32_t threadgroup_size = tg_size.x * tg_size.y * tg_size.z; - - uint32_t N1 = params.N1; - uint32_t N2 = params.N2; - uint32_t N = params.N; - uint64_t Q = params.Q; - uint32_t batch_idx = tg_pos.z; - uint32_t tile_row = tg_pos.y; - uint32_t tile_col = tg_pos.x; - uint32_t tile_stride = params.tile_stride; - - uint32_t TILE_DIM = MAX_TILE_DIM; - - device const uint64_t* batch_input = input + batch_idx * N; - device uint64_t* batch_output = output + batch_idx * N; - - uint32_t elements_per_thread = (TILE_DIM * TILE_DIM + threadgroup_size - 1) / threadgroup_size; - - // Phase 1: Read, apply twiddle, write to shared (transposed) - for (uint32_t e = 0; e < elements_per_thread; ++e) { - uint32_t local_idx = thread_idx + e * threadgroup_size; - if (local_idx >= TILE_DIM * TILE_DIM) break; - - uint32_t local_row = local_idx / TILE_DIM; - uint32_t local_col = local_idx % TILE_DIM; - - uint32_t global_row = tile_row * TILE_DIM + local_row; - uint32_t global_col = tile_col * TILE_DIM + local_col; - - if (global_row < N1 && global_col < N2) { - uint32_t in_idx = global_row * N2 + global_col; - uint64_t val = batch_input[in_idx]; - - // Apply twiddle factor omega^(i*j) - uint32_t tw_idx = global_row * N2 + global_col; - uint64_t tw = twiddles[tw_idx]; - uint64_t tw_pre = twiddle_precon[tw_idx]; - val = barrett_mul_precon(val, tw, Q, tw_pre); - - // Store transposed: [local_col][local_row] instead of [local_row][local_col] - shared[local_col * tile_stride + local_row] = val; - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Phase 2: Read from shared (already transposed), write to output - for (uint32_t e = 0; e < elements_per_thread; ++e) { - uint32_t local_idx = thread_idx + e * threadgroup_size; - if (local_idx >= TILE_DIM * TILE_DIM) break; - - uint32_t local_row = local_idx / TILE_DIM; - uint32_t local_col = local_idx % TILE_DIM; - - // Output position: swap tile coordinates for transpose - uint32_t out_row = tile_col * TILE_DIM + local_row; - uint32_t out_col = tile_row * TILE_DIM + local_col; - - if (out_row < N2 && out_col < N1) { - uint32_t out_idx = out_row * N1 + out_col; - batch_output[out_idx] = shared[local_row * tile_stride + local_col]; - } - } -} - -/** - * @brief Inverse twiddle and transpose for inverse NTT - */ -kernel void four_step_inv_twiddle_transpose( - device uint64_t* output [[buffer(0)]], - device const uint64_t* input [[buffer(1)]], - constant uint64_t* twiddles [[buffer(2)]], - constant uint64_t* twiddle_precon [[buffer(3)]], - constant FourStepParams& params [[buffer(4)]], - uint3 tg_pos [[threadgroup_position_in_grid]], - uint3 thread_pos_in_tg [[thread_position_in_threadgroup]], - uint3 tg_size [[threads_per_threadgroup]], - threadgroup uint64_t* shared [[threadgroup(0)]] -) { - uint32_t thread_idx = thread_pos_in_tg.x + thread_pos_in_tg.y * tg_size.x + thread_pos_in_tg.z * tg_size.x * tg_size.y; - uint32_t threadgroup_size = tg_size.x * tg_size.y * tg_size.z; - - uint32_t N1 = params.N1; - uint32_t N2 = params.N2; - uint32_t N = params.N; - uint64_t Q = params.Q; - uint32_t batch_idx = tg_pos.z; - uint32_t tile_row = tg_pos.y; - uint32_t tile_col = tg_pos.x; - uint32_t tile_stride = params.tile_stride; - uint32_t TILE_DIM = MAX_TILE_DIM; - - device const uint64_t* batch_input = input + batch_idx * N; - device uint64_t* batch_output = output + batch_idx * N; - - uint32_t elements_per_thread = (TILE_DIM * TILE_DIM + threadgroup_size - 1) / threadgroup_size; - - for (uint32_t e = 0; e < elements_per_thread; ++e) { - uint32_t local_idx = thread_idx + e * threadgroup_size; - if (local_idx >= TILE_DIM * TILE_DIM) break; - - uint32_t local_row = local_idx / TILE_DIM; - uint32_t local_col = local_idx % TILE_DIM; - - uint32_t global_row = tile_row * TILE_DIM + local_row; - uint32_t global_col = tile_col * TILE_DIM + local_col; - - if (global_row < N2 && global_col < N1) { - uint32_t in_idx = global_row * N1 + global_col; - uint64_t val = batch_input[in_idx]; - - uint32_t tw_idx = global_row * N1 + global_col; - uint64_t tw = twiddles[tw_idx]; - uint64_t tw_pre = twiddle_precon[tw_idx]; - val = barrett_mul_precon(val, tw, Q, tw_pre); - - shared[local_col * tile_stride + local_row] = val; - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - for (uint32_t e = 0; e < elements_per_thread; ++e) { - uint32_t local_idx = thread_idx + e * threadgroup_size; - if (local_idx >= TILE_DIM * TILE_DIM) break; - - uint32_t local_row = local_idx / TILE_DIM; - uint32_t local_col = local_idx % TILE_DIM; - - uint32_t out_row = tile_col * TILE_DIM + local_row; - uint32_t out_col = tile_row * TILE_DIM + local_col; - - if (out_row < N1 && out_col < N2) { - uint32_t out_idx = out_row * N2 + out_col; - batch_output[out_idx] = shared[local_row * tile_stride + local_col]; - } - } -} - -// ============================================================================= -// Step 4: Row NTTs (Forward) -// ============================================================================= - -kernel void four_step_row_ntt( - device uint64_t* data [[buffer(0)]], - constant uint64_t* twiddles [[buffer(1)]], - constant uint64_t* twiddle_precon [[buffer(2)]], - constant FourStepParams& params [[buffer(3)]], - uint3 tg_pos [[threadgroup_position_in_grid]], - uint3 thread_pos_in_tg [[thread_position_in_threadgroup]], - uint3 tg_size [[threads_per_threadgroup]], - threadgroup uint64_t* shared [[threadgroup(0)]] -) { - uint32_t thread_idx = thread_pos_in_tg.x + thread_pos_in_tg.y * tg_size.x + thread_pos_in_tg.z * tg_size.x * tg_size.y; - uint32_t threadgroup_size = tg_size.x * tg_size.y * tg_size.z; - - uint32_t N1 = params.N1; - uint32_t N2 = params.N2; - uint32_t N = params.N; - uint64_t Q = params.Q; - uint32_t batch_idx = tg_pos.z; - uint32_t tile_row = tg_pos.y; - uint32_t tile_col = tg_pos.x; - uint32_t tile_stride = params.tile_stride; - - uint32_t TILE_N2 = min(N2, MAX_TILE_DIM); - uint32_t TILE_N1 = min(N1, MAX_TILE_DIM); - - device uint64_t* batch_data = data + batch_idx * N; - - // Load tile - uint32_t elements_per_thread = (TILE_N2 * TILE_N1 + threadgroup_size - 1) / threadgroup_size; - - for (uint32_t e = 0; e < elements_per_thread; ++e) { - uint32_t local_idx = thread_idx + e * threadgroup_size; - if (local_idx >= TILE_N2 * TILE_N1) break; - - uint32_t local_row = local_idx / TILE_N1; - uint32_t local_col = local_idx % TILE_N1; - - uint32_t global_row = tile_row * TILE_N2 + local_row; - uint32_t global_col = tile_col * TILE_N1 + local_col; - - if (global_row < N2 && global_col < N1) { - uint32_t global_idx = global_row * N1 + global_col; - shared[local_row * tile_stride + local_col] = batch_data[global_idx]; - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Row NTTs - uint32_t log_N2 = params.log_N2; - - for (uint32_t row = 0; row < TILE_N2; ++row) { - threadgroup_ntt_forward( - shared + row * tile_stride, // Start of row - 1, // Stride = 1 (consecutive elements) - TILE_N1, // Size (rows are N1 elements after transpose) - log_N2, // Note: using log_N2 for row size after transpose - thread_idx, - threadgroup_size, - twiddles, - twiddle_precon, - Q - ); - } - - // Write back - for (uint32_t e = 0; e < elements_per_thread; ++e) { - uint32_t local_idx = thread_idx + e * threadgroup_size; - if (local_idx >= TILE_N2 * TILE_N1) break; - - uint32_t local_row = local_idx / TILE_N1; - uint32_t local_col = local_idx % TILE_N1; - - uint32_t global_row = tile_row * TILE_N2 + local_row; - uint32_t global_col = tile_col * TILE_N1 + local_col; - - if (global_row < N2 && global_col < N1) { - uint32_t global_idx = global_row * N1 + global_col; - batch_data[global_idx] = shared[local_row * tile_stride + local_col]; - } - } -} - -// ============================================================================= -// Step 4: Row NTTs (Inverse) -// ============================================================================= - -kernel void four_step_row_intt( - device uint64_t* data [[buffer(0)]], - constant uint64_t* twiddles [[buffer(1)]], - constant uint64_t* twiddle_precon [[buffer(2)]], - constant FourStepParams& params [[buffer(3)]], - uint3 tg_pos [[threadgroup_position_in_grid]], - uint3 thread_pos_in_tg [[thread_position_in_threadgroup]], - uint3 tg_size [[threads_per_threadgroup]], - threadgroup uint64_t* shared [[threadgroup(0)]] -) { - uint32_t thread_idx = thread_pos_in_tg.x + thread_pos_in_tg.y * tg_size.x + thread_pos_in_tg.z * tg_size.x * tg_size.y; - uint32_t threadgroup_size = tg_size.x * tg_size.y * tg_size.z; - - uint32_t N1 = params.N1; - uint32_t N2 = params.N2; - uint32_t N = params.N; - uint64_t Q = params.Q; - uint32_t batch_idx = tg_pos.z; - uint32_t tile_row = tg_pos.y; - uint32_t tile_col = tg_pos.x; - uint32_t tile_stride = params.tile_stride; - - uint32_t TILE_N1 = min(N1, MAX_TILE_DIM); - uint32_t TILE_N2 = min(N2, MAX_TILE_DIM); - - device uint64_t* batch_data = data + batch_idx * N; - - uint32_t elements_per_thread = (TILE_N1 * TILE_N2 + threadgroup_size - 1) / threadgroup_size; - - for (uint32_t e = 0; e < elements_per_thread; ++e) { - uint32_t local_idx = thread_idx + e * threadgroup_size; - if (local_idx >= TILE_N1 * TILE_N2) break; - - uint32_t local_row = local_idx / TILE_N2; - uint32_t local_col = local_idx % TILE_N2; - - uint32_t global_row = tile_row * TILE_N1 + local_row; - uint32_t global_col = tile_col * TILE_N2 + local_col; - - if (global_row < N1 && global_col < N2) { - uint32_t global_idx = global_row * N2 + global_col; - shared[local_row * tile_stride + local_col] = batch_data[global_idx]; - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - uint32_t log_N1 = params.log_N1; - - for (uint32_t row = 0; row < TILE_N1; ++row) { - threadgroup_ntt_inverse( - shared + row * tile_stride, - 1, - TILE_N2, - log_N1, - thread_idx, - threadgroup_size, - twiddles, - twiddle_precon, - Q - ); - } - - for (uint32_t e = 0; e < elements_per_thread; ++e) { - uint32_t local_idx = thread_idx + e * threadgroup_size; - if (local_idx >= TILE_N1 * TILE_N2) break; - - uint32_t local_row = local_idx / TILE_N2; - uint32_t local_col = local_idx % TILE_N2; - - uint32_t global_row = tile_row * TILE_N1 + local_row; - uint32_t global_col = tile_col * TILE_N2 + local_col; - - if (global_row < N1 && global_col < N2) { - uint32_t global_idx = global_row * N2 + global_col; - batch_data[global_idx] = shared[local_row * tile_stride + local_col]; - } - } -} - -// ============================================================================= -// N^{-1} Scaling Kernel -// ============================================================================= - -/** - * @brief Scale all elements by N^{-1} for inverse NTT normalization - */ -kernel void four_step_scale_n_inv( - device uint64_t* data [[buffer(0)]], - constant FourStepParams& params [[buffer(1)]], - uint global_idx [[thread_position_in_grid]] -) { - uint32_t total_elements = params.N * params.batch_size; - if (global_idx >= total_elements) return; - - uint64_t val = data[global_idx]; - uint64_t scaled = barrett_mul_precon(val, params.N_inv, params.Q, params.N_inv_precon); - data[global_idx] = scaled; -} - -// ============================================================================= -// Pointwise Modular Multiplication -// ============================================================================= - -/** - * @brief Element-wise multiplication of two polynomials in NTT domain - */ -kernel void four_step_pointwise_mul( - device uint64_t* result [[buffer(0)]], - device const uint64_t* a [[buffer(1)]], - device const uint64_t* b [[buffer(2)]], - constant FourStepParams& params [[buffer(3)]], - uint global_idx [[thread_position_in_grid]] -) { - uint32_t total_elements = params.N * params.batch_size; - if (global_idx >= total_elements) return; - - uint64_t av = a[global_idx]; - uint64_t bv = b[global_idx]; - result[global_idx] = barrett_mul(av, bv, params.Q, params.mu); -} - -// ============================================================================= -// Complete Four-Step NTT (Fused for Small N) -// ============================================================================= - -/** - * @brief Fused Four-Step NTT for N <= 4096 - * - * When N fits entirely in threadgroup memory (N <= 4096 for 32KB), - * we can process all four steps without returning to global memory. - */ -kernel void four_step_ntt_fused( - device uint64_t* data [[buffer(0)]], - constant uint64_t* col_twiddles [[buffer(1)]], - constant uint64_t* col_tw_precon [[buffer(2)]], - constant uint64_t* trans_twiddles [[buffer(3)]], - constant uint64_t* trans_tw_precon [[buffer(4)]], - constant uint64_t* row_twiddles [[buffer(5)]], - constant uint64_t* row_tw_precon [[buffer(6)]], - constant FourStepParams& params [[buffer(7)]], - uint3 tg_pos [[threadgroup_position_in_grid]], - uint3 thread_pos_in_tg [[thread_position_in_threadgroup]], - uint3 tg_size [[threads_per_threadgroup]], - threadgroup uint64_t* shared [[threadgroup(0)]] -) { - uint32_t thread_idx = thread_pos_in_tg.x + thread_pos_in_tg.y * tg_size.x + thread_pos_in_tg.z * tg_size.x * tg_size.y; - uint32_t threadgroup_size = tg_size.x * tg_size.y * tg_size.z; - - uint32_t N1 = params.N1; - uint32_t N2 = params.N2; - uint32_t N = params.N; - uint64_t Q = params.Q; - uint32_t batch_idx = tg_pos.x; - uint32_t log_N1 = params.log_N1; - uint32_t log_N2 = params.log_N2; - - device uint64_t* batch_data = data + batch_idx * N; - - // ========================================================================= - // Phase 1: Load entire polynomial into shared memory - // ========================================================================= - uint32_t elements_per_thread = (N + threadgroup_size - 1) / threadgroup_size; - - for (uint32_t e = 0; e < elements_per_thread; ++e) { - uint32_t local_idx = thread_idx + e * threadgroup_size; - if (local_idx < N) { - shared[local_idx] = batch_data[local_idx]; - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // ========================================================================= - // Phase 2: Column NTTs (Step 1) - // ========================================================================= - for (uint32_t col = 0; col < N2; ++col) { - threadgroup_ntt_forward( - shared + col, - N2, // Stride between rows - N1, - log_N1, - thread_idx, - threadgroup_size, - col_twiddles, - col_tw_precon, - Q - ); - } - - // ========================================================================= - // Phase 3: Twiddle multiplication (Step 2) - // ========================================================================= - for (uint32_t e = 0; e < elements_per_thread; ++e) { - uint32_t local_idx = thread_idx + e * threadgroup_size; - if (local_idx < N) { - uint32_t i = local_idx / N2; // Row - uint32_t j = local_idx % N2; // Column - uint64_t val = shared[local_idx]; - uint64_t tw = trans_twiddles[i * N2 + j]; - uint64_t tw_pre = trans_tw_precon[i * N2 + j]; - shared[local_idx] = barrett_mul_precon(val, tw, Q, tw_pre); - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // ========================================================================= - // Phase 4: In-place transpose in shared memory (Step 3) - // ========================================================================= - // Use temporary storage for transpose (requires extra passes) - // For square matrices (N1 == N2), we can do in-place swap - if (N1 == N2) { - for (uint32_t e = 0; e < elements_per_thread; ++e) { - uint32_t local_idx = thread_idx + e * threadgroup_size; - if (local_idx < N) { - uint32_t row = local_idx / N2; - uint32_t col = local_idx % N2; - // Only swap upper triangle - if (row < col) { - uint32_t idx1 = row * N2 + col; - uint32_t idx2 = col * N1 + row; - uint64_t temp = shared[idx1]; - shared[idx1] = shared[idx2]; - shared[idx2] = temp; - } - } - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // ========================================================================= - // Phase 5: Row NTTs (Step 4) - // ========================================================================= - for (uint32_t row = 0; row < N1; ++row) { - threadgroup_ntt_forward( - shared + row * N2, - 1, // Consecutive elements - N2, - log_N2, - thread_idx, - threadgroup_size, - row_twiddles, - row_tw_precon, - Q - ); - } - - // ========================================================================= - // Phase 6: Write back to global memory - // ========================================================================= - for (uint32_t e = 0; e < elements_per_thread; ++e) { - uint32_t local_idx = thread_idx + e * threadgroup_size; - if (local_idx < N) { - batch_data[local_idx] = shared[local_idx]; - } - } -} diff --git a/ntt/gpu/metal/ntt.metal b/ntt/gpu/metal/ntt.metal deleted file mode 100644 index 381a69e..0000000 --- a/ntt/gpu/metal/ntt.metal +++ /dev/null @@ -1,237 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -/// @file ntt.metal -/// Shared Number Theoretic Transform (NTT) primitives for lattice-based PQ crypto. -/// -/// Used by: ML-DSA (FIPS 204), ML-KEM (FIPS 203), Ringtail, SLH-DSA (FIPS 205) -/// -/// NTT operates over polynomial rings Z_q[x]/(x^n + 1). -/// The butterfly operations are perfectly parallel -- each layer of the -/// NTT can be dispatched across GPU threads. -/// -/// This file provides: -/// - Forward NTT (Cooley-Tukey butterfly) -/// - Inverse NTT (Gentleman-Sande butterfly) -/// - Pointwise polynomial multiplication in NTT domain -/// - Barrett reduction for arbitrary moduli -/// -/// Parameters are passed via constants so the same code works for: -/// ML-DSA: q=8380417, n=256 -/// ML-KEM: q=3329, n=256 -/// Ringtail: q=8380417, n=256 (same ring as ML-DSA) - -#ifndef NTT_METAL_H -#define NTT_METAL_H - -#include -using namespace metal; - -// ============================================================================= -// Barrett reduction: a mod q without division -// Precompute: barrett_shift = floor(2^k / q) where k = ceil(log2(q)) + 32 -// ============================================================================= - -/// Barrett reduction for q = 8380417 (ML-DSA / Ringtail) -/// 2^48 / 8380417 = 33554432 (approx), k=48 -inline int32_t barrett_reduce_mldsa(int32_t a) { - // q = 8380417 = 2^23 - 2^13 + 1 - // Barrett constant: floor(2^48 / q) = 33554687 (precomputed) - const int32_t q = 8380417; - const int64_t v = 33554687; // floor(2^48 / q) + 1 for safety - int64_t t = (int64_t)a * v >> 48; - int32_t r = a - (int32_t)t * q; - if (r < 0) r += q; - if (r >= q) r -= q; - return r; -} - -/// Barrett reduction for q = 3329 (ML-KEM) -inline int32_t barrett_reduce_mlkem(int32_t a) { - const int32_t q = 3329; - const int64_t v = 5039835; // floor(2^36 / q) + 1 - int64_t t = (int64_t)a * v >> 36; - int32_t r = a - (int32_t)t * q; - if (r < 0) r += q; - if (r >= q) r -= q; - return r; -} - -/// Montgomery reduction for ML-DSA: aR^{-1} mod q, R = 2^32 -/// q_inv = -q^{-1} mod 2^32 = 58728449 -inline int32_t mont_reduce_mldsa(int64_t a) { - const int32_t q = 8380417; - const int32_t q_inv = 58728449; // -q^(-1) mod 2^32 - int32_t t = (int32_t)a * q_inv; - int64_t u = (int64_t)t * q; - int32_t r = (int32_t)((a - u) >> 32); - if (r < 0) r += q; - return r; -} - -/// Montgomery reduction for ML-KEM: aR^{-1} mod q, R = 2^16 -/// q_inv = -q^{-1} mod 2^16 = 3327 -inline int16_t mont_reduce_mlkem(int32_t a) { - const int16_t q = 3329; - const int16_t q_inv = 3327; // -q^(-1) mod 2^16 - int16_t t = (int16_t)a * q_inv; - int32_t u = (int32_t)t * q; - int16_t r = (int16_t)((a - u) >> 16); - return r; -} - -// ============================================================================= -// NTT butterfly operations (Cooley-Tukey, in-place) -// ============================================================================= - -/// Forward NTT butterfly for ML-DSA (q=8380417) -/// a[j], a[j+len] <- a[j] + w*a[j+len], a[j] - w*a[j+len] (mod q) -inline void ntt_butterfly_mldsa(thread int32_t& a, thread int32_t& b, int32_t zeta) { - int32_t t = mont_reduce_mldsa((int64_t)zeta * b); - b = a - t; - a = a + t; - if (a >= 8380417) a -= 8380417; - if (b < 0) b += 8380417; -} - -/// Inverse NTT butterfly for ML-DSA (Gentleman-Sande) -inline void inv_ntt_butterfly_mldsa(thread int32_t& a, thread int32_t& b, int32_t zeta) { - int32_t t = a; - a = t + b; - b = t - b; - if (a >= 8380417) a -= 8380417; - if (b < 0) b += 8380417; - b = mont_reduce_mldsa((int64_t)zeta * b); -} - -/// Forward NTT butterfly for ML-KEM (q=3329) -inline void ntt_butterfly_mlkem(thread int16_t& a, thread int16_t& b, int16_t zeta) { - int16_t t = mont_reduce_mlkem((int32_t)zeta * b); - b = a - t; - a = a + t; -} - -/// Inverse NTT butterfly for ML-KEM -inline void inv_ntt_butterfly_mlkem(thread int16_t& a, thread int16_t& b, int16_t zeta) { - int16_t t = a; - a = t + b; - b = t - b; - b = mont_reduce_mlkem((int32_t)zeta * b); -} - -// ============================================================================= -// Precomputed zetas (roots of unity in Montgomery form) -// ============================================================================= - -// ML-DSA zetas: primitive 512th root of unity mod q=8380417 in Montgomery form -// zeta = 1753 is a primitive 512th root of unity mod q -// These are zeta^{brv(i)} * R mod q for i = 0..127 -constant int32_t MLDSA_ZETAS[128] = { - 25847, -2608894, -518909, 237124, -777960, -876248, 466468, 1826347, - 2353451, -359251, -2091905, 3119733, -2884855, 3111497, 2680103, 2725464, - 1024112, -1079900, 3585928, -549488, -1119584, 2619752, -2108549, -2118186, - -3859737, -1399561,-3277672, 1757237, -19422, 4010497, 280005, -2353451, - -1012179, -1277625, 1526252, -1402780, -2091905, 3119733, 3585928, -549488, - 2619752, -2108549, 2804197, -3199876, -38575, -2704181, 1757237, -19422, - 280005, 2706023, 1391570, 2287915, -3583748, -1399561, -3277672, -2353451, - 2353451, 3585928, -549488, 2619752, -2108549, 2804197, -3199876, -38575, - -2704181, 1757237, -19422, 280005, 2706023, 1391570, 2287915, -3583748, - -1399561, -3277672, 237124, -777960, -876248, 466468, 1826347, -2608894, - -518909, 237124, -777960, -876248, 466468, 1826347, 2353451, -359251, - -2091905, 3119733,-2884855, 3111497, 2680103, 2725464, 1024112, -1079900, - 3585928, -549488,-1119584, 2619752, -2108549, -2118186, -3859737, -1399561, - -3277672, 1757237, -19422, 4010497, 280005, -2353451, -1012179, -1277625, - 1526252, -1402780, 2706023, 1391570, 2287915, -3583748, -1399561, -3277672, - 1757237, -19422, 280005, 2706023, 1391570, 2287915, -3583748, -1399561 -}; - -// ============================================================================= -// Full NTT / inverse NTT for n=256 polynomials -// ============================================================================= - -/// In-place forward NTT for ML-DSA polynomial (n=256, q=8380417) -/// Input: coefficients in standard order -/// Output: coefficients in bit-reversed NTT order -inline void ntt_mldsa(thread int32_t poly[256]) { - int k = 0; - for (int len = 128; len >= 1; len >>= 1) { - for (int start = 0; start < 256; start += 2 * len) { - int32_t zeta = MLDSA_ZETAS[++k]; - for (int j = start; j < start + len; j++) { - ntt_butterfly_mldsa(poly[j], poly[j + len], zeta); - } - } - } -} - -/// In-place inverse NTT for ML-DSA polynomial -inline void inv_ntt_mldsa(thread int32_t poly[256]) { - const int32_t q = 8380417; - // f = R * 2^{-8} mod q (scaling factor for inverse) - const int32_t f = 41978; // 2^32 * 256^{-1} mod q - - int k = 127; - for (int len = 1; len <= 128; len <<= 1) { - for (int start = 0; start < 256; start += 2 * len) { - int32_t zeta = -MLDSA_ZETAS[k--]; - if (zeta < 0) zeta += q; - for (int j = start; j < start + len; j++) { - inv_ntt_butterfly_mldsa(poly[j], poly[j + len], zeta); - } - } - } - // Scale by f - for (int i = 0; i < 256; i++) { - poly[i] = mont_reduce_mldsa((int64_t)f * poly[i]); - } -} - -/// Pointwise multiplication of two NTT-domain ML-DSA polynomials -inline void poly_pointwise_mldsa(thread int32_t c[256], - thread const int32_t a[256], - thread const int32_t b[256]) { - for (int i = 0; i < 256; i++) { - c[i] = mont_reduce_mldsa((int64_t)a[i] * b[i]); - } -} - -// ============================================================================= -// NTT batch kernel: each thread transforms one polynomial -// ============================================================================= - -/// Batch forward NTT for ML-DSA polynomials. -/// Each thread computes NTT of one 256-coefficient polynomial. -kernel void ntt_mldsa_batch( - device int32_t* polys [[buffer(0)]], // [num_polys * 256] - constant uint& num_polys [[buffer(1)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= num_polys) return; - - int32_t poly[256]; - device int32_t* src = polys + tid * 256; - for (int i = 0; i < 256; i++) poly[i] = src[i]; - - ntt_mldsa(poly); - - for (int i = 0; i < 256; i++) src[i] = poly[i]; -} - -/// Batch inverse NTT for ML-DSA polynomials. -kernel void inv_ntt_mldsa_batch( - device int32_t* polys [[buffer(0)]], - constant uint& num_polys [[buffer(1)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= num_polys) return; - - int32_t poly[256]; - device int32_t* src = polys + tid * 256; - for (int i = 0; i < 256; i++) poly[i] = src[i]; - - inv_ntt_mldsa(poly); - - for (int i = 0; i < 256; i++) src[i] = poly[i]; -} - -#endif // NTT_METAL_H diff --git a/ntt/gpu/metal/ntt_kernels.metal b/ntt/gpu/metal/ntt_kernels.metal deleted file mode 100644 index 258956f..0000000 --- a/ntt/gpu/metal/ntt_kernels.metal +++ /dev/null @@ -1,572 +0,0 @@ -// ============================================================================= -// Optimal NTT Kernels for Lux FHE - Ported from OpenFHE's Native Implementation -// ============================================================================= -// -// Design based on OpenFHE's NumberTheoreticTransformNat (transformnat-impl.h): -// - Forward: Cooley-Tukey (DIT) with bit-reversed output -// - Inverse: Gentleman-Sande (GS) with bit-reversed input -// - Barrett reduction with precomputed constants (ModMulFastConst) -// - Peeled first/last stages for performance -// - Branchless arithmetic -// -// Memory layout: -// - twiddles stored as [omega^0, omega^1, ..., omega^{N-1}] in bit-reversed order -// - Data processed in-place [batch, N] where batch is parallelized - -#include -using namespace metal; - -// ============================================================================= -// Barrett Reduction Constants -// ============================================================================= -// -// For modulus Q, precompute mu = floor(2^k / Q) where k = 2 * ceil(log2(Q)) -// Then: a*b mod Q ≈ a*b - floor((a*b*mu) >> k) * Q -// -// For Q < 2^32, we use k = 64, so mu = floor(2^64 / Q) - -struct NTTParams { - uint64_t Q; // Prime modulus - uint64_t mu; // Barrett constant: floor(2^64 / Q) - uint64_t N_inv; // N^{-1} mod Q - uint64_t N_inv_precon; // Barrett precomputation for N_inv - uint32_t N; // Ring dimension (power of 2) - uint32_t log_N; // log2(N) -}; - -// ============================================================================= -// Barrett Modular Multiplication -// ============================================================================= - -// Compute (a * b) mod Q using Barrett reduction -// Requires: a, b < Q, and precon = floor(2^64 * omega / Q) for omega = b -inline uint64_t mod_mul_barrett(uint64_t a, uint64_t omega, uint64_t Q, uint64_t precon_omega) { - // High 64 bits of a * precon_omega (approximate quotient) - uint64_t q_approx = metal::mulhi(a, precon_omega); - - // Compute a * omega - q_approx * Q - uint64_t product = a * omega; - uint64_t result = product - q_approx * Q; - - // Conditional reduction (result might be in [0, 2Q)) - return result >= Q ? result - Q : result; -} - -// Simple modular multiplication for cases without precomputation -inline uint64_t mod_mul(uint64_t a, uint64_t b, uint64_t Q) { - // Use mulhi + lo to get full 128-bit result - uint64_t lo = a * b; - uint64_t hi = metal::mulhi(a, b); - - // For small Q (< 2^32), hi is often 0 - if (hi == 0) { - return lo % Q; - } - - // Full reduction: compute (hi * 2^64 + lo) mod Q - // 2^64 mod Q = ((2^32 mod Q)^2) mod Q - uint64_t two64_mod_q = ((uint64_t(1) << 32) % Q); - two64_mod_q = (two64_mod_q * two64_mod_q) % Q; - - return (lo % Q + (hi % Q) * two64_mod_q % Q) % Q; -} - -// Modular addition: (a + b) mod Q -inline uint64_t mod_add(uint64_t a, uint64_t b, uint64_t Q) { - uint64_t sum = a + b; - // Branchless: use conditional instead of if-statement (Clang optimization) - return sum - (sum >= Q ? Q : 0); -} - -// Modular subtraction: (a - b) mod Q -inline uint64_t mod_sub(uint64_t a, uint64_t b, uint64_t Q) { - // Branchless subtraction - return a + (b > a ? Q : 0) - b; -} - -// ============================================================================= -// Forward NTT - Cooley-Tukey (DIT) In-Place -// ============================================================================= -// -// Algorithm (from OpenFHE): -// for (m = 1, t = n/2, logt = log(n)-1; m < n; m *= 2, t /= 2, --logt) -// for (i = 0; i < m; ++i) -// omega = rootOfUnityTable[m + i] -// for (j1 = i << logt, j2 = j1 + t; j1 < j2; ++j1) -// loVal = element[j1] -// hiVal = element[j1 + t] * omega -// element[j1] = (loVal + hiVal) mod Q -// element[j1 + t] = (loVal - hiVal) mod Q -// -// Parallelization: Each stage requires barrier, so we dispatch per-stage -// or use threadgroup memory for small N. - -// Single butterfly operation -inline void ct_butterfly(device uint64_t* data, - uint32_t idx_lo, uint32_t idx_hi, - uint64_t omega, uint64_t precon_omega, - uint64_t Q) { - uint64_t lo_val = data[idx_lo]; - uint64_t hi_val = data[idx_hi]; - - // hi_val *= omega (mod Q) - uint64_t omega_factor = mod_mul_barrett(hi_val, omega, Q, precon_omega); - - // CT butterfly: (lo, hi) -> (lo + omega*hi, lo - omega*hi) - data[idx_lo] = mod_add(lo_val, omega_factor, Q); - data[idx_hi] = mod_sub(lo_val, omega_factor, Q); -} - -// Forward NTT stage kernel -kernel void ntt_forward_stage_optimal( - device uint64_t* data [[buffer(0)]], - constant uint64_t* twiddles [[buffer(1)]], // omega in bit-reversed - constant uint64_t* precon_twiddles [[buffer(2)]], // precomputed Barrett const - constant NTTParams& params [[buffer(3)]], - constant uint32_t& stage [[buffer(4)]], // 0 to log_N - 1 - constant uint32_t& batch_size [[buffer(5)]], - uint2 tid [[thread_position_in_grid]] -) { - uint32_t batch_idx = tid.y; - uint32_t butterfly_idx = tid.x; - - if (batch_idx >= batch_size) return; - - uint32_t N = params.N; - uint64_t Q = params.Q; - - // Stage s: m = 2^s butterflies of size 2^{log_N - s} - uint32_t m = 1u << stage; - uint32_t t = N >> (stage + 1); // half-size - - uint32_t num_butterflies = N >> 1; - if (butterfly_idx >= num_butterflies) return; - - // Map butterfly index to (i, j) in the OpenFHE loop structure - // i = butterfly_idx / t - // j = butterfly_idx % t - uint32_t i = butterfly_idx / t; - uint32_t j = butterfly_idx % t; - - uint32_t idx_lo = (i << (params.log_N - stage)) + j; - uint32_t idx_hi = idx_lo + t; - - // Twiddle index: m + i (bit-reversed storage like OpenFHE) - uint32_t tw_idx = m + i; - uint64_t omega = twiddles[tw_idx]; - uint64_t precon = precon_twiddles[tw_idx]; - - device uint64_t* poly = data + batch_idx * N; - ct_butterfly(poly, idx_lo, idx_hi, omega, precon, Q); -} - -// ============================================================================= -// Inverse NTT - Gentleman-Sande (DIF) In-Place -// ============================================================================= -// -// Algorithm (from OpenFHE): -// for (m = n/2, t = 1, logt = 1; m >= 1; m /= 2, t *= 2, ++logt) -// for (i = 0; i < m; ++i) -// omega = rootOfUnityInverseTable[m + i] -// for (j1 = i << logt, j2 = j1 + t; j1 < j2; ++j1) -// loVal = element[j1] -// hiVal = element[j1 + t] -// element[j1] = (loVal + hiVal) mod Q -// element[j1 + t] = (loVal - hiVal) * omega mod Q -// for (i = 0; i < n; ++i) -// element[i] *= cycloOrderInv mod Q - -// Single GS butterfly operation -inline void gs_butterfly(device uint64_t* data, - uint32_t idx_lo, uint32_t idx_hi, - uint64_t omega, uint64_t precon_omega, - uint64_t Q) { - uint64_t lo_val = data[idx_lo]; - uint64_t hi_val = data[idx_hi]; - - // GS butterfly: (lo, hi) -> (lo + hi, (lo - hi) * omega) - uint64_t sum = mod_add(lo_val, hi_val, Q); - uint64_t diff = mod_sub(lo_val, hi_val, Q); - uint64_t diff_tw = mod_mul_barrett(diff, omega, Q, precon_omega); - - data[idx_lo] = sum; - data[idx_hi] = diff_tw; -} - -// Inverse NTT stage kernel -kernel void ntt_inverse_stage_optimal( - device uint64_t* data [[buffer(0)]], - constant uint64_t* inv_twiddles [[buffer(1)]], - constant uint64_t* precon_inv_twiddles [[buffer(2)]], - constant NTTParams& params [[buffer(3)]], - constant uint32_t& stage [[buffer(4)]], // 0 to log_N - 1 - constant uint32_t& batch_size [[buffer(5)]], - uint2 tid [[thread_position_in_grid]] -) { - uint32_t batch_idx = tid.y; - uint32_t butterfly_idx = tid.x; - - if (batch_idx >= batch_size) return; - - uint32_t N = params.N; - uint64_t Q = params.Q; - - // Stage s: m = N/2^{s+1}, t = 2^s - uint32_t m = N >> (stage + 1); - uint32_t t = 1u << stage; - - uint32_t num_butterflies = N >> 1; - if (butterfly_idx >= num_butterflies) return; - - // Map butterfly index - uint32_t i = butterfly_idx / t; - uint32_t j = butterfly_idx % t; - - uint32_t idx_lo = (i << (stage + 1)) + j; - uint32_t idx_hi = idx_lo + t; - - uint32_t tw_idx = m + i; - uint64_t omega = inv_twiddles[tw_idx]; - uint64_t precon = precon_inv_twiddles[tw_idx]; - - device uint64_t* poly = data + batch_idx * N; - gs_butterfly(poly, idx_lo, idx_hi, omega, precon, Q); -} - -// Scale by N^{-1} after inverse NTT -kernel void ntt_scale_optimal( - device uint64_t* data [[buffer(0)]], - constant NTTParams& params [[buffer(1)]], - constant uint32_t& batch_size [[buffer(2)]], - uint2 tid [[thread_position_in_grid]] -) { - uint32_t batch_idx = tid.y; - uint32_t coeff_idx = tid.x; - - if (batch_idx >= batch_size || coeff_idx >= params.N) return; - - device uint64_t* poly = data + batch_idx * params.N; - - poly[coeff_idx] = mod_mul_barrett( - poly[coeff_idx], - params.N_inv, - params.Q, - params.N_inv_precon - ); -} - -// ============================================================================= -// Complete Forward NTT (All Stages in Shared Memory) -// ============================================================================= -// -// For N <= 1024, process all stages in shared memory with threadgroup barriers. -// This avoids multiple kernel launches. - -kernel void ntt_forward_complete_optimal( - device uint64_t* data [[buffer(0)]], - constant uint64_t* twiddles [[buffer(1)]], - constant uint64_t* precon_twiddles [[buffer(2)]], - constant NTTParams& params [[buffer(3)]], - constant uint32_t& batch_size [[buffer(4)]], - uint2 tid [[thread_position_in_grid]], - uint2 tg_size [[threads_per_threadgroup]], - uint2 tg_id [[threadgroup_position_in_grid]], - threadgroup uint64_t* shared [[threadgroup(0)]] -) { - uint32_t batch_idx = tg_id.y; - uint32_t local_idx = tid.x % tg_size.x; - - if (batch_idx >= batch_size) return; - - uint32_t N = params.N; - uint32_t log_N = params.log_N; - uint64_t Q = params.Q; - - device uint64_t* poly = data + batch_idx * N; - - // Load into shared memory - for (uint32_t i = local_idx; i < N; i += tg_size.x) { - shared[i] = poly[i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Cooley-Tukey stages (OpenFHE structure) - // for (m = 1, t = n/2; m < n; m *= 2, t /= 2) - for (uint32_t stage = 0; stage < log_N; ++stage) { - uint32_t m = 1u << stage; - uint32_t t = N >> (stage + 1); - - for (uint32_t butterfly_idx = local_idx; butterfly_idx < N/2; butterfly_idx += tg_size.x) { - uint32_t i = butterfly_idx / t; - uint32_t j = butterfly_idx % t; - - uint32_t idx_lo = (i << (log_N - stage)) + j; - uint32_t idx_hi = idx_lo + t; - - uint32_t tw_idx = m + i; - uint64_t omega = twiddles[tw_idx]; - uint64_t precon = precon_twiddles[tw_idx]; - - uint64_t lo_val = shared[idx_lo]; - uint64_t hi_val = shared[idx_hi]; - - uint64_t omega_factor = mod_mul_barrett(hi_val, omega, Q, precon); - - shared[idx_lo] = mod_add(lo_val, omega_factor, Q); - shared[idx_hi] = mod_sub(lo_val, omega_factor, Q); - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - // Write back to global memory - for (uint32_t i = local_idx; i < N; i += tg_size.x) { - poly[i] = shared[i]; - } -} - -// ============================================================================= -// Complete Inverse NTT (All Stages + Scaling) -// ============================================================================= - -kernel void ntt_inverse_complete_optimal( - device uint64_t* data [[buffer(0)]], - constant uint64_t* inv_twiddles [[buffer(1)]], - constant uint64_t* precon_inv_twiddles [[buffer(2)]], - constant NTTParams& params [[buffer(3)]], - constant uint32_t& batch_size [[buffer(4)]], - uint2 tid [[thread_position_in_grid]], - uint2 tg_size [[threads_per_threadgroup]], - uint2 tg_id [[threadgroup_position_in_grid]], - threadgroup uint64_t* shared [[threadgroup(0)]] -) { - uint32_t batch_idx = tg_id.y; - uint32_t local_idx = tid.x % tg_size.x; - - if (batch_idx >= batch_size) return; - - uint32_t N = params.N; - uint32_t log_N = params.log_N; - uint64_t Q = params.Q; - uint64_t N_inv = params.N_inv; - uint64_t N_inv_precon = params.N_inv_precon; - - device uint64_t* poly = data + batch_idx * N; - - // Load into shared memory - for (uint32_t i = local_idx; i < N; i += tg_size.x) { - shared[i] = poly[i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Gentleman-Sande stages (OpenFHE structure) - // for (m = n/2, t = 1; m >= 1; m /= 2, t *= 2) - for (uint32_t stage = 0; stage < log_N; ++stage) { - uint32_t m = N >> (stage + 1); - uint32_t t = 1u << stage; - - for (uint32_t butterfly_idx = local_idx; butterfly_idx < N/2; butterfly_idx += tg_size.x) { - uint32_t i = butterfly_idx / t; - uint32_t j = butterfly_idx % t; - - uint32_t idx_lo = (i << (stage + 1)) + j; - uint32_t idx_hi = idx_lo + t; - - uint32_t tw_idx = m + i; - uint64_t omega = inv_twiddles[tw_idx]; - uint64_t precon = precon_inv_twiddles[tw_idx]; - - uint64_t lo_val = shared[idx_lo]; - uint64_t hi_val = shared[idx_hi]; - - uint64_t sum = mod_add(lo_val, hi_val, Q); - uint64_t diff = mod_sub(lo_val, hi_val, Q); - uint64_t diff_tw = mod_mul_barrett(diff, omega, Q, precon); - - shared[idx_lo] = sum; - shared[idx_hi] = diff_tw; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - // Scale by N^{-1} and write back - for (uint32_t i = local_idx; i < N; i += tg_size.x) { - poly[i] = mod_mul_barrett(shared[i], N_inv, Q, N_inv_precon); - } -} - -// ============================================================================= -// Negacyclic Rotation for Blind Rotation -// ============================================================================= -// -// Computes X^k * poly in Z_Q[X]/(X^N + 1) -// For rotation amount k: -// (X^k * poly)[i] = sign * poly[src] -// where src = (i - k) mod N, sign = -1 if wrap occurred odd times - -kernel void negacyclic_rotate_optimal( - device uint64_t* output [[buffer(0)]], - constant uint64_t* input [[buffer(1)]], - constant NTTParams& params [[buffer(2)]], - constant int32_t* rotations [[buffer(3)]], - constant uint32_t& batch_size [[buffer(4)]], - uint2 tid [[thread_position_in_grid]] -) { - uint32_t batch_idx = tid.y; - uint32_t coeff_idx = tid.x; - - if (batch_idx >= batch_size || coeff_idx >= params.N) return; - - uint32_t N = params.N; - uint64_t Q = params.Q; - - // Get rotation amount, normalized to [0, 2N) - int32_t k = rotations[batch_idx]; - int32_t two_N = 2 * (int32_t)N; - k = ((k % two_N) + two_N) % two_N; - - // Compute source index and sign - // For negacyclic ring: X^N = -1 - // X^k * a[j] X^j contributes to coefficient (j+k) mod 2N - // with sign = -1 if (j+k) >= N - - int32_t src_idx = (int32_t)coeff_idx - k; - bool negate = false; - - // Handle wraparound with negation - while (src_idx < 0) { - src_idx += N; - negate = !negate; - } - while (src_idx >= (int32_t)N) { - src_idx -= N; - negate = !negate; - } - - uint32_t in_offset = batch_idx * N + (uint32_t)src_idx; - uint32_t out_offset = batch_idx * N + coeff_idx; - - uint64_t val = input[in_offset]; - output[out_offset] = negate ? (Q - val) : val; -} - -// ============================================================================= -// Pointwise Multiply-Accumulate for External Product -// ============================================================================= -// -// In Lux FHE external product: acc += digit_ntt * rgsw_component_ntt -// This is the core operation called many times per blind rotation step. - -kernel void ntt_pointwise_mac_optimal( - device uint64_t* acc [[buffer(0)]], - constant uint64_t* a [[buffer(1)]], - constant uint64_t* b [[buffer(2)]], - constant NTTParams& params [[buffer(3)]], - constant uint32_t& batch_size [[buffer(4)]], - uint2 tid [[thread_position_in_grid]] -) { - uint32_t batch_idx = tid.y; - uint32_t coeff_idx = tid.x; - - if (batch_idx >= batch_size || coeff_idx >= params.N) return; - - uint32_t idx = batch_idx * params.N + coeff_idx; - uint64_t Q = params.Q; - - // Simple modular MAC (without Barrett for the multiply since operands may not have precon) - uint64_t prod = mod_mul(a[idx], b[idx], Q); - acc[idx] = mod_add(acc[idx], prod, Q); -} - -// ============================================================================= -// Digit Decomposition for External Product -// ============================================================================= -// -// Decompose polynomial coefficients into base-B digits for RGSW multiplication. -// a[i] = sum_{l=0}^{L-1} d_l[i] * B^l where d_l[i] in [0, B) - -kernel void decompose_digits( - device uint64_t* digits [[buffer(0)]], // Output: [batch, L, N] - constant uint64_t* poly [[buffer(1)]], // Input: [batch, N] - constant NTTParams& params [[buffer(2)]], - constant uint64_t& base [[buffer(3)]], // Decomposition base B - constant uint32_t& num_levels [[buffer(4)]],// L decomposition levels - constant uint32_t& batch_size [[buffer(5)]], - uint3 tid [[thread_position_in_grid]] -) { - uint32_t batch_idx = tid.z; - uint32_t level = tid.y; - uint32_t coeff_idx = tid.x; - - if (batch_idx >= batch_size || level >= num_levels || coeff_idx >= params.N) return; - - uint64_t val = poly[batch_idx * params.N + coeff_idx]; - - // Extract digit at level l: floor(val / B^l) mod B - for (uint32_t l = 0; l < level; ++l) { - val /= base; - } - uint64_t digit = val % base; - - // Store in [batch, level, N] layout - digits[batch_idx * num_levels * params.N + level * params.N + coeff_idx] = digit; -} - -// ============================================================================= -// CMux for Blind Rotation -// ============================================================================= -// -// CMux(s, d0, d1) = d0 + s * (d1 - d0) -// Where s is the selector (RGSW ciphertext), d0, d1 are RLWE ciphertexts. -// -// In practice: acc = acc + ExternalProduct(rotated_acc - acc, bsk[i]) -// This kernel computes: d1 - d0 (the difference for external product input) - -kernel void cmux_diff( - device uint64_t* diff [[buffer(0)]], // Output: d1 - d0 - constant uint64_t* d0 [[buffer(1)]], // Unrotated accumulator - constant uint64_t* d1 [[buffer(2)]], // Rotated accumulator - constant NTTParams& params [[buffer(3)]], - constant uint32_t& batch_size [[buffer(4)]], - uint2 tid [[thread_position_in_grid]] -) { - uint32_t batch_idx = tid.y; - uint32_t coeff_idx = tid.x; - - if (batch_idx >= batch_size || coeff_idx >= params.N) return; - - uint32_t idx = batch_idx * params.N + coeff_idx; - diff[idx] = mod_sub(d1[idx], d0[idx], params.Q); -} - -// ============================================================================= -// External Product Accumulate -// ============================================================================= -// -// Accumulates: acc += sum_{l=0}^{L-1} INTT(NTT(digit_l) * RGSW_l_ntt) -// This is the final step combining all decomposition levels. - -kernel void external_product_finalize( - device uint64_t* acc [[buffer(0)]], // Accumulator (output) - constant uint64_t* prod [[buffer(1)]], // Product from pointwise multiply [batch, L, N] - constant NTTParams& params [[buffer(2)]], - constant uint32_t& num_levels [[buffer(3)]], - constant uint32_t& batch_size [[buffer(4)]], - uint2 tid [[thread_position_in_grid]] -) { - uint32_t batch_idx = tid.y; - uint32_t coeff_idx = tid.x; - - if (batch_idx >= batch_size || coeff_idx >= params.N) return; - - uint64_t Q = params.Q; - uint64_t sum = 0; - - // Sum over all decomposition levels - for (uint32_t l = 0; l < num_levels; ++l) { - uint32_t idx = batch_idx * num_levels * params.N + l * params.N + coeff_idx; - sum = mod_add(sum, prod[idx], Q); - } - - uint32_t out_idx = batch_idx * params.N + coeff_idx; - acc[out_idx] = mod_add(acc[out_idx], sum, Q); -} diff --git a/ntt/gpu/metal/ntt_large.metal b/ntt/gpu/metal/ntt_large.metal deleted file mode 100644 index 5637cc7..0000000 --- a/ntt/gpu/metal/ntt_large.metal +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Metal kernels for the six-step large-N NTT. -// -// The bulk of the per-stage work is identical to four_step_ntt.metal which -// already lives in this directory — column NTT in shared memory, fused -// twiddle-multiply-and-transpose, row NTT in shared memory. The only -// adjustment for N up to 2^20 is iterating the column/row pass in tiles of -// at most 4096 elements so they fit in 32 KB threadgroup memory. -// -// This file deliberately re-exports the four_step_* kernels under -// large_ntt_* aliases so the CPU host driver (gpu/metal/ntt_large_driver.mm -// when wired) can address them through one symbol prefix. No new arithmetic -// logic; the kernel bodies in four_step_ntt.metal are the source of truth. - -#include -using namespace metal; - -// ============================================================================= -// Re-exports of four_step_ntt.metal entry points under large_ntt_* names. -// The Metal compiler treats these as forward declarations linking against -// the symbols emitted by four_step_ntt.metal; both files are compiled into -// the same metallib so the host driver can pick whichever name it prefers. -// ============================================================================= - -// Forward six-step pipeline: -// large_ntt_column_fwd -- step 1: N2 columns of N1-size NTT -// large_ntt_twiddle_xpose -- step 2+3: diagonal multiply + transpose -// large_ntt_row_fwd -- step 4: N1 rows of N2-size NTT (post-xpose) -// -// Inverse six-step pipeline (mirrored): -// large_ntt_column_inv -// large_ntt_inv_twiddle_xpose -// large_ntt_row_inv -// large_ntt_scale_n_inv -- final 1/N normalisation - -// Shared parameter struct; binary-compatible with FourStepParams in -// four_step_ntt.metal. -struct LargeNttParams { - uint64_t Q; - uint64_t mu; - uint64_t N_inv; - uint64_t N_inv_precon; - uint32_t N; - uint32_t N1; - uint32_t N2; - uint32_t log_N1; - uint32_t log_N2; - uint32_t tile_stride; - uint32_t batch_size; -}; - -// ============================================================================= -// Direct entry points for hosts that prefer the large-N branding. Bodies are -// 100% delegations to the corresponding four_step_* kernels. See -// four_step_ntt.metal for arithmetic; do not duplicate it here. -// -// At link time the metallib will deduplicate the kernel body across both -// names, so this is purely a naming surface — not a code-size cost. -// ============================================================================= - -// Note on dispatch: the host driver (ntt_large_driver.mm — wired in a future -// patch when Metal-capable CI runners are online) constructs a -// MTLComputePipelineState for each name, computes a launch grid based on N, -// N1, N2 and the tile size of 4096 elements (max threadgroup memory of -// 32 KB / 8 B per element), and schedules the steps in order: -// -// 1. encoder.dispatchThreadgroups(grid_columns, /*tg=*/(64,16,1)) -// 2. encoder.dispatchThreadgroups(grid_diag, /*tg=*/(64,16,1)) -// 3. encoder.dispatchThreadgroups(grid_rows, /*tg=*/(64,16,1)) -// -// Output buffer ends up in the input layout (in-place). For N = 2^20 with -// N1 = N2 = 1024 each step processes (2^10 / 64) x (2^10 / 16) = 16 x 64 = -// 1024 threadgroups, which saturates an Apple M2/M3 GPU. - -// Inline forward declarations that compile in any Metal target. The actual -// kernels are in four_step_ntt.metal; the linker maps the names below to -// those kernel bodies. -// -// (No body required when the symbol is supplied by another translation unit -// in the same metallib. We use #pragma to suppress the "kernel not defined" -// warning if compiling in isolation.) - -extern "C" { -kernel void four_step_column_ntt( - device uint64_t* data [[buffer(0)]], - constant uint64_t* twiddles [[buffer(1)]], - constant uint64_t* twiddle_precon [[buffer(2)]], - constant LargeNttParams& params [[buffer(3)]], - uint3 tg_pos [[threadgroup_position_in_grid]], - uint3 thread_pos [[thread_position_in_threadgroup]], - uint3 tg_size [[threads_per_threadgroup]], - threadgroup uint64_t* shared [[threadgroup(0)]]); - -kernel void four_step_twiddle_transpose( - device uint64_t* output [[buffer(0)]], - device const uint64_t* input [[buffer(1)]], - constant uint64_t* twiddles [[buffer(2)]], - constant uint64_t* twiddle_precon [[buffer(3)]], - constant LargeNttParams& params [[buffer(4)]], - uint3 tg_pos [[threadgroup_position_in_grid]], - uint3 thread_pos [[thread_position_in_threadgroup]], - uint3 tg_size [[threads_per_threadgroup]], - threadgroup uint64_t* shared [[threadgroup(0)]]); - -kernel void four_step_row_ntt( - device uint64_t* data [[buffer(0)]], - constant uint64_t* twiddles [[buffer(1)]], - constant uint64_t* twiddle_precon [[buffer(2)]], - constant LargeNttParams& params [[buffer(3)]], - uint3 tg_pos [[threadgroup_position_in_grid]], - uint3 thread_pos [[thread_position_in_threadgroup]], - uint3 tg_size [[threads_per_threadgroup]], - threadgroup uint64_t* shared [[threadgroup(0)]]); - -kernel void four_step_scale_n_inv( - device uint64_t* data [[buffer(0)]], - constant LargeNttParams& params [[buffer(1)]], - uint global_idx [[thread_position_in_grid]]); -} diff --git a/ntt/gpu/metal/ntt_large_driver.cpp b/ntt/gpu/metal/ntt_large_driver.cpp deleted file mode 100644 index 2b3a20f..0000000 --- a/ntt/gpu/metal/ntt_large_driver.cpp +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Metal host-side driver for the six-step large-N NTT. -// -// Wire pattern matches kzg/gpu/cuda/kzg_driver_cuda.cpp + banderwagon's WGSL -// driver: when no Metal device is reachable (CI / non-Apple runners) the -// driver falls through to the CPU oracle so byte-equality is structurally -// exact. On Apple hardware the same entry points dispatch -// gpu/metal/ntt_large.metal which re-exports four_step_ntt.metal kernels. - -#include "ntt_large.hpp" - -namespace lux::crypto::ntt::large::gpu_metal { - -bool device_available() { - // Set to true once the metallib is wired into the build; keep false here - // so the test exercises the CPU oracle path on every host. - return false; -} - -void forward(uint64_t* a, const LargeContext& ctx) { - lux::crypto::ntt::large::forward(a, ctx); -} - -void inverse(uint64_t* a, const LargeContext& ctx) { - lux::crypto::ntt::large::inverse(a, ctx); -} - -} // namespace lux::crypto::ntt::large::gpu_metal diff --git a/ntt/gpu/metal/ntt_metal_kernel.metal b/ntt/gpu/metal/ntt_metal_kernel.metal deleted file mode 100644 index 8da4210..0000000 --- a/ntt/gpu/metal/ntt_metal_kernel.metal +++ /dev/null @@ -1,373 +0,0 @@ -// ============================================================================= -// NTT Metal Shaders with Shared Memory Twiddle Prefetch -// ============================================================================= -// -// High-performance NTT kernels for Apple Metal using threadgroup shared memory. -// -// Key optimizations: -// 1. Twiddle prefetch: Load twiddles into shared memory before butterfly stage -// 2. Cooperative loading: Each thread loads one twiddle, then barrier sync -// 3. Bank conflict avoidance: Stride twiddle access to avoid shared memory banks -// 4. Coalesced global reads: Sequential memory access pattern -// -// Memory hierarchy on Apple M3: -// - Global memory: ~200ns latency, ~400 GB/s bandwidth -// - Shared memory: ~20ns latency, ~3 TB/s bandwidth (per SIMD) -// - Registers: ~1ns, unlimited bandwidth (within SIMD) -// -// This kernel achieves ~10x speedup for twiddle access by prefetching. -// -// Copyright (C) 2024-2025 Lux Industries Inc. -// SPDX-License-Identifier: Apache-2.0 - -#include -using namespace metal; - -// ============================================================================= -// NTT Parameters Structure -// ============================================================================= - -struct NTTParams { - uint64_t Q; // Prime modulus - uint64_t mu; // Barrett constant: floor(2^64 / Q) - uint64_t N_inv; // N^{-1} mod Q - uint64_t N_inv_precon; // Barrett precomputation for N_inv - uint32_t N; // Ring dimension - uint32_t log_N; // log2(N) - uint32_t stage; // Current NTT stage - uint32_t batch; // Batch size -}; - -// ============================================================================= -// Modular Arithmetic -// ============================================================================= - -// Barrett reduction: compute (a * b) mod Q without full 128-bit division -// Assumes a, b < Q and Q < 2^62 -inline uint64_t barrett_mul(uint64_t a, uint64_t b, uint64_t Q, uint64_t mu) { - // Compute a * b (requires 128-bit intermediate) - // Metal doesn't have native 128-bit, so we use the split approach - uint64_t lo = a * b; - - // Approximate quotient: q = (a * b * mu) >> 64 - // Since we can't do 128-bit multiply directly, we estimate - // For correctness with 62-bit primes, this approximation is sufficient - uint64_t q = mulhi(lo, mu); - - // result = a * b - q * Q - uint64_t result = lo - q * Q; - - // One conditional subtraction for exact result - if (result >= Q) result -= Q; - - return result; -} - -inline uint64_t mod_add(uint64_t a, uint64_t b, uint64_t Q) { - uint64_t sum = a + b; - return (sum >= Q) ? sum - Q : sum; -} - -inline uint64_t mod_sub(uint64_t a, uint64_t b, uint64_t Q) { - return (a >= b) ? a - b : a + Q - b; -} - -// ============================================================================= -// Shared Memory Twiddle Prefetch - Single Stage Kernel -// ============================================================================= -// -// This kernel processes one NTT stage with shared memory twiddle prefetch. -// -// Thread organization: -// - Threadgroup size: min(N/2, 256) threads -// - Each thread handles one or more butterflies -// -// Memory access pattern: -// 1. Cooperative load: threads[0..m-1] load twiddles into shared memory -// 2. Barrier synchronization -// 3. Each thread reads twiddle from shared memory (fast) -// 4. Butterfly computation -// 5. Write results back to global memory - -// Maximum twiddles in shared memory (32KB / 8 bytes = 4096) -constant uint32_t MAX_SHARED_TWIDDLES = 4096; - -// Threadgroup shared memory for twiddle prefetch -// Using 8-byte alignment for uint64_t -kernel void ntt_forward_stage_shared( - device uint64_t* data [[buffer(0)]], - device const uint64_t* twiddles [[buffer(1)]], - constant NTTParams& params [[buffer(2)]], - uint thread_idx [[thread_index_in_threadgroup]], - uint threadgroup_size [[threads_per_threadgroup]], - uint threadgroup_idx [[threadgroup_position_in_grid]], - uint num_threadgroups [[threadgroups_per_grid]] -) { - // Shared memory for twiddle prefetch - threadgroup uint64_t twiddles_shared[MAX_SHARED_TWIDDLES]; - - uint32_t N = params.N; - uint64_t Q = params.Q; - uint64_t mu = params.mu; - uint32_t stage = params.stage; - uint32_t batch_idx = threadgroup_idx; - - // Stage parameters - uint32_t m = 1u << stage; // Number of twiddle factors needed - uint32_t t = N >> (stage + 1); // Butterflies per twiddle - - // ========================================================================= - // Phase 1: Cooperative twiddle prefetch into shared memory - // ========================================================================= - // - // Each thread loads one or more twiddles. - // For stage s, we need 2^s twiddles. - // For early stages (small m), multiple threads share the load. - // For late stages (large m), each thread loads multiple twiddles. - - uint32_t twiddles_to_load = m; - uint32_t loads_per_thread = (twiddles_to_load + threadgroup_size - 1) / threadgroup_size; - - for (uint32_t i = 0; i < loads_per_thread; ++i) { - uint32_t tw_idx = thread_idx + i * threadgroup_size; - if (tw_idx < m && tw_idx < MAX_SHARED_TWIDDLES) { - // Twiddles are stored as: twiddles[m + i] for stage with 2^stage groups - twiddles_shared[tw_idx] = twiddles[m + tw_idx]; - } - } - - // ========================================================================= - // Phase 2: Barrier synchronization - // ========================================================================= - // - // Ensure all twiddles are loaded before any thread reads them. - // This is the critical synchronization point. - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // ========================================================================= - // Phase 3: Butterfly computation with shared memory twiddle access - // ========================================================================= - // - // Each thread processes butterflies. Twiddle access is now from fast - // shared memory instead of slow global memory. - - // Offset for this batch in global data - device uint64_t* batch_data = data + batch_idx * N; - - // Number of butterflies total: N/2 - // Each threadgroup handles one batch - uint32_t butterflies_per_thread = (N / 2 + threadgroup_size - 1) / threadgroup_size; - - for (uint32_t b = 0; b < butterflies_per_thread; ++b) { - uint32_t butterfly_idx = thread_idx + b * threadgroup_size; - if (butterfly_idx >= N / 2) break; - - // Compute indices for this butterfly - // For Cooley-Tukey: group i, element j within group - uint32_t group = butterfly_idx / t; - uint32_t elem = butterfly_idx % t; - - uint32_t idx_lo = (group << (params.log_N - stage)) + elem; - uint32_t idx_hi = idx_lo + t; - - // Load data from global memory - uint64_t lo = batch_data[idx_lo]; - uint64_t hi = batch_data[idx_hi]; - - // Load twiddle from SHARED memory (fast!) - uint64_t tw = twiddles_shared[group]; - - // Butterfly: (lo + hi*tw, lo - hi*tw) - uint64_t hi_tw = barrett_mul(hi, tw, Q, mu); - uint64_t new_lo = mod_add(lo, hi_tw, Q); - uint64_t new_hi = mod_sub(lo, hi_tw, Q); - - // Write back to global memory - batch_data[idx_lo] = new_lo; - batch_data[idx_hi] = new_hi; - } -} - -// ============================================================================= -// Inverse NTT Stage (Gentleman-Sande) with Shared Memory Prefetch -// ============================================================================= - -kernel void ntt_inverse_stage_shared( - device uint64_t* data [[buffer(0)]], - device const uint64_t* twiddles [[buffer(1)]], - constant NTTParams& params [[buffer(2)]], - uint thread_idx [[thread_index_in_threadgroup]], - uint threadgroup_size [[threads_per_threadgroup]], - uint threadgroup_idx [[threadgroup_position_in_grid]] -) { - threadgroup uint64_t twiddles_shared[MAX_SHARED_TWIDDLES]; - - uint32_t N = params.N; - uint64_t Q = params.Q; - uint64_t mu = params.mu; - uint32_t stage = params.stage; - uint32_t batch_idx = threadgroup_idx; - - // GS butterfly: m = N / 2^(s+1), t = 2^s - uint32_t m = N >> (stage + 1); - uint32_t t = 1u << stage; - - // Phase 1: Cooperative twiddle prefetch - uint32_t twiddles_to_load = m; - uint32_t loads_per_thread = (twiddles_to_load + threadgroup_size - 1) / threadgroup_size; - - for (uint32_t i = 0; i < loads_per_thread; ++i) { - uint32_t tw_idx = thread_idx + i * threadgroup_size; - if (tw_idx < m && tw_idx < MAX_SHARED_TWIDDLES) { - twiddles_shared[tw_idx] = twiddles[m + tw_idx]; - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Phase 2: Butterfly computation - device uint64_t* batch_data = data + batch_idx * N; - uint32_t butterflies_per_thread = (N / 2 + threadgroup_size - 1) / threadgroup_size; - - for (uint32_t b = 0; b < butterflies_per_thread; ++b) { - uint32_t butterfly_idx = thread_idx + b * threadgroup_size; - if (butterfly_idx >= N / 2) break; - - uint32_t group = butterfly_idx / t; - uint32_t elem = butterfly_idx % t; - - uint32_t idx_lo = (group << (stage + 1)) + elem; - uint32_t idx_hi = idx_lo + t; - - uint64_t lo = batch_data[idx_lo]; - uint64_t hi = batch_data[idx_hi]; - uint64_t tw = twiddles_shared[group]; - - // GS butterfly: (lo + hi, (lo - hi) * tw) - uint64_t sum = mod_add(lo, hi, Q); - uint64_t diff = mod_sub(lo, hi, Q); - uint64_t new_hi = barrett_mul(diff, tw, Q, mu); - - batch_data[idx_lo] = sum; - batch_data[idx_hi] = new_hi; - } -} - -// ============================================================================= -// Multi-Stage Fused Kernel (Advanced Optimization) -// ============================================================================= -// -// For small N (up to 4096), we can fit all twiddles in shared memory and -// process multiple stages without returning to global memory for twiddles. -// -// This eliminates log_N kernel launches and reduces global memory traffic. - -kernel void ntt_forward_fused( - device uint64_t* data [[buffer(0)]], - device const uint64_t* twiddles_flat [[buffer(1)]], - device const uint32_t* stage_offsets [[buffer(2)]], - constant NTTParams& params [[buffer(3)]], - uint thread_idx [[thread_index_in_threadgroup]], - uint threadgroup_size [[threads_per_threadgroup]], - uint threadgroup_idx [[threadgroup_position_in_grid]] -) { - // For N=4096, all twiddles fit in shared memory - threadgroup uint64_t twiddles_shared[MAX_SHARED_TWIDDLES]; - - uint32_t N = params.N; - uint64_t Q = params.Q; - uint64_t mu = params.mu; - uint32_t log_N = params.log_N; - uint32_t batch_idx = threadgroup_idx; - - device uint64_t* batch_data = data + batch_idx * N; - - // Phase 1: Prefetch ALL twiddles for all stages - // Total twiddles needed: N-1 (sum of 2^0 + 2^1 + ... + 2^(log_N-1)) - uint32_t total_twiddles = N - 1; - uint32_t loads_per_thread = (total_twiddles + threadgroup_size - 1) / threadgroup_size; - - for (uint32_t i = 0; i < loads_per_thread; ++i) { - uint32_t tw_idx = thread_idx + i * threadgroup_size; - if (tw_idx < total_twiddles && tw_idx < MAX_SHARED_TWIDDLES) { - twiddles_shared[tw_idx] = twiddles_flat[tw_idx]; - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Phase 2: Process all stages - for (uint32_t stage = 0; stage < log_N; ++stage) { - uint32_t m = 1u << stage; - uint32_t t = N >> (stage + 1); - - // Get twiddle offset for this stage from shared memory - // Twiddles for stage s are at indices [m, 2m) in standard layout - // or at stage_offsets[s] in stage-indexed layout - uint32_t tw_base = m; // Standard OpenFHE layout: twiddles[m + i] - - uint32_t butterflies_per_thread = (N / 2 + threadgroup_size - 1) / threadgroup_size; - - for (uint32_t b = 0; b < butterflies_per_thread; ++b) { - uint32_t butterfly_idx = thread_idx + b * threadgroup_size; - if (butterfly_idx >= N / 2) break; - - uint32_t group = butterfly_idx / t; - uint32_t elem = butterfly_idx % t; - - uint32_t idx_lo = (group << (log_N - stage)) + elem; - uint32_t idx_hi = idx_lo + t; - - uint64_t lo = batch_data[idx_lo]; - uint64_t hi = batch_data[idx_hi]; - uint64_t tw = twiddles_shared[tw_base + group]; - - uint64_t hi_tw = barrett_mul(hi, tw, Q, mu); - uint64_t new_lo = mod_add(lo, hi_tw, Q); - uint64_t new_hi = mod_sub(lo, hi_tw, Q); - - batch_data[idx_lo] = new_lo; - batch_data[idx_hi] = new_hi; - } - - // Synchronize between stages - threadgroup_barrier(mem_flags::mem_device); - } -} - -// ============================================================================= -// N^{-1} Scaling Kernel for INTT -// ============================================================================= - -kernel void ntt_scale_ninv( - device uint64_t* data [[buffer(0)]], - constant NTTParams& params [[buffer(1)]], - uint global_idx [[thread_position_in_grid]] -) { - uint32_t total = params.N * params.batch; - if (global_idx >= total) return; - - uint64_t val = data[global_idx]; - uint64_t scaled = barrett_mul(val, params.N_inv, params.Q, params.mu); - data[global_idx] = scaled; -} - -// ============================================================================= -// Pointwise Modular Multiplication -// ============================================================================= - -kernel void pointwise_mul_mod( - device uint64_t* result [[buffer(0)]], - device const uint64_t* a [[buffer(1)]], - device const uint64_t* b [[buffer(2)]], - constant NTTParams& params [[buffer(3)]], - uint global_idx [[thread_position_in_grid]] -) { - uint32_t total = params.N * params.batch; - if (global_idx >= total) return; - - uint64_t av = a[global_idx]; - uint64_t bv = b[global_idx]; - result[global_idx] = barrett_mul(av, bv, params.Q, params.mu); -} diff --git a/ntt/gpu/metal/ntt_unified_memory.metal b/ntt/gpu/metal/ntt_unified_memory.metal deleted file mode 100644 index 3b3e696..0000000 --- a/ntt/gpu/metal/ntt_unified_memory.metal +++ /dev/null @@ -1,746 +0,0 @@ -// ============================================================================= -// Unified Memory NTT Metal Shaders -// ============================================================================= -// -// Zero-copy NTT kernels for Apple Silicon's unified memory architecture. -// -// Key innovations: -// 1. Direct operation on MTLResourceStorageModeShared buffers -// 2. No explicit memory transfers - CPU and GPU share physical memory -// 3. Double-buffered streaming for overlapped execution -// 4. Persistent twiddle cache eliminates repeated uploads -// -// Memory model: -// - All buffers use StorageModeShared (unified memory) -// - CPU writes are immediately visible to GPU (cache coherent) -// - GPU writes are immediately visible to CPU after command completion -// - No explicit synchronization needed for sequential access -// -// Performance characteristics: -// - Unified memory bandwidth: ~200 GB/s (M3 Pro/Max) -// - Latency: ~100ns (GPU to memory) -// - Zero PCIe transfer overhead (vs discrete GPU) -// -// Copyright (C) 2024-2025 Lux Industries Inc. -// SPDX-License-Identifier: Apache-2.0 - -#include -using namespace metal; - -// ============================================================================= -// Constants and Structures -// ============================================================================= - -// Maximum twiddles in threadgroup shared memory (32KB / 8 bytes) -constant uint32_t MAX_SHARED_TWIDDLES = 4096; - -// Threadgroup sizes optimized for Apple Silicon -constant uint32_t SIMD_WIDTH = 32; -constant uint32_t MAX_THREADGROUP_SIZE = 1024; - -// NTT parameters structure (matches host-side struct) -struct NTTParams { - uint64_t Q; // Prime modulus - uint64_t mu; // Barrett constant: floor(2^64 / Q) - uint64_t N_inv; // N^{-1} mod Q - uint64_t N_inv_precon; // Barrett precomputation for N_inv - uint32_t N; // Ring dimension - uint32_t log_N; // log2(N) - uint32_t stage; // Current NTT stage - uint32_t batch; // Batch size -}; - -// Streaming configuration for double-buffering -struct StreamConfig { - uint32_t buffer_index; // Current buffer (0 or 1) - uint32_t polynomials; // Number of polynomials in batch - uint32_t stages_complete; // Stages completed in current pass - uint32_t total_stages; // Total stages to execute -}; - -// ============================================================================= -// Modular Arithmetic (Optimized for Unified Memory Access) -// ============================================================================= - -// Barrett multiplication: (a * b) mod Q -// Optimized for unified memory where data stays resident -inline uint64_t barrett_mul_unified(uint64_t a, uint64_t b, uint64_t Q, uint64_t mu) { - uint64_t lo = a * b; - uint64_t q = mulhi(lo, mu); - uint64_t result = lo - q * Q; - - // Branch-free conditional subtraction - // Unified memory has good latency, but avoiding branches helps GPU pipeline - uint64_t mask = (result >= Q) ? ~0ULL : 0ULL; - result -= (Q & mask); - - return result; -} - -// Modular addition with branch-free reduction -inline uint64_t mod_add_unified(uint64_t a, uint64_t b, uint64_t Q) { - uint64_t sum = a + b; - uint64_t mask = (sum >= Q) ? ~0ULL : 0ULL; - return sum - (Q & mask); -} - -// Modular subtraction with branch-free correction -inline uint64_t mod_sub_unified(uint64_t a, uint64_t b, uint64_t Q) { - uint64_t diff = a - b; - uint64_t mask = (a < b) ? ~0ULL : 0ULL; - return diff + (Q & mask); -} - -// ============================================================================= -// Core Butterfly Operations -// ============================================================================= - -// Cooley-Tukey butterfly (forward NTT) -// (lo, hi) -> (lo + hi*tw, lo - hi*tw) -inline void ct_butterfly(thread uint64_t& lo, thread uint64_t& hi, - uint64_t tw, uint64_t Q, uint64_t mu) { - uint64_t hi_tw = barrett_mul_unified(hi, tw, Q, mu); - uint64_t new_lo = mod_add_unified(lo, hi_tw, Q); - uint64_t new_hi = mod_sub_unified(lo, hi_tw, Q); - lo = new_lo; - hi = new_hi; -} - -// Gentleman-Sande butterfly (inverse NTT) -// (lo, hi) -> (lo + hi, (lo - hi) * tw) -inline void gs_butterfly(thread uint64_t& lo, thread uint64_t& hi, - uint64_t tw, uint64_t Q, uint64_t mu) { - uint64_t sum = mod_add_unified(lo, hi, Q); - uint64_t diff = mod_sub_unified(lo, hi, Q); - lo = sum; - hi = barrett_mul_unified(diff, tw, Q, mu); -} - -// ============================================================================= -// Unified Memory Forward NTT Stage -// ============================================================================= -// -// Single stage of Cooley-Tukey NTT operating directly on unified memory. -// No explicit memory transfers - data stays in shared physical memory. -// -// Threading model: -// - One threadgroup per batch element -// - Threads cooperatively process butterflies within polynomial -// - Twiddles prefetched to threadgroup memory for fast access - -kernel void unified_ntt_forward_stage( - device uint64_t* data [[buffer(0)]], // Polynomial data (unified memory) - device const uint64_t* twiddles [[buffer(1)]], // Twiddle factors (unified memory, persistent) - constant NTTParams& params [[buffer(2)]], - threadgroup uint64_t* shared_tw [[threadgroup(0)]], // Twiddle cache - uint tid [[thread_position_in_threadgroup]], - uint tg_size [[threads_per_threadgroup]], - uint tg_id [[threadgroup_position_in_grid]], - uint simd_lane [[thread_index_in_simdgroup]], - uint simd_id [[simdgroup_index_in_threadgroup]] -) { - uint32_t N = params.N; - uint64_t Q = params.Q; - uint64_t mu = params.mu; - uint32_t stage = params.stage; - uint32_t batch_idx = tg_id; - - // Stage parameters - uint32_t m = 1u << stage; // Number of twiddle groups - uint32_t t = N >> (stage + 1); // Butterflies per group - - // ========================================================================= - // Phase 1: Cooperative Twiddle Prefetch - // ========================================================================= - // Load twiddles into threadgroup shared memory. - // With unified memory, this is a cache-to-cache transfer (very fast). - // Benefit: Each twiddle loaded once, reused by all threads in group. - - uint32_t tw_to_load = min(m, MAX_SHARED_TWIDDLES); - uint32_t loads_per_thread = (tw_to_load + tg_size - 1) / tg_size; - - for (uint32_t i = 0; i < loads_per_thread; ++i) { - uint32_t tw_idx = tid + i * tg_size; - if (tw_idx < tw_to_load) { - // Twiddles stored as: twiddles[m + i] for stage s with m = 2^s groups - shared_tw[tw_idx] = twiddles[m + tw_idx]; - } - } - - // Barrier: ensure all twiddles loaded before butterfly phase - threadgroup_barrier(mem_flags::mem_threadgroup); - - // ========================================================================= - // Phase 2: Butterfly Computation - // ========================================================================= - // Each thread processes multiple butterflies. - // Unified memory provides coherent access without explicit sync. - - device uint64_t* batch_data = data + batch_idx * N; - uint32_t butterflies_total = N / 2; - uint32_t butterflies_per_thread = (butterflies_total + tg_size - 1) / tg_size; - - for (uint32_t b = 0; b < butterflies_per_thread; ++b) { - uint32_t butterfly_idx = tid + b * tg_size; - if (butterfly_idx >= butterflies_total) break; - - // Compute indices for this butterfly - uint32_t group = butterfly_idx / t; - uint32_t elem = butterfly_idx % t; - - uint32_t idx_lo = (group << (params.log_N - stage)) + elem; - uint32_t idx_hi = idx_lo + t; - - // Load data from unified memory (no explicit transfer) - uint64_t lo = batch_data[idx_lo]; - uint64_t hi = batch_data[idx_hi]; - - // Get twiddle from shared memory (or global if too many) - uint64_t tw = (group < MAX_SHARED_TWIDDLES) ? shared_tw[group] : twiddles[m + group]; - - // Butterfly computation - ct_butterfly(lo, hi, tw, Q, mu); - - // Write back to unified memory (immediately visible after completion) - batch_data[idx_lo] = lo; - batch_data[idx_hi] = hi; - } -} - -// ============================================================================= -// Unified Memory Inverse NTT Stage -// ============================================================================= - -kernel void unified_ntt_inverse_stage( - device uint64_t* data [[buffer(0)]], - device const uint64_t* twiddles [[buffer(1)]], - constant NTTParams& params [[buffer(2)]], - threadgroup uint64_t* shared_tw [[threadgroup(0)]], - uint tid [[thread_position_in_threadgroup]], - uint tg_size [[threads_per_threadgroup]], - uint tg_id [[threadgroup_position_in_grid]] -) { - uint32_t N = params.N; - uint64_t Q = params.Q; - uint64_t mu = params.mu; - uint32_t stage = params.stage; - uint32_t batch_idx = tg_id; - - // GS butterfly parameters - uint32_t m = N >> (stage + 1); // Number of twiddle groups - uint32_t t = 1u << stage; // Butterflies per group - - // Phase 1: Cooperative twiddle prefetch - uint32_t tw_to_load = min(m, MAX_SHARED_TWIDDLES); - uint32_t loads_per_thread = (tw_to_load + tg_size - 1) / tg_size; - - for (uint32_t i = 0; i < loads_per_thread; ++i) { - uint32_t tw_idx = tid + i * tg_size; - if (tw_idx < tw_to_load) { - shared_tw[tw_idx] = twiddles[m + tw_idx]; - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Phase 2: Butterfly computation - device uint64_t* batch_data = data + batch_idx * N; - uint32_t butterflies_total = N / 2; - uint32_t butterflies_per_thread = (butterflies_total + tg_size - 1) / tg_size; - - for (uint32_t b = 0; b < butterflies_per_thread; ++b) { - uint32_t butterfly_idx = tid + b * tg_size; - if (butterfly_idx >= butterflies_total) break; - - uint32_t group = butterfly_idx / t; - uint32_t elem = butterfly_idx % t; - - uint32_t idx_lo = (group << (stage + 1)) + elem; - uint32_t idx_hi = idx_lo + t; - - uint64_t lo = batch_data[idx_lo]; - uint64_t hi = batch_data[idx_hi]; - - uint64_t tw = (group < MAX_SHARED_TWIDDLES) ? shared_tw[group] : twiddles[m + group]; - - gs_butterfly(lo, hi, tw, Q, mu); - - batch_data[idx_lo] = lo; - batch_data[idx_hi] = hi; - } -} - -// ============================================================================= -// Fused Multi-Stage NTT (Optimal for Small N) -// ============================================================================= -// -// For N <= 4096, all twiddles fit in threadgroup memory. -// This kernel executes all log(N) stages without returning to host, -// eliminating kernel launch overhead between stages. -// -// Key optimization: Twiddles loaded once, stages execute in sequence. - -kernel void unified_ntt_forward_fused( - device uint64_t* data [[buffer(0)]], - device const uint64_t* twiddles [[buffer(1)]], - constant NTTParams& params [[buffer(2)]], - threadgroup uint64_t* shared_tw [[threadgroup(0)]], - threadgroup uint64_t* shared_data [[threadgroup(1)]], // Local polynomial copy - uint tid [[thread_position_in_threadgroup]], - uint tg_size [[threads_per_threadgroup]], - uint tg_id [[threadgroup_position_in_grid]] -) { - uint32_t N = params.N; - uint64_t Q = params.Q; - uint64_t mu = params.mu; - uint32_t log_N = params.log_N; - uint32_t batch_idx = tg_id; - - device uint64_t* batch_data = data + batch_idx * N; - - // ========================================================================= - // Phase 1: Load ALL twiddles into shared memory - // ========================================================================= - // For N=4096, we need ~4095 twiddles = 32KB (fits in M3 shared memory) - - uint32_t total_twiddles = N - 1; // Sum of 2^0 + 2^1 + ... + 2^(log_N-1) - uint32_t loads_per_thread = (total_twiddles + tg_size - 1) / tg_size; - - for (uint32_t i = 0; i < loads_per_thread; ++i) { - uint32_t tw_idx = tid + i * tg_size; - if (tw_idx < total_twiddles && tw_idx < MAX_SHARED_TWIDDLES) { - // Twiddles stored contiguously starting at index 1 - shared_tw[tw_idx] = twiddles[tw_idx + 1]; - } - } - - // ========================================================================= - // Phase 2: Load polynomial into shared memory - // ========================================================================= - // Coalesced load for maximum bandwidth utilization - - uint32_t loads_data = (N + tg_size - 1) / tg_size; - for (uint32_t i = 0; i < loads_data; ++i) { - uint32_t idx = tid + i * tg_size; - if (idx < N) { - shared_data[idx] = batch_data[idx]; - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // ========================================================================= - // Phase 3: Execute ALL NTT stages in shared memory - // ========================================================================= - - for (uint32_t stage = 0; stage < log_N; ++stage) { - uint32_t m = 1u << stage; - uint32_t t = N >> (stage + 1); - uint32_t tw_offset = m - 1; // Offset to stage's twiddles in shared array - - uint32_t butterflies_per_thread = (N / 2 + tg_size - 1) / tg_size; - - for (uint32_t b = 0; b < butterflies_per_thread; ++b) { - uint32_t butterfly_idx = tid + b * tg_size; - if (butterfly_idx >= N / 2) break; - - uint32_t group = butterfly_idx / t; - uint32_t elem = butterfly_idx % t; - - uint32_t idx_lo = (group << (log_N - stage)) + elem; - uint32_t idx_hi = idx_lo + t; - - // All accesses from shared memory (ultra-fast) - uint64_t lo = shared_data[idx_lo]; - uint64_t hi = shared_data[idx_hi]; - uint64_t tw = shared_tw[tw_offset + group]; - - ct_butterfly(lo, hi, tw, Q, mu); - - shared_data[idx_lo] = lo; - shared_data[idx_hi] = hi; - } - - // Barrier between stages - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - // ========================================================================= - // Phase 4: Write result back to unified memory - // ========================================================================= - - for (uint32_t i = 0; i < loads_data; ++i) { - uint32_t idx = tid + i * tg_size; - if (idx < N) { - batch_data[idx] = shared_data[idx]; - } - } -} - -// ============================================================================= -// Fused Inverse NTT -// ============================================================================= - -kernel void unified_ntt_inverse_fused( - device uint64_t* data [[buffer(0)]], - device const uint64_t* twiddles [[buffer(1)]], - constant NTTParams& params [[buffer(2)]], - threadgroup uint64_t* shared_tw [[threadgroup(0)]], - threadgroup uint64_t* shared_data [[threadgroup(1)]], - uint tid [[thread_position_in_threadgroup]], - uint tg_size [[threads_per_threadgroup]], - uint tg_id [[threadgroup_position_in_grid]] -) { - uint32_t N = params.N; - uint64_t Q = params.Q; - uint64_t mu = params.mu; - uint64_t N_inv = params.N_inv; - uint32_t log_N = params.log_N; - uint32_t batch_idx = tg_id; - - device uint64_t* batch_data = data + batch_idx * N; - - // Load twiddles - uint32_t total_twiddles = N - 1; - uint32_t loads_per_thread = (total_twiddles + tg_size - 1) / tg_size; - - for (uint32_t i = 0; i < loads_per_thread; ++i) { - uint32_t tw_idx = tid + i * tg_size; - if (tw_idx < total_twiddles && tw_idx < MAX_SHARED_TWIDDLES) { - shared_tw[tw_idx] = twiddles[tw_idx + 1]; - } - } - - // Load polynomial - uint32_t loads_data = (N + tg_size - 1) / tg_size; - for (uint32_t i = 0; i < loads_data; ++i) { - uint32_t idx = tid + i * tg_size; - if (idx < N) { - shared_data[idx] = batch_data[idx]; - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Execute all inverse stages - for (uint32_t stage = 0; stage < log_N; ++stage) { - uint32_t m = N >> (stage + 1); - uint32_t t = 1u << stage; - uint32_t tw_offset = m - 1; - - uint32_t butterflies_per_thread = (N / 2 + tg_size - 1) / tg_size; - - for (uint32_t b = 0; b < butterflies_per_thread; ++b) { - uint32_t butterfly_idx = tid + b * tg_size; - if (butterfly_idx >= N / 2) break; - - uint32_t group = butterfly_idx / t; - uint32_t elem = butterfly_idx % t; - - uint32_t idx_lo = (group << (stage + 1)) + elem; - uint32_t idx_hi = idx_lo + t; - - uint64_t lo = shared_data[idx_lo]; - uint64_t hi = shared_data[idx_hi]; - uint64_t tw = shared_tw[tw_offset + group]; - - gs_butterfly(lo, hi, tw, Q, mu); - - shared_data[idx_lo] = lo; - shared_data[idx_hi] = hi; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - // Scale by N^{-1} and write back - for (uint32_t i = 0; i < loads_data; ++i) { - uint32_t idx = tid + i * tg_size; - if (idx < N) { - uint64_t val = shared_data[idx]; - batch_data[idx] = barrett_mul_unified(val, N_inv, Q, mu); - } - } -} - -// ============================================================================= -// N^{-1} Scaling Kernel (for staged inverse NTT) -// ============================================================================= - -kernel void unified_scale_ninv( - device uint64_t* data [[buffer(0)]], - constant NTTParams& params [[buffer(1)]], - uint tid [[thread_position_in_grid]], - uint total_threads [[threads_per_grid]] -) { - uint32_t total = params.N * params.batch; - uint64_t Q = params.Q; - uint64_t mu = params.mu; - uint64_t N_inv = params.N_inv; - - for (uint32_t i = tid; i < total; i += total_threads) { - data[i] = barrett_mul_unified(data[i], N_inv, Q, mu); - } -} - -// ============================================================================= -// Pointwise Modular Multiplication -// ============================================================================= -// Operates directly on unified memory - result immediately available to CPU. - -kernel void unified_pointwise_mul( - device uint64_t* result [[buffer(0)]], - device const uint64_t* a [[buffer(1)]], - device const uint64_t* b [[buffer(2)]], - constant NTTParams& params [[buffer(3)]], - uint tid [[thread_position_in_grid]], - uint total_threads [[threads_per_grid]] -) { - uint32_t total = params.N * params.batch; - uint64_t Q = params.Q; - uint64_t mu = params.mu; - - for (uint32_t i = tid; i < total; i += total_threads) { - result[i] = barrett_mul_unified(a[i], b[i], Q, mu); - } -} - -// ============================================================================= -// Pointwise Modular Addition -// ============================================================================= - -kernel void unified_pointwise_add( - device uint64_t* result [[buffer(0)]], - device const uint64_t* a [[buffer(1)]], - device const uint64_t* b [[buffer(2)]], - constant NTTParams& params [[buffer(3)]], - uint tid [[thread_position_in_grid]], - uint total_threads [[threads_per_grid]] -) { - uint32_t total = params.N * params.batch; - uint64_t Q = params.Q; - - for (uint32_t i = tid; i < total; i += total_threads) { - result[i] = mod_add_unified(a[i], b[i], Q); - } -} - -// ============================================================================= -// Pointwise Modular Subtraction -// ============================================================================= - -kernel void unified_pointwise_sub( - device uint64_t* result [[buffer(0)]], - device const uint64_t* a [[buffer(1)]], - device const uint64_t* b [[buffer(2)]], - constant NTTParams& params [[buffer(3)]], - uint tid [[thread_position_in_grid]], - uint total_threads [[threads_per_grid]] -) { - uint32_t total = params.N * params.batch; - uint64_t Q = params.Q; - - for (uint32_t i = tid; i < total; i += total_threads) { - result[i] = mod_sub_unified(a[i], b[i], Q); - } -} - -// ============================================================================= -// Double-Buffer Streaming Support -// ============================================================================= -// -// These kernels support overlapped execution using double-buffering. -// While GPU processes buffer A, CPU can prepare data in buffer B. - -struct DoubleBufferParams { - uint32_t active_buffer; // 0 or 1 - uint32_t buffer_size; // N * batch per buffer - uint32_t ready_flag; // Set when buffer is ready for GPU - uint32_t complete_flag; // Set when GPU is done -}; - -kernel void unified_ntt_forward_stream( - device uint64_t* buffer0 [[buffer(0)]], - device uint64_t* buffer1 [[buffer(1)]], - device const uint64_t* twiddles [[buffer(2)]], - constant NTTParams& params [[buffer(3)]], - constant DoubleBufferParams& stream [[buffer(4)]], - threadgroup uint64_t* shared_tw [[threadgroup(0)]], - uint tid [[thread_position_in_threadgroup]], - uint tg_size [[threads_per_threadgroup]], - uint tg_id [[threadgroup_position_in_grid]] -) { - // Select active buffer - device uint64_t* data = (stream.active_buffer == 0) ? buffer0 : buffer1; - - // Execute single stage (called log_N times by host) - uint32_t N = params.N; - uint64_t Q = params.Q; - uint64_t mu = params.mu; - uint32_t stage = params.stage; - uint32_t batch_idx = tg_id; - - uint32_t m = 1u << stage; - uint32_t t = N >> (stage + 1); - - // Prefetch twiddles - uint32_t tw_to_load = min(m, MAX_SHARED_TWIDDLES); - uint32_t loads_per_thread = (tw_to_load + tg_size - 1) / tg_size; - - for (uint32_t i = 0; i < loads_per_thread; ++i) { - uint32_t tw_idx = tid + i * tg_size; - if (tw_idx < tw_to_load) { - shared_tw[tw_idx] = twiddles[m + tw_idx]; - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Butterfly computation - device uint64_t* batch_data = data + batch_idx * N; - uint32_t butterflies_per_thread = (N / 2 + tg_size - 1) / tg_size; - - for (uint32_t b = 0; b < butterflies_per_thread; ++b) { - uint32_t butterfly_idx = tid + b * tg_size; - if (butterfly_idx >= N / 2) break; - - uint32_t group = butterfly_idx / t; - uint32_t elem = butterfly_idx % t; - - uint32_t idx_lo = (group << (params.log_N - stage)) + elem; - uint32_t idx_hi = idx_lo + t; - - uint64_t lo = batch_data[idx_lo]; - uint64_t hi = batch_data[idx_hi]; - uint64_t tw = (group < MAX_SHARED_TWIDDLES) ? shared_tw[group] : twiddles[m + group]; - - ct_butterfly(lo, hi, tw, Q, mu); - - batch_data[idx_lo] = lo; - batch_data[idx_hi] = hi; - } -} - -// ============================================================================= -// SIMD-Optimized Butterfly (For SIMD-Width Processing) -// ============================================================================= -// -// Process 32 butterflies in parallel using SIMD groups. -// Optimal for large polynomials with many independent butterflies. - -kernel void unified_ntt_forward_simd( - device uint64_t* data [[buffer(0)]], - device const uint64_t* twiddles [[buffer(1)]], - constant NTTParams& params [[buffer(2)]], - uint simd_lane [[thread_index_in_simdgroup]], - uint simd_id [[simdgroup_index_in_threadgroup]], - uint tg_id [[threadgroup_position_in_grid]], - uint simd_groups [[simdgroups_per_threadgroup]] -) { - uint32_t N = params.N; - uint64_t Q = params.Q; - uint64_t mu = params.mu; - uint32_t stage = params.stage; - uint32_t batch_idx = tg_id; - - uint32_t m = 1u << stage; - uint32_t t = N >> (stage + 1); - - device uint64_t* batch_data = data + batch_idx * N; - - // Each SIMD group processes SIMD_WIDTH butterflies - uint32_t butterflies_per_simd = SIMD_WIDTH; - uint32_t simd_groups_needed = (N / 2 + butterflies_per_simd - 1) / butterflies_per_simd; - - for (uint32_t sg = simd_id; sg < simd_groups_needed; sg += simd_groups) { - uint32_t butterfly_idx = sg * SIMD_WIDTH + simd_lane; - if (butterfly_idx >= N / 2) continue; - - uint32_t group = butterfly_idx / t; - uint32_t elem = butterfly_idx % t; - - uint32_t idx_lo = (group << (params.log_N - stage)) + elem; - uint32_t idx_hi = idx_lo + t; - - uint64_t lo = batch_data[idx_lo]; - uint64_t hi = batch_data[idx_hi]; - uint64_t tw = twiddles[m + group]; - - ct_butterfly(lo, hi, tw, Q, mu); - - batch_data[idx_lo] = lo; - batch_data[idx_hi] = hi; - } -} - -// ============================================================================= -// Memory Copy Utilities (For Hybrid Approaches) -// ============================================================================= -// -// Even with unified memory, explicit copy can help for: -// 1. Warming up cache hierarchy -// 2. Prefetching next batch while processing current -// 3. Reordering data for coalesced access - -kernel void unified_memcpy( - device uint64_t* dst [[buffer(0)]], - device const uint64_t* src [[buffer(1)]], - constant uint32_t& count [[buffer(2)]], - uint tid [[thread_position_in_grid]], - uint total_threads [[threads_per_grid]] -) { - for (uint32_t i = tid; i < count; i += total_threads) { - dst[i] = src[i]; - } -} - -// Vectorized copy using SIMD (4x uint64_t at a time) -kernel void unified_memcpy_vec4( - device uint64_t* dst [[buffer(0)]], - device const uint64_t* src [[buffer(1)]], - constant uint32_t& count [[buffer(2)]], - uint tid [[thread_position_in_grid]], - uint total_threads [[threads_per_grid]] -) { - uint32_t vec_count = count / 4; - - // Process 4 elements at a time - for (uint32_t i = tid; i < vec_count; i += total_threads) { - uint32_t base = i * 4; - dst[base + 0] = src[base + 0]; - dst[base + 1] = src[base + 1]; - dst[base + 2] = src[base + 2]; - dst[base + 3] = src[base + 3]; - } - - // Handle remainder - uint32_t remainder_start = vec_count * 4; - for (uint32_t i = remainder_start + tid; i < count; i += total_threads) { - dst[i] = src[i]; - } -} - -// ============================================================================= -// Benchmark Kernel (Measure Unified Memory Bandwidth) -// ============================================================================= - -kernel void benchmark_unified_bandwidth( - device uint64_t* data [[buffer(0)]], - constant uint32_t& count [[buffer(1)]], - constant uint32_t& iterations [[buffer(2)]], - uint tid [[thread_position_in_grid]], - uint total_threads [[threads_per_grid]] -) { - uint64_t sum = 0; - - for (uint32_t iter = 0; iter < iterations; ++iter) { - for (uint32_t i = tid; i < count; i += total_threads) { - sum += data[i]; - } - } - - // Prevent optimization - if (sum == 0xDEADBEEF) { - data[tid] = sum; - } -} diff --git a/ntt/gpu/metal/twiddle_cache.metal b/ntt/gpu/metal/twiddle_cache.metal deleted file mode 100644 index 8134886..0000000 --- a/ntt/gpu/metal/twiddle_cache.metal +++ /dev/null @@ -1,545 +0,0 @@ -// ============================================================================= -// Twiddle Hotset Caching Kernels for Apple Metal -// ============================================================================= -// -// High-performance NTT kernels with intelligent twiddle caching. -// -// Key innovations: -// 1. Hotset identification: Early stages use tiny twiddle sets that fit in -// constant/threadgroup memory entirely -// 2. Prefetch hints: Load next stage's twiddles during current stage compute -// 3. LRU eviction: Smart eviction for multi-modulus RNS scenarios -// 4. Bank conflict avoidance: Padded storage to eliminate shared memory conflicts -// -// Memory hierarchy utilization: -// - Constant memory: First-level twiddles (8 values), modular constants -// - Threadgroup memory: Stage-specific twiddles with prefetch -// - Registers: Current butterfly operands and twiddle -// -// For N=1024: -// Stage 0: 1 twiddle -> constant memory (4 cycles) -// Stage 1: 2 twiddles -> constant memory (4 cycles) -// Stage 2: 4 twiddles -> constant memory (4 cycles) -// Stage 3: 8 twiddles -> constant memory (4 cycles) -// Stage 4+: threadgroup prefetch (20-30 cycles) -// -// Copyright (C) 2024-2025 Lux Industries Inc. -// SPDX-License-Identifier: BSD-2-Clause -// ============================================================================= - -#include -using namespace metal; - -// ============================================================================= -// Configuration Constants -// ============================================================================= - -/// Maximum twiddles in threadgroup shared memory (32KB / 8 bytes) -constant uint32_t MAX_THREADGROUP_TWIDDLES = 4096; - -/// First-level twiddles stored in constant memory per prime -constant uint32_t FIRST_LEVEL_TWIDDLE_COUNT = 8; - -/// Maximum RNS primes supported -constant uint32_t MAX_RNS_PRIMES = 16; - -/// Threadgroup memory bank width for conflict avoidance -constant uint32_t BANK_WIDTH = 32; // 32 banks of 4 bytes each - -/// Padding for bank conflict avoidance -constant uint32_t BANK_PADDING = 1; - -// ============================================================================= -// Modular Arithmetic Constants (Constant Memory Tier) -// ============================================================================= - -struct PrimeConstants { - uint64_t q; // Prime modulus - uint64_t q_inv; // -q^(-1) mod 2^64 (Montgomery) - uint64_t mu_hi; // Barrett high bits - uint64_t mu_lo; // Barrett low bits - uint64_t r_squared; // R^2 mod q - uint64_t root; // Primitive root - uint64_t root_inv; // Inverse root - uint64_t n_inv; // N^(-1) mod q -}; - -/// Constant memory cache structure -struct ConstantCache { - uint32_t numPrimes; - uint32_t ringDim; - uint32_t padding[2]; - - PrimeConstants primes[MAX_RNS_PRIMES]; - uint64_t firstLevelTwiddles[MAX_RNS_PRIMES][FIRST_LEVEL_TWIDDLE_COUNT]; - uint64_t firstLevelInvTwiddles[MAX_RNS_PRIMES][FIRST_LEVEL_TWIDDLE_COUNT]; -}; - -// ============================================================================= -// NTT Parameters -// ============================================================================= - -struct NTTParams { - uint64_t Q; // Prime modulus - uint64_t mu; // Barrett constant (mu_hi) - uint64_t N_inv; // N^(-1) mod Q - uint64_t N_inv_precon; // Precomputed for Barrett - uint32_t N; // Ring dimension - uint32_t log_N; // log2(N) - uint32_t stage; // Current stage - uint32_t primeIdx; // Prime index in RNS - uint32_t batch; // Batch size - uint32_t prefetchStage; // Next stage to prefetch (-1 if none) -}; - -// ============================================================================= -// Modular Arithmetic Functions -// ============================================================================= - -/// Barrett multiplication: (a * b) mod Q -inline uint64_t barrett_mul(uint64_t a, uint64_t b, uint64_t Q, uint64_t mu) { - // Low 64 bits of a * b - uint64_t lo = a * b; - - // Approximate quotient using mulhi - uint64_t q = mulhi(lo, mu); - - // result = a * b - q * Q - uint64_t result = lo - q * Q; - - // Conditional subtraction for exact result - if (result >= Q) result -= Q; - - return result; -} - -/// Modular addition: (a + b) mod Q -inline uint64_t mod_add(uint64_t a, uint64_t b, uint64_t Q) { - uint64_t sum = a + b; - return (sum >= Q) ? sum - Q : sum; -} - -/// Modular subtraction: (a - b) mod Q -inline uint64_t mod_sub(uint64_t a, uint64_t b, uint64_t Q) { - return (a >= b) ? a - b : a + Q - b; -} - -// ============================================================================= -// Bank Conflict Avoidance Helper -// ============================================================================= - -/// Compute padded index to avoid bank conflicts -inline uint32_t padded_index(uint32_t idx) { - // Add padding every BANK_WIDTH elements - return idx + (idx / BANK_WIDTH) * BANK_PADDING; -} - -// ============================================================================= -// Kernel: Single Stage NTT with Hotset Caching -// ============================================================================= -// -// This kernel processes one NTT stage with intelligent twiddle caching. -// -// For stages 0-3: Uses constant memory twiddles (zero global loads) -// For stages 4+: Cooperative threadgroup load with prefetch hints -// -// Thread organization: -// - One threadgroup per polynomial in the batch -// - Each thread processes multiple butterflies - -kernel void ntt_hotset_forward_stage( - device uint64_t* data [[buffer(0)]], - constant uint64_t* twiddles [[buffer(1)]], // Device memory twiddles - constant ConstantCache& cache [[buffer(2)]], // Constant memory cache - constant NTTParams& params [[buffer(3)]], - uint thread_idx [[thread_index_in_threadgroup]], - uint threadgroup_size [[threads_per_threadgroup]], - uint threadgroup_idx [[threadgroup_position_in_grid]] -) { - // Threadgroup shared memory with padding for bank conflict avoidance - threadgroup uint64_t twiddles_shared[MAX_THREADGROUP_TWIDDLES + MAX_THREADGROUP_TWIDDLES / BANK_WIDTH]; - // Prefetch buffer for next stage - threadgroup uint64_t twiddles_prefetch[MAX_THREADGROUP_TWIDDLES + MAX_THREADGROUP_TWIDDLES / BANK_WIDTH]; - - uint32_t N = params.N; - uint64_t Q = params.Q; - uint64_t mu = params.mu; - uint32_t stage = params.stage; - uint32_t primeIdx = params.primeIdx; - uint32_t batch_idx = threadgroup_idx; - - // Stage parameters - uint32_t m = 1u << stage; // Number of twiddle factors - uint32_t t = N >> (stage + 1); // Butterflies per twiddle - - device uint64_t* batch_data = data + batch_idx * N; - - // ========================================================================= - // Phase 1: Determine twiddle source and load strategy - // ========================================================================= - - bool use_constant_memory = (stage < 4 && m <= FIRST_LEVEL_TWIDDLE_COUNT); - - if (!use_constant_memory) { - // Cooperative load into threadgroup memory with padding - uint32_t twiddles_to_load = m; - uint32_t loads_per_thread = (twiddles_to_load + threadgroup_size - 1) / threadgroup_size; - - for (uint32_t i = 0; i < loads_per_thread; ++i) { - uint32_t tw_idx = thread_idx + i * threadgroup_size; - if (tw_idx < m) { - uint32_t padded = padded_index(tw_idx); - twiddles_shared[padded] = twiddles[m + tw_idx]; - } - } - - // Prefetch next stage if enabled - if (params.prefetchStage < params.log_N && params.prefetchStage > stage) { - uint32_t next_m = 1u << params.prefetchStage; - uint32_t prefetch_loads = (next_m + threadgroup_size - 1) / threadgroup_size; - - for (uint32_t i = 0; i < prefetch_loads; ++i) { - uint32_t tw_idx = thread_idx + i * threadgroup_size; - if (tw_idx < next_m && tw_idx < MAX_THREADGROUP_TWIDDLES) { - uint32_t padded = padded_index(tw_idx); - twiddles_prefetch[padded] = twiddles[next_m + tw_idx]; - } - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - // ========================================================================= - // Phase 2: Butterfly computation - // ========================================================================= - - uint32_t butterflies_per_thread = (N / 2 + threadgroup_size - 1) / threadgroup_size; - - for (uint32_t b = 0; b < butterflies_per_thread; ++b) { - uint32_t butterfly_idx = thread_idx + b * threadgroup_size; - if (butterfly_idx >= N / 2) break; - - // Compute butterfly indices - uint32_t group = butterfly_idx / t; - uint32_t elem = butterfly_idx % t; - uint32_t idx_lo = (group << (params.log_N - stage)) + elem; - uint32_t idx_hi = idx_lo + t; - - // Load data - uint64_t lo = batch_data[idx_lo]; - uint64_t hi = batch_data[idx_hi]; - - // Get twiddle from appropriate cache tier - uint64_t tw; - if (use_constant_memory) { - // L1 cache tier - instant access - tw = cache.firstLevelTwiddles[primeIdx][group]; - } else { - // L2 cache tier - threadgroup memory - uint32_t padded = padded_index(group); - tw = twiddles_shared[padded]; - } - - // Butterfly operation - uint64_t hi_tw = barrett_mul(hi, tw, Q, mu); - uint64_t new_lo = mod_add(lo, hi_tw, Q); - uint64_t new_hi = mod_sub(lo, hi_tw, Q); - - // Write results - batch_data[idx_lo] = new_lo; - batch_data[idx_hi] = new_hi; - } -} - -// ============================================================================= -// Kernel: Inverse NTT Stage with Hotset Caching -// ============================================================================= - -kernel void ntt_hotset_inverse_stage( - device uint64_t* data [[buffer(0)]], - constant uint64_t* twiddles [[buffer(1)]], - constant ConstantCache& cache [[buffer(2)]], - constant NTTParams& params [[buffer(3)]], - uint thread_idx [[thread_index_in_threadgroup]], - uint threadgroup_size [[threads_per_threadgroup]], - uint threadgroup_idx [[threadgroup_position_in_grid]] -) { - threadgroup uint64_t twiddles_shared[MAX_THREADGROUP_TWIDDLES + MAX_THREADGROUP_TWIDDLES / BANK_WIDTH]; - - uint32_t N = params.N; - uint64_t Q = params.Q; - uint64_t mu = params.mu; - uint32_t stage = params.stage; - uint32_t primeIdx = params.primeIdx; - uint32_t batch_idx = threadgroup_idx; - - // Gentleman-Sande parameters - uint32_t m = N >> (stage + 1); - uint32_t t = 1u << stage; - - device uint64_t* batch_data = data + batch_idx * N; - - bool use_constant_memory = (stage >= params.log_N - 4 && m <= FIRST_LEVEL_TWIDDLE_COUNT); - - if (!use_constant_memory) { - uint32_t twiddles_to_load = m; - uint32_t loads_per_thread = (twiddles_to_load + threadgroup_size - 1) / threadgroup_size; - - for (uint32_t i = 0; i < loads_per_thread; ++i) { - uint32_t tw_idx = thread_idx + i * threadgroup_size; - if (tw_idx < m) { - uint32_t padded = padded_index(tw_idx); - twiddles_shared[padded] = twiddles[m + tw_idx]; - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - uint32_t butterflies_per_thread = (N / 2 + threadgroup_size - 1) / threadgroup_size; - - for (uint32_t b = 0; b < butterflies_per_thread; ++b) { - uint32_t butterfly_idx = thread_idx + b * threadgroup_size; - if (butterfly_idx >= N / 2) break; - - uint32_t group = butterfly_idx / t; - uint32_t elem = butterfly_idx % t; - uint32_t idx_lo = (group << (stage + 1)) + elem; - uint32_t idx_hi = idx_lo + t; - - uint64_t lo = batch_data[idx_lo]; - uint64_t hi = batch_data[idx_hi]; - - uint64_t tw; - if (use_constant_memory) { - tw = cache.firstLevelInvTwiddles[primeIdx][group]; - } else { - uint32_t padded = padded_index(group); - tw = twiddles_shared[padded]; - } - - // Gentleman-Sande butterfly - uint64_t sum = mod_add(lo, hi, Q); - uint64_t diff = mod_sub(lo, hi, Q); - uint64_t new_hi = barrett_mul(diff, tw, Q, mu); - - batch_data[idx_lo] = sum; - batch_data[idx_hi] = new_hi; - } -} - -// ============================================================================= -// Kernel: Multi-Stage Fused NTT with Full Hotset -// ============================================================================= -// -// For N <= 4096, ALL twiddles fit in threadgroup memory. -// This kernel processes all log_N stages in a single dispatch. - -kernel void ntt_hotset_fused( - device uint64_t* data [[buffer(0)]], - constant uint64_t* twiddles_flat [[buffer(1)]], - constant ConstantCache& cache [[buffer(2)]], - constant NTTParams& params [[buffer(3)]], - uint thread_idx [[thread_index_in_threadgroup]], - uint threadgroup_size [[threads_per_threadgroup]], - uint threadgroup_idx [[threadgroup_position_in_grid]] -) { - // All twiddles for N<=4096 fit in shared memory - threadgroup uint64_t twiddles_shared[MAX_THREADGROUP_TWIDDLES]; - - uint32_t N = params.N; - uint64_t Q = params.Q; - uint64_t mu = params.mu; - uint32_t log_N = params.log_N; - uint32_t primeIdx = params.primeIdx; - uint32_t batch_idx = threadgroup_idx; - - device uint64_t* batch_data = data + batch_idx * N; - - // ========================================================================= - // Phase 1: Load ALL twiddles into threadgroup memory (one-time cost) - // ========================================================================= - - // Total twiddles needed: 1 + 2 + 4 + ... + N/2 = N - 1 - uint32_t total_twiddles = N - 1; - uint32_t loads_per_thread = (total_twiddles + threadgroup_size - 1) / threadgroup_size; - - for (uint32_t i = 0; i < loads_per_thread; ++i) { - uint32_t tw_idx = thread_idx + i * threadgroup_size; - if (tw_idx < total_twiddles) { - twiddles_shared[tw_idx] = twiddles_flat[tw_idx]; - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // ========================================================================= - // Phase 2: Process all stages from threadgroup memory - // ========================================================================= - - for (uint32_t stage = 0; stage < log_N; ++stage) { - uint32_t m = 1u << stage; - uint32_t t = N >> (stage + 1); - uint32_t tw_base = m; // Standard layout: twiddles[m + i] - - uint32_t butterflies_per_thread = (N / 2 + threadgroup_size - 1) / threadgroup_size; - - for (uint32_t b = 0; b < butterflies_per_thread; ++b) { - uint32_t butterfly_idx = thread_idx + b * threadgroup_size; - if (butterfly_idx >= N / 2) break; - - uint32_t group = butterfly_idx / t; - uint32_t elem = butterfly_idx % t; - uint32_t idx_lo = (group << (log_N - stage)) + elem; - uint32_t idx_hi = idx_lo + t; - - uint64_t lo = batch_data[idx_lo]; - uint64_t hi = batch_data[idx_hi]; - uint64_t tw = twiddles_shared[tw_base + group]; - - uint64_t hi_tw = barrett_mul(hi, tw, Q, mu); - uint64_t new_lo = mod_add(lo, hi_tw, Q); - uint64_t new_hi = mod_sub(lo, hi_tw, Q); - - batch_data[idx_lo] = new_lo; - batch_data[idx_hi] = new_hi; - } - - // Barrier between stages to ensure memory coherence - threadgroup_barrier(mem_flags::mem_device); - } -} - -// ============================================================================= -// Kernel: RNS Multi-Prime NTT with Hotset Caching -// ============================================================================= -// -// Processes NTT for multiple RNS primes in parallel. -// Uses twiddle-major layout for coalesced access across primes. - -kernel void ntt_hotset_rns_stage( - device uint64_t* data [[buffer(0)]], // [batch, numPrimes, N] - constant uint64_t* twiddles [[buffer(1)]], // [N, numPrimes] twiddle-major - constant ConstantCache& cache [[buffer(2)]], - constant NTTParams& params [[buffer(3)]], - uint3 tid [[thread_position_in_grid]], - uint3 gid [[threadgroup_position_in_grid]], - uint3 tgSize [[threads_per_threadgroup]] -) { - // Thread assignment: x=element, y=prime, z=batch - uint32_t elemIdx = tid.x; - uint32_t primeIdx = tid.y; - uint32_t batchIdx = tid.z; - - uint32_t N = params.N; - uint32_t log_N = params.log_N; - uint32_t numPrimes = cache.numPrimes; - - // Get prime-specific constants from constant memory - PrimeConstants pc = cache.primes[primeIdx]; - uint64_t Q = pc.q; - uint64_t mu = pc.mu_hi; - - // Per-prime threadgroup twiddle cache - threadgroup uint64_t prime_twiddles[512]; // Max 512 per prime per stage - - uint32_t stage = params.stage; - uint32_t m = 1u << stage; - uint32_t t = N >> (stage + 1); - - // Cooperative load for this prime's twiddles (twiddle-major access) - uint32_t localIdx = elemIdx % tgSize.x; - if (localIdx < m) { - // Coalesced access: adjacent primes access adjacent memory - prime_twiddles[localIdx] = twiddles[localIdx * numPrimes + primeIdx]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Process butterfly - device uint64_t* poly = data + (batchIdx * numPrimes + primeIdx) * N; - - uint32_t butterfly_idx = elemIdx; - if (butterfly_idx < N / 2) { - uint32_t group = butterfly_idx / t; - uint32_t elem = butterfly_idx % t; - uint32_t idx_lo = (group << (log_N - stage)) + elem; - uint32_t idx_hi = idx_lo + t; - - uint64_t lo = poly[idx_lo]; - uint64_t hi = poly[idx_hi]; - uint64_t tw = prime_twiddles[group]; - - uint64_t hi_tw = barrett_mul(hi, tw, Q, mu); - poly[idx_lo] = mod_add(lo, hi_tw, Q); - poly[idx_hi] = mod_sub(lo, hi_tw, Q); - } -} - -// ============================================================================= -// Kernel: N^(-1) Scaling for INTT -// ============================================================================= - -kernel void ntt_hotset_scale_ninv( - device uint64_t* data [[buffer(0)]], - constant NTTParams& params [[buffer(1)]], - uint global_idx [[thread_position_in_grid]] -) { - uint32_t total = params.N * params.batch; - if (global_idx >= total) return; - - uint64_t val = data[global_idx]; - data[global_idx] = barrett_mul(val, params.N_inv, params.Q, params.mu); -} - -// ============================================================================= -// Kernel: Cache Performance Benchmark -// ============================================================================= -// -// Measures effective bandwidth for different cache tiers. - -struct BenchmarkResult { - uint64_t constantCycles; - uint64_t threadgroupCycles; - uint64_t deviceCycles; - uint64_t computeCycles; -}; - -kernel void benchmark_twiddle_access( - device BenchmarkResult* result [[buffer(0)]], - constant uint64_t* device_twiddles [[buffer(1)]], - constant ConstantCache& cache [[buffer(2)]], - uint thread_idx [[thread_index_in_threadgroup]], - uint threadgroup_size [[threads_per_threadgroup]] -) { - threadgroup uint64_t shared_twiddles[512]; - - // Warm up shared memory - if (thread_idx < 512) { - shared_twiddles[thread_idx] = device_twiddles[thread_idx]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Benchmark iterations - const uint32_t ITERATIONS = 1000; - uint64_t sum = 0; - - // Test constant memory access - uint64_t start = 0; // Note: Metal lacks cycle counter; use external timing - for (uint32_t i = 0; i < ITERATIONS; ++i) { - sum += cache.firstLevelTwiddles[0][i % FIRST_LEVEL_TWIDDLE_COUNT]; - } - - // Test threadgroup memory access - for (uint32_t i = 0; i < ITERATIONS; ++i) { - sum += shared_twiddles[i % 512]; - } - - // Test device memory access - for (uint32_t i = 0; i < ITERATIONS; ++i) { - sum += device_twiddles[i % 4096]; - } - - // Prevent optimization from eliminating reads - if (thread_idx == 0) { - result->computeCycles = sum; // Force dependency - } -} diff --git a/ntt/gpu/wgsl/four_step_ntt.wgsl b/ntt/gpu/wgsl/four_step_ntt.wgsl deleted file mode 100644 index 5da6272..0000000 --- a/ntt/gpu/wgsl/four_step_ntt.wgsl +++ /dev/null @@ -1,143 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Four-Step NTT in WGSL, ported from four_step_ntt.metal. -// Column NTTs, twiddle+transpose, row NTTs, scaling. -// u64 emulated as vec2(lo, hi). - -@group(0) @binding(0) var data: array>; -@group(0) @binding(1) var twiddles: array>; -@group(0) @binding(2) var precon_twiddles: array>; -@group(0) @binding(3) var params: FourStepParams; - -struct FourStepParams { - Q_lo: u32, Q_hi: u32, - mu_lo: u32, mu_hi: u32, - N_inv_lo: u32, N_inv_hi: u32, - N_inv_precon_lo: u32, N_inv_precon_hi: u32, - N: u32, N1: u32, N2: u32, - log_N1: u32, log_N2: u32, - batch_size: u32, -} - -fn u64_add(a: vec2, b: vec2) -> vec2 { - let lo = a.x + b.x; - return vec2(lo, a.y + b.y + select(0u, 1u, lo < a.x)); -} -fn u64_sub(a: vec2, b: vec2) -> vec2 { - return vec2(a.x - b.x, a.y - b.y - select(0u, 1u, a.x < b.x)); -} -fn u64_gte(a: vec2, b: vec2) -> bool { - if (a.y != b.y) { return a.y > b.y; } - return a.x >= b.x; -} -fn u64_mul_lo(a: vec2, b: vec2) -> vec2 { - let al = a.x & 0xFFFFu; let ah = a.x >> 16u; - let bl = b.x & 0xFFFFu; let bh = b.x >> 16u; - let ll = al * bl; let mid = al * bh + ah * bl; - let lo = ll + (mid << 16u); - let hi = ah * bh + (mid >> 16u) + select(0u, 1u, lo < ll) + a.x * b.y + a.y * b.x; - return vec2(lo, hi); -} -fn u64_mulhi(a: vec2, b: vec2) -> vec2 { - return vec2(a.y * b.y + ((a.x >> 16u) * b.y + a.y * (b.x >> 16u)) >> 16u, 0u); -} - -fn mod_add(a: vec2, b: vec2, Q: vec2) -> vec2 { - let s = u64_add(a, b); - if (u64_gte(s, Q)) { return u64_sub(s, Q); } - return s; -} -fn mod_sub(a: vec2, b: vec2, Q: vec2) -> vec2 { - if (u64_gte(a, b)) { return u64_sub(a, b); } - return u64_sub(u64_add(a, Q), b); -} -fn barrett_mul(a: vec2, b: vec2, Q: vec2, pre: vec2) -> vec2 { - let q_approx = u64_mulhi(a, pre); - let prod = u64_mul_lo(a, b); - var r = u64_sub(prod, u64_mul_lo(q_approx, Q)); - if (u64_gte(r, Q)) { r = u64_sub(r, Q); } - return r; -} - -// ============================================================================ -// Twiddle multiplication + pointwise operations -// ============================================================================ - -@compute @workgroup_size(256) -fn four_step_twiddle_mul(@builtin(global_invocation_id) gid: vec3) { - let batch_idx = gid.y; - let elem_idx = gid.x; - if (batch_idx >= params.batch_size || elem_idx >= params.N) { return; } - - let Q = vec2(params.Q_lo, params.Q_hi); - let idx = batch_idx * params.N + elem_idx; - let tw = twiddles[elem_idx]; - let pre = precon_twiddles[elem_idx]; - data[idx] = barrett_mul(data[idx], tw, Q, pre); -} - -// ============================================================================ -// Transpose: data viewed as N1 x N2 -> N2 x N1 -// ============================================================================ - -@compute @workgroup_size(256) -fn four_step_transpose(@builtin(global_invocation_id) gid: vec3) { - let batch_idx = gid.y; - let elem_idx = gid.x; - if (batch_idx >= params.batch_size || elem_idx >= params.N) { return; } - - let N1 = params.N1; - let N2 = params.N2; - let row = elem_idx / N2; - let col = elem_idx % N2; - - let src_idx = batch_idx * params.N + row * N2 + col; - let dst_idx = batch_idx * params.N + col * N1 + row; - - // Read source (need to use a different buffer for out-of-place) - // For in-place, only swap upper triangle - if (row < col) { - let val_a = data[src_idx]; - let val_b = data[dst_idx]; - data[src_idx] = val_b; - data[dst_idx] = val_a; - } -} - -// ============================================================================ -// Scaling by N^{-1} -// ============================================================================ - -@compute @workgroup_size(256) -fn four_step_scale(@builtin(global_invocation_id) gid: vec3) { - let idx = gid.x; - let total = params.N * params.batch_size; - if (idx >= total) { return; } - - let Q = vec2(params.Q_lo, params.Q_hi); - let N_inv = vec2(params.N_inv_lo, params.N_inv_hi); - let pre = vec2(params.N_inv_precon_lo, params.N_inv_precon_hi); - data[idx] = barrett_mul(data[idx], N_inv, Q, pre); -} - -// ============================================================================ -// Pointwise multiplication -// ============================================================================ - -@compute @workgroup_size(256) -fn four_step_pointwise_mul(@builtin(global_invocation_id) gid: vec3) { - let idx = gid.x; - let total = params.N * params.batch_size; - if (idx >= total) { return; } - - let Q = vec2(params.Q_lo, params.Q_hi); - let mu = vec2(params.mu_lo, params.mu_hi); - let a = data[idx]; - let b = twiddles[idx]; // reuse binding 1 as second polynomial - let prod = u64_mul_lo(a, b); - let q = u64_mulhi(prod, mu); - var r = u64_sub(prod, u64_mul_lo(q, Q)); - if (u64_gte(r, Q)) { r = u64_sub(r, Q); } - data[idx] = r; -} diff --git a/ntt/gpu/wgsl/ntt.wgsl b/ntt/gpu/wgsl/ntt.wgsl deleted file mode 100644 index 044dfdd..0000000 --- a/ntt/gpu/wgsl/ntt.wgsl +++ /dev/null @@ -1,155 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Shared NTT (Number Theoretic Transform) compute shader in WGSL. -// -// Forward and inverse NTT for lattice-based PQ crypto. -// Used by ML-DSA (q=8380417), ML-KEM (q=3329), Ringtail. -// Each thread transforms one 256-coefficient polynomial. - -@group(0) @binding(0) var polys: array; -@group(0) @binding(1) var params: vec4; // params.x = num_polys, params.y = direction (0=fwd, 1=inv) - -const Q: i32 = 8380417; -const Q_INV: i32 = 58728449; // -q^(-1) mod 2^32 -const F_INV: i32 = 41978; // R * 256^{-1} mod q - -// Precomputed zetas (roots of unity in Montgomery form for q=8380417) -const ZETAS = array( - 25847, -2608894, -518909, 237124, -777960, -876248, 466468, 1826347, - 2353451, -359251, -2091905, 3119733, -2884855, 3111497, 2680103, 2725464, - 1024112, -1079900, 3585928, -549488, -1119584, 2619752, -2108549, -2118186, - -3859737, -1399561, -3277672, 1757237, -19422, 4010497, 280005, -2353451, - -1012179, -1277625, 1526252, -1402780, -2091905, 3119733, 3585928, -549488, - 2619752, -2108549, 2804197, -3199876, -38575, -2704181, 1757237, -19422, - 280005, 2706023, 1391570, 2287915, -3583748, -1399561, -3277672, -2353451, - 2353451, 3585928, -549488, 2619752, -2108549, 2804197, -3199876, -38575, - -2704181, 1757237, -19422, 280005, 2706023, 1391570, 2287915, -3583748, - -1399561, -3277672, 237124, -777960, -876248, 466468, 1826347, -2608894, - -518909, 237124, -777960, -876248, 466468, 1826347, 2353451, -359251, - -2091905, 3119733, -2884855, 3111497, 2680103, 2725464, 1024112, -1079900, - 3585928, -549488, -1119584, 2619752, -2108549, -2118186, -3859737, -1399561, - -3277672, 1757237, -19422, 4010497, 280005, -2353451, -1012179, -1277625, - 1526252, -1402780, 2706023, 1391570, 2287915, -3583748, -1399561, -3277672, - 1757237, -19422, 280005, 2706023, 1391570, 2287915, -3583748, -1399561 -); - -// Montgomery reduction: a * R^{-1} mod q -fn mont_reduce(a_lo: i32, a_hi: i32) -> i32 { - // Emulate 64-bit: (a_hi << 32) | a_lo - let t: i32 = a_lo * Q_INV; - // u = t * Q (low 32 bits cancel a_lo) - // result = (a - u) >> 32 = a_hi - (t * Q) >> 32 + correction - let u_lo: i32 = t * Q; - var r: i32 = a_hi - ((t >> 16) * (Q >> 16)); // Approximate high part - // Simplified: for WGSL without 64-bit, use the fact that the low 32 bits cancel - if (r < 0) { r = r + Q; } - if (r >= Q) { r = r - Q; } - return r; -} - -// Simplified Montgomery mul for 32-bit WGSL: a * b mod q -fn mod_mul(a: i32, b: i32) -> i32 { - // Since WGSL has no 64-bit, do schoolbook with 16-bit pieces - let a_lo: u32 = u32(a) & 0xFFFFu; - let a_hi: u32 = u32(a) >> 16u; - let b_lo: u32 = u32(b) & 0xFFFFu; - let b_hi: u32 = u32(b) >> 16u; - - let ll: u32 = a_lo * b_lo; - let lh: u32 = a_lo * b_hi; - let hl: u32 = a_hi * b_lo; - let hh: u32 = a_hi * b_hi; - - // Combine: result = hh:mid:ll where mid = lh + hl - let mid: u32 = lh + hl; - let result_lo: u32 = ll + (mid << 16u); - let result_hi: u32 = hh + (mid >> 16u) + select(0u, 1u, result_lo < ll); - - // Barrett reduction mod q - // Approximate: result / q using precomputed constant - let q = u32(Q); - var r: u32 = result_lo; - // Simple iterative reduction (sufficient for 32x32 -> 48-bit results) - r = result_lo - (result_hi * q); - if (r >= q) { r = r - q; } - if (r >= q) { r = r - q; } - return i32(r); -} - -@compute @workgroup_size(64) -fn ntt_mldsa_batch(@builtin(global_invocation_id) gid: vec3) { - let tid = gid.x; - if (tid >= params.x) { return; } - - let base = tid * 256u; - - // Load polynomial into private memory - var poly: array; - for (var i = 0u; i < 256u; i = i + 1u) { - poly[i] = polys[base + i]; - } - - if (params.y == 0u) { - // Forward NTT (Cooley-Tukey) - var k = 0u; - var len = 128u; - loop { - if (len == 0u) { break; } - var start = 0u; - loop { - if (start >= 256u) { break; } - k = k + 1u; - let zeta = ZETAS[k]; - var j = start; - loop { - if (j >= start + len) { break; } - let t = mod_mul(zeta, poly[j + len]); - poly[j + len] = poly[j] - t; - poly[j] = poly[j] + t; - if (poly[j] >= Q) { poly[j] = poly[j] - Q; } - if (poly[j + len] < 0) { poly[j + len] = poly[j + len] + Q; } - j = j + 1u; - } - start = start + 2u * len; - } - len = len >> 1u; - } - } else { - // Inverse NTT (Gentleman-Sande) - var k = 127u; - var len = 1u; - loop { - if (len > 128u) { break; } - var start = 0u; - loop { - if (start >= 256u) { break; } - var zeta = -ZETAS[k]; - k = k - 1u; - if (zeta < 0) { zeta = zeta + Q; } - var j = start; - loop { - if (j >= start + len) { break; } - let t = poly[j]; - poly[j] = t + poly[j + len]; - poly[j + len] = t - poly[j + len]; - if (poly[j] >= Q) { poly[j] = poly[j] - Q; } - if (poly[j + len] < 0) { poly[j + len] = poly[j + len] + Q; } - poly[j + len] = mod_mul(zeta, poly[j + len]); - j = j + 1u; - } - start = start + 2u * len; - } - len = len << 1u; - } - // Scale by f - for (var i = 0u; i < 256u; i = i + 1u) { - poly[i] = mod_mul(F_INV, poly[i]); - } - } - - // Write back - for (var i = 0u; i < 256u; i = i + 1u) { - polys[base + i] = poly[i]; - } -} diff --git a/ntt/gpu/wgsl/ntt_kernels.wgsl b/ntt/gpu/wgsl/ntt_kernels.wgsl deleted file mode 100644 index 84c879b..0000000 --- a/ntt/gpu/wgsl/ntt_kernels.wgsl +++ /dev/null @@ -1,211 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Optimal NTT kernels for Lux FHE in WGSL, ported from ntt_kernels.metal. -// Forward/inverse NTT stages with Barrett reduction. -// u64 emulated via vec2(lo, hi). - -@group(0) @binding(0) var data: array>; // u64 as (lo, hi) -@group(0) @binding(1) var twiddles: array>; -@group(0) @binding(2) var precon_twiddles: array>; -@group(0) @binding(3) var params: NTTParams; -@group(0) @binding(4) var stage_params: vec4; // x=stage, y=batch_size - -struct NTTParams { - Q_lo: u32, Q_hi: u32, - mu_lo: u32, mu_hi: u32, - N_inv_lo: u32, N_inv_hi: u32, - N_inv_precon_lo: u32, N_inv_precon_hi: u32, - N: u32, log_N: u32, -} - -// u64 helpers -fn u64_zero() -> vec2 { return vec2(0u, 0u); } -fn u64_add(a: vec2, b: vec2) -> vec2 { - let lo = a.x + b.x; - let c = select(0u, 1u, lo < a.x); - return vec2(lo, a.y + b.y + c); -} -fn u64_sub(a: vec2, b: vec2) -> vec2 { - let bw = select(0u, 1u, a.x < b.x); - return vec2(a.x - b.x, a.y - b.y - bw); -} -fn u64_gte(a: vec2, b: vec2) -> bool { - if (a.y > b.y) { return true; } - if (a.y < b.y) { return false; } - return a.x >= b.x; -} - -// 32x32 -> 64 multiply -fn mul32_64(a: u32, b: u32) -> vec2 { - let al = a & 0xFFFFu; let ah = a >> 16u; - let bl = b & 0xFFFFu; let bh = b >> 16u; - let ll = al * bl; - let lh = al * bh; - let hl = ah * bl; - let hh = ah * bh; - let mid = lh + hl; - let lo = ll + (mid << 16u); - let hi = hh + (mid >> 16u) + select(0u, 1u, lo < ll) + select(0u, 0x10000u, mid < lh); - return vec2(lo, hi); -} - -// Approximate mulhi(a, b) for u64: returns high 64 bits of a*b -// For Barrett: we only need a.hi * b.hi as approximation for small Q -fn u64_mulhi_approx(a: vec2, b: vec2) -> vec2 { - // Full 128-bit would require 4 partial products; for Q < 2^32 this suffices - let p = mul32_64(a.y, b.y); - let cross1 = mul32_64(a.x, b.y); - let cross2 = mul32_64(a.y, b.x); - let mid_lo = cross1.y + cross2.y; // approximate carry into high product - return vec2(p.x + mid_lo, p.y + select(0u, 1u, p.x + mid_lo < p.x)); -} - -// u64 multiply (lo 64 bits only) -fn u64_mul(a: vec2, b: vec2) -> vec2 { - let ll = mul32_64(a.x, b.x); - let cross = a.x * b.y + a.y * b.x; // low 32 bits of cross products - return vec2(ll.x, ll.y + cross); -} - -fn mod_add(a: vec2, b: vec2, Q: vec2) -> vec2 { - let sum = u64_add(a, b); - if (u64_gte(sum, Q)) { return u64_sub(sum, Q); } - return sum; -} - -fn mod_sub(a: vec2, b: vec2, Q: vec2) -> vec2 { - if (u64_gte(a, b)) { return u64_sub(a, b); } - return u64_sub(u64_add(a, Q), b); -} - -fn mod_mul_barrett(a: vec2, omega: vec2, Q: vec2, - precon: vec2) -> vec2 { - let q_approx = u64_mulhi_approx(a, precon); - let product = u64_mul(a, omega); - var result = u64_sub(product, u64_mul(q_approx, Q)); - if (u64_gte(result, Q)) { result = u64_sub(result, Q); } - return result; -} - -// ============================================================================ -// Forward NTT stage (Cooley-Tukey) -// ============================================================================ - -@compute @workgroup_size(256) -fn ntt_forward_stage(@builtin(global_invocation_id) gid: vec3) { - let batch_idx = gid.y; - let butterfly_idx = gid.x; - - if (batch_idx >= stage_params.y) { return; } - - let N = params.N; - let Q = vec2(params.Q_lo, params.Q_hi); - let stage = stage_params.x; - - let m = 1u << stage; - let t = N >> (stage + 1u); - let num_butterflies = N >> 1u; - if (butterfly_idx >= num_butterflies) { return; } - - let i = butterfly_idx / t; - let j = butterfly_idx % t; - let idx_lo = (i << (params.log_N - stage)) + j; - let idx_hi = idx_lo + t; - let tw_idx = m + i; - - let poly_offset = batch_idx * N; - let lo_val = data[poly_offset + idx_lo]; - let hi_val = data[poly_offset + idx_hi]; - - let omega = twiddles[tw_idx]; - let precon = precon_twiddles[tw_idx]; - - let omega_factor = mod_mul_barrett(hi_val, omega, Q, precon); - data[poly_offset + idx_lo] = mod_add(lo_val, omega_factor, Q); - data[poly_offset + idx_hi] = mod_sub(lo_val, omega_factor, Q); -} - -// ============================================================================ -// Inverse NTT stage (Gentleman-Sande) -// ============================================================================ - -@compute @workgroup_size(256) -fn ntt_inverse_stage(@builtin(global_invocation_id) gid: vec3) { - let batch_idx = gid.y; - let butterfly_idx = gid.x; - - if (batch_idx >= stage_params.y) { return; } - - let N = params.N; - let Q = vec2(params.Q_lo, params.Q_hi); - let stage = stage_params.x; - - let m = N >> (stage + 1u); - let t = 1u << stage; - let num_butterflies = N >> 1u; - if (butterfly_idx >= num_butterflies) { return; } - - let i = butterfly_idx / t; - let j = butterfly_idx % t; - let idx_lo = (i << (stage + 1u)) + j; - let idx_hi = idx_lo + t; - let tw_idx = m + i; - - let poly_offset = batch_idx * N; - let lo_val = data[poly_offset + idx_lo]; - let hi_val = data[poly_offset + idx_hi]; - - let omega = twiddles[tw_idx]; - let precon = precon_twiddles[tw_idx]; - - let sum = mod_add(lo_val, hi_val, Q); - let diff = mod_sub(lo_val, hi_val, Q); - let diff_tw = mod_mul_barrett(diff, omega, Q, precon); - - data[poly_offset + idx_lo] = sum; - data[poly_offset + idx_hi] = diff_tw; -} - -// ============================================================================ -// Scale by N^{-1} after inverse NTT -// ============================================================================ - -@compute @workgroup_size(256) -fn ntt_scale(@builtin(global_invocation_id) gid: vec3) { - let batch_idx = gid.y; - let coeff_idx = gid.x; - if (batch_idx >= stage_params.y || coeff_idx >= params.N) { return; } - - let Q = vec2(params.Q_lo, params.Q_hi); - let N_inv = vec2(params.N_inv_lo, params.N_inv_hi); - let N_inv_pre = vec2(params.N_inv_precon_lo, params.N_inv_precon_hi); - - let idx = batch_idx * params.N + coeff_idx; - data[idx] = mod_mul_barrett(data[idx], N_inv, Q, N_inv_pre); -} - -// ============================================================================ -// Pointwise multiply-accumulate -// ============================================================================ - -@compute @workgroup_size(256) -fn ntt_pointwise_mac(@builtin(global_invocation_id) gid: vec3) { - let batch_idx = gid.y; - let coeff_idx = gid.x; - if (batch_idx >= stage_params.y || coeff_idx >= params.N) { return; } - - let Q = vec2(params.Q_lo, params.Q_hi); - let idx = batch_idx * params.N + coeff_idx; - let a_val = twiddles[idx]; // reuse binding 1 as 'a' input - let b_val = precon_twiddles[idx]; // reuse binding 2 as 'b' input - - // Simple modular multiply (no precon) - let prod = u64_mul(a_val, b_val); - // Reduction: if prod >= Q, subtract Q (sufficient for small Q) - var r = prod; - if (u64_gte(r, Q)) { r = u64_sub(r, Q); } - if (u64_gte(r, Q)) { r = u64_sub(r, Q); } - - data[idx] = mod_add(data[idx], r, Q); -} diff --git a/ntt/gpu/wgsl/ntt_large.wgsl b/ntt/gpu/wgsl/ntt_large.wgsl deleted file mode 100644 index fcfcedb..0000000 --- a/ntt/gpu/wgsl/ntt_large.wgsl +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// WGSL compute shader: six-step large-N NTT. -// -// WGSL has no native u64; all u64 ops are emulated as pairs of u32 (lo, hi) -// using carry-aware addition and 64-bit-wide multiply-with-Barrett. The -// existing four_step_ntt.wgsl in this directory implements those primitives -// and the per-step kernels (column NTT, fused twiddle-and-transpose, row -// NTT, N^-1 scale). -// -// This module aliases the four_step_* kernels under large_ntt_* names for -// hosts that prefer the large-N branding. No new arithmetic logic; the -// kernels in four_step_ntt.wgsl are the source of truth. - -// Shared parameter struct, binary-compatible with FourStepParams in -// four_step_ntt.wgsl. -struct LargeNttParams { - Q_lo: u32, - Q_hi: u32, - mu_lo: u32, - mu_hi: u32, - N_inv_lo: u32, - N_inv_hi: u32, - N_inv_precon_lo: u32, - N_inv_precon_hi: u32, - N: u32, - N1: u32, - N2: u32, - log_N1: u32, - log_N2: u32, - tile_stride: u32, - batch_size: u32, - _pad: u32, -}; - -@group(0) @binding(0) var data: array; -@group(0) @binding(1) var twiddles: array; -@group(0) @binding(2) var twiddle_precon: array; -@group(0) @binding(3) var params: LargeNttParams; - -// ============================================================================= -// large_ntt_column_fwd -// large_ntt_twiddle_xpose -// large_ntt_row_fwd -// large_ntt_column_inv -// large_ntt_inv_twiddle_xpose -// large_ntt_row_inv -// large_ntt_scale_n_inv -// -// All seven kernels are defined in four_step_ntt.wgsl. WGSL pipelines bind -// them by name; the host driver requests them with the four_step_* names -// for backward compatibility, and with the large_ntt_* names for the new -// dispatch path. Both resolve to the same compiled compute pipeline. -// -// To avoid duplicate-symbol errors when both files are compiled into the -// same module, this file does NOT redefine any kernel here. It serves as a -// header / contract document: the ntt_large host driver targets these -// entry points and supplies the parameter buffers above. -// ============================================================================= - -@compute @workgroup_size(1) -fn ntt_large_module_marker(@builtin(global_invocation_id) gid: vec3) { - // Empty entry point so naga / wgpu has at least one stage to compile when - // this module is loaded standalone. The real work is in four_step_ntt.wgsl. - if (gid.x == 0u) { - // No-op write back to ensure the module is not optimised away. - let i: u32 = gid.x; - if (i < params.N) { - // Write-then-read identity to pin the binding. - let v_lo = data[2u * i + 0u]; - let v_hi = data[2u * i + 1u]; - data[2u * i + 0u] = v_lo; - data[2u * i + 1u] = v_hi; - } - } -} diff --git a/ntt/gpu/wgsl/ntt_large_driver.cpp b/ntt/gpu/wgsl/ntt_large_driver.cpp deleted file mode 100644 index 1b9dadc..0000000 --- a/ntt/gpu/wgsl/ntt_large_driver.cpp +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// WGSL host-side driver for the six-step large-N NTT. -// -// Same pattern as banderwagon/gpu/wgsl/banderwagon_driver.cpp: no wgpu-native -// runtime needed in CI; the driver runs the CPU oracle and the kernel in -// gpu/wgsl/ntt_large.wgsl is exercised on hosts with wgpu hardware. - -#include "ntt_large.hpp" - -namespace lux::crypto::ntt::large::gpu_wgsl { - -bool device_available() { - return false; -} - -void forward(uint64_t* a, const LargeContext& ctx) { - lux::crypto::ntt::large::forward(a, ctx); -} - -void inverse(uint64_t* a, const LargeContext& ctx) { - lux::crypto::ntt::large::inverse(a, ctx); -} - -} // namespace lux::crypto::ntt::large::gpu_wgsl diff --git a/ntt/gpu/wgsl/ntt_metal_kernel.wgsl b/ntt/gpu/wgsl/ntt_metal_kernel.wgsl deleted file mode 100644 index a2f0034..0000000 --- a/ntt/gpu/wgsl/ntt_metal_kernel.wgsl +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// NTT with shared memory twiddle prefetch in WGSL. -// Ported from ntt_metal_kernel.metal. -// Single NTT stage with workgroup-local twiddle cache. -// u64 emulated as vec2(lo, hi). - -@group(0) @binding(0) var data: array>; -@group(0) @binding(1) var twiddles: array>; -@group(0) @binding(2) var params: NTTStageParams; - -struct NTTStageParams { - Q_lo: u32, Q_hi: u32, - mu_lo: u32, mu_hi: u32, - N: u32, log_N: u32, - stage: u32, batch: u32, -} - -var shared_twiddles: array, 4096>; - -fn u64_add(a: vec2, b: vec2) -> vec2 { - let lo = a.x + b.x; - return vec2(lo, a.y + b.y + select(0u, 1u, lo < a.x)); -} -fn u64_sub(a: vec2, b: vec2) -> vec2 { - return vec2(a.x - b.x, a.y - b.y - select(0u, 1u, a.x < b.x)); -} -fn u64_gte(a: vec2, b: vec2) -> bool { - if (a.y != b.y) { return a.y > b.y; } - return a.x >= b.x; -} -fn u64_mul_lo(a: vec2, b: vec2) -> vec2 { - let al = a.x & 0xFFFFu; let ah = a.x >> 16u; - let bl = b.x & 0xFFFFu; let bh = b.x >> 16u; - let ll = al * bl; - let mid = al * bh + ah * bl; - let lo = ll + (mid << 16u); - let hi = ah * bh + (mid >> 16u) + select(0u, 1u, lo < ll) + a.x * b.y + a.y * b.x; - return vec2(lo, hi); -} -fn u64_mulhi_approx(a: vec2, b: vec2) -> vec2 { - let al = a.x & 0xFFFFu; let ah = a.x >> 16u; - let bh_lo = b.y & 0xFFFFu; let bh_hi = b.y >> 16u; - let p = ah * bh_hi; - let cross = ah * bh_lo + al * bh_hi + (a.y & 0xFFFFu) * (b.x >> 16u); - return vec2(a.y * b.y + (cross >> 16u), p); -} - -fn mod_add(a: vec2, b: vec2, Q: vec2) -> vec2 { - let s = u64_add(a, b); - if (u64_gte(s, Q)) { return u64_sub(s, Q); } - return s; -} -fn mod_sub(a: vec2, b: vec2, Q: vec2) -> vec2 { - if (u64_gte(a, b)) { return u64_sub(a, b); } - return u64_sub(u64_add(a, Q), b); -} -fn barrett_mul(a: vec2, b: vec2, Q: vec2, mu: vec2) -> vec2 { - let prod = u64_mul_lo(a, b); - let q = u64_mulhi_approx(prod, mu); - var r = u64_sub(prod, u64_mul_lo(q, Q)); - if (u64_gte(r, Q)) { r = u64_sub(r, Q); } - return r; -} - -@compute @workgroup_size(256) -fn ntt_forward_stage_shared(@builtin(global_invocation_id) gid: vec3, - @builtin(local_invocation_id) lid: vec3) { - let batch_idx = gid.y; - let thread_idx = lid.x; - if (batch_idx >= params.batch) { return; } - - let N = params.N; - let Q = vec2(params.Q_lo, params.Q_hi); - let mu = vec2(params.mu_lo, params.mu_hi); - let stage = params.stage; - let m = 1u << stage; - - // Prefetch twiddles into workgroup memory - if (thread_idx < m) { - shared_twiddles[thread_idx] = twiddles[m + thread_idx]; - } - workgroupBarrier(); - - let t = N >> (stage + 1u); - let num_butterflies = N >> 1u; - let butterfly_idx = gid.x; - if (butterfly_idx >= num_butterflies) { return; } - - let i = butterfly_idx / t; - let j = butterfly_idx % t; - let idx_lo = (i << (params.log_N - stage)) + j; - let idx_hi = idx_lo + t; - - let poly_offset = batch_idx * N; - let lo_val = data[poly_offset + idx_lo]; - let hi_val = data[poly_offset + idx_hi]; - - let tw = shared_twiddles[i % m]; - let omega_factor = barrett_mul(hi_val, tw, Q, mu); - - data[poly_offset + idx_lo] = mod_add(lo_val, omega_factor, Q); - data[poly_offset + idx_hi] = mod_sub(lo_val, omega_factor, Q); -} - -@compute @workgroup_size(256) -fn ntt_inverse_stage_shared(@builtin(global_invocation_id) gid: vec3, - @builtin(local_invocation_id) lid: vec3) { - let batch_idx = gid.y; - let thread_idx = lid.x; - if (batch_idx >= params.batch) { return; } - - let N = params.N; - let Q = vec2(params.Q_lo, params.Q_hi); - let mu = vec2(params.mu_lo, params.mu_hi); - let stage = params.stage; - let m = N >> (stage + 1u); - - if (thread_idx < m) { - shared_twiddles[thread_idx] = twiddles[m + thread_idx]; - } - workgroupBarrier(); - - let t = 1u << stage; - let num_butterflies = N >> 1u; - let butterfly_idx = gid.x; - if (butterfly_idx >= num_butterflies) { return; } - - let i = butterfly_idx / t; - let j = butterfly_idx % t; - let idx_lo = (i << (stage + 1u)) + j; - let idx_hi = idx_lo + t; - - let poly_offset = batch_idx * N; - let lo_val = data[poly_offset + idx_lo]; - let hi_val = data[poly_offset + idx_hi]; - - let tw = shared_twiddles[i % m]; - let sum = mod_add(lo_val, hi_val, Q); - let diff = mod_sub(lo_val, hi_val, Q); - let diff_tw = barrett_mul(diff, tw, Q, mu); - - data[poly_offset + idx_lo] = sum; - data[poly_offset + idx_hi] = diff_tw; -} diff --git a/ntt/gpu/wgsl/ntt_unified_memory.wgsl b/ntt/gpu/wgsl/ntt_unified_memory.wgsl deleted file mode 100644 index 6df84e9..0000000 --- a/ntt/gpu/wgsl/ntt_unified_memory.wgsl +++ /dev/null @@ -1,142 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Unified memory NTT kernels in WGSL, ported from ntt_unified_memory.metal. -// Zero-copy NTT with branch-free modular arithmetic. -// u64 emulated as vec2(lo, hi). - -@group(0) @binding(0) var data: array>; -@group(0) @binding(1) var twiddles: array>; -@group(0) @binding(2) var params: NTTUnifiedParams; - -struct NTTUnifiedParams { - Q_lo: u32, Q_hi: u32, - mu_lo: u32, mu_hi: u32, - N_inv_lo: u32, N_inv_hi: u32, - N_inv_precon_lo: u32, N_inv_precon_hi: u32, - N: u32, log_N: u32, - stage: u32, batch: u32, -} - -fn u64_add(a: vec2, b: vec2) -> vec2 { - let lo = a.x + b.x; - return vec2(lo, a.y + b.y + select(0u, 1u, lo < a.x)); -} -fn u64_sub(a: vec2, b: vec2) -> vec2 { - return vec2(a.x - b.x, a.y - b.y - select(0u, 1u, a.x < b.x)); -} -fn u64_gte(a: vec2, b: vec2) -> bool { - if (a.y != b.y) { return a.y > b.y; } - return a.x >= b.x; -} -fn u64_mul_lo(a: vec2, b: vec2) -> vec2 { - let al = a.x & 0xFFFFu; let ah = a.x >> 16u; - let bl = b.x & 0xFFFFu; let bh = b.x >> 16u; - let ll = al * bl; let mid = al * bh + ah * bl; - let lo = ll + (mid << 16u); - let hi = ah * bh + (mid >> 16u) + select(0u, 1u, lo < ll) + a.x * b.y + a.y * b.x; - return vec2(lo, hi); -} -fn u64_mulhi(a: vec2, b: vec2) -> vec2 { - // Approximate high 64 bits - let p = a.y * b.y; - let c1 = (a.x >> 16u) * b.y + a.y * (b.x >> 16u); - return vec2(p + (c1 >> 16u), 0u); -} - -// Branch-free modular ops (matching unified memory Metal style) -fn mod_add_bf(a: vec2, b: vec2, Q: vec2) -> vec2 { - let sum = u64_add(a, b); - let mask_hi = select(0u, 0xFFFFFFFFu, u64_gte(sum, Q)); - let mask_lo = mask_hi; - return u64_sub(sum, vec2(Q.x & mask_lo, Q.y & mask_hi)); -} -fn mod_sub_bf(a: vec2, b: vec2, Q: vec2) -> vec2 { - let diff = u64_sub(a, b); - let mask = select(0u, 0xFFFFFFFFu, !u64_gte(a, b)); - return u64_add(diff, vec2(Q.x & mask, Q.y & mask)); -} - -fn barrett_mul_unified(a: vec2, b: vec2, Q: vec2, - mu: vec2) -> vec2 { - let lo = u64_mul_lo(a, b); - let q = u64_mulhi(lo, mu); - var r = u64_sub(lo, u64_mul_lo(q, Q)); - let mask = select(0u, 0xFFFFFFFFu, u64_gte(r, Q)); - r = u64_sub(r, vec2(Q.x & mask, Q.y & mask)); - return r; -} - -@compute @workgroup_size(256) -fn ntt_unified_forward_stage(@builtin(global_invocation_id) gid: vec3) { - let batch_idx = gid.y; - let butterfly_idx = gid.x; - if (batch_idx >= params.batch) { return; } - - let N = params.N; - let Q = vec2(params.Q_lo, params.Q_hi); - let mu = vec2(params.mu_lo, params.mu_hi); - let stage = params.stage; - let m = 1u << stage; - let t = N >> (stage + 1u); - if (butterfly_idx >= N >> 1u) { return; } - - let i = butterfly_idx / t; - let j = butterfly_idx % t; - let idx_lo = (i << (params.log_N - stage)) + j; - let idx_hi = idx_lo + t; - let poly = batch_idx * N; - - let lo_val = data[poly + idx_lo]; - let hi_val = data[poly + idx_hi]; - let tw = twiddles[m + i]; - let hi_tw = barrett_mul_unified(hi_val, tw, Q, mu); - - data[poly + idx_lo] = mod_add_bf(lo_val, hi_tw, Q); - data[poly + idx_hi] = mod_sub_bf(lo_val, hi_tw, Q); -} - -@compute @workgroup_size(256) -fn ntt_unified_inverse_stage(@builtin(global_invocation_id) gid: vec3) { - let batch_idx = gid.y; - let butterfly_idx = gid.x; - if (batch_idx >= params.batch) { return; } - - let N = params.N; - let Q = vec2(params.Q_lo, params.Q_hi); - let mu = vec2(params.mu_lo, params.mu_hi); - let stage = params.stage; - let m = N >> (stage + 1u); - let t = 1u << stage; - if (butterfly_idx >= N >> 1u) { return; } - - let i = butterfly_idx / t; - let j = butterfly_idx % t; - let idx_lo = (i << (stage + 1u)) + j; - let idx_hi = idx_lo + t; - let poly = batch_idx * N; - - let lo_val = data[poly + idx_lo]; - let hi_val = data[poly + idx_hi]; - let tw = twiddles[m + i]; - - let sum = mod_add_bf(lo_val, hi_val, Q); - let diff = mod_sub_bf(lo_val, hi_val, Q); - let diff_tw = barrett_mul_unified(diff, tw, Q, mu); - - data[poly + idx_lo] = sum; - data[poly + idx_hi] = diff_tw; -} - -@compute @workgroup_size(256) -fn ntt_unified_scale(@builtin(global_invocation_id) gid: vec3) { - let batch_idx = gid.y; - let coeff_idx = gid.x; - if (batch_idx >= params.batch || coeff_idx >= params.N) { return; } - - let Q = vec2(params.Q_lo, params.Q_hi); - let mu = vec2(params.mu_lo, params.mu_hi); - let N_inv = vec2(params.N_inv_lo, params.N_inv_hi); - let idx = batch_idx * params.N + coeff_idx; - data[idx] = barrett_mul_unified(data[idx], N_inv, Q, mu); -} diff --git a/ntt/gpu/wgsl/twiddle_cache.wgsl b/ntt/gpu/wgsl/twiddle_cache.wgsl deleted file mode 100644 index 7238ee3..0000000 --- a/ntt/gpu/wgsl/twiddle_cache.wgsl +++ /dev/null @@ -1,359 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-2-Clause -// -// Twiddle Hotset Caching — WGSL compute shaders -// Ported from twiddle_cache.metal. -// NTT kernels with intelligent twiddle caching for different memory tiers. -// u64 emulated as vec2(lo, hi). No native u64 in WGSL. -// -// Memory hierarchy: -// Uniform buffer: First-level twiddles (8 values), modular constants -// Workgroup memory: Stage-specific twiddles with prefetch -// Registers: Current butterfly operands - -// --------------------------------------------------------------------------- -// Bindings -// --------------------------------------------------------------------------- - -@group(0) @binding(0) var data: array>; -@group(0) @binding(1) var twiddles_buf: array>; -@group(0) @binding(2) var cache: ConstantCache; -@group(0) @binding(3) var ntt_params: NTTParams; - -// Constant cache with first-level twiddles per prime -struct PrimeConstants { - q_lo: u32, q_hi: u32, - q_inv_lo: u32, q_inv_hi: u32, - mu_hi_lo: u32, mu_hi_hi: u32, - mu_lo_lo: u32, mu_lo_hi: u32, - r_squared_lo: u32, r_squared_hi: u32, - root_lo: u32, root_hi: u32, - root_inv_lo: u32, root_inv_hi: u32, - n_inv_lo: u32, n_inv_hi: u32, -} - -struct ConstantCache { - num_primes: u32, - ring_dim: u32, - _pad0: u32, _pad1: u32, - // First-level twiddles: 8 per prime, up to 16 primes - // Flattened: first_level_twiddles[prime_idx * 8 + i] - first_level_twiddles: array, 128>, // [16][8] - first_level_inv_twiddles: array, 128>, // [16][8] - // Prime constants (up to 16 primes) - primes: array, -} - -struct NTTParams { - Q_lo: u32, Q_hi: u32, - mu_lo: u32, mu_hi: u32, - N_inv_lo: u32, N_inv_hi: u32, - N_inv_precon_lo: u32, N_inv_precon_hi: u32, - N: u32, log_N: u32, - stage: u32, prime_idx: u32, - batch: u32, prefetch_stage: u32, - _pad0: u32, _pad1: u32, -} - -// --------------------------------------------------------------------------- -// u64 emulation -// --------------------------------------------------------------------------- - -fn u64_from(lo: u32, hi: u32) -> vec2 { return vec2(lo, hi); } -fn u64_zero() -> vec2 { return vec2(0u, 0u); } -fn u64_add(a: vec2, b: vec2) -> vec2 { - let lo = a.x + b.x; - return vec2(lo, a.y + b.y + select(0u, 1u, lo < a.x)); -} -fn u64_sub(a: vec2, b: vec2) -> vec2 { - return vec2(a.x - b.x, a.y - b.y - select(0u, 1u, a.x < b.x)); -} -fn u64_gte(a: vec2, b: vec2) -> bool { - if (a.y != b.y) { return a.y > b.y; } - return a.x >= b.x; -} -fn mul32_64(a: u32, b: u32) -> vec2 { - let al = a & 0xFFFFu; let ah = a >> 16u; - let bl = b & 0xFFFFu; let bh = b >> 16u; - let ll = al * bl; - let mid = al * bh + ah * bl; - let lo = ll + (mid << 16u); - let hi = ah * bh + (mid >> 16u) + select(0u, 1u, lo < ll) + select(0u, 0x10000u, mid < (al * bh)); - return vec2(lo, hi); -} -fn u64_mul(a: vec2, b: vec2) -> vec2 { - let ll = mul32_64(a.x, b.x); - let cross = a.x * b.y + a.y * b.x; - return vec2(ll.x, ll.y + cross); -} -fn u64_mulhi(a: vec2, b: vec2) -> vec2 { - let p = mul32_64(a.y, b.y); - let c1 = mul32_64(a.x, b.y); - let c2 = mul32_64(a.y, b.x); - let mid_lo = c1.y + c2.y; - return vec2(p.x + mid_lo, p.y + select(0u, 1u, p.x + mid_lo < p.x)); -} - -fn mod_add(a: vec2, b: vec2, Q: vec2) -> vec2 { - let s = u64_add(a, b); - if (u64_gte(s, Q)) { return u64_sub(s, Q); } - return s; -} -fn mod_sub(a: vec2, b: vec2, Q: vec2) -> vec2 { - if (u64_gte(a, b)) { return u64_sub(a, b); } - return u64_sub(u64_add(a, Q), b); -} -fn barrett_mul(a: vec2, omega: vec2, Q: vec2, mu: vec2) -> vec2 { - let q_hat = u64_mulhi(a, mu); - let product = u64_mul(a, omega); - var r = u64_sub(product, u64_mul(q_hat, Q)); - if (u64_gte(r, Q)) { r = u64_sub(r, Q); } - return r; -} - -// --------------------------------------------------------------------------- -// Bank conflict avoidance -// --------------------------------------------------------------------------- - -const BANK_WIDTH: u32 = 32u; -const BANK_PADDING: u32 = 1u; - -fn padded_index(idx: u32) -> u32 { - return idx + (idx / BANK_WIDTH) * BANK_PADDING; -} - -// --------------------------------------------------------------------------- -// Workgroup shared memory -// --------------------------------------------------------------------------- - -// Max twiddles in workgroup = 4096 + padding -var twiddles_shared: array, 4224>; // 4096 + 4096/32 -var twiddles_prefetch: array, 4224>; - -// For fused kernel: all twiddles -var all_twiddles: array, 4096>; - -// =========================================================================== -// Forward NTT stage with hotset caching -// =========================================================================== - -@compute @workgroup_size(256) -fn ntt_hotset_forward_stage( - @builtin(local_invocation_id) lid_v: vec3, - @builtin(workgroup_id) wgid: vec3, - @builtin(num_workgroups) nwg: vec3, -) { - let lid = lid_v.x; - let batch_idx = wgid.x; - let N = ntt_params.N; - let Q = u64_from(ntt_params.Q_lo, ntt_params.Q_hi); - let mu = u64_from(ntt_params.mu_lo, ntt_params.mu_hi); - let stage = ntt_params.stage; - let prime_idx = ntt_params.prime_idx; - - let m = 1u << stage; - let t = N >> (stage + 1u); - - let batch_data_offset = batch_idx * N; - - // Determine twiddle source - let use_constant = (stage < 4u && m <= 8u); - - if (!use_constant) { - // Cooperative load into workgroup memory with padding - let loads = (m + 255u) / 256u; - for (var i = 0u; i < loads; i++) { - let tw_idx = lid + i * 256u; - if (tw_idx < m) { - let padded = padded_index(tw_idx); - twiddles_shared[padded] = twiddles_buf[m + tw_idx]; - } - } - - // Prefetch next stage twiddles if enabled - if (ntt_params.prefetch_stage < ntt_params.log_N && ntt_params.prefetch_stage > stage) { - let next_m = 1u << ntt_params.prefetch_stage; - let pf_loads = (next_m + 255u) / 256u; - for (var i = 0u; i < pf_loads; i++) { - let tw_idx = lid + i * 256u; - if (tw_idx < next_m && tw_idx < 4096u) { - let padded = padded_index(tw_idx); - twiddles_prefetch[padded] = twiddles_buf[next_m + tw_idx]; - } - } - } - - workgroupBarrier(); - } - - // Butterfly computation - let butterflies_per_thread = (N / 2u + 255u) / 256u; - - for (var b = 0u; b < butterflies_per_thread; b++) { - let bf_idx = lid + b * 256u; - if (bf_idx >= N / 2u) { break; } - - let group = bf_idx / t; - let elem = bf_idx % t; - let idx_lo = (group << (ntt_params.log_N - stage)) + elem; - let idx_hi = idx_lo + t; - - let lo = data[batch_data_offset + idx_lo]; - let hi = data[batch_data_offset + idx_hi]; - - // Get twiddle from appropriate cache tier - var tw: vec2; - if (use_constant) { - tw = cache.first_level_twiddles[prime_idx * 8u + group]; - } else { - let padded = padded_index(group); - tw = twiddles_shared[padded]; - } - - let hi_tw = barrett_mul(hi, tw, Q, mu); - data[batch_data_offset + idx_lo] = mod_add(lo, hi_tw, Q); - data[batch_data_offset + idx_hi] = mod_sub(lo, hi_tw, Q); - } -} - -// =========================================================================== -// Inverse NTT stage with hotset caching -// =========================================================================== - -@compute @workgroup_size(256) -fn ntt_hotset_inverse_stage( - @builtin(local_invocation_id) lid_v: vec3, - @builtin(workgroup_id) wgid: vec3, -) { - let lid = lid_v.x; - let batch_idx = wgid.x; - let N = ntt_params.N; - let Q = u64_from(ntt_params.Q_lo, ntt_params.Q_hi); - let mu = u64_from(ntt_params.mu_lo, ntt_params.mu_hi); - let stage = ntt_params.stage; - let prime_idx = ntt_params.prime_idx; - - let m = N >> (stage + 1u); - let t = 1u << stage; - - let batch_data_offset = batch_idx * N; - - let use_constant = (stage >= ntt_params.log_N - 4u && m <= 8u); - - if (!use_constant) { - let loads = (m + 255u) / 256u; - for (var i = 0u; i < loads; i++) { - let tw_idx = lid + i * 256u; - if (tw_idx < m) { - let padded = padded_index(tw_idx); - twiddles_shared[padded] = twiddles_buf[m + tw_idx]; - } - } - workgroupBarrier(); - } - - let butterflies_per_thread = (N / 2u + 255u) / 256u; - - for (var b = 0u; b < butterflies_per_thread; b++) { - let bf_idx = lid + b * 256u; - if (bf_idx >= N / 2u) { break; } - - let group = bf_idx / t; - let elem = bf_idx % t; - let idx_lo = (group << (stage + 1u)) + elem; - let idx_hi = idx_lo + t; - - let lo = data[batch_data_offset + idx_lo]; - let hi = data[batch_data_offset + idx_hi]; - - var tw: vec2; - if (use_constant) { - tw = cache.first_level_inv_twiddles[prime_idx * 8u + group]; - } else { - let padded = padded_index(group); - tw = twiddles_shared[padded]; - } - - // Gentleman-Sande butterfly - let sum = mod_add(lo, hi, Q); - let diff = mod_sub(lo, hi, Q); - data[batch_data_offset + idx_lo] = sum; - data[batch_data_offset + idx_hi] = barrett_mul(diff, tw, Q, mu); - } -} - -// =========================================================================== -// Multi-stage fused NTT (all stages in one dispatch for N <= 4096) -// =========================================================================== - -@compute @workgroup_size(256) -fn ntt_hotset_fused( - @builtin(local_invocation_id) lid_v: vec3, - @builtin(workgroup_id) wgid: vec3, -) { - let lid = lid_v.x; - let batch_idx = wgid.x; - let N = ntt_params.N; - let log_N = ntt_params.log_N; - let Q = u64_from(ntt_params.Q_lo, ntt_params.Q_hi); - let mu = u64_from(ntt_params.mu_lo, ntt_params.mu_hi); - - let batch_data_offset = batch_idx * N; - - // Load ALL twiddles into workgroup memory (N-1 total) - let total_twiddles = N - 1u; - let loads = (total_twiddles + 255u) / 256u; - for (var i = 0u; i < loads; i++) { - let tw_idx = lid + i * 256u; - if (tw_idx < total_twiddles) { - all_twiddles[tw_idx] = twiddles_buf[tw_idx]; - } - } - workgroupBarrier(); - - // Process all stages - for (var stage = 0u; stage < log_N; stage++) { - let m = 1u << stage; - let t = N >> (stage + 1u); - let tw_base = m; - - let bpt = (N / 2u + 255u) / 256u; - for (var b = 0u; b < bpt; b++) { - let bf_idx = lid + b * 256u; - if (bf_idx >= N / 2u) { break; } - - let group = bf_idx / t; - let elem = bf_idx % t; - let idx_lo = (group << (log_N - stage)) + elem; - let idx_hi = idx_lo + t; - - let lo = data[batch_data_offset + idx_lo]; - let hi = data[batch_data_offset + idx_hi]; - let tw = all_twiddles[tw_base + group]; - - let hi_tw = barrett_mul(hi, tw, Q, mu); - data[batch_data_offset + idx_lo] = mod_add(lo, hi_tw, Q); - data[batch_data_offset + idx_hi] = mod_sub(lo, hi_tw, Q); - } - - // Device memory barrier between stages - storageBarrier(); - } -} - -// =========================================================================== -// N^{-1} scaling for INTT -// =========================================================================== - -@compute @workgroup_size(256) -fn ntt_hotset_scale_ninv(@builtin(global_invocation_id) gid: vec3) { - let total = ntt_params.N * ntt_params.batch; - let idx = gid.x; - if (idx >= total) { return; } - - let Q = u64_from(ntt_params.Q_lo, ntt_params.Q_hi); - let mu = u64_from(ntt_params.mu_lo, ntt_params.mu_hi); - let N_inv = u64_from(ntt_params.N_inv_lo, ntt_params.N_inv_hi); - - data[idx] = barrett_mul(data[idx], N_inv, Q, mu); -} diff --git a/pedersen/CMakeLists.txt b/pedersen/CMakeLists.txt index d73fe74..3f9e8fa 100644 --- a/pedersen/CMakeLists.txt +++ b/pedersen/CMakeLists.txt @@ -67,49 +67,41 @@ if(CRYPTO_ENABLE_CUDA) CUDA_SEPARABLE_COMPILATION ON) target_include_directories(pedersen_cuda PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/gpu/cuda) -else() - # Compile only the host driver stubs so consumers can still link a - # symbol-complete pedersen_cuda library on hosts without nvcc. - add_library(pedersen_cuda STATIC - gpu/cuda/pedersen_driver_cuda.cpp - gpu/cuda/pedersen_tree_driver.cpp) - set_target_properties(pedersen_cuda PROPERTIES POSITION_INDEPENDENT_CODE ON) - target_include_directories(pedersen_cuda PUBLIC - ${CMAKE_CURRENT_SOURCE_DIR}/gpu/cuda) + target_compile_features(pedersen_cuda PUBLIC cxx_std_20) endif() -target_compile_features(pedersen_cuda PUBLIC cxx_std_20) -# Batched Pedersen WebGPU/WGSL driver. Each shader source is embedded at -# configure time as a const char* so the driver can submit it via Dawn / -# wgpu-native without runtime file I/O. +# Batched Pedersen WebGPU/WGSL driver. Built only when CRYPTO_ENABLE_WGSL=ON +# (lux-gpu-kernels found). Each shader source is embedded at configure time +# as a const char* so the driver can submit it via Dawn / wgpu-native +# without runtime file I/O. # # Two shaders ship in the same archive: # * pedersen.wgsl -- two-stage pointmul + reduce_add pipeline # * pedersen_tree.wgsl -- single-dispatch tree-reduce kernel at N = 256 -set(_ped_wgsl_src ${CMAKE_CURRENT_SOURCE_DIR}/gpu/wgsl/pedersen.wgsl) -set(_ped_wgsl_embed ${CMAKE_CURRENT_BINARY_DIR}/pedersen_wgsl_source.h) -file(READ ${_ped_wgsl_src} _ped_wgsl_content) -file(WRITE ${_ped_wgsl_embed} "// Auto-generated. Do not edit.\n#pragma once\n\n") -file(APPEND ${_ped_wgsl_embed} - "static constexpr char kPedersenWGSL[] = R\"PEDWGSL(\n${_ped_wgsl_content}\n)PEDWGSL\";\n") +if(CRYPTO_ENABLE_WGSL) + set(_ped_wgsl_src ${CMAKE_CURRENT_SOURCE_DIR}/gpu/wgsl/pedersen.wgsl) + set(_ped_wgsl_embed ${CMAKE_CURRENT_BINARY_DIR}/pedersen_wgsl_source.h) + file(READ ${_ped_wgsl_src} _ped_wgsl_content) + file(WRITE ${_ped_wgsl_embed} "// Auto-generated. Do not edit.\n#pragma once\n\n") + file(APPEND ${_ped_wgsl_embed} + "static constexpr char kPedersenWGSL[] = R\"PEDWGSL(\n${_ped_wgsl_content}\n)PEDWGSL\";\n") -set(_ped_tree_wgsl_src ${CMAKE_CURRENT_SOURCE_DIR}/gpu/wgsl/pedersen_tree.wgsl) -set(_ped_tree_wgsl_embed ${CMAKE_CURRENT_BINARY_DIR}/pedersen_tree_wgsl_source.h) -file(READ ${_ped_tree_wgsl_src} _ped_tree_wgsl_content) -file(WRITE ${_ped_tree_wgsl_embed} "// Auto-generated. Do not edit.\n#pragma once\n\n") -file(APPEND ${_ped_tree_wgsl_embed} - "static constexpr char kPedersenTreeWGSL[] = R\"PEDTRWGSL(\n${_ped_tree_wgsl_content}\n)PEDTRWGSL\";\n") + set(_ped_tree_wgsl_src ${CMAKE_CURRENT_SOURCE_DIR}/gpu/wgsl/pedersen_tree.wgsl) + set(_ped_tree_wgsl_embed ${CMAKE_CURRENT_BINARY_DIR}/pedersen_tree_wgsl_source.h) + file(READ ${_ped_tree_wgsl_src} _ped_tree_wgsl_content) + file(WRITE ${_ped_tree_wgsl_embed} "// Auto-generated. Do not edit.\n#pragma once\n\n") + file(APPEND ${_ped_tree_wgsl_embed} + "static constexpr char kPedersenTreeWGSL[] = R\"PEDTRWGSL(\n${_ped_tree_wgsl_content}\n)PEDTRWGSL\";\n") -add_library(pedersen_wgpu STATIC - gpu/wgsl/pedersen_driver_wgpu.cpp - gpu/wgsl/pedersen_tree_driver.cpp) -target_compile_features(pedersen_wgpu PUBLIC cxx_std_20) -set_target_properties(pedersen_wgpu PROPERTIES POSITION_INDEPENDENT_CODE ON) -target_include_directories(pedersen_wgpu PUBLIC - ${CMAKE_CURRENT_SOURCE_DIR}/gpu/wgsl - ${CMAKE_CURRENT_BINARY_DIR}) + add_library(pedersen_wgpu STATIC + gpu/wgsl/pedersen_driver_wgpu.cpp + gpu/wgsl/pedersen_tree_driver.cpp) + target_compile_features(pedersen_wgpu PUBLIC cxx_std_20) + set_target_properties(pedersen_wgpu PROPERTIES POSITION_INDEPENDENT_CODE ON) + target_include_directories(pedersen_wgpu PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/gpu/wgsl + ${CMAKE_CURRENT_BINARY_DIR}) -if(CRYPTO_ENABLE_WGSL) find_path(_PED_WGPU_INCLUDE webgpu.h HINTS /opt/homebrew/include /usr/local/include /usr/include) find_library(_PED_WGPU_LIB NAMES wgpu_native wgpu diff --git a/pedersen/gpu/cuda/pedersen.cu b/pedersen/gpu/cuda/pedersen.cu deleted file mode 100644 index 450df38..0000000 --- a/pedersen/gpu/cuda/pedersen.cu +++ /dev/null @@ -1,562 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// First-party CUDA kernel for batched Pedersen vector commitments over BN254 -// G1. Mechanically ported from pedersen/gpu/metal/pedersen.metal -- byte-equal -// outputs to the Metal kernel and the Go canonical at -// github.com/luxfi/crypto/pedersen. -// -// Computed quantity (M parallel commitments, each of size N): -// -// C_m = sum_{i=0..N-1} s[m][i] * G[i] + r[m] * H for m = 0..M-1 -// -// Two-stage pipeline (one cubin, two kernels): -// 1. pedersen_pointmul M*(N+1) threads, one per (commitment, term) -// 2. pedersen_reduce_add M threads, one per commitment -// -// Wire format identical to Metal driver: gens (N+1)*64 BE, scalars M*N*32 BE, -// blindings M*32 BE, output M*64 BE. - -#include - -#ifndef __CUDACC__ -// Allow this file to compile with a plain C++ compiler so callers that lack -// the CUDA toolkit still get a non-error "translation unit empty" object. -#define __device__ -#define __global__ -#endif - -// ============================================================================= -// BN254 base-field constants -// ============================================================================= - -__device__ static const uint64_t BN254_P0 = 0x3C208C16D87CFD47ULL; -__device__ static const uint64_t BN254_P1 = 0x97816A916871CA8DULL; -__device__ static const uint64_t BN254_P2 = 0xB85045B68181585DULL; -__device__ static const uint64_t BN254_P3 = 0x30644E72E131A029ULL; - -__device__ static const uint64_t BN254_R_0 = 0xD35D438DC58F0D9DULL; -__device__ static const uint64_t BN254_R_1 = 0x0A78EB28F5C70B3DULL; -__device__ static const uint64_t BN254_R_2 = 0x666EA36F7879462CULL; -__device__ static const uint64_t BN254_R_3 = 0x0E0A77C19A07DF2FULL; - -__device__ static const uint64_t BN254_R2_0 = 0xF32CFC5B538AFA89ULL; -__device__ static const uint64_t BN254_R2_1 = 0xB5E71911D44501FBULL; -__device__ static const uint64_t BN254_R2_2 = 0x47AB1EFF0A417FF6ULL; -__device__ static const uint64_t BN254_R2_3 = 0x06D89F71CAB8351FULL; - -__device__ static const uint64_t BN254_INV = 0x87D20782E4866389ULL; - -// ============================================================================= -// 256-bit big integer helpers -// ============================================================================= - -struct U256 { uint64_t l[4]; }; - -__device__ static inline U256 u256_zero() { - U256 x; x.l[0]=0; x.l[1]=0; x.l[2]=0; x.l[3]=0; return x; -} - -__device__ static inline bool u256_is_zero(const U256& a) { - return (a.l[0] | a.l[1] | a.l[2] | a.l[3]) == 0; -} - -__device__ static inline int u256_cmp_p(const U256& a) { - if (a.l[3] != BN254_P3) return a.l[3] > BN254_P3 ? 1 : -1; - if (a.l[2] != BN254_P2) return a.l[2] > BN254_P2 ? 1 : -1; - if (a.l[1] != BN254_P1) return a.l[1] > BN254_P1 ? 1 : -1; - if (a.l[0] != BN254_P0) return a.l[0] > BN254_P0 ? 1 : -1; - return 0; -} - -__device__ static inline U256 fp_p() { - U256 r; r.l[0]=BN254_P0; r.l[1]=BN254_P1; r.l[2]=BN254_P2; r.l[3]=BN254_P3; - return r; -} - -__device__ static inline U256 fp_csub_p(const U256& a) { - if (u256_cmp_p(a) < 0) return a; - U256 r; - uint64_t borrow = 0; - uint64_t pl[4] = { BN254_P0, BN254_P1, BN254_P2, BN254_P3 }; - for (int i = 0; i < 4; ++i) { - uint64_t ai = a.l[i]; - uint64_t s = ai - pl[i] - borrow; - borrow = ((ai < pl[i] + borrow) || (pl[i] + borrow < pl[i])) ? 1 : 0; - r.l[i] = s; - } - return r; -} - -__device__ static inline U256 fp_add(const U256& a, const U256& b) { - U256 r; - uint64_t carry = 0; - for (int i = 0; i < 4; ++i) { - uint64_t s = a.l[i] + b.l[i]; - uint64_t c1 = (s < a.l[i]) ? 1 : 0; - uint64_t s2 = s + carry; - uint64_t c2 = (s2 < s) ? 1 : 0; - r.l[i] = s2; - carry = c1 + c2; - } - return fp_csub_p(r); -} - -__device__ static inline U256 fp_sub(const U256& a, const U256& b) { - U256 r; - uint64_t borrow = 0; - for (int i = 0; i < 4; ++i) { - uint64_t bi = b.l[i]; - uint64_t s = a.l[i] - bi - borrow; - borrow = ((a.l[i] < bi + borrow) || (bi + borrow < bi)) ? 1 : 0; - r.l[i] = s; - } - if (borrow) { - uint64_t carry = 0; - uint64_t pl[4] = { BN254_P0, BN254_P1, BN254_P2, BN254_P3 }; - for (int i = 0; i < 4; ++i) { - uint64_t s = r.l[i] + pl[i]; - uint64_t c1 = (s < r.l[i]) ? 1 : 0; - uint64_t s2 = s + carry; - uint64_t c2 = (s2 < s) ? 1 : 0; - r.l[i] = s2; - carry = c1 + c2; - } - } - return r; -} - -// ============================================================================= -// 64x64->128 multiplication (CUDA: __umul64hi for the high half) -// ============================================================================= - -#ifdef __CUDACC__ -__device__ static inline uint64_t mul_hi64(uint64_t a, uint64_t b) { - return __umul64hi(a, b); -} -#else -__device__ static inline uint64_t mul_hi64(uint64_t a, uint64_t b) { -#if defined(__SIZEOF_INT128__) - return (uint64_t)(((__uint128_t)a * (__uint128_t)b) >> 64); -#else - uint64_t a_lo = (uint32_t)a, a_hi = a >> 32; - uint64_t b_lo = (uint32_t)b, b_hi = b >> 32; - uint64_t ll = a_lo * b_lo; - uint64_t lh = a_lo * b_hi; - uint64_t hl = a_hi * b_lo; - uint64_t hh = a_hi * b_hi; - uint64_t mid = (ll >> 32) + (uint32_t)lh + (uint32_t)hl; - return hh + (lh >> 32) + (hl >> 32) + (mid >> 32); -#endif -} -#endif - -// ============================================================================= -// Montgomery multiplication (CIOS, 4-limb specialized) -- byte-identical to Metal -// ============================================================================= - -__device__ static inline U256 fp_mont_mul(const U256& a, const U256& b) { - uint64_t pl[4] = { BN254_P0, BN254_P1, BN254_P2, BN254_P3 }; - - uint64_t t0 = 0, t1 = 0, t2 = 0, t3 = 0, t4 = 0; - uint64_t t5 = 0; - - for (int i = 0; i < 4; ++i) { - uint64_t ai = a.l[i]; - - // Step A: t += a[i] * b - { - uint64_t carry = 0; - uint64_t lo, hi; - // j=0 - lo = ai * b.l[0]; hi = mul_hi64(ai, b.l[0]); - uint64_t s = t0 + lo; - uint64_t c1 = (s < t0) ? 1ULL : 0ULL; - t0 = s; - carry = hi + c1; - // j=1 - lo = ai * b.l[1]; hi = mul_hi64(ai, b.l[1]); - s = t1 + lo; - c1 = (s < t1) ? 1ULL : 0ULL; - uint64_t s2 = s + carry; - uint64_t c2 = (s2 < s) ? 1ULL : 0ULL; - t1 = s2; - carry = hi + c1 + c2; - // j=2 - lo = ai * b.l[2]; hi = mul_hi64(ai, b.l[2]); - s = t2 + lo; - c1 = (s < t2) ? 1ULL : 0ULL; - s2 = s + carry; - c2 = (s2 < s) ? 1ULL : 0ULL; - t2 = s2; - carry = hi + c1 + c2; - // j=3 - lo = ai * b.l[3]; hi = mul_hi64(ai, b.l[3]); - s = t3 + lo; - c1 = (s < t3) ? 1ULL : 0ULL; - s2 = s + carry; - c2 = (s2 < s) ? 1ULL : 0ULL; - t3 = s2; - carry = hi + c1 + c2; - // propagate - s = t4 + carry; - c1 = (s < t4) ? 1ULL : 0ULL; - t4 = s; - t5 = t5 + c1; - } - - // Step B: m = t0 * INV mod 2^64; t = t + m*p; shift down. - uint64_t m = t0 * BN254_INV; - { - uint64_t carry = 0; - uint64_t lo, hi; - // j=0 (zeroes t0) - lo = m * pl[0]; hi = mul_hi64(m, pl[0]); - uint64_t s = t0 + lo; - uint64_t c1 = (s < t0) ? 1ULL : 0ULL; - carry = hi + c1; - // j=1 - lo = m * pl[1]; hi = mul_hi64(m, pl[1]); - s = t1 + lo; - c1 = (s < t1) ? 1ULL : 0ULL; - uint64_t s2 = s + carry; - uint64_t c2 = (s2 < s) ? 1ULL : 0ULL; - t1 = s2; - carry = hi + c1 + c2; - // j=2 - lo = m * pl[2]; hi = mul_hi64(m, pl[2]); - s = t2 + lo; - c1 = (s < t2) ? 1ULL : 0ULL; - s2 = s + carry; - c2 = (s2 < s) ? 1ULL : 0ULL; - t2 = s2; - carry = hi + c1 + c2; - // j=3 - lo = m * pl[3]; hi = mul_hi64(m, pl[3]); - s = t3 + lo; - c1 = (s < t3) ? 1ULL : 0ULL; - s2 = s + carry; - c2 = (s2 < s) ? 1ULL : 0ULL; - t3 = s2; - carry = hi + c1 + c2; - // propagate - s = t4 + carry; - c1 = (s < t4) ? 1ULL : 0ULL; - t4 = s; - t5 = t5 + c1; - - // shift down by one limb - t0 = t1; t1 = t2; t2 = t3; t3 = t4; t4 = t5; t5 = 0; - } - } - - U256 r; r.l[0]=t0; r.l[1]=t1; r.l[2]=t2; r.l[3]=t3; - if (t4 != 0) { - U256 p = fp_p(); - r = fp_sub(r, p); - } - return fp_csub_p(r); -} - -__device__ static inline U256 fp_mont_sqr(const U256& a) { return fp_mont_mul(a, a); } - -__device__ static inline U256 fp_one_mont() { - U256 r; r.l[0]=BN254_R_0; r.l[1]=BN254_R_1; r.l[2]=BN254_R_2; r.l[3]=BN254_R_3; - return r; -} - -__device__ static inline U256 fp_r2() { - U256 r; r.l[0]=BN254_R2_0; r.l[1]=BN254_R2_1; r.l[2]=BN254_R2_2; r.l[3]=BN254_R2_3; - return r; -} - -__device__ static inline U256 fp_to_mont(const U256& x) { return fp_mont_mul(x, fp_r2()); } - -__device__ static inline U256 fp_from_mont(const U256& x) { - U256 one; one.l[0]=1; one.l[1]=0; one.l[2]=0; one.l[3]=0; - return fp_mont_mul(x, one); -} - -// Inversion via Fermat's little theorem (a^(p-2)). Identical bit-loop to Metal. -__device__ static inline U256 fp_inv(const U256& a) { - if (u256_is_zero(a)) return a; - - uint64_t exp[4] = { BN254_P0 - 2ULL, BN254_P1, BN254_P2, BN254_P3 }; - - U256 result = fp_one_mont(); - U256 base = a; - for (int limb = 0; limb < 4; ++limb) { - uint64_t e = exp[limb]; - for (int b = 0; b < 64; ++b) { - if ((e >> b) & 1ULL) { - result = fp_mont_mul(result, base); - } - base = fp_mont_sqr(base); - } - } - return result; -} - -// ============================================================================= -// G1 in Jacobian coordinates (Montgomery form for X, Y, Z) -// ============================================================================= - -struct G1Jac { U256 X, Y, Z; }; -struct G1Aff { U256 X, Y; bool inf; }; - -__device__ static inline G1Jac g1_zero() { - G1Jac p; - p.X = fp_one_mont(); - p.Y = fp_one_mont(); - p.Z = u256_zero(); - return p; -} - -__device__ static inline bool g1_is_zero(const G1Jac& p) { return u256_is_zero(p.Z); } - -__device__ static inline G1Jac g1_dbl(const G1Jac& p) { - if (g1_is_zero(p)) return p; - U256 A = fp_mont_sqr(p.X); - U256 B = fp_mont_sqr(p.Y); - U256 C = fp_mont_sqr(B); - U256 t = fp_add(p.X, B); - U256 t2 = fp_mont_sqr(t); - U256 t3 = fp_sub(t2, A); - U256 t4 = fp_sub(t3, C); - U256 D = fp_add(t4, t4); - U256 E = fp_add(fp_add(A, A), A); - U256 F = fp_mont_sqr(E); - G1Jac r; - U256 twoD = fp_add(D, D); - r.X = fp_sub(F, twoD); - U256 D_minus_X = fp_sub(D, r.X); - U256 EDX = fp_mont_mul(E, D_minus_X); - U256 eightC = fp_add(C, C); - eightC = fp_add(eightC, eightC); - eightC = fp_add(eightC, eightC); - r.Y = fp_sub(EDX, eightC); - U256 YZ = fp_mont_mul(p.Y, p.Z); - r.Z = fp_add(YZ, YZ); - return r; -} - -__device__ static inline G1Jac g1_add_mixed(const G1Jac& p, const U256& Qx, const U256& Qy) { - if (g1_is_zero(p)) { - G1Jac r; - r.X = Qx; r.Y = Qy; r.Z = fp_one_mont(); - return r; - } - U256 Z1Z1 = fp_mont_sqr(p.Z); - U256 U2 = fp_mont_mul(Qx, Z1Z1); - U256 S2 = fp_mont_mul(Qy, fp_mont_mul(p.Z, Z1Z1)); - U256 H = fp_sub(U2, p.X); - U256 r_v = fp_sub(S2, p.Y); - if (u256_is_zero(H)) { - if (u256_is_zero(r_v)) return g1_dbl(p); - return g1_zero(); - } - U256 HH = fp_mont_sqr(H); - U256 I = fp_add(HH, HH); - I = fp_add(I, I); - U256 J = fp_mont_mul(H, I); - U256 r_2 = fp_add(r_v, r_v); - U256 V = fp_mont_mul(p.X, I); - G1Jac out; - U256 r_sq = fp_mont_sqr(r_2); - U256 t1 = fp_sub(r_sq, J); - U256 twoV = fp_add(V, V); - out.X = fp_sub(t1, twoV); - U256 V_minus_X3 = fp_sub(V, out.X); - U256 r_VX = fp_mont_mul(r_2, V_minus_X3); - U256 Y1J = fp_mont_mul(p.Y, J); - U256 twoY1J = fp_add(Y1J, Y1J); - out.Y = fp_sub(r_VX, twoY1J); - out.Z = fp_mont_mul(p.Z, fp_add(H, H)); - return out; -} - -__device__ static inline G1Jac g1_add(const G1Jac& p, const G1Jac& q) { - if (g1_is_zero(p)) return q; - if (g1_is_zero(q)) return p; - U256 Z1Z1 = fp_mont_sqr(p.Z); - U256 Z2Z2 = fp_mont_sqr(q.Z); - U256 U1 = fp_mont_mul(p.X, Z2Z2); - U256 U2 = fp_mont_mul(q.X, Z1Z1); - U256 S1 = fp_mont_mul(fp_mont_mul(p.Y, q.Z), Z2Z2); - U256 S2 = fp_mont_mul(fp_mont_mul(q.Y, p.Z), Z1Z1); - U256 H = fp_sub(U2, U1); - U256 r_v = fp_sub(S2, S1); - if (u256_is_zero(H)) { - if (u256_is_zero(r_v)) return g1_dbl(p); - return g1_zero(); - } - U256 r2 = fp_add(r_v, r_v); - U256 HH = fp_mont_sqr(H); - U256 I = fp_add(HH, HH); - I = fp_add(I, I); - U256 J = fp_mont_mul(H, I); - U256 V = fp_mont_mul(U1, I); - G1Jac out; - U256 r_sq = fp_mont_sqr(r2); - U256 t1 = fp_sub(r_sq, J); - U256 twoV = fp_add(V, V); - out.X = fp_sub(t1, twoV); - U256 V_minus_X3 = fp_sub(V, out.X); - U256 r_VX = fp_mont_mul(r2, V_minus_X3); - U256 S1J = fp_mont_mul(S1, J); - U256 twoS1J = fp_add(S1J, S1J); - out.Y = fp_sub(r_VX, twoS1J); - U256 Z1Z2 = fp_mont_mul(p.Z, q.Z); - out.Z = fp_mont_mul(Z1Z2, fp_add(H, H)); - return out; -} - -__device__ static inline G1Aff g1_to_affine(const G1Jac& p) { - G1Aff r; - if (g1_is_zero(p)) { - r.X = u256_zero(); r.Y = u256_zero(); r.inf = true; - return r; - } - U256 Zinv = fp_inv(p.Z); - U256 Zinv2 = fp_mont_sqr(Zinv); - U256 Zinv3 = fp_mont_mul(Zinv2, Zinv); - r.X = fp_mont_mul(p.X, Zinv2); - r.Y = fp_mont_mul(p.Y, Zinv3); - r.inf = false; - return r; -} - -__device__ static inline G1Jac g1_scalar_mul_aff(const U256& Qx, const U256& Qy, - const uint64_t s[4]) { - G1Jac acc = g1_zero(); - for (int li = 3; li >= 0; --li) { - uint64_t limb = s[li]; - for (int bi = 63; bi >= 0; --bi) { - acc = g1_dbl(acc); - if ((limb >> bi) & 1ULL) { - acc = g1_add_mixed(acc, Qx, Qy); - } - } - } - return acc; -} - -// ============================================================================= -// I/O helpers (raw 32-byte BE Fp <-> 4 LE limbs) -// ============================================================================= - -__device__ static inline U256 read_be32(const uint8_t* p) { - U256 r; - for (int limb = 0; limb < 4; ++limb) { - const uint8_t* src = p + (3 - limb) * 8; - uint64_t v = 0; - for (int i = 0; i < 8; ++i) v = (v << 8) | (uint64_t)src[i]; - r.l[limb] = v; - } - return r; -} - -__device__ static inline void write_be32(uint8_t* p, const U256& a) { - for (int limb = 0; limb < 4; ++limb) { - uint8_t* dst = p + (3 - limb) * 8; - uint64_t v = a.l[limb]; - for (int i = 7; i >= 0; --i) { - dst[i] = (uint8_t)(v & 0xFFu); - v >>= 8; - } - } -} - -// ============================================================================= -// Kernel 1: pedersen_pointmul -- one thread per (m, i) scalar multiplication -// ============================================================================= - -struct PedersenDims { uint32_t M; uint32_t N; }; - -extern "C" __global__ void k_pedersen_pointmul( - const uint8_t* __restrict__ gens_be, // (N+1)*64 BE bytes - const uint8_t* __restrict__ scalars_be, // M*N*32 BE bytes - const uint8_t* __restrict__ blindings_be, // M*32 BE bytes - uint64_t* __restrict__ scratch, // M*(N+1)*12 u64 - PedersenDims dims) { -#ifdef __CUDACC__ - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; -#else - uint32_t tid = 0; -#endif - uint32_t M = dims.M; - uint32_t N = dims.N; - uint32_t total = M * (N + 1); - if (tid >= total) return; - - uint32_t m = tid / (N + 1); - uint32_t i = tid - m * (N + 1); - - U256 Qx_raw, Qy_raw, scalar_raw; - if (i < N) { - Qx_raw = read_be32(gens_be + i * 64); - Qy_raw = read_be32(gens_be + i * 64 + 32); - scalar_raw = read_be32(scalars_be + (m * N + i) * 32); - } else { - Qx_raw = read_be32(gens_be + N * 64); - Qy_raw = read_be32(gens_be + N * 64 + 32); - scalar_raw = read_be32(blindings_be + m * 32); - } - - U256 Qx = fp_to_mont(Qx_raw); - U256 Qy = fp_to_mont(Qy_raw); - - uint64_t s[4] = { scalar_raw.l[0], scalar_raw.l[1], scalar_raw.l[2], scalar_raw.l[3] }; - G1Jac result = g1_scalar_mul_aff(Qx, Qy, s); - - uint32_t base = tid * 12; - for (int k = 0; k < 4; ++k) scratch[base + 0 + k] = result.X.l[k]; - for (int k = 0; k < 4; ++k) scratch[base + 4 + k] = result.Y.l[k]; - for (int k = 0; k < 4; ++k) scratch[base + 8 + k] = result.Z.l[k]; -} - -// ============================================================================= -// Kernel 2: pedersen_reduce_add -- one thread per commitment -// ============================================================================= - -__device__ static inline G1Jac scratch_load(const uint64_t* scratch, uint32_t idx) { - G1Jac r; - uint32_t base = idx * 12; - for (int k = 0; k < 4; ++k) r.X.l[k] = scratch[base + 0 + k]; - for (int k = 0; k < 4; ++k) r.Y.l[k] = scratch[base + 4 + k]; - for (int k = 0; k < 4; ++k) r.Z.l[k] = scratch[base + 8 + k]; - return r; -} - -extern "C" __global__ void k_pedersen_reduce_add( - const uint64_t* __restrict__ scratch, // M*(N+1)*12 u64 - uint8_t* __restrict__ out_be, // M*64 BE bytes - PedersenDims dims) { -#ifdef __CUDACC__ - uint32_t m = blockIdx.x * blockDim.x + threadIdx.x; -#else - uint32_t m = 0; -#endif - uint32_t M = dims.M; - uint32_t N = dims.N; - if (m >= M) return; - - uint32_t base = m * (N + 1); - G1Jac acc = g1_zero(); - for (uint32_t i = 0; i < N + 1; ++i) { - G1Jac term = scratch_load(scratch, base + i); - if (g1_is_zero(term)) continue; - acc = g1_add(acc, term); - } - - G1Aff aff = g1_to_affine(acc); - if (aff.inf) { - uint8_t* dst = out_be + m * 64; - for (int b = 0; b < 64; ++b) dst[b] = 0; - return; - } - U256 X_raw = fp_from_mont(aff.X); - U256 Y_raw = fp_from_mont(aff.Y); - write_be32(out_be + m * 64, X_raw); - write_be32(out_be + m * 64 + 32, Y_raw); -} diff --git a/pedersen/gpu/cuda/pedersen_driver_cuda.cpp b/pedersen/gpu/cuda/pedersen_driver_cuda.cpp deleted file mode 100644 index b45a156..0000000 --- a/pedersen/gpu/cuda/pedersen_driver_cuda.cpp +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CUDA host driver for the batched Pedersen vector commitment. -// -// Build modes: -// 1. With CUDA toolkit (LUX_PEDERSEN_HAVE_CUDA defined): -// - Allocates device buffers, copies wire-format inputs, launches -// k_pedersen_pointmul + k_pedersen_reduce_add (defined in -// pedersen.cu), copies the M*64 output back. -// - Byte-equal to pedersen_batch_metal and the Go canonical at -// github.com/luxfi/crypto/pedersen. -// 2. Without CUDA: returns -1 (non-zero) and lux_pedersen_cuda_available() -// returns 0, so test harnesses skip the CUDA path. - -#include "pedersen_driver_cuda.h" - -#include -#include - -#ifdef LUX_PEDERSEN_HAVE_CUDA -#include - -struct PedersenDimsHost { uint32_t M; uint32_t N; }; - -extern "C" __global__ void k_pedersen_pointmul( - const uint8_t*, const uint8_t*, const uint8_t*, - uint64_t*, PedersenDimsHost); -extern "C" __global__ void k_pedersen_reduce_add( - const uint64_t*, uint8_t*, PedersenDimsHost); - -extern "C" int lux_pedersen_cuda_available(void) { - int count = 0; - cudaError_t e = cudaGetDeviceCount(&count); - return (e == cudaSuccess && count > 0) ? 1 : 0; -} - -extern "C" int pedersen_batch_cuda( - const uint8_t* gens_be, - const uint8_t* scalars_be, - const uint8_t* blindings_be, - uint32_t M, - uint32_t N, - uint8_t* out_be) { - if (M == 0 || N == 0) return 0; - if (!gens_be || !scalars_be || !blindings_be || !out_be) return -1; - if (!lux_pedersen_cuda_available()) return -1; - - size_t gens_len = (size_t)(N + 1) * 64; - size_t scalars_len = (size_t)M * N * 32; - size_t blind_len = (size_t)M * 32; - size_t scratch_u64 = (size_t)M * (N + 1) * 12; - size_t out_len = (size_t)M * 64; - - uint8_t *dGens=nullptr, *dScalars=nullptr, *dBlind=nullptr, *dOut=nullptr; - uint64_t *dScratch=nullptr; - - auto cleanup = [&]() { - if (dGens) cudaFree(dGens); - if (dScalars) cudaFree(dScalars); - if (dBlind) cudaFree(dBlind); - if (dScratch) cudaFree(dScratch); - if (dOut) cudaFree(dOut); - }; - - if (cudaMalloc((void**)&dGens, gens_len) != cudaSuccess) { cleanup(); return -2; } - if (cudaMalloc((void**)&dScalars, scalars_len) != cudaSuccess) { cleanup(); return -2; } - if (cudaMalloc((void**)&dBlind, blind_len) != cudaSuccess) { cleanup(); return -2; } - if (cudaMalloc((void**)&dScratch, scratch_u64*sizeof(uint64_t)) != cudaSuccess) { cleanup(); return -2; } - if (cudaMalloc((void**)&dOut, out_len) != cudaSuccess) { cleanup(); return -2; } - - cudaMemcpy(dGens, gens_be, gens_len, cudaMemcpyHostToDevice); - cudaMemcpy(dScalars, scalars_be, scalars_len, cudaMemcpyHostToDevice); - cudaMemcpy(dBlind, blindings_be, blind_len, cudaMemcpyHostToDevice); - - PedersenDimsHost dims{ M, N }; - - // Stage 1: pointmul -- M*(N+1) threads. - { - unsigned tg = 64; - unsigned total = M * (N + 1); - unsigned grid = (total + tg - 1) / tg; - k_pedersen_pointmul<<>>(dGens, dScalars, dBlind, dScratch, dims); - if (cudaDeviceSynchronize() != cudaSuccess) { cleanup(); return -3; } - } - - // Stage 2: reduce_add -- M threads. - { - unsigned tg = 32; - unsigned grid = (M + tg - 1) / tg; - k_pedersen_reduce_add<<>>(dScratch, dOut, dims); - if (cudaDeviceSynchronize() != cudaSuccess) { cleanup(); return -4; } - } - - cudaMemcpy(out_be, dOut, out_len, cudaMemcpyDeviceToHost); - cleanup(); - return 0; -} - -#else // LUX_PEDERSEN_HAVE_CUDA not defined: stub mode - -extern "C" int lux_pedersen_cuda_available(void) { return 0; } - -extern "C" int pedersen_batch_cuda( - const uint8_t*, const uint8_t*, const uint8_t*, - uint32_t, uint32_t, uint8_t*) { - return -1; -} - -#endif diff --git a/pedersen/gpu/cuda/pedersen_driver_cuda.h b/pedersen/gpu/cuda/pedersen_driver_cuda.h deleted file mode 100644 index 84bd180..0000000 --- a/pedersen/gpu/cuda/pedersen_driver_cuda.h +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Public C-ABI for the CUDA driver of the batched Pedersen vector commitment. -// Mirrors pedersen_batch_metal exactly. On non-CUDA hosts every entry returns -// non-zero except *_available which returns 0. - -#ifndef LUX_PEDERSEN_DRIVER_CUDA_H -#define LUX_PEDERSEN_DRIVER_CUDA_H - -#include - -#ifdef __cplusplus -extern "C" { -#endif - -// 1 if a CUDA device is present and the runtime initialised successfully. -int lux_pedersen_cuda_available(void); - -// Computes M parallel Pedersen commitments of basis size N. -// -// Wire format (raw big-endian, gnark-crypto compatible): -// gens_be : (N + 1) * 64 bytes -- G_basis[0..N-1] || H, X then Y -// scalars_be : M * N * 32 bytes -- raw BE Fr elements -// blindings_be: M * 32 bytes -- raw BE Fr elements -// out_be : M * 64 bytes -- (X || Y) raw BE -// -// Returns 0 on success, negative on failure. -int pedersen_batch_cuda( - const uint8_t* gens_be, - const uint8_t* scalars_be, - const uint8_t* blindings_be, - uint32_t M, - uint32_t N, - uint8_t* out_be); - -#ifdef __cplusplus -} -#endif - -#endif // LUX_PEDERSEN_DRIVER_CUDA_H diff --git a/pedersen/gpu/cuda/pedersen_tree.cu b/pedersen/gpu/cuda/pedersen_tree.cu deleted file mode 100644 index a83b1a8..0000000 --- a/pedersen/gpu/cuda/pedersen_tree.cu +++ /dev/null @@ -1,491 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Tree-reduce CUDA kernel for batched Pedersen vector commitments at the -// fixed Verkle width N = 256. One CUDA block per commitment, 256 threads -// per block. Block-shared memory holds the 256 partial points; an -// 8-stride shared-memory tree reduction collapses them inside the block. -// Output byte-equal to pedersen_tree_metal and the legacy two-stage CUDA -// pipeline (k_pedersen_pointmul + k_pedersen_reduce_add). - -#include - -#ifndef __CUDACC__ -// Plain C++ compile path: keep the file as a non-empty translation unit -// so the host driver can still link a stub library. -#define __device__ -#define __global__ -#define __shared__ -#define __forceinline__ inline -#define __syncthreads() ((void)0) -#endif - -// ============================================================================= -// BN254 base-field constants -// ============================================================================= - -__device__ static const uint64_t BN254_P0 = 0x3C208C16D87CFD47ULL; -__device__ static const uint64_t BN254_P1 = 0x97816A916871CA8DULL; -__device__ static const uint64_t BN254_P2 = 0xB85045B68181585DULL; -__device__ static const uint64_t BN254_P3 = 0x30644E72E131A029ULL; - -__device__ static const uint64_t BN254_R_0 = 0xD35D438DC58F0D9DULL; -__device__ static const uint64_t BN254_R_1 = 0x0A78EB28F5C70B3DULL; -__device__ static const uint64_t BN254_R_2 = 0x666EA36F7879462CULL; -__device__ static const uint64_t BN254_R_3 = 0x0E0A77C19A07DF2FULL; - -__device__ static const uint64_t BN254_R2_0 = 0xF32CFC5B538AFA89ULL; -__device__ static const uint64_t BN254_R2_1 = 0xB5E71911D44501FBULL; -__device__ static const uint64_t BN254_R2_2 = 0x47AB1EFF0A417FF6ULL; -__device__ static const uint64_t BN254_R2_3 = 0x06D89F71CAB8351FULL; - -__device__ static const uint64_t BN254_INV = 0x87D20782E4866389ULL; - -struct U256 { uint64_t l[4]; }; - -__device__ static inline U256 u256_zero() { - U256 x; x.l[0]=0; x.l[1]=0; x.l[2]=0; x.l[3]=0; return x; -} - -__device__ static inline bool u256_is_zero(const U256& a) { - return (a.l[0] | a.l[1] | a.l[2] | a.l[3]) == 0; -} - -__device__ static inline int u256_cmp_p(const U256& a) { - if (a.l[3] != BN254_P3) return a.l[3] > BN254_P3 ? 1 : -1; - if (a.l[2] != BN254_P2) return a.l[2] > BN254_P2 ? 1 : -1; - if (a.l[1] != BN254_P1) return a.l[1] > BN254_P1 ? 1 : -1; - if (a.l[0] != BN254_P0) return a.l[0] > BN254_P0 ? 1 : -1; - return 0; -} - -__device__ static inline U256 fp_p() { - U256 r; r.l[0]=BN254_P0; r.l[1]=BN254_P1; r.l[2]=BN254_P2; r.l[3]=BN254_P3; - return r; -} - -__device__ static inline U256 fp_csub_p(const U256& a) { - if (u256_cmp_p(a) < 0) return a; - U256 r; - uint64_t borrow = 0; - uint64_t pl[4] = { BN254_P0, BN254_P1, BN254_P2, BN254_P3 }; - for (int i = 0; i < 4; ++i) { - uint64_t ai = a.l[i]; - uint64_t s = ai - pl[i] - borrow; - borrow = ((ai < pl[i] + borrow) || (pl[i] + borrow < pl[i])) ? 1 : 0; - r.l[i] = s; - } - return r; -} - -__device__ static inline U256 fp_add(const U256& a, const U256& b) { - U256 r; - uint64_t carry = 0; - for (int i = 0; i < 4; ++i) { - uint64_t s = a.l[i] + b.l[i]; - uint64_t c1 = (s < a.l[i]) ? 1 : 0; - uint64_t s2 = s + carry; - uint64_t c2 = (s2 < s) ? 1 : 0; - r.l[i] = s2; - carry = c1 + c2; - } - return fp_csub_p(r); -} - -__device__ static inline U256 fp_sub(const U256& a, const U256& b) { - U256 r; - uint64_t borrow = 0; - for (int i = 0; i < 4; ++i) { - uint64_t bi = b.l[i]; - uint64_t s = a.l[i] - bi - borrow; - borrow = ((a.l[i] < bi + borrow) || (bi + borrow < bi)) ? 1 : 0; - r.l[i] = s; - } - if (borrow) { - uint64_t carry = 0; - uint64_t pl[4] = { BN254_P0, BN254_P1, BN254_P2, BN254_P3 }; - for (int i = 0; i < 4; ++i) { - uint64_t s = r.l[i] + pl[i]; - uint64_t c1 = (s < r.l[i]) ? 1 : 0; - uint64_t s2 = s + carry; - uint64_t c2 = (s2 < s) ? 1 : 0; - r.l[i] = s2; - carry = c1 + c2; - } - } - return r; -} - -#ifdef __CUDACC__ -__device__ static inline uint64_t mul_hi64(uint64_t a, uint64_t b) { - return __umul64hi(a, b); -} -#else -__device__ static inline uint64_t mul_hi64(uint64_t a, uint64_t b) { -#if defined(__SIZEOF_INT128__) - return (uint64_t)(((__uint128_t)a * (__uint128_t)b) >> 64); -#else - uint64_t a_lo = (uint32_t)a, a_hi = a >> 32; - uint64_t b_lo = (uint32_t)b, b_hi = b >> 32; - uint64_t ll = a_lo * b_lo; - uint64_t lh = a_lo * b_hi; - uint64_t hl = a_hi * b_lo; - uint64_t hh = a_hi * b_hi; - uint64_t mid = (ll >> 32) + (uint32_t)lh + (uint32_t)hl; - return hh + (lh >> 32) + (hl >> 32) + (mid >> 32); -#endif -} -#endif - -__device__ static inline U256 fp_mont_mul(const U256& a, const U256& b) { - uint64_t pl[4] = { BN254_P0, BN254_P1, BN254_P2, BN254_P3 }; - uint64_t t0 = 0, t1 = 0, t2 = 0, t3 = 0, t4 = 0, t5 = 0; - - for (int i = 0; i < 4; ++i) { - uint64_t ai = a.l[i]; - { - uint64_t carry = 0, lo, hi; - lo = ai * b.l[0]; hi = mul_hi64(ai, b.l[0]); - uint64_t s = t0 + lo; uint64_t c1 = (s < t0) ? 1ULL : 0ULL; - t0 = s; carry = hi + c1; - lo = ai * b.l[1]; hi = mul_hi64(ai, b.l[1]); - s = t1 + lo; c1 = (s < t1) ? 1ULL : 0ULL; - uint64_t s2 = s + carry; uint64_t c2 = (s2 < s) ? 1ULL : 0ULL; - t1 = s2; carry = hi + c1 + c2; - lo = ai * b.l[2]; hi = mul_hi64(ai, b.l[2]); - s = t2 + lo; c1 = (s < t2) ? 1ULL : 0ULL; - s2 = s + carry; c2 = (s2 < s) ? 1ULL : 0ULL; - t2 = s2; carry = hi + c1 + c2; - lo = ai * b.l[3]; hi = mul_hi64(ai, b.l[3]); - s = t3 + lo; c1 = (s < t3) ? 1ULL : 0ULL; - s2 = s + carry; c2 = (s2 < s) ? 1ULL : 0ULL; - t3 = s2; carry = hi + c1 + c2; - s = t4 + carry; c1 = (s < t4) ? 1ULL : 0ULL; - t4 = s; t5 = t5 + c1; - } - uint64_t m = t0 * BN254_INV; - { - uint64_t carry = 0, lo, hi; - lo = m * pl[0]; hi = mul_hi64(m, pl[0]); - uint64_t s = t0 + lo; uint64_t c1 = (s < t0) ? 1ULL : 0ULL; - carry = hi + c1; - lo = m * pl[1]; hi = mul_hi64(m, pl[1]); - s = t1 + lo; c1 = (s < t1) ? 1ULL : 0ULL; - uint64_t s2 = s + carry; uint64_t c2 = (s2 < s) ? 1ULL : 0ULL; - t1 = s2; carry = hi + c1 + c2; - lo = m * pl[2]; hi = mul_hi64(m, pl[2]); - s = t2 + lo; c1 = (s < t2) ? 1ULL : 0ULL; - s2 = s + carry; c2 = (s2 < s) ? 1ULL : 0ULL; - t2 = s2; carry = hi + c1 + c2; - lo = m * pl[3]; hi = mul_hi64(m, pl[3]); - s = t3 + lo; c1 = (s < t3) ? 1ULL : 0ULL; - s2 = s + carry; c2 = (s2 < s) ? 1ULL : 0ULL; - t3 = s2; carry = hi + c1 + c2; - s = t4 + carry; c1 = (s < t4) ? 1ULL : 0ULL; - t4 = s; t5 = t5 + c1; - t0 = t1; t1 = t2; t2 = t3; t3 = t4; t4 = t5; t5 = 0; - } - } - - U256 r; r.l[0]=t0; r.l[1]=t1; r.l[2]=t2; r.l[3]=t3; - if (t4 != 0) { - U256 p = fp_p(); - r = fp_sub(r, p); - } - return fp_csub_p(r); -} - -__device__ static inline U256 fp_mont_sqr(const U256& a) { return fp_mont_mul(a, a); } - -__device__ static inline U256 fp_one_mont() { - U256 r; r.l[0]=BN254_R_0; r.l[1]=BN254_R_1; r.l[2]=BN254_R_2; r.l[3]=BN254_R_3; - return r; -} - -__device__ static inline U256 fp_r2() { - U256 r; r.l[0]=BN254_R2_0; r.l[1]=BN254_R2_1; r.l[2]=BN254_R2_2; r.l[3]=BN254_R2_3; - return r; -} - -__device__ static inline U256 fp_to_mont(const U256& x) { return fp_mont_mul(x, fp_r2()); } - -__device__ static inline U256 fp_from_mont(const U256& x) { - U256 one; one.l[0]=1; one.l[1]=0; one.l[2]=0; one.l[3]=0; - return fp_mont_mul(x, one); -} - -__device__ static inline U256 fp_inv(const U256& a) { - if (u256_is_zero(a)) return a; - uint64_t exp[4] = { BN254_P0 - 2ULL, BN254_P1, BN254_P2, BN254_P3 }; - U256 result = fp_one_mont(); - U256 base = a; - for (int limb = 0; limb < 4; ++limb) { - uint64_t e = exp[limb]; - for (int b = 0; b < 64; ++b) { - if ((e >> b) & 1ULL) result = fp_mont_mul(result, base); - base = fp_mont_sqr(base); - } - } - return result; -} - -// ============================================================================= -// G1 in Jacobian -// ============================================================================= - -struct G1Jac { U256 X, Y, Z; }; -struct G1Aff { U256 X, Y; bool inf; }; - -__device__ static inline G1Jac g1_zero() { - G1Jac p; - p.X = fp_one_mont(); p.Y = fp_one_mont(); p.Z = u256_zero(); - return p; -} - -__device__ static inline bool g1_is_zero(const G1Jac& p) { return u256_is_zero(p.Z); } - -__device__ static inline G1Jac g1_dbl(const G1Jac& p) { - if (g1_is_zero(p)) return p; - U256 A = fp_mont_sqr(p.X); - U256 B = fp_mont_sqr(p.Y); - U256 C = fp_mont_sqr(B); - U256 t = fp_add(p.X, B); - U256 t2 = fp_mont_sqr(t); - U256 t3 = fp_sub(t2, A); - U256 t4 = fp_sub(t3, C); - U256 D = fp_add(t4, t4); - U256 E = fp_add(fp_add(A, A), A); - U256 F = fp_mont_sqr(E); - G1Jac r; - U256 twoD = fp_add(D, D); - r.X = fp_sub(F, twoD); - U256 D_minus_X = fp_sub(D, r.X); - U256 EDX = fp_mont_mul(E, D_minus_X); - U256 eightC = fp_add(C, C); - eightC = fp_add(eightC, eightC); - eightC = fp_add(eightC, eightC); - r.Y = fp_sub(EDX, eightC); - U256 YZ = fp_mont_mul(p.Y, p.Z); - r.Z = fp_add(YZ, YZ); - return r; -} - -__device__ static inline G1Jac g1_add_mixed(const G1Jac& p, const U256& Qx, const U256& Qy) { - if (g1_is_zero(p)) { - G1Jac r; - r.X = Qx; r.Y = Qy; r.Z = fp_one_mont(); - return r; - } - U256 Z1Z1 = fp_mont_sqr(p.Z); - U256 U2 = fp_mont_mul(Qx, Z1Z1); - U256 S2 = fp_mont_mul(Qy, fp_mont_mul(p.Z, Z1Z1)); - U256 H = fp_sub(U2, p.X); - U256 r_v = fp_sub(S2, p.Y); - if (u256_is_zero(H)) { - if (u256_is_zero(r_v)) return g1_dbl(p); - return g1_zero(); - } - U256 HH = fp_mont_sqr(H); - U256 I = fp_add(HH, HH); I = fp_add(I, I); - U256 J = fp_mont_mul(H, I); - U256 r_2 = fp_add(r_v, r_v); - U256 V = fp_mont_mul(p.X, I); - G1Jac out; - U256 r_sq = fp_mont_sqr(r_2); - U256 t1 = fp_sub(r_sq, J); - U256 twoV = fp_add(V, V); - out.X = fp_sub(t1, twoV); - U256 V_minus_X3 = fp_sub(V, out.X); - U256 r_VX = fp_mont_mul(r_2, V_minus_X3); - U256 Y1J = fp_mont_mul(p.Y, J); - U256 twoY1J = fp_add(Y1J, Y1J); - out.Y = fp_sub(r_VX, twoY1J); - out.Z = fp_mont_mul(p.Z, fp_add(H, H)); - return out; -} - -__device__ static inline G1Jac g1_add(const G1Jac& p, const G1Jac& q) { - if (g1_is_zero(p)) return q; - if (g1_is_zero(q)) return p; - U256 Z1Z1 = fp_mont_sqr(p.Z); - U256 Z2Z2 = fp_mont_sqr(q.Z); - U256 U1 = fp_mont_mul(p.X, Z2Z2); - U256 U2 = fp_mont_mul(q.X, Z1Z1); - U256 S1 = fp_mont_mul(fp_mont_mul(p.Y, q.Z), Z2Z2); - U256 S2 = fp_mont_mul(fp_mont_mul(q.Y, p.Z), Z1Z1); - U256 H = fp_sub(U2, U1); - U256 r_v = fp_sub(S2, S1); - if (u256_is_zero(H)) { - if (u256_is_zero(r_v)) return g1_dbl(p); - return g1_zero(); - } - U256 r2 = fp_add(r_v, r_v); - U256 HH = fp_mont_sqr(H); - U256 I = fp_add(HH, HH); I = fp_add(I, I); - U256 J = fp_mont_mul(H, I); - U256 V = fp_mont_mul(U1, I); - G1Jac out; - U256 r_sq = fp_mont_sqr(r2); - U256 t1 = fp_sub(r_sq, J); - U256 twoV = fp_add(V, V); - out.X = fp_sub(t1, twoV); - U256 V_minus_X3 = fp_sub(V, out.X); - U256 r_VX = fp_mont_mul(r2, V_minus_X3); - U256 S1J = fp_mont_mul(S1, J); - U256 twoS1J = fp_add(S1J, S1J); - out.Y = fp_sub(r_VX, twoS1J); - U256 Z1Z2 = fp_mont_mul(p.Z, q.Z); - out.Z = fp_mont_mul(Z1Z2, fp_add(H, H)); - return out; -} - -__device__ static inline G1Aff g1_to_affine(const G1Jac& p) { - G1Aff r; - if (g1_is_zero(p)) { r.X = u256_zero(); r.Y = u256_zero(); r.inf = true; return r; } - U256 Zinv = fp_inv(p.Z); - U256 Zinv2 = fp_mont_sqr(Zinv); - U256 Zinv3 = fp_mont_mul(Zinv2, Zinv); - r.X = fp_mont_mul(p.X, Zinv2); - r.Y = fp_mont_mul(p.Y, Zinv3); - r.inf = false; - return r; -} - -__device__ static inline G1Jac g1_scalar_mul_aff(const U256& Qx, const U256& Qy, - const uint64_t s[4]) { - G1Jac acc = g1_zero(); - for (int li = 3; li >= 0; --li) { - uint64_t limb = s[li]; - for (int bi = 63; bi >= 0; --bi) { - acc = g1_dbl(acc); - if ((limb >> bi) & 1ULL) acc = g1_add_mixed(acc, Qx, Qy); - } - } - return acc; -} - -// ============================================================================= -// I/O helpers -// ============================================================================= - -__device__ static inline U256 read_be32(const uint8_t* p) { - U256 r; - for (int limb = 0; limb < 4; ++limb) { - const uint8_t* src = p + (3 - limb) * 8; - uint64_t v = 0; - for (int i = 0; i < 8; ++i) v = (v << 8) | (uint64_t)src[i]; - r.l[limb] = v; - } - return r; -} - -__device__ static inline void write_be32(uint8_t* p, const U256& a) { - for (int limb = 0; limb < 4; ++limb) { - uint8_t* dst = p + (3 - limb) * 8; - uint64_t v = a.l[limb]; - for (int i = 7; i >= 0; --i) { - dst[i] = (uint8_t)(v & 0xFFu); - v >>= 8; - } - } -} - -// ============================================================================= -// Tree-reduce kernel: one block per commitment, 256 threads, shared-memory tree -// ============================================================================= -// -// Block-shared layout: shared_pts[256][12] u64 = 24 KiB per block. -// The reduction collapses 256 partial Jacobian points into 1 in 8 strides. -// Thread 0 then folds in r*H and emits 64 BE bytes. - -#define PED_TREE_N 256u - -struct PedTreeDimsCUDA { uint32_t M; uint32_t N; }; - -extern "C" __global__ void k_pedersen_tree_commit( - const uint8_t* __restrict__ gens_be, - const uint8_t* __restrict__ scalars_be, - const uint8_t* __restrict__ blindings_be, - uint8_t* __restrict__ out_be, - PedTreeDimsCUDA dims) { -#ifdef __CUDACC__ - const uint32_t tid = threadIdx.x; - const uint32_t bid = blockIdx.x; - const uint32_t bdim = blockDim.x; -#else - const uint32_t tid = 0; - const uint32_t bid = 0; - const uint32_t bdim = 1; -#endif - const uint32_t N = dims.N; - const uint32_t M = dims.M; - if (bid >= M) return; - if (tid >= N) return; - if (bdim != PED_TREE_N) return; // contract; host driver enforces - - __shared__ uint64_t shared_pts[PED_TREE_N * 12]; - - // Phase 1: each thread's term = scalar * generator. - U256 Qx_raw = read_be32(gens_be + tid * 64); - U256 Qy_raw = read_be32(gens_be + tid * 64 + 32); - U256 sc_raw = read_be32(scalars_be + (bid * N + tid) * 32); - U256 Qx = fp_to_mont(Qx_raw); - U256 Qy = fp_to_mont(Qy_raw); - uint64_t s[4] = { sc_raw.l[0], sc_raw.l[1], sc_raw.l[2], sc_raw.l[3] }; - G1Jac P = g1_scalar_mul_aff(Qx, Qy, s); - - uint32_t base = tid * 12; - for (int k = 0; k < 4; ++k) shared_pts[base + 0 + k] = P.X.l[k]; - for (int k = 0; k < 4; ++k) shared_pts[base + 4 + k] = P.Y.l[k]; - for (int k = 0; k < 4; ++k) shared_pts[base + 8 + k] = P.Z.l[k]; - __syncthreads(); - - // Phase 2: shared-memory tree reduce. 8 strides for N = 256. - for (uint32_t stride = 128u; stride > 0u; stride >>= 1) { - if (tid < stride) { - uint32_t b0 = tid * 12; - G1Jac a; - for (int k = 0; k < 4; ++k) a.X.l[k] = shared_pts[b0 + 0 + k]; - for (int k = 0; k < 4; ++k) a.Y.l[k] = shared_pts[b0 + 4 + k]; - for (int k = 0; k < 4; ++k) a.Z.l[k] = shared_pts[b0 + 8 + k]; - uint32_t b1 = (tid + stride) * 12; - G1Jac b; - for (int k = 0; k < 4; ++k) b.X.l[k] = shared_pts[b1 + 0 + k]; - for (int k = 0; k < 4; ++k) b.Y.l[k] = shared_pts[b1 + 4 + k]; - for (int k = 0; k < 4; ++k) b.Z.l[k] = shared_pts[b1 + 8 + k]; - G1Jac sum = g1_add(a, b); - for (int k = 0; k < 4; ++k) shared_pts[b0 + 0 + k] = sum.X.l[k]; - for (int k = 0; k < 4; ++k) shared_pts[b0 + 4 + k] = sum.Y.l[k]; - for (int k = 0; k < 4; ++k) shared_pts[b0 + 8 + k] = sum.Z.l[k]; - } - __syncthreads(); - } - - // Phase 3 + 4: thread 0 finishes. - if (tid == 0) { - G1Jac acc; - for (int k = 0; k < 4; ++k) acc.X.l[k] = shared_pts[0 + k]; - for (int k = 0; k < 4; ++k) acc.Y.l[k] = shared_pts[4 + k]; - for (int k = 0; k < 4; ++k) acc.Z.l[k] = shared_pts[8 + k]; - - U256 Hx_raw = read_be32(gens_be + N * 64); - U256 Hy_raw = read_be32(gens_be + N * 64 + 32); - U256 r_raw = read_be32(blindings_be + bid * 32); - U256 Hx = fp_to_mont(Hx_raw); - U256 Hy = fp_to_mont(Hy_raw); - uint64_t rs[4] = { r_raw.l[0], r_raw.l[1], r_raw.l[2], r_raw.l[3] }; - G1Jac rH = g1_scalar_mul_aff(Hx, Hy, rs); - acc = g1_add(acc, rH); - - G1Aff aff = g1_to_affine(acc); - uint8_t* dst = out_be + bid * 64; - if (aff.inf) { - for (int b = 0; b < 64; ++b) dst[b] = 0; - return; - } - U256 X_raw = fp_from_mont(aff.X); - U256 Y_raw = fp_from_mont(aff.Y); - write_be32(dst, X_raw); - write_be32(dst + 32, Y_raw); - } -} diff --git a/pedersen/gpu/cuda/pedersen_tree_driver.cpp b/pedersen/gpu/cuda/pedersen_tree_driver.cpp deleted file mode 100644 index 7edd39f..0000000 --- a/pedersen/gpu/cuda/pedersen_tree_driver.cpp +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CUDA host driver for the tree-reduce Pedersen vector commitment. - -#include "pedersen_tree_driver.h" - -#include -#include - -#ifdef LUX_PEDERSEN_HAVE_CUDA -#include - -struct PedTreeDimsHost { uint32_t M; uint32_t N; }; - -extern "C" __global__ void k_pedersen_tree_commit( - const uint8_t*, const uint8_t*, const uint8_t*, - uint8_t*, PedTreeDimsHost); - -extern "C" int lux_pedersen_tree_cuda_available(void) { - int count = 0; - cudaError_t e = cudaGetDeviceCount(&count); - return (e == cudaSuccess && count > 0) ? 1 : 0; -} - -extern "C" int pedersen_tree_cuda( - const uint8_t* gens_be, - const uint8_t* scalars_be, - const uint8_t* blindings_be, - uint32_t M, - uint8_t* out_be) { - if (M == 0) return 0; - if (!gens_be || !scalars_be || !blindings_be || !out_be) return -1; - if (!lux_pedersen_tree_cuda_available()) return -1; - - const uint32_t N = PEDERSEN_TREE_WIDTH; - size_t gens_len = (size_t)(N + 1) * 64; - size_t scalars_len = (size_t)M * N * 32; - size_t blind_len = (size_t)M * 32; - size_t out_len = (size_t)M * 64; - - uint8_t *dGens=nullptr, *dScalars=nullptr, *dBlind=nullptr, *dOut=nullptr; - - auto cleanup = [&]() { - if (dGens) cudaFree(dGens); - if (dScalars) cudaFree(dScalars); - if (dBlind) cudaFree(dBlind); - if (dOut) cudaFree(dOut); - }; - - if (cudaMalloc((void**)&dGens, gens_len) != cudaSuccess) { cleanup(); return -2; } - if (cudaMalloc((void**)&dScalars, scalars_len) != cudaSuccess) { cleanup(); return -2; } - if (cudaMalloc((void**)&dBlind, blind_len) != cudaSuccess) { cleanup(); return -2; } - if (cudaMalloc((void**)&dOut, out_len) != cudaSuccess) { cleanup(); return -2; } - - cudaMemcpy(dGens, gens_be, gens_len, cudaMemcpyHostToDevice); - cudaMemcpy(dScalars, scalars_be, scalars_len, cudaMemcpyHostToDevice); - cudaMemcpy(dBlind, blindings_be, blind_len, cudaMemcpyHostToDevice); - - PedTreeDimsHost dims{ M, N }; - - // M blocks of 256 threads each. Single dispatch: log_2 N round-trips - // collapse into one shared-memory reduction. - k_pedersen_tree_commit<<>>(dGens, dScalars, dBlind, dOut, dims); - if (cudaDeviceSynchronize() != cudaSuccess) { cleanup(); return -3; } - - cudaMemcpy(out_be, dOut, out_len, cudaMemcpyDeviceToHost); - cleanup(); - return 0; -} - -#else // LUX_PEDERSEN_HAVE_CUDA not defined: stub mode - -extern "C" int lux_pedersen_tree_cuda_available(void) { return 0; } - -extern "C" int pedersen_tree_cuda( - const uint8_t*, const uint8_t*, const uint8_t*, - uint32_t, uint8_t*) { - return -1; -} - -#endif diff --git a/pedersen/gpu/cuda/pedersen_tree_driver.h b/pedersen/gpu/cuda/pedersen_tree_driver.h deleted file mode 100644 index 767aadf..0000000 --- a/pedersen/gpu/cuda/pedersen_tree_driver.h +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Tree-reduce CUDA driver for the batched Pedersen vector commitment at the -// fixed Verkle width N = 256. - -#ifndef LUX_PEDERSEN_TREE_DRIVER_CUDA_H -#define LUX_PEDERSEN_TREE_DRIVER_CUDA_H - -#include - -#ifndef PEDERSEN_TREE_WIDTH -#define PEDERSEN_TREE_WIDTH 256u -#endif - -#ifdef __cplusplus -extern "C" { -#endif - -int lux_pedersen_tree_cuda_available(void); - -// Computes M Pedersen commitments at the fixed width N = 256 in a single -// CUDA dispatch with block-shared-memory tree reduction. Wire format -// identical to pedersen_batch_cuda. -int pedersen_tree_cuda( - const uint8_t* gens_be, - const uint8_t* scalars_be, - const uint8_t* blindings_be, - uint32_t M, - uint8_t* out_be); - -#ifdef __cplusplus -} -#endif - -#endif // LUX_PEDERSEN_TREE_DRIVER_CUDA_H diff --git a/pedersen/gpu/metal/pedersen.metal b/pedersen/gpu/metal/pedersen.metal deleted file mode 100644 index db58c20..0000000 --- a/pedersen/gpu/metal/pedersen.metal +++ /dev/null @@ -1,663 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// First-party Metal kernel for batched Pedersen vector commitments over -// BN254 G1. -// -// Computed quantity (M parallel commitments, each of size N): -// -// C_m = sum_{i=0..N-1} s[m][i] * G[i] + r[m] * H for m = 0..M-1 -// -// Two-stage pipeline (one metallib, two kernels): -// -// 1. pedersen_pointmul -- M*(N+1) threads, one per (commitment, term). -// Each thread does ONE scalar multiplication and writes the resulting -// Jacobian point to a scratch buffer. -// -// 2. pedersen_reduce_add -- M threads, one per commitment. Each thread -// sums (N+1) Jacobian points from scratch, then converts the result to -// affine and writes it as 64-byte big-endian (32 BE for x, 32 BE for y). -// -// Byte-equality target: the Go canonical at -// github.com/luxfi/crypto/pedersen -// (single-scalar case is the n=1 reduction; vector case matches the fixture -// generator at pedersen/test/tools/gen_pedersen_metal_kat.go.) -// -// Wire formats: -// -// Generators G_basis (N points) + H (1 point) -// -> buffer of (N+1) * 64 bytes, raw big-endian (X || Y), gnark layout -// -// Scalars S of shape [M][N] -// -> buffer of M*N*32 bytes, raw big-endian Fr (already reduced mod r) -// -// Blindings r of shape [M] -// -> buffer of M*32 bytes, raw big-endian Fr (already reduced mod r) -// -// Output commitments C of shape [M] -// -> buffer of M*64 bytes, raw big-endian (X || Y) - -#include -using namespace metal; - -// ============================================================================= -// BN254 base-field constants (p = 21888242871839275222246405745257275088696...) -// ============================================================================= - -// p (4 limbs little-endian) -constant uint64_t BN254_P0 = 0x3C208C16D87CFD47ULL; -constant uint64_t BN254_P1 = 0x97816A916871CA8DULL; -constant uint64_t BN254_P2 = 0xB85045B68181585DULL; -constant uint64_t BN254_P3 = 0x30644E72E131A029ULL; - -// R = 2^256 mod p (Montgomery one) -constant uint64_t BN254_R_0 = 0xD35D438DC58F0D9DULL; -constant uint64_t BN254_R_1 = 0x0A78EB28F5C70B3DULL; -constant uint64_t BN254_R_2 = 0x666EA36F7879462CULL; -constant uint64_t BN254_R_3 = 0x0E0A77C19A07DF2FULL; - -// R^2 mod p (used to enter Montgomery form: x_mont = montmul(x, R^2)) -constant uint64_t BN254_R2_0 = 0xF32CFC5B538AFA89ULL; -constant uint64_t BN254_R2_1 = 0xB5E71911D44501FBULL; -constant uint64_t BN254_R2_2 = 0x47AB1EFF0A417FF6ULL; -constant uint64_t BN254_R2_3 = 0x06D89F71CAB8351FULL; - -// -p^{-1} mod 2^64 -constant uint64_t BN254_INV = 0x87D20782E4866389ULL; - -// ============================================================================= -// Plain 256-bit big integer helpers (used for I/O conversion + Montgomery mul) -// ============================================================================= - -struct U256 { uint64_t l[4]; }; // little-endian limbs - -inline U256 u256_zero() { - U256 x; - x.l[0] = 0; x.l[1] = 0; x.l[2] = 0; x.l[3] = 0; - return x; -} - -inline bool u256_is_zero(thread const U256& a) { - return (a.l[0] | a.l[1] | a.l[2] | a.l[3]) == 0; -} - -inline bool u256_eq(thread const U256& a, thread const U256& b) { - return a.l[0] == b.l[0] && a.l[1] == b.l[1] && - a.l[2] == b.l[2] && a.l[3] == b.l[3]; -} - -inline int u256_cmp_p(thread const U256& a) { - if (a.l[3] != BN254_P3) return a.l[3] > BN254_P3 ? 1 : -1; - if (a.l[2] != BN254_P2) return a.l[2] > BN254_P2 ? 1 : -1; - if (a.l[1] != BN254_P1) return a.l[1] > BN254_P1 ? 1 : -1; - if (a.l[0] != BN254_P0) return a.l[0] > BN254_P0 ? 1 : -1; - return 0; -} - -inline U256 fp_p() { - U256 r; r.l[0]=BN254_P0; r.l[1]=BN254_P1; r.l[2]=BN254_P2; r.l[3]=BN254_P3; - return r; -} - -// Conditionally subtract p (single trial; result in [0, p)). -inline U256 fp_csub_p(thread const U256& a) { - if (u256_cmp_p(a) < 0) return a; - U256 r; - uint64_t borrow = 0; - uint64_t pl[4] = { BN254_P0, BN254_P1, BN254_P2, BN254_P3 }; - for (int i = 0; i < 4; ++i) { - uint64_t ai = a.l[i]; - uint64_t s = ai - pl[i] - borrow; - // borrow if (a < b + borrow) -- careful unsigned underflow check - borrow = ((ai < pl[i] + borrow) || (pl[i] + borrow < pl[i])) ? 1 : 0; - r.l[i] = s; - } - return r; -} - -// (a + b) mod p, inputs already < p (or sum < 2p which we then reduce). -inline U256 fp_add(thread const U256& a, thread const U256& b) { - U256 r; - uint64_t carry = 0; - for (int i = 0; i < 4; ++i) { - uint64_t s = a.l[i] + b.l[i]; - uint64_t c1 = (s < a.l[i]) ? 1 : 0; - uint64_t s2 = s + carry; - uint64_t c2 = (s2 < s) ? 1 : 0; - r.l[i] = s2; - carry = c1 + c2; - } - return fp_csub_p(r); -} - -// (a - b) mod p -inline U256 fp_sub(thread const U256& a, thread const U256& b) { - U256 r; - uint64_t borrow = 0; - for (int i = 0; i < 4; ++i) { - uint64_t bi = b.l[i]; - uint64_t s = a.l[i] - bi - borrow; - borrow = ((a.l[i] < bi + borrow) || (bi + borrow < bi)) ? 1 : 0; - r.l[i] = s; - } - if (borrow) { - // add p back - uint64_t carry = 0; - uint64_t pl[4] = { BN254_P0, BN254_P1, BN254_P2, BN254_P3 }; - for (int i = 0; i < 4; ++i) { - uint64_t s = r.l[i] + pl[i]; - uint64_t c1 = (s < r.l[i]) ? 1 : 0; - uint64_t s2 = s + carry; - uint64_t c2 = (s2 < s) ? 1 : 0; - r.l[i] = s2; - carry = c1 + c2; - } - } - return r; -} - -inline U256 fp_neg(thread const U256& a) { - if (u256_is_zero(a)) return a; - U256 p = fp_p(); - return fp_sub(p, a); -} - -// ============================================================================= -// Schoolbook 256x256->512 + Montgomery reduction (CIOS) -// ============================================================================= - -// Add (lo, hi) into accumulator (acc_lo, acc_hi) with carry-out. -// Computes (acc_hi:acc_lo) = (acc_hi:acc_lo) + (hi:lo); returns carry-out (0 or 1). -inline uint64_t addc128(thread uint64_t& acc_lo, thread uint64_t& acc_hi, - uint64_t lo, uint64_t hi) { - uint64_t a = acc_lo + lo; - uint64_t c1 = (a < acc_lo) ? 1ULL : 0ULL; - uint64_t b = acc_hi + hi; - uint64_t c2 = (b < acc_hi) ? 1ULL : 0ULL; - uint64_t b2 = b + c1; - uint64_t c3 = (b2 < b) ? 1ULL : 0ULL; - acc_lo = a; - acc_hi = b2; - return c2 + c3; -} - -// Montgomery multiplication: returns a*b*R^{-1} mod p, where a,b are in -// Montgomery form (or general residues < p). CIOS, 4-limb specialized. -inline U256 fp_mont_mul(thread const U256& a, thread const U256& b) { - uint64_t pl[4] = { BN254_P0, BN254_P1, BN254_P2, BN254_P3 }; - - // 5-limb accumulator + carry bit. - uint64_t t0 = 0, t1 = 0, t2 = 0, t3 = 0, t4 = 0; - uint64_t t5 = 0; // overflow guard (at most 1 across the loop) - - for (int i = 0; i < 4; ++i) { - uint64_t ai = a.l[i]; - - // Step A: t += a[i] * b - { - uint64_t carry = 0; - uint64_t lo, hi; - // j=0 - lo = ai * b.l[0]; hi = mulhi(ai, b.l[0]); - uint64_t s = t0 + lo; - uint64_t c1 = (s < t0) ? 1ULL : 0ULL; - t0 = s; - carry = hi + c1; - // j=1 - lo = ai * b.l[1]; hi = mulhi(ai, b.l[1]); - s = t1 + lo; - c1 = (s < t1) ? 1ULL : 0ULL; - uint64_t s2 = s + carry; - uint64_t c2 = (s2 < s) ? 1ULL : 0ULL; - t1 = s2; - carry = hi + c1 + c2; - // j=2 - lo = ai * b.l[2]; hi = mulhi(ai, b.l[2]); - s = t2 + lo; - c1 = (s < t2) ? 1ULL : 0ULL; - s2 = s + carry; - c2 = (s2 < s) ? 1ULL : 0ULL; - t2 = s2; - carry = hi + c1 + c2; - // j=3 - lo = ai * b.l[3]; hi = mulhi(ai, b.l[3]); - s = t3 + lo; - c1 = (s < t3) ? 1ULL : 0ULL; - s2 = s + carry; - c2 = (s2 < s) ? 1ULL : 0ULL; - t3 = s2; - carry = hi + c1 + c2; - // propagate into t4, t5 - s = t4 + carry; - c1 = (s < t4) ? 1ULL : 0ULL; - t4 = s; - t5 = t5 + c1; - } - - // Step B: m = t0 * INV mod 2^64; t = t + m*p; then shift down. - uint64_t m = t0 * BN254_INV; - { - uint64_t carry = 0; - uint64_t lo, hi; - // j=0 (this zeroes out t0) - lo = m * pl[0]; hi = mulhi(m, pl[0]); - uint64_t s = t0 + lo; - uint64_t c1 = (s < t0) ? 1ULL : 0ULL; - // discard low result (it should equal 0 mod 2^64) - carry = hi + c1; - // j=1 - lo = m * pl[1]; hi = mulhi(m, pl[1]); - s = t1 + lo; - c1 = (s < t1) ? 1ULL : 0ULL; - uint64_t s2 = s + carry; - uint64_t c2 = (s2 < s) ? 1ULL : 0ULL; - t1 = s2; - carry = hi + c1 + c2; - // j=2 - lo = m * pl[2]; hi = mulhi(m, pl[2]); - s = t2 + lo; - c1 = (s < t2) ? 1ULL : 0ULL; - s2 = s + carry; - c2 = (s2 < s) ? 1ULL : 0ULL; - t2 = s2; - carry = hi + c1 + c2; - // j=3 - lo = m * pl[3]; hi = mulhi(m, pl[3]); - s = t3 + lo; - c1 = (s < t3) ? 1ULL : 0ULL; - s2 = s + carry; - c2 = (s2 < s) ? 1ULL : 0ULL; - t3 = s2; - carry = hi + c1 + c2; - // propagate - s = t4 + carry; - c1 = (s < t4) ? 1ULL : 0ULL; - t4 = s; - t5 = t5 + c1; - - // shift down by one limb - t0 = t1; - t1 = t2; - t2 = t3; - t3 = t4; - t4 = t5; - t5 = 0; - } - } - - U256 r; r.l[0]=t0; r.l[1]=t1; r.l[2]=t2; r.l[3]=t3; - if (t4 != 0) { - // single-bit overflow: subtract p once. - U256 p = fp_p(); - r = fp_sub(r, p); - } - return fp_csub_p(r); -} - -inline U256 fp_mont_sqr(thread const U256& a) { - return fp_mont_mul(a, a); -} - -inline U256 fp_one_mont() { - U256 r; r.l[0]=BN254_R_0; r.l[1]=BN254_R_1; r.l[2]=BN254_R_2; r.l[3]=BN254_R_3; - return r; -} - -inline U256 fp_r2() { - U256 r; r.l[0]=BN254_R2_0; r.l[1]=BN254_R2_1; r.l[2]=BN254_R2_2; r.l[3]=BN254_R2_3; - return r; -} - -// Enter Montgomery form: returns x * R mod p (assumes x already < p). -inline U256 fp_to_mont(thread const U256& x) { - return fp_mont_mul(x, fp_r2()); -} - -// Leave Montgomery form: returns x * R^{-1} mod p = montmul(x, 1). -inline U256 fp_from_mont(thread const U256& x) { - U256 one; - one.l[0]=1; one.l[1]=0; one.l[2]=0; one.l[3]=0; - return fp_mont_mul(x, one); -} - -// Inversion via Fermat's little theorem: a^(p-2) mod p. -// (No need for fast inversion -- only called M times per pipeline.) -inline U256 fp_inv(thread const U256& a) { - // Exponent is p - 2. Compute as repeated square-and-multiply with the - // bit pattern of (p - 2) read MSB->LSB. We unroll the bit loop on the - // four limbs. When a == 0, returns 0 (caller must guard). - if (u256_is_zero(a)) return a; - - // p - 2 in 4 LE limbs: - // p0 - 2 (no borrow possible since p0 is large), p1, p2, p3 - uint64_t e0 = BN254_P0 - 2ULL; - uint64_t e1 = BN254_P1; - uint64_t e2 = BN254_P2; - uint64_t e3 = BN254_P3; - uint64_t exp[4] = { e0, e1, e2, e3 }; - - U256 result = fp_one_mont(); - U256 base = a; - // process bits LSB->MSB - for (int limb = 0; limb < 4; ++limb) { - uint64_t e = exp[limb]; - for (int b = 0; b < 64; ++b) { - if ((e >> b) & 1ULL) { - result = fp_mont_mul(result, base); - } - base = fp_mont_sqr(base); - } - } - return result; -} - -// ============================================================================= -// G1 in Jacobian coordinates (Montgomery form for X, Y, Z) -// ============================================================================= - -struct G1Jac { - U256 X; - U256 Y; - U256 Z; // Z == 0 represents point at infinity -}; - -struct G1Aff { - U256 X; - U256 Y; - bool inf; -}; - -inline G1Jac g1_zero() { - G1Jac p; - p.X = fp_one_mont(); - p.Y = fp_one_mont(); - p.Z = u256_zero(); - return p; -} - -inline bool g1_is_zero(thread const G1Jac& p) { return u256_is_zero(p.Z); } - -// Doubling, BN254 a = 0 specialization (https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-0.html#doubling-dbl-2009-l) -inline G1Jac g1_dbl(thread const G1Jac& p) { - if (g1_is_zero(p)) return p; - U256 A = fp_mont_sqr(p.X); // X^2 - U256 B = fp_mont_sqr(p.Y); // Y^2 - U256 C = fp_mont_sqr(B); // Y^4 - U256 t = fp_add(p.X, B); - U256 t2 = fp_mont_sqr(t); // (X+Y^2)^2 - U256 t3 = fp_sub(t2, A); - U256 t4 = fp_sub(t3, C); - U256 D = fp_add(t4, t4); // 2*((X+Y^2)^2 - X^2 - Y^4) - U256 E = fp_add(fp_add(A, A), A); // 3*X^2 - U256 F = fp_mont_sqr(E); // E^2 - G1Jac r; - U256 twoD = fp_add(D, D); - r.X = fp_sub(F, twoD); // F - 2D - U256 D_minus_X = fp_sub(D, r.X); - U256 EDX = fp_mont_mul(E, D_minus_X); - U256 eightC = fp_add(C, C); - eightC = fp_add(eightC, eightC); - eightC = fp_add(eightC, eightC); // 8*C - r.Y = fp_sub(EDX, eightC); // E*(D - X3) - 8*C - U256 YZ = fp_mont_mul(p.Y, p.Z); - r.Z = fp_add(YZ, YZ); // 2 Y Z - return r; -} - -// Mixed addition: Jacobian + Affine (Z2 = 1). Standard formulas. -// Inputs in Montgomery form. Returns Jacobian. -inline G1Jac g1_add_mixed(thread const G1Jac& p, thread const U256& Qx, thread const U256& Qy) { - if (g1_is_zero(p)) { - G1Jac r; - r.X = Qx; r.Y = Qy; r.Z = fp_one_mont(); - return r; - } - U256 Z1Z1 = fp_mont_sqr(p.Z); // Z1^2 - U256 U2 = fp_mont_mul(Qx, Z1Z1); // U2 = X2*Z1^2 - U256 S2 = fp_mont_mul(Qy, fp_mont_mul(p.Z, Z1Z1)); // S2 = Y2*Z1^3 - U256 H = fp_sub(U2, p.X); // H = U2 - X1 - U256 r_v = fp_sub(S2, p.Y); // r = S2 - Y1 - if (u256_is_zero(H)) { - if (u256_is_zero(r_v)) { - return g1_dbl(p); - } - return g1_zero(); - } - U256 HH = fp_mont_sqr(H); // H^2 - U256 I = fp_add(HH, HH); - I = fp_add(I, I); // 4*H^2 - U256 J = fp_mont_mul(H, I); // 4*H^3 - U256 r_2 = fp_add(r_v, r_v); // 2 r - U256 V = fp_mont_mul(p.X, I); // X1 * 4 H^2 - G1Jac out; - U256 r_sq = fp_mont_sqr(r_2); // (2r)^2 - U256 t1 = fp_sub(r_sq, J); - U256 twoV = fp_add(V, V); - out.X = fp_sub(t1, twoV); // X3 = r^2 - J - 2V - U256 V_minus_X3 = fp_sub(V, out.X); - U256 r_VX = fp_mont_mul(r_2, V_minus_X3); - U256 Y1J = fp_mont_mul(p.Y, J); - U256 twoY1J = fp_add(Y1J, Y1J); - out.Y = fp_sub(r_VX, twoY1J); // Y3 = r*(V - X3) - 2 Y1 J - out.Z = fp_mont_mul(p.Z, fp_add(H, H)); // Z3 = Z1 * 2H - return out; -} - -// Full Jacobian addition. https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-0.html#addition-add-2007-bl -inline G1Jac g1_add(thread const G1Jac& p, thread const G1Jac& q) { - if (g1_is_zero(p)) return q; - if (g1_is_zero(q)) return p; - U256 Z1Z1 = fp_mont_sqr(p.Z); - U256 Z2Z2 = fp_mont_sqr(q.Z); - U256 U1 = fp_mont_mul(p.X, Z2Z2); - U256 U2 = fp_mont_mul(q.X, Z1Z1); - U256 S1 = fp_mont_mul(fp_mont_mul(p.Y, q.Z), Z2Z2); - U256 S2 = fp_mont_mul(fp_mont_mul(q.Y, p.Z), Z1Z1); - U256 H = fp_sub(U2, U1); - U256 r_v = fp_sub(S2, S1); - if (u256_is_zero(H)) { - if (u256_is_zero(r_v)) return g1_dbl(p); - return g1_zero(); - } - U256 r2 = fp_add(r_v, r_v); - U256 HH = fp_mont_sqr(H); - U256 I = fp_add(HH, HH); - I = fp_add(I, I); - U256 J = fp_mont_mul(H, I); - U256 V = fp_mont_mul(U1, I); - G1Jac out; - U256 r_sq = fp_mont_sqr(r2); - U256 t1 = fp_sub(r_sq, J); - U256 twoV = fp_add(V, V); - out.X = fp_sub(t1, twoV); - U256 V_minus_X3 = fp_sub(V, out.X); - U256 r_VX = fp_mont_mul(r2, V_minus_X3); - U256 S1J = fp_mont_mul(S1, J); - U256 twoS1J = fp_add(S1J, S1J); - out.Y = fp_sub(r_VX, twoS1J); - U256 Z1Z2 = fp_mont_mul(p.Z, q.Z); - out.Z = fp_mont_mul(Z1Z2, fp_add(H, H)); - return out; -} - -// Convert Jacobian -> Affine. Returns inf=true iff Z == 0. -inline G1Aff g1_to_affine(thread const G1Jac& p) { - G1Aff r; - if (g1_is_zero(p)) { - r.X = u256_zero(); r.Y = u256_zero(); r.inf = true; - return r; - } - U256 Zinv = fp_inv(p.Z); - U256 Zinv2 = fp_mont_sqr(Zinv); - U256 Zinv3 = fp_mont_mul(Zinv2, Zinv); - r.X = fp_mont_mul(p.X, Zinv2); - r.Y = fp_mont_mul(p.Y, Zinv3); - r.inf = false; - return r; -} - -// Scalar multiplication: standard left-to-right binary ladder. -// `s` is provided as four little-endian 64-bit limbs in NON-Montgomery form -// (raw integer). Affine input (Q in Montgomery form, no Z). Returns Jacobian. -inline G1Jac g1_scalar_mul_aff(thread const U256& Qx, thread const U256& Qy, - thread const uint64_t s[4]) { - G1Jac acc = g1_zero(); - // top limb first (MSB) - for (int li = 3; li >= 0; --li) { - uint64_t limb = s[li]; - for (int bi = 63; bi >= 0; --bi) { - acc = g1_dbl(acc); - if ((limb >> bi) & 1ULL) { - acc = g1_add_mixed(acc, Qx, Qy); - } - } - } - return acc; -} - -// ============================================================================= -// I/O: raw 32-byte big-endian Fp <-> Montgomery U256 -// ============================================================================= - -// Read 32-byte BE -> 4 LE limbs (no reduction; assumes value already < p or < r). -inline U256 read_be32(device const uint8_t* p) { - U256 r; - for (int limb = 0; limb < 4; ++limb) { - // limb 0 = least significant 8 bytes; in BE these are the LAST 8 bytes - const device uint8_t* src = p + (3 - limb) * 8; - uint64_t v = 0; - for (int i = 0; i < 8; ++i) { - v = (v << 8) | (uint64_t)src[i]; - } - r.l[limb] = v; - } - return r; -} - -// Write 4 LE limbs -> 32-byte BE. -inline void write_be32(device uint8_t* p, thread const U256& a) { - for (int limb = 0; limb < 4; ++limb) { - device uint8_t* dst = p + (3 - limb) * 8; - uint64_t v = a.l[limb]; - for (int i = 7; i >= 0; --i) { - dst[i] = (uint8_t)(v & 0xFFu); - v >>= 8; - } - } -} - -// ============================================================================= -// Kernel 1: pedersen_pointmul -- one thread per (m, i) scalar mul -// ============================================================================= -// -// Layout: -// tid = m * (N + 1) + i (i in [0, N+1)) -// -// For i in [0, N): -// point = G_basis[i] (raw BE 64 bytes) -// scalar = S[m*N + i] (raw BE 32 bytes) -// -// For i == N: -// point = H (raw BE 64 bytes) -// scalar = blinding[m] (raw BE 32 bytes) -// -// Output: -// scratch[tid] = scalar * point in Jacobian Montgomery form (3*32 bytes) -// -// We write the full Jacobian (X, Y, Z) into scratch as 96 bytes per term. -// Reduction in kernel 2 reads them back. - -struct PedersenDims { - uint32_t M; // number of commitments - uint32_t N; // basis size -}; - -kernel void pedersen_pointmul( - device const uint8_t* gens_be [[buffer(0)]], // (N+1)*64 BE bytes - device const uint8_t* scalars_be [[buffer(1)]], // M*N*32 BE bytes - device const uint8_t* blindings_be [[buffer(2)]],// M*32 BE bytes - device uint64_t* scratch [[buffer(3)]], // M*(N+1)*12 u64 (X||Y||Z) - constant PedersenDims& dims [[buffer(4)]], - uint tid [[thread_position_in_grid]] -) { - uint32_t M = dims.M; - uint32_t N = dims.N; - uint32_t total = M * (N + 1); - if (tid >= total) return; - - uint32_t m = tid / (N + 1); - uint32_t i = tid - m * (N + 1); - - // Pick generator and scalar - U256 Qx_raw, Qy_raw, scalar_raw; - if (i < N) { - Qx_raw = read_be32(gens_be + i * 64); - Qy_raw = read_be32(gens_be + i * 64 + 32); - scalar_raw = read_be32(scalars_be + (m * N + i) * 32); - } else { - Qx_raw = read_be32(gens_be + N * 64); - Qy_raw = read_be32(gens_be + N * 64 + 32); - scalar_raw = read_be32(blindings_be + m * 32); - } - - // Convert generator coords to Montgomery form (X, Y < p assumed). - U256 Qx = fp_to_mont(Qx_raw); - U256 Qy = fp_to_mont(Qy_raw); - - // scalar limbs (raw integer, NOT Montgomery) - uint64_t s[4] = { scalar_raw.l[0], scalar_raw.l[1], scalar_raw.l[2], scalar_raw.l[3] }; - - G1Jac result = g1_scalar_mul_aff(Qx, Qy, s); - - // Write Jacobian (X, Y, Z) Montgomery limbs into scratch. - uint32_t base = tid * 12; - for (int k = 0; k < 4; ++k) scratch[base + 0 + k] = result.X.l[k]; - for (int k = 0; k < 4; ++k) scratch[base + 4 + k] = result.Y.l[k]; - for (int k = 0; k < 4; ++k) scratch[base + 8 + k] = result.Z.l[k]; -} - -// ============================================================================= -// Kernel 2: pedersen_reduce_add -- one thread per commitment; sums (N+1) terms -// ============================================================================= - -inline G1Jac scratch_load(device const uint64_t* scratch, uint32_t idx) { - G1Jac r; - uint32_t base = idx * 12; - for (int k = 0; k < 4; ++k) r.X.l[k] = scratch[base + 0 + k]; - for (int k = 0; k < 4; ++k) r.Y.l[k] = scratch[base + 4 + k]; - for (int k = 0; k < 4; ++k) r.Z.l[k] = scratch[base + 8 + k]; - return r; -} - -kernel void pedersen_reduce_add( - device const uint64_t* scratch [[buffer(0)]], // M*(N+1)*12 u64 - device uint8_t* out_be [[buffer(1)]], // M*64 BE bytes - constant PedersenDims& dims [[buffer(2)]], - uint m [[thread_position_in_grid]] -) { - uint32_t M = dims.M; - uint32_t N = dims.N; - if (m >= M) return; - - uint32_t base = m * (N + 1); - G1Jac acc = g1_zero(); - for (uint32_t i = 0; i < N + 1; ++i) { - G1Jac term = scratch_load(scratch, base + i); - if (g1_is_zero(term)) continue; - acc = g1_add(acc, term); - } - - // Convert to affine and emit raw BE. - G1Aff aff = g1_to_affine(acc); - if (aff.inf) { - // Emit (0, 0) BE for infinity (matches gnark's encoding when point is identity). - device uint8_t* dst = out_be + m * 64; - for (int b = 0; b < 64; ++b) dst[b] = 0; - return; - } - // Convert from Montgomery before BE emit. - U256 X_raw = fp_from_mont(aff.X); - U256 Y_raw = fp_from_mont(aff.Y); - write_be32(out_be + m * 64, X_raw); - write_be32(out_be + m * 64 + 32, Y_raw); -} diff --git a/pedersen/gpu/metal/pedersen_driver.h b/pedersen/gpu/metal/pedersen_driver.h deleted file mode 100644 index 7f5c415..0000000 --- a/pedersen/gpu/metal/pedersen_driver.h +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Metal driver for batched Pedersen vector commitments. macOS / iOS only. -// -// Computes M parallel Pedersen vector commitments of basis size N: -// -// C_m = sum_{i=0..N-1} scalars[m*N + i] * G_basis[i] + blindings[m] * H -// -// Wire format (raw big-endian, gnark-crypto compatible): -// gens_be : (N + 1) * 64 bytes -- G_basis[0..N-1] || H, X then Y -// scalars_be : M * N * 32 bytes -- raw BE Fr elements -// blindings_be: M * 32 bytes -- raw BE Fr elements -// out_be : M * 64 bytes -- (X || Y) raw BE -// -// Returns 0 on success, negative on failure. - -#pragma once - -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -int pedersen_batch_metal( - const uint8_t* gens_be, - const uint8_t* scalars_be, - const uint8_t* blindings_be, - uint32_t M, - uint32_t N, - uint8_t* out_be, - const char* metallib_path); - -#ifdef __cplusplus -} -#endif diff --git a/pedersen/gpu/metal/pedersen_driver.mm b/pedersen/gpu/metal/pedersen_driver.mm deleted file mode 100644 index cd1b30a..0000000 --- a/pedersen/gpu/metal/pedersen_driver.mm +++ /dev/null @@ -1,141 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Metal driver for batched Pedersen vector commitments. macOS / iOS only. -// -// Two-stage dispatch: -// 1. pedersen_pointmul -- M*(N+1) threads, one per (commitment, term) -// 2. pedersen_reduce_add -- M threads, one per commitment - -#if __APPLE__ && __OBJC__ - -#import -#import - -#include "pedersen_driver.h" - -#include -#include -#include - -namespace { - -struct PedersenDimsGPU { - uint32_t M; - uint32_t N; -}; - -} // namespace - -extern "C" int pedersen_batch_metal( - const uint8_t* gens_be, - const uint8_t* scalars_be, - const uint8_t* blindings_be, - uint32_t M, - uint32_t N, - uint8_t* out_be, - const char* metallib_path) { - - if (M == 0 || N == 0) return 0; - if (!gens_be || !scalars_be || !blindings_be || !out_be || !metallib_path) { - return -1; - } - - @autoreleasepool { - id device = MTLCreateSystemDefaultDevice(); - if (!device) return -2; - - NSError* err = nil; - NSString* path = [NSString stringWithUTF8String:metallib_path]; - NSURL* url = [NSURL fileURLWithPath:path]; - id lib = [device newLibraryWithURL:url error:&err]; - if (!lib) return -3; - - id fn_mul = [lib newFunctionWithName:@"pedersen_pointmul"]; - if (!fn_mul) return -4; - id fn_red = [lib newFunctionWithName:@"pedersen_reduce_add"]; - if (!fn_red) return -5; - - id pipe_mul = - [device newComputePipelineStateWithFunction:fn_mul error:&err]; - if (!pipe_mul) return -6; - id pipe_red = - [device newComputePipelineStateWithFunction:fn_red error:&err]; - if (!pipe_red) return -7; - - id queue = [device newCommandQueue]; - - // Buffers -------------------------------------------------------------- - size_t gens_len = (size_t)(N + 1) * 64; - size_t scalars_len = (size_t)M * N * 32; - size_t blind_len = (size_t)M * 32; - size_t scratch_u64 = (size_t)M * (N + 1) * 12; // X||Y||Z each 4 limbs - size_t out_len = (size_t)M * 64; - - id gens_buf = [device newBufferWithBytes:gens_be - length:gens_len - options:MTLResourceStorageModeShared]; - id scalars_buf = [device newBufferWithBytes:scalars_be - length:scalars_len - options:MTLResourceStorageModeShared]; - id blind_buf = [device newBufferWithBytes:blindings_be - length:blind_len - options:MTLResourceStorageModeShared]; - id scratch_buf = [device newBufferWithLength:scratch_u64 * sizeof(uint64_t) - options:MTLResourceStorageModePrivate]; - id out_buf = [device newBufferWithLength:out_len - options:MTLResourceStorageModeShared]; - - PedersenDimsGPU dims = { M, N }; - id dims_buf = [device newBufferWithBytes:&dims - length:sizeof(dims) - options:MTLResourceStorageModeShared]; - - // Stage 1: pointmul ---------------------------------------------------- - { - id cmd = [queue commandBuffer]; - id enc = [cmd computeCommandEncoder]; - [enc setComputePipelineState:pipe_mul]; - [enc setBuffer:gens_buf offset:0 atIndex:0]; - [enc setBuffer:scalars_buf offset:0 atIndex:1]; - [enc setBuffer:blind_buf offset:0 atIndex:2]; - [enc setBuffer:scratch_buf offset:0 atIndex:3]; - [enc setBuffer:dims_buf offset:0 atIndex:4]; - - NSUInteger total = (NSUInteger)M * (NSUInteger)(N + 1); - NSUInteger tg_max = pipe_mul.maxTotalThreadsPerThreadgroup; - NSUInteger tg_w = (tg_max < 64) ? tg_max : 64; - MTLSize threads_per_grid = MTLSizeMake(total, 1, 1); - MTLSize threads_per_tg = MTLSizeMake(tg_w, 1, 1); - [enc dispatchThreads:threads_per_grid threadsPerThreadgroup:threads_per_tg]; - [enc endEncoding]; - [cmd commit]; - [cmd waitUntilCompleted]; - } - - // Stage 2: reduce_add -------------------------------------------------- - { - id cmd = [queue commandBuffer]; - id enc = [cmd computeCommandEncoder]; - [enc setComputePipelineState:pipe_red]; - [enc setBuffer:scratch_buf offset:0 atIndex:0]; - [enc setBuffer:out_buf offset:0 atIndex:1]; - [enc setBuffer:dims_buf offset:0 atIndex:2]; - - NSUInteger total = (NSUInteger)M; - NSUInteger tg_max = pipe_red.maxTotalThreadsPerThreadgroup; - NSUInteger tg_w = (tg_max < 32) ? tg_max : 32; - MTLSize threads_per_grid = MTLSizeMake(total, 1, 1); - MTLSize threads_per_tg = MTLSizeMake(tg_w, 1, 1); - [enc dispatchThreads:threads_per_grid threadsPerThreadgroup:threads_per_tg]; - [enc endEncoding]; - [cmd commit]; - [cmd waitUntilCompleted]; - } - - std::memcpy(out_be, [out_buf contents], out_len); - } - return 0; -} - -#endif // __APPLE__ && __OBJC__ diff --git a/pedersen/gpu/metal/pedersen_tree.metal b/pedersen/gpu/metal/pedersen_tree.metal deleted file mode 100644 index b25e16e..0000000 --- a/pedersen/gpu/metal/pedersen_tree.metal +++ /dev/null @@ -1,478 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Tree-reduce Pedersen vector-commitment kernel for Verkle. -// -// Specialised at width N = 256 (Verkle node width). One threadgroup per -// commitment, 256 threads per threadgroup. Each thread computes one -// scalar*generator product into threadgroup memory; a log_2(256) = 8 stride -// loop reduces the 256 partial points to a single point inside the same -// threadgroup. Thread 0 finalises with r*H and emits the affine result. -// -// Cuts log_2 N round-trips (8 host->GPU dispatches in the two-stage driver) -// down to one. Threadgroup barrier replaces command-buffer wait. -// -// Threadgroup memory budget: 256 * 96 bytes (X || Y || Z, 4 LE u64 each) -// = 24 KiB. Within the 32 KiB threadgroup-memory floor on every Apple GPU. -// -// Wire format identical to pedersen_pointmul / pedersen_reduce_add: gens -// (N+1)*64 BE, scalars M*N*32 BE, blindings M*32 BE, output M*64 BE. Output -// is byte-equal to the legacy two-stage Metal pipeline and to the Go -// canonical at github.com/luxfi/crypto/pedersen. - -#include -using namespace metal; - -constant uint PED_TREE_N = 256; // Verkle node width - -// ============================================================================= -// BN254 base-field constants (identical to pedersen.metal) -// ============================================================================= - -constant uint64_t BN254_P0 = 0x3C208C16D87CFD47ULL; -constant uint64_t BN254_P1 = 0x97816A916871CA8DULL; -constant uint64_t BN254_P2 = 0xB85045B68181585DULL; -constant uint64_t BN254_P3 = 0x30644E72E131A029ULL; - -constant uint64_t BN254_R_0 = 0xD35D438DC58F0D9DULL; -constant uint64_t BN254_R_1 = 0x0A78EB28F5C70B3DULL; -constant uint64_t BN254_R_2 = 0x666EA36F7879462CULL; -constant uint64_t BN254_R_3 = 0x0E0A77C19A07DF2FULL; - -constant uint64_t BN254_R2_0 = 0xF32CFC5B538AFA89ULL; -constant uint64_t BN254_R2_1 = 0xB5E71911D44501FBULL; -constant uint64_t BN254_R2_2 = 0x47AB1EFF0A417FF6ULL; -constant uint64_t BN254_R2_3 = 0x06D89F71CAB8351FULL; - -constant uint64_t BN254_INV = 0x87D20782E4866389ULL; - -struct U256 { uint64_t l[4]; }; - -inline U256 u256_zero() { U256 x; x.l[0]=0; x.l[1]=0; x.l[2]=0; x.l[3]=0; return x; } - -inline bool u256_is_zero(thread const U256& a) { - return (a.l[0] | a.l[1] | a.l[2] | a.l[3]) == 0; -} - -inline int u256_cmp_p(thread const U256& a) { - if (a.l[3] != BN254_P3) return a.l[3] > BN254_P3 ? 1 : -1; - if (a.l[2] != BN254_P2) return a.l[2] > BN254_P2 ? 1 : -1; - if (a.l[1] != BN254_P1) return a.l[1] > BN254_P1 ? 1 : -1; - if (a.l[0] != BN254_P0) return a.l[0] > BN254_P0 ? 1 : -1; - return 0; -} - -inline U256 fp_p() { - U256 r; r.l[0]=BN254_P0; r.l[1]=BN254_P1; r.l[2]=BN254_P2; r.l[3]=BN254_P3; - return r; -} - -inline U256 fp_csub_p(thread const U256& a) { - if (u256_cmp_p(a) < 0) return a; - U256 r; - uint64_t borrow = 0; - uint64_t pl[4] = { BN254_P0, BN254_P1, BN254_P2, BN254_P3 }; - for (int i = 0; i < 4; ++i) { - uint64_t ai = a.l[i]; - uint64_t s = ai - pl[i] - borrow; - borrow = ((ai < pl[i] + borrow) || (pl[i] + borrow < pl[i])) ? 1 : 0; - r.l[i] = s; - } - return r; -} - -inline U256 fp_add(thread const U256& a, thread const U256& b) { - U256 r; - uint64_t carry = 0; - for (int i = 0; i < 4; ++i) { - uint64_t s = a.l[i] + b.l[i]; - uint64_t c1 = (s < a.l[i]) ? 1 : 0; - uint64_t s2 = s + carry; - uint64_t c2 = (s2 < s) ? 1 : 0; - r.l[i] = s2; - carry = c1 + c2; - } - return fp_csub_p(r); -} - -inline U256 fp_sub(thread const U256& a, thread const U256& b) { - U256 r; - uint64_t borrow = 0; - for (int i = 0; i < 4; ++i) { - uint64_t bi = b.l[i]; - uint64_t s = a.l[i] - bi - borrow; - borrow = ((a.l[i] < bi + borrow) || (bi + borrow < bi)) ? 1 : 0; - r.l[i] = s; - } - if (borrow) { - uint64_t carry = 0; - uint64_t pl[4] = { BN254_P0, BN254_P1, BN254_P2, BN254_P3 }; - for (int i = 0; i < 4; ++i) { - uint64_t s = r.l[i] + pl[i]; - uint64_t c1 = (s < r.l[i]) ? 1 : 0; - uint64_t s2 = s + carry; - uint64_t c2 = (s2 < s) ? 1 : 0; - r.l[i] = s2; - carry = c1 + c2; - } - } - return r; -} - -inline U256 fp_mont_mul(thread const U256& a, thread const U256& b) { - uint64_t pl[4] = { BN254_P0, BN254_P1, BN254_P2, BN254_P3 }; - uint64_t t0 = 0, t1 = 0, t2 = 0, t3 = 0, t4 = 0, t5 = 0; - - for (int i = 0; i < 4; ++i) { - uint64_t ai = a.l[i]; - { - uint64_t carry = 0, lo, hi; - lo = ai * b.l[0]; hi = mulhi(ai, b.l[0]); - uint64_t s = t0 + lo; uint64_t c1 = (s < t0) ? 1ULL : 0ULL; - t0 = s; carry = hi + c1; - lo = ai * b.l[1]; hi = mulhi(ai, b.l[1]); - s = t1 + lo; c1 = (s < t1) ? 1ULL : 0ULL; - uint64_t s2 = s + carry; uint64_t c2 = (s2 < s) ? 1ULL : 0ULL; - t1 = s2; carry = hi + c1 + c2; - lo = ai * b.l[2]; hi = mulhi(ai, b.l[2]); - s = t2 + lo; c1 = (s < t2) ? 1ULL : 0ULL; - s2 = s + carry; c2 = (s2 < s) ? 1ULL : 0ULL; - t2 = s2; carry = hi + c1 + c2; - lo = ai * b.l[3]; hi = mulhi(ai, b.l[3]); - s = t3 + lo; c1 = (s < t3) ? 1ULL : 0ULL; - s2 = s + carry; c2 = (s2 < s) ? 1ULL : 0ULL; - t3 = s2; carry = hi + c1 + c2; - s = t4 + carry; c1 = (s < t4) ? 1ULL : 0ULL; - t4 = s; t5 = t5 + c1; - } - uint64_t m = t0 * BN254_INV; - { - uint64_t carry = 0, lo, hi; - lo = m * pl[0]; hi = mulhi(m, pl[0]); - uint64_t s = t0 + lo; uint64_t c1 = (s < t0) ? 1ULL : 0ULL; - carry = hi + c1; - lo = m * pl[1]; hi = mulhi(m, pl[1]); - s = t1 + lo; c1 = (s < t1) ? 1ULL : 0ULL; - uint64_t s2 = s + carry; uint64_t c2 = (s2 < s) ? 1ULL : 0ULL; - t1 = s2; carry = hi + c1 + c2; - lo = m * pl[2]; hi = mulhi(m, pl[2]); - s = t2 + lo; c1 = (s < t2) ? 1ULL : 0ULL; - s2 = s + carry; c2 = (s2 < s) ? 1ULL : 0ULL; - t2 = s2; carry = hi + c1 + c2; - lo = m * pl[3]; hi = mulhi(m, pl[3]); - s = t3 + lo; c1 = (s < t3) ? 1ULL : 0ULL; - s2 = s + carry; c2 = (s2 < s) ? 1ULL : 0ULL; - t3 = s2; carry = hi + c1 + c2; - s = t4 + carry; c1 = (s < t4) ? 1ULL : 0ULL; - t4 = s; t5 = t5 + c1; - t0 = t1; t1 = t2; t2 = t3; t3 = t4; t4 = t5; t5 = 0; - } - } - - U256 r; r.l[0]=t0; r.l[1]=t1; r.l[2]=t2; r.l[3]=t3; - if (t4 != 0) { - U256 p = fp_p(); - r = fp_sub(r, p); - } - return fp_csub_p(r); -} - -inline U256 fp_mont_sqr(thread const U256& a) { return fp_mont_mul(a, a); } - -inline U256 fp_one_mont() { - U256 r; r.l[0]=BN254_R_0; r.l[1]=BN254_R_1; r.l[2]=BN254_R_2; r.l[3]=BN254_R_3; - return r; -} - -inline U256 fp_r2() { - U256 r; r.l[0]=BN254_R2_0; r.l[1]=BN254_R2_1; r.l[2]=BN254_R2_2; r.l[3]=BN254_R2_3; - return r; -} - -inline U256 fp_to_mont(thread const U256& x) { return fp_mont_mul(x, fp_r2()); } - -inline U256 fp_from_mont(thread const U256& x) { - U256 one; one.l[0]=1; one.l[1]=0; one.l[2]=0; one.l[3]=0; - return fp_mont_mul(x, one); -} - -inline U256 fp_inv(thread const U256& a) { - if (u256_is_zero(a)) return a; - uint64_t exp[4] = { BN254_P0 - 2ULL, BN254_P1, BN254_P2, BN254_P3 }; - U256 result = fp_one_mont(); - U256 base = a; - for (int limb = 0; limb < 4; ++limb) { - uint64_t e = exp[limb]; - for (int b = 0; b < 64; ++b) { - if ((e >> b) & 1ULL) result = fp_mont_mul(result, base); - base = fp_mont_sqr(base); - } - } - return result; -} - -// ============================================================================= -// G1 in Jacobian (Montgomery) -// ============================================================================= - -struct G1Jac { U256 X; U256 Y; U256 Z; }; -struct G1Aff { U256 X; U256 Y; bool inf; }; - -inline G1Jac g1_zero() { - G1Jac p; - p.X = fp_one_mont(); p.Y = fp_one_mont(); p.Z = u256_zero(); - return p; -} - -inline bool g1_is_zero(thread const G1Jac& p) { return u256_is_zero(p.Z); } - -inline G1Jac g1_dbl(thread const G1Jac& p) { - if (g1_is_zero(p)) return p; - U256 A = fp_mont_sqr(p.X); - U256 B = fp_mont_sqr(p.Y); - U256 C = fp_mont_sqr(B); - U256 t = fp_add(p.X, B); - U256 t2 = fp_mont_sqr(t); - U256 t3 = fp_sub(t2, A); - U256 t4 = fp_sub(t3, C); - U256 D = fp_add(t4, t4); - U256 E = fp_add(fp_add(A, A), A); - U256 F = fp_mont_sqr(E); - G1Jac r; - U256 twoD = fp_add(D, D); - r.X = fp_sub(F, twoD); - U256 D_minus_X = fp_sub(D, r.X); - U256 EDX = fp_mont_mul(E, D_minus_X); - U256 eightC = fp_add(C, C); - eightC = fp_add(eightC, eightC); - eightC = fp_add(eightC, eightC); - r.Y = fp_sub(EDX, eightC); - U256 YZ = fp_mont_mul(p.Y, p.Z); - r.Z = fp_add(YZ, YZ); - return r; -} - -inline G1Jac g1_add_mixed(thread const G1Jac& p, thread const U256& Qx, thread const U256& Qy) { - if (g1_is_zero(p)) { - G1Jac r; - r.X = Qx; r.Y = Qy; r.Z = fp_one_mont(); - return r; - } - U256 Z1Z1 = fp_mont_sqr(p.Z); - U256 U2 = fp_mont_mul(Qx, Z1Z1); - U256 S2 = fp_mont_mul(Qy, fp_mont_mul(p.Z, Z1Z1)); - U256 H = fp_sub(U2, p.X); - U256 r_v = fp_sub(S2, p.Y); - if (u256_is_zero(H)) { - if (u256_is_zero(r_v)) return g1_dbl(p); - return g1_zero(); - } - U256 HH = fp_mont_sqr(H); - U256 I = fp_add(HH, HH); I = fp_add(I, I); - U256 J = fp_mont_mul(H, I); - U256 r_2 = fp_add(r_v, r_v); - U256 V = fp_mont_mul(p.X, I); - G1Jac out; - U256 r_sq = fp_mont_sqr(r_2); - U256 t1 = fp_sub(r_sq, J); - U256 twoV = fp_add(V, V); - out.X = fp_sub(t1, twoV); - U256 V_minus_X3 = fp_sub(V, out.X); - U256 r_VX = fp_mont_mul(r_2, V_minus_X3); - U256 Y1J = fp_mont_mul(p.Y, J); - U256 twoY1J = fp_add(Y1J, Y1J); - out.Y = fp_sub(r_VX, twoY1J); - out.Z = fp_mont_mul(p.Z, fp_add(H, H)); - return out; -} - -inline G1Jac g1_add(thread const G1Jac& p, thread const G1Jac& q) { - if (g1_is_zero(p)) return q; - if (g1_is_zero(q)) return p; - U256 Z1Z1 = fp_mont_sqr(p.Z); - U256 Z2Z2 = fp_mont_sqr(q.Z); - U256 U1 = fp_mont_mul(p.X, Z2Z2); - U256 U2 = fp_mont_mul(q.X, Z1Z1); - U256 S1 = fp_mont_mul(fp_mont_mul(p.Y, q.Z), Z2Z2); - U256 S2 = fp_mont_mul(fp_mont_mul(q.Y, p.Z), Z1Z1); - U256 H = fp_sub(U2, U1); - U256 r_v = fp_sub(S2, S1); - if (u256_is_zero(H)) { - if (u256_is_zero(r_v)) return g1_dbl(p); - return g1_zero(); - } - U256 r2 = fp_add(r_v, r_v); - U256 HH = fp_mont_sqr(H); - U256 I = fp_add(HH, HH); I = fp_add(I, I); - U256 J = fp_mont_mul(H, I); - U256 V = fp_mont_mul(U1, I); - G1Jac out; - U256 r_sq = fp_mont_sqr(r2); - U256 t1 = fp_sub(r_sq, J); - U256 twoV = fp_add(V, V); - out.X = fp_sub(t1, twoV); - U256 V_minus_X3 = fp_sub(V, out.X); - U256 r_VX = fp_mont_mul(r2, V_minus_X3); - U256 S1J = fp_mont_mul(S1, J); - U256 twoS1J = fp_add(S1J, S1J); - out.Y = fp_sub(r_VX, twoS1J); - U256 Z1Z2 = fp_mont_mul(p.Z, q.Z); - out.Z = fp_mont_mul(Z1Z2, fp_add(H, H)); - return out; -} - -inline G1Aff g1_to_affine(thread const G1Jac& p) { - G1Aff r; - if (g1_is_zero(p)) { r.X = u256_zero(); r.Y = u256_zero(); r.inf = true; return r; } - U256 Zinv = fp_inv(p.Z); - U256 Zinv2 = fp_mont_sqr(Zinv); - U256 Zinv3 = fp_mont_mul(Zinv2, Zinv); - r.X = fp_mont_mul(p.X, Zinv2); - r.Y = fp_mont_mul(p.Y, Zinv3); - r.inf = false; - return r; -} - -inline G1Jac g1_scalar_mul_aff(thread const U256& Qx, thread const U256& Qy, - thread const uint64_t s[4]) { - G1Jac acc = g1_zero(); - for (int li = 3; li >= 0; --li) { - uint64_t limb = s[li]; - for (int bi = 63; bi >= 0; --bi) { - acc = g1_dbl(acc); - if ((limb >> bi) & 1ULL) acc = g1_add_mixed(acc, Qx, Qy); - } - } - return acc; -} - -// ============================================================================= -// I/O helpers -// ============================================================================= - -inline U256 read_be32(device const uint8_t* p) { - U256 r; - for (int limb = 0; limb < 4; ++limb) { - const device uint8_t* src = p + (3 - limb) * 8; - uint64_t v = 0; - for (int i = 0; i < 8; ++i) v = (v << 8) | (uint64_t)src[i]; - r.l[limb] = v; - } - return r; -} - -inline void write_be32(device uint8_t* p, thread const U256& a) { - for (int limb = 0; limb < 4; ++limb) { - device uint8_t* dst = p + (3 - limb) * 8; - uint64_t v = a.l[limb]; - for (int i = 7; i >= 0; --i) { - dst[i] = (uint8_t)(v & 0xFFu); - v >>= 8; - } - } -} - -// ============================================================================= -// Threadgroup-cooperative tree-reduce kernel -// ============================================================================= -// -// One threadgroup per commitment, 256 threads per threadgroup. -// Threadgroup memory: shared_pts[256] G1Jac (24 KiB). -// -// Phase 1: thread `i` does shared_pts[i] = scalars[i] * G_basis[i] -// Phase 2: tree-reduce in shared memory: -// stride = 128, 64, 32, 16, 8, 4, 2, 1 -// if i < stride: shared_pts[i] = shared_pts[i] + shared_pts[i+stride] -// barrier between strides -// Phase 3: thread 0 finishes by computing result = shared_pts[0] + r * H -// Phase 4: thread 0 converts to affine and writes 64 BE bytes. -// -// Outputs byte-equal to the legacy two-stage Metal pipeline. - -struct PedTreeDims { uint32_t M; uint32_t N; }; - -kernel void pedersen_tree_commit( - device const uint8_t* gens_be [[buffer(0)]], - device const uint8_t* scalars_be [[buffer(1)]], - device const uint8_t* blindings_be[[buffer(2)]], - device uint8_t* out_be [[buffer(3)]], - constant PedTreeDims& dims [[buffer(4)]], - threadgroup uint64_t* shared_pts [[threadgroup(0)]], - uint tid [[thread_position_in_threadgroup]], - uint gid [[threadgroup_position_in_grid]], - uint tg_size [[threads_per_threadgroup]] -) { - const uint32_t N = dims.N; // expected to be 256 (Verkle width) - const uint32_t M = dims.M; - if (gid >= M) return; - if (tid >= N) return; - if (tg_size != PED_TREE_N) return; // contract; driver enforces - - // Phase 1: each thread computes its own term P[i] = scalars[m,i] * G[i]. - U256 Qx_raw = read_be32(gens_be + tid * 64); - U256 Qy_raw = read_be32(gens_be + tid * 64 + 32); - U256 sc_raw = read_be32(scalars_be + (gid * N + tid) * 32); - U256 Qx = fp_to_mont(Qx_raw); - U256 Qy = fp_to_mont(Qy_raw); - uint64_t s[4] = { sc_raw.l[0], sc_raw.l[1], sc_raw.l[2], sc_raw.l[3] }; - G1Jac P = g1_scalar_mul_aff(Qx, Qy, s); - - // Stash in threadgroup memory: 12 u64 per slot (X || Y || Z). - uint base = tid * 12; - for (int k = 0; k < 4; ++k) shared_pts[base + 0 + k] = P.X.l[k]; - for (int k = 0; k < 4; ++k) shared_pts[base + 4 + k] = P.Y.l[k]; - for (int k = 0; k < 4; ++k) shared_pts[base + 8 + k] = P.Z.l[k]; - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Phase 2: tree reduction. After each step, the live half is [0, stride). - for (uint stride = 128u; stride > 0u; stride >>= 1) { - if (tid < stride) { - // load self - uint b0 = tid * 12; - G1Jac a; - for (int k = 0; k < 4; ++k) a.X.l[k] = shared_pts[b0 + 0 + k]; - for (int k = 0; k < 4; ++k) a.Y.l[k] = shared_pts[b0 + 4 + k]; - for (int k = 0; k < 4; ++k) a.Z.l[k] = shared_pts[b0 + 8 + k]; - // load sibling - uint b1 = (tid + stride) * 12; - G1Jac b; - for (int k = 0; k < 4; ++k) b.X.l[k] = shared_pts[b1 + 0 + k]; - for (int k = 0; k < 4; ++k) b.Y.l[k] = shared_pts[b1 + 4 + k]; - for (int k = 0; k < 4; ++k) b.Z.l[k] = shared_pts[b1 + 8 + k]; - G1Jac sum = g1_add(a, b); - for (int k = 0; k < 4; ++k) shared_pts[b0 + 0 + k] = sum.X.l[k]; - for (int k = 0; k < 4; ++k) shared_pts[b0 + 4 + k] = sum.Y.l[k]; - for (int k = 0; k < 4; ++k) shared_pts[b0 + 8 + k] = sum.Z.l[k]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - // Phase 3 + 4: thread 0 finishes (load reduced sum, add r*H, emit). - if (tid == 0) { - G1Jac acc; - for (int k = 0; k < 4; ++k) acc.X.l[k] = shared_pts[0 + k]; - for (int k = 0; k < 4; ++k) acc.Y.l[k] = shared_pts[4 + k]; - for (int k = 0; k < 4; ++k) acc.Z.l[k] = shared_pts[8 + k]; - - // Blinding term: r * H - U256 Hx_raw = read_be32(gens_be + N * 64); - U256 Hy_raw = read_be32(gens_be + N * 64 + 32); - U256 r_raw = read_be32(blindings_be + gid * 32); - U256 Hx = fp_to_mont(Hx_raw); - U256 Hy = fp_to_mont(Hy_raw); - uint64_t rs[4] = { r_raw.l[0], r_raw.l[1], r_raw.l[2], r_raw.l[3] }; - G1Jac rH = g1_scalar_mul_aff(Hx, Hy, rs); - acc = g1_add(acc, rH); - - G1Aff aff = g1_to_affine(acc); - device uint8_t* dst = out_be + gid * 64; - if (aff.inf) { - for (int b = 0; b < 64; ++b) dst[b] = 0; - return; - } - U256 X_raw = fp_from_mont(aff.X); - U256 Y_raw = fp_from_mont(aff.Y); - write_be32(dst, X_raw); - write_be32(dst + 32, Y_raw); - } -} diff --git a/pedersen/gpu/metal/pedersen_tree_driver.h b/pedersen/gpu/metal/pedersen_tree_driver.h deleted file mode 100644 index 9f51fc6..0000000 --- a/pedersen/gpu/metal/pedersen_tree_driver.h +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Tree-reduce Metal driver for batched Pedersen vector commitments at the -// fixed Verkle width N = 256. One threadgroup per commitment, 256 threads -// per threadgroup. Threadgroup-local tree reduction collapses log_2 256 = 8 -// host -> GPU dispatches into one. Output byte-equal to the legacy -// pedersen_batch_metal. - -#pragma once - -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -// Verkle node width; the tree-reduce kernel is specialised at this size. -#define PEDERSEN_TREE_WIDTH 256u - -// Computes M Pedersen vector commitments at the fixed width N = 256 in a -// single GPU dispatch with threadgroup-cooperative tree reduction. -// -// Wire format (raw big-endian, gnark-crypto compatible): -// gens_be : (N + 1) * 64 bytes -- G_basis[0..N-1] || H, X then Y -// scalars_be : M * N * 32 bytes -- raw BE Fr elements -// blindings_be: M * 32 bytes -- raw BE Fr elements -// out_be : M * 64 bytes -- (X || Y) raw BE -// -// Returns 0 on success, negative on failure. -int pedersen_tree_metal( - const uint8_t* gens_be, - const uint8_t* scalars_be, - const uint8_t* blindings_be, - uint32_t M, - uint8_t* out_be, - const char* metallib_path); - -#ifdef __cplusplus -} -#endif diff --git a/pedersen/gpu/metal/pedersen_tree_driver.mm b/pedersen/gpu/metal/pedersen_tree_driver.mm deleted file mode 100644 index 3a80084..0000000 --- a/pedersen/gpu/metal/pedersen_tree_driver.mm +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Tree-reduce Metal driver. Single-stage dispatch: -// pedersen_tree_commit -- M threadgroups of 256 threads each. -// -// 1 command buffer, 1 commit, 1 wait. No scratch device buffer (reduction -// happens in 24 KiB of threadgroup memory per commit). - -#if __APPLE__ && __OBJC__ - -#import -#import - -#include "pedersen_tree_driver.h" - -#include -#include -#include - -namespace { - -struct PedTreeDimsGPU { - uint32_t M; - uint32_t N; -}; - -} // namespace - -extern "C" int pedersen_tree_metal( - const uint8_t* gens_be, - const uint8_t* scalars_be, - const uint8_t* blindings_be, - uint32_t M, - uint8_t* out_be, - const char* metallib_path) { - - if (M == 0) return 0; - if (!gens_be || !scalars_be || !blindings_be || !out_be || !metallib_path) { - return -1; - } - - const uint32_t N = PEDERSEN_TREE_WIDTH; - - @autoreleasepool { - id device = MTLCreateSystemDefaultDevice(); - if (!device) return -2; - - NSError* err = nil; - NSString* path = [NSString stringWithUTF8String:metallib_path]; - NSURL* url = [NSURL fileURLWithPath:path]; - id lib = [device newLibraryWithURL:url error:&err]; - if (!lib) return -3; - - id fn_tree = [lib newFunctionWithName:@"pedersen_tree_commit"]; - if (!fn_tree) return -4; - - id pipe_tree = - [device newComputePipelineStateWithFunction:fn_tree error:&err]; - if (!pipe_tree) return -5; - - // The kernel hard-requires N = 256 threads per threadgroup. Verify - // the device can satisfy that (every Apple GPU since A11 / M1 can, - // but the contract is checked here so a misbuilt metallib fails fast). - if (pipe_tree.maxTotalThreadsPerThreadgroup < N) return -6; - - id queue = [device newCommandQueue]; - - size_t gens_len = (size_t)(N + 1) * 64; - size_t scalars_len = (size_t)M * N * 32; - size_t blind_len = (size_t)M * 32; - size_t out_len = (size_t)M * 64; - // Threadgroup memory budget: 256 slots * 12 u64 each = 3072 u64 = 24 KiB. - size_t tg_bytes = (size_t)N * 12 * sizeof(uint64_t); - - id gens_buf = [device newBufferWithBytes:gens_be - length:gens_len - options:MTLResourceStorageModeShared]; - id scalars_buf = [device newBufferWithBytes:scalars_be - length:scalars_len - options:MTLResourceStorageModeShared]; - id blind_buf = [device newBufferWithBytes:blindings_be - length:blind_len - options:MTLResourceStorageModeShared]; - id out_buf = [device newBufferWithLength:out_len - options:MTLResourceStorageModeShared]; - - PedTreeDimsGPU dims = { M, N }; - id dims_buf = [device newBufferWithBytes:&dims - length:sizeof(dims) - options:MTLResourceStorageModeShared]; - - id cmd = [queue commandBuffer]; - id enc = [cmd computeCommandEncoder]; - [enc setComputePipelineState:pipe_tree]; - [enc setBuffer:gens_buf offset:0 atIndex:0]; - [enc setBuffer:scalars_buf offset:0 atIndex:1]; - [enc setBuffer:blind_buf offset:0 atIndex:2]; - [enc setBuffer:out_buf offset:0 atIndex:3]; - [enc setBuffer:dims_buf offset:0 atIndex:4]; - [enc setThreadgroupMemoryLength:tg_bytes atIndex:0]; - - // M threadgroups of N (= 256) threads each. - MTLSize threads_per_grid = MTLSizeMake((NSUInteger)M * (NSUInteger)N, 1, 1); - MTLSize threads_per_tg = MTLSizeMake((NSUInteger)N, 1, 1); - [enc dispatchThreads:threads_per_grid threadsPerThreadgroup:threads_per_tg]; - [enc endEncoding]; - [cmd commit]; - [cmd waitUntilCompleted]; - - std::memcpy(out_be, [out_buf contents], out_len); - } - return 0; -} - -#endif // __APPLE__ && __OBJC__ diff --git a/pedersen/gpu/wgsl/pedersen.wgsl b/pedersen/gpu/wgsl/pedersen.wgsl deleted file mode 100644 index 8be212d..0000000 --- a/pedersen/gpu/wgsl/pedersen.wgsl +++ /dev/null @@ -1,560 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// First-party WGSL kernel for batched Pedersen vector commitments over BN254 G1. -// Mechanically ported from pedersen/gpu/metal/pedersen.metal -- byte-equal to -// the Metal kernel and the Go canonical at github.com/luxfi/crypto/pedersen. -// -// 256-bit limbs are 8 x u32 little-endian (no native u64 in WGSL). -// -// Two-stage pipeline (driver dispatches twice with different entry points): -// 1. pedersen_pointmul M*(N+1) threads, one per (commitment, term) -// 2. pedersen_reduce_add M threads, one per commitment -// -// Wire format (raw big-endian, gnark-crypto compatible): -// gens_be : (N + 1) * 64 bytes -- G_basis[0..N-1] || H, X then Y -// scalars_be : M * N * 32 bytes -- raw BE Fr elements -// blindings_be: M * 32 bytes -- raw BE Fr elements -// scratch : M * (N+1) * 24 u32 -- (X || Y || Z) Montgomery, 8 limbs each -// out_be : M * 64 bytes -- (X || Y) raw BE -// -// Stage 1 binds: gens_be(0), scalars_be(1), blindings_be(2), scratch(3), dims(4) -// Stage 2 binds: scratch(0), out_be(1), dims(2) - -// Bindings: stage 1 -@group(0) @binding(0) var gens_be : array; -@group(0) @binding(1) var scalars_be : array; -@group(0) @binding(2) var blindings_be : array; -@group(0) @binding(3) var scratch : array; -@group(0) @binding(4) var dims : Dims; - -// Bindings: stage 2 -- separate kernel must use distinct group(1) bindings to -// avoid being merged with stage 1 by the WGSL static binding analyzer. -@group(1) @binding(0) var scratch_in : array; -@group(1) @binding(1) var out_be : array; -@group(1) @binding(2) var dims2 : Dims; - -struct Dims { - M: u32, - N: u32, - _pad0: u32, - _pad1: u32, -} - -// ============================================================================= -// BN254 base-field constants -- 8 x u32 little-endian -// ============================================================================= -// p = 0x30644E72E131A029 B85045B68181585D 97816A916871CA8D 3C208C16D87CFD47 -const BN254_P = array( - 0xD87CFD47u, 0x3C208C16u, 0x6871CA8Du, 0x97816A91u, - 0x8181585Du, 0xB85045B6u, 0xE131A029u, 0x30644E72u -); -// R = 2^256 mod p -const BN254_R_MONT = array( - 0xC58F0D9Du, 0xD35D438Du, 0xF5C70B3Du, 0x0A78EB28u, - 0x7879462Cu, 0x666EA36Fu, 0x9A07DF2Fu, 0x0E0A77C1u -); -// R^2 mod p -const BN254_R2 = array( - 0x538AFA89u, 0xF32CFC5Bu, 0xD44501FBu, 0xB5E71911u, - 0x0A417FF6u, 0x47AB1EFFu, 0xCAB8351Fu, 0x06D89F71u -); -// -p^{-1} mod 2^32 (low 32 bits of -p^{-1} mod 2^64 = 0x87D20782_E4866389) -const BN254_INV: u32 = 0xE4866389u; - -// ============================================================================= -// 256-bit (8 x u32) helpers -// ============================================================================= - -fn u256_zero() -> array { - return array(0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u); -} - -fn u256_is_zero(a: ptr>) -> bool { - var acc = 0u; - for (var i = 0u; i < 8u; i = i + 1u) { acc = acc | (*a)[i]; } - return acc == 0u; -} - -fn u256_cmp(a: ptr>, b: ptr>) -> i32 { - for (var i = 7i; i >= 0; i = i - 1) { - let ui = u32(i); - if ((*a)[ui] > (*b)[ui]) { return 1; } - if ((*a)[ui] < (*b)[ui]) { return -1; } - } - return 0; -} - -fn u256_add(a: ptr>, b: ptr>, - r: ptr>) -> u32 { - var c = 0u; - for (var i = 0u; i < 8u; i = i + 1u) { - let s1 = (*a)[i] + c; - c = select(0u, 1u, s1 < (*a)[i]); - let s2 = s1 + (*b)[i]; - c = c + select(0u, 1u, s2 < s1); - (*r)[i] = s2; - } - return c; -} - -fn u256_sub(a: ptr>, b: ptr>, - r: ptr>) -> u32 { - var bw = 0u; - for (var i = 0u; i < 8u; i = i + 1u) { - let d1 = (*a)[i] - bw; - bw = select(0u, 1u, d1 > (*a)[i]); - let d2 = d1 - (*b)[i]; - bw = bw + select(0u, 1u, d2 > d1); - (*r)[i] = d2; - } - return bw; -} - -// ============================================================================= -// Montgomery reduction (CIOS) over BN254 p, 8x u32 limbs -// ============================================================================= -// -// Mirrors mont_reduce in secp256k1.wgsl. Given t in [0, p*R), returns t*R^{-1} mod p. - -fn mont_reduce(t: ptr>, r: ptr>) { - var p = BN254_P; - var a: array; - for (var i = 0u; i < 16u; i = i + 1u) { a[i] = (*t)[i]; } - a[16] = 0u; - - for (var i = 0u; i < 8u; i = i + 1u) { - let u = a[i] * BN254_INV; - var carry = 0u; - for (var j = 0u; j < 8u; j = j + 1u) { - // u * p[j] -> (hi, lo) - let u_lo = u & 0xFFFFu; let u_hi = u >> 16u; - let m_lo = p[j] & 0xFFFFu; let m_hi = p[j] >> 16u; - let ll = u_lo * m_lo; - let lh = u_lo * m_hi; - let hl = u_hi * m_lo; - let hh = u_hi * m_hi; - let mid = lh + hl; - var lo = ll + (mid << 16u); - var hi = hh + (mid >> 16u) + select(0u, 1u, lo < ll) + select(0u, 0x10000u, mid < lh); - - let s1 = lo + carry; - hi = hi + select(0u, 1u, s1 < lo); - let s2 = a[i + j] + s1; - hi = hi + select(0u, 1u, s2 < a[i + j]); - a[i + j] = s2; - carry = hi; - } - // propagate carry through high half - for (var j = 8u; i + j <= 16u; j = j + 1u) { - let s = a[i + j] + carry; - carry = select(0u, 1u, s < a[i + j]); - a[i + j] = s; - if (carry == 0u) { break; } - } - } - - for (var i = 0u; i < 8u; i = i + 1u) { (*r)[i] = a[i + 8u]; } - if (a[16] != 0u || u256_cmp(r, &p) >= 0) { - _ = u256_sub(r, &p, r); - } -} - -fn mont_mul(a: ptr>, b: ptr>, - r: ptr>) { - var t: array; - for (var i = 0u; i < 16u; i = i + 1u) { t[i] = 0u; } - - for (var i = 0u; i < 8u; i = i + 1u) { - var carry = 0u; - for (var j = 0u; j < 8u; j = j + 1u) { - let al = (*a)[i] & 0xFFFFu; let ah = (*a)[i] >> 16u; - let bl = (*b)[j] & 0xFFFFu; let bh = (*b)[j] >> 16u; - let ll = al * bl; - let lh = al * bh; - let hl = ah * bl; - let hh = ah * bh; - let mid = lh + hl; - var lo = ll + (mid << 16u); - var hi = hh + (mid >> 16u) + select(0u, 1u, lo < ll) + select(0u, 0x10000u, mid < lh); - let s1 = lo + carry; hi = hi + select(0u, 1u, s1 < lo); - let s2 = t[i + j] + s1; hi = hi + select(0u, 1u, s2 < t[i + j]); - t[i + j] = s2; - carry = hi; - } - for (var j = 8u; i + j < 16u; j = j + 1u) { - let s = t[i + j] + carry; - carry = select(0u, 1u, s < t[i + j]); - t[i + j] = s; - if (carry == 0u) { break; } - } - } - mont_reduce(&t, r); -} - -// ============================================================================= -// Field ops over p (Montgomery) -// ============================================================================= - -fn fp_add(a: ptr>, b: ptr>, - r: ptr>) { - var p = BN254_P; - let c = u256_add(a, b, r); - if (c != 0u || u256_cmp(r, &p) >= 0) { - _ = u256_sub(r, &p, r); - } -} - -fn fp_sub(a: ptr>, b: ptr>, - r: ptr>) { - var p = BN254_P; - let bw = u256_sub(a, b, r); - if (bw != 0u) { - _ = u256_add(r, &p, r); - } -} - -fn fp_mul(a: ptr>, b: ptr>, - r: ptr>) { mont_mul(a, b, r); } - -fn fp_sqr(a: ptr>, r: ptr>) { - mont_mul(a, a, r); -} - -fn to_mont_p(a: ptr>, r: ptr>) { - var r2 = BN254_R2; - fp_mul(a, &r2, r); -} - -fn from_mont_p(a: ptr>, r: ptr>) { - var t: array; - for (var i = 0u; i < 16u; i = i + 1u) { t[i] = 0u; } - for (var i = 0u; i < 8u; i = i + 1u) { t[i] = (*a)[i]; } - mont_reduce(&t, r); -} - -// Inversion via Fermat: a^(p-2) -fn fp_inv(a: ptr>, r: ptr>) { - // p - 2 LE limbs - var exp = array( - 0xD87CFD45u, 0x3C208C16u, 0x6871CA8Du, 0x97816A91u, - 0x8181585Du, 0xB85045B6u, 0xE131A029u, 0x30644E72u - ); - var one = array(1u, 0u, 0u, 0u, 0u, 0u, 0u, 0u); - var result: array; - to_mont_p(&one, &result); - var base: array; - for (var i = 0u; i < 8u; i = i + 1u) { base[i] = (*a)[i]; } - - for (var i = 0u; i < 8u; i = i + 1u) { - for (var bit = 0u; bit < 32u; bit = bit + 1u) { - if (((exp[i] >> bit) & 1u) != 0u) { - var tmp: array; - fp_mul(&result, &base, &tmp); - result = tmp; - } - var tmp2: array; - fp_sqr(&base, &tmp2); - base = tmp2; - } - } - *r = result; -} - -// ============================================================================= -// G1 in Jacobian (Montgomery X, Y, Z); Z == 0 represents infinity -// ============================================================================= - -struct G1Jac { - x: array, - y: array, - z: array, -} - -fn g1_zero() -> G1Jac { - var p: G1Jac; - var one = array(1u, 0u, 0u, 0u, 0u, 0u, 0u, 0u); - to_mont_p(&one, &p.x); - p.y = p.x; - p.z = u256_zero(); - return p; -} - -fn g1_is_zero(p: ptr) -> bool { - var z = (*p).z; - return u256_is_zero(&z); -} - -// BN254 a = 0 doubling (matches Metal's fp formulas exactly) -fn g1_dbl(p: ptr, r: ptr) { - if (g1_is_zero(p)) { *r = *p; return; } - var A: array; fp_sqr(&(*p).x, &A); - var B: array; fp_sqr(&(*p).y, &B); - var C: array; fp_sqr(&B, &C); - var t: array; fp_add(&(*p).x, &B, &t); - var t2: array; fp_sqr(&t, &t2); - var t3: array; fp_sub(&t2, &A, &t3); - var t4: array; fp_sub(&t3, &C, &t4); - var D: array; fp_add(&t4, &t4, &D); - var twoA: array; fp_add(&A, &A, &twoA); - var E: array; fp_add(&twoA, &A, &E); - var F: array; fp_sqr(&E, &F); - var twoD: array; fp_add(&D, &D, &twoD); - fp_sub(&F, &twoD, &(*r).x); - var DminusX: array; fp_sub(&D, &(*r).x, &DminusX); - var EDX: array; fp_mul(&E, &DminusX, &EDX); - var twoC: array; fp_add(&C, &C, &twoC); - var fourC: array; fp_add(&twoC, &twoC, &fourC); - var eightC: array; fp_add(&fourC, &fourC, &eightC); - fp_sub(&EDX, &eightC, &(*r).y); - var YZ: array; fp_mul(&(*p).y, &(*p).z, &YZ); - fp_add(&YZ, &YZ, &(*r).z); -} - -// Mixed Jacobian + Affine addition, byte-identical to Metal g1_add_mixed -fn g1_add_mixed(p: ptr, qx: ptr>, - qy: ptr>, r: ptr) { - if (g1_is_zero(p)) { - (*r).x = *qx; (*r).y = *qy; - var one = array(1u, 0u, 0u, 0u, 0u, 0u, 0u, 0u); - to_mont_p(&one, &(*r).z); - return; - } - var Z1Z1: array; fp_sqr(&(*p).z, &Z1Z1); - var U2: array; fp_mul(qx, &Z1Z1, &U2); - var ZZ1Z1: array; fp_mul(&(*p).z, &Z1Z1, &ZZ1Z1); - var S2: array; fp_mul(qy, &ZZ1Z1, &S2); - var H: array; fp_sub(&U2, &(*p).x, &H); - var R: array; fp_sub(&S2, &(*p).y, &R); - - if (u256_is_zero(&H)) { - if (u256_is_zero(&R)) { g1_dbl(p, r); return; } - *r = g1_zero(); - return; - } - var HH: array; fp_sqr(&H, &HH); - var twoHH: array; fp_add(&HH, &HH, &twoHH); - var I: array; fp_add(&twoHH, &twoHH, &I); - var J: array; fp_mul(&H, &I, &J); - var R2: array; fp_add(&R, &R, &R2); - var V: array; fp_mul(&(*p).x, &I, &V); - - var Rsq: array; fp_sqr(&R2, &Rsq); - var t1: array; fp_sub(&Rsq, &J, &t1); - var twoV: array; fp_add(&V, &V, &twoV); - fp_sub(&t1, &twoV, &(*r).x); - var VminusX3: array; fp_sub(&V, &(*r).x, &VminusX3); - var RVX: array; fp_mul(&R2, &VminusX3, &RVX); - var Y1J: array; fp_mul(&(*p).y, &J, &Y1J); - var twoY1J: array; fp_add(&Y1J, &Y1J, &twoY1J); - fp_sub(&RVX, &twoY1J, &(*r).y); - var twoH: array; fp_add(&H, &H, &twoH); - fp_mul(&(*p).z, &twoH, &(*r).z); -} - -// Full Jacobian addition (matches Metal g1_add) -fn g1_add(p: ptr, q: ptr, r: ptr) { - if (g1_is_zero(p)) { *r = *q; return; } - if (g1_is_zero(q)) { *r = *p; return; } - var Z1Z1: array; fp_sqr(&(*p).z, &Z1Z1); - var Z2Z2: array; fp_sqr(&(*q).z, &Z2Z2); - var U1: array; fp_mul(&(*p).x, &Z2Z2, &U1); - var U2: array; fp_mul(&(*q).x, &Z1Z1, &U2); - var Yq: array; fp_mul(&(*p).y, &(*q).z, &Yq); - var S1: array; fp_mul(&Yq, &Z2Z2, &S1); - var Yp: array; fp_mul(&(*q).y, &(*p).z, &Yp); - var S2: array; fp_mul(&Yp, &Z1Z1, &S2); - var H: array; fp_sub(&U2, &U1, &H); - var R: array; fp_sub(&S2, &S1, &R); - if (u256_is_zero(&H)) { - if (u256_is_zero(&R)) { g1_dbl(p, r); return; } - *r = g1_zero(); - return; - } - var R2: array; fp_add(&R, &R, &R2); - var HH: array; fp_sqr(&H, &HH); - var twoHH: array; fp_add(&HH, &HH, &twoHH); - var I: array; fp_add(&twoHH, &twoHH, &I); - var J: array; fp_mul(&H, &I, &J); - var V: array; fp_mul(&U1, &I, &V); - var Rsq: array; fp_sqr(&R2, &Rsq); - var t1: array; fp_sub(&Rsq, &J, &t1); - var twoV: array; fp_add(&V, &V, &twoV); - fp_sub(&t1, &twoV, &(*r).x); - var VmX: array; fp_sub(&V, &(*r).x, &VmX); - var RVX: array; fp_mul(&R2, &VmX, &RVX); - var S1J: array; fp_mul(&S1, &J, &S1J); - var twoS1J: array; fp_add(&S1J, &S1J, &twoS1J); - fp_sub(&RVX, &twoS1J, &(*r).y); - var Z1Z2: array; fp_mul(&(*p).z, &(*q).z, &Z1Z2); - var twoH: array; fp_add(&H, &H, &twoH); - fp_mul(&Z1Z2, &twoH, &(*r).z); -} - -fn g1_to_affine(p: ptr, ax: ptr>, - ay: ptr>) -> bool { - if (g1_is_zero(p)) { *ax = u256_zero(); *ay = u256_zero(); return true; } - var Zinv: array; fp_inv(&(*p).z, &Zinv); - var Zinv2: array; fp_sqr(&Zinv, &Zinv2); - var Zinv3: array; fp_mul(&Zinv2, &Zinv, &Zinv3); - fp_mul(&(*p).x, &Zinv2, ax); - fp_mul(&(*p).y, &Zinv3, ay); - return false; -} - -fn g1_scalar_mul_aff(qx: ptr>, qy: ptr>, - s: ptr>) -> G1Jac { - var acc = g1_zero(); - // top word first (MSB) -- 8 u32 limbs LE means limb[7] is most significant - for (var li = 7i; li >= 0; li = li - 1) { - let limb = (*s)[u32(li)]; - for (var bi = 31i; bi >= 0; bi = bi - 1) { - var dbl: G1Jac; - g1_dbl(&acc, &dbl); - acc = dbl; - if (((limb >> u32(bi)) & 1u) != 0u) { - var tmp: G1Jac; - g1_add_mixed(&acc, qx, qy, &tmp); - acc = tmp; - } - } - } - return acc; -} - -// ============================================================================= -// I/O: 32-byte BE field element (8 BE u32 words) <-> 8 LE u32 limbs -// ============================================================================= -// -// The host uploads raw BE bytes. WGSL reads u32 from the storage buffer as the -// host's native u32 (little-endian on every consumer GPU). So a 32-byte BE -// field element, viewed as 8 host u32s, is a sequence of byte-swapped limbs in -// reverse order: word[0] = bytes 0..3 (most significant 4 BE bytes), and the -// LE limb[7] equals byteswap(word[0]). - -fn bswap32(x: u32) -> u32 { - return ((x & 0x000000FFu) << 24u) - | ((x & 0x0000FF00u) << 8u) - | ((x & 0x00FF0000u) >> 8u) - | ((x & 0xFF000000u) >> 24u); -} - -fn read_be32_gens(word_off: u32) -> array { - var r: array; - for (var i = 0u; i < 8u; i = i + 1u) { - r[i] = bswap32(gens_be[word_off + 7u - i]); - } - return r; -} - -fn read_be32_scalars(word_off: u32) -> array { - var r: array; - for (var i = 0u; i < 8u; i = i + 1u) { - r[i] = bswap32(scalars_be[word_off + 7u - i]); - } - return r; -} - -fn read_be32_blindings(word_off: u32) -> array { - var r: array; - for (var i = 0u; i < 8u; i = i + 1u) { - r[i] = bswap32(blindings_be[word_off + 7u - i]); - } - return r; -} - -fn write_be32_out(word_off: u32, a: ptr>) { - for (var i = 0u; i < 8u; i = i + 1u) { - out_be[word_off + i] = bswap32((*a)[7u - i]); - } -} - -// ============================================================================= -// Stage 1 kernel: pedersen_pointmul -// ============================================================================= -// One thread per (m, i): scratch[m*(N+1)+i] = scalar_{m,i} * P_i -// where P_i = G_basis[i] for i) { - let tid = gid.x; - let M = dims.M; - let N = dims.N; - let total = M * (N + 1u); - if (tid >= total) { return; } - - let m = tid / (N + 1u); - let i = tid - m * (N + 1u); - - // Pick generator (32 BE bytes = 8 host u32 each for X, Y). - var Qx_raw: array; - var Qy_raw: array; - var scalar_raw: array; - if (i < N) { - Qx_raw = read_be32_gens(i * 16u); - Qy_raw = read_be32_gens(i * 16u + 8u); - scalar_raw = read_be32_scalars((m * N + i) * 8u); - } else { - Qx_raw = read_be32_gens(N * 16u); - Qy_raw = read_be32_gens(N * 16u + 8u); - scalar_raw = read_be32_blindings(m * 8u); - } - - // To Montgomery form for X, Y; scalar stays as raw integer (limbs). - var Qx_mont: array; to_mont_p(&Qx_raw, &Qx_mont); - var Qy_mont: array; to_mont_p(&Qy_raw, &Qy_mont); - - let result = g1_scalar_mul_aff(&Qx_mont, &Qy_mont, &scalar_raw); - - let base = tid * 24u; - for (var k = 0u; k < 8u; k = k + 1u) { scratch[base + 0u + k] = result.x[k]; } - for (var k = 0u; k < 8u; k = k + 1u) { scratch[base + 8u + k] = result.y[k]; } - for (var k = 0u; k < 8u; k = k + 1u) { scratch[base + 16u + k] = result.z[k]; } -} - -// ============================================================================= -// Stage 2 kernel: pedersen_reduce_add -// ============================================================================= -// One thread per commitment; sums (N+1) Jacobian terms, converts to affine, -// emits 64 BE bytes as (X || Y). - -fn scratch_load(idx: u32) -> G1Jac { - var r: G1Jac; - let base = idx * 24u; - for (var k = 0u; k < 8u; k = k + 1u) { r.x[k] = scratch_in[base + 0u + k]; } - for (var k = 0u; k < 8u; k = k + 1u) { r.y[k] = scratch_in[base + 8u + k]; } - for (var k = 0u; k < 8u; k = k + 1u) { r.z[k] = scratch_in[base + 16u + k]; } - return r; -} - -@compute @workgroup_size(32) -fn pedersen_reduce_add(@builtin(global_invocation_id) gid: vec3) { - let m = gid.x; - let M = dims2.M; - let N = dims2.N; - if (m >= M) { return; } - - let base = m * (N + 1u); - var acc = g1_zero(); - for (var i = 0u; i < N + 1u; i = i + 1u) { - var term = scratch_load(base + i); - if (g1_is_zero(&term)) { continue; } - var sum: G1Jac; - g1_add(&acc, &term, &sum); - acc = sum; - } - - var aff_x: array; var aff_y: array; - let inf = g1_to_affine(&acc, &aff_x, &aff_y); - let out_word = m * 16u; // 64 bytes / 4 = 16 u32 words per commitment - if (inf) { - for (var k = 0u; k < 16u; k = k + 1u) { out_be[out_word + k] = 0u; } - return; - } - var X_raw: array; from_mont_p(&aff_x, &X_raw); - var Y_raw: array; from_mont_p(&aff_y, &Y_raw); - write_be32_out(out_word, &X_raw); - write_be32_out(out_word + 8u, &Y_raw); -} diff --git a/pedersen/gpu/wgsl/pedersen_driver_wgpu.cpp b/pedersen/gpu/wgsl/pedersen_driver_wgpu.cpp deleted file mode 100644 index 3eef03e..0000000 --- a/pedersen/gpu/wgsl/pedersen_driver_wgpu.cpp +++ /dev/null @@ -1,304 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// WebGPU/WGSL host driver for the batched Pedersen vector commitment. -// -// Two-stage dispatch into a single shader module: -// stage 1 (group 0): pedersen_pointmul -- M*(N+1) threads -// stage 2 (group 1): pedersen_reduce_add -- M threads -// -// Build modes: -// * LUX_PEDERSEN_HAS_WEBGPU=1 -- Dawn or wgpu-native runtime -// * LUX_PEDERSEN_HAS_WGPU_NATIVE=1 -- enables wgpuDevicePoll -// Stub mode otherwise (returns -1). - -#include "pedersen_driver_wgpu.h" - -#if defined(LUX_PEDERSEN_HAS_WEBGPU) - -#include -#if defined(LUX_PEDERSEN_HAS_WGPU_NATIVE) -# include -#endif - -#include -#include -#include -#include -#include - -// Embedded WGSL source (concatenated by CMake into pedersen_wgsl_source.h). -#include "pedersen_wgsl_source.h" - -namespace { - -WGPUStringView sv(const char* s) { - WGPUStringView v{}; - v.data = s; - v.length = (s == nullptr) ? 0 : std::strlen(s); - return v; -} -WGPUStringView sv(const std::string& s) { - WGPUStringView v{}; - v.data = s.data(); - v.length = s.size(); - return v; -} - -void drain(WGPUInstance inst, WGPUDevice dev) { - if (inst) wgpuInstanceProcessEvents(inst); -#if defined(LUX_PEDERSEN_HAS_WGPU_NATIVE) - if (dev) wgpuDevicePoll(dev, /*wait=*/WGPU_TRUE, nullptr); -#else - (void)dev; -#endif -} - -bool wait_map(WGPUInstance inst, WGPUDevice dev, WGPUBuffer buf, - WGPUMapMode mode, size_t off, size_t size) { - struct State { std::atomic done{false}; WGPUMapAsyncStatus status{WGPUMapAsyncStatus_Error}; } s; - WGPUBufferMapCallbackInfo cb{}; - cb.mode = WGPUCallbackMode_AllowProcessEvents; - cb.callback = [](WGPUMapAsyncStatus st, WGPUStringView, void* u, void*) { - auto* p = static_cast(u); - p->status = st; - p->done.store(true, std::memory_order_release); - }; - cb.userdata1 = &s; - wgpuBufferMapAsync(buf, mode, off, size, cb); - for (int spin = 0; spin < 8192; spin++) { - if (s.done.load(std::memory_order_acquire)) break; - drain(inst, dev); - } - return s.done.load() && s.status == WGPUMapAsyncStatus_Success; -} - -struct Engine { - WGPUInstance instance{nullptr}; - WGPUAdapter adapter{nullptr}; - WGPUDevice device{nullptr}; - WGPUQueue queue{nullptr}; - WGPUShaderModule module{nullptr}; - bool initialized{false}; -}; - -Engine& engine() { static Engine e; return e; } - -bool init_engine() { - Engine& e = engine(); - if (e.initialized) return true; - - WGPUInstanceDescriptor idesc{}; - e.instance = wgpuCreateInstance(&idesc); - if (!e.instance) return false; - - struct AS { std::atomic done{false}; WGPUAdapter ad{nullptr}; } as; - WGPURequestAdapterOptions ropt{}; - ropt.powerPreference = WGPUPowerPreference_HighPerformance; - WGPURequestAdapterCallbackInfo rcb{}; - rcb.mode = WGPUCallbackMode_AllowProcessEvents; - rcb.callback = [](WGPURequestAdapterStatus st, WGPUAdapter ad, - WGPUStringView, void* u, void*) { - auto* p = static_cast(u); - if (st == WGPURequestAdapterStatus_Success) p->ad = ad; - p->done.store(true, std::memory_order_release); - }; - rcb.userdata1 = &as; - wgpuInstanceRequestAdapter(e.instance, &ropt, rcb); - for (int spin = 0; spin < 8192; spin++) { - if (as.done.load(std::memory_order_acquire)) break; - wgpuInstanceProcessEvents(e.instance); - } - if (!as.ad) return false; - e.adapter = as.ad; - - struct DS { std::atomic done{false}; WGPUDevice dev{nullptr}; } ds; - WGPUDeviceDescriptor ddesc{}; - WGPURequestDeviceCallbackInfo dcb{}; - dcb.mode = WGPUCallbackMode_AllowProcessEvents; - dcb.callback = [](WGPURequestDeviceStatus st, WGPUDevice dev, - WGPUStringView, void* u, void*) { - auto* p = static_cast(u); - if (st == WGPURequestDeviceStatus_Success) p->dev = dev; - p->done.store(true, std::memory_order_release); - }; - dcb.userdata1 = &ds; - wgpuAdapterRequestDevice(e.adapter, &ddesc, dcb); - for (int spin = 0; spin < 8192; spin++) { - if (ds.done.load(std::memory_order_acquire)) break; - wgpuInstanceProcessEvents(e.instance); - } - if (!ds.dev) return false; - e.device = ds.dev; - e.queue = wgpuDeviceGetQueue(e.device); - if (!e.queue) return false; - - std::string src(kPedersenWGSL); - WGPUShaderSourceWGSL wgsl{}; - wgsl.chain.sType = WGPUSType_ShaderSourceWGSL; - wgsl.code = sv(src); - WGPUShaderModuleDescriptor smd{}; - smd.nextInChain = &wgsl.chain; - smd.label = sv("pedersen"); - e.module = wgpuDeviceCreateShaderModule(e.device, &smd); - if (!e.module) return false; - - e.initialized = true; - return true; -} - -WGPUBuffer make_buf(Engine& e, size_t size, WGPUBufferUsage usage) { - WGPUBufferDescriptor bd{}; - bd.size = (size + 3) & ~size_t(3); - bd.usage = usage; - return wgpuDeviceCreateBuffer(e.device, &bd); -} - -} // namespace - -extern "C" int lux_pedersen_wgpu_available(void) { - return init_engine() ? 1 : 0; -} - -extern "C" int pedersen_batch_wgpu( - const uint8_t* gens_be, - const uint8_t* scalars_be, - const uint8_t* blindings_be, - uint32_t M, - uint32_t N, - uint8_t* out_be) { - if (M == 0 || N == 0) return 0; - if (!gens_be || !scalars_be || !blindings_be || !out_be) return -1; - if (!init_engine()) return -1; - - Engine& e = engine(); - - size_t gens_len = (size_t)(N + 1) * 64; - size_t scalars_len = (size_t)M * N * 32; - size_t blind_len = (size_t)M * 32; - size_t scratch_len = (size_t)M * (N + 1) * 24 * sizeof(uint32_t); - size_t out_len = (size_t)M * 64; - - WGPUBuffer bufGens = make_buf(e, gens_len, - (WGPUBufferUsage)(WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst)); - WGPUBuffer bufScalars = make_buf(e, scalars_len, - (WGPUBufferUsage)(WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst)); - WGPUBuffer bufBlind = make_buf(e, blind_len, - (WGPUBufferUsage)(WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst)); - WGPUBuffer bufScratch = make_buf(e, scratch_len, - (WGPUBufferUsage)(WGPUBufferUsage_Storage)); - WGPUBuffer bufOut = make_buf(e, out_len, - (WGPUBufferUsage)(WGPUBufferUsage_Storage | WGPUBufferUsage_CopySrc)); - WGPUBuffer bufDims = make_buf(e, 16, - (WGPUBufferUsage)(WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst)); - WGPUBuffer bufRead = make_buf(e, out_len, - (WGPUBufferUsage)(WGPUBufferUsage_MapRead | WGPUBufferUsage_CopyDst)); - if (!bufGens || !bufScalars || !bufBlind || !bufScratch || - !bufOut || !bufDims || !bufRead) { - return -2; - } - - wgpuQueueWriteBuffer(e.queue, bufGens, 0, gens_be, gens_len); - wgpuQueueWriteBuffer(e.queue, bufScalars, 0, scalars_be, scalars_len); - wgpuQueueWriteBuffer(e.queue, bufBlind, 0, blindings_be, blind_len); - uint32_t dimVals[4] = { M, N, 0, 0 }; - wgpuQueueWriteBuffer(e.queue, bufDims, 0, dimVals, 16); - - // Stage 1: pipeline + bind group 0 ------------------------------------ - WGPUComputePipelineDescriptor cpd1{}; - cpd1.compute.module = e.module; - cpd1.compute.entryPoint = sv("pedersen_pointmul"); - cpd1.label = sv("pedersen_pointmul"); - WGPUComputePipeline pso1 = wgpuDeviceCreateComputePipeline(e.device, &cpd1); - if (!pso1) return -3; - - WGPUBindGroupLayout bgl1 = wgpuComputePipelineGetBindGroupLayout(pso1, 0); - WGPUBindGroupEntry bge1[5] = {}; - bge1[0].binding = 0; bge1[0].buffer = bufGens; bge1[0].size = gens_len; - bge1[1].binding = 1; bge1[1].buffer = bufScalars; bge1[1].size = scalars_len; - bge1[2].binding = 2; bge1[2].buffer = bufBlind; bge1[2].size = blind_len; - bge1[3].binding = 3; bge1[3].buffer = bufScratch; bge1[3].size = scratch_len; - bge1[4].binding = 4; bge1[4].buffer = bufDims; bge1[4].size = 16; - WGPUBindGroupDescriptor bgd1{}; - bgd1.layout = bgl1; bgd1.entryCount = 5; bgd1.entries = bge1; - WGPUBindGroup bg1 = wgpuDeviceCreateBindGroup(e.device, &bgd1); - if (!bg1) return -4; - - // Stage 2: pipeline + bind group 1 ------------------------------------ - WGPUComputePipelineDescriptor cpd2{}; - cpd2.compute.module = e.module; - cpd2.compute.entryPoint = sv("pedersen_reduce_add"); - cpd2.label = sv("pedersen_reduce_add"); - WGPUComputePipeline pso2 = wgpuDeviceCreateComputePipeline(e.device, &cpd2); - if (!pso2) return -5; - - WGPUBindGroupLayout bgl2 = wgpuComputePipelineGetBindGroupLayout(pso2, 1); - WGPUBindGroupEntry bge2[3] = {}; - bge2[0].binding = 0; bge2[0].buffer = bufScratch; bge2[0].size = scratch_len; - bge2[1].binding = 1; bge2[1].buffer = bufOut; bge2[1].size = out_len; - bge2[2].binding = 2; bge2[2].buffer = bufDims; bge2[2].size = 16; - WGPUBindGroupDescriptor bgd2{}; - bgd2.layout = bgl2; bgd2.entryCount = 3; bgd2.entries = bge2; - WGPUBindGroup bg2 = wgpuDeviceCreateBindGroup(e.device, &bgd2); - if (!bg2) return -6; - - // Encode + dispatch both stages in one command buffer. - WGPUCommandEncoderDescriptor ced{}; - WGPUCommandEncoder ce = wgpuDeviceCreateCommandEncoder(e.device, &ced); - - { - WGPUComputePassDescriptor cpd{}; - WGPUComputePassEncoder cpe = wgpuCommandEncoderBeginComputePass(ce, &cpd); - wgpuComputePassEncoderSetPipeline(cpe, pso1); - wgpuComputePassEncoderSetBindGroup(cpe, 0, bg1, 0, nullptr); - uint32_t total1 = M * (N + 1); - uint32_t wg1 = (total1 + 63) / 64; - wgpuComputePassEncoderDispatchWorkgroups(cpe, wg1, 1, 1); - wgpuComputePassEncoderEnd(cpe); - wgpuComputePassEncoderRelease(cpe); - } - { - WGPUComputePassDescriptor cpd{}; - WGPUComputePassEncoder cpe = wgpuCommandEncoderBeginComputePass(ce, &cpd); - wgpuComputePassEncoderSetPipeline(cpe, pso2); - wgpuComputePassEncoderSetBindGroup(cpe, 1, bg2, 0, nullptr); - uint32_t wg2 = (M + 31) / 32; - wgpuComputePassEncoderDispatchWorkgroups(cpe, wg2, 1, 1); - wgpuComputePassEncoderEnd(cpe); - wgpuComputePassEncoderRelease(cpe); - } - - wgpuCommandEncoderCopyBufferToBuffer(ce, bufOut, 0, bufRead, 0, out_len); - WGPUCommandBufferDescriptor cbd{}; - WGPUCommandBuffer cmd = wgpuCommandEncoderFinish(ce, &cbd); - wgpuQueueSubmit(e.queue, 1, &cmd); - - int rc = 0; - if (!wait_map(e.instance, e.device, bufRead, WGPUMapMode_Read, 0, out_len)) { - rc = -7; - } else { - const void* mapped = wgpuBufferGetConstMappedRange(bufRead, 0, out_len); - std::memcpy(out_be, mapped, out_len); - wgpuBufferUnmap(bufRead); - } - - wgpuCommandEncoderRelease(ce); - wgpuCommandBufferRelease(cmd); - wgpuBindGroupRelease(bg1); wgpuBindGroupRelease(bg2); - wgpuBindGroupLayoutRelease(bgl1); wgpuBindGroupLayoutRelease(bgl2); - wgpuComputePipelineRelease(pso1); wgpuComputePipelineRelease(pso2); - wgpuBufferRelease(bufGens); wgpuBufferRelease(bufScalars); - wgpuBufferRelease(bufBlind); wgpuBufferRelease(bufScratch); - wgpuBufferRelease(bufOut); wgpuBufferRelease(bufDims); - wgpuBufferRelease(bufRead); - return rc; -} - -#else // LUX_PEDERSEN_HAS_WEBGPU not defined: stub mode - -extern "C" int lux_pedersen_wgpu_available(void) { return 0; } -extern "C" int pedersen_batch_wgpu( - const uint8_t*, const uint8_t*, const uint8_t*, - uint32_t, uint32_t, uint8_t*) { return -1; } - -#endif diff --git a/pedersen/gpu/wgsl/pedersen_driver_wgpu.h b/pedersen/gpu/wgsl/pedersen_driver_wgpu.h deleted file mode 100644 index cce80a8..0000000 --- a/pedersen/gpu/wgsl/pedersen_driver_wgpu.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Public C-ABI for the WebGPU/WGSL driver of the batched Pedersen vector -// commitment. Mirrors pedersen_batch_metal exactly. - -#ifndef LUX_PEDERSEN_DRIVER_WGPU_H -#define LUX_PEDERSEN_DRIVER_WGPU_H - -#include - -#ifdef __cplusplus -extern "C" { -#endif - -int lux_pedersen_wgpu_available(void); - -int pedersen_batch_wgpu( - const uint8_t* gens_be, - const uint8_t* scalars_be, - const uint8_t* blindings_be, - uint32_t M, - uint32_t N, - uint8_t* out_be); - -#ifdef __cplusplus -} -#endif - -#endif // LUX_PEDERSEN_DRIVER_WGPU_H diff --git a/pedersen/gpu/wgsl/pedersen_tree.wgsl b/pedersen/gpu/wgsl/pedersen_tree.wgsl deleted file mode 100644 index 5cbfaf1..0000000 --- a/pedersen/gpu/wgsl/pedersen_tree.wgsl +++ /dev/null @@ -1,535 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Tree-reduce WGSL kernel for batched Pedersen vector commitments at the -// fixed Verkle width N = 256. One workgroup per commitment, 256 invocations -// per workgroup. Workgroup-local var array holds the 256 partial -// points; an 8-stride workgroupBarrier-synchronised reduction collapses -// them inside the same dispatch. -// -// 256-bit limbs are 8 x u32 little-endian (no native u64 in WGSL). -// -// Bindings: -// group(0) binding(0) gens_be : storage -// group(0) binding(1) scalars_be : storage -// group(0) binding(2) blindings_be : storage -// group(0) binding(3) out_be : storage -// group(0) binding(4) dims : uniform -// -// Output byte-equal to pedersen_tree_metal / pedersen_tree_cuda and the legacy -// two-stage pedersen.wgsl pipeline. - -@group(0) @binding(0) var gens_be : array; -@group(0) @binding(1) var scalars_be : array; -@group(0) @binding(2) var blindings_be : array; -@group(0) @binding(3) var out_be : array; -@group(0) @binding(4) var dims : Dims; - -struct Dims { - M: u32, - N: u32, - _pad0: u32, - _pad1: u32, -} - -// ============================================================================= -// BN254 base-field constants -- 8 x u32 LE -// ============================================================================= -// p = 0x30644E72E131A029 B85045B68181585D 97816A916871CA8D 3C208C16D87CFD47 -const BN254_P = array( - 0xD87CFD47u, 0x3C208C16u, 0x6871CA8Du, 0x97816A91u, - 0x8181585Du, 0xB85045B6u, 0xE131A029u, 0x30644E72u -); -const BN254_R_MONT = array( - 0xC58F0D9Du, 0xD35D438Du, 0xF5C70B3Du, 0x0A78EB28u, - 0x7879462Cu, 0x666EA36Fu, 0x9A07DF2Fu, 0x0E0A77C1u -); -const BN254_R2 = array( - 0x538AFA89u, 0xF32CFC5Bu, 0xD44501FBu, 0xB5E71911u, - 0x0A417FF6u, 0x47AB1EFFu, 0xCAB8351Fu, 0x06D89F71u -); -const BN254_INV: u32 = 0xE4866389u; - -// ============================================================================= -// 256-bit (8 x u32) helpers -// ============================================================================= - -fn u256_zero() -> array { - return array(0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u); -} - -fn u256_is_zero(a: ptr>) -> bool { - var acc = 0u; - for (var i = 0u; i < 8u; i = i + 1u) { acc = acc | (*a)[i]; } - return acc == 0u; -} - -fn u256_cmp(a: ptr>, b: ptr>) -> i32 { - for (var i = 7i; i >= 0; i = i - 1) { - let ui = u32(i); - if ((*a)[ui] > (*b)[ui]) { return 1; } - if ((*a)[ui] < (*b)[ui]) { return -1; } - } - return 0; -} - -fn u256_add(a: ptr>, b: ptr>, - r: ptr>) -> u32 { - var c = 0u; - for (var i = 0u; i < 8u; i = i + 1u) { - let s1 = (*a)[i] + c; - c = select(0u, 1u, s1 < (*a)[i]); - let s2 = s1 + (*b)[i]; - c = c + select(0u, 1u, s2 < s1); - (*r)[i] = s2; - } - return c; -} - -fn u256_sub(a: ptr>, b: ptr>, - r: ptr>) -> u32 { - var bw = 0u; - for (var i = 0u; i < 8u; i = i + 1u) { - let d1 = (*a)[i] - bw; - bw = select(0u, 1u, d1 > (*a)[i]); - let d2 = d1 - (*b)[i]; - bw = bw + select(0u, 1u, d2 > d1); - (*r)[i] = d2; - } - return bw; -} - -// ============================================================================= -// Montgomery reduction (CIOS) over BN254 p, 8x u32 limbs -// ============================================================================= - -fn mont_reduce(t: ptr>, r: ptr>) { - var p = BN254_P; - var a: array; - for (var i = 0u; i < 16u; i = i + 1u) { a[i] = (*t)[i]; } - a[16] = 0u; - - for (var i = 0u; i < 8u; i = i + 1u) { - let u = a[i] * BN254_INV; - var carry = 0u; - for (var j = 0u; j < 8u; j = j + 1u) { - let u_lo = u & 0xFFFFu; let u_hi = u >> 16u; - let m_lo = p[j] & 0xFFFFu; let m_hi = p[j] >> 16u; - let ll = u_lo * m_lo; - let lh = u_lo * m_hi; - let hl = u_hi * m_lo; - let hh = u_hi * m_hi; - let mid = lh + hl; - var lo = ll + (mid << 16u); - var hi = hh + (mid >> 16u) + select(0u, 1u, lo < ll) + select(0u, 0x10000u, mid < lh); - - let s1 = lo + carry; - hi = hi + select(0u, 1u, s1 < lo); - let s2 = a[i + j] + s1; - hi = hi + select(0u, 1u, s2 < a[i + j]); - a[i + j] = s2; - carry = hi; - } - for (var j = 8u; i + j <= 16u; j = j + 1u) { - let s = a[i + j] + carry; - carry = select(0u, 1u, s < a[i + j]); - a[i + j] = s; - if (carry == 0u) { break; } - } - } - - for (var i = 0u; i < 8u; i = i + 1u) { (*r)[i] = a[i + 8u]; } - if (a[16] != 0u || u256_cmp(r, &p) >= 0) { - _ = u256_sub(r, &p, r); - } -} - -fn mont_mul(a: ptr>, b: ptr>, - r: ptr>) { - var t: array; - for (var i = 0u; i < 16u; i = i + 1u) { t[i] = 0u; } - - for (var i = 0u; i < 8u; i = i + 1u) { - var carry = 0u; - for (var j = 0u; j < 8u; j = j + 1u) { - let al = (*a)[i] & 0xFFFFu; let ah = (*a)[i] >> 16u; - let bl = (*b)[j] & 0xFFFFu; let bh = (*b)[j] >> 16u; - let ll = al * bl; - let lh = al * bh; - let hl = ah * bl; - let hh = ah * bh; - let mid = lh + hl; - var lo = ll + (mid << 16u); - var hi = hh + (mid >> 16u) + select(0u, 1u, lo < ll) + select(0u, 0x10000u, mid < lh); - let s1 = lo + carry; hi = hi + select(0u, 1u, s1 < lo); - let s2 = t[i + j] + s1; hi = hi + select(0u, 1u, s2 < t[i + j]); - t[i + j] = s2; - carry = hi; - } - for (var j = 8u; i + j < 16u; j = j + 1u) { - let s = t[i + j] + carry; - carry = select(0u, 1u, s < t[i + j]); - t[i + j] = s; - if (carry == 0u) { break; } - } - } - mont_reduce(&t, r); -} - -fn fp_add(a: ptr>, b: ptr>, - r: ptr>) { - var p = BN254_P; - let c = u256_add(a, b, r); - if (c != 0u || u256_cmp(r, &p) >= 0) { - _ = u256_sub(r, &p, r); - } -} - -fn fp_sub(a: ptr>, b: ptr>, - r: ptr>) { - var p = BN254_P; - let bw = u256_sub(a, b, r); - if (bw != 0u) { - _ = u256_add(r, &p, r); - } -} - -fn fp_mul(a: ptr>, b: ptr>, - r: ptr>) { mont_mul(a, b, r); } - -fn fp_sqr(a: ptr>, r: ptr>) { - mont_mul(a, a, r); -} - -fn to_mont_p(a: ptr>, r: ptr>) { - var r2 = BN254_R2; - fp_mul(a, &r2, r); -} - -fn from_mont_p(a: ptr>, r: ptr>) { - var t: array; - for (var i = 0u; i < 16u; i = i + 1u) { t[i] = 0u; } - for (var i = 0u; i < 8u; i = i + 1u) { t[i] = (*a)[i]; } - mont_reduce(&t, r); -} - -fn fp_inv(a: ptr>, r: ptr>) { - var exp = array( - 0xD87CFD45u, 0x3C208C16u, 0x6871CA8Du, 0x97816A91u, - 0x8181585Du, 0xB85045B6u, 0xE131A029u, 0x30644E72u - ); - var one = array(1u, 0u, 0u, 0u, 0u, 0u, 0u, 0u); - var result: array; - to_mont_p(&one, &result); - var base: array; - for (var i = 0u; i < 8u; i = i + 1u) { base[i] = (*a)[i]; } - - for (var i = 0u; i < 8u; i = i + 1u) { - for (var bit = 0u; bit < 32u; bit = bit + 1u) { - if (((exp[i] >> bit) & 1u) != 0u) { - var tmp: array; - fp_mul(&result, &base, &tmp); - result = tmp; - } - var tmp2: array; - fp_sqr(&base, &tmp2); - base = tmp2; - } - } - *r = result; -} - -// ============================================================================= -// G1 in Jacobian (Montgomery X, Y, Z); Z == 0 represents infinity -// ============================================================================= - -struct G1Jac { - x: array, - y: array, - z: array, -} - -fn g1_zero() -> G1Jac { - var p: G1Jac; - var one = array(1u, 0u, 0u, 0u, 0u, 0u, 0u, 0u); - to_mont_p(&one, &p.x); - p.y = p.x; - p.z = u256_zero(); - return p; -} - -fn g1_is_zero(p: ptr) -> bool { - var z = (*p).z; - return u256_is_zero(&z); -} - -fn g1_dbl(p: ptr, r: ptr) { - if (g1_is_zero(p)) { *r = *p; return; } - var A: array; fp_sqr(&(*p).x, &A); - var B: array; fp_sqr(&(*p).y, &B); - var C: array; fp_sqr(&B, &C); - var t: array; fp_add(&(*p).x, &B, &t); - var t2: array; fp_sqr(&t, &t2); - var t3: array; fp_sub(&t2, &A, &t3); - var t4: array; fp_sub(&t3, &C, &t4); - var D: array; fp_add(&t4, &t4, &D); - var twoA: array; fp_add(&A, &A, &twoA); - var E: array; fp_add(&twoA, &A, &E); - var F: array; fp_sqr(&E, &F); - var twoD: array; fp_add(&D, &D, &twoD); - fp_sub(&F, &twoD, &(*r).x); - var DminusX: array; fp_sub(&D, &(*r).x, &DminusX); - var EDX: array; fp_mul(&E, &DminusX, &EDX); - var twoC: array; fp_add(&C, &C, &twoC); - var fourC: array; fp_add(&twoC, &twoC, &fourC); - var eightC: array; fp_add(&fourC, &fourC, &eightC); - fp_sub(&EDX, &eightC, &(*r).y); - var YZ: array; fp_mul(&(*p).y, &(*p).z, &YZ); - fp_add(&YZ, &YZ, &(*r).z); -} - -fn g1_add_mixed(p: ptr, qx: ptr>, - qy: ptr>, r: ptr) { - if (g1_is_zero(p)) { - (*r).x = *qx; (*r).y = *qy; - var one = array(1u, 0u, 0u, 0u, 0u, 0u, 0u, 0u); - to_mont_p(&one, &(*r).z); - return; - } - var Z1Z1: array; fp_sqr(&(*p).z, &Z1Z1); - var U2: array; fp_mul(qx, &Z1Z1, &U2); - var ZZ1Z1: array; fp_mul(&(*p).z, &Z1Z1, &ZZ1Z1); - var S2: array; fp_mul(qy, &ZZ1Z1, &S2); - var H: array; fp_sub(&U2, &(*p).x, &H); - var R: array; fp_sub(&S2, &(*p).y, &R); - - if (u256_is_zero(&H)) { - if (u256_is_zero(&R)) { g1_dbl(p, r); return; } - *r = g1_zero(); - return; - } - var HH: array; fp_sqr(&H, &HH); - var twoHH: array; fp_add(&HH, &HH, &twoHH); - var I: array; fp_add(&twoHH, &twoHH, &I); - var J: array; fp_mul(&H, &I, &J); - var R2: array; fp_add(&R, &R, &R2); - var V: array; fp_mul(&(*p).x, &I, &V); - - var Rsq: array; fp_sqr(&R2, &Rsq); - var t1: array; fp_sub(&Rsq, &J, &t1); - var twoV: array; fp_add(&V, &V, &twoV); - fp_sub(&t1, &twoV, &(*r).x); - var VminusX3: array; fp_sub(&V, &(*r).x, &VminusX3); - var RVX: array; fp_mul(&R2, &VminusX3, &RVX); - var Y1J: array; fp_mul(&(*p).y, &J, &Y1J); - var twoY1J: array; fp_add(&Y1J, &Y1J, &twoY1J); - fp_sub(&RVX, &twoY1J, &(*r).y); - var twoH: array; fp_add(&H, &H, &twoH); - fp_mul(&(*p).z, &twoH, &(*r).z); -} - -fn g1_add(p: ptr, q: ptr, r: ptr) { - if (g1_is_zero(p)) { *r = *q; return; } - if (g1_is_zero(q)) { *r = *p; return; } - var Z1Z1: array; fp_sqr(&(*p).z, &Z1Z1); - var Z2Z2: array; fp_sqr(&(*q).z, &Z2Z2); - var U1: array; fp_mul(&(*p).x, &Z2Z2, &U1); - var U2: array; fp_mul(&(*q).x, &Z1Z1, &U2); - var Yq: array; fp_mul(&(*p).y, &(*q).z, &Yq); - var S1: array; fp_mul(&Yq, &Z2Z2, &S1); - var Yp: array; fp_mul(&(*q).y, &(*p).z, &Yp); - var S2: array; fp_mul(&Yp, &Z1Z1, &S2); - var H: array; fp_sub(&U2, &U1, &H); - var R: array; fp_sub(&S2, &S1, &R); - if (u256_is_zero(&H)) { - if (u256_is_zero(&R)) { g1_dbl(p, r); return; } - *r = g1_zero(); - return; - } - var R2: array; fp_add(&R, &R, &R2); - var HH: array; fp_sqr(&H, &HH); - var twoHH: array; fp_add(&HH, &HH, &twoHH); - var I: array; fp_add(&twoHH, &twoHH, &I); - var J: array; fp_mul(&H, &I, &J); - var V: array; fp_mul(&U1, &I, &V); - var Rsq: array; fp_sqr(&R2, &Rsq); - var t1: array; fp_sub(&Rsq, &J, &t1); - var twoV: array; fp_add(&V, &V, &twoV); - fp_sub(&t1, &twoV, &(*r).x); - var VmX: array; fp_sub(&V, &(*r).x, &VmX); - var RVX: array; fp_mul(&R2, &VmX, &RVX); - var S1J: array; fp_mul(&S1, &J, &S1J); - var twoS1J: array; fp_add(&S1J, &S1J, &twoS1J); - fp_sub(&RVX, &twoS1J, &(*r).y); - var Z1Z2: array; fp_mul(&(*p).z, &(*q).z, &Z1Z2); - var twoH: array; fp_add(&H, &H, &twoH); - fp_mul(&Z1Z2, &twoH, &(*r).z); -} - -fn g1_to_affine(p: ptr, ax: ptr>, - ay: ptr>) -> bool { - if (g1_is_zero(p)) { *ax = u256_zero(); *ay = u256_zero(); return true; } - var Zinv: array; fp_inv(&(*p).z, &Zinv); - var Zinv2: array; fp_sqr(&Zinv, &Zinv2); - var Zinv3: array; fp_mul(&Zinv2, &Zinv, &Zinv3); - fp_mul(&(*p).x, &Zinv2, ax); - fp_mul(&(*p).y, &Zinv3, ay); - return false; -} - -fn g1_scalar_mul_aff(qx: ptr>, qy: ptr>, - s: ptr>) -> G1Jac { - var acc = g1_zero(); - for (var li = 7i; li >= 0; li = li - 1) { - let limb = (*s)[u32(li)]; - for (var bi = 31i; bi >= 0; bi = bi - 1) { - var dbl: G1Jac; - g1_dbl(&acc, &dbl); - acc = dbl; - if (((limb >> u32(bi)) & 1u) != 0u) { - var tmp: G1Jac; - g1_add_mixed(&acc, qx, qy, &tmp); - acc = tmp; - } - } - } - return acc; -} - -// ============================================================================= -// I/O helpers (raw 32-byte BE field element <-> 8 LE u32 limbs) -// ============================================================================= - -fn bswap32(x: u32) -> u32 { - return ((x & 0x000000FFu) << 24u) - | ((x & 0x0000FF00u) << 8u) - | ((x & 0x00FF0000u) >> 8u) - | ((x & 0xFF000000u) >> 24u); -} - -fn read_be32_gens(word_off: u32) -> array { - var r: array; - for (var i = 0u; i < 8u; i = i + 1u) { - r[i] = bswap32(gens_be[word_off + 7u - i]); - } - return r; -} - -fn read_be32_scalars(word_off: u32) -> array { - var r: array; - for (var i = 0u; i < 8u; i = i + 1u) { - r[i] = bswap32(scalars_be[word_off + 7u - i]); - } - return r; -} - -fn read_be32_blindings(word_off: u32) -> array { - var r: array; - for (var i = 0u; i < 8u; i = i + 1u) { - r[i] = bswap32(blindings_be[word_off + 7u - i]); - } - return r; -} - -fn write_be32_out(word_off: u32, a: ptr>) { - for (var i = 0u; i < 8u; i = i + 1u) { - out_be[word_off + i] = bswap32((*a)[7u - i]); - } -} - -// ============================================================================= -// Workgroup-shared tree reduction storage -// ============================================================================= -// 256 slots * 24 u32 (X || Y || Z) = 24 KiB per workgroup. Within the -// minimum required workgroup-storage budget on every WebGPU adapter -// (16 KiB minimum per WebGPU spec, 32 KiB on every modern GPU we target). -// -// NOTE: WebGPU requires workgroup storage <= the device's -// maxComputeWorkgroupStorageSize. Dawn / wgpu-native both expose -// >= 16384 bytes by spec; in practice 32+ KiB on the GPUs we run. -// 256 * 24 * 4 = 24576 bytes. At spec floor (16384) this kernel cannot -// run; that is documented as the trade-off for a single-dispatch design. - -const PED_TREE_N : u32 = 256u; - -var shared_pts : array, PED_TREE_N>; - -@compute @workgroup_size(256) -fn pedersen_tree_commit( - @builtin(workgroup_id) wgid : vec3, - @builtin(local_invocation_id) lid : vec3, -) { - let M = dims.M; - let N = dims.N; - let bid = wgid.x; - let tid = lid.x; - if (bid >= M) { return; } - if (tid >= N) { return; } - - // Phase 1: each thread computes its term P[i] = scalars[bid,tid] * G[tid]. - var Qx_raw = read_be32_gens(tid * 16u); - var Qy_raw = read_be32_gens(tid * 16u + 8u); - var sc_raw = read_be32_scalars((bid * N + tid) * 8u); - var Qx_mont: array; to_mont_p(&Qx_raw, &Qx_mont); - var Qy_mont: array; to_mont_p(&Qy_raw, &Qy_mont); - let P = g1_scalar_mul_aff(&Qx_mont, &Qy_mont, &sc_raw); - - // Stash in workgroup memory. - for (var k = 0u; k < 8u; k = k + 1u) { shared_pts[tid][ 0u + k] = P.x[k]; } - for (var k = 0u; k < 8u; k = k + 1u) { shared_pts[tid][ 8u + k] = P.y[k]; } - for (var k = 0u; k < 8u; k = k + 1u) { shared_pts[tid][16u + k] = P.z[k]; } - workgroupBarrier(); - - // Phase 2: tree reduction (8 strides for N = 256). - var stride : u32 = 128u; - loop { - if (stride == 0u) { break; } - if (tid < stride) { - var a: G1Jac; - for (var k = 0u; k < 8u; k = k + 1u) { a.x[k] = shared_pts[tid][ 0u + k]; } - for (var k = 0u; k < 8u; k = k + 1u) { a.y[k] = shared_pts[tid][ 8u + k]; } - for (var k = 0u; k < 8u; k = k + 1u) { a.z[k] = shared_pts[tid][16u + k]; } - var b: G1Jac; - for (var k = 0u; k < 8u; k = k + 1u) { b.x[k] = shared_pts[tid + stride][ 0u + k]; } - for (var k = 0u; k < 8u; k = k + 1u) { b.y[k] = shared_pts[tid + stride][ 8u + k]; } - for (var k = 0u; k < 8u; k = k + 1u) { b.z[k] = shared_pts[tid + stride][16u + k]; } - var sum: G1Jac; - g1_add(&a, &b, &sum); - for (var k = 0u; k < 8u; k = k + 1u) { shared_pts[tid][ 0u + k] = sum.x[k]; } - for (var k = 0u; k < 8u; k = k + 1u) { shared_pts[tid][ 8u + k] = sum.y[k]; } - for (var k = 0u; k < 8u; k = k + 1u) { shared_pts[tid][16u + k] = sum.z[k]; } - } - workgroupBarrier(); - stride = stride >> 1u; - } - - // Phase 3 + 4: thread 0 finishes (load reduced sum, add r*H, emit). - if (tid == 0u) { - var acc: G1Jac; - for (var k = 0u; k < 8u; k = k + 1u) { acc.x[k] = shared_pts[0u][ 0u + k]; } - for (var k = 0u; k < 8u; k = k + 1u) { acc.y[k] = shared_pts[0u][ 8u + k]; } - for (var k = 0u; k < 8u; k = k + 1u) { acc.z[k] = shared_pts[0u][16u + k]; } - - var Hx_raw = read_be32_gens(N * 16u); - var Hy_raw = read_be32_gens(N * 16u + 8u); - var r_raw = read_be32_blindings(bid * 8u); - var Hx_mont: array; to_mont_p(&Hx_raw, &Hx_mont); - var Hy_mont: array; to_mont_p(&Hy_raw, &Hy_mont); - var rH = g1_scalar_mul_aff(&Hx_mont, &Hy_mont, &r_raw); - var sum: G1Jac; - g1_add(&acc, &rH, &sum); - acc = sum; - - var aff_x: array; var aff_y: array; - let inf = g1_to_affine(&acc, &aff_x, &aff_y); - let out_word = bid * 16u; - if (inf) { - for (var k = 0u; k < 16u; k = k + 1u) { out_be[out_word + k] = 0u; } - return; - } - var X_raw: array; from_mont_p(&aff_x, &X_raw); - var Y_raw: array; from_mont_p(&aff_y, &Y_raw); - write_be32_out(out_word, &X_raw); - write_be32_out(out_word + 8u, &Y_raw); - } -} diff --git a/pedersen/gpu/wgsl/pedersen_tree_driver.cpp b/pedersen/gpu/wgsl/pedersen_tree_driver.cpp deleted file mode 100644 index ac773eb..0000000 --- a/pedersen/gpu/wgsl/pedersen_tree_driver.cpp +++ /dev/null @@ -1,271 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// WebGPU/WGSL host driver for the tree-reduce Pedersen vector commitment. -// -// Single-stage dispatch: M workgroups of 256 invocations each. Workgroup -// memory holds the 256 partial points; an in-shader workgroupBarrier() -// loop collapses them to one before thread 0 emits the result. -// -// Build modes: -// * LUX_PEDERSEN_HAS_WEBGPU=1 -- Dawn / wgpu-native runtime -// * LUX_PEDERSEN_HAS_WGPU_NATIVE=1 -- enables wgpuDevicePoll -// Stub mode otherwise (returns -1). - -#include "pedersen_tree_driver.h" - -#if defined(LUX_PEDERSEN_HAS_WEBGPU) - -#include -#if defined(LUX_PEDERSEN_HAS_WGPU_NATIVE) -# include -#endif - -#include -#include -#include -#include -#include - -// Embedded WGSL source (concatenated by CMake into pedersen_tree_wgsl_source.h). -#include "pedersen_tree_wgsl_source.h" - -namespace { - -WGPUStringView sv(const char* s) { - WGPUStringView v{}; - v.data = s; - v.length = (s == nullptr) ? 0 : std::strlen(s); - return v; -} -WGPUStringView sv(const std::string& s) { - WGPUStringView v{}; - v.data = s.data(); - v.length = s.size(); - return v; -} - -void drain(WGPUInstance inst, WGPUDevice dev) { - if (inst) wgpuInstanceProcessEvents(inst); -#if defined(LUX_PEDERSEN_HAS_WGPU_NATIVE) - if (dev) wgpuDevicePoll(dev, /*wait=*/WGPU_TRUE, nullptr); -#else - (void)dev; -#endif -} - -bool wait_map(WGPUInstance inst, WGPUDevice dev, WGPUBuffer buf, - WGPUMapMode mode, size_t off, size_t size) { - struct State { std::atomic done{false}; WGPUMapAsyncStatus status{WGPUMapAsyncStatus_Error}; } s; - WGPUBufferMapCallbackInfo cb{}; - cb.mode = WGPUCallbackMode_AllowProcessEvents; - cb.callback = [](WGPUMapAsyncStatus st, WGPUStringView, void* u, void*) { - auto* p = static_cast(u); - p->status = st; - p->done.store(true, std::memory_order_release); - }; - cb.userdata1 = &s; - wgpuBufferMapAsync(buf, mode, off, size, cb); - for (int spin = 0; spin < 8192; spin++) { - if (s.done.load(std::memory_order_acquire)) break; - drain(inst, dev); - } - return s.done.load() && s.status == WGPUMapAsyncStatus_Success; -} - -struct Engine { - WGPUInstance instance{nullptr}; - WGPUAdapter adapter{nullptr}; - WGPUDevice device{nullptr}; - WGPUQueue queue{nullptr}; - WGPUShaderModule module{nullptr}; - bool initialized{false}; -}; - -Engine& engine() { static Engine e; return e; } - -bool init_engine() { - Engine& e = engine(); - if (e.initialized) return true; - - WGPUInstanceDescriptor idesc{}; - e.instance = wgpuCreateInstance(&idesc); - if (!e.instance) return false; - - struct AS { std::atomic done{false}; WGPUAdapter ad{nullptr}; } as; - WGPURequestAdapterOptions ropt{}; - ropt.powerPreference = WGPUPowerPreference_HighPerformance; - WGPURequestAdapterCallbackInfo rcb{}; - rcb.mode = WGPUCallbackMode_AllowProcessEvents; - rcb.callback = [](WGPURequestAdapterStatus st, WGPUAdapter ad, - WGPUStringView, void* u, void*) { - auto* p = static_cast(u); - if (st == WGPURequestAdapterStatus_Success) p->ad = ad; - p->done.store(true, std::memory_order_release); - }; - rcb.userdata1 = &as; - wgpuInstanceRequestAdapter(e.instance, &ropt, rcb); - for (int spin = 0; spin < 8192; spin++) { - if (as.done.load(std::memory_order_acquire)) break; - wgpuInstanceProcessEvents(e.instance); - } - if (!as.ad) return false; - e.adapter = as.ad; - - struct DS { std::atomic done{false}; WGPUDevice dev{nullptr}; } ds; - WGPUDeviceDescriptor ddesc{}; - WGPURequestDeviceCallbackInfo dcb{}; - dcb.mode = WGPUCallbackMode_AllowProcessEvents; - dcb.callback = [](WGPURequestDeviceStatus st, WGPUDevice dev, - WGPUStringView, void* u, void*) { - auto* p = static_cast(u); - if (st == WGPURequestDeviceStatus_Success) p->dev = dev; - p->done.store(true, std::memory_order_release); - }; - dcb.userdata1 = &ds; - wgpuAdapterRequestDevice(e.adapter, &ddesc, dcb); - for (int spin = 0; spin < 8192; spin++) { - if (ds.done.load(std::memory_order_acquire)) break; - wgpuInstanceProcessEvents(e.instance); - } - if (!ds.dev) return false; - e.device = ds.dev; - e.queue = wgpuDeviceGetQueue(e.device); - if (!e.queue) return false; - - std::string src(kPedersenTreeWGSL); - WGPUShaderSourceWGSL wgsl{}; - wgsl.chain.sType = WGPUSType_ShaderSourceWGSL; - wgsl.code = sv(src); - WGPUShaderModuleDescriptor smd{}; - smd.nextInChain = &wgsl.chain; - smd.label = sv("pedersen_tree"); - e.module = wgpuDeviceCreateShaderModule(e.device, &smd); - if (!e.module) return false; - - e.initialized = true; - return true; -} - -WGPUBuffer make_buf(Engine& e, size_t size, WGPUBufferUsage usage) { - WGPUBufferDescriptor bd{}; - bd.size = (size + 3) & ~size_t(3); - bd.usage = usage; - return wgpuDeviceCreateBuffer(e.device, &bd); -} - -} // namespace - -extern "C" int lux_pedersen_tree_wgpu_available(void) { - return init_engine() ? 1 : 0; -} - -extern "C" int pedersen_tree_wgpu( - const uint8_t* gens_be, - const uint8_t* scalars_be, - const uint8_t* blindings_be, - uint32_t M, - uint8_t* out_be) { - if (M == 0) return 0; - if (!gens_be || !scalars_be || !blindings_be || !out_be) return -1; - if (!init_engine()) return -1; - - Engine& e = engine(); - const uint32_t N = PEDERSEN_TREE_WIDTH; - - size_t gens_len = (size_t)(N + 1) * 64; - size_t scalars_len = (size_t)M * N * 32; - size_t blind_len = (size_t)M * 32; - size_t out_len = (size_t)M * 64; - - WGPUBuffer bufGens = make_buf(e, gens_len, - (WGPUBufferUsage)(WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst)); - WGPUBuffer bufScalars = make_buf(e, scalars_len, - (WGPUBufferUsage)(WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst)); - WGPUBuffer bufBlind = make_buf(e, blind_len, - (WGPUBufferUsage)(WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst)); - WGPUBuffer bufOut = make_buf(e, out_len, - (WGPUBufferUsage)(WGPUBufferUsage_Storage | WGPUBufferUsage_CopySrc)); - WGPUBuffer bufDims = make_buf(e, 16, - (WGPUBufferUsage)(WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst)); - WGPUBuffer bufRead = make_buf(e, out_len, - (WGPUBufferUsage)(WGPUBufferUsage_MapRead | WGPUBufferUsage_CopyDst)); - if (!bufGens || !bufScalars || !bufBlind || - !bufOut || !bufDims || !bufRead) { - return -2; - } - - wgpuQueueWriteBuffer(e.queue, bufGens, 0, gens_be, gens_len); - wgpuQueueWriteBuffer(e.queue, bufScalars, 0, scalars_be, scalars_len); - wgpuQueueWriteBuffer(e.queue, bufBlind, 0, blindings_be, blind_len); - uint32_t dimVals[4] = { M, N, 0, 0 }; - wgpuQueueWriteBuffer(e.queue, bufDims, 0, dimVals, 16); - - // Single pipeline; bind group 0. - WGPUComputePipelineDescriptor cpd{}; - cpd.compute.module = e.module; - cpd.compute.entryPoint = sv("pedersen_tree_commit"); - cpd.label = sv("pedersen_tree_commit"); - WGPUComputePipeline pso = wgpuDeviceCreateComputePipeline(e.device, &cpd); - if (!pso) return -3; - - WGPUBindGroupLayout bgl = wgpuComputePipelineGetBindGroupLayout(pso, 0); - WGPUBindGroupEntry bge[5] = {}; - bge[0].binding = 0; bge[0].buffer = bufGens; bge[0].size = gens_len; - bge[1].binding = 1; bge[1].buffer = bufScalars; bge[1].size = scalars_len; - bge[2].binding = 2; bge[2].buffer = bufBlind; bge[2].size = blind_len; - bge[3].binding = 3; bge[3].buffer = bufOut; bge[3].size = out_len; - bge[4].binding = 4; bge[4].buffer = bufDims; bge[4].size = 16; - WGPUBindGroupDescriptor bgd{}; - bgd.layout = bgl; bgd.entryCount = 5; bgd.entries = bge; - WGPUBindGroup bg = wgpuDeviceCreateBindGroup(e.device, &bgd); - if (!bg) return -4; - - WGPUCommandEncoderDescriptor ced{}; - WGPUCommandEncoder ce = wgpuDeviceCreateCommandEncoder(e.device, &ced); - - { - WGPUComputePassDescriptor cpdesc{}; - WGPUComputePassEncoder cpe = wgpuCommandEncoderBeginComputePass(ce, &cpdesc); - wgpuComputePassEncoderSetPipeline(cpe, pso); - wgpuComputePassEncoderSetBindGroup(cpe, 0, bg, 0, nullptr); - // M workgroups; the kernel fixes the workgroup size at 256. - wgpuComputePassEncoderDispatchWorkgroups(cpe, M, 1, 1); - wgpuComputePassEncoderEnd(cpe); - wgpuComputePassEncoderRelease(cpe); - } - - wgpuCommandEncoderCopyBufferToBuffer(ce, bufOut, 0, bufRead, 0, out_len); - WGPUCommandBufferDescriptor cbd{}; - WGPUCommandBuffer cmd = wgpuCommandEncoderFinish(ce, &cbd); - wgpuQueueSubmit(e.queue, 1, &cmd); - - int rc = 0; - if (!wait_map(e.instance, e.device, bufRead, WGPUMapMode_Read, 0, out_len)) { - rc = -5; - } else { - const void* mapped = wgpuBufferGetConstMappedRange(bufRead, 0, out_len); - std::memcpy(out_be, mapped, out_len); - wgpuBufferUnmap(bufRead); - } - - wgpuCommandEncoderRelease(ce); - wgpuCommandBufferRelease(cmd); - wgpuBindGroupRelease(bg); - wgpuBindGroupLayoutRelease(bgl); - wgpuComputePipelineRelease(pso); - wgpuBufferRelease(bufGens); wgpuBufferRelease(bufScalars); - wgpuBufferRelease(bufBlind); - wgpuBufferRelease(bufOut); wgpuBufferRelease(bufDims); - wgpuBufferRelease(bufRead); - return rc; -} - -#else // LUX_PEDERSEN_HAS_WEBGPU not defined: stub mode - -extern "C" int lux_pedersen_tree_wgpu_available(void) { return 0; } -extern "C" int pedersen_tree_wgpu( - const uint8_t*, const uint8_t*, const uint8_t*, - uint32_t, uint8_t*) { return -1; } - -#endif diff --git a/pedersen/gpu/wgsl/pedersen_tree_driver.h b/pedersen/gpu/wgsl/pedersen_tree_driver.h deleted file mode 100644 index 2d4c3a5..0000000 --- a/pedersen/gpu/wgsl/pedersen_tree_driver.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Tree-reduce WebGPU/WGSL driver for the batched Pedersen vector commitment -// at the fixed Verkle width N = 256. - -#ifndef LUX_PEDERSEN_TREE_DRIVER_WGPU_H -#define LUX_PEDERSEN_TREE_DRIVER_WGPU_H - -#include - -#ifndef PEDERSEN_TREE_WIDTH -#define PEDERSEN_TREE_WIDTH 256u -#endif - -#ifdef __cplusplus -extern "C" { -#endif - -int lux_pedersen_tree_wgpu_available(void); - -int pedersen_tree_wgpu( - const uint8_t* gens_be, - const uint8_t* scalars_be, - const uint8_t* blindings_be, - uint32_t M, - uint8_t* out_be); - -#ifdef __cplusplus -} -#endif - -#endif // LUX_PEDERSEN_TREE_DRIVER_WGPU_H diff --git a/poly_mul/gpu/cuda/poly_mul.cu b/poly_mul/gpu/cuda/poly_mul.cu deleted file mode 100644 index c95988f..0000000 --- a/poly_mul/gpu/cuda/poly_mul.cu +++ /dev/null @@ -1,287 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. All Rights Reserved. -// SPDX-License-Identifier: BSD-3-Clause -// -// Polynomial Multiplication - CUDA Port of poly_mul.metal -// Byte-identical output. Supports schoolbook and NTT-based multiplication. - -#include - -#ifdef __CUDA_ARCH__ -#define PM_DEVICE __device__ __forceinline__ -#else -#define PM_DEVICE inline -#define __global__ -#define __shared__ -static inline void __syncthreads() {} -#endif - -// ============================================================================ -// 64-bit Arithmetic (matches Metal's emulated 128-bit via U64 struct) -// On CUDA we have native __int128, but we keep the same struct interface -// to guarantee byte-identical intermediate values. -// ============================================================================ - -struct U64 { uint32_t lo; uint32_t hi; }; - -PM_DEVICE U64 u64_from(uint64_t v) { return {(uint32_t)(v & 0xFFFFFFFFu), (uint32_t)(v >> 32)}; } -PM_DEVICE uint64_t u64_to(U64 v) { return (uint64_t)v.lo | ((uint64_t)v.hi << 32); } -PM_DEVICE U64 u64_zero() { return {0u, 0u}; } - -PM_DEVICE bool u64_gte(U64 a, U64 b) { - if (a.hi > b.hi) return true; - if (a.hi < b.hi) return false; - return a.lo >= b.lo; -} - -PM_DEVICE U64 u64_add(U64 a, U64 b) { - uint32_t lo = a.lo + b.lo; - uint32_t carry = (lo < a.lo) ? 1u : 0u; - return {lo, a.hi + b.hi + carry}; -} - -PM_DEVICE U64 u64_sub(U64 a, U64 b) { - uint32_t borrow = (a.lo < b.lo) ? 1u : 0u; - return {a.lo - b.lo, a.hi - b.hi - borrow}; -} - -PM_DEVICE U64 mul32_to_64(uint32_t a, uint32_t b) { - uint32_t a_lo = a & 0xFFFFu, a_hi = a >> 16u; - uint32_t b_lo = b & 0xFFFFu, b_hi = b >> 16u; - uint32_t p0 = a_lo * b_lo, p1 = a_lo * b_hi; - uint32_t p2 = a_hi * b_lo, p3 = a_hi * b_hi; - uint32_t mid = p1 + p2; - uint32_t mid_carry = (mid < p1) ? 0x10000u : 0u; - uint32_t lo = p0 + (mid << 16u); - uint32_t carry = (lo < p0) ? 1u : 0u; - return {lo, p3 + (mid >> 16u) + mid_carry + carry}; -} - -PM_DEVICE void mul64_to_128(U64 a, U64 b, U64& lo, U64& hi) { - U64 p0 = mul32_to_64(a.lo, b.lo); - U64 p1 = mul32_to_64(a.lo, b.hi); - U64 p2 = mul32_to_64(a.hi, b.lo); - U64 p3 = mul32_to_64(a.hi, b.hi); - lo.lo = p0.lo; - uint32_t sum1 = p0.hi + p1.lo; - uint32_t c1 = (sum1 < p0.hi) ? 1u : 0u; - uint32_t sum2 = sum1 + p2.lo; - uint32_t c2 = (sum2 < sum1) ? 1u : 0u; - lo.hi = sum2; - hi = u64_add(p3, {c1 + c2 + p1.hi + p2.hi, 0u}); -} - -// ============================================================================ -// Modular Arithmetic -// ============================================================================ - -PM_DEVICE U64 mod_add(U64 a, U64 b, U64 q) { - U64 sum = u64_add(a, b); - if (u64_gte(sum, q)) sum = u64_sub(sum, q); - return sum; -} - -PM_DEVICE U64 mod_sub(U64 a, U64 b, U64 q) { - if (u64_gte(a, b)) return u64_sub(a, b); - return u64_sub(u64_add(a, q), b); -} - -PM_DEVICE U64 mont_reduce(U64 lo, U64 hi, U64 q, U64 q_inv) { - U64 m_lo, m_hi; - mul64_to_128(lo, q_inv, m_lo, m_hi); - U64 prod_lo, prod_hi; - mul64_to_128(m_lo, q, prod_lo, prod_hi); - U64 sum = u64_add(lo, prod_lo); - uint32_t carry = (sum.lo < lo.lo || sum.hi < lo.hi) ? 1u : 0u; - U64 result = u64_add(hi, prod_hi); - result = u64_add(result, {carry, 0u}); - if (u64_gte(result, q)) result = u64_sub(result, q); - return result; -} - -PM_DEVICE U64 mont_mul(U64 a, U64 b, U64 q, U64 q_inv) { - U64 lo, hi; - mul64_to_128(a, b, lo, hi); - return mont_reduce(lo, hi, q, q_inv); -} - -// ============================================================================ -// NTT Butterfly Operations -// ============================================================================ - -PM_DEVICE void ct_butterfly(U64& x0, U64& x1, U64 w, U64 q, U64 q_inv) { - U64 t = mont_mul(x1, w, q, q_inv); - x1 = mod_sub(x0, t, q); - x0 = mod_add(x0, t, q); -} - -PM_DEVICE void gs_butterfly(U64& x0, U64& x1, U64 w, U64 q, U64 q_inv) { - U64 t = mod_sub(x0, x1, q); - x0 = mod_add(x0, x1, q); - x1 = mont_mul(t, w, q, q_inv); -} - -// ============================================================================ -// Schoolbook Polynomial Multiplication -// ============================================================================ - -extern "C" __global__ void poly_mul_schoolbook( - const U64* a, const U64* b, U64* c, - U64 q, U64 q_inv, uint32_t n, uint32_t poly_idx) -{ -#ifdef __CUDA_ARCH__ - extern __shared__ U64 smem[]; - U64* s_a = smem; - U64* s_b = smem + n; - uint32_t tid = threadIdx.x; - uint32_t offset = poly_idx * n; - - if (tid < n) { - s_a[tid] = a[offset + tid]; - s_b[tid] = b[offset + tid]; - } - __syncthreads(); - if (tid >= n) return; - - U64 sum = u64_zero(); - for (uint32_t i = 0; i <= tid; i++) { - U64 prod = mont_mul(s_a[i], s_b[tid - i], q, q_inv); - sum = mod_add(sum, prod, q); - } - for (uint32_t i = tid + 1; i < n; i++) { - U64 prod = mont_mul(s_a[i], s_b[n + tid - i], q, q_inv); - sum = mod_sub(sum, prod, q); - } - c[offset + tid] = sum; -#endif -} - -// ============================================================================ -// NTT-Based Polynomial Multiplication -// ============================================================================ - -extern "C" __global__ void ntt_forward_stage( - U64* data, const U64* twiddles, U64 q, U64 q_inv, uint32_t stage, uint32_t log_n) -{ -#ifdef __CUDA_ARCH__ - uint32_t gid = blockIdx.x * blockDim.x + threadIdx.x; - uint32_t n = 1u << log_n; - uint32_t half_n = n >> 1; - if (gid >= half_n) return; - - uint32_t half_size = 1u << stage; - uint32_t group_id = gid / half_size; - uint32_t idx_in_group = gid % half_size; - uint32_t idx0 = group_id * (half_size << 1) + idx_in_group; - uint32_t idx1 = idx0 + half_size; - - U64 x0 = data[idx0], x1 = data[idx1]; - U64 w = twiddles[group_id + half_size]; - ct_butterfly(x0, x1, w, q, q_inv); - data[idx0] = x0; - data[idx1] = x1; -#endif -} - -extern "C" __global__ void ntt_inverse_stage( - U64* data, const U64* twiddles_inv, U64 q, U64 q_inv, uint32_t stage, uint32_t log_n) -{ -#ifdef __CUDA_ARCH__ - uint32_t gid = blockIdx.x * blockDim.x + threadIdx.x; - uint32_t n = 1u << log_n; - uint32_t half_n = n >> 1; - if (gid >= half_n) return; - - uint32_t half_size = 1u << stage; - uint32_t group_id = gid / half_size; - uint32_t idx_in_group = gid % half_size; - uint32_t idx0 = group_id * (half_size << 1) + idx_in_group; - uint32_t idx1 = idx0 + half_size; - - U64 x0 = data[idx0], x1 = data[idx1]; - U64 w = twiddles_inv[group_id + half_size]; - gs_butterfly(x0, x1, w, q, q_inv); - data[idx0] = x0; - data[idx1] = x1; -#endif -} - -extern "C" __global__ void ntt_scale(U64* data, U64 q, U64 q_inv, U64 n_inv, uint32_t n) -{ -#ifdef __CUDA_ARCH__ - uint32_t gid = blockIdx.x * blockDim.x + threadIdx.x; - if (gid >= n) return; - data[gid] = mont_mul(data[gid], n_inv, q, q_inv); -#endif -} - -// ============================================================================ -// Pointwise Operations -// ============================================================================ - -extern "C" __global__ void poly_mul_pointwise( - const U64* a_ntt, const U64* b_ntt, U64* c_ntt, U64 q, U64 q_inv, uint32_t n) -{ -#ifdef __CUDA_ARCH__ - uint32_t gid = blockIdx.x * blockDim.x + threadIdx.x; - if (gid >= n) return; - c_ntt[gid] = mont_mul(a_ntt[gid], b_ntt[gid], q, q_inv); -#endif -} - -extern "C" __global__ void poly_mul_acc( - const U64* a_ntt, const U64* b_ntt, U64* c_ntt, U64 q, U64 q_inv, uint32_t n) -{ -#ifdef __CUDA_ARCH__ - uint32_t gid = blockIdx.x * blockDim.x + threadIdx.x; - if (gid >= n) return; - U64 prod = mont_mul(a_ntt[gid], b_ntt[gid], q, q_inv); - c_ntt[gid] = mod_add(c_ntt[gid], prod, q); -#endif -} - -extern "C" __global__ void poly_scalar_mul( - const U64* a, U64* c, U64 scalar, U64 q, U64 q_inv, uint32_t n) -{ -#ifdef __CUDA_ARCH__ - uint32_t gid = blockIdx.x * blockDim.x + threadIdx.x; - if (gid >= n) return; - c[gid] = mont_mul(a[gid], scalar, q, q_inv); -#endif -} - -// ============================================================================ -// Batch Operations -// ============================================================================ - -extern "C" __global__ void poly_batch_mul_pointwise( - const U64* a, const U64* b, U64* c, U64 q, U64 q_inv, uint32_t n, uint32_t batch_size) -{ -#ifdef __CUDA_ARCH__ - uint32_t gid = blockIdx.x * blockDim.x + threadIdx.x; - uint32_t total = n * batch_size; - if (gid >= total) return; - c[gid] = mont_mul(a[gid], b[gid], q, q_inv); -#endif -} - -extern "C" __global__ void poly_batch_add( - const U64* a, const U64* b, U64* c, U64 q, uint32_t n, uint32_t batch_size) -{ -#ifdef __CUDA_ARCH__ - uint32_t gid = blockIdx.x * blockDim.x + threadIdx.x; - uint32_t total = n * batch_size; - if (gid >= total) return; - c[gid] = mod_add(a[gid], b[gid], q); -#endif -} - -extern "C" __global__ void poly_batch_sub( - const U64* a, const U64* b, U64* c, U64 q, uint32_t n, uint32_t batch_size) -{ -#ifdef __CUDA_ARCH__ - uint32_t gid = blockIdx.x * blockDim.x + threadIdx.x; - uint32_t total = n * batch_size; - if (gid >= total) return; - c[gid] = mod_sub(a[gid], b[gid], q); -#endif -} diff --git a/poly_mul/gpu/metal/poly_mul.metal b/poly_mul/gpu/metal/poly_mul.metal deleted file mode 100644 index 0552697..0000000 --- a/poly_mul/gpu/metal/poly_mul.metal +++ /dev/null @@ -1,351 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. All Rights Reserved. -// SPDX-License-Identifier: BSD-3-Clause -// -// Polynomial Multiplication - High-Performance Metal Implementation -// Supports schoolbook and NTT-based multiplication for lattice cryptography. - -#include -using namespace metal; - -// ============================================================================ -// 64-bit Arithmetic -// ============================================================================ - -struct U64 { - uint lo; - uint hi; -}; - -inline U64 u64_from(ulong v) { - return {uint(v & 0xFFFFFFFFu), uint(v >> 32)}; -} - -inline ulong u64_to(U64 v) { - return ulong(v.lo) | (ulong(v.hi) << 32); -} - -inline U64 u64_zero() { return {0u, 0u}; } -inline U64 u64_one() { return {1u, 0u}; } - -inline bool u64_gte(U64 a, U64 b) { - if (a.hi > b.hi) return true; - if (a.hi < b.hi) return false; - return a.lo >= b.lo; -} - -inline U64 u64_add(U64 a, U64 b) { - uint lo = a.lo + b.lo; - uint carry = (lo < a.lo) ? 1u : 0u; - return {lo, a.hi + b.hi + carry}; -} - -inline U64 u64_sub(U64 a, U64 b) { - uint borrow = (a.lo < b.lo) ? 1u : 0u; - return {a.lo - b.lo, a.hi - b.hi - borrow}; -} - -inline U64 mul32_to_64(uint a, uint b) { - uint a_lo = a & 0xFFFFu; - uint a_hi = a >> 16u; - uint b_lo = b & 0xFFFFu; - uint b_hi = b >> 16u; - - uint p0 = a_lo * b_lo; - uint p1 = a_lo * b_hi; - uint p2 = a_hi * b_lo; - uint p3 = a_hi * b_hi; - - uint mid = p1 + p2; - uint mid_carry = (mid < p1) ? 0x10000u : 0u; - - uint lo = p0 + (mid << 16u); - uint carry = (lo < p0) ? 1u : 0u; - - return {lo, p3 + (mid >> 16u) + mid_carry + carry}; -} - -inline void mul64_to_128(U64 a, U64 b, thread U64& lo, thread U64& hi) { - U64 p0 = mul32_to_64(a.lo, b.lo); - U64 p1 = mul32_to_64(a.lo, b.hi); - U64 p2 = mul32_to_64(a.hi, b.lo); - U64 p3 = mul32_to_64(a.hi, b.hi); - - lo.lo = p0.lo; - uint sum1 = p0.hi + p1.lo; - uint c1 = (sum1 < p0.hi) ? 1u : 0u; - uint sum2 = sum1 + p2.lo; - uint c2 = (sum2 < sum1) ? 1u : 0u; - lo.hi = sum2; - - hi = u64_add(p3, {c1 + c2 + p1.hi + p2.hi, 0u}); -} - -// ============================================================================ -// Modular Arithmetic -// ============================================================================ - -inline U64 mod_add(U64 a, U64 b, U64 q) { - U64 sum = u64_add(a, b); - if (u64_gte(sum, q)) sum = u64_sub(sum, q); - return sum; -} - -inline U64 mod_sub(U64 a, U64 b, U64 q) { - if (u64_gte(a, b)) return u64_sub(a, b); - return u64_sub(u64_add(a, q), b); -} - -inline U64 mont_reduce(U64 lo, U64 hi, U64 q, U64 q_inv) { - U64 m_lo, m_hi; - mul64_to_128(lo, q_inv, m_lo, m_hi); - - U64 prod_lo, prod_hi; - mul64_to_128(m_lo, q, prod_lo, prod_hi); - - U64 sum = u64_add(lo, prod_lo); - uint carry = (sum.lo < lo.lo || sum.hi < lo.hi) ? 1u : 0u; - - U64 result = u64_add(hi, prod_hi); - result = u64_add(result, {carry, 0u}); - - if (u64_gte(result, q)) result = u64_sub(result, q); - return result; -} - -inline U64 mont_mul(U64 a, U64 b, U64 q, U64 q_inv) { - U64 lo, hi; - mul64_to_128(a, b, lo, hi); - return mont_reduce(lo, hi, q, q_inv); -} - -// ============================================================================ -// NTT Butterfly Operations -// ============================================================================ - -inline void ct_butterfly(thread U64& x0, thread U64& x1, U64 w, U64 q, U64 q_inv) { - U64 t = mont_mul(x1, w, q, q_inv); - x1 = mod_sub(x0, t, q); - x0 = mod_add(x0, t, q); -} - -inline void gs_butterfly(thread U64& x0, thread U64& x1, U64 w, U64 q, U64 q_inv) { - U64 t = mod_sub(x0, x1, q); - x0 = mod_add(x0, x1, q); - x1 = mont_mul(t, w, q, q_inv); -} - -// ============================================================================ -// Schoolbook Polynomial Multiplication -// ============================================================================ - -kernel void poly_mul_schoolbook( - device const U64* a [[buffer(0)]], - device const U64* b [[buffer(1)]], - device U64* c [[buffer(2)]], - constant U64& q [[buffer(3)]], - constant U64& q_inv [[buffer(4)]], - constant uint& n [[buffer(5)]], - constant uint& poly_idx [[buffer(6)]], - uint tid [[thread_position_in_threadgroup]], - uint tg_size [[threads_per_threadgroup]], - threadgroup U64* s_a [[threadgroup(0)]], - threadgroup U64* s_b [[threadgroup(1)]] -) { - uint offset = poly_idx * n; - - // Load to threadgroup memory - if (tid < n) { - s_a[tid] = a[offset + tid]; - s_b[tid] = b[offset + tid]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (tid >= n) return; - - // Negacyclic convolution: c[k] = sum(a[i]*b[j]) where i+j=k - // - sum(a[i]*b[j]) where i+j=k+n - U64 sum = u64_zero(); - - // Positive terms - for (uint i = 0; i <= tid; i++) { - U64 prod = mont_mul(s_a[i], s_b[tid - i], q, q_inv); - sum = mod_add(sum, prod, q); - } - - // Negative terms (wraparound) - for (uint i = tid + 1; i < n; i++) { - U64 prod = mont_mul(s_a[i], s_b[n + tid - i], q, q_inv); - sum = mod_sub(sum, prod, q); - } - - c[offset + tid] = sum; -} - -// ============================================================================ -// NTT-Based Polynomial Multiplication -// ============================================================================ - -kernel void ntt_forward_stage( - device U64* data [[buffer(0)]], - device const U64* twiddles [[buffer(1)]], - constant U64& q [[buffer(2)]], - constant U64& q_inv [[buffer(3)]], - constant uint& stage [[buffer(4)]], - constant uint& log_n [[buffer(5)]], - uint gid [[thread_position_in_grid]] -) { - uint n = 1u << log_n; - uint half_n = n >> 1; - - if (gid >= half_n) return; - - uint half_size = 1u << stage; - uint group_id = gid / half_size; - uint idx_in_group = gid % half_size; - - uint idx0 = group_id * (half_size << 1) + idx_in_group; - uint idx1 = idx0 + half_size; - - U64 x0 = data[idx0]; - U64 x1 = data[idx1]; - U64 w = twiddles[group_id + half_size]; - - ct_butterfly(x0, x1, w, q, q_inv); - - data[idx0] = x0; - data[idx1] = x1; -} - -kernel void ntt_inverse_stage( - device U64* data [[buffer(0)]], - device const U64* twiddles_inv [[buffer(1)]], - constant U64& q [[buffer(2)]], - constant U64& q_inv [[buffer(3)]], - constant uint& stage [[buffer(4)]], - constant uint& log_n [[buffer(5)]], - uint gid [[thread_position_in_grid]] -) { - uint n = 1u << log_n; - uint half_n = n >> 1; - - if (gid >= half_n) return; - - uint half_size = 1u << stage; - uint group_id = gid / half_size; - uint idx_in_group = gid % half_size; - - uint idx0 = group_id * (half_size << 1) + idx_in_group; - uint idx1 = idx0 + half_size; - - U64 x0 = data[idx0]; - U64 x1 = data[idx1]; - U64 w = twiddles_inv[group_id + half_size]; - - gs_butterfly(x0, x1, w, q, q_inv); - - data[idx0] = x0; - data[idx1] = x1; -} - -kernel void ntt_scale( - device U64* data [[buffer(0)]], - constant U64& q [[buffer(1)]], - constant U64& q_inv [[buffer(2)]], - constant U64& n_inv [[buffer(3)]], - constant uint& n [[buffer(4)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= n) return; - data[gid] = mont_mul(data[gid], n_inv, q, q_inv); -} - -// ============================================================================ -// Pointwise Operations -// ============================================================================ - -kernel void poly_mul_pointwise( - device const U64* a_ntt [[buffer(0)]], - device const U64* b_ntt [[buffer(1)]], - device U64* c_ntt [[buffer(2)]], - constant U64& q [[buffer(3)]], - constant U64& q_inv [[buffer(4)]], - constant uint& n [[buffer(5)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= n) return; - c_ntt[gid] = mont_mul(a_ntt[gid], b_ntt[gid], q, q_inv); -} - -kernel void poly_mul_acc( - device const U64* a_ntt [[buffer(0)]], - device const U64* b_ntt [[buffer(1)]], - device U64* c_ntt [[buffer(2)]], - constant U64& q [[buffer(3)]], - constant U64& q_inv [[buffer(4)]], - constant uint& n [[buffer(5)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= n) return; - U64 prod = mont_mul(a_ntt[gid], b_ntt[gid], q, q_inv); - c_ntt[gid] = mod_add(c_ntt[gid], prod, q); -} - -kernel void poly_scalar_mul( - device const U64* a [[buffer(0)]], - device U64* c [[buffer(1)]], - constant U64& scalar [[buffer(2)]], - constant U64& q [[buffer(3)]], - constant U64& q_inv [[buffer(4)]], - constant uint& n [[buffer(5)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= n) return; - c[gid] = mont_mul(a[gid], scalar, q, q_inv); -} - -// ============================================================================ -// Batch Operations -// ============================================================================ - -kernel void poly_batch_mul_pointwise( - device const U64* a [[buffer(0)]], - device const U64* b [[buffer(1)]], - device U64* c [[buffer(2)]], - constant U64& q [[buffer(3)]], - constant U64& q_inv [[buffer(4)]], - constant uint& n [[buffer(5)]], - constant uint& batch_size [[buffer(6)]], - uint gid [[thread_position_in_grid]] -) { - uint total = n * batch_size; - if (gid >= total) return; - c[gid] = mont_mul(a[gid], b[gid], q, q_inv); -} - -kernel void poly_batch_add( - device const U64* a [[buffer(0)]], - device const U64* b [[buffer(1)]], - device U64* c [[buffer(2)]], - constant U64& q [[buffer(3)]], - constant uint& n [[buffer(4)]], - constant uint& batch_size [[buffer(5)]], - uint gid [[thread_position_in_grid]] -) { - uint total = n * batch_size; - if (gid >= total) return; - c[gid] = mod_add(a[gid], b[gid], q); -} - -kernel void poly_batch_sub( - device const U64* a [[buffer(0)]], - device const U64* b [[buffer(1)]], - device U64* c [[buffer(2)]], - constant U64& q [[buffer(3)]], - constant uint& n [[buffer(4)]], - constant uint& batch_size [[buffer(5)]], - uint gid [[thread_position_in_grid]] -) { - uint total = n * batch_size; - if (gid >= total) return; - c[gid] = mod_sub(a[gid], b[gid], q); -} diff --git a/poly_mul/gpu/metal/poly_mul_batch.metal b/poly_mul/gpu/metal/poly_mul_batch.metal deleted file mode 100644 index bbab9a9..0000000 --- a/poly_mul/gpu/metal/poly_mul_batch.metal +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (C) 2020-2026, Lux Industries Inc. All rights reserved. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Metal kernel for batched polynomial multiplication over Q = 998244353. -// Byte-equal to luxfi/crypto/poly_mul.MulSchoolbook for n <= 1024. -// -// Threading: one thread per (batch_idx, output_coefficient_idx). Each thread -// computes c[k] = sum_{i+j=k} a[i]*b[j] - sum_{i+j=k+n} a[i]*b[j] for one k -// independently of every other coefficient. n is bounded by NMAX (1024) to -// fit per-thread accumulation into a single uint64. - -#include -using namespace metal; - -constant uint64_t Q_PRIME = 998244353UL; -constant uint NMAX = 1024; - -inline uint64_t add_mod(uint64_t a, uint64_t b) { - uint64_t s = a + b; - if (s >= Q_PRIME) s -= Q_PRIME; - return s; -} - -inline uint64_t sub_mod(uint64_t a, uint64_t b) { - return a >= b ? a - b : a + Q_PRIME - b; -} - -inline uint64_t mul_mod(uint64_t a, uint64_t b) { - // Q < 2^30 so a*b < 2^60 -- fits a uint64 with no overflow before reduction. - return (a * b) % Q_PRIME; -} - -// Each thread computes one coefficient c[k] of one batch element. -// gid.x = output coefficient index k (0..n-1) -// gid.y = batch index -kernel void poly_mul_schoolbook_batch( - device const uint64_t* a [[buffer(0)]], // [batch_size * n] - device const uint64_t* b [[buffer(1)]], // [batch_size * n] - device uint64_t* c [[buffer(2)]], // [batch_size * n] - constant uint& n [[buffer(3)]], - constant uint& batch_size [[buffer(4)]], - uint2 gid [[thread_position_in_grid]]) -{ - if (gid.x >= n || gid.y >= batch_size) return; - if (n > NMAX) return; - - uint k = gid.x; - uint base = gid.y * n; - - uint64_t sum = 0; - - // Positive terms: i+j == k - for (uint i = 0; i <= k; ++i) { - uint j = k - i; - uint64_t prod = mul_mod(a[base + i] % Q_PRIME, b[base + j] % Q_PRIME); - sum = add_mod(sum, prod); - } - // Negative terms (negacyclic wrap): i+j == k+n - for (uint i = k + 1; i < n; ++i) { - uint j = (k + n) - i; - uint64_t prod = mul_mod(a[base + i] % Q_PRIME, b[base + j] % Q_PRIME); - sum = sub_mod(sum, prod); - } - - c[base + k] = sum; -} diff --git a/poly_mul/gpu/metal/poly_mul_batch_driver.mm b/poly_mul/gpu/metal/poly_mul_batch_driver.mm deleted file mode 100644 index 1517c53..0000000 --- a/poly_mul/gpu/metal/poly_mul_batch_driver.mm +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Metal driver for batched poly_mul. macOS only. -// One Metal kernel call processes batch_size independent polynomials of -// length n in parallel; each thread computes a single output coefficient. - -#if __APPLE__ && __OBJC__ - -#import -#import - -#include -#include -#include - -extern "C" int poly_mul_batch_metal( - const uint64_t* a_arena, - const uint64_t* b_arena, - uint64_t* c_arena, - uint32_t n, - uint32_t batch_size, - const char* metallib_path) -{ - if (a_arena == nullptr || b_arena == nullptr || c_arena == nullptr || - metallib_path == nullptr) return -1; - if (n == 0 || batch_size == 0 || n > 1024) return -2; - - @autoreleasepool { - id device = MTLCreateSystemDefaultDevice(); - if (!device) return -3; - - NSError* err = nil; - NSString* path = [NSString stringWithUTF8String:metallib_path]; - NSURL* url = [NSURL fileURLWithPath:path]; - id lib = [device newLibraryWithURL:url error:&err]; - if (!lib) return -4; - - id fn = [lib newFunctionWithName:@"poly_mul_schoolbook_batch"]; - if (!fn) return -5; - - id pipeline = - [device newComputePipelineStateWithFunction:fn error:&err]; - if (!pipeline) return -6; - - id queue = [device newCommandQueue]; - - size_t total = size_t(n) * size_t(batch_size); - size_t bytes = total * sizeof(uint64_t); - - id a_buf = [device newBufferWithBytes:a_arena length:bytes - options:MTLResourceStorageModeShared]; - id b_buf = [device newBufferWithBytes:b_arena length:bytes - options:MTLResourceStorageModeShared]; - id c_buf = [device newBufferWithLength:bytes - options:MTLResourceStorageModeShared]; - - id n_buf = [device newBufferWithBytes:&n length:sizeof(n) - options:MTLResourceStorageModeShared]; - id bs_buf = [device newBufferWithBytes:&batch_size length:sizeof(batch_size) - options:MTLResourceStorageModeShared]; - - id cmd = [queue commandBuffer]; - id enc = [cmd computeCommandEncoder]; - [enc setComputePipelineState:pipeline]; - [enc setBuffer:a_buf offset:0 atIndex:0]; - [enc setBuffer:b_buf offset:0 atIndex:1]; - [enc setBuffer:c_buf offset:0 atIndex:2]; - [enc setBuffer:n_buf offset:0 atIndex:3]; - [enc setBuffer:bs_buf offset:0 atIndex:4]; - - // 2D grid: (n, batch_size). Threadgroup chosen to fit the smaller dim. - NSUInteger tg_max = pipeline.maxTotalThreadsPerThreadgroup; - NSUInteger tg_x = (n < tg_max) ? n : tg_max; - NSUInteger tg_y = 1; - if (tg_max / tg_x >= 1) tg_y = 1; - MTLSize threads_per_grid = MTLSizeMake(n, batch_size, 1); - MTLSize threads_per_tg = MTLSizeMake(tg_x, tg_y, 1); - [enc dispatchThreads:threads_per_grid threadsPerThreadgroup:threads_per_tg]; - [enc endEncoding]; - [cmd commit]; - [cmd waitUntilCompleted]; - - std::memcpy(c_arena, [c_buf contents], bytes); - } - return 0; -} - -#endif // __APPLE__ && __OBJC__ diff --git a/poly_mul/gpu/wgsl/poly_mul.wgsl b/poly_mul/gpu/wgsl/poly_mul.wgsl deleted file mode 100644 index b4491d7..0000000 --- a/poly_mul/gpu/wgsl/poly_mul.wgsl +++ /dev/null @@ -1,165 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Polynomial multiplication in WGSL, ported from poly_mul.metal. -// Schoolbook and NTT-based multiplication for lattice cryptography. -// u64 emulated as vec2(lo, hi). - -@group(0) @binding(0) var poly_a: array>; -@group(0) @binding(1) var poly_b: array>; -@group(0) @binding(2) var poly_out: array>; -@group(0) @binding(3) var params: PolyMulParams; - -struct PolyMulParams { - Q_lo: u32, Q_hi: u32, - Q_inv_lo: u32, Q_inv_hi: u32, - N: u32, batch_size: u32, -} - -fn u64_add(a: vec2, b: vec2) -> vec2 { - let lo = a.x + b.x; - return vec2(lo, a.y + b.y + select(0u, 1u, lo < a.x)); -} -fn u64_sub(a: vec2, b: vec2) -> vec2 { - return vec2(a.x - b.x, a.y - b.y - select(0u, 1u, a.x < b.x)); -} -fn u64_gte(a: vec2, b: vec2) -> bool { - if (a.y != b.y) { return a.y > b.y; } - return a.x >= b.x; -} - -// 32x32 -> 64 -fn mul32_64(a: u32, b: u32) -> vec2 { - let al = a & 0xFFFFu; let ah = a >> 16u; - let bl = b & 0xFFFFu; let bh = b >> 16u; - let ll = al * bl; let mid = al * bh + ah * bl; - let hh = ah * bh; - let lo = ll + (mid << 16u); - let hi = hh + (mid >> 16u) + select(0u, 1u, lo < ll); - return vec2(lo, hi); -} - -// u64 * u64 -> 128 bit (low and high u64) -fn u64_mul_wide(a: vec2, b: vec2, - lo: ptr>, - hi: ptr>) { - let p0 = mul32_64(a.x, b.x); - let p1 = mul32_64(a.x, b.y); - let p2 = mul32_64(a.y, b.x); - let p3 = mul32_64(a.y, b.y); - - (*lo).x = p0.x; - var mid_sum = p0.y + p1.x; - var c1 = select(0u, 1u, mid_sum < p0.y); - mid_sum = mid_sum + p2.x; - c1 = c1 + select(0u, 1u, mid_sum < p2.x); - (*lo).y = mid_sum; - - (*hi) = u64_add(p3, vec2(c1 + p1.y + p2.y, 0u)); -} - -// Montgomery reduction: (lo, hi) * R^{-1} mod Q -fn mont_reduce(lo_val: vec2, hi_val: vec2, - Q: vec2, q_inv: vec2) -> vec2 { - // m = lo * q_inv (mod R, keep low 64 bits) - let al = lo_val.x & 0xFFFFu; let ah = lo_val.x >> 16u; - let bl = q_inv.x & 0xFFFFu; let bh = q_inv.x >> 16u; - let ll = al * bl; let mid = al * bh + ah * bl; - let m_lo = ll + (mid << 16u); - let m_hi = ah * bh + (mid >> 16u) + select(0u, 1u, m_lo < ll) - + lo_val.x * q_inv.y + lo_val.y * q_inv.x; - let m = vec2(m_lo, m_hi); - - // t = m * Q - var t_lo: vec2; var t_hi: vec2; - u64_mul_wide(m, Q, &t_lo, &t_hi); - - // result = (combined - t) >> 64 = hi - t_hi (with borrow from lo) - let borrow = select(0u, 1u, !u64_gte(lo_val, t_lo)); - var r = u64_sub(hi_val, u64_add(t_hi, vec2(borrow, 0u))); - - if (u64_gte(r, Q)) { r = u64_sub(r, Q); } - // Handle negative result - if (r.y > 0x80000000u) { r = u64_add(r, Q); } - return r; -} - -fn mod_add(a: vec2, b: vec2, Q: vec2) -> vec2 { - let s = u64_add(a, b); - if (u64_gte(s, Q)) { return u64_sub(s, Q); } - return s; -} - -fn mod_sub(a: vec2, b: vec2, Q: vec2) -> vec2 { - if (u64_gte(a, b)) { return u64_sub(a, b); } - return u64_sub(u64_add(a, Q), b); -} - -// ============================================================================ -// Schoolbook negacyclic multiplication: c = a * b mod (X^N + 1) mod Q -// One thread per output coefficient. -// ============================================================================ - -@compute @workgroup_size(256) -fn poly_mul_schoolbook(@builtin(global_invocation_id) gid: vec3) { - let batch_idx = gid.y; - let coeff_idx = gid.x; - if (batch_idx >= params.batch_size || coeff_idx >= params.N) { return; } - - let N = params.N; - let Q = vec2(params.Q_lo, params.Q_hi); - let q_inv = vec2(params.Q_inv_lo, params.Q_inv_hi); - let base_a = batch_idx * N; - let base_b = batch_idx * N; - - var acc = vec2(0u, 0u); - - // c[k] = sum_{i+j=k} a[i]*b[j] - sum_{i+j=k+N} a[i]*b[j] (negacyclic) - for (var i = 0u; i < N; i = i + 1u) { - let a_val = poly_a[base_a + i]; - - // Positive contribution: j = k - i (if j >= 0) - if (coeff_idx >= i) { - let j = coeff_idx - i; - let b_val = poly_b[base_b + j]; - var prod_lo: vec2; var prod_hi: vec2; - u64_mul_wide(a_val, b_val, &prod_lo, &prod_hi); - let reduced = mont_reduce(prod_lo, prod_hi, Q, q_inv); - acc = mod_add(acc, reduced, Q); - } - - // Negative contribution: j = k - i + N (wraps, so subtract) - if (coeff_idx < i) { - let j = coeff_idx + N - i; - let b_val = poly_b[base_b + j]; - var prod_lo: vec2; var prod_hi: vec2; - u64_mul_wide(a_val, b_val, &prod_lo, &prod_hi); - let reduced = mont_reduce(prod_lo, prod_hi, Q, q_inv); - acc = mod_sub(acc, reduced, Q); - } - } - - poly_out[batch_idx * N + coeff_idx] = acc; -} - -// ============================================================================ -// Pointwise NTT-domain multiplication (for use after forward NTT) -// ============================================================================ - -@compute @workgroup_size(256) -fn poly_mul_pointwise(@builtin(global_invocation_id) gid: vec3) { - let batch_idx = gid.y; - let coeff_idx = gid.x; - if (batch_idx >= params.batch_size || coeff_idx >= params.N) { return; } - - let Q = vec2(params.Q_lo, params.Q_hi); - let q_inv = vec2(params.Q_inv_lo, params.Q_inv_hi); - let idx = batch_idx * params.N + coeff_idx; - - let a_val = poly_a[idx]; - let b_val = poly_b[idx]; - - var prod_lo: vec2; var prod_hi: vec2; - u64_mul_wide(a_val, b_val, &prod_lo, &prod_hi); - poly_out[idx] = mont_reduce(prod_lo, prod_hi, Q, q_inv); -} diff --git a/poseidon/gpu/cuda/poseidon2_bn254.cu b/poseidon/gpu/cuda/poseidon2_bn254.cu deleted file mode 100644 index e46afc3..0000000 --- a/poseidon/gpu/cuda/poseidon2_bn254.cu +++ /dev/null @@ -1,376 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// First-party CUDA kernel for Poseidon2-BN254 (canonical default permutation). -// -// Mechanical port of poseidon/gpu/metal/poseidon2_bn254.metal -- byte-for-byte -// equivalent to lux::crypto::poseidon::hash2 in poseidon/cpp/poseidon.cpp -// (gnark-crypto v0.20.1 ecc/bn254/fr/poseidon2 with t=2, rF=6, rP=50, d=5). -// -// The round-key constant table is emitted by the same CPU body via -// dump_round_keys -> gen_gpu_constants -> poseidon2_bn254_rk.cuh, so there is -// exactly one source of truth for round constants across CPU, Metal, CUDA, -// and WGSL backends. -// -// Compile with nvcc compute_60+ on Linux/CUDA hosts. On hosts without nvcc -// (Apple, plain g++/clang++), the file compiles via a host-side polyfill that -// elides __device__/__global__/__host__ qualifiers and lets the kernel run -// on the CPU oracle path. The polyfill is exercised by the parity test; -// production GPU dispatch uses the real nvcc-compiled path. - -#include -#include - -// Polyfill: when not building with nvcc, neutralize the device qualifiers. -#ifndef __CUDA_ARCH__ -# ifndef LUX_POSEIDON_CUDA_HOST_POLYFILL -# define LUX_POSEIDON_CUDA_HOST_POLYFILL 1 -# endif -#endif - -#if LUX_POSEIDON_CUDA_HOST_POLYFILL -# define __device__ -# define __global__ -# define __host__ -# define __forceinline__ inline -#endif - -#include "poseidon2_bn254_rk.cuh" // POSEIDON2_RK[56][2][4] - -// ============================================================================= -// BN254 Fr modulus q (4x64 little-endian limbs). -// q = 21888242871839275222246405745257275088548364400416034343698204186575808495617 -// Montgomery params: qInvNeg = -q^{-1} mod 2^64; rSquare = R^2 mod q. -// All values match poseidon/cpp/poseidon.cpp byte-for-byte. -// ============================================================================= -__device__ static const unsigned long long Q0 = 0x43e1f593f0000001ULL; -__device__ static const unsigned long long Q1 = 0x2833e84879b97091ULL; -__device__ static const unsigned long long Q2 = 0xb85045b68181585dULL; -__device__ static const unsigned long long Q3 = 0x30644e72e131a029ULL; -__device__ static const unsigned long long Q_INV_NEG = 0xc2e1f593efffffffULL; - -__device__ static const unsigned long long R_SQUARE_0 = 1997599621687373223ULL; -__device__ static const unsigned long long R_SQUARE_1 = 6052339484930628067ULL; -__device__ static const unsigned long long R_SQUARE_2 = 10108755138030829701ULL; -__device__ static const unsigned long long R_SQUARE_3 = 150537098327114917ULL; - -// 256-bit field element in Montgomery form. -struct Fr { - unsigned long long l0, l1, l2, l3; -}; - -// ============================================================================= -// 64x64 -> 128 multiply. nvcc emits __umul64hi for the high half on the GPU; -// the host polyfill uses __int128 (or a portable fallback if unavailable). -// ============================================================================= -__device__ __forceinline__ void mul64(unsigned long long a, - unsigned long long b, - unsigned long long &lo, - unsigned long long &hi) { -#if defined(__CUDA_ARCH__) - lo = a * b; - hi = __umul64hi(a, b); -#elif defined(__SIZEOF_INT128__) - unsigned __int128 t = (unsigned __int128)a * b; - lo = (unsigned long long)t; - hi = (unsigned long long)(t >> 64); -#else - unsigned long long al = a & 0xffffffffULL, ah = a >> 32; - unsigned long long bl = b & 0xffffffffULL, bh = b >> 32; - unsigned long long ll = al * bl; - unsigned long long lh = al * bh; - unsigned long long hl = ah * bl; - unsigned long long hh = ah * bh; - unsigned long long mid = - (ll >> 32) + (lh & 0xffffffffULL) + (hl & 0xffffffffULL); - lo = (ll & 0xffffffffULL) | (mid << 32); - hi = hh + (lh >> 32) + (hl >> 32) + (mid >> 32); -#endif -} - -// adc/sbb: same carry-chain pattern as the CPU body. -__device__ __forceinline__ unsigned long long adc(unsigned long long a, - unsigned long long b, - unsigned long long &carry) { - unsigned long long s = a + b; - unsigned long long c1 = (s < a) ? 1ULL : 0ULL; - unsigned long long s2 = s + carry; - unsigned long long c2 = (s2 < s) ? 1ULL : 0ULL; - carry = c1 + c2; - return s2; -} - -__device__ __forceinline__ unsigned long long sbb(unsigned long long a, - unsigned long long b, - unsigned long long &borrow) { - unsigned long long d = a - b; - unsigned long long b1 = (a < b) ? 1ULL : 0ULL; - unsigned long long d2 = d - borrow; - unsigned long long b2 = (d < borrow) ? 1ULL : 0ULL; - borrow = b1 + b2; - return d2; -} - -__device__ __forceinline__ int cmp_q(const Fr &a) { - if (a.l3 != Q3) return (a.l3 < Q3) ? -1 : 1; - if (a.l2 != Q2) return (a.l2 < Q2) ? -1 : 1; - if (a.l1 != Q1) return (a.l1 < Q1) ? -1 : 1; - if (a.l0 != Q0) return (a.l0 < Q0) ? -1 : 1; - return 0; -} - -__device__ __forceinline__ void reduce_once(Fr &a) { - if (cmp_q(a) >= 0) { - unsigned long long br = 0; - a.l0 = sbb(a.l0, Q0, br); - a.l1 = sbb(a.l1, Q1, br); - a.l2 = sbb(a.l2, Q2, br); - a.l3 = sbb(a.l3, Q3, br); - } -} - -__device__ __forceinline__ Fr fr_add(const Fr &a, const Fr &b) { - Fr c; - unsigned long long cy = 0; - c.l0 = adc(a.l0, b.l0, cy); - c.l1 = adc(a.l1, b.l1, cy); - c.l2 = adc(a.l2, b.l2, cy); - c.l3 = adc(a.l3, b.l3, cy); - if (cy != 0 || cmp_q(c) >= 0) { - unsigned long long br = 0; - c.l0 = sbb(c.l0, Q0, br); - c.l1 = sbb(c.l1, Q1, br); - c.l2 = sbb(c.l2, Q2, br); - c.l3 = sbb(c.l3, Q3, br); - } - return c; -} - -__device__ __forceinline__ Fr fr_double(const Fr &a) { return fr_add(a, a); } - -// CIOS Montgomery multiplication. Identical algorithm to the CPU body. -__device__ __forceinline__ Fr fr_mul(const Fr &a, const Fr &b) { - unsigned long long t[5] = {0, 0, 0, 0, 0}; - const unsigned long long al[4] = {a.l0, a.l1, a.l2, a.l3}; - const unsigned long long bl[4] = {b.l0, b.l1, b.l2, b.l3}; - const unsigned long long qq[4] = {Q0, Q1, Q2, Q3}; - - for (int i = 0; i < 4; ++i) { - unsigned long long cy = 0; - for (int j = 0; j < 4; ++j) { - unsigned long long lo, hi; - mul64(al[j], bl[i], lo, hi); - unsigned long long s = t[j] + lo; - unsigned long long c1 = (s < t[j]) ? 1ULL : 0ULL; - unsigned long long s2 = s + cy; - unsigned long long c2 = (s2 < s) ? 1ULL : 0ULL; - t[j] = s2; - cy = hi + c1 + c2; - } - t[4] += cy; - - unsigned long long m = t[0] * Q_INV_NEG; - - cy = 0; - for (int j = 0; j < 4; ++j) { - unsigned long long lo, hi; - mul64(m, qq[j], lo, hi); - unsigned long long s = t[j] + lo; - unsigned long long c1 = (s < t[j]) ? 1ULL : 0ULL; - unsigned long long s2 = s + cy; - unsigned long long c2 = (s2 < s) ? 1ULL : 0ULL; - t[j] = s2; - cy = hi + c1 + c2; - } - t[4] += cy; - - t[0] = t[1]; - t[1] = t[2]; - t[2] = t[3]; - t[3] = t[4]; - t[4] = 0; - } - Fr c; - c.l0 = t[0]; c.l1 = t[1]; c.l2 = t[2]; c.l3 = t[3]; - reduce_once(c); - return c; -} - -__device__ __forceinline__ Fr fr_square(const Fr &a) { return fr_mul(a, a); } - -// ============================================================================= -// Poseidon2-BN254 default permutation (t=2, rF=6 split 3+3, rP=50, d=5). -// ============================================================================= - -__device__ __forceinline__ void sbox(Fr &x) { - Fr x2 = fr_square(x); - Fr x4 = fr_square(x2); - x = fr_mul(x4, x); -} - -__device__ __forceinline__ void mat_mul_external(Fr s[2]) { - Fr tmp = fr_add(s[0], s[1]); - s[0] = fr_add(s[0], tmp); - s[1] = fr_add(s[1], tmp); -} - -__device__ __forceinline__ void mat_mul_internal(Fr s[2]) { - Fr sum = fr_add(s[0], s[1]); - s[0] = fr_add(s[0], sum); - Fr s1d = fr_double(s[1]); - s[1] = fr_add(s1d, sum); -} - -#define POSEIDON_FULL_HALF 3 -#define POSEIDON_PARTIAL 50 - -__device__ __forceinline__ Fr load_rk(int round, int slot) { - Fr r; - r.l0 = POSEIDON2_RK[round][slot][0]; - r.l1 = POSEIDON2_RK[round][slot][1]; - r.l2 = POSEIDON2_RK[round][slot][2]; - r.l3 = POSEIDON2_RK[round][slot][3]; - return r; -} - -__device__ __forceinline__ void permute(Fr s[2]) { - mat_mul_external(s); - for (int i = 0; i < POSEIDON_FULL_HALF; ++i) { - Fr k0 = load_rk(i, 0); - Fr k1 = load_rk(i, 1); - s[0] = fr_add(s[0], k0); - s[1] = fr_add(s[1], k1); - sbox(s[0]); - sbox(s[1]); - mat_mul_external(s); - } - for (int i = 0; i < POSEIDON_PARTIAL; ++i) { - Fr k0 = load_rk(POSEIDON_FULL_HALF + i, 0); - s[0] = fr_add(s[0], k0); - sbox(s[0]); - mat_mul_internal(s); - } - for (int i = 0; i < POSEIDON_FULL_HALF; ++i) { - Fr k0 = load_rk(POSEIDON_FULL_HALF + POSEIDON_PARTIAL + i, 0); - Fr k1 = load_rk(POSEIDON_FULL_HALF + POSEIDON_PARTIAL + i, 1); - s[0] = fr_add(s[0], k0); - s[1] = fr_add(s[1], k1); - sbox(s[0]); - sbox(s[1]); - mat_mul_external(s); - } -} - -// ============================================================================= -// Bytes (BE) <-> Fr (Montgomery LE limbs) conversions. -// ============================================================================= - -__device__ __forceinline__ unsigned long long be_read64(const unsigned char *p) { - unsigned long long v = 0; - for (int b = 0; b < 8; ++b) v = (v << 8) | (unsigned long long)p[b]; - return v; -} - -__device__ __forceinline__ Fr be_to_fr_mont(const unsigned char *be) { - Fr x; - x.l0 = be_read64(be + 24); - x.l1 = be_read64(be + 16); - x.l2 = be_read64(be + 8); - x.l3 = be_read64(be + 0); - for (int i = 0; i < 4; ++i) { - if (cmp_q(x) < 0) break; - unsigned long long br = 0; - x.l0 = sbb(x.l0, Q0, br); - x.l1 = sbb(x.l1, Q1, br); - x.l2 = sbb(x.l2, Q2, br); - x.l3 = sbb(x.l3, Q3, br); - } - Fr r2; - r2.l0 = R_SQUARE_0; r2.l1 = R_SQUARE_1; - r2.l2 = R_SQUARE_2; r2.l3 = R_SQUARE_3; - return fr_mul(x, r2); -} - -__device__ __forceinline__ void fr_mont_to_be(const Fr &x, unsigned char *be) { - Fr one_reg; - one_reg.l0 = 1; one_reg.l1 = 0; one_reg.l2 = 0; one_reg.l3 = 0; - Fr r = fr_mul(x, one_reg); - unsigned long long limbs[4] = {r.l0, r.l1, r.l2, r.l3}; - for (int i = 0; i < 4; ++i) { - unsigned long long v = limbs[i]; - int off = 32 - 8 * (i + 1); - for (int b = 7; b >= 0; --b) { - be[off + b] = (unsigned char)(v & 0xff); - v >>= 8; - } - } -} - -// ============================================================================= -// Per-thread compression body. Used by the global kernel and by the host -// polyfill loop. -// ============================================================================= -__device__ static void poseidon2_hash2_one(const unsigned char *pair_in, - unsigned char *out) { - Fr s[2]; - s[0] = be_to_fr_mont(pair_in); - s[1] = be_to_fr_mont(pair_in + 32); - Fr saved_right = s[1]; - permute(s); - Fr digest = fr_add(saved_right, s[1]); - fr_mont_to_be(digest, out); -} - -#if !LUX_POSEIDON_CUDA_HOST_POLYFILL -extern "C" __global__ void poseidon2_hash2_batch_kernel( - const unsigned char *pairs, - unsigned char *outs, - unsigned int n) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= n) return; - poseidon2_hash2_one(pairs + i * 64, outs + i * 32); -} -#endif - -// ============================================================================= -// Host driver C-ABI. On real CUDA hosts this calls cudaMalloc/cudaMemcpy/ -// kernel<<<>>> dispatch; on polyfill hosts (no nvcc), it loops the same -// per-thread body on the CPU. Either way the output is byte-equal to the -// CPU oracle by construction (constants come from the CPU body). -// ============================================================================= -extern "C" int poseidon2_hash2_cuda_batch(const unsigned char *pairs, - unsigned char *outs, - unsigned long n) { - if (n == 0) return 0; - if (!pairs || !outs) return -1; - -#if LUX_POSEIDON_CUDA_HOST_POLYFILL - for (unsigned long i = 0; i < n; ++i) { - poseidon2_hash2_one(pairs + i * 64, outs + i * 32); - } - return 0; -#else - unsigned char *d_pairs = nullptr, *d_outs = nullptr; - cudaError_t st; - st = cudaMalloc(&d_pairs, (size_t)n * 64); - if (st != cudaSuccess) { return -2; } - st = cudaMalloc(&d_outs, (size_t)n * 32); - if (st != cudaSuccess) { cudaFree(d_pairs); return -2; } - st = cudaMemcpy(d_pairs, pairs, (size_t)n * 64, cudaMemcpyHostToDevice); - if (st != cudaSuccess) { cudaFree(d_pairs); cudaFree(d_outs); return -3; } - - unsigned int tpb = 64; - unsigned int blocks = (unsigned int)((n + tpb - 1) / tpb); - poseidon2_hash2_batch_kernel<<>>(d_pairs, d_outs, (unsigned int)n); - st = cudaGetLastError(); - if (st != cudaSuccess) { cudaFree(d_pairs); cudaFree(d_outs); return -4; } - st = cudaDeviceSynchronize(); - if (st != cudaSuccess) { cudaFree(d_pairs); cudaFree(d_outs); return -4; } - - st = cudaMemcpy(outs, d_outs, (size_t)n * 32, cudaMemcpyDeviceToHost); - cudaFree(d_pairs); - cudaFree(d_outs); - if (st != cudaSuccess) { return -5; } - return 0; -#endif -} diff --git a/poseidon/gpu/cuda/poseidon2_driver.h b/poseidon/gpu/cuda/poseidon2_driver.h deleted file mode 100644 index 514988b..0000000 --- a/poseidon/gpu/cuda/poseidon2_driver.h +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CUDA driver for Poseidon2-BN254. Linux/CUDA on real GPUs; host-side polyfill -// when nvcc is unavailable so the CPU oracle test path stays exercised. -// -// Byte-equal-by-construction to lux::crypto::poseidon::hash2 -- the constant -// table is generated from the CPU body via dump_round_keys -> gen_gpu_constants -// and #include'd by poseidon2_bn254.cu. - -#ifndef LUX_POSEIDON2_CUDA_DRIVER_H -#define LUX_POSEIDON2_CUDA_DRIVER_H - -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -// Run n Poseidon2.Compress calls in one CUDA dispatch (or host loop on -// polyfill builds). -// -// pairs : n * 64 bytes, layout = [BE(left_i) || BE(right_i)] for i=0..n-1 -// outs : n * 32 bytes, BE digest written per pair -// n : number of pairs -// -// Returns 0 on success. Negative on failure: -1 invalid arg, -2 cudaMalloc, -// -3 H2D copy, -4 kernel launch / sync, -5 D2H copy. -int poseidon2_hash2_cuda_batch(const unsigned char *pairs, - unsigned char *outs, - unsigned long n); - -#ifdef __cplusplus -} -#endif - -#endif // LUX_POSEIDON2_CUDA_DRIVER_H diff --git a/poseidon/gpu/metal/attestation.metal b/poseidon/gpu/metal/attestation.metal deleted file mode 100644 index 8ebdd9e..0000000 --- a/poseidon/gpu/metal/attestation.metal +++ /dev/null @@ -1,459 +0,0 @@ -// ============================================================================= -// Attestation Verification - Metal Compute Shaders -// ============================================================================= -// -// GPU-accelerated TEE attestation verification for NVTrust and TPM quotes. -// Batch processing for high-throughput AI mining verification. -// -// Operations: -// - SHA-256/SHA-384 hash computation for quote verification -// - ECDSA P-384 signature verification (for NVTrust) -// - Certificate chain validation helpers -// -// Copyright (C) 2024-2025 Lux Industries Inc. -// SPDX-License-Identifier: Apache-2.0 - -#include -using namespace metal; - -// ============================================================================= -// SHA-256 Constants and State -// ============================================================================= - -constant uint32_t SHA256_K[64] = { - 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, - 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, - 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, - 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, - 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, - 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, - 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, - 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2 -}; - -constant uint32_t SHA256_H[8] = { - 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, - 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19 -}; - -// SHA-256 helper functions -inline uint32_t sha256_rotr(uint32_t x, uint32_t n) { - return (x >> n) | (x << (32 - n)); -} - -inline uint32_t sha256_ch(uint32_t x, uint32_t y, uint32_t z) { - return (x & y) ^ (~x & z); -} - -inline uint32_t sha256_maj(uint32_t x, uint32_t y, uint32_t z) { - return (x & y) ^ (x & z) ^ (y & z); -} - -inline uint32_t sha256_sigma0(uint32_t x) { - return sha256_rotr(x, 2) ^ sha256_rotr(x, 13) ^ sha256_rotr(x, 22); -} - -inline uint32_t sha256_sigma1(uint32_t x) { - return sha256_rotr(x, 6) ^ sha256_rotr(x, 11) ^ sha256_rotr(x, 25); -} - -inline uint32_t sha256_gamma0(uint32_t x) { - return sha256_rotr(x, 7) ^ sha256_rotr(x, 18) ^ (x >> 3); -} - -inline uint32_t sha256_gamma1(uint32_t x) { - return sha256_rotr(x, 17) ^ sha256_rotr(x, 19) ^ (x >> 10); -} - -// ============================================================================= -// Attestation Verification Parameters -// ============================================================================= - -struct AttestationParams { - uint32_t batch_size; // Number of attestations to verify - uint32_t quote_size; // Size of each quote - uint32_t cert_offset; // Offset to certificate chain - uint32_t sig_offset; // Offset to signature -}; - -// ============================================================================= -// SHA-256 Hash Kernel (Single Block) -// ============================================================================= - -kernel void sha256_hash_block( - device uint32_t* state [[buffer(0)]], // 8 uint32 state words per hash - constant uint8_t* data [[buffer(1)]], // 64 bytes per block - constant AttestationParams& params [[buffer(2)]], - uint tid [[thread_position_in_grid]] -) { - if (tid >= params.batch_size) return; - - // Load state - device uint32_t* h = state + tid * 8; - constant uint8_t* block = data + tid * 64; - - // Message schedule - uint32_t w[64]; - - // Load first 16 words (big-endian) - for (int i = 0; i < 16; ++i) { - w[i] = (uint32_t(block[i*4]) << 24) | - (uint32_t(block[i*4 + 1]) << 16) | - (uint32_t(block[i*4 + 2]) << 8) | - uint32_t(block[i*4 + 3]); - } - - // Extend to 64 words - for (int i = 16; i < 64; ++i) { - w[i] = sha256_gamma1(w[i-2]) + w[i-7] + sha256_gamma0(w[i-15]) + w[i-16]; - } - - // Working variables - uint32_t a = h[0], b = h[1], c = h[2], d = h[3]; - uint32_t e = h[4], f = h[5], g = h[6], hh = h[7]; - - // 64 rounds - for (int i = 0; i < 64; ++i) { - uint32_t t1 = hh + sha256_sigma1(e) + sha256_ch(e, f, g) + SHA256_K[i] + w[i]; - uint32_t t2 = sha256_sigma0(a) + sha256_maj(a, b, c); - - hh = g; g = f; f = e; - e = d + t1; - d = c; c = b; b = a; - a = t1 + t2; - } - - // Update state - h[0] += a; h[1] += b; h[2] += c; h[3] += d; - h[4] += e; h[5] += f; h[6] += g; h[7] += hh; -} - -// ============================================================================= -// Quote Hash Computation -// ============================================================================= - -// Compute SHA-256 hash of attestation quote data -kernel void compute_quote_hash( - device uint32_t* hashes [[buffer(0)]], // Output: 8 uint32 per quote - constant uint8_t* quotes [[buffer(1)]], // Input: quote data - constant AttestationParams& params [[buffer(2)]], - uint tid [[thread_position_in_grid]] -) { - if (tid >= params.batch_size) return; - - device uint32_t* h = hashes + tid * 8; - constant uint8_t* quote = quotes + tid * params.quote_size; - - // Initialize state - for (int i = 0; i < 8; ++i) { - h[i] = SHA256_H[i]; - } - - // Hash quote data (excluding signature) - uint32_t data_size = params.sig_offset; - uint32_t num_blocks = (data_size + 9 + 63) / 64; // +9 for length + padding - - // Process complete blocks - uint32_t block_idx = 0; - uint32_t w[64]; - - for (; block_idx < data_size / 64; ++block_idx) { - constant uint8_t* block = quote + block_idx * 64; - - for (int i = 0; i < 16; ++i) { - w[i] = (uint32_t(block[i*4]) << 24) | - (uint32_t(block[i*4 + 1]) << 16) | - (uint32_t(block[i*4 + 2]) << 8) | - uint32_t(block[i*4 + 3]); - } - - for (int i = 16; i < 64; ++i) { - w[i] = sha256_gamma1(w[i-2]) + w[i-7] + sha256_gamma0(w[i-15]) + w[i-16]; - } - - uint32_t a = h[0], b = h[1], c = h[2], d = h[3]; - uint32_t e = h[4], f = h[5], g = h[6], hh = h[7]; - - for (int i = 0; i < 64; ++i) { - uint32_t t1 = hh + sha256_sigma1(e) + sha256_ch(e, f, g) + SHA256_K[i] + w[i]; - uint32_t t2 = sha256_sigma0(a) + sha256_maj(a, b, c); - hh = g; g = f; f = e; e = d + t1; - d = c; c = b; b = a; a = t1 + t2; - } - - h[0] += a; h[1] += b; h[2] += c; h[3] += d; - h[4] += e; h[5] += f; h[6] += g; h[7] += hh; - } - - // Handle final block with padding (simplified - assumes < 55 bytes remaining) - uint8_t final_block[64]; - uint32_t remaining = data_size % 64; - - for (uint32_t i = 0; i < remaining; ++i) { - final_block[i] = quote[block_idx * 64 + i]; - } - final_block[remaining] = 0x80; // Padding bit - - for (uint32_t i = remaining + 1; i < 56; ++i) { - final_block[i] = 0; - } - - // Length in bits (big-endian) - uint64_t bit_len = uint64_t(data_size) * 8; - for (int i = 0; i < 8; ++i) { - final_block[56 + i] = uint8_t(bit_len >> (56 - i * 8)); - } - - // Hash final block - for (int i = 0; i < 16; ++i) { - w[i] = (uint32_t(final_block[i*4]) << 24) | - (uint32_t(final_block[i*4 + 1]) << 16) | - (uint32_t(final_block[i*4 + 2]) << 8) | - uint32_t(final_block[i*4 + 3]); - } - - for (int i = 16; i < 64; ++i) { - w[i] = sha256_gamma1(w[i-2]) + w[i-7] + sha256_gamma0(w[i-15]) + w[i-16]; - } - - uint32_t a = h[0], b = h[1], c = h[2], d = h[3]; - uint32_t e = h[4], f = h[5], g = h[6], hh = h[7]; - - for (int i = 0; i < 64; ++i) { - uint32_t t1 = hh + sha256_sigma1(e) + sha256_ch(e, f, g) + SHA256_K[i] + w[i]; - uint32_t t2 = sha256_sigma0(a) + sha256_maj(a, b, c); - hh = g; g = f; f = e; e = d + t1; - d = c; c = b; b = a; a = t1 + t2; - } - - h[0] += a; h[1] += b; h[2] += c; h[3] += d; - h[4] += e; h[5] += f; h[6] += g; h[7] += hh; -} - -// ============================================================================= -// P-384 Field Arithmetic (6 x 64-bit limbs) -// ============================================================================= - -// P-384 prime: p = 2^384 - 2^128 - 2^96 + 2^32 - 1 -constant uint64_t P384_P[6] = { - 0x00000000ffffffff, - 0xffffffff00000000, - 0xfffffffffffffffe, - 0xffffffffffffffff, - 0xffffffffffffffff, - 0xffffffffffffffff -}; - -struct P384Element { - uint64_t limbs[6]; -}; - -struct P384Point { - P384Element x; - P384Element y; - bool infinity; -}; - -// Add with carry -inline uint64_t p384_adc(uint64_t a, uint64_t b, thread uint64_t& carry) { - uint64_t sum = a + b + carry; - carry = (sum < a) || (carry && sum == a) ? 1 : 0; - return sum; -} - -// Subtract with borrow -inline uint64_t p384_sbb(uint64_t a, uint64_t b, thread uint64_t& borrow) { - uint64_t diff = a - b - borrow; - borrow = (a < b + borrow) ? 1 : 0; - return diff; -} - -// Modular addition -inline P384Element p384_add(thread const P384Element& a, thread const P384Element& b) { - P384Element c; - uint64_t carry = 0; - - for (int i = 0; i < 6; ++i) { - c.limbs[i] = p384_adc(a.limbs[i], b.limbs[i], carry); - } - - // Reduce if >= p - bool ge_p = carry != 0; - if (!ge_p) { - for (int i = 5; i >= 0; --i) { - if (c.limbs[i] > P384_P[i]) { ge_p = true; break; } - if (c.limbs[i] < P384_P[i]) break; - } - } - - if (ge_p) { - uint64_t borrow = 0; - for (int i = 0; i < 6; ++i) { - c.limbs[i] = p384_sbb(c.limbs[i], P384_P[i], borrow); - } - } - - return c; -} - -// Modular subtraction -inline P384Element p384_sub(thread const P384Element& a, thread const P384Element& b) { - P384Element c; - uint64_t borrow = 0; - - for (int i = 0; i < 6; ++i) { - c.limbs[i] = p384_sbb(a.limbs[i], b.limbs[i], borrow); - } - - // Add p if underflow - if (borrow) { - uint64_t carry = 0; - for (int i = 0; i < 6; ++i) { - c.limbs[i] = p384_adc(c.limbs[i], P384_P[i], carry); - } - } - - return c; -} - -// Check if element is zero -inline bool p384_is_zero(thread const P384Element& a) { - for (int i = 0; i < 6; ++i) { - if (a.limbs[i] != 0) return false; - } - return true; -} - -// ============================================================================= -// ECDSA P-384 Signature Verification (Simplified) -// ============================================================================= - -// Verify result computation (partial - full impl requires scalar mul) -kernel void ecdsa_p384_verify_prepare( - device uint32_t* results [[buffer(0)]], // Output: 1=potentially valid - constant uint8_t* signatures [[buffer(1)]], // r || s (96 bytes each) - constant uint8_t* hashes [[buffer(2)]], // 48 bytes each - constant uint8_t* pubkeys [[buffer(3)]], // x || y (96 bytes each) - constant AttestationParams& params [[buffer(4)]], - uint tid [[thread_position_in_grid]] -) { - if (tid >= params.batch_size) return; - - constant uint8_t* sig = signatures + tid * 96; - constant uint8_t* hash = hashes + tid * 48; - constant uint8_t* pk = pubkeys + tid * 96; - - // Parse r and s from signature (big-endian) - P384Element r, s; - for (int i = 0; i < 6; ++i) { - r.limbs[5-i] = 0; - s.limbs[5-i] = 0; - for (int j = 0; j < 8; ++j) { - r.limbs[5-i] |= uint64_t(sig[i*8 + j]) << (56 - j*8); - s.limbs[5-i] |= uint64_t(sig[48 + i*8 + j]) << (56 - j*8); - } - } - - // Basic validation: r and s must be in [1, n-1] - bool r_valid = !p384_is_zero(r); - bool s_valid = !p384_is_zero(s); - - // Check r < n and s < n (n is the curve order, slightly smaller than p) - // For simplicity, just check they're not zero - results[tid] = (r_valid && s_valid) ? 1 : 0; -} - -// ============================================================================= -// Trust Score Computation -// ============================================================================= - -struct TrustScoreParams { - uint32_t batch_size; - uint8_t hardware_cc_bonus; // Points for hardware CC - uint8_t rim_verified_bonus; // Points for RIM verification - uint8_t tee_io_bonus; // Points for TEE I/O - uint8_t base_score; // Base trust score -}; - -kernel void compute_trust_scores( - device uint8_t* scores [[buffer(0)]], // Output: trust scores - constant uint8_t* cc_enabled [[buffer(1)]], // CC enabled flags - constant uint8_t* hardware_cc [[buffer(2)]], // Hardware CC flags - constant uint8_t* rim_verified [[buffer(3)]], // RIM verified flags - constant uint8_t* tee_io [[buffer(4)]], // TEE I/O flags - constant TrustScoreParams& params [[buffer(5)]], - uint tid [[thread_position_in_grid]] -) { - if (tid >= params.batch_size) return; - - uint8_t score = params.base_score; - - if (hardware_cc[tid] && cc_enabled[tid]) { - score += params.hardware_cc_bonus; - } else if (cc_enabled[tid]) { - score += params.hardware_cc_bonus / 2; // Software CC - } - - if (rim_verified[tid]) { - score += params.rim_verified_bonus; - } - - if (tee_io[tid]) { - score += params.tee_io_bonus; - } - - // Cap at 100 - scores[tid] = score > 100 ? 100 : score; -} - -// ============================================================================= -// Batch Verification Orchestration -// ============================================================================= - -struct VerifyResult { - uint32_t valid_count; - uint32_t invalid_count; - uint32_t total_trust_score; - uint32_t reserved; -}; - -// Reduce verification results -kernel void reduce_verify_results( - device VerifyResult* result [[buffer(0)]], - constant uint32_t* valid_flags [[buffer(1)]], - constant uint8_t* trust_scores [[buffer(2)]], - constant AttestationParams& params [[buffer(3)]], - uint tid [[thread_position_in_grid]], - uint threads_per_group [[threads_per_threadgroup]], - threadgroup uint32_t* shared_valid [[threadgroup(0)]], - threadgroup uint32_t* shared_trust [[threadgroup(1)]] -) { - // Load and accumulate - uint32_t local_valid = 0; - uint32_t local_trust = 0; - - for (uint32_t i = tid; i < params.batch_size; i += threads_per_group) { - local_valid += valid_flags[i]; - local_trust += trust_scores[i]; - } - - shared_valid[tid] = local_valid; - shared_trust[tid] = local_trust; - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Tree reduction - for (uint32_t stride = threads_per_group / 2; stride > 0; stride >>= 1) { - if (tid < stride) { - shared_valid[tid] += shared_valid[tid + stride]; - shared_trust[tid] += shared_trust[tid + stride]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - // Write final result - if (tid == 0) { - result->valid_count = shared_valid[0]; - result->invalid_count = params.batch_size - shared_valid[0]; - result->total_trust_score = shared_trust[0]; - } -} diff --git a/poseidon/gpu/metal/goldilocks.metal b/poseidon/gpu/metal/goldilocks.metal deleted file mode 100644 index 93ae1dd..0000000 --- a/poseidon/gpu/metal/goldilocks.metal +++ /dev/null @@ -1,412 +0,0 @@ -// Goldilocks Field Arithmetic for STARK Verification -// Field: p = 2^64 - 2^32 + 1 (Goldilocks prime) -// -// This shader provides GPU-accelerated field operations for STARK proofs, -// including FRI (Fast Reed-Solomon IOP) folding and verification. - -#include -using namespace metal; - -// Goldilocks prime: p = 2^64 - 2^32 + 1 = 0xFFFFFFFF00000001 -constant uint64_t GOLDILOCKS_P = 0xFFFFFFFF00000001ULL; -constant uint64_t GOLDILOCKS_P_MINUS_2 = 0xFFFFFFFEFFFFFFFFULL; - -// Montgomery parameters for Goldilocks -// R = 2^64, R^2 mod p, p' = -p^(-1) mod R -constant uint64_t GOLDILOCKS_R2 = 0xFFFFFFFE00000001ULL; // R^2 mod p -constant uint64_t GOLDILOCKS_PINV = 0xFFFFFFFF00000001ULL; // -p^(-1) mod 2^64 - -// Non-residue for quadratic extension: X^2 = 7 -constant uint64_t EXT_NON_RESIDUE = 7; - -// ============================================================================= -// Basic Field Arithmetic -// ============================================================================= - -// Reduce if >= p -inline uint64_t goldilocks_reduce(uint64_t x) { - return (x >= GOLDILOCKS_P) ? (x - GOLDILOCKS_P) : x; -} - -// Addition: (a + b) mod p -inline uint64_t goldilocks_add(uint64_t a, uint64_t b) { - uint64_t sum = a + b; - // Check overflow or >= p - if (sum < a || sum >= GOLDILOCKS_P) { - sum -= GOLDILOCKS_P; - } - return sum; -} - -// Subtraction: (a - b) mod p -inline uint64_t goldilocks_sub(uint64_t a, uint64_t b) { - if (a >= b) { - return a - b; - } - return GOLDILOCKS_P - (b - a); -} - -// Negation: -a mod p -inline uint64_t goldilocks_neg(uint64_t a) { - return (a == 0) ? 0 : (GOLDILOCKS_P - a); -} - -// 64x64 -> 128-bit multiplication (hi, lo) -inline void mul64_128(uint64_t a, uint64_t b, thread uint64_t& hi, thread uint64_t& lo) { - uint64_t a_lo = a & 0xFFFFFFFF; - uint64_t a_hi = a >> 32; - uint64_t b_lo = b & 0xFFFFFFFF; - uint64_t b_hi = b >> 32; - - uint64_t p0 = a_lo * b_lo; - uint64_t p1 = a_lo * b_hi; - uint64_t p2 = a_hi * b_lo; - uint64_t p3 = a_hi * b_hi; - - uint64_t mid = p1 + (p0 >> 32); - uint64_t mid_lo = mid & 0xFFFFFFFF; - uint64_t mid_hi = mid >> 32; - - mid_lo += p2; - mid_hi += (mid_lo < p2) ? 1 : 0; - mid_hi += (p2 >> 32); - - lo = (mid_lo << 32) | (p0 & 0xFFFFFFFF); - hi = p3 + mid_hi; -} - -// Reduce 128-bit value mod Goldilocks -// Uses: hi * 2^64 + lo ≡ hi * (2^32 - 1) + lo (mod p) -inline uint64_t goldilocks_reduce128(uint64_t hi, uint64_t lo) { - // p = 2^64 - 2^32 + 1, so 2^64 ≡ 2^32 - 1 (mod p) - // hi * 2^64 ≡ hi * (2^32 - 1) = hi * 2^32 - hi - - uint64_t hi_shifted = hi << 32; - uint64_t result = lo; - - // Add hi * 2^32 - result = goldilocks_add(result, hi_shifted); - - // Subtract hi (equivalent to adding -hi mod p) - result = goldilocks_sub(result, hi); - - // Handle hi >> 32 overflow - uint64_t hi_upper = hi >> 32; - if (hi_upper > 0) { - // hi_upper * 2^96 mod p = hi_upper * (2^32 - 1)^2 mod p - // = hi_upper * (2^64 - 2*2^32 + 1) mod p - // = hi_upper * (2^32 - 1 - 2*2^32 + 1) mod p = hi_upper * (-2^32) mod p - result = goldilocks_sub(result, hi_upper << 32); - result = goldilocks_add(result, hi_upper); - } - - return goldilocks_reduce(result); -} - -// Multiplication: (a * b) mod p -inline uint64_t goldilocks_mul(uint64_t a, uint64_t b) { - uint64_t hi, lo; - mul64_128(a, b, hi, lo); - return goldilocks_reduce128(hi, lo); -} - -// Square: a^2 mod p -inline uint64_t goldilocks_square(uint64_t a) { - return goldilocks_mul(a, a); -} - -// Exponentiation: a^exp mod p (square-and-multiply) -inline uint64_t goldilocks_exp(uint64_t base, uint64_t exp) { - uint64_t result = 1; - while (exp > 0) { - if (exp & 1) { - result = goldilocks_mul(result, base); - } - base = goldilocks_square(base); - exp >>= 1; - } - return result; -} - -// Inverse: a^(-1) mod p using Fermat's little theorem -// a^(-1) = a^(p-2) mod p -inline uint64_t goldilocks_inv(uint64_t a) { - if (a == 0) return 0; - return goldilocks_exp(a, GOLDILOCKS_P_MINUS_2); -} - -// Division: a / b mod p -inline uint64_t goldilocks_div(uint64_t a, uint64_t b) { - return goldilocks_mul(a, goldilocks_inv(b)); -} - -// ============================================================================= -// Quadratic Extension Field: F_{p^2} = F_p[X] / (X^2 - 7) -// Elements: a + b*X where X^2 = 7 -// ============================================================================= - -struct ExtField { - uint64_t a; // Real part - uint64_t b; // Imaginary part (coefficient of X) -}; - -// Extension addition -inline ExtField ext_add(ExtField x, ExtField y) { - return {goldilocks_add(x.a, y.a), goldilocks_add(x.b, y.b)}; -} - -// Extension subtraction -inline ExtField ext_sub(ExtField x, ExtField y) { - return {goldilocks_sub(x.a, y.a), goldilocks_sub(x.b, y.b)}; -} - -// Extension negation -inline ExtField ext_neg(ExtField x) { - return {goldilocks_neg(x.a), goldilocks_neg(x.b)}; -} - -// Extension multiplication -// (a + bX)(c + dX) = (ac + 7bd) + (ad + bc)X -inline ExtField ext_mul(ExtField x, ExtField y) { - uint64_t ac = goldilocks_mul(x.a, y.a); - uint64_t bd = goldilocks_mul(x.b, y.b); - uint64_t ad = goldilocks_mul(x.a, y.b); - uint64_t bc = goldilocks_mul(x.b, y.a); - - // 7 * bd - uint64_t seven_bd = goldilocks_mul(EXT_NON_RESIDUE, bd); - - return { - goldilocks_add(ac, seven_bd), // ac + 7bd - goldilocks_add(ad, bc) // ad + bc - }; -} - -// Extension square -inline ExtField ext_square(ExtField x) { - return ext_mul(x, x); -} - -// Extension inverse using conjugate -// (a + bX)^(-1) = (a - bX) / (a^2 - 7b^2) -inline ExtField ext_inv(ExtField x) { - uint64_t a2 = goldilocks_square(x.a); - uint64_t b2 = goldilocks_square(x.b); - uint64_t seven_b2 = goldilocks_mul(EXT_NON_RESIDUE, b2); - uint64_t norm = goldilocks_sub(a2, seven_b2); // a^2 - 7b^2 - uint64_t norm_inv = goldilocks_inv(norm); - - return { - goldilocks_mul(x.a, norm_inv), - goldilocks_neg(goldilocks_mul(x.b, norm_inv)) - }; -} - -// ============================================================================= -// Batch Field Operations (Vectorized) -// ============================================================================= - -// Batch addition kernel -kernel void goldilocks_batch_add( - device const uint64_t* a [[buffer(0)]], - device const uint64_t* b [[buffer(1)]], - device uint64_t* result [[buffer(2)]], - uint index [[thread_position_in_grid]] -) { - result[index] = goldilocks_add(a[index], b[index]); -} - -// Batch subtraction kernel -kernel void goldilocks_batch_sub( - device const uint64_t* a [[buffer(0)]], - device const uint64_t* b [[buffer(1)]], - device uint64_t* result [[buffer(2)]], - uint index [[thread_position_in_grid]] -) { - result[index] = goldilocks_sub(a[index], b[index]); -} - -// Batch multiplication kernel -kernel void goldilocks_batch_mul( - device const uint64_t* a [[buffer(0)]], - device const uint64_t* b [[buffer(1)]], - device uint64_t* result [[buffer(2)]], - uint index [[thread_position_in_grid]] -) { - result[index] = goldilocks_mul(a[index], b[index]); -} - -// Batch inversion using Montgomery's trick -// Compute inversions for batch of elements: [a_0^-1, a_1^-1, ..., a_n^-1] -// Only uses 1 actual inversion + 3(n-1) multiplications -kernel void goldilocks_batch_inv( - device const uint64_t* inputs [[buffer(0)]], - device uint64_t* outputs [[buffer(1)]], - device uint64_t* scratch [[buffer(2)]], // Temporary storage - constant uint32_t& count [[buffer(3)]], - uint index [[thread_position_in_grid]] -) { - // Phase 1: Compute running products - // scratch[i] = inputs[0] * inputs[1] * ... * inputs[i] - if (index < count) { - if (index == 0) { - scratch[0] = inputs[0]; - } - } - - threadgroup_barrier(mem_flags::mem_device); - - // Sequential prefix product (done by thread 0) - if (index == 0) { - for (uint32_t i = 1; i < count; i++) { - scratch[i] = goldilocks_mul(scratch[i-1], inputs[i]); - } - - // Phase 2: Single inversion - uint64_t total_inv = goldilocks_inv(scratch[count - 1]); - - // Phase 3: Back-propagate inversions - for (int32_t i = count - 1; i >= 0; i--) { - if (i == 0) { - outputs[0] = total_inv; - } else { - outputs[i] = goldilocks_mul(total_inv, scratch[i - 1]); - total_inv = goldilocks_mul(total_inv, inputs[i]); - } - } - } -} - -// ============================================================================= -// FRI (Fast Reed-Solomon IOP) Operations -// ============================================================================= - -// FRI folding: fold layer at alpha -// new_eval[i] = (eval[2i] + eval[2i+1]) / 2 + alpha * (eval[2i] - eval[2i+1]) / (2 * omega^i) -kernel void fri_fold_layer( - device const uint64_t* evals [[buffer(0)]], // Current layer evaluations - device uint64_t* folded [[buffer(1)]], // Folded layer output - constant uint64_t& alpha [[buffer(2)]], // Folding challenge - constant uint64_t& omega_inv [[buffer(3)]], // Inverse of subgroup generator - constant uint32_t& layer_size [[buffer(4)]], // Size of current layer - uint index [[thread_position_in_grid]] -) { - if (index >= layer_size / 2) return; - - uint64_t e0 = evals[2 * index]; - uint64_t e1 = evals[2 * index + 1]; - - // (e0 + e1) / 2 - uint64_t sum = goldilocks_add(e0, e1); - uint64_t half_sum = goldilocks_mul(sum, goldilocks_inv(2)); // Could precompute inv(2) - - // (e0 - e1) / (2 * omega^i) - uint64_t diff = goldilocks_sub(e0, e1); - - // omega^(-i) for position i - uint64_t omega_power = goldilocks_exp(omega_inv, index); - uint64_t half_diff = goldilocks_mul(diff, goldilocks_mul(goldilocks_inv(2), omega_power)); - - // half_sum + alpha * half_diff - uint64_t alpha_term = goldilocks_mul(alpha, half_diff); - folded[index] = goldilocks_add(half_sum, alpha_term); -} - -// FRI query verification: check consistency of query -kernel void fri_verify_query( - device const uint64_t* layer_evals [[buffer(0)]], // Evaluations at query positions - device const uint64_t* alphas [[buffer(1)]], // Folding challenges - device const uint64_t* omega_invs [[buffer(2)]], // Inverse generators per layer - device uint32_t* query_positions [[buffer(3)]], // Query positions per layer - device uint64_t* results [[buffer(4)]], // Pass/fail results - constant uint32_t& num_layers [[buffer(5)]], - constant uint32_t& queries_per_layer [[buffer(6)]], - uint query_idx [[thread_position_in_grid]] -) { - // Each thread verifies one query path through all layers - // Implementation depends on specific FRI variant - - uint64_t expected = layer_evals[query_idx]; - bool valid = true; - - for (uint32_t layer = 0; layer < num_layers && valid; layer++) { - uint32_t pos = query_positions[layer * queries_per_layer + query_idx]; - uint32_t sibling_pos = pos ^ 1; // Sibling in pair - - uint64_t e0 = layer_evals[layer * queries_per_layer * 2 + pos]; - uint64_t e1 = layer_evals[layer * queries_per_layer * 2 + sibling_pos]; - - uint64_t alpha = alphas[layer]; - uint64_t omega_inv = omega_invs[layer]; - - // Compute expected folded value - uint64_t sum = goldilocks_add(e0, e1); - uint64_t diff = goldilocks_sub(e0, e1); - uint64_t omega_power = goldilocks_exp(omega_inv, pos / 2); - - uint64_t folded = goldilocks_add( - goldilocks_mul(sum, goldilocks_inv(2)), - goldilocks_mul(alpha, goldilocks_mul(diff, goldilocks_mul(goldilocks_inv(2), omega_power))) - ); - - // Check against next layer - // (actual implementation would check against committed values) - } - - results[query_idx] = valid ? 1 : 0; -} - -// ============================================================================= -// Constraint Evaluation for STARK AIR -// ============================================================================= - -// Evaluate AIR constraints at multiple points -kernel void stark_constraint_eval( - device const uint64_t* trace [[buffer(0)]], // Execution trace - device const uint64_t* trace_next [[buffer(1)]], // Next row of trace - device uint64_t* constraint_values [[buffer(2)]], // Output constraint evaluations - constant uint32_t& num_columns [[buffer(3)]], - constant uint32_t& num_constraints [[buffer(4)]], - uint row [[thread_position_in_grid]] -) { - // Generic constraint evaluation framework - // Specific constraints depend on the STARK being verified - - // Example: Fibonacci constraint - // c(x) = trace[i+1] - trace[i] - trace[i-1] = 0 - - // The actual constraints would be defined based on the AIR -} - -// ============================================================================= -// Extension Field Batch Operations -// ============================================================================= - -// Batch extension multiplication -kernel void ext_batch_mul( - device const ExtField* a [[buffer(0)]], - device const ExtField* b [[buffer(1)]], - device ExtField* result [[buffer(2)]], - uint index [[thread_position_in_grid]] -) { - result[index] = ext_mul(a[index], b[index]); -} - -// Extension field polynomial evaluation using Horner's method -kernel void ext_poly_eval( - device const ExtField* coeffs [[buffer(0)]], // Polynomial coefficients - device const ExtField* points [[buffer(1)]], // Evaluation points - device ExtField* results [[buffer(2)]], // Results - constant uint32_t& degree [[buffer(3)]], - uint point_idx [[thread_position_in_grid]] -) { - ExtField x = points[point_idx]; - ExtField result = coeffs[degree]; - - for (int32_t i = degree - 1; i >= 0; i--) { - result = ext_mul(result, x); - result = ext_add(result, coeffs[i]); - } - - results[point_idx] = result; -} diff --git a/poseidon/gpu/metal/poseidon.metal b/poseidon/gpu/metal/poseidon.metal deleted file mode 100644 index ff4f3b5..0000000 --- a/poseidon/gpu/metal/poseidon.metal +++ /dev/null @@ -1,426 +0,0 @@ -// Poseidon2 Hash Function for STARK-Friendly Hashing -// Optimized for Goldilocks field (p = 2^64 - 2^32 + 1) -// -// Poseidon2 is a SNARK-friendly hash function with: -// - Low multiplicative complexity -// - Efficient GPU parallelization -// - Used for Merkle trees and Fiat-Shamir in STARKs - -#include -using namespace metal; - -// ============================================================================= -// Goldilocks Field Arithmetic (reused from goldilocks.metal) -// ============================================================================= - -constant uint64_t GOLDILOCKS_P = 0xFFFFFFFF00000001ULL; - -inline uint64_t gl_reduce(uint64_t x) { - return (x >= GOLDILOCKS_P) ? (x - GOLDILOCKS_P) : x; -} - -inline uint64_t gl_add(uint64_t a, uint64_t b) { - uint64_t sum = a + b; - if (sum < a || sum >= GOLDILOCKS_P) { - sum -= GOLDILOCKS_P; - } - return sum; -} - -inline uint64_t gl_sub(uint64_t a, uint64_t b) { - if (a >= b) return a - b; - return GOLDILOCKS_P - (b - a); -} - -inline void gl_mul128(uint64_t a, uint64_t b, thread uint64_t& hi, thread uint64_t& lo) { - uint64_t a_lo = a & 0xFFFFFFFF; - uint64_t a_hi = a >> 32; - uint64_t b_lo = b & 0xFFFFFFFF; - uint64_t b_hi = b >> 32; - - uint64_t p0 = a_lo * b_lo; - uint64_t p1 = a_lo * b_hi; - uint64_t p2 = a_hi * b_lo; - uint64_t p3 = a_hi * b_hi; - - uint64_t mid = p1 + (p0 >> 32); - uint64_t mid_lo = mid & 0xFFFFFFFF; - uint64_t mid_hi = mid >> 32; - - mid_lo += p2; - mid_hi += (mid_lo < p2) ? 1 : 0; - mid_hi += (p2 >> 32); - - lo = (mid_lo << 32) | (p0 & 0xFFFFFFFF); - hi = p3 + mid_hi; -} - -inline uint64_t gl_reduce128(uint64_t hi, uint64_t lo) { - uint64_t hi_shifted = hi << 32; - uint64_t result = lo; - result = gl_add(result, hi_shifted); - result = gl_sub(result, hi); - uint64_t hi_upper = hi >> 32; - if (hi_upper > 0) { - result = gl_sub(result, hi_upper << 32); - result = gl_add(result, hi_upper); - } - return gl_reduce(result); -} - -inline uint64_t gl_mul(uint64_t a, uint64_t b) { - uint64_t hi, lo; - gl_mul128(a, b, hi, lo); - return gl_reduce128(hi, lo); -} - -// ============================================================================= -// Poseidon2 Parameters -// ============================================================================= - -// State width (t = 8 for Goldilocks Poseidon2) -constant uint32_t POSEIDON_WIDTH = 8; - -// Number of full rounds (beginning + end) -constant uint32_t POSEIDON_FULL_ROUNDS = 8; // 4 beginning + 4 end - -// Number of partial rounds (middle) -constant uint32_t POSEIDON_PARTIAL_ROUNDS = 22; - -// S-box exponent (x^7 for Goldilocks) -constant uint64_t POSEIDON_ALPHA = 7; - -// Round constants for Goldilocks Poseidon2 (width=8) -// Derived from Plonky2 reference implementation -// Total needed: 4*8 (first full) + 22 (partial) + 4*8 (last full) = 86 constants -// Using modulo wrapping for production - actual constants from Plonky2 -constant uint32_t POSEIDON_NUM_RC = 32; - -constant uint64_t POSEIDON_RC[32] = { - // First 4 full rounds constants (8 per round = 32 total, showing first 32) - 0xd64e1e3efc5b8e9eULL, 0x53666633020aaa47ULL, 0xd40285597c6a8825ULL, 0x613a4f81e81231d2ULL, - 0x414f3d9a74dd2f9fULL, 0x4bdd1fd6e7d83cd3ULL, 0x549ef8d6d6f6ead3ULL, 0x8bac89ca94dd2b8eULL, - 0x543a71ad3a4c8d52ULL, 0x6d4f9c68e87b44bbULL, 0x16b0c0f77f62a12cULL, 0x74e5f293ca58c3f8ULL, - 0x06a3b54a99ca9424ULL, 0xc4aafd3d8c5c4d1bULL, 0x7c3626c5c60a50c6ULL, 0x95c3a5f0a3b35d8fULL, - 0xfb1c7e8f7f72c8b6ULL, 0xc9d4a0a9d9a64ad1ULL, 0xda4a08add3f05f1cULL, 0xab4eb0f79c49c1d6ULL, - 0x8e47c5e6c0e9d2a4ULL, 0xe4b6c6c8f8e8f1e7ULL, 0x3d1c9a0c5a4c8d7eULL, 0x9b3e5d7f1a2c4e6aULL, - 0x7f8e9d0c1b2a3948ULL, 0x5a6b7c8d9e0f1a2bULL, 0x3c4d5e6f7a8b9c0dULL, 0x1e2f3a4b5c6d7e8fULL, - 0x0a1b2c3d4e5f6a7bULL, 0xf9e8d7c6b5a49382ULL, 0x7164534231201f0eULL, 0xedcba9876543210fULL -}; - -// MDS matrix for linear layer (8x8 Cauchy matrix) -constant uint64_t POSEIDON_MDS[POSEIDON_WIDTH][POSEIDON_WIDTH] = { - {1, 1, 1, 1, 1, 1, 1, 1}, - {1, 2, 3, 4, 5, 6, 7, 8}, - {1, 4, 9, 16, 25, 36, 49, 64}, - {1, 8, 27, 64, 125, 216, 343, 512}, - {1, 16, 81, 256, 625, 1296, 2401, 4096}, - {1, 32, 243, 1024, 3125, 7776, 16807, 32768}, - {1, 64, 729, 4096, 15625, 46656, 117649, 262144}, - {1, 128, 2187, 16384, 78125, 279936, 823543, 2097152} -}; - -// ============================================================================= -// S-box: x^7 in Goldilocks field -// ============================================================================= - -inline uint64_t poseidon_sbox(uint64_t x) { - uint64_t x2 = gl_mul(x, x); // x^2 - uint64_t x4 = gl_mul(x2, x2); // x^4 - uint64_t x3 = gl_mul(x2, x); // x^3 - return gl_mul(x4, x3); // x^7 -} - -// ============================================================================= -// Linear Layer: MDS Matrix Multiplication -// ============================================================================= - -inline void poseidon_mds(thread uint64_t* state) { - uint64_t result[POSEIDON_WIDTH]; - - for (uint32_t i = 0; i < POSEIDON_WIDTH; i++) { - result[i] = 0; - for (uint32_t j = 0; j < POSEIDON_WIDTH; j++) { - result[i] = gl_add(result[i], gl_mul(POSEIDON_MDS[i][j], state[j])); - } - } - - for (uint32_t i = 0; i < POSEIDON_WIDTH; i++) { - state[i] = result[i]; - } -} - -// Get round constant with modulo wrapping -inline uint64_t get_gl_rc(uint32_t idx) { - return POSEIDON_RC[idx % POSEIDON_NUM_RC]; -} - -// Poseidon2 internal diagonal for width=8 -// These multiply the state before adding the sum -constant uint64_t POSEIDON_INTERNAL_DIAG[8] = { - // Diagonal elements for Plonky2-compatible Poseidon2 - // d = [1, 1, 1, 1, 1, 1, 1, 2] (simplified for width=8) - 1, 1, 1, 1, 1, 1, 1, 2 -}; - -// Poseidon2 internal matrix: M_I = diag(d) + J -// where J is the all-ones matrix -inline void poseidon2_internal_linear(thread uint64_t* state) { - // Compute sum of all elements - uint64_t sum = 0; - for (uint32_t i = 0; i < POSEIDON_WIDTH; i++) { - sum = gl_add(sum, state[i]); - } - - // Apply: state[i] = d[i] * state[i] + sum - for (uint32_t i = 0; i < POSEIDON_WIDTH; i++) { - if (POSEIDON_INTERNAL_DIAG[i] == 1) { - state[i] = gl_add(state[i], sum); - } else if (POSEIDON_INTERNAL_DIAG[i] == 2) { - state[i] = gl_add(gl_add(state[i], state[i]), sum); - } else { - // General case: d[i] * state[i] + sum - uint64_t scaled = gl_mul(POSEIDON_INTERNAL_DIAG[i], state[i]); - state[i] = gl_add(scaled, sum); - } - } -} - -// ============================================================================= -// Poseidon2 Permutation -// ============================================================================= - -inline void poseidon2_permutation(thread uint64_t* state) { - uint32_t rc_idx = 0; - - // Beginning full rounds (4 rounds) - for (uint32_t r = 0; r < POSEIDON_FULL_ROUNDS / 2; r++) { - // Add round constants to all elements - for (uint32_t i = 0; i < POSEIDON_WIDTH; i++) { - state[i] = gl_add(state[i], get_gl_rc(rc_idx++)); - } - - // S-box (x^7) on all elements - for (uint32_t i = 0; i < POSEIDON_WIDTH; i++) { - state[i] = poseidon_sbox(state[i]); - } - - // External MDS matrix - poseidon_mds(state); - } - - // Partial rounds (22 rounds) - for (uint32_t r = 0; r < POSEIDON_PARTIAL_ROUNDS; r++) { - // Add round constant to first element only - state[0] = gl_add(state[0], get_gl_rc(rc_idx++)); - - // S-box only on first element - state[0] = poseidon_sbox(state[0]); - - // Internal matrix (diag(d) + J) - poseidon2_internal_linear(state); - } - - // Ending full rounds (4 rounds) - for (uint32_t r = 0; r < POSEIDON_FULL_ROUNDS / 2; r++) { - // Add round constants to all elements - for (uint32_t i = 0; i < POSEIDON_WIDTH; i++) { - state[i] = gl_add(state[i], get_gl_rc(rc_idx++)); - } - - // S-box (x^7) on all elements - for (uint32_t i = 0; i < POSEIDON_WIDTH; i++) { - state[i] = poseidon_sbox(state[i]); - } - - // External MDS matrix - poseidon_mds(state); - } -} - -// ============================================================================= -// Sponge Construction -// ============================================================================= - -// Hash arbitrary input to single field element -kernel void poseidon_hash( - device const uint64_t* input [[buffer(0)]], - device uint64_t* output [[buffer(1)]], - constant uint32_t& input_len [[buffer(2)]], - uint index [[thread_position_in_grid]] -) { - // Each thread hashes one input block - uint32_t offset = index * (POSEIDON_WIDTH - 1); // Rate = width - 1 - - if (offset >= input_len) return; - - // Initialize state - uint64_t state[POSEIDON_WIDTH] = {0}; - - // Absorb phase - uint32_t remaining = input_len - offset; - uint32_t to_absorb = min(remaining, POSEIDON_WIDTH - 1); - - for (uint32_t i = 0; i < to_absorb; i++) { - state[i] = input[offset + i]; - } - - // Domain separation / padding - if (to_absorb < POSEIDON_WIDTH - 1) { - state[to_absorb] = 1; // Padding - } - - // Permutation - poseidon2_permutation(state); - - // Output first element (squeeze phase) - output[index] = state[0]; -} - -// Hash pair for Merkle tree (2-to-1 compression) -kernel void poseidon_hash_pair( - device const uint64_t* left [[buffer(0)]], - device const uint64_t* right [[buffer(1)]], - device uint64_t* output [[buffer(2)]], - uint index [[thread_position_in_grid]] -) { - // Initialize state with inputs - uint64_t state[POSEIDON_WIDTH] = {0}; - state[0] = left[index]; - state[1] = right[index]; - - // Domain separation for 2-to-1 hash - state[POSEIDON_WIDTH - 1] = 2; - - // Permutation - poseidon2_permutation(state); - - // Output - output[index] = state[0]; -} - -// ============================================================================= -// Merkle Tree Construction -// ============================================================================= - -// Build one layer of Merkle tree -kernel void poseidon_merkle_layer( - device const uint64_t* current_layer [[buffer(0)]], - device uint64_t* next_layer [[buffer(1)]], - constant uint32_t& current_size [[buffer(2)]], - uint index [[thread_position_in_grid]] -) { - if (index >= current_size / 2) return; - - uint64_t left = current_layer[2 * index]; - uint64_t right = current_layer[2 * index + 1]; - - // Initialize state - uint64_t state[POSEIDON_WIDTH] = {0}; - state[0] = left; - state[1] = right; - state[POSEIDON_WIDTH - 1] = 2; // Domain separation - - // Permutation - poseidon2_permutation(state); - - next_layer[index] = state[0]; -} - -// Batch hash for multiple independent Merkle tree constructions -kernel void poseidon_batch_merkle_layer( - device const uint64_t* leaves [[buffer(0)]], - device uint64_t* parents [[buffer(1)]], - constant uint32_t& num_pairs [[buffer(2)]], - uint index [[thread_position_in_grid]] -) { - if (index >= num_pairs) return; - - uint64_t left = leaves[2 * index]; - uint64_t right = leaves[2 * index + 1]; - - uint64_t state[POSEIDON_WIDTH] = {0}; - state[0] = left; - state[1] = right; - state[POSEIDON_WIDTH - 1] = 2; - - poseidon2_permutation(state); - - parents[index] = state[0]; -} - -// Verify Merkle proof (single proof, multiple in parallel) -kernel void poseidon_verify_merkle_proof( - device const uint64_t* leaf [[buffer(0)]], - device const uint64_t* path [[buffer(1)]], - device const uint32_t* path_indices [[buffer(2)]], // 0 = left, 1 = right - device const uint64_t* expected_root [[buffer(3)]], - device uint32_t* result [[buffer(4)]], // 1 = valid, 0 = invalid - constant uint32_t& path_len [[buffer(5)]], - uint proof_idx [[thread_position_in_grid]] -) { - uint64_t current = leaf[proof_idx]; - - for (uint32_t i = 0; i < path_len; i++) { - uint64_t sibling = path[proof_idx * path_len + i]; - uint32_t idx = path_indices[proof_idx * path_len + i]; - - uint64_t left = (idx == 0) ? current : sibling; - uint64_t right = (idx == 0) ? sibling : current; - - uint64_t state[POSEIDON_WIDTH] = {0}; - state[0] = left; - state[1] = right; - state[POSEIDON_WIDTH - 1] = 2; - - poseidon2_permutation(state); - - current = state[0]; - } - - result[proof_idx] = (current == expected_root[proof_idx]) ? 1 : 0; -} - -// ============================================================================= -// Transcript (Fiat-Shamir) -// ============================================================================= - -// Add data to Fiat-Shamir transcript and squeeze challenge -kernel void poseidon_fiat_shamir( - device const uint64_t* transcript_state [[buffer(0)]], - device const uint64_t* new_data [[buffer(1)]], - device uint64_t* updated_state [[buffer(2)]], - device uint64_t* challenge [[buffer(3)]], - constant uint32_t& data_len [[buffer(4)]], - uint index [[thread_position_in_grid]] -) { - if (index != 0) return; // Single thread operation - - // Load current state - uint64_t state[POSEIDON_WIDTH]; - for (uint32_t i = 0; i < POSEIDON_WIDTH; i++) { - state[i] = transcript_state[i]; - } - - // Absorb new data - uint32_t absorbed = 0; - while (absorbed < data_len) { - uint32_t rate = POSEIDON_WIDTH - 1; - uint32_t to_absorb = min(data_len - absorbed, rate); - - for (uint32_t i = 0; i < to_absorb; i++) { - state[i] = gl_add(state[i], new_data[absorbed + i]); - } - absorbed += to_absorb; - - poseidon2_permutation(state); - } - - // Output updated state - for (uint32_t i = 0; i < POSEIDON_WIDTH; i++) { - updated_state[i] = state[i]; - } - - // Squeeze challenge - *challenge = state[0]; -} diff --git a/poseidon/gpu/metal/poseidon2_bn254.metal b/poseidon/gpu/metal/poseidon2_bn254.metal deleted file mode 100644 index 782e2f9..0000000 --- a/poseidon/gpu/metal/poseidon2_bn254.metal +++ /dev/null @@ -1,327 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// First-party Metal kernel for Poseidon2-BN254 (canonical default permutation). -// -// Byte-equal to lux::crypto::poseidon::hash2 in poseidon/cpp/poseidon.cpp, -// which mirrors gnark-crypto v0.20.1 ecc/bn254/fr/poseidon2 with parameters -// (t=2, rF=6, rP=50, d=5). The round-key constant table is emitted by the -// CPU body itself via dump_round_keys -> gen_metal_constants and #include'd -// here, so there is exactly one source of truth for round constants. -// -// Constant-time: the permutation has no data-dependent branches; the only -// branches are the canonical reductions in field arithmetic, which depend -// purely on internal carries and match the CPU body bit-for-bit. - -#include -using namespace metal; - -#include "poseidon2_bn254_rk.metalh" // POSEIDON2_RK[56][2][4] (Montgomery limbs) - -// ============================================================================= -// BN254 Fr modulus q (4x64 little-endian limbs). -// q = 21888242871839275222246405745257275088548364400416034343698204186575808495617 -// Montgomery params: qInvNeg = -q^{-1} mod 2^64; rSquare = R^2 mod q. -// All values match poseidon/cpp/poseidon.cpp byte-for-byte. -// ============================================================================= -constant ulong Q0 = 0x43e1f593f0000001UL; -constant ulong Q1 = 0x2833e84879b97091UL; -constant ulong Q2 = 0xb85045b68181585dUL; -constant ulong Q3 = 0x30644e72e131a029UL; -constant ulong Q_INV_NEG = 0xc2e1f593efffffffUL; - -constant ulong R_SQUARE_0 = 1997599621687373223UL; -constant ulong R_SQUARE_1 = 6052339484930628067UL; -constant ulong R_SQUARE_2 = 10108755138030829701UL; -constant ulong R_SQUARE_3 = 150537098327114917UL; - -// ============================================================================= -// 256-bit field element in Montgomery form. -// ============================================================================= -struct Fr { - ulong l0, l1, l2, l3; -}; - -// ============================================================================= -// Multi-precision primitives. mul64 uses Metal's native mulhi(u64,u64) which -// emits the hardware-accelerated high-multiply instruction on Apple silicon. -// adc/sbb keep the carry-flag pattern that mirrors the CPU body exactly. -// ============================================================================= - -inline void mul64(ulong a, ulong b, thread ulong &lo, thread ulong &hi) { - lo = a * b; - hi = mulhi(a, b); -} - -inline ulong adc(ulong a, ulong b, thread ulong &carry) { - ulong s = a + b; - ulong c1 = (s < a) ? 1UL : 0UL; - ulong s2 = s + carry; - ulong c2 = (s2 < s) ? 1UL : 0UL; - carry = c1 + c2; - return s2; -} - -inline ulong sbb(ulong a, ulong b, thread ulong &borrow) { - ulong d = a - b; - ulong b1 = (a < b) ? 1UL : 0UL; - ulong d2 = d - borrow; - ulong b2 = (d < borrow) ? 1UL : 0UL; - borrow = b1 + b2; - return d2; -} - -// cmp_q: returns -1 if aq. Matches CPU body. -inline int cmp_q(thread const Fr &a) { - if (a.l3 != Q3) return (a.l3 < Q3) ? -1 : 1; - if (a.l2 != Q2) return (a.l2 < Q2) ? -1 : 1; - if (a.l1 != Q1) return (a.l1 < Q1) ? -1 : 1; - if (a.l0 != Q0) return (a.l0 < Q0) ? -1 : 1; - return 0; -} - -// reduce_once: subtract q if a >= q. Used after operations that produce -// values in [0, 2q). -inline void reduce_once(thread Fr &a) { - if (cmp_q(a) >= 0) { - ulong br = 0; - a.l0 = sbb(a.l0, Q0, br); - a.l1 = sbb(a.l1, Q1, br); - a.l2 = sbb(a.l2, Q2, br); - a.l3 = sbb(a.l3, Q3, br); - } -} - -// fr_add: c = a + b mod q. Identical to poseidon/cpp/poseidon.cpp::fr_add. -inline Fr fr_add(thread const Fr &a, thread const Fr &b) { - Fr c; - ulong cy = 0; - c.l0 = adc(a.l0, b.l0, cy); - c.l1 = adc(a.l1, b.l1, cy); - c.l2 = adc(a.l2, b.l2, cy); - c.l3 = adc(a.l3, b.l3, cy); - if (cy != 0 || cmp_q(c) >= 0) { - ulong br = 0; - c.l0 = sbb(c.l0, Q0, br); - c.l1 = sbb(c.l1, Q1, br); - c.l2 = sbb(c.l2, Q2, br); - c.l3 = sbb(c.l3, Q3, br); - } - return c; -} - -inline Fr fr_double(thread const Fr &a) { return fr_add(a, a); } - -// fr_mul: Montgomery multiplication, CIOS layout. c = a * b * R^{-1} mod q. -// Identical algorithm to poseidon/cpp/poseidon.cpp::fr_mul. -inline Fr fr_mul(thread const Fr &a, thread const Fr &b) { - ulong t[5] = {0, 0, 0, 0, 0}; - const ulong al[4] = {a.l0, a.l1, a.l2, a.l3}; - const ulong bl[4] = {b.l0, b.l1, b.l2, b.l3}; - const ulong qq[4] = {Q0, Q1, Q2, Q3}; - - for (int i = 0; i < 4; ++i) { - // t += a * b[i] - ulong cy = 0; - for (int j = 0; j < 4; ++j) { - ulong lo, hi; - mul64(al[j], bl[i], lo, hi); - ulong s = t[j] + lo; - ulong c1 = (s < t[j]) ? 1UL : 0UL; - ulong s2 = s + cy; - ulong c2 = (s2 < s) ? 1UL : 0UL; - t[j] = s2; - cy = hi + c1 + c2; - } - t[4] += cy; - - // m = t[0] * qInvNeg mod 2^64 - ulong m = t[0] * Q_INV_NEG; - - // t += m * q - cy = 0; - for (int j = 0; j < 4; ++j) { - ulong lo, hi; - mul64(m, qq[j], lo, hi); - ulong s = t[j] + lo; - ulong c1 = (s < t[j]) ? 1UL : 0UL; - ulong s2 = s + cy; - ulong c2 = (s2 < s) ? 1UL : 0UL; - t[j] = s2; - cy = hi + c1 + c2; - } - t[4] += cy; - - // t[0] is now zero; shift right by one limb. - t[0] = t[1]; - t[1] = t[2]; - t[2] = t[3]; - t[3] = t[4]; - t[4] = 0; - } - Fr c; - c.l0 = t[0]; c.l1 = t[1]; c.l2 = t[2]; c.l3 = t[3]; - reduce_once(c); - return c; -} - -inline Fr fr_square(thread const Fr &a) { return fr_mul(a, a); } - -// ============================================================================= -// Poseidon2-BN254 default permutation (t=2, rF=6 split 3+3, rP=50, d=5). -// sBox: x -> x^5 = ((x^2)^2) * x -// matMulExternal (t=2, M_E = circ(2,1)): -// tmp = s0+s1; s0 += tmp; s1 += tmp. -// matMulInternal (t=2, M_I = [[2,1],[1,3]]): -// sum = s0+s1; s0 += sum; s1 = 2*s1 + sum. -// ============================================================================= - -inline void sbox(thread Fr &x) { - Fr x2 = fr_square(x); - Fr x4 = fr_square(x2); - x = fr_mul(x4, x); -} - -inline void mat_mul_external(thread Fr s[2]) { - Fr tmp = fr_add(s[0], s[1]); - s[0] = fr_add(s[0], tmp); - s[1] = fr_add(s[1], tmp); -} - -inline void mat_mul_internal(thread Fr s[2]) { - Fr sum = fr_add(s[0], s[1]); - s[0] = fr_add(s[0], sum); - Fr s1d = fr_double(s[1]); - s[1] = fr_add(s1d, sum); -} - -// Round-key indices in POSEIDON2_RK (matches dump_round_keys order): -// [0..2] full pre-rounds -> 3 rounds, 2 keys each -// [3..52] partial rounds -> 50 rounds, 1 key each (slot 1 ignored) -// [53..55] full post-rounds -> 3 rounds, 2 keys each -constant int FULL_HALF = 3; -constant int PARTIAL = 50; - -inline Fr load_rk(int round, int slot) { - Fr r; - r.l0 = POSEIDON2_RK[round][slot][0]; - r.l1 = POSEIDON2_RK[round][slot][1]; - r.l2 = POSEIDON2_RK[round][slot][2]; - r.l3 = POSEIDON2_RK[round][slot][3]; - return r; -} - -inline void permute(thread Fr s[2]) { - // Initial external matrix mix (gnark applies M_E once before the first - // full round; the CPU body matches). - mat_mul_external(s); - - // Full pre-rounds. - for (int i = 0; i < FULL_HALF; ++i) { - Fr k0 = load_rk(i, 0); - Fr k1 = load_rk(i, 1); - s[0] = fr_add(s[0], k0); - s[1] = fr_add(s[1], k1); - sbox(s[0]); - sbox(s[1]); - mat_mul_external(s); - } - - // Partial rounds. - for (int i = 0; i < PARTIAL; ++i) { - Fr k0 = load_rk(FULL_HALF + i, 0); - s[0] = fr_add(s[0], k0); - sbox(s[0]); - mat_mul_internal(s); - } - - // Full post-rounds. - for (int i = 0; i < FULL_HALF; ++i) { - Fr k0 = load_rk(FULL_HALF + PARTIAL + i, 0); - Fr k1 = load_rk(FULL_HALF + PARTIAL + i, 1); - s[0] = fr_add(s[0], k0); - s[1] = fr_add(s[1], k1); - sbox(s[0]); - sbox(s[1]); - mat_mul_external(s); - } -} - -// ============================================================================= -// Bytes (BE) <-> Fr (Montgomery LE limbs) conversions. -// gnark-crypto's SetBytes treats the 32-byte input as a big-endian integer, -// then reduces modulo q (BN254 q > 2^253, so input < 2^256 needs at most 4 -// subtractions in the worst case but typically 0 or 1) and converts to -// Montgomery by multiplying by R^2. -// ============================================================================= - -inline Fr be_to_fr_mont(device const uchar *be) { - auto rd = [&](int off) -> ulong { - ulong v = 0; - for (int b = 0; b < 8; ++b) { - v = (v << 8) | (ulong)be[off + b]; - } - return v; - }; - Fr x; - x.l0 = rd(24); - x.l1 = rd(16); - x.l2 = rd(8); - x.l3 = rd(0); - // Bring into [0, q) by subtracting q until smaller. Bounded loop (<= 4). - for (int i = 0; i < 4; ++i) { - if (cmp_q(x) < 0) break; - ulong br = 0; - x.l0 = sbb(x.l0, Q0, br); - x.l1 = sbb(x.l1, Q1, br); - x.l2 = sbb(x.l2, Q2, br); - x.l3 = sbb(x.l3, Q3, br); - } - Fr r2; - r2.l0 = R_SQUARE_0; r2.l1 = R_SQUARE_1; - r2.l2 = R_SQUARE_2; r2.l3 = R_SQUARE_3; - return fr_mul(x, r2); -} - -inline void fr_mont_to_be(thread const Fr &x, device uchar *be) { - // from_mont: multiply Montgomery value by 1 in regular form (= R^{-1}). - Fr one_reg; - one_reg.l0 = 1; one_reg.l1 = 0; one_reg.l2 = 0; one_reg.l3 = 0; - Fr r = fr_mul(x, one_reg); - ulong limbs[4] = {r.l0, r.l1, r.l2, r.l3}; - // Big-endian 32-byte output: limb 3 first. - for (int i = 0; i < 4; ++i) { - ulong v = limbs[i]; - int off = 32 - 8 * (i + 1); - for (int b = 7; b >= 0; --b) { - be[off + b] = (uchar)(v & 0xff); - v >>= 8; - } - } -} - -// ============================================================================= -// Kernel: poseidon2_hash2_batch. -// Inputs: pairs[i] = 64 bytes = [BE(left) | BE(right)]. -// Output: outs[i] = 32 bytes BE. -// One thread per pair. -// ============================================================================= -kernel void poseidon2_hash2_batch( - device const uchar *pairs [[buffer(0)]], // n * 64 bytes - device uchar *outs [[buffer(1)]], // n * 32 bytes - constant uint &n [[buffer(2)]], - uint i [[thread_position_in_grid]]) -{ - if (i >= n) return; - - Fr s[2]; - s[0] = be_to_fr_mont(pairs + i * 64); - s[1] = be_to_fr_mont(pairs + i * 64 + 32); - Fr saved_right = s[1]; - - permute(s); - - // gnark-crypto Compress: out = saved_right + s[1]. - Fr out = fr_add(saved_right, s[1]); - fr_mont_to_be(out, outs + i * 32); -} diff --git a/poseidon/gpu/metal/poseidon2_driver.h b/poseidon/gpu/metal/poseidon2_driver.h deleted file mode 100644 index 05bd367..0000000 --- a/poseidon/gpu/metal/poseidon2_driver.h +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Metal driver for Poseidon2-BN254. macOS / iOS only. -// -// Provides a single batched 2-to-1 compression entry point. The kernel is -// byte-equal-by-construction to lux::crypto::poseidon::hash2 (the round-key -// table is generated from the CPU body itself and #include'd into the .metal -// source). Inputs/outputs are 32-byte big-endian Fr field elements per the -// gnark-crypto convention. - -#ifndef LUX_POSEIDON2_METAL_DRIVER_H -#define LUX_POSEIDON2_METAL_DRIVER_H - -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -// Run n Poseidon2.Compress calls in one Metal dispatch. -// -// pairs : n * 64 bytes, layout = [BE(left_i) || BE(right_i)] for i=0..n-1 -// outs : n * 32 bytes, BE digest written per pair -// n : number of pairs -// metallib_path : absolute path to the precompiled metallib -// -// Returns 0 on success, negative on failure (-1 invalid arg, -2 device init, -// -3 lib load, -4 function lookup, -5 pipeline create). -int poseidon2_hash2_metal_batch( - const uint8_t* pairs, - uint8_t* outs, - size_t n, - const char* metallib_path); - -#ifdef __cplusplus -} -#endif - -#endif // LUX_POSEIDON2_METAL_DRIVER_H diff --git a/poseidon/gpu/metal/poseidon2_driver.mm b/poseidon/gpu/metal/poseidon2_driver.mm deleted file mode 100644 index 280ed1c..0000000 --- a/poseidon/gpu/metal/poseidon2_driver.mm +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Metal driver for Poseidon2-BN254. macOS / iOS only. -// -// Loads the precompiled poseidon2_bn254.metallib and dispatches the -// `poseidon2_hash2_batch` kernel with one thread per (left, right) pair. -// Byte-equal to poseidon/cpp/poseidon.cpp::hash2. - -#if __APPLE__ && __OBJC__ - -#import -#import - -#include "poseidon2_driver.h" - -#include -#include -#include - -extern "C" int poseidon2_hash2_metal_batch( - const uint8_t* pairs, - uint8_t* outs, - size_t n, - const char* metallib_path) { - - if (n == 0) return 0; - if (!pairs || !outs || !metallib_path) return -1; - - @autoreleasepool { - id device = MTLCreateSystemDefaultDevice(); - if (!device) return -2; - - NSError* err = nil; - NSString* path = [NSString stringWithUTF8String:metallib_path]; - NSURL* url = [NSURL fileURLWithPath:path]; - id lib = [device newLibraryWithURL:url error:&err]; - if (!lib) return -3; - - id fn = [lib newFunctionWithName:@"poseidon2_hash2_batch"]; - if (!fn) return -4; - - id pipeline = - [device newComputePipelineStateWithFunction:fn error:&err]; - if (!pipeline) return -5; - - id queue = [device newCommandQueue]; - - id pairs_buf = [device newBufferWithBytes:pairs - length:n * 64 - options:MTLResourceStorageModeShared]; - id outs_buf = [device newBufferWithLength:n * 32 - options:MTLResourceStorageModeShared]; - uint32_t n_u32 = (uint32_t)n; - id n_buf = [device newBufferWithBytes:&n_u32 - length:sizeof(n_u32) - options:MTLResourceStorageModeShared]; - - id cmd = [queue commandBuffer]; - id enc = [cmd computeCommandEncoder]; - [enc setComputePipelineState:pipeline]; - [enc setBuffer:pairs_buf offset:0 atIndex:0]; - [enc setBuffer:outs_buf offset:0 atIndex:1]; - [enc setBuffer:n_buf offset:0 atIndex:2]; - - NSUInteger tg_max = pipeline.maxTotalThreadsPerThreadgroup; - NSUInteger tg_w = tg_max < 64 ? tg_max : 64; - MTLSize threads_per_grid = MTLSizeMake(n, 1, 1); - MTLSize threads_per_tg = MTLSizeMake(tg_w, 1, 1); - [enc dispatchThreads:threads_per_grid threadsPerThreadgroup:threads_per_tg]; - [enc endEncoding]; - [cmd commit]; - [cmd waitUntilCompleted]; - - std::memcpy(outs, [outs_buf contents], n * 32); - } - return 0; -} - -#endif // __APPLE__ && __OBJC__ diff --git a/poseidon/gpu/metal/poseidon2_t2_batch.metal b/poseidon/gpu/metal/poseidon2_t2_batch.metal deleted file mode 100644 index 128a0ae..0000000 --- a/poseidon/gpu/metal/poseidon2_t2_batch.metal +++ /dev/null @@ -1,380 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// GPU-batched Poseidon2 over BN254 scalar field, t=2 default parameters. -// One thread per state (left, right). Byte-equal to -// poseidon/cpp/poseidon.cpp::permutation_t2(). -// -// Field arithmetic is implemented as 256-bit big-integer ops on 4 little- -// endian uint64_t limbs. Multiplications use the schoolbook 4x4 -> 8 limb -// multiply followed by bit-by-bit shift-and-subtract reduction. This is -// slower than Montgomery multiplication but keeps the kernel arithmetic -// trivially auditable against the CPU body. - -#include -using namespace metal; - -// BN254 scalar modulus r (LE limbs). -constant ulong MOD0 = 0x43e1f593f0000001ul; -constant ulong MOD1 = 0x2833e84879b97091ul; -constant ulong MOD2 = 0xb85045b68181585dul; -constant ulong MOD3 = 0x30644e72e131a029ul; - -inline bool ge_mod(thread const ulong* a) { - if (a[3] != MOD3) return a[3] > MOD3; - if (a[2] != MOD2) return a[2] > MOD2; - if (a[1] != MOD1) return a[1] > MOD1; - if (a[0] != MOD0) return a[0] >= MOD0; - return true; -} - -inline bool ge(thread const ulong* a, thread const ulong* b) { - if (a[3] != b[3]) return a[3] > b[3]; - if (a[2] != b[2]) return a[2] > b[2]; - if (a[1] != b[1]) return a[1] > b[1]; - return a[0] >= b[0]; -} - -// Subtract MOD in-place from a (assumes a >= MOD). -inline void sub_mod_inplace(thread ulong* a) { - ulong borrow = 0u; - ulong b[4] = { MOD0, MOD1, MOD2, MOD3 }; - for (int i = 0; i < 4; ++i) { - ulong x = a[i]; - ulong y = b[i]; - ulong sub = x - y - borrow; - // Set borrow if (x < y + borrow). Use bit trick. - ulong new_borrow = ((y > x) || (y == x && borrow)) ? 1u : 0u; - // Refined: borrow if y+borrow > x, with possible y+borrow overflow. - ulong yb_lo = y + borrow; - ulong yb_overflow = (yb_lo < y) ? 1u : 0u; - new_borrow = (yb_overflow || (x < yb_lo)) ? 1u : 0u; - a[i] = sub; - borrow = new_borrow; - } -} - -// 4-limb add with carry out. -inline ulong add_carry(thread ulong* r, thread const ulong* a, thread const ulong* b) { - ulong carry = 0u; - for (int i = 0; i < 4; ++i) { - ulong x = a[i]; - ulong y = b[i]; - ulong sum = x + y; - ulong c1 = (sum < x) ? 1u : 0u; - ulong sum2 = sum + carry; - ulong c2 = (sum2 < sum) ? 1u : 0u; - r[i] = sum2; - carry = c1 + c2; - } - return carry; -} - -// r = (a + b) mod MOD. -inline void add_mod(thread ulong* r, thread const ulong* a, thread const ulong* b) { - ulong c = add_carry(r, a, b); - if (c != 0u || ge_mod(r)) { - sub_mod_inplace(r); - } -} - -inline void double_mod(thread ulong* r, thread const ulong* a) { - add_mod(r, a, a); -} - -// Subtract b from a (assumes a >= b). No borrow handling beyond limb 3. -inline void sub_unchecked(thread ulong* r, thread const ulong* a, thread const ulong* b) { - ulong borrow = 0u; - for (int i = 0; i < 4; ++i) { - ulong x = a[i]; - ulong y = b[i]; - ulong yb_lo = y + borrow; - ulong yb_overflow = (yb_lo < y) ? 1u : 0u; - ulong new_borrow = (yb_overflow || (x < yb_lo)) ? 1u : 0u; - r[i] = x - yb_lo; - borrow = new_borrow; - } -} - -// r = (a - b) mod MOD. -inline void sub_mod(thread ulong* r, thread const ulong* a, thread const ulong* b) { - if (ge(a, b)) { - sub_unchecked(r, a, b); - } else { - ulong tmp[4]; - ulong m[4] = { MOD0, MOD1, MOD2, MOD3 }; - sub_unchecked(tmp, m, b); - ulong c = add_carry(r, tmp, a); - if (c != 0u || ge_mod(r)) sub_mod_inplace(r); - } -} - -// 64x64 -> 128 multiply, returning the low and high words. -inline void mul64_128(ulong x, ulong y, thread ulong& lo, thread ulong& hi) { - // Metal does not have a native 128-bit type. Decompose into 32-bit halves. - ulong x_lo = x & 0xFFFFFFFFul; - ulong x_hi = x >> 32; - ulong y_lo = y & 0xFFFFFFFFul; - ulong y_hi = y >> 32; - - ulong p_ll = x_lo * y_lo; - ulong p_lh = x_lo * y_hi; - ulong p_hl = x_hi * y_lo; - ulong p_hh = x_hi * y_hi; - - ulong mid = (p_ll >> 32) + (p_lh & 0xFFFFFFFFul) + (p_hl & 0xFFFFFFFFul); - lo = (p_ll & 0xFFFFFFFFul) | (mid << 32); - hi = p_hh + (p_lh >> 32) + (p_hl >> 32) + (mid >> 32); -} - -// 4x4 -> 8 limb schoolbook multiply. -inline void mul_512(thread ulong* out, thread const ulong* a, thread const ulong* b) { - for (int k = 0; k < 8; ++k) out[k] = 0u; - for (int i = 0; i < 4; ++i) { - ulong carry = 0u; - for (int j = 0; j < 4; ++j) { - ulong p_lo, p_hi; - mul64_128(a[i], b[j], p_lo, p_hi); - - // out[i+j] += p_lo + carry - ulong s = out[i + j] + p_lo; - ulong c1 = (s < out[i + j]) ? 1u : 0u; - ulong s2 = s + carry; - ulong c2 = (s2 < s) ? 1u : 0u; - out[i + j] = s2; - - // carry for next iteration = p_hi + c1 + c2 - carry = p_hi + c1 + c2; - } - out[i + 4] = carry; - } -} - -// Bit-serial reduction: r = (wide mod MOD). -inline void reduce_512(thread ulong* r, thread const ulong* wide) { - ulong acc[4] = { 0u, 0u, 0u, 0u }; - - for (int word = 7; word >= 0; --word) { - ulong w = wide[word]; - for (int bit = 63; bit >= 0; --bit) { - ulong b = (w >> bit) & 1ul; - // shift acc left by 1, or-in b - ulong c0 = b; - ulong c1 = (acc[0] >> 63) & 1ul; - ulong c2 = (acc[1] >> 63) & 1ul; - ulong c3 = (acc[2] >> 63) & 1ul; - ulong c4 = (acc[3] >> 63) & 1ul; - acc[0] = (acc[0] << 1) | c0; - acc[1] = (acc[1] << 1) | c1; - acc[2] = (acc[2] << 1) | c2; - acc[3] = (acc[3] << 1) | c3; - ulong overflow = c4; - - if (overflow != 0u || ge_mod(acc)) { - sub_mod_inplace(acc); - } - } - } - r[0] = acc[0]; r[1] = acc[1]; r[2] = acc[2]; r[3] = acc[3]; -} - -inline void mul_mod(thread ulong* r, thread const ulong* a, thread const ulong* b) { - ulong wide[8]; - mul_512(wide, a, b); - reduce_512(r, wide); -} - -inline void square_mod(thread ulong* r, thread const ulong* a) { - mul_mod(r, a, a); -} - -// x^5 = x^2 * x^2 * x. -inline void pow5_mod(thread ulong* r, thread const ulong* a) { - ulong a2[4], a4[4]; - square_mod(a2, a); - square_mod(a4, a2); - mul_mod(r, a4, a); -} - -// Big-endian 32 bytes -> 4 LE limbs. -inline void from_bytes_be(thread ulong* r, device const uchar* bytes) { - for (int i = 0; i < 4; ++i) { - device const uchar* p = bytes + (3 - i) * 8; - r[i] = (ulong(p[0]) << 56) | (ulong(p[1]) << 48) | - (ulong(p[2]) << 40) | (ulong(p[3]) << 32) | - (ulong(p[4]) << 24) | (ulong(p[5]) << 16) | - (ulong(p[6]) << 8) | (ulong(p[7])); - } -} - -inline void to_bytes_be(device uchar* out, thread const ulong* a) { - for (int i = 0; i < 4; ++i) { - ulong w = a[3 - i]; - device uchar* p = out + i * 8; - p[0] = uchar(w >> 56); - p[1] = uchar(w >> 48); - p[2] = uchar(w >> 40); - p[3] = uchar(w >> 32); - p[4] = uchar(w >> 24); - p[5] = uchar(w >> 16); - p[6] = uchar(w >> 8); - p[7] = uchar(w); - } -} - -// External matrix [[2,1],[1,2]] times state. -inline void mat_external(thread ulong* s0, thread ulong* s1) { - ulong tmp[4], old_s0[4], old_s1[4]; - add_mod(tmp, s0, s1); - for (int i = 0; i < 4; ++i) old_s0[i] = s0[i]; - for (int i = 0; i < 4; ++i) old_s1[i] = s1[i]; - add_mod(s0, tmp, old_s0); - add_mod(s1, tmp, old_s1); -} - -// Internal matrix [[2,1],[1,3]] times state. -inline void mat_internal(thread ulong* s0, thread ulong* s1) { - ulong sum[4], old_s0[4], two_s1[4]; - add_mod(sum, s0, s1); - for (int i = 0; i < 4; ++i) old_s0[i] = s0[i]; - add_mod(s0, old_s0, sum); - double_mod(two_s1, s1); - add_mod(s1, two_s1, sum); -} - -inline void add_rk_full(thread ulong* s0, thread ulong* s1, - constant const ulong* k0, constant const ulong* k1) { - ulong tk0[4] = { k0[0], k0[1], k0[2], k0[3] }; - ulong tk1[4] = { k1[0], k1[1], k1[2], k1[3] }; - ulong tmp0[4], tmp1[4]; - add_mod(tmp0, s0, tk0); - add_mod(tmp1, s1, tk1); - for (int i = 0; i < 4; ++i) s0[i] = tmp0[i]; - for (int i = 0; i < 4; ++i) s1[i] = tmp1[i]; -} - -inline void add_rk_partial(thread ulong* s0, constant const ulong* k0) { - ulong tk0[4] = { k0[0], k0[1], k0[2], k0[3] }; - ulong tmp[4]; - add_mod(tmp, s0, tk0); - for (int i = 0; i < 4; ++i) s0[i] = tmp[i]; -} - -inline void sbox(thread ulong* s) { - ulong t[4]; - pow5_mod(t, s); - for (int i = 0; i < 4; ++i) s[i] = t[i]; -} - -// Round constants (LE limbs of 256-bit values). Generated from -// gnark-crypto v0.20.1 NewParameters(2,6,50). DO NOT EDIT. -constant ulong RC_FULL_PRE[3][2][4] = { - { { 0x8b9b8d31770fac4ful, 0x08580a5f5e295e16ul, 0x4584f763db50a819ul, 0x1da4d6adfb0d0b49ul }, - { 0x51fab2a59acc074cul, 0xab80d0a0c7a0c634ul, 0x19707a56a3b3790eul, 0x0946129a2e33b4e8ul } }, - { { 0x68794dad4b8f82eaul, 0x867b87303b071402ul, 0x580abd6952986570ul, 0x2a39b9d5376afd35ul }, - { 0x32448eea5e3069b2ul, 0xe4a70e66ecd6de5dul, 0xc546b3c7014e5fa3ul, 0x27605717d1245c20ul } }, - { { 0x8e50809c03773632ul, 0xe771ebb09a228d63ul, 0x973653193a470ed7ul, 0x24c896cb2594e17bul }, - { 0xd8042e1b7d597a49ul, 0x6d8701c7869a919aul, 0x0d61d783957003dbul, 0x0911096c45dd9cdaul } }, -}; - -constant ulong RC_PARTIAL[50][4] = { - { 0xf11f38a70603afdcul, 0xf1daba01aad28021ul, 0x27eee6d352a8ce26ul, 0x26ff6166f1e4b99eul }, - { 0x4d5b89437bf64b7eul, 0x79bcb5e18cfb7d95ul, 0xad6591ff90e50feaul, 0x008e2faedcf76d08ul }, - { 0x32f98820eb353d78ul, 0xf6184a3cec71eeb6ul, 0xe3ad4d1872470830ul, 0x19c9da2379b598acul }, - { 0xa459a40583dfe96bul, 0xab3dfcdb57e2539aul, 0xa8f6a090ec9610a2ul, 0x0f7c4eb15d8b0b62ul }, - { 0x9171588b08ac3983ul, 0xd46900b47dd5ff80ul, 0x79750eba282362d1ul, 0x18b99417dc26b5e0ul }, - { 0x8848af08a8d91a26ul, 0xc73f1c054b76320aul, 0xe2d4493feab82141ul, 0x1ee044081160b3eeul }, - { 0x9122220366c258dcul, 0x0455b6d7a14780d1ul, 0x0e87f5df12ee8a15ul, 0x29bb95c8763efd3eul }, - { 0x067b4705e9a5c9fbul, 0x215e8991f7f9ec12ul, 0xa3ee9a363d740653ul, 0x22c23eec9cb13ff8ul }, - { 0xc7deefc1690b42dcul, 0x9115c7644c4f905cul, 0x680c8b18926c3be0ul, 0x23589e033a31a667ul }, - { 0xdd64ec938a6a247cul, 0x64f60b98a219f1f0ul, 0x2c9c0cde5f2bdd47ul, 0x304e99b887f2e1e9ul }, - { 0xa493b2de7f7d8b4cul, 0xe792f326271d53a3ul, 0x76fbe88bbdf31fcaul, 0x22e817865236ad3aul }, - { 0x26b77c881dc8b1c9ul, 0x02bc7dcfcc4878e0ul, 0xb238a3f5c70a00bful, 0x10c9efe573e86fa5ul }, - { 0x3ea283c72b89bbb6ul, 0xc084f309edffef02ul, 0x4d6f80e745c6bddbul, 0x0a94f16be920d85ful }, - { 0xdf3da0ebcd69101aul, 0xbc2c1ea81d6465b5ul, 0xc7888fedb770494aul, 0x23ed72b4d01d14e3ul }, - { 0xe8d8f1022eb6bb74ul, 0x98df815a77cb162dul, 0xed0e6cbb511ac384ul, 0x17c5115640e4cebeul }, - { 0xb99680a82772e4fcul, 0x73e09e1fc813959ful, 0xcf765245750eb047ul, 0x2e507fcca290d0d9ul }, - { 0xfc8803013292bbbful, 0x2433f055cb71d70dul, 0x6af6cce65c8d1216ul, 0x0d4a98999f5b3917ul }, - { 0xa1bd3c335eccfa21ul, 0x321d829e22a7a3b7ul, 0xb3c01261a03dc125ul, 0x238d8022cc09c21aul }, - { 0x70e454ed9f4ae165ul, 0x43ce365613d5b397ul, 0xb81dc8292e35d8e9ul, 0x010cd8e4c2b7051cul }, - { 0x8e03539590ecc027ul, 0x861173a90fb23123ul, 0xb11178cf0ea3c6aaul, 0x088027e54f2a3604ul }, - { 0x8302b73e464819dbul, 0x61165468daa93810ul, 0xb4cd7aa7e5a9a6d1ul, 0x1b840f5311a2b1d4ul }, - { 0xcb366fcd8db78dc1ul, 0x9a3a346e44198ea0ul, 0xf9b764b1e16c1592ul, 0x2bf51a5da1828a1cul }, - { 0xd56a3a50f6f23827ul, 0x26b8872f9c7beef9ul, 0xe68a6a86767a7fe7ul, 0x206ad089d8d296fful }, - { 0x9fc07ff88eac550cul, 0xeaec4b6d617c7c19ul, 0x1a54e0a99ac16d05ul, 0x24d19193171494faul }, - { 0x95d9f9fbaf152138ul, 0x2c40662278b9c0b9ul, 0xf33d88246a40dfb3ul, 0x1dd654a2ca9d9f24ul }, - { 0x8acf57f3789785e4ul, 0x8c85f05b0fe7793bul, 0x59d20ecbd3a60110ul, 0x0d171025c925f6e2ul }, - { 0x98b9908fed877d55ul, 0x1d9e8cf2b80d95ddul, 0x45cd99ccb0f7c879ul, 0x055bef435a43aec2ul }, - { 0x3b75765d0e595f0eul, 0xbe63d0ad4cfd9f20ul, 0x8a2a3f42d9359a14ul, 0x10d2ac8c61c8a2e8ul }, - { 0x296e3d0790dcfa56ul, 0x254d7966ddc41e85ul, 0x82a84057ec3ba6b2ul, 0x103479710e709969ul }, - { 0x8225f8372921807bul, 0xde96f9aa340db163ul, 0x5914ffb48c765da8ul, 0x2a366f0448fda3c0ul }, - { 0x04626b63285837a4ul, 0x153bb8899aeb686cul, 0x919b6e0378d00594ul, 0x16be0fb8ef62da17ul }, - { 0x2cfd8c78a15bc370ul, 0x2dec8a0b2aa5a4d5ul, 0x60abbc7f0d0d24c3ul, 0x0417038500e9d06cul }, - { 0xe5a0308fdd013812ul, 0x4f4d76480bc486a3ul, 0xf66ec6f4493ff9b5ul, 0x26a6873b43ffd2ccul }, - { 0x391922d36278c498ul, 0x2f3db9b2aa8aa9f7ul, 0xa96251914fe5ad26ul, 0x0a3314a838f32630ul }, - { 0x44f4b59399027da4ul, 0xaeadb46a090b15a0ul, 0x7f462d4821f48f86ul, 0x0fde0c5429a6beb0ul }, - { 0xf11a1c89cda2c8a9ul, 0xb517e1eb535a984dul, 0x9b357e4163793b0bul, 0x0abc2d5049972a6bul }, - { 0x735205dbb77c464eul, 0x1b427a5f173f7f17ul, 0x1d21722fb21e5505ul, 0x0dab51d6e3ebfa66ul }, - { 0xe9b92e70a8c1ae03ul, 0xb21be5c2d68fb8eful, 0x51af6cc37123652ful, 0x29c36622598b511dul }, - { 0x0d6bbe7ea2fbc85cul, 0xc41c4096e53ac699ul, 0xae846bc0b700d0bcul, 0x2c03ec80adac2a33ul }, - { 0xb8b521e3cebe1e0dul, 0x317303116017d893ul, 0xbdb4c6852d065105ul, 0x0918fdbe9cf3a59ful }, - { 0xa9d2135acabe2359ul, 0x176a79df4f7acf79ul, 0x599dd13cd7e495a8ul, 0x1f19ec22e69ca33ful }, - { 0x46447b2911080704ul, 0x8c4c76b29ceb0013ul, 0xeb32b872eb7f7692ul, 0x1c4b037c8ae85ee1ul }, - { 0xaf0a69c534efd8dbul, 0xba3131e04d420de5ul, 0x6c826d0bde341766ul, 0x2b68900ed906616dul }, - { 0x723d4f0c782bc7f0ul, 0xa180ff6dfb397ceful, 0x448f8dac653c8daaul, 0x20ca92aa222fcc69ul }, - { 0x1f6186289b66b2b6ul, 0x7045a5ab7ac05394ul, 0x75276fc82057db33ul, 0x10d22d05bdff6bb3ul }, - { 0xe5ca79e88bdd6cd9ul, 0x5aa910b4a00636d1ul, 0x98f32ba45784cb43ul, 0x0b1ffdbb529367bbul }, - { 0x6f89bcc33022ce9ful, 0x82c7b4f82e5e515ful, 0xed757ec705eccf82ul, 0x2da32b38e7984bc2ul }, - { 0xd292cecc6abed1bbul, 0xb87eb33958031e94ul, 0x2674b8b55a886725ul, 0x042593ad87403f6dul }, - { 0x9308fbd6e7d419f1ul, 0x051faedab463a6deul, 0x19d7367bf49b3f01ul, 0x181fa1b4d067783aul }, - { 0x34927188c60de660ul, 0xa8e35b00ed8a815cul, 0x5683c95515c26028ul, 0x15aaa6cc9b7900b1ul }, -}; - -constant ulong RC_FULL_POST[3][2][4] = { - { { 0x6787320212f75a7aul, 0x80b1a636f5a0ecedul, 0xbc63234f057254c2ul, 0x1bf28a93209084bbul }, - { 0x6f8c9d81a620cb6ful, 0x56273b312f805c8eul, 0x2cd9e229776118f1ul, 0x1cdb8c8bee5426f0ul } }, - { { 0xf8ea0e8760d5ca7eul, 0x516df2505cc387a0ul, 0x162e0facb5f1876ful, 0x08299c0abf196d53ul }, - { 0x4117269d7c710b8aul, 0x2d76c0072cabd112ul, 0x8a7b7b58cb65c496ul, 0x221643d205fe8277ul } }, - { { 0x487920d43343dbfful, 0xf6bd10c3f74b22dbul, 0xb7a0143a28c88767ul, 0x2d036a95f81cf49bul }, - { 0x6d5d1bb7a172ebd2ul, 0x7cd4a39486729fbcul, 0xea414fb1bceca226ul, 0x08a50897c06aafe6ul } }, -}; - -kernel void poseidon2_t2_jobs( - device const uchar* states_in [[buffer(0)]], - device uchar* states_out [[buffer(1)]], - constant uint& n [[buffer(2)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= n) return; - - ulong s0[4], s1[4]; - from_bytes_be(s0, states_in + tid * 64); - from_bytes_be(s1, states_in + tid * 64 + 32); - - if (ge_mod(s0) || ge_mod(s1)) { - // Non-canonical input: write zeros to flag failure. - for (int k = 0; k < 64; ++k) states_out[tid * 64 + k] = 0u; - return; - } - - mat_external(s0, s1); - - for (int i = 0; i < 3; ++i) { - add_rk_full(s0, s1, RC_FULL_PRE[i][0], RC_FULL_PRE[i][1]); - sbox(s0); sbox(s1); - mat_external(s0, s1); - } - for (int i = 0; i < 50; ++i) { - add_rk_partial(s0, RC_PARTIAL[i]); - sbox(s0); - mat_internal(s0, s1); - } - for (int i = 0; i < 3; ++i) { - add_rk_full(s0, s1, RC_FULL_POST[i][0], RC_FULL_POST[i][1]); - sbox(s0); sbox(s1); - mat_external(s0, s1); - } - - to_bytes_be(states_out + tid * 64, s0); - to_bytes_be(states_out + tid * 64 + 32, s1); -} diff --git a/poseidon/gpu/metal/poseidon2_t2_batch_driver.mm b/poseidon/gpu/metal/poseidon2_t2_batch_driver.mm deleted file mode 100644 index fbda408..0000000 --- a/poseidon/gpu/metal/poseidon2_t2_batch_driver.mm +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Metal driver for batched Poseidon2-BN254 t=2 permutation. - -#if __APPLE__ && __OBJC__ - -#import -#import - -#include -#include -#include - -extern "C" int poseidon2_t2_batch_metal( - const uint8_t* states_in, - size_t n, - uint8_t* states_out, - const char* metallib_path) { - - if (n == 0) return 0; - if (!states_in || !states_out || !metallib_path) return -1; - - @autoreleasepool { - id device = MTLCreateSystemDefaultDevice(); - if (!device) return -2; - - NSError* err = nil; - NSString* path = [NSString stringWithUTF8String:metallib_path]; - NSURL* url = [NSURL fileURLWithPath:path]; - id lib = [device newLibraryWithURL:url error:&err]; - if (!lib) return -3; - - id fn = [lib newFunctionWithName:@"poseidon2_t2_jobs"]; - if (!fn) return -4; - - id pipeline = - [device newComputePipelineStateWithFunction:fn error:&err]; - if (!pipeline) return -5; - - id queue = [device newCommandQueue]; - - id in_buf = [device newBufferWithBytes:states_in - length:n * 64 - options:MTLResourceStorageModeShared]; - id out_buf = [device newBufferWithLength:n * 64 - options:MTLResourceStorageModeShared]; - uint32_t n_u32 = (uint32_t)n; - id n_buf = [device newBufferWithBytes:&n_u32 - length:sizeof(n_u32) - options:MTLResourceStorageModeShared]; - - id cmd = [queue commandBuffer]; - id enc = [cmd computeCommandEncoder]; - [enc setComputePipelineState:pipeline]; - [enc setBuffer:in_buf offset:0 atIndex:0]; - [enc setBuffer:out_buf offset:0 atIndex:1]; - [enc setBuffer:n_buf offset:0 atIndex:2]; - - NSUInteger tg_max = pipeline.maxTotalThreadsPerThreadgroup; - NSUInteger tg_w = tg_max < 32 ? tg_max : 32; - MTLSize threads_per_grid = MTLSizeMake(n, 1, 1); - MTLSize threads_per_tg = MTLSizeMake(tg_w, 1, 1); - [enc dispatchThreads:threads_per_grid threadsPerThreadgroup:threads_per_tg]; - [enc endEncoding]; - [cmd commit]; - [cmd waitUntilCompleted]; - - std::memcpy(states_out, [out_buf contents], n * 64); - } - return 0; -} - -#endif diff --git a/poseidon/gpu/wgsl/poseidon2_bn254.wgsl b/poseidon/gpu/wgsl/poseidon2_bn254.wgsl deleted file mode 100644 index b6b6d56..0000000 --- a/poseidon/gpu/wgsl/poseidon2_bn254.wgsl +++ /dev/null @@ -1,467 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// First-party WGSL kernel for Poseidon2-BN254 (canonical default permutation). -// -// Mechanical port of poseidon/gpu/metal/poseidon2_bn254.metal -- byte-for-byte -// equivalent to lux::crypto::poseidon::hash2 in poseidon/cpp/poseidon.cpp -// (gnark-crypto v0.20.1 ecc/bn254/fr/poseidon2 with t=2, rF=6, rP=50, d=5). -// -// WGSL has no native u64. Each 64-bit Montgomery limb is represented as a pair -// of u32 (lo, hi), and 64-bit ops are reconstructed from 32-bit primitives. -// -// The round-key constants come from poseidon2_bn254_rk.wgslh, which is emitted -// by the CPU body's dump_round_keys -> gen_gpu_constants. There is exactly one -// source of truth across CPU, Metal, CUDA, and WGSL. - -// === BEGIN POSEIDON2_RK_LO/HI (auto-generated, prepended at compile time) === -// In production, the host driver concatenates poseidon2_bn254_rk.wgslh in -// front of this file before submitting it to wgpu. The kernel body below -// references POSEIDON2_RK_LO[round][slot][limb] and POSEIDON2_RK_HI[...]. -// === END POSEIDON2_RK_LO/HI === - -// ============================================================================= -// 64-bit unsigned represented as (lo, hi) pair of u32. -// ============================================================================= -struct U64 { lo: u32, hi: u32 }; - -fn u64_make(lo: u32, hi: u32) -> U64 { - return U64(lo, hi); -} - -fn u64_zero() -> U64 { return U64(0u, 0u); } - -fn u64_lt(a: U64, b: U64) -> bool { - if (a.hi != b.hi) { return a.hi < b.hi; } - return a.lo < b.lo; -} - -fn u64_eq(a: U64, b: U64) -> bool { - return a.lo == b.lo && a.hi == b.hi; -} - -// 64-bit add. Returns sum and carry-out (0 or 1). -fn u64_add(a: U64, b: U64) -> U64 { - let lo = a.lo + b.lo; - let c0: u32 = select(0u, 1u, lo < a.lo); - let hi = a.hi + b.hi + c0; - return U64(lo, hi); -} - -// 64-bit add with carry. Returns sum; carry_out written via separate path. -struct U64Carry { v: U64, carry: u32 }; - -fn u64_add_carry(a: U64, b: U64, cin: u32) -> U64Carry { - let lo1 = a.lo + b.lo; - let c0: u32 = select(0u, 1u, lo1 < a.lo); - let lo = lo1 + cin; - let c1: u32 = select(0u, 1u, lo < lo1); - let hi1 = a.hi + b.hi; - let c2: u32 = select(0u, 1u, hi1 < a.hi); - let hi2 = hi1 + (c0 + c1); - let c3: u32 = select(0u, 1u, hi2 < hi1); - return U64Carry(U64(lo, hi2), c2 + c3); -} - -// 64-bit sub with borrow. Returns difference; borrow_out via field. -struct U64Borrow { v: U64, borrow: u32 }; - -fn u64_sub_borrow(a: U64, b: U64, bin: u32) -> U64Borrow { - let lo1: u32 = a.lo - b.lo; - let bor0: u32 = select(0u, 1u, a.lo < b.lo); - let lo: u32 = lo1 - bin; - let bor1: u32 = select(0u, 1u, lo1 < bin); - let hi1: u32 = a.hi - b.hi; - let bor2: u32 = select(0u, 1u, a.hi < b.hi); - let hi: u32 = hi1 - (bor0 + bor1); - let bor3: u32 = select(0u, 1u, hi1 < (bor0 + bor1)); - return U64Borrow(U64(lo, hi), bor2 + bor3); -} - -// 64x64 -> 128 multiply via four 32x32 -> 64 partial products. -struct U128 { l0: u32, l1: u32, l2: u32, l3: u32 }; - -fn umul64(a: U64, b: U64) -> U128 { - // a = a.hi*2^32 + a.lo, b = b.hi*2^32 + b.lo - // a*b = a.lo*b.lo - // + (a.lo*b.hi + a.hi*b.lo) * 2^32 - // + a.hi*b.hi * 2^64 - // Each 32x32 -> u64 product is decomposed into (lo, hi) u32. - let p_ll: U64 = u64_make_from_u32_mul(a.lo, b.lo); - let p_lh: U64 = u64_make_from_u32_mul(a.lo, b.hi); - let p_hl: U64 = u64_make_from_u32_mul(a.hi, b.lo); - let p_hh: U64 = u64_make_from_u32_mul(a.hi, b.hi); - - // word 0 = p_ll.lo - let w0: u32 = p_ll.lo; - // word 1 = p_ll.hi + p_lh.lo + p_hl.lo - let s1a: u32 = p_ll.hi + p_lh.lo; - let c1a: u32 = select(0u, 1u, s1a < p_ll.hi); - let w1: u32 = s1a + p_hl.lo; - let c1b: u32 = select(0u, 1u, w1 < s1a); - let carry1: u32 = c1a + c1b; - // word 2 = p_lh.hi + p_hl.hi + p_hh.lo + carry1 - let s2a: u32 = p_lh.hi + p_hl.hi; - let c2a: u32 = select(0u, 1u, s2a < p_lh.hi); - let s2b: u32 = s2a + p_hh.lo; - let c2b: u32 = select(0u, 1u, s2b < s2a); - let w2: u32 = s2b + carry1; - let c2c: u32 = select(0u, 1u, w2 < s2b); - let carry2: u32 = c2a + c2b + c2c; - // word 3 = p_hh.hi + carry2 - let w3: u32 = p_hh.hi + carry2; - - return U128(w0, w1, w2, w3); -} - -fn u64_make_from_u32_mul(a: u32, b: u32) -> U64 { - // 32x32 -> 64 multiply. We split each operand into 16-bit halves to - // avoid overflowing u32 in the partial products. - let al: u32 = a & 0xffffu; - let ah: u32 = a >> 16u; - let bl: u32 = b & 0xffffu; - let bh: u32 = b >> 16u; - let ll: u32 = al * bl; // up to 0xfffe0001 - let lh: u32 = al * bh; - let hl: u32 = ah * bl; - let hh: u32 = ah * bh; - // mid = (ll >> 16) + (lh & 0xffff) + (hl & 0xffff) - let mid_a: u32 = (ll >> 16u) + (lh & 0xffffu); - let mid_b: u32 = mid_a + (hl & 0xffffu); - let mid_carry: u32 = (mid_b >> 16u); // overflow into hi - let lo: u32 = (ll & 0xffffu) | (mid_b << 16u); - let hi: u32 = hh + (lh >> 16u) + (hl >> 16u) + mid_carry; - return U64(lo, hi); -} - -// 64-bit u64 product of two U64; returns the low 64 bits only (used for -// computing m = t[0] * Q_INV_NEG mod 2^64). -fn u64_mul_low(a: U64, b: U64) -> U64 { - let p_ll: U64 = u64_make_from_u32_mul(a.lo, b.lo); - let p_lh: U64 = u64_make_from_u32_mul(a.lo, b.hi); - let p_hl: U64 = u64_make_from_u32_mul(a.hi, b.lo); - let lo: u32 = p_ll.lo; - let hi: u32 = p_ll.hi + p_lh.lo + p_hl.lo; - return U64(lo, hi); -} - -// ============================================================================= -// BN254 Fr modulus + Montgomery params (Q_INV_NEG, R_SQUARE). -// Each constant declared as (lo, hi) u32 pair. -// ============================================================================= -const Q0_LO: u32 = 0xf0000001u; const Q0_HI: u32 = 0x43e1f593u; -const Q1_LO: u32 = 0x79b97091u; const Q1_HI: u32 = 0x2833e848u; -const Q2_LO: u32 = 0x8181585du; const Q2_HI: u32 = 0xb85045b6u; -const Q3_LO: u32 = 0xe131a029u; const Q3_HI: u32 = 0x30644e72u; -const QINV_LO: u32 = 0xefffffffu; const QINV_HI: u32 = 0xc2e1f593u; - -// R^2 mod q, Montgomery form. -// 1997599621687373223 = 0x1bb8e645ae216da7 -// 6052339484930628067 = 0x53fe3ab1e35c59e3 -// 10108755138030829701 = 0x8c49833d53bb8085 -// 150537098327114917 = 0x0216d0b17f4e44a5 -const R2_0_LO: u32 = 0xae216da7u; const R2_0_HI: u32 = 0x1bb8e645u; -const R2_1_LO: u32 = 0xe35c59e3u; const R2_1_HI: u32 = 0x53fe3ab1u; -const R2_2_LO: u32 = 0x53bb8085u; const R2_2_HI: u32 = 0x8c49833du; -const R2_3_LO: u32 = 0x7f4e44a5u; const R2_3_HI: u32 = 0x0216d0b1u; - -// ============================================================================= -// 256-bit Fr in Montgomery form, four U64 limbs. -// ============================================================================= -struct Fr { - l0: U64, - l1: U64, - l2: U64, - l3: U64, -}; - -fn fr_zero() -> Fr { - return Fr(u64_zero(), u64_zero(), u64_zero(), u64_zero()); -} - -fn fr_q() -> Fr { - return Fr(U64(Q0_LO, Q0_HI), U64(Q1_LO, Q1_HI), - U64(Q2_LO, Q2_HI), U64(Q3_LO, Q3_HI)); -} - -fn fr_r2() -> Fr { - return Fr(U64(R2_0_LO, R2_0_HI), U64(R2_1_LO, R2_1_HI), - U64(R2_2_LO, R2_2_HI), U64(R2_3_LO, R2_3_HI)); -} - -fn cmp_q(a: Fr) -> i32 { - let q = fr_q(); - if (!u64_eq(a.l3, q.l3)) { - if (u64_lt(a.l3, q.l3)) { return -1; } else { return 1; } - } - if (!u64_eq(a.l2, q.l2)) { - if (u64_lt(a.l2, q.l2)) { return -1; } else { return 1; } - } - if (!u64_eq(a.l1, q.l1)) { - if (u64_lt(a.l1, q.l1)) { return -1; } else { return 1; } - } - if (!u64_eq(a.l0, q.l0)) { - if (u64_lt(a.l0, q.l0)) { return -1; } else { return 1; } - } - return 0; -} - -fn fr_sub_q(a: Fr) -> Fr { - let q = fr_q(); - let r0 = u64_sub_borrow(a.l0, q.l0, 0u); - let r1 = u64_sub_borrow(a.l1, q.l1, r0.borrow); - let r2 = u64_sub_borrow(a.l2, q.l2, r1.borrow); - let r3 = u64_sub_borrow(a.l3, q.l3, r2.borrow); - return Fr(r0.v, r1.v, r2.v, r3.v); -} - -fn reduce_once(a: Fr) -> Fr { - if (cmp_q(a) >= 0) { return fr_sub_q(a); } - return a; -} - -fn fr_add(a: Fr, b: Fr) -> Fr { - let r0 = u64_add_carry(a.l0, b.l0, 0u); - let r1 = u64_add_carry(a.l1, b.l1, r0.carry); - let r2 = u64_add_carry(a.l2, b.l2, r1.carry); - let r3 = u64_add_carry(a.l3, b.l3, r2.carry); - var c = Fr(r0.v, r1.v, r2.v, r3.v); - if (r3.carry != 0u || cmp_q(c) >= 0) { - c = fr_sub_q(c); - } - return c; -} - -fn fr_double(a: Fr) -> Fr { return fr_add(a, a); } - -// CIOS Montgomery multiplication. Bit-identical algorithm to the CPU body. -fn fr_mul(a: Fr, b: Fr) -> Fr { - var t0 = u64_zero(); - var t1 = u64_zero(); - var t2 = u64_zero(); - var t3 = u64_zero(); - var t4 = u64_zero(); - let al = array(a.l0, a.l1, a.l2, a.l3); - let bl = array(b.l0, b.l1, b.l2, b.l3); - let qq = array(U64(Q0_LO, Q0_HI), U64(Q1_LO, Q1_HI), - U64(Q2_LO, Q2_HI), U64(Q3_LO, Q3_HI)); - let qinv = U64(QINV_LO, QINV_HI); - - for (var i: i32 = 0; i < 4; i = i + 1) { - // t += a * b[i] - var cy: U64 = u64_zero(); - for (var j: i32 = 0; j < 4; j = j + 1) { - let prod: U128 = umul64(al[j], bl[i]); - let lo: U64 = U64(prod.l0, prod.l1); - let hi: U64 = U64(prod.l2, prod.l3); - // Pick t[j] - var tj: U64 = u64_zero(); - if (j == 0) { tj = t0; } - else if (j == 1) { tj = t1; } - else if (j == 2) { tj = t2; } - else { tj = t3; } - let s = u64_add_carry(tj, lo, 0u); - let s2 = u64_add_carry(s.v, cy, 0u); - // Update t[j] - if (j == 0) { t0 = s2.v; } - else if (j == 1) { t1 = s2.v; } - else if (j == 2) { t2 = s2.v; } - else { t3 = s2.v; } - // cy = hi + s.carry + s2.carry - let cy1 = u64_add_carry(hi, U64(s.carry, 0u), 0u); - let cy2 = u64_add_carry(cy1.v, U64(s2.carry, 0u), 0u); - cy = cy2.v; - } - // t[4] += cy - let t4u = u64_add_carry(t4, cy, 0u); - t4 = t4u.v; - - // m = t[0] * qInvNeg mod 2^64 - let m: U64 = u64_mul_low(t0, qinv); - - // t += m * q - cy = u64_zero(); - for (var j: i32 = 0; j < 4; j = j + 1) { - let prod = umul64(m, qq[j]); - let lo: U64 = U64(prod.l0, prod.l1); - let hi: U64 = U64(prod.l2, prod.l3); - var tj: U64 = u64_zero(); - if (j == 0) { tj = t0; } - else if (j == 1) { tj = t1; } - else if (j == 2) { tj = t2; } - else { tj = t3; } - let s = u64_add_carry(tj, lo, 0u); - let s2 = u64_add_carry(s.v, cy, 0u); - if (j == 0) { t0 = s2.v; } - else if (j == 1) { t1 = s2.v; } - else if (j == 2) { t2 = s2.v; } - else { t3 = s2.v; } - let cy1 = u64_add_carry(hi, U64(s.carry, 0u), 0u); - let cy2 = u64_add_carry(cy1.v, U64(s2.carry, 0u), 0u); - cy = cy2.v; - } - let t4u2 = u64_add_carry(t4, cy, 0u); - t4 = t4u2.v; - - // Shift right by one limb. - t0 = t1; - t1 = t2; - t2 = t3; - t3 = t4; - t4 = u64_zero(); - } - var c = Fr(t0, t1, t2, t3); - c = reduce_once(c); - return c; -} - -fn fr_square(a: Fr) -> Fr { return fr_mul(a, a); } - -// ============================================================================= -// Poseidon2-BN254 default permutation. -// ============================================================================= - -fn sbox(x: Fr) -> Fr { - let x2 = fr_square(x); - let x4 = fr_square(x2); - return fr_mul(x4, x); -} - -struct State2 { s0: Fr, s1: Fr }; - -fn mat_mul_external(s: State2) -> State2 { - let tmp = fr_add(s.s0, s.s1); - return State2(fr_add(s.s0, tmp), fr_add(s.s1, tmp)); -} - -fn mat_mul_internal(s: State2) -> State2 { - let sum = fr_add(s.s0, s.s1); - let s0p = fr_add(s.s0, sum); - let s1d = fr_double(s.s1); - let s1p = fr_add(s1d, sum); - return State2(s0p, s1p); -} - -const FULL_HALF: i32 = 3; -const PARTIAL: i32 = 50; - -fn load_rk(round: i32, slot: i32) -> Fr { - return Fr( - U64(POSEIDON2_RK_LO[round][slot][0], POSEIDON2_RK_HI[round][slot][0]), - U64(POSEIDON2_RK_LO[round][slot][1], POSEIDON2_RK_HI[round][slot][1]), - U64(POSEIDON2_RK_LO[round][slot][2], POSEIDON2_RK_HI[round][slot][2]), - U64(POSEIDON2_RK_LO[round][slot][3], POSEIDON2_RK_HI[round][slot][3]) - ); -} - -fn permute(s_in: State2) -> State2 { - var s = mat_mul_external(s_in); - for (var i: i32 = 0; i < FULL_HALF; i = i + 1) { - let k0 = load_rk(i, 0); - let k1 = load_rk(i, 1); - s.s0 = fr_add(s.s0, k0); - s.s1 = fr_add(s.s1, k1); - s.s0 = sbox(s.s0); - s.s1 = sbox(s.s1); - s = mat_mul_external(s); - } - for (var i: i32 = 0; i < PARTIAL; i = i + 1) { - let k0 = load_rk(FULL_HALF + i, 0); - s.s0 = fr_add(s.s0, k0); - s.s0 = sbox(s.s0); - s = mat_mul_internal(s); - } - for (var i: i32 = 0; i < FULL_HALF; i = i + 1) { - let k0 = load_rk(FULL_HALF + PARTIAL + i, 0); - let k1 = load_rk(FULL_HALF + PARTIAL + i, 1); - s.s0 = fr_add(s.s0, k0); - s.s1 = fr_add(s.s1, k1); - s.s0 = sbox(s.s0); - s.s1 = sbox(s.s1); - s = mat_mul_external(s); - } - return s; -} - -// ============================================================================= -// Bytes (BE) <-> Fr conversions, Montgomery form on the inside. -// ============================================================================= - -fn read_be64(buf: ptr, read>, byte_off: u32) -> U64 { - // Read 8 bytes big-endian into a U64. WGSL storage buffers are word- - // addressable, so we unpack from u32 words. - let w0 = (*buf)[byte_off / 4u]; - let w1 = (*buf)[byte_off / 4u + 1u]; - // be: bytes [b0 b1 b2 b3 b4 b5 b6 b7] -> u64 = b0<<56 | ... | b7 - // Little-endian u32 word holds bytes [w&0xff, (w>>8)&0xff, ..., (w>>24)&0xff] - // at increasing byte addresses. - let b0: u32 = (w0 ) & 0xffu; - let b1: u32 = (w0 >> 8u) & 0xffu; - let b2: u32 = (w0 >> 16u) & 0xffu; - let b3: u32 = (w0 >> 24u) & 0xffu; - let b4: u32 = (w1 ) & 0xffu; - let b5: u32 = (w1 >> 8u) & 0xffu; - let b6: u32 = (w1 >> 16u) & 0xffu; - let b7: u32 = (w1 >> 24u) & 0xffu; - let hi: u32 = (b0 << 24u) | (b1 << 16u) | (b2 << 8u) | b3; - let lo: u32 = (b4 << 24u) | (b5 << 16u) | (b6 << 8u) | b7; - return U64(lo, hi); -} - -fn be_to_fr_mont(buf: ptr, read>, off: u32) -> Fr { - var x = Fr(read_be64(buf, off + 24u), - read_be64(buf, off + 16u), - read_be64(buf, off + 8u), - read_be64(buf, off )); - for (var i: i32 = 0; i < 4; i = i + 1) { - if (cmp_q(x) < 0) { break; } - x = fr_sub_q(x); - } - return fr_mul(x, fr_r2()); -} - -fn write_be64(buf: ptr, read_write>, byte_off: u32, v: U64) { - // BE bytes: [hi>>24, hi>>16, hi>>8, hi, lo>>24, lo>>16, lo>>8, lo] - let b0: u32 = (v.hi >> 24u) & 0xffu; - let b1: u32 = (v.hi >> 16u) & 0xffu; - let b2: u32 = (v.hi >> 8u) & 0xffu; - let b3: u32 = v.hi & 0xffu; - let b4: u32 = (v.lo >> 24u) & 0xffu; - let b5: u32 = (v.lo >> 16u) & 0xffu; - let b6: u32 = (v.lo >> 8u) & 0xffu; - let b7: u32 = v.lo & 0xffu; - let w0: u32 = b0 | (b1 << 8u) | (b2 << 16u) | (b3 << 24u); - let w1: u32 = b4 | (b5 << 8u) | (b6 << 16u) | (b7 << 24u); - (*buf)[byte_off / 4u] = w0; - (*buf)[byte_off / 4u + 1u] = w1; -} - -fn fr_mont_to_be(buf: ptr, read_write>, off: u32, x: Fr) { - let one_reg = Fr(U64(1u, 0u), U64(0u, 0u), U64(0u, 0u), U64(0u, 0u)); - let r = fr_mul(x, one_reg); - write_be64(buf, off + 0u, r.l3); - write_be64(buf, off + 8u, r.l2); - write_be64(buf, off + 16u, r.l1); - write_be64(buf, off + 24u, r.l0); -} - -// ============================================================================= -// Storage bindings + kernel. -// ============================================================================= -@group(0) @binding(0) var g_pairs: array; // n*64 bytes -@group(0) @binding(1) var g_outs: array; // n*32 bytes -@group(0) @binding(2) var g_n: u32; - -@compute @workgroup_size(64) -fn poseidon2_hash2_batch(@builtin(global_invocation_id) gid: vec3) { - let i: u32 = gid.x; - if (i >= g_n) { return; } - let in_off: u32 = i * 64u; - let out_off: u32 = i * 32u; - var s = State2(be_to_fr_mont(&g_pairs, in_off), - be_to_fr_mont(&g_pairs, in_off + 32u)); - let saved_right = s.s1; - s = permute(s); - let digest = fr_add(saved_right, s.s1); - fr_mont_to_be(&g_outs, out_off, digest); -} diff --git a/poseidon/gpu/wgsl/poseidon2_driver.cpp b/poseidon/gpu/wgsl/poseidon2_driver.cpp deleted file mode 100644 index 72cca65..0000000 --- a/poseidon/gpu/wgsl/poseidon2_driver.cpp +++ /dev/null @@ -1,350 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// WGSL driver for Poseidon2-BN254 -- C++ host polyfill of the WGSL kernel. -// -// Mirrors poseidon/gpu/wgsl/poseidon2_bn254.wgsl byte-for-byte: each 64-bit -// Montgomery limb is represented as a (lo, hi) pair of uint32_t, and every -// 64-bit op is reconstructed from u32 primitives. This is exactly what WGSL -// does; running the same arithmetic on the host gives byte-equal output to -// both the WGSL kernel and the CPU oracle. -// -// The round-key constants live in poseidon2_bn254_rk.wgslh (auto-generated -// from the CPU body), which exposes both POSEIDON2_RK_LO/HI on __cplusplus -// and on the WGSL side. Single source of truth. - -#include "poseidon2_driver.h" -#include "poseidon2_bn254_rk.wgslh" - -#include -#include - -namespace { - -// ============================================================================= -// 64-bit unsigned as (lo, hi) of u32. Matches WGSL `struct U64`. -// ============================================================================= -struct U64 { uint32_t lo, hi; }; - -inline U64 u64_make(uint32_t lo, uint32_t hi) { return U64{lo, hi}; } -inline U64 u64_zero() { return U64{0u, 0u}; } -inline bool u64_lt(U64 a, U64 b) { - if (a.hi != b.hi) return a.hi < b.hi; - return a.lo < b.lo; -} -inline bool u64_eq(U64 a, U64 b) { return a.lo == b.lo && a.hi == b.hi; } - -struct U64Carry { U64 v; uint32_t carry; }; -struct U64Borrow { U64 v; uint32_t borrow; }; - -inline U64Carry u64_add_carry(U64 a, U64 b, uint32_t cin) { - uint32_t lo1 = a.lo + b.lo; - uint32_t c0 = (lo1 < a.lo) ? 1u : 0u; - uint32_t lo = lo1 + cin; - uint32_t c1 = (lo < lo1) ? 1u : 0u; - uint32_t hi1 = a.hi + b.hi; - uint32_t c2 = (hi1 < a.hi) ? 1u : 0u; - uint32_t hi2 = hi1 + (c0 + c1); - uint32_t c3 = (hi2 < hi1) ? 1u : 0u; - return U64Carry{U64{lo, hi2}, c2 + c3}; -} - -inline U64Borrow u64_sub_borrow(U64 a, U64 b, uint32_t bin) { - uint32_t lo1 = a.lo - b.lo; - uint32_t bor0 = (a.lo < b.lo) ? 1u : 0u; - uint32_t lo = lo1 - bin; - uint32_t bor1 = (lo1 < bin) ? 1u : 0u; - uint32_t hi1 = a.hi - b.hi; - uint32_t bor2 = (a.hi < b.hi) ? 1u : 0u; - uint32_t bor_in_hi = bor0 + bor1; - uint32_t hi = hi1 - bor_in_hi; - uint32_t bor3 = (hi1 < bor_in_hi) ? 1u : 0u; - return U64Borrow{U64{lo, hi}, bor2 + bor3}; -} - -// 32x32 -> 64 multiply, decomposed via 16-bit halves so no intermediate -// product exceeds u32. This is the byte-equivalent of WGSL's -// u64_make_from_u32_mul. -inline U64 u32_mul(uint32_t a, uint32_t b) { - uint32_t al = a & 0xffffu; - uint32_t ah = a >> 16; - uint32_t bl = b & 0xffffu; - uint32_t bh = b >> 16; - uint32_t ll = al * bl; - uint32_t lh = al * bh; - uint32_t hl = ah * bl; - uint32_t hh = ah * bh; - uint32_t mid_a = (ll >> 16) + (lh & 0xffffu); - uint32_t mid_b = mid_a + (hl & 0xffffu); - uint32_t mid_carry = (mid_b >> 16); - uint32_t lo = (ll & 0xffffu) | (mid_b << 16); - uint32_t hi = hh + (lh >> 16) + (hl >> 16) + mid_carry; - return U64{lo, hi}; -} - -struct U128 { uint32_t l0, l1, l2, l3; }; - -inline U128 umul64(U64 a, U64 b) { - U64 p_ll = u32_mul(a.lo, b.lo); - U64 p_lh = u32_mul(a.lo, b.hi); - U64 p_hl = u32_mul(a.hi, b.lo); - U64 p_hh = u32_mul(a.hi, b.hi); - uint32_t w0 = p_ll.lo; - uint32_t s1a = p_ll.hi + p_lh.lo; - uint32_t c1a = (s1a < p_ll.hi) ? 1u : 0u; - uint32_t w1 = s1a + p_hl.lo; - uint32_t c1b = (w1 < s1a) ? 1u : 0u; - uint32_t carry1 = c1a + c1b; - uint32_t s2a = p_lh.hi + p_hl.hi; - uint32_t c2a = (s2a < p_lh.hi) ? 1u : 0u; - uint32_t s2b = s2a + p_hh.lo; - uint32_t c2b = (s2b < s2a) ? 1u : 0u; - uint32_t w2 = s2b + carry1; - uint32_t c2c = (w2 < s2b) ? 1u : 0u; - uint32_t carry2 = c2a + c2b + c2c; - uint32_t w3 = p_hh.hi + carry2; - return U128{w0, w1, w2, w3}; -} - -// Low 64 bits of u64 * u64, used for the Montgomery m = t[0]*qInvNeg. -inline U64 u64_mul_low(U64 a, U64 b) { - U64 p_ll = u32_mul(a.lo, b.lo); - U64 p_lh = u32_mul(a.lo, b.hi); - U64 p_hl = u32_mul(a.hi, b.lo); - uint32_t lo = p_ll.lo; - uint32_t hi = p_ll.hi + p_lh.lo + p_hl.lo; - return U64{lo, hi}; -} - -// ============================================================================= -// BN254 Fr modulus + Montgomery params. Same numeric values as the kernel -// constants (declared as u32 lo/hi pairs). -// ============================================================================= -constexpr U64 Q0_U = U64{0xf0000001u, 0x43e1f593u}; -constexpr U64 Q1_U = U64{0x79b97091u, 0x2833e848u}; -constexpr U64 Q2_U = U64{0x8181585du, 0xb85045b6u}; -constexpr U64 Q3_U = U64{0xe131a029u, 0x30644e72u}; -constexpr U64 QINV = U64{0xefffffffu, 0xc2e1f593u}; - -constexpr U64 R2_0 = U64{0xae216da7u, 0x1bb8e645u}; -constexpr U64 R2_1 = U64{0xe35c59e3u, 0x53fe3ab1u}; -constexpr U64 R2_2 = U64{0x53bb8085u, 0x8c49833du}; -constexpr U64 R2_3 = U64{0x7f4e44a5u, 0x0216d0b1u}; - -struct Fr { - U64 l0, l1, l2, l3; -}; - -inline Fr fr_zero() { return Fr{u64_zero(), u64_zero(), u64_zero(), u64_zero()}; } -inline Fr fr_q() { return Fr{Q0_U, Q1_U, Q2_U, Q3_U}; } -inline Fr fr_r2() { return Fr{R2_0, R2_1, R2_2, R2_3}; } - -inline int cmp_q(const Fr &a) { - Fr q = fr_q(); - if (!u64_eq(a.l3, q.l3)) return u64_lt(a.l3, q.l3) ? -1 : 1; - if (!u64_eq(a.l2, q.l2)) return u64_lt(a.l2, q.l2) ? -1 : 1; - if (!u64_eq(a.l1, q.l1)) return u64_lt(a.l1, q.l1) ? -1 : 1; - if (!u64_eq(a.l0, q.l0)) return u64_lt(a.l0, q.l0) ? -1 : 1; - return 0; -} - -inline Fr fr_sub_q(const Fr &a) { - Fr q = fr_q(); - auto r0 = u64_sub_borrow(a.l0, q.l0, 0u); - auto r1 = u64_sub_borrow(a.l1, q.l1, r0.borrow); - auto r2 = u64_sub_borrow(a.l2, q.l2, r1.borrow); - auto r3 = u64_sub_borrow(a.l3, q.l3, r2.borrow); - return Fr{r0.v, r1.v, r2.v, r3.v}; -} - -inline Fr reduce_once(const Fr &a) { - return cmp_q(a) >= 0 ? fr_sub_q(a) : a; -} - -inline Fr fr_add(const Fr &a, const Fr &b) { - auto r0 = u64_add_carry(a.l0, b.l0, 0u); - auto r1 = u64_add_carry(a.l1, b.l1, r0.carry); - auto r2 = u64_add_carry(a.l2, b.l2, r1.carry); - auto r3 = u64_add_carry(a.l3, b.l3, r2.carry); - Fr c{r0.v, r1.v, r2.v, r3.v}; - if (r3.carry != 0u || cmp_q(c) >= 0) c = fr_sub_q(c); - return c; -} - -inline Fr fr_double(const Fr &a) { return fr_add(a, a); } - -// CIOS Montgomery multiplication, u32-only. -inline Fr fr_mul(const Fr &a, const Fr &b) { - U64 t[5] = {u64_zero(), u64_zero(), u64_zero(), u64_zero(), u64_zero()}; - const U64 al[4] = {a.l0, a.l1, a.l2, a.l3}; - const U64 bl[4] = {b.l0, b.l1, b.l2, b.l3}; - const U64 qq[4] = {Q0_U, Q1_U, Q2_U, Q3_U}; - - for (int i = 0; i < 4; ++i) { - U64 cy = u64_zero(); - for (int j = 0; j < 4; ++j) { - U128 prod = umul64(al[j], bl[i]); - U64 lo = U64{prod.l0, prod.l1}; - U64 hi = U64{prod.l2, prod.l3}; - auto s = u64_add_carry(t[j], lo, 0u); - auto s2 = u64_add_carry(s.v, cy, 0u); - t[j] = s2.v; - // cy = hi + s.carry + s2.carry - auto cy1 = u64_add_carry(hi, U64{s.carry, 0u}, 0u); - auto cy2 = u64_add_carry(cy1.v, U64{s2.carry, 0u}, 0u); - cy = cy2.v; - } - auto t4u = u64_add_carry(t[4], cy, 0u); - t[4] = t4u.v; - - U64 m = u64_mul_low(t[0], QINV); - - cy = u64_zero(); - for (int j = 0; j < 4; ++j) { - U128 prod = umul64(m, qq[j]); - U64 lo = U64{prod.l0, prod.l1}; - U64 hi = U64{prod.l2, prod.l3}; - auto s = u64_add_carry(t[j], lo, 0u); - auto s2 = u64_add_carry(s.v, cy, 0u); - t[j] = s2.v; - auto cy1 = u64_add_carry(hi, U64{s.carry, 0u}, 0u); - auto cy2 = u64_add_carry(cy1.v, U64{s2.carry, 0u}, 0u); - cy = cy2.v; - } - auto t4u2 = u64_add_carry(t[4], cy, 0u); - t[4] = t4u2.v; - - t[0] = t[1]; - t[1] = t[2]; - t[2] = t[3]; - t[3] = t[4]; - t[4] = u64_zero(); - } - Fr c{t[0], t[1], t[2], t[3]}; - return reduce_once(c); -} - -inline Fr fr_square(const Fr &a) { return fr_mul(a, a); } - -inline Fr sbox(const Fr &x) { - Fr x2 = fr_square(x); - Fr x4 = fr_square(x2); - return fr_mul(x4, x); -} - -struct State2 { Fr s0, s1; }; - -inline State2 mat_mul_external(const State2 &s) { - Fr tmp = fr_add(s.s0, s.s1); - return State2{fr_add(s.s0, tmp), fr_add(s.s1, tmp)}; -} -inline State2 mat_mul_internal(const State2 &s) { - Fr sum = fr_add(s.s0, s.s1); - Fr s0p = fr_add(s.s0, sum); - Fr s1d = fr_double(s.s1); - Fr s1p = fr_add(s1d, sum); - return State2{s0p, s1p}; -} - -constexpr int FULL_HALF = 3; -constexpr int PARTIAL = 50; - -inline Fr load_rk(int round, int slot) { - return Fr{ - U64{POSEIDON2_RK_LO[round][slot][0], POSEIDON2_RK_HI[round][slot][0]}, - U64{POSEIDON2_RK_LO[round][slot][1], POSEIDON2_RK_HI[round][slot][1]}, - U64{POSEIDON2_RK_LO[round][slot][2], POSEIDON2_RK_HI[round][slot][2]}, - U64{POSEIDON2_RK_LO[round][slot][3], POSEIDON2_RK_HI[round][slot][3]} - }; -} - -inline State2 permute(const State2 &s_in) { - State2 s = mat_mul_external(s_in); - for (int i = 0; i < FULL_HALF; ++i) { - Fr k0 = load_rk(i, 0); - Fr k1 = load_rk(i, 1); - s.s0 = fr_add(s.s0, k0); - s.s1 = fr_add(s.s1, k1); - s.s0 = sbox(s.s0); - s.s1 = sbox(s.s1); - s = mat_mul_external(s); - } - for (int i = 0; i < PARTIAL; ++i) { - Fr k0 = load_rk(FULL_HALF + i, 0); - s.s0 = fr_add(s.s0, k0); - s.s0 = sbox(s.s0); - s = mat_mul_internal(s); - } - for (int i = 0; i < FULL_HALF; ++i) { - Fr k0 = load_rk(FULL_HALF + PARTIAL + i, 0); - Fr k1 = load_rk(FULL_HALF + PARTIAL + i, 1); - s.s0 = fr_add(s.s0, k0); - s.s1 = fr_add(s.s1, k1); - s.s0 = sbox(s.s0); - s.s1 = sbox(s.s1); - s = mat_mul_external(s); - } - return s; -} - -inline U64 read_be64(const unsigned char *p) { - uint32_t hi = ((uint32_t)p[0] << 24) | ((uint32_t)p[1] << 16) | - ((uint32_t)p[2] << 8) | (uint32_t)p[3]; - uint32_t lo = ((uint32_t)p[4] << 24) | ((uint32_t)p[5] << 16) | - ((uint32_t)p[6] << 8) | (uint32_t)p[7]; - return U64{lo, hi}; -} - -inline void write_be64(unsigned char *p, U64 v) { - p[0] = (unsigned char)(v.hi >> 24); - p[1] = (unsigned char)(v.hi >> 16); - p[2] = (unsigned char)(v.hi >> 8); - p[3] = (unsigned char)(v.hi ); - p[4] = (unsigned char)(v.lo >> 24); - p[5] = (unsigned char)(v.lo >> 16); - p[6] = (unsigned char)(v.lo >> 8); - p[7] = (unsigned char)(v.lo ); -} - -inline Fr be_to_fr_mont(const unsigned char *be) { - Fr x{ - read_be64(be + 24), - read_be64(be + 16), - read_be64(be + 8), - read_be64(be + 0) - }; - for (int i = 0; i < 4; ++i) { - if (cmp_q(x) < 0) break; - x = fr_sub_q(x); - } - return fr_mul(x, fr_r2()); -} - -inline void fr_mont_to_be(unsigned char *be, const Fr &x) { - Fr one_reg{U64{1u, 0u}, u64_zero(), u64_zero(), u64_zero()}; - Fr r = fr_mul(x, one_reg); - write_be64(be + 0, r.l3); - write_be64(be + 8, r.l2); - write_be64(be + 16, r.l1); - write_be64(be + 24, r.l0); -} - -} // namespace - -extern "C" int poseidon2_hash2_wgsl_batch(const unsigned char *pairs, - unsigned char *outs, - unsigned long n) { - if (n == 0) return 0; - if (!pairs || !outs) return -1; - for (unsigned long i = 0; i < n; ++i) { - State2 s{ - be_to_fr_mont(pairs + i * 64), - be_to_fr_mont(pairs + i * 64 + 32) - }; - Fr saved_right = s.s1; - s = permute(s); - Fr digest = fr_add(saved_right, s.s1); - fr_mont_to_be(outs + i * 32, digest); - } - return 0; -} diff --git a/poseidon/gpu/wgsl/poseidon2_driver.h b/poseidon/gpu/wgsl/poseidon2_driver.h deleted file mode 100644 index b7dffe7..0000000 --- a/poseidon/gpu/wgsl/poseidon2_driver.h +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// WGSL driver for Poseidon2-BN254. Runs the algorithm via a C++ host polyfill -// that emulates the kernel's u32-only arithmetic (WGSL has no native u64). -// On hosts with wgpu-native, this same driver dispatches the .wgsl shader on -// the GPU; on hosts without wgpu-native, the polyfill produces byte-equal -// output by construction (the Montgomery scalar arithmetic is u32-only). - -#ifndef LUX_POSEIDON2_WGSL_DRIVER_H -#define LUX_POSEIDON2_WGSL_DRIVER_H - -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -// Run n Poseidon2.Compress calls in one WGSL dispatch (host polyfill loops -// the same per-thread body using u32-only arithmetic, byte-equal to the GPU -// kernel by construction). -// -// pairs : n * 64 bytes, [BE(left_i) || BE(right_i)] for i in 0..n -// outs : n * 32 bytes, BE digest -// n : number of pairs -// -// Returns 0 on success, -1 on invalid arg. -int poseidon2_hash2_wgsl_batch(const unsigned char *pairs, - unsigned char *outs, - unsigned long n); - -#ifdef __cplusplus -} -#endif - -#endif // LUX_POSEIDON2_WGSL_DRIVER_H diff --git a/ringtail/gpu/cuda/ringtail.cu b/ringtail/gpu/cuda/ringtail.cu deleted file mode 100644 index baa1cdc..0000000 --- a/ringtail/gpu/cuda/ringtail.cu +++ /dev/null @@ -1,240 +0,0 @@ -// Ringtail lattice-based threshold signatures -- CUDA implementation -// Matches ringtail.metal output byte-for-byte -// One thread per partial sign / combine operation - -#include - -#ifndef __CUDA_ARCH__ -#define __device__ -#define __global__ -#define __shared__ -struct dim3 { unsigned x, y, z; }; -static dim3 blockIdx, blockDim, threadIdx; -#endif - -// ============================================================================= -// Ringtail parameters (same ring as ML-DSA) -// ============================================================================= - -#define RT_Q 8380417 - -// ============================================================================= -// Modular arithmetic -// ============================================================================= - -__device__ static int32_t rt_reduce(int32_t a) { - int32_t t = (int32_t)((int64_t)a * 33554687LL >> 48); - int32_t r = a - t * RT_Q; - if (r < 0) r += RT_Q; - if (r >= RT_Q) r -= RT_Q; - return r; -} - -__device__ static int32_t rt_mont_reduce(int64_t a) { - const int32_t q_inv = 58728449; - int32_t t = (int32_t)a * q_inv; - int64_t u = (int64_t)t * RT_Q; - int32_t r = (int32_t)((a - u) >> 32); - if (r < 0) r += RT_Q; - return r; -} - -// ============================================================================= -// NTT (same as ML-DSA, q=8380417, n=256) -// ============================================================================= - -__device__ static const int32_t RT_ZETAS[128] = { - 25847, -2608894, -518909, 237124, -777960, -876248, 466468, 1826347, - 2353451, -359251, -2091905, 3119733, -2884855, 3111497, 2680103, 2725464, - 1024112, -1079900, 3585928, -549488, -1119584, 2619752, -2108549, -2118186, - -3859737, -1399561,-3277672, 1757237, -19422, 4010497, 280005, -2353451, - -1012179, -1277625, 1526252, -1402780, -2091905, 3119733, 3585928, -549488, - 2619752, -2108549, 2804197, -3199876, -38575, -2704181, 1757237, -19422, - 280005, 2706023, 1391570, 2287915, -3583748, -1399561, -3277672, -2353451, - 2353451, 3585928, -549488, 2619752, -2108549, 2804197, -3199876, -38575, - -2704181, 1757237, -19422, 280005, 2706023, 1391570, 2287915, -3583748, - -1399561, -3277672, 237124, -777960, -876248, 466468, 1826347, -2608894, - -518909, 237124, -777960, -876248, 466468, 1826347, 2353451, -359251, - -2091905, 3119733,-2884855, 3111497, 2680103, 2725464, 1024112, -1079900, - 3585928, -549488,-1119584, 2619752, -2108549, -2118186, -3859737, -1399561, - -3277672, 1757237, -19422, 4010497, 280005, -2353451, -1012179, -1277625, - 1526252, -1402780, 2706023, 1391570, 2287915, -3583748, -1399561, -3277672, - 1757237, -19422, 280005, 2706023, 1391570, 2287915, -3583748, -1399561 -}; - -__device__ static void rt_ntt_bf(int32_t& a, int32_t& b, int32_t z) { - int32_t t = rt_mont_reduce((int64_t)z * b); - b = a - t; a = a + t; - if (a >= RT_Q) a -= RT_Q; - if (b < 0) b += RT_Q; -} - -__device__ static void rt_inv_ntt_bf(int32_t& a, int32_t& b, int32_t z) { - int32_t t = a; - a = t + b; b = t - b; - if (a >= RT_Q) a -= RT_Q; - if (b < 0) b += RT_Q; - b = rt_mont_reduce((int64_t)z * b); -} - -__device__ static void rt_ntt(int32_t poly[256]) { - int k = 0; - for (int len = 128; len >= 1; len >>= 1) - for (int start = 0; start < 256; start += 2 * len) { - int32_t z = RT_ZETAS[++k]; - for (int j = start; j < start + len; j++) - rt_ntt_bf(poly[j], poly[j + len], z); - } -} - -__device__ static void rt_inv_ntt(int32_t poly[256]) { - const int32_t f = 41978; - int k = 127; - for (int len = 1; len <= 128; len <<= 1) - for (int start = 0; start < 256; start += 2 * len) { - int32_t z = -RT_ZETAS[k--]; - if (z < 0) z += RT_Q; - for (int j = start; j < start + len; j++) - rt_inv_ntt_bf(poly[j], poly[j + len], z); - } - for (int i = 0; i < 256; i++) - poly[i] = rt_mont_reduce((int64_t)f * poly[i]); -} - -// Pointwise multiply: c[i] = a[i] * b[i] mod q (NTT domain) -__device__ static void rt_poly_mul_ntt(int32_t c[256], - const int32_t a[256], - const int32_t b[256]) { - for (int i = 0; i < 256; i++) - c[i] = rt_mont_reduce((int64_t)a[i] * b[i]); -} - -// Polynomial add: c[i] = a[i] + b[i] mod q -__device__ static void rt_poly_add(int32_t c[256], - const int32_t a[256], - const int32_t b[256]) { - for (int i = 0; i < 256; i++) - c[i] = rt_reduce(a[i] + b[i]); -} - -// ============================================================================= -// Ringtail structures -// ============================================================================= - -struct RingtailShare { - uint8_t data[1024]; // 256 int32_t coefficients -}; - -struct RingtailMessage { - uint8_t data[32]; -}; - -struct RingtailPartialSig { - uint8_t data[1024]; -}; - -// ============================================================================= -// Partial signing kernel -// ============================================================================= - -extern "C" __global__ void ringtail_partial_sign_batch( - const RingtailShare* __restrict__ shares, - const RingtailMessage* __restrict__ messages, - RingtailPartialSig* __restrict__ partial_sigs, - const uint32_t* __restrict__ num_ops_ptr) -{ - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - uint32_t num_ops = *num_ops_ptr; - if (tid >= num_ops) return; - - // Load share polynomial - int32_t share[256]; - const uint8_t* sp = shares[tid].data; - for (int i = 0; i < 256; i++) { - share[i] = (int32_t)sp[i * 4] - | ((int32_t)sp[i * 4 + 1] << 8) - | ((int32_t)sp[i * 4 + 2] << 16) - | ((int32_t)sp[i * 4 + 3] << 24); - } - - // Derive challenge polynomial from message hash - const uint8_t* msg = messages[tid].data; - int32_t challenge[256]; - for (int i = 0; i < 256; i++) { - uint32_t idx = (i * 4) % 32; - uint32_t val = (uint32_t)msg[idx] - | ((uint32_t)msg[(idx + 1) % 32] << 8) - | ((uint32_t)msg[(idx + 2) % 32] << 16) - | ((uint32_t)msg[(idx + 3) % 32] << 24); - val ^= (uint32_t)(i * 2654435761u); - challenge[i] = (int32_t)(val % (uint32_t)RT_Q); - } - - // NTT of challenge - rt_ntt(challenge); - - // NTT of share - rt_ntt(share); - - // Pointwise multiply - int32_t result[256]; - rt_poly_mul_ntt(result, share, challenge); - - // Inverse NTT - rt_inv_ntt(result); - - // Write partial signature - uint8_t* out = partial_sigs[tid].data; - for (int i = 0; i < 256; i++) { - uint32_t v = (uint32_t)result[i]; - out[i * 4] = (uint8_t)(v & 0xFF); - out[i * 4 + 1] = (uint8_t)((v >> 8) & 0xFF); - out[i * 4 + 2] = (uint8_t)((v >> 16) & 0xFF); - out[i * 4 + 3] = (uint8_t)((v >> 24) & 0xFF); - } -} - -// ============================================================================= -// Combine kernel -// ============================================================================= - -extern "C" __global__ void ringtail_combine_batch( - const RingtailPartialSig* __restrict__ partial_sigs, // [num_ops * threshold] - const int32_t* __restrict__ lagrange_coeffs, // [num_ops * threshold] - RingtailPartialSig* __restrict__ combined_sigs, // [num_ops] - const uint32_t* __restrict__ threshold_ptr, - const uint32_t* __restrict__ num_ops_ptr) -{ - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - uint32_t num_ops = *num_ops_ptr; - uint32_t threshold = *threshold_ptr; - if (tid >= num_ops) return; - - int32_t combined[256]; - for (int i = 0; i < 256; i++) combined[i] = 0; - - for (uint32_t s = 0; s < threshold; s++) { - const uint8_t* ps = partial_sigs[tid * threshold + s].data; - int32_t lambda = lagrange_coeffs[tid * threshold + s]; - - for (int i = 0; i < 256; i++) { - int32_t coeff = (int32_t)ps[i * 4] - | ((int32_t)ps[i * 4 + 1] << 8) - | ((int32_t)ps[i * 4 + 2] << 16) - | ((int32_t)ps[i * 4 + 3] << 24); - - int64_t prod = (int64_t)lambda * coeff; - combined[i] = rt_reduce(combined[i] + rt_mont_reduce(prod)); - } - } - - // Write combined signature - uint8_t* out = combined_sigs[tid].data; - for (int i = 0; i < 256; i++) { - uint32_t v = (uint32_t)combined[i]; - out[i * 4] = (uint8_t)(v & 0xFF); - out[i * 4 + 1] = (uint8_t)((v >> 8) & 0xFF); - out[i * 4 + 2] = (uint8_t)((v >> 16) & 0xFF); - out[i * 4 + 3] = (uint8_t)((v >> 24) & 0xFF); - } -} diff --git a/ringtail/gpu/metal/ringtail.metal b/ringtail/gpu/metal/ringtail.metal deleted file mode 100644 index bcdbe54..0000000 --- a/ringtail/gpu/metal/ringtail.metal +++ /dev/null @@ -1,262 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -/// @file ringtail.metal -/// Metal compute shader for Ringtail lattice-based threshold signatures. -/// -/// Ringtail is a Lux-specific lattice-based threshold signature scheme -/// operating over the same polynomial ring as ML-DSA: Z_q[x]/(x^n + 1) -/// with q=8380417, n=256. -/// -/// Threshold protocol: k-of-n signers produce partial signatures, -/// which are combined into one valid signature. -/// -/// Operations: -/// - ringtail_partial_sign_batch: compute partial signature from share -/// - ringtail_combine_batch: combine k partial sigs into one -/// -/// GPU advantage: NTT-based polynomial multiplication is the hot path. - -#include -using namespace metal; - -// ============================================================================= -// Ringtail parameters (same ring as ML-DSA) -// ============================================================================= - -constant int32_t RT_Q = 8380417; - -// ============================================================================= -// Modular arithmetic -// ============================================================================= - -inline int32_t rt_reduce(int32_t a) { - int32_t t = (int32_t)((int64_t)a * 33554687 >> 48); - int32_t r = a - t * RT_Q; - if (r < 0) r += RT_Q; - if (r >= RT_Q) r -= RT_Q; - return r; -} - -inline int32_t rt_mont_reduce(int64_t a) { - const int32_t q_inv = 58728449; - int32_t t = (int32_t)a * q_inv; - int64_t u = (int64_t)t * RT_Q; - int32_t r = (int32_t)((a - u) >> 32); - if (r < 0) r += RT_Q; - return r; -} - -// ============================================================================= -// NTT (same as ML-DSA, q=8380417, n=256) -// ============================================================================= - -constant int32_t RT_ZETAS[128] = { - 25847, -2608894, -518909, 237124, -777960, -876248, 466468, 1826347, - 2353451, -359251, -2091905, 3119733, -2884855, 3111497, 2680103, 2725464, - 1024112, -1079900, 3585928, -549488, -1119584, 2619752, -2108549, -2118186, - -3859737, -1399561,-3277672, 1757237, -19422, 4010497, 280005, -2353451, - -1012179, -1277625, 1526252, -1402780, -2091905, 3119733, 3585928, -549488, - 2619752, -2108549, 2804197, -3199876, -38575, -2704181, 1757237, -19422, - 280005, 2706023, 1391570, 2287915, -3583748, -1399561, -3277672, -2353451, - 2353451, 3585928, -549488, 2619752, -2108549, 2804197, -3199876, -38575, - -2704181, 1757237, -19422, 280005, 2706023, 1391570, 2287915, -3583748, - -1399561, -3277672, 237124, -777960, -876248, 466468, 1826347, -2608894, - -518909, 237124, -777960, -876248, 466468, 1826347, 2353451, -359251, - -2091905, 3119733,-2884855, 3111497, 2680103, 2725464, 1024112, -1079900, - 3585928, -549488,-1119584, 2619752, -2108549, -2118186, -3859737, -1399561, - -3277672, 1757237, -19422, 4010497, 280005, -2353451, -1012179, -1277625, - 1526252, -1402780, 2706023, 1391570, 2287915, -3583748, -1399561, -3277672, - 1757237, -19422, 280005, 2706023, 1391570, 2287915, -3583748, -1399561 -}; - -inline void rt_ntt_bf(thread int32_t& a, thread int32_t& b, int32_t z) { - int32_t t = rt_mont_reduce((int64_t)z * b); - b = a - t; a = a + t; - if (a >= RT_Q) a -= RT_Q; - if (b < 0) b += RT_Q; -} - -inline void rt_inv_ntt_bf(thread int32_t& a, thread int32_t& b, int32_t z) { - int32_t t = a; - a = t + b; b = t - b; - if (a >= RT_Q) a -= RT_Q; - if (b < 0) b += RT_Q; - b = rt_mont_reduce((int64_t)z * b); -} - -inline void rt_ntt(thread int32_t poly[256]) { - int k = 0; - for (int len = 128; len >= 1; len >>= 1) - for (int start = 0; start < 256; start += 2 * len) { - int32_t z = RT_ZETAS[++k]; - for (int j = start; j < start + len; j++) - rt_ntt_bf(poly[j], poly[j + len], z); - } -} - -inline void rt_inv_ntt(thread int32_t poly[256]) { - const int32_t f = 41978; - int k = 127; - for (int len = 1; len <= 128; len <<= 1) - for (int start = 0; start < 256; start += 2 * len) { - int32_t z = -RT_ZETAS[k--]; - if (z < 0) z += RT_Q; - for (int j = start; j < start + len; j++) - rt_inv_ntt_bf(poly[j], poly[j + len], z); - } - for (int i = 0; i < 256; i++) - poly[i] = rt_mont_reduce((int64_t)f * poly[i]); -} - -/// Pointwise multiply: c[i] = a[i] * b[i] mod q (NTT domain) -inline void rt_poly_mul_ntt(thread int32_t c[256], - thread const int32_t a[256], - thread const int32_t b[256]) { - for (int i = 0; i < 256; i++) - c[i] = rt_mont_reduce((int64_t)a[i] * b[i]); -} - -/// Polynomial add: c[i] = a[i] + b[i] mod q -inline void rt_poly_add(thread int32_t c[256], - thread const int32_t a[256], - thread const int32_t b[256]) { - for (int i = 0; i < 256; i++) - c[i] = rt_reduce(a[i] + b[i]); -} - -// ============================================================================= -// Ringtail structures -// ============================================================================= - -/// Secret share: one polynomial in NTT domain (256 * 4 bytes) -struct RingtailShare { - uchar data[1024]; // 256 int32_t coefficients -}; - -/// Message hash (32 bytes, pre-hashed by host) -struct RingtailMessage { - uchar data[32]; -}; - -/// Partial signature: one polynomial (256 * 4 bytes) -struct RingtailPartialSig { - uchar data[1024]; -}; - -// ============================================================================= -// Partial signing kernel -// ============================================================================= - -/// Compute partial signature from secret share and message. -/// partial_sig = NTT^{-1}(NTT(share) * NTT(c)) + mask -/// where c is the challenge polynomial derived from the message hash. -/// -/// Each thread produces one partial signature. -kernel void ringtail_partial_sign_batch( - device const RingtailShare* shares [[buffer(0)]], - device const RingtailMessage* messages [[buffer(1)]], - device RingtailPartialSig* partial_sigs [[buffer(2)]], - constant uint& num_ops [[buffer(3)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= num_ops) return; - - // Load share polynomial - int32_t share[256]; - device const uchar* sp = shares[tid].data; - for (int i = 0; i < 256; i++) { - share[i] = int32_t(sp[i * 4]) - | (int32_t(sp[i * 4 + 1]) << 8) - | (int32_t(sp[i * 4 + 2]) << 16) - | (int32_t(sp[i * 4 + 3]) << 24); - } - - // Derive challenge polynomial from message hash - // Simple: expand 32 bytes to 256 coefficients via rejection sampling - device const uchar* msg = messages[tid].data; - int32_t challenge[256]; - for (int i = 0; i < 256; i++) { - // Deterministic expansion: take bytes and reduce mod q - uint idx = (i * 4) % 32; - uint32_t val = uint32_t(msg[idx]) - | (uint32_t(msg[(idx + 1) % 32]) << 8) - | (uint32_t(msg[(idx + 2) % 32]) << 16) - | (uint32_t(msg[(idx + 3) % 32]) << 24); - // Mix with index for uniqueness - val ^= uint32_t(i * 2654435761u); - challenge[i] = int32_t(val % uint32_t(RT_Q)); - } - - // NTT of challenge - rt_ntt(challenge); - - // NTT of share (already in NTT domain if stored that way, but we NTT anyway) - rt_ntt(share); - - // Pointwise multiply - int32_t result[256]; - rt_poly_mul_ntt(result, share, challenge); - - // Inverse NTT - rt_inv_ntt(result); - - // Write partial signature - device uchar* out = partial_sigs[tid].data; - for (int i = 0; i < 256; i++) { - uint32_t v = uint32_t(result[i]); - out[i * 4] = uchar(v & 0xFF); - out[i * 4 + 1] = uchar((v >> 8) & 0xFF); - out[i * 4 + 2] = uchar((v >> 16) & 0xFF); - out[i * 4 + 3] = uchar((v >> 24) & 0xFF); - } -} - -// ============================================================================= -// Combine kernel -// ============================================================================= - -/// Combine k partial signatures into one via Lagrange interpolation. -/// combined = sum_{i=0}^{k-1} lambda_i * partial_sig_i (mod q) -/// -/// Each thread combines one set of k partial signatures. -/// The Lagrange coefficients are pre-computed by the host. -kernel void ringtail_combine_batch( - device const RingtailPartialSig* partial_sigs [[buffer(0)]], // [num_ops * threshold] - device const int32_t* lagrange_coeffs [[buffer(1)]], // [num_ops * threshold] - device RingtailPartialSig* combined_sigs [[buffer(2)]], // [num_ops] - constant uint& threshold [[buffer(3)]], - constant uint& num_ops [[buffer(4)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= num_ops) return; - - int32_t combined[256] = {}; - - for (uint s = 0; s < threshold; s++) { - // Load partial signature - device const uchar* ps = partial_sigs[tid * threshold + s].data; - int32_t lambda = lagrange_coeffs[tid * threshold + s]; - - for (int i = 0; i < 256; i++) { - int32_t coeff = int32_t(ps[i * 4]) - | (int32_t(ps[i * 4 + 1]) << 8) - | (int32_t(ps[i * 4 + 2]) << 16) - | (int32_t(ps[i * 4 + 3]) << 24); - - // combined[i] += lambda * coeff mod q - int64_t prod = (int64_t)lambda * coeff; - combined[i] = rt_reduce(combined[i] + rt_mont_reduce(prod)); - } - } - - // Write combined signature - device uchar* out = combined_sigs[tid].data; - for (int i = 0; i < 256; i++) { - uint32_t v = uint32_t(combined[i]); - out[i * 4] = uchar(v & 0xFF); - out[i * 4 + 1] = uchar((v >> 8) & 0xFF); - out[i * 4 + 2] = uchar((v >> 16) & 0xFF); - out[i * 4 + 3] = uchar((v >> 24) & 0xFF); - } -} diff --git a/ringtail/gpu/metal/ringtail_ops.metal b/ringtail/gpu/metal/ringtail_ops.metal deleted file mode 100644 index 042f395..0000000 --- a/ringtail/gpu/metal/ringtail_ops.metal +++ /dev/null @@ -1,378 +0,0 @@ -// Copyright (c) 2024-2026 Lux Partners Limited -// SPDX-License-Identifier: BSD-3-Clause -// -// Ringtail Lattice Threshold Operations - Metal Implementation -// Post-quantum threshold signatures based on Module-LWE (MLWE). - -#include -using namespace metal; - -// ============================================================================ -// Ringtail Parameters (Dilithium-like construction) -// ============================================================================ - -constant uint Q = 8380417; // Modulus 2^23 - 2^13 + 1 -constant uint N = 256; // Ring dimension -constant uint K = 4; // Module rank for public key -constant uint L = 4; // Module rank for secret -constant int GAMMA1 = 131072; // Commitment bound (2^17) -constant int GAMMA2 = 95232; // Low bits rounding -constant uint QINV = 58728449; // q^-1 mod 2^32 - -// ============================================================================ -// Modular Arithmetic -// ============================================================================ - -inline int mont_reduce(long a) { - int t = int(uint(a) * QINV); - return int((a - long(t) * Q) >> 32); -} - -inline int mod_add(int a, int b) { - int r = a + b; - if (r >= int(Q)) r -= Q; - if (r < 0) r += Q; - return r; -} - -inline int mod_sub(int a, int b) { - int r = a - b; - if (r < 0) r += Q; - return r; -} - -inline int caddq(int a) { - return a + ((a >> 31) & int(Q)); -} - -inline int freeze(int a) { - a = caddq(a); - return a - int(Q) + ((int(Q) - 1 - a) >> 31 & int(Q)); -} - -// ============================================================================ -// Polynomial Structures -// ============================================================================ - -struct Poly { - int coeffs[256]; -}; - -struct PolyVec { - Poly polys[4]; // L polynomials -}; - -struct ThresholdShare { - uint index; - PolyVec s_share; - PolyVec y_share; -}; - -// ============================================================================ -// High/Low Bits Decomposition -// ============================================================================ - -inline int highbits(int r, int alpha) { - r = freeze(r); - int r1 = (r + (alpha >> 1)) / alpha; - return r1; -} - -inline int lowbits(int r, int alpha) { - int r1 = highbits(r, alpha); - return r - r1 * alpha; -} - -// ============================================================================ -// Lagrange Interpolation -// ============================================================================ - -// Compute modular inverse using extended Euclidean algorithm -inline int mod_inv_int(int a, uint q) { - int t = 0, new_t = 1; - int r = int(q), new_r = a; - - while (new_r != 0) { - int quotient = r / new_r; - int temp_t = t - quotient * new_t; - t = new_t; - new_t = temp_t; - - int temp_r = r - quotient * new_r; - r = new_r; - new_r = temp_r; - } - - if (t < 0) t += int(q); - return t; -} - -// Compute Lagrange coefficient at x=0 -inline int compute_lagrange_coeff( - uint index, - device const uint* indices, - uint num_shares -) { - long numerator = 1; - long denominator = 1; - - for (uint j = 0; j < num_shares; j++) { - if (indices[j] == index) continue; - - numerator = (numerator * long(indices[j])) % Q; - long diff = long(indices[j]) - long(index); - if (diff < 0) diff += Q; - denominator = (denominator * diff) % Q; - } - - int inv = mod_inv_int(int(denominator), Q); - return int((numerator * inv) % Q); -} - -// ============================================================================ -// Share Combination Kernel -// ============================================================================ - -kernel void combine_shares( - device const int* shares_s [[buffer(0)]], // [num_shares][L][N] - device const int* shares_y [[buffer(1)]], // [num_shares][L][N] - device const uint* share_indices [[buffer(2)]], - device const uint* participant_indices [[buffer(3)]], - device int* combined_s [[buffer(4)]], // [L][N] - device int* combined_y [[buffer(5)]], // [L][N] - constant uint& num_shares [[buffer(6)]], - uint2 gid [[thread_position_in_grid]] // (coeff_idx, poly_idx) -) { - uint coeff_idx = gid.x; - uint poly_idx = gid.y; - - if (coeff_idx >= N || poly_idx >= L) return; - - long s_sum = 0; - long y_sum = 0; - - for (uint i = 0; i < num_shares; i++) { - int lambda = compute_lagrange_coeff(participant_indices[i], participant_indices, num_shares); - - uint offset = i * L * N + poly_idx * N + coeff_idx; - long s_val = shares_s[offset]; - long y_val = shares_y[offset]; - - s_sum = (s_sum + (s_val * lambda) % Q + Q) % Q; - y_sum = (y_sum + (y_val * lambda) % Q + Q) % Q; - } - - uint out_offset = poly_idx * N + coeff_idx; - combined_s[out_offset] = int(s_sum); - combined_y[out_offset] = int(y_sum); -} - -// ============================================================================ -// Commitment Computation (A*y in NTT domain) -// ============================================================================ - -kernel void compute_commitment( - device const int* A [[buffer(0)]], // [K][L][N] in NTT domain - device const int* y [[buffer(1)]], // [L][N] - device int* w [[buffer(2)]], // [K][N] - uint2 gid [[thread_position_in_grid]] // (coeff_idx, row) -) { - uint coeff_idx = gid.x; - uint row = gid.y; - - if (coeff_idx >= N || row >= K) return; - - long sum = 0; - - for (uint col = 0; col < L; col++) { - int a_val = A[row * L * N + col * N + coeff_idx]; - int y_val = y[col * N + coeff_idx]; - long prod = long(a_val) * y_val; - sum += mont_reduce(prod); - } - - w[row * N + coeff_idx] = freeze(int(sum % Q)); -} - -// ============================================================================ -// Response Bounds Check -// ============================================================================ - -kernel void check_response_bounds( - device const int* z [[buffer(0)]], // [L][N] - device atomic_uint* valid [[buffer(1)]], - constant int& gamma1_minus_beta [[buffer(2)]], - uint2 gid [[thread_position_in_grid]] -) { - uint coeff_idx = gid.x; - uint poly_idx = gid.y; - - if (coeff_idx >= N || poly_idx >= L) return; - - int val = z[poly_idx * N + coeff_idx]; - val = freeze(val); - - // Check |z| < gamma1 - beta - if (val > gamma1_minus_beta && val < int(Q) - gamma1_minus_beta) { - atomic_store_explicit(valid, 0u, memory_order_relaxed); - } -} - -// ============================================================================ -// Hint Generation -// ============================================================================ - -kernel void make_hint( - device const int* r [[buffer(0)]], // [K][N] - device const int* z [[buffer(1)]], // [K][N] - device uchar* hint [[buffer(2)]], // [K][N] - device atomic_uint* hint_count [[buffer(3)]], - uint2 gid [[thread_position_in_grid]] -) { - uint coeff_idx = gid.x; - uint poly_idx = gid.y; - - if (coeff_idx >= N || poly_idx >= K) return; - - uint idx = poly_idx * N + coeff_idx; - - int r_val = r[idx]; - int z_val = z[idx]; - - int r_high = highbits(r_val, 2 * GAMMA2); - int rz_high = highbits(mod_add(r_val, z_val), 2 * GAMMA2); - - if (r_high != rz_high) { - hint[idx] = 1; - atomic_fetch_add_explicit(hint_count, 1u, memory_order_relaxed); - } else { - hint[idx] = 0; - } -} - -// ============================================================================ -// Use Hint for Verification -// ============================================================================ - -kernel void use_hint( - device const int* r0 [[buffer(0)]], - device const int* r1 [[buffer(1)]], - device const uchar* hint [[buffer(2)]], - device int* recovered [[buffer(3)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= K * N) return; - - int r0_val = r0[gid]; - int r1_val = r1[gid]; - uchar h = hint[gid]; - - if (h == 0) { - recovered[gid] = r1_val; - } else { - int max_high = (int(Q) - 1) / (2 * GAMMA2) + 1; - if (r0_val > 0) { - recovered[gid] = (r1_val + 1) % max_high; - } else { - recovered[gid] = (r1_val + max_high - 1) % max_high; - } - } -} - -// ============================================================================ -// Batch Share Combination -// ============================================================================ - -kernel void batch_combine_shares( - device const int* all_shares_s [[buffer(0)]], // [batch][num_shares][L][N] - device const int* all_shares_y [[buffer(1)]], - device const uint* all_indices [[buffer(2)]], // [batch][num_shares] - device int* combined_s [[buffer(3)]], // [batch][L][N] - device int* combined_y [[buffer(4)]], - constant uint& num_shares [[buffer(5)]], - constant uint& batch_size [[buffer(6)]], - uint3 gid [[thread_position_in_grid]] // (coeff, poly, batch) -) { - uint coeff_idx = gid.x; - uint poly_idx = gid.y; - uint batch_idx = gid.z; - - if (coeff_idx >= N || poly_idx >= L || batch_idx >= batch_size) return; - - device const uint* indices = all_indices + batch_idx * num_shares; - - long s_sum = 0; - long y_sum = 0; - - for (uint i = 0; i < num_shares; i++) { - int lambda = compute_lagrange_coeff(indices[i], indices, num_shares); - - uint offset = batch_idx * num_shares * L * N + i * L * N + poly_idx * N + coeff_idx; - long s_val = all_shares_s[offset]; - long y_val = all_shares_y[offset]; - - s_sum = (s_sum + (s_val * lambda) % Q + Q) % Q; - y_sum = (y_sum + (y_val * lambda) % Q + Q) % Q; - } - - uint out_offset = batch_idx * L * N + poly_idx * N + coeff_idx; - combined_s[out_offset] = int(s_sum); - combined_y[out_offset] = int(y_sum); -} - -// ============================================================================ -// Power2Round Decomposition -// ============================================================================ - -kernel void power2round( - device const int* r [[buffer(0)]], - device int* r1 [[buffer(1)]], - device int* r0 [[buffer(2)]], - constant uint& d [[buffer(3)]], - constant uint& n [[buffer(4)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= n) return; - - int val = r[gid]; - int half = 1 << (d - 1); - - r1[gid] = (val + half - 1) >> d; - r0[gid] = val - (r1[gid] << d); -} - -// ============================================================================ -// Polynomial Sampling with Bounded Coefficients -// ============================================================================ - -// Simple hash-based random for sampling -inline uint pcg_hash(uint input) { - uint state = input * 747796405u + 2891336453u; - uint word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u; - return (word >> 22u) ^ word; -} - -kernel void sample_poly_eta( - device int* poly [[buffer(0)]], - constant uint& seed [[buffer(1)]], - constant uint& eta [[buffer(2)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= N) return; - - uint rng = pcg_hash(seed ^ gid); - int sample = int(rng % (2 * eta + 1)) - int(eta); - poly[gid] = sample; -} - -kernel void sample_poly_gamma1( - device int* poly [[buffer(0)]], - constant uint& seed [[buffer(1)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= N) return; - - uint rng = pcg_hash(seed ^ gid); - int r = int(rng % (2 * uint(GAMMA1) + 1)); - poly[gid] = GAMMA1 - r; -} diff --git a/ringtail/gpu/metal/ringtail_sign.metal b/ringtail/gpu/metal/ringtail_sign.metal deleted file mode 100644 index 38148fc..0000000 --- a/ringtail/gpu/metal/ringtail_sign.metal +++ /dev/null @@ -1,568 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Ringtail Lattice-Based Threshold Signatures -// GPU-accelerated MLWE-based threshold signing operations -// Optimized for Apple Silicon GPUs -// -// Parameters from Ringtail specification: -// - Ring dimension N = 256 (power of 2 for NTT) -// - Modulus Q = 8380417 (23-bit prime, NTT-friendly) -// - Vector dimensions: M = 8, N_vec = 7 -// - Gaussian parameter sigma for rejection sampling - -#include -using namespace metal; - -// ============================================================================ -// Ringtail Parameters -// ============================================================================ - -// Ring parameters -constant uint RING_N = 256; // Polynomial degree -constant ulong RING_Q = 8380417UL; // Modulus (23-bit prime) -constant ulong RING_Q_INV = 58728449UL; // Montgomery inverse -constant uint LOG_N = 8; // log2(RING_N) - -// Vector dimensions for signature scheme -constant uint VEC_M = 8; // Public key rows -constant uint VEC_N = 7; // Secret key / signature dimension - -// Security parameters -constant int REJECTION_BOUND = 1 << 18; // Rejection sampling bound -constant float SIGMA = 1.55f; // Gaussian standard deviation - -// NTT parameters (primitive root of unity for Q) -constant ulong OMEGA = 1753UL; // Primitive 512th root of unity mod Q -constant ulong OMEGA_INV = 731434UL; // Inverse of OMEGA mod Q -constant ulong N_INV = 8347649UL; // Inverse of N=256 mod Q - -// ============================================================================ -// Data Types -// ============================================================================ - -// Single polynomial coefficient (fits in 32 bits for Q < 2^24) -typedef uint Coeff; - -// Polynomial in ring Z_Q[X]/(X^N + 1) -struct Poly { - Coeff coeffs[RING_N]; -}; - -// Vector of M polynomials -struct PolyVecM { - Poly polys[VEC_M]; -}; - -// Vector of N polynomials -struct PolyVecN { - Poly polys[VEC_N]; -}; - -// Matrix M x N of polynomials -struct PolyMatrix { - Poly polys[VEC_M * VEC_N]; -}; - -// Threshold signature share -struct SignatureShare { - uint participant_id; - Poly c; // Challenge polynomial - PolyVecN z; // Response vector - PolyVecM Delta; // Rounding component -}; - -// Ringtail parameters for kernel dispatch -struct RingtailParams { - uint num_participants; - uint threshold; - uint batch_size; - uint ntt_stage; // For staged NTT -}; - -// ============================================================================ -// Modular Arithmetic -// ============================================================================ - -// Montgomery reduction: returns (a * R^-1) mod Q where R = 2^32 -inline Coeff mont_reduce(ulong a) { - ulong t = (a * RING_Q_INV) & 0xFFFFFFFFUL; - ulong u = a + t * RING_Q; - Coeff result = (Coeff)(u >> 32); - return (result >= RING_Q) ? result - RING_Q : result; -} - -// Modular addition -inline Coeff mod_add(Coeff a, Coeff b) { - Coeff sum = a + b; - return (sum >= RING_Q) ? sum - (Coeff)RING_Q : sum; -} - -// Modular subtraction -inline Coeff mod_sub(Coeff a, Coeff b) { - return (a >= b) ? a - b : a + (Coeff)RING_Q - b; -} - -// Modular multiplication with Montgomery -inline Coeff mod_mul(Coeff a, Coeff b) { - return mont_reduce((ulong)a * (ulong)b); -} - -// Modular negation -inline Coeff mod_neg(Coeff a) { - return (a == 0) ? 0 : (Coeff)RING_Q - a; -} - -// Center reduction: map [0, Q) to [-(Q-1)/2, (Q-1)/2] -inline int center_reduce(Coeff a) { - int t = (int)a; - int half_q = (int)(RING_Q >> 1); - return (t > half_q) ? t - (int)RING_Q : t; -} - -// ============================================================================ -// NTT Operations -// ============================================================================ - -// Precomputed twiddle factors would normally be in a buffer -// For simplicity, compute on-the-fly (production code uses lookup table) - -inline Coeff power_of_omega(uint k) { - // Compute OMEGA^k mod Q using binary exponentiation - ulong result = 1; - ulong base = OMEGA; - while (k > 0) { - if (k & 1) { - result = (result * base) % RING_Q; - } - base = (base * base) % RING_Q; - k >>= 1; - } - return (Coeff)result; -} - -// Forward NTT (Cooley-Tukey, bit-reversal input) -inline void ntt_forward_inplace(thread Poly* p, device const Coeff* twiddles) { - uint n = RING_N; - - for (uint len = 1; len < n; len <<= 1) { - for (uint i = 0; i < n; i += 2 * len) { - for (uint j = 0; j < len; j++) { - Coeff w = twiddles[len + j]; - Coeff u = p->coeffs[i + j]; - Coeff v = mod_mul(p->coeffs[i + j + len], w); - p->coeffs[i + j] = mod_add(u, v); - p->coeffs[i + j + len] = mod_sub(u, v); - } - } - } -} - -// Inverse NTT (Gentleman-Sande, bit-reversal output) -inline void ntt_inverse_inplace(thread Poly* p, device const Coeff* inv_twiddles) { - uint n = RING_N; - - for (uint len = n >> 1; len > 0; len >>= 1) { - for (uint i = 0; i < n; i += 2 * len) { - for (uint j = 0; j < len; j++) { - Coeff w = inv_twiddles[len + j]; - Coeff u = p->coeffs[i + j]; - Coeff v = p->coeffs[i + j + len]; - p->coeffs[i + j] = mod_add(u, v); - p->coeffs[i + j + len] = mod_mul(mod_sub(u, v), w); - } - } - } - - // Scale by N^-1 - Coeff n_inv = (Coeff)N_INV; - for (uint i = 0; i < n; i++) { - p->coeffs[i] = mod_mul(p->coeffs[i], n_inv); - } -} - -// Pointwise multiplication in NTT domain -inline Poly poly_mul_ntt(Poly a, Poly b) { - Poly c; - for (uint i = 0; i < RING_N; i++) { - c.coeffs[i] = mod_mul(a.coeffs[i], b.coeffs[i]); - } - return c; -} - -// ============================================================================ -// Polynomial Arithmetic -// ============================================================================ - -inline Poly poly_zero() { - Poly p; - for (uint i = 0; i < RING_N; i++) { - p.coeffs[i] = 0; - } - return p; -} - -inline Poly poly_add(Poly a, Poly b) { - Poly c; - for (uint i = 0; i < RING_N; i++) { - c.coeffs[i] = mod_add(a.coeffs[i], b.coeffs[i]); - } - return c; -} - -inline Poly poly_sub(Poly a, Poly b) { - Poly c; - for (uint i = 0; i < RING_N; i++) { - c.coeffs[i] = mod_sub(a.coeffs[i], b.coeffs[i]); - } - return c; -} - -inline Poly poly_neg(Poly a) { - Poly c; - for (uint i = 0; i < RING_N; i++) { - c.coeffs[i] = mod_neg(a.coeffs[i]); - } - return c; -} - -// Scalar multiplication -inline Poly poly_scalar_mul(Poly a, Coeff s) { - Poly c; - for (uint i = 0; i < RING_N; i++) { - c.coeffs[i] = mod_mul(a.coeffs[i], s); - } - return c; -} - -// ============================================================================ -// Vector and Matrix Operations -// ============================================================================ - -inline PolyVecN vec_n_add(PolyVecN a, PolyVecN b) { - PolyVecN c; - for (uint i = 0; i < VEC_N; i++) { - c.polys[i] = poly_add(a.polys[i], b.polys[i]); - } - return c; -} - -inline PolyVecM vec_m_add(PolyVecM a, PolyVecM b) { - PolyVecM c; - for (uint i = 0; i < VEC_M; i++) { - c.polys[i] = poly_add(a.polys[i], b.polys[i]); - } - return c; -} - -// Matrix-vector multiplication: A (M x N) * v (N) = result (M) -// Assumes inputs are in NTT domain -inline PolyVecM matrix_vec_mul_ntt(device const PolyMatrix* A, PolyVecN v) { - PolyVecM result; - - for (uint i = 0; i < VEC_M; i++) { - result.polys[i] = poly_zero(); - for (uint j = 0; j < VEC_N; j++) { - Poly product = poly_mul_ntt(A->polys[i * VEC_N + j], v.polys[j]); - result.polys[i] = poly_add(result.polys[i], product); - } - } - - return result; -} - -// ============================================================================ -// Gaussian Sampling (Rejection Sampling with CDT) -// ============================================================================ - -// Simple box-muller approximation for discrete Gaussian -// Production code would use CDT or Knuth-Yao -inline int sample_gaussian(thread uint* rng_state, float sigma) { - // LCG for random bits - *rng_state = (*rng_state) * 1103515245u + 12345u; - uint u1 = *rng_state; - *rng_state = (*rng_state) * 1103515245u + 12345u; - uint u2 = *rng_state; - - // Convert to uniform [0,1) - float f1 = (float)(u1 >> 8) / 16777216.0f; - float f2 = (float)(u2 >> 8) / 16777216.0f; - - // Box-Muller transform - float r = sigma * sqrt(-2.0f * log(f1 + 0.000001f)); - float theta = 2.0f * 3.14159265f * f2; - - return (int)round(r * cos(theta)); -} - -// Sample a polynomial with coefficients from discrete Gaussian -inline Poly sample_poly_gaussian(thread uint* rng_state, float sigma) { - Poly p; - for (uint i = 0; i < RING_N; i++) { - int sample = sample_gaussian(rng_state, sigma); - // Reduce to [0, Q) - p.coeffs[i] = (sample >= 0) ? (Coeff)sample : (Coeff)(sample + (int)RING_Q); - } - return p; -} - -// ============================================================================ -// Rejection Sampling for Signature -// ============================================================================ - -// Check if z vector is within rejection bound -inline bool check_rejection_bound(PolyVecN z, int bound) { - for (uint i = 0; i < VEC_N; i++) { - for (uint j = 0; j < RING_N; j++) { - int coeff = center_reduce(z.polys[i].coeffs[j]); - if (coeff > bound || coeff < -bound) { - return false; - } - } - } - return true; -} - -// Compute infinity norm of polynomial -inline int poly_norm_inf(Poly p) { - int max_val = 0; - for (uint i = 0; i < RING_N; i++) { - int coeff = center_reduce(p.coeffs[i]); - int abs_coeff = (coeff >= 0) ? coeff : -coeff; - if (abs_coeff > max_val) max_val = abs_coeff; - } - return max_val; -} - -// ============================================================================ -// Ringtail Signing Kernels -// ============================================================================ - -// Kernel 1: Generate signature shares (per participant) -kernel void ringtail_generate_share( - device const Poly* secret_shares [[buffer(0)]], // s_i (participant's share) - device const PolyMatrix* public_key_A [[buffer(1)]], // Public matrix A - device const PolyVecM* commitment_y [[buffer(2)]], // Commitment y = A*r - device const Poly* challenge [[buffer(3)]], // Challenge c - device const PolyVecN* randomness [[buffer(4)]], // Randomness r_i - device const Coeff* ntt_twiddles [[buffer(5)]], // NTT twiddles - device SignatureShare* shares [[buffer(6)]], // Output shares - constant RingtailParams& params [[buffer(7)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= params.num_participants) return; - - // Load data - Poly s_i = secret_shares[gid]; - Poly c = challenge[0]; - PolyVecN r_i = randomness[gid]; - - // Convert to NTT domain - thread Poly c_ntt = c; - ntt_forward_inplace(&c_ntt, ntt_twiddles); - - thread Poly s_i_ntt = s_i; - ntt_forward_inplace(&s_i_ntt, ntt_twiddles); - - // Compute z_i = r_i + c * s_i for each component - PolyVecN z_i; - for (uint j = 0; j < VEC_N; j++) { - thread Poly r_ij_ntt = r_i.polys[j]; - ntt_forward_inplace(&r_ij_ntt, ntt_twiddles); - - // c * s_i in NTT domain - Poly cs_ntt = poly_mul_ntt(c_ntt, s_i_ntt); - - // z_ij = r_ij + c * s_i (NTT domain) - Poly z_ij_ntt = poly_add(r_ij_ntt, cs_ntt); - - // Inverse NTT - thread Poly z_ij = z_ij_ntt; - // ntt_inverse_inplace would be called here - - z_i.polys[j] = z_ij; - } - - // Compute Delta for rounding (A * z_i - commitment_y) - PolyVecN z_i_ntt; - for (uint j = 0; j < VEC_N; j++) { - z_i_ntt.polys[j] = z_i.polys[j]; - ntt_forward_inplace(&z_i_ntt.polys[j], ntt_twiddles); - } - - PolyVecM Az = matrix_vec_mul_ntt(public_key_A, z_i_ntt); - PolyVecM Delta; - for (uint j = 0; j < VEC_M; j++) { - Delta.polys[j] = poly_sub(Az.polys[j], commitment_y->polys[j]); - } - - // Store signature share - SignatureShare share; - share.participant_id = gid + 1; // 1-indexed - share.c = c; - share.z = z_i; - share.Delta = Delta; - - shares[gid] = share; -} - -// Kernel 2: Aggregate signature shares -kernel void ringtail_aggregate_shares( - device const SignatureShare* shares [[buffer(0)]], - device const uint* participant_ids [[buffer(1)]], - device const Coeff* lagrange_coeffs [[buffer(2)]], // Precomputed - device Poly* aggregated_c [[buffer(3)]], - device PolyVecN* aggregated_z [[buffer(4)]], - device PolyVecM* aggregated_delta [[buffer(5)]], - constant RingtailParams& params [[buffer(6)]], - uint gid [[thread_position_in_grid]] -) { - if (gid != 0) return; // Single-threaded aggregation - - // Initialize accumulators - Poly c = shares[0].c; // Challenge is the same for all - PolyVecN z; - PolyVecM Delta; - - for (uint i = 0; i < VEC_N; i++) z.polys[i] = poly_zero(); - for (uint i = 0; i < VEC_M; i++) Delta.polys[i] = poly_zero(); - - // Aggregate: z = sum(lambda_i * z_i), Delta = sum(lambda_i * Delta_i) - for (uint p = 0; p < params.num_participants; p++) { - Coeff lambda = lagrange_coeffs[p]; - SignatureShare share = shares[p]; - - // z += lambda * z_i - for (uint i = 0; i < VEC_N; i++) { - Poly scaled = poly_scalar_mul(share.z.polys[i], lambda); - z.polys[i] = poly_add(z.polys[i], scaled); - } - - // Delta += lambda * Delta_i - for (uint i = 0; i < VEC_M; i++) { - Poly scaled = poly_scalar_mul(share.Delta.polys[i], lambda); - Delta.polys[i] = poly_add(Delta.polys[i], scaled); - } - } - - aggregated_c[0] = c; - *aggregated_z = z; - *aggregated_delta = Delta; -} - -// Kernel 3: Batch NTT forward (one polynomial per thread) -kernel void ringtail_batch_ntt_forward( - device Poly* polys [[buffer(0)]], - device const Coeff* twiddles [[buffer(1)]], - constant uint& count [[buffer(2)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= count) return; - - thread Poly p = polys[gid]; - ntt_forward_inplace(&p, twiddles); - polys[gid] = p; -} - -// Kernel 4: Batch NTT inverse (one polynomial per thread) -kernel void ringtail_batch_ntt_inverse( - device Poly* polys [[buffer(0)]], - device const Coeff* inv_twiddles [[buffer(1)]], - constant uint& count [[buffer(2)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= count) return; - - thread Poly p = polys[gid]; - ntt_inverse_inplace(&p, inv_twiddles); - polys[gid] = p; -} - -// Kernel 5: Gaussian sampling for masking randomness -kernel void ringtail_sample_gaussian_vec( - device PolyVecN* output [[buffer(0)]], - device const uint* seeds [[buffer(1)]], - constant RingtailParams& params [[buffer(2)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= params.num_participants) return; - - thread uint rng = seeds[gid]; - PolyVecN result; - - for (uint i = 0; i < VEC_N; i++) { - result.polys[i] = sample_poly_gaussian(&rng, SIGMA); - } - - output[gid] = result; -} - -// Kernel 6: Rejection sampling check (parallel per signature) -kernel void ringtail_check_rejection( - device const PolyVecN* z_vectors [[buffer(0)]], - device uint* valid [[buffer(1)]], - constant RingtailParams& params [[buffer(2)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= params.batch_size) return; - - PolyVecN z = z_vectors[gid]; - bool passes = check_rejection_bound(z, REJECTION_BOUND); - valid[gid] = passes ? 1 : 0; -} - -// Kernel 7: Polynomial norm computation (for security checks) -kernel void ringtail_compute_norms( - device const Poly* polys [[buffer(0)]], - device int* norms [[buffer(1)]], - constant uint& count [[buffer(2)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= count) return; - - Poly p = polys[gid]; - norms[gid] = poly_norm_inf(p); -} - -// Kernel 8: Matrix-vector product (parallel per output row) -kernel void ringtail_matrix_vec_mul( - device const PolyMatrix* A [[buffer(0)]], - device const PolyVecN* v [[buffer(1)]], - device PolyVecM* result [[buffer(2)]], - constant uint& batch_idx [[buffer(3)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= VEC_M) return; - - PolyVecN vec = v[batch_idx]; - Poly sum = poly_zero(); - - for (uint j = 0; j < VEC_N; j++) { - Poly product = poly_mul_ntt(A->polys[gid * VEC_N + j], vec.polys[j]); - sum = poly_add(sum, product); - } - - result[batch_idx].polys[gid] = sum; -} - -// Kernel 9: Combine partial signatures with Lagrange interpolation -kernel void ringtail_lagrange_combine( - device const SignatureShare* shares [[buffer(0)]], - device const Coeff* lagrange_at_zero [[buffer(1)]], // lambda_i(0) - device Poly* combined_secret [[buffer(2)]], - constant RingtailParams& params [[buffer(3)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= RING_N) return; // One thread per coefficient - - Coeff sum = 0; - - for (uint p = 0; p < params.num_participants; p++) { - Coeff lambda = lagrange_at_zero[p]; - // Get p-th participant's first z polynomial coefficient at position gid - Coeff z_coeff = shares[p].z.polys[0].coeffs[gid]; - sum = mod_add(sum, mod_mul(lambda, z_coeff)); - } - - combined_secret->coeffs[gid] = sum; -} diff --git a/ringtail/gpu/metal/ringtail_verify.metal b/ringtail/gpu/metal/ringtail_verify.metal deleted file mode 100644 index 5403b25..0000000 --- a/ringtail/gpu/metal/ringtail_verify.metal +++ /dev/null @@ -1,569 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Ringtail Lattice-Based Threshold Signature Verification -// Batch verification of Ringtail signatures with polynomial norm checks -// Optimized for Apple Silicon GPUs - -#include -using namespace metal; - -// ============================================================================ -// Ringtail Parameters (shared with ringtail_sign.metal) -// ============================================================================ - -constant uint RING_N = 256; -constant ulong RING_Q = 8380417UL; -constant ulong RING_Q_INV = 58728449UL; -constant uint VEC_M = 8; -constant uint VEC_N = 7; -constant ulong N_INV = 8347649UL; -constant int CHALLENGE_WEIGHT = 60; // Hamming weight of challenge -// Note: beta_bound and delta_bound come from VerifyParams at runtime - -// ============================================================================ -// Data Types -// ============================================================================ - -typedef uint Coeff; - -struct Poly { - Coeff coeffs[RING_N]; -}; - -struct PolyVecM { - Poly polys[VEC_M]; -}; - -struct PolyVecN { - Poly polys[VEC_N]; -}; - -struct PolyMatrix { - Poly polys[VEC_M * VEC_N]; -}; - -// Ringtail public key -struct RingtailPublicKey { - PolyMatrix A; // M x N matrix - PolyVecM bTilde; // Rounded public key b~ = round(A*s) -}; - -// Ringtail signature -struct RingtailSignature { - Poly c; // Challenge polynomial (sparse, weight tau) - PolyVecN z; // Response vector - PolyVecM Delta; // Rounding correction -}; - -// Verification parameters -struct VerifyParams { - uint batch_size; - uint num_threads; - int beta_bound; - int delta_bound; -}; - -// ============================================================================ -// Modular Arithmetic -// ============================================================================ - -inline Coeff mont_reduce(ulong a) { - ulong t = (a * RING_Q_INV) & 0xFFFFFFFFUL; - ulong u = a + t * RING_Q; - Coeff result = (Coeff)(u >> 32); - return (result >= RING_Q) ? result - RING_Q : result; -} - -inline Coeff mod_add(Coeff a, Coeff b) { - Coeff sum = a + b; - return (sum >= RING_Q) ? sum - (Coeff)RING_Q : sum; -} - -inline Coeff mod_sub(Coeff a, Coeff b) { - return (a >= b) ? a - b : a + (Coeff)RING_Q - b; -} - -inline Coeff mod_mul(Coeff a, Coeff b) { - return mont_reduce((ulong)a * (ulong)b); -} - -inline Coeff mod_neg(Coeff a) { - return (a == 0) ? 0 : (Coeff)RING_Q - a; -} - -inline int center_reduce(Coeff a) { - int t = (int)a; - int half_q = (int)(RING_Q >> 1); - return (t > half_q) ? t - (int)RING_Q : t; -} - -// ============================================================================ -// Polynomial Operations -// ============================================================================ - -inline Poly poly_zero() { - Poly p; - for (uint i = 0; i < RING_N; i++) p.coeffs[i] = 0; - return p; -} - -inline Poly poly_add(Poly a, Poly b) { - Poly c; - for (uint i = 0; i < RING_N; i++) { - c.coeffs[i] = mod_add(a.coeffs[i], b.coeffs[i]); - } - return c; -} - -inline Poly poly_sub(Poly a, Poly b) { - Poly c; - for (uint i = 0; i < RING_N; i++) { - c.coeffs[i] = mod_sub(a.coeffs[i], b.coeffs[i]); - } - return c; -} - -inline Poly poly_mul_ntt(Poly a, Poly b) { - Poly c; - for (uint i = 0; i < RING_N; i++) { - c.coeffs[i] = mod_mul(a.coeffs[i], b.coeffs[i]); - } - return c; -} - -// ============================================================================ -// NTT Operations -// ============================================================================ - -inline void ntt_forward_inplace(thread Poly* p, device const Coeff* twiddles) { - for (uint len = 1; len < RING_N; len <<= 1) { - for (uint i = 0; i < RING_N; i += 2 * len) { - for (uint j = 0; j < len; j++) { - Coeff w = twiddles[len + j]; - Coeff u = p->coeffs[i + j]; - Coeff v = mod_mul(p->coeffs[i + j + len], w); - p->coeffs[i + j] = mod_add(u, v); - p->coeffs[i + j + len] = mod_sub(u, v); - } - } - } -} - -inline void ntt_inverse_inplace(thread Poly* p, device const Coeff* inv_twiddles) { - for (uint len = RING_N >> 1; len > 0; len >>= 1) { - for (uint i = 0; i < RING_N; i += 2 * len) { - for (uint j = 0; j < len; j++) { - Coeff w = inv_twiddles[len + j]; - Coeff u = p->coeffs[i + j]; - Coeff v = p->coeffs[i + j + len]; - p->coeffs[i + j] = mod_add(u, v); - p->coeffs[i + j + len] = mod_mul(mod_sub(u, v), w); - } - } - } - - Coeff n_inv = (Coeff)N_INV; - for (uint i = 0; i < RING_N; i++) { - p->coeffs[i] = mod_mul(p->coeffs[i], n_inv); - } -} - -// ============================================================================ -// Norm and Bound Checking -// ============================================================================ - -// Compute infinity norm -inline int poly_norm_inf(Poly p) { - int max_val = 0; - for (uint i = 0; i < RING_N; i++) { - int coeff = center_reduce(p.coeffs[i]); - int abs_coeff = (coeff >= 0) ? coeff : -coeff; - if (abs_coeff > max_val) max_val = abs_coeff; - } - return max_val; -} - -// Compute L2 norm squared (sum of squared coefficients) -inline ulong poly_norm_l2_squared(Poly p) { - ulong sum = 0; - for (uint i = 0; i < RING_N; i++) { - int coeff = center_reduce(p.coeffs[i]); - sum += (ulong)coeff * (ulong)coeff; - } - return sum; -} - -// Check z vector norm bound -inline bool check_z_norm(PolyVecN z, int bound) { - for (uint i = 0; i < VEC_N; i++) { - if (poly_norm_inf(z.polys[i]) > bound) { - return false; - } - } - return true; -} - -// Check Delta vector norm bound -inline bool check_delta_norm(PolyVecM Delta, int bound) { - for (uint i = 0; i < VEC_M; i++) { - if (poly_norm_inf(Delta.polys[i]) > bound) { - return false; - } - } - return true; -} - -// Check challenge polynomial structure (sparse with bounded coefficients) -inline bool check_challenge_format(Poly c, int weight) { - int nonzero = 0; - for (uint i = 0; i < RING_N; i++) { - if (c.coeffs[i] != 0) { - nonzero++; - // Coefficients should be +1 or -1 (Q-1) - if (c.coeffs[i] != 1 && c.coeffs[i] != (Coeff)RING_Q - 1) { - return false; - } - } - } - return nonzero == weight; -} - -// ============================================================================ -// Matrix-Vector Multiplication -// ============================================================================ - -inline PolyVecM matrix_vec_mul_ntt(device const PolyMatrix* A, PolyVecN v) { - PolyVecM result; - - for (uint i = 0; i < VEC_M; i++) { - result.polys[i] = poly_zero(); - for (uint j = 0; j < VEC_N; j++) { - Poly product = poly_mul_ntt(A->polys[i * VEC_N + j], v.polys[j]); - result.polys[i] = poly_add(result.polys[i], product); - } - } - - return result; -} - -// Thread address space version for local copies -inline PolyVecM matrix_vec_mul_ntt(thread const PolyMatrix* A, PolyVecN v) { - PolyVecM result; - - for (uint i = 0; i < VEC_M; i++) { - result.polys[i] = poly_zero(); - for (uint j = 0; j < VEC_N; j++) { - Poly product = poly_mul_ntt(A->polys[i * VEC_N + j], v.polys[j]); - result.polys[i] = poly_add(result.polys[i], product); - } - } - - return result; -} - -// ============================================================================ -// Rounding Functions -// ============================================================================ - -// Apply rounding: round coefficient to nearest multiple of 2^d -inline Coeff round_coeff(Coeff a, uint d) { - uint mask = (1u << d) - 1; - uint half_val = 1u << (d - 1); // 'half' is reserved keyword in Metal - uint rounded = (a + half_val) & ~mask; - return rounded % RING_Q; -} - -// Apply rounding to polynomial -inline Poly poly_round(Poly p, uint d) { - Poly r; - for (uint i = 0; i < RING_N; i++) { - r.coeffs[i] = round_coeff(p.coeffs[i], d); - } - return r; -} - -// ============================================================================ -// Hash Function (for challenge derivation) -// ============================================================================ - -// Simple hash mixing for challenge recomputation -inline uint hash_mix(thread uint* state, uint input) { - *state ^= input; - *state = (*state * 0x5bd1e995u) ^ (*state >> 15); - return *state; -} - -// Derive challenge polynomial from hash state -inline Poly derive_challenge(thread uint* hash_state, int weight) { - Poly c = poly_zero(); - - int placed = 0; - while (placed < weight) { - uint pos = hash_mix(hash_state, placed) % RING_N; - if (c.coeffs[pos] == 0) { - uint sign = hash_mix(hash_state, pos) & 1; - c.coeffs[pos] = sign ? 1 : (Coeff)RING_Q - 1; - placed++; - } - } - - return c; -} - -// ============================================================================ -// Verification Kernels -// ============================================================================ - -// Kernel 1: Batch verify Ringtail signatures -kernel void ringtail_batch_verify( - device const RingtailSignature* signatures [[buffer(0)]], - device const RingtailPublicKey* public_keys [[buffer(1)]], - device const Poly* messages [[buffer(2)]], // H(message) as polynomial - device const Coeff* ntt_twiddles [[buffer(3)]], - device const Coeff* inv_twiddles [[buffer(4)]], - device uint* results [[buffer(5)]], - constant VerifyParams& params [[buffer(6)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= params.batch_size) return; - - RingtailSignature sig = signatures[gid]; - RingtailPublicKey pk = public_keys[gid]; - // msg used in step 9 challenge verification (currently simplified) - (void)messages[gid]; - - bool valid = true; - - // Step 1: Check z norm bound - if (!check_z_norm(sig.z, params.beta_bound)) { - results[gid] = 0; - return; - } - - // Step 2: Check Delta norm bound - if (!check_delta_norm(sig.Delta, params.delta_bound)) { - results[gid] = 0; - return; - } - - // Step 3: Check challenge format - if (!check_challenge_format(sig.c, CHALLENGE_WEIGHT)) { - results[gid] = 0; - return; - } - - // Step 4: Convert z to NTT domain - PolyVecN z_ntt; - for (uint i = 0; i < VEC_N; i++) { - thread Poly p = sig.z.polys[i]; - ntt_forward_inplace(&p, ntt_twiddles); - z_ntt.polys[i] = p; - } - - // Step 5: Compute A * z in NTT domain - PolyVecM Az = matrix_vec_mul_ntt(&pk.A, z_ntt); - - // Convert to coefficient domain - for (uint i = 0; i < VEC_M; i++) { - thread Poly p = Az.polys[i]; - ntt_inverse_inplace(&p, inv_twiddles); - Az.polys[i] = p; - } - - // Step 6: Convert c to NTT domain for multiplication - thread Poly c_ntt = sig.c; - ntt_forward_inplace(&c_ntt, ntt_twiddles); - - // Step 7: Compute c * bTilde for each component - PolyVecM c_btilde; - for (uint i = 0; i < VEC_M; i++) { - thread Poly btilde_ntt = pk.bTilde.polys[i]; - ntt_forward_inplace(&btilde_ntt, ntt_twiddles); - - Poly product = poly_mul_ntt(c_ntt, btilde_ntt); - ntt_inverse_inplace(&product, inv_twiddles); - c_btilde.polys[i] = product; - } - - // Step 8: Verify equation: round(A*z) = c*bTilde + Delta (approximately) - // Check that A*z - c*bTilde - Delta rounds to zero - for (uint i = 0; i < VEC_M; i++) { - Poly diff = poly_sub(Az.polys[i], c_btilde.polys[i]); - diff = poly_sub(diff, sig.Delta.polys[i]); - - // Check that diff has small coefficients after rounding - for (uint j = 0; j < RING_N; j++) { - int coeff = center_reduce(diff.coeffs[j]); - if (coeff > params.delta_bound || coeff < -params.delta_bound) { - valid = false; - break; - } - } - if (!valid) break; - } - - // Step 9: Verify challenge is correctly derived from commitment - // This would involve recomputing H(A*z - c*bTilde || message) - // Simplified: assume challenge is valid if format check passed - - results[gid] = valid ? 1 : 0; -} - -// Kernel 2: Compute polynomial norms in parallel -kernel void ringtail_compute_poly_norms( - device const Poly* polys [[buffer(0)]], - device int* inf_norms [[buffer(1)]], - device ulong* l2_norms [[buffer(2)]], - constant uint& count [[buffer(3)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= count) return; - - Poly p = polys[gid]; - inf_norms[gid] = poly_norm_inf(p); - l2_norms[gid] = poly_norm_l2_squared(p); -} - -// Kernel 3: Batch check rejection bounds -kernel void ringtail_check_bounds( - device const PolyVecN* z_vectors [[buffer(0)]], - device const PolyVecM* delta_vectors [[buffer(1)]], - device uint* z_valid [[buffer(2)]], - device uint* delta_valid [[buffer(3)]], - constant VerifyParams& params [[buffer(4)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= params.batch_size) return; - - z_valid[gid] = check_z_norm(z_vectors[gid], params.beta_bound) ? 1 : 0; - delta_valid[gid] = check_delta_norm(delta_vectors[gid], params.delta_bound) ? 1 : 0; -} - -// Kernel 4: Parallel matrix-vector multiplication (one output polynomial per thread) -kernel void ringtail_parallel_mat_vec_mul( - device const PolyMatrix* A [[buffer(0)]], - device const PolyVecN* z [[buffer(1)]], - device PolyVecM* Az [[buffer(2)]], - constant uint& batch_idx [[buffer(3)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= VEC_M) return; - - PolyVecN z_batch = z[batch_idx]; - Poly sum = poly_zero(); - - for (uint j = 0; j < VEC_N; j++) { - Poly product = poly_mul_ntt(A->polys[gid * VEC_N + j], z_batch.polys[j]); - sum = poly_add(sum, product); - } - - Az[batch_idx].polys[gid] = sum; -} - -// Kernel 5: Challenge verification -kernel void ringtail_verify_challenge( - device const Poly* challenges [[buffer(0)]], - device const Poly* recomputed_challenges [[buffer(1)]], - device uint* valid [[buffer(2)]], - constant uint& count [[buffer(3)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= count) return; - - Poly c1 = challenges[gid]; - Poly c2 = recomputed_challenges[gid]; - - bool equal = true; - for (uint i = 0; i < RING_N; i++) { - if (c1.coeffs[i] != c2.coeffs[i]) { - equal = false; - break; - } - } - - valid[gid] = equal ? 1 : 0; -} - -// Kernel 6: Reconstruct public key from shares (for threshold verification) -kernel void ringtail_reconstruct_public_key( - device const PolyVecM* pk_shares [[buffer(0)]], - device const Coeff* lagrange_coeffs [[buffer(1)]], - device PolyVecM* reconstructed [[buffer(2)]], - constant uint& num_shares [[buffer(3)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= VEC_M * RING_N) return; - - uint poly_idx = gid / RING_N; - uint coeff_idx = gid % RING_N; - - Coeff sum = 0; - for (uint s = 0; s < num_shares; s++) { - Coeff lambda = lagrange_coeffs[s]; - Coeff coeff = pk_shares[s].polys[poly_idx].coeffs[coeff_idx]; - sum = mod_add(sum, mod_mul(lambda, coeff)); - } - - reconstructed->polys[poly_idx].coeffs[coeff_idx] = sum; -} - -// Kernel 7: Batch apply rounding -kernel void ringtail_batch_round( - device Poly* polys [[buffer(0)]], - constant uint& count [[buffer(1)]], - constant uint& round_bits [[buffer(2)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= count) return; - - polys[gid] = poly_round(polys[gid], round_bits); -} - -// Kernel 8: Combined verification equation check -// Verifies: Az = c*bTilde + Delta + w (where w is rounding error) -kernel void ringtail_verify_equation( - device const PolyVecM* Az [[buffer(0)]], - device const PolyVecM* c_btilde [[buffer(1)]], - device const PolyVecM* Delta [[buffer(2)]], - device uint* valid [[buffer(3)]], - constant VerifyParams& params [[buffer(4)]], - uint gid [[thread_position_in_grid]] -) { - if (gid >= params.batch_size) return; - - bool result = true; - - for (uint i = 0; i < VEC_M && result; i++) { - Poly lhs = Az[gid].polys[i]; - Poly rhs = poly_add(c_btilde[gid].polys[i], Delta[gid].polys[i]); - Poly diff = poly_sub(lhs, rhs); - - // Check residual is small (within rounding error) - int norm = poly_norm_inf(diff); - if (norm > params.delta_bound) { - result = false; - } - } - - valid[gid] = result ? 1 : 0; -} - -// Kernel 9: Parallel coefficient validation -kernel void ringtail_validate_coefficients( - device const Poly* polys [[buffer(0)]], - device atomic_uint* invalid_count [[buffer(1)]], - constant uint& count [[buffer(2)]], - constant int& bound [[buffer(3)]], - uint gid [[thread_position_in_grid]] -) { - uint poly_idx = gid / RING_N; - uint coeff_idx = gid % RING_N; - - if (poly_idx >= count) return; - - int coeff = center_reduce(polys[poly_idx].coeffs[coeff_idx]); - if (coeff > bound || coeff < -bound) { - atomic_fetch_add_explicit(invalid_count, 1, memory_order_relaxed); - } -} diff --git a/ringtail/gpu/wgsl/ringtail.wgsl b/ringtail/gpu/wgsl/ringtail.wgsl deleted file mode 100644 index 74c230a..0000000 --- a/ringtail/gpu/wgsl/ringtail.wgsl +++ /dev/null @@ -1,147 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Ringtail lattice-based threshold signatures in WGSL. -// Polynomial ring Z_q[x]/(x^n + 1), q=8380417, n=256. -// NTT-based polynomial multiplication. - -@group(0) @binding(0) var shares: array; // [num_ops * 256] -@group(0) @binding(1) var messages: array; // [num_ops * 8] (32 bytes each) -@group(0) @binding(2) var partial_sigs: array; // [num_ops * 256] -@group(0) @binding(3) var params: vec4; // params.x = num_ops - -const Q: i32 = 8380417; - -const ZETAS = array( - 25847, -2608894, -518909, 237124, -777960, -876248, 466468, 1826347, - 2353451, -359251, -2091905, 3119733, -2884855, 3111497, 2680103, 2725464, - 1024112, -1079900, 3585928, -549488, -1119584, 2619752, -2108549, -2118186, - -3859737, -1399561, -3277672, 1757237, -19422, 4010497, 280005, -2353451, - -1012179, -1277625, 1526252, -1402780, -2091905, 3119733, 3585928, -549488, - 2619752, -2108549, 2804197, -3199876, -38575, -2704181, 1757237, -19422, - 280005, 2706023, 1391570, 2287915, -3583748, -1399561, -3277672, -2353451, - 2353451, 3585928, -549488, 2619752, -2108549, 2804197, -3199876, -38575, - -2704181, 1757237, -19422, 280005, 2706023, 1391570, 2287915, -3583748, - -1399561, -3277672, 237124, -777960, -876248, 466468, 1826347, -2608894, - -518909, 237124, -777960, -876248, 466468, 1826347, 2353451, -359251, - -2091905, 3119733, -2884855, 3111497, 2680103, 2725464, 1024112, -1079900, - 3585928, -549488, -1119584, 2619752, -2108549, -2118186, -3859737, -1399561, - -3277672, 1757237, -19422, 4010497, 280005, -2353451, -1012179, -1277625, - 1526252, -1402780, 2706023, 1391570, 2287915, -3583748, -1399561, -3277672, - 1757237, -19422, 280005, 2706023, 1391570, 2287915, -3583748, -1399561 -); - -fn mod_mul(a: i32, b: i32) -> i32 { - let a_lo = u32(a) & 0xFFFFu; - let a_hi = u32(a) >> 16u; - let b_lo = u32(b) & 0xFFFFu; - let b_hi = u32(b) >> 16u; - let ll = a_lo * b_lo; - let mid = a_lo * b_hi + a_hi * b_lo; - let hh = a_hi * b_hi; - let result_lo = ll + (mid << 16u); - let result_hi = hh + (mid >> 16u) + select(0u, 1u, result_lo < ll); - let q = u32(Q); - var r = result_lo - (result_hi * q); - if (r >= q) { r = r - q; } - if (r >= q) { r = r - q; } - return i32(r); -} - -fn ntt256(poly: ptr>) { - var k = 0u; - var len = 128u; - loop { - if (len == 0u) { break; } - var start = 0u; - loop { - if (start >= 256u) { break; } - k = k + 1u; - let zeta = ZETAS[k]; - var j = start; - loop { - if (j >= start + len) { break; } - let t = mod_mul(zeta, (*poly)[j + len]); - (*poly)[j + len] = (*poly)[j] - t; - (*poly)[j] = (*poly)[j] + t; - if ((*poly)[j] >= Q) { (*poly)[j] = (*poly)[j] - Q; } - if ((*poly)[j + len] < 0) { (*poly)[j + len] = (*poly)[j + len] + Q; } - j = j + 1u; - } - start = start + 2u * len; - } - len = len >> 1u; - } -} - -fn inv_ntt256(poly: ptr>) { - let f: i32 = 41978; - var k = 127u; - var len = 1u; - loop { - if (len > 128u) { break; } - var start = 0u; - loop { - if (start >= 256u) { break; } - var zeta = -ZETAS[k]; - k = k - 1u; - if (zeta < 0) { zeta = zeta + Q; } - var j = start; - loop { - if (j >= start + len) { break; } - let t = (*poly)[j]; - (*poly)[j] = t + (*poly)[j + len]; - (*poly)[j + len] = t - (*poly)[j + len]; - if ((*poly)[j] >= Q) { (*poly)[j] = (*poly)[j] - Q; } - if ((*poly)[j + len] < 0) { (*poly)[j + len] = (*poly)[j + len] + Q; } - (*poly)[j + len] = mod_mul(zeta, (*poly)[j + len]); - j = j + 1u; - } - start = start + 2u * len; - } - len = len << 1u; - } - for (var i = 0u; i < 256u; i = i + 1u) { - (*poly)[i] = mod_mul(f, (*poly)[i]); - } -} - -@compute @workgroup_size(64) -fn ringtail_partial_sign_batch(@builtin(global_invocation_id) gid: vec3) { - let tid = gid.x; - if (tid >= params.x) { return; } - - let base = tid * 256u; - let msg_base = tid * 8u; - - // Load share - var share: array; - for (var i = 0u; i < 256u; i = i + 1u) { - share[i] = shares[base + i]; - } - - // Derive challenge from message hash - var challenge: array; - for (var i = 0u; i < 256u; i = i + 1u) { - let idx = (i * 4u) % 8u; - var val = messages[msg_base + idx]; - val = val ^ (i * 2654435761u); - challenge[i] = i32(val % u32(Q)); - } - - ntt256(&challenge); - ntt256(&share); - - // Pointwise multiply - var result: array; - for (var i = 0u; i < 256u; i = i + 1u) { - result[i] = mod_mul(share[i], challenge[i]); - } - - inv_ntt256(&result); - - // Write result - for (var i = 0u; i < 256u; i = i + 1u) { - partial_sigs[base + i] = result[i]; - } -} diff --git a/ripemd160/gpu/cuda/ripemd160.cu b/ripemd160/gpu/cuda/ripemd160.cu deleted file mode 100644 index e572aef..0000000 --- a/ripemd160/gpu/cuda/ripemd160.cu +++ /dev/null @@ -1,191 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// RIPEMD-160 batch hashing — CUDA implementation. -// Byte-equal to ripemd160/cpp/ripemd160.cpp::ripemd160() and to -// ripemd160/gpu/metal/ripemd160_batch.metal::ripemd160_jobs. -// -// Algorithm: Dobbertin, Bosselaers, Preneel — "RIPEMD-160: A Strengthened -// Version of RIPEMD" (1996). Two parallel "lines" (left + right) of 5 rounds -// each, 16 step operations per round, applied to a 16-word LE block. Final -// state combination interleaves left-line z[0] and right-line z[1] across -// the IV plus a 4-word rotation. Padding is MD4-style: append 0x80, zero-pad -// to 56 mod 64, append 64-bit little-endian bit length. -// -// One thread per input. Layout matches the Metal/SHA-256 drivers: caller -// fills a flat byte arena and per-input (offset, length) descriptors; -// outputs are 20-byte stride. - -#include - -#ifndef __CUDA_ARCH__ -#define __device__ -#define __global__ -#define __shared__ -struct dim3 { unsigned x, y, z; }; -static dim3 blockIdx, blockDim, threadIdx; -#endif - -// Round added constants (Dobbertin §3, Table 1). K[i] for left line, K'[i] -// for right line. Both lines have 5 rounds; only one is non-zero per round -// per line. -__device__ static const uint32_t K0[5] = { - 0x00000000u, 0x5a827999u, 0x6ed9eba1u, 0x8f1bbcdcu, 0xa953fd4eu -}; -__device__ static const uint32_t K1[5] = { - 0x50a28be6u, 0x5c4dd124u, 0x6d703ef3u, 0x7a6d76e9u, 0x00000000u -}; - -// Message word selection r[i] (left) and r'[i] (right). 80 indices each. -__device__ static const uint8_t R0[80] = { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 7, 4, 13, 1, 10, 6, 15, 3, 12, 0, 9, 5, 2, 14, 11, 8, - 3, 10, 14, 4, 9, 15, 8, 1, 2, 7, 0, 6, 13, 11, 5, 12, - 1, 9, 11, 10, 0, 8, 12, 4, 13, 3, 7, 15, 14, 5, 6, 2, - 4, 0, 5, 9, 7, 12, 2, 10, 14, 1, 3, 8, 11, 6, 15, 13, -}; -__device__ static const uint8_t R1[80] = { - 5, 14, 7, 0, 9, 2, 11, 4, 13, 6, 15, 8, 1, 10, 3, 12, - 6, 11, 3, 7, 0, 13, 5, 10, 14, 15, 8, 12, 4, 9, 1, 2, - 15, 5, 1, 3, 7, 14, 6, 9, 11, 8, 12, 2, 10, 0, 4, 13, - 8, 6, 4, 1, 3, 11, 15, 0, 5, 12, 2, 13, 9, 7, 10, 14, - 12, 15, 10, 4, 1, 5, 8, 7, 6, 2, 13, 14, 0, 3, 9, 11, -}; - -// Rotation amounts s[i] / s'[i]. -__device__ static const uint8_t S0[80] = { - 11, 14, 15, 12, 5, 8, 7, 9, 11, 13, 14, 15, 6, 7, 9, 8, - 7, 6, 8, 13, 11, 9, 7, 15, 7, 12, 15, 9, 11, 7, 13, 12, - 11, 13, 6, 7, 14, 9, 13, 15, 14, 8, 13, 6, 5, 12, 7, 5, - 11, 12, 14, 15, 14, 15, 9, 8, 9, 14, 5, 6, 8, 6, 5, 12, - 9, 15, 5, 11, 6, 8, 13, 12, 5, 12, 13, 14, 11, 8, 5, 6, -}; -__device__ static const uint8_t S1[80] = { - 8, 9, 9, 11, 13, 15, 15, 5, 7, 7, 8, 11, 14, 14, 12, 6, - 9, 13, 15, 7, 12, 8, 9, 11, 7, 7, 12, 7, 6, 15, 13, 11, - 9, 7, 15, 11, 8, 6, 6, 14, 12, 13, 5, 14, 13, 13, 7, 5, - 15, 5, 8, 11, 14, 14, 6, 14, 6, 9, 12, 9, 12, 5, 15, 8, - 8, 5, 12, 9, 12, 5, 14, 6, 8, 13, 6, 5, 15, 13, 11, 11, -}; - -__device__ static inline uint32_t rotl32(uint32_t x, uint32_t n) { - return (x << n) | (x >> (32u - n)); -} - -// Boolean selection functions f_j (Dobbertin §3, eq. 1-5). -__device__ static inline uint32_t round_f(uint32_t round_idx, - uint32_t x, uint32_t y, uint32_t z) { - if (round_idx == 0) return x ^ y ^ z; // f1 - if (round_idx == 1) return ((y ^ z) & x) ^ z; // f2 = (x&y) | (~x&z) - if (round_idx == 2) return (x | ~y) ^ z; // f3 - if (round_idx == 3) return ((x ^ y) & z) ^ y; // f4 = (x&z) | (y&~z) - return x ^ (y | ~z); // f5 -} - -__device__ static void compress(uint32_t* h, const uint32_t* w) { - // Two parallel lines z[0] (uses f1..f5 in order, R0/S0/K0) and z[1] - // (uses f5..f1 mirrored, R1/S1/K1). - uint32_t a0 = h[0], b0 = h[1], c0 = h[2], d0 = h[3], e0 = h[4]; - uint32_t a1 = h[0], b1 = h[1], c1 = h[2], d1 = h[3], e1 = h[4]; - - for (uint32_t j = 0; j < 80; ++j) { - uint32_t round_idx = j / 16u; - - // Left line. - uint32_t t0 = rotl32(a0 + round_f(round_idx, b0, c0, d0) - + w[R0[j]] + K0[round_idx], S0[j]) + e0; - a0 = e0; e0 = d0; d0 = rotl32(c0, 10); c0 = b0; b0 = t0; - - // Right line uses mirrored function index 4 - round_idx. - uint32_t inv_round = 4u - round_idx; - uint32_t t1 = rotl32(a1 + round_f(inv_round, b1, c1, d1) - + w[R1[j]] + K1[round_idx], S1[j]) + e1; - a1 = e1; e1 = d1; d1 = rotl32(c1, 10); c1 = b1; b1 = t1; - } - - // Final mixdown: t = h[1] + c0 + d1, h[1] = h[2] + d0 + e1, ... - uint32_t t = h[1] + c0 + d1; - h[1] = h[2] + d0 + e1; - h[2] = h[3] + e0 + a1; - h[3] = h[4] + a0 + b1; - h[4] = h[0] + b0 + c1; - h[0] = t; -} - -extern "C" __global__ void ripemd160_jobs( - const uint8_t* __restrict__ inputs, - const uint32_t* __restrict__ input_offsets, - const uint32_t* __restrict__ input_lens, - uint8_t* __restrict__ outputs, - uint32_t num_jobs) -{ - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= num_jobs) return; - - const uint8_t* in = inputs + input_offsets[tid]; - uint32_t len = input_lens[tid]; - - uint32_t h[5] = { - 0x67452301u, 0xefcdab89u, 0x98badcfeu, 0x10325476u, 0xc3d2e1f0u - }; - - uint8_t block[64]; - uint32_t w[16]; - uint32_t pos = 0; - while ((len - pos) >= 64u) { - for (uint32_t i = 0; i < 64; ++i) block[i] = in[pos + i]; - for (uint32_t i = 0; i < 16; ++i) { - w[i] = ((uint32_t)block[i * 4 + 0] ) | - ((uint32_t)block[i * 4 + 1] << 8) | - ((uint32_t)block[i * 4 + 2] << 16) | - ((uint32_t)block[i * 4 + 3] << 24); - } - compress(h, w); - pos += 64; - } - - // Final block(s) with MD4-style padding: append 0x80, zero-pad, 64-bit LE - // bit length at offset 56..63 of the final 64-byte block. - uint32_t rem = len - pos; - for (uint32_t i = 0; i < 64; ++i) block[i] = 0; - for (uint32_t i = 0; i < rem; ++i) block[i] = in[pos + i]; - block[rem] = 0x80u; - - if (rem >= 56u) { - // Tail spans two final blocks. - for (uint32_t i = 0; i < 16; ++i) { - w[i] = ((uint32_t)block[i * 4 + 0] ) | - ((uint32_t)block[i * 4 + 1] << 8) | - ((uint32_t)block[i * 4 + 2] << 16) | - ((uint32_t)block[i * 4 + 3] << 24); - } - compress(h, w); - for (uint32_t i = 0; i < 64; ++i) block[i] = 0; - } - - uint64_t bit_len = (uint64_t)len * 8u; - block[56] = (uint8_t)((bit_len ) & 0xFFu); - block[57] = (uint8_t)((bit_len >> 8) & 0xFFu); - block[58] = (uint8_t)((bit_len >> 16) & 0xFFu); - block[59] = (uint8_t)((bit_len >> 24) & 0xFFu); - block[60] = (uint8_t)((bit_len >> 32) & 0xFFu); - block[61] = (uint8_t)((bit_len >> 40) & 0xFFu); - block[62] = (uint8_t)((bit_len >> 48) & 0xFFu); - block[63] = (uint8_t)((bit_len >> 56) & 0xFFu); - - for (uint32_t i = 0; i < 16; ++i) { - w[i] = ((uint32_t)block[i * 4 + 0] ) | - ((uint32_t)block[i * 4 + 1] << 8) | - ((uint32_t)block[i * 4 + 2] << 16) | - ((uint32_t)block[i * 4 + 3] << 24); - } - compress(h, w); - - uint8_t* out = outputs + tid * 20; - for (uint32_t i = 0; i < 5; ++i) { - out[i * 4 + 0] = (uint8_t)( h[i] & 0xFFu); - out[i * 4 + 1] = (uint8_t)((h[i] >> 8) & 0xFFu); - out[i * 4 + 2] = (uint8_t)((h[i] >> 16) & 0xFFu); - out[i * 4 + 3] = (uint8_t)((h[i] >> 24) & 0xFFu); - } -} diff --git a/ripemd160/gpu/cuda/ripemd160_driver.cpp b/ripemd160/gpu/cuda/ripemd160_driver.cpp deleted file mode 100644 index ad271fb..0000000 --- a/ripemd160/gpu/cuda/ripemd160_driver.cpp +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CUDA host driver for batched RIPEMD-160 (Dobbertin et al. 1996). -// -// Build modes: -// 1. With CUDA toolkit (LUX_RIPEMD160_HAVE_CUDA defined): -// - Compiles ripemd160.cu via nvcc; invokes the kernel with one -// thread per input. Byte-equal to ripemd160/cpp/ripemd160.cpp -// and ripemd160/gpu/metal/ripemd160_batch.metal. -// 2. Without CUDA (LUX_RIPEMD160_HAVE_CUDA not defined): -// - Stub mode: lux_ripemd160_cuda_available() returns 0, every -// other function returns -1 ("CUDA unavailable on this host"). -// The test harness skips the CUDA path on Apple/non-CUDA hosts. - -#include "ripemd160_driver.h" - -#include -#include - -#ifdef LUX_RIPEMD160_HAVE_CUDA -#include - -// Forward declaration of the CUDA kernel defined in ripemd160.cu. -extern "C" __global__ void ripemd160_jobs( - const uint8_t* inputs, - const uint32_t* input_offsets, - const uint32_t* input_lens, - uint8_t* outputs, - uint32_t num_jobs); - -extern "C" int lux_ripemd160_cuda_available(void) { - int count = 0; - cudaError_t e = cudaGetDeviceCount(&count); - return (e == cudaSuccess && count > 0) ? 1 : 0; -} - -extern "C" int ripemd160_batch_cuda( - const uint8_t* inputs_arena, - size_t inputs_arena_len, - const uint32_t* input_offsets, - const uint32_t* input_lens, - size_t n, - uint8_t* outputs_arena) { - - if (n == 0) return 0; - if (!input_offsets || !input_lens || !outputs_arena) return -1; - if (!lux_ripemd160_cuda_available()) return -2; - - uint8_t* d_inputs = nullptr; - uint32_t* d_offsets = nullptr; - uint32_t* d_lens = nullptr; - uint8_t* d_outputs = nullptr; - size_t out_bytes = n * 20u; - - auto cleanup = [&]() { - if (d_inputs) cudaFree(d_inputs); - if (d_offsets) cudaFree(d_offsets); - if (d_lens) cudaFree(d_lens); - if (d_outputs) cudaFree(d_outputs); - }; - - if (cudaMalloc((void**)&d_inputs, inputs_arena_len ? inputs_arena_len : 1) != cudaSuccess) { - cleanup(); return -3; - } - if (cudaMalloc((void**)&d_offsets, n * sizeof(uint32_t)) != cudaSuccess) { - cleanup(); return -3; - } - if (cudaMalloc((void**)&d_lens, n * sizeof(uint32_t)) != cudaSuccess) { - cleanup(); return -3; - } - if (cudaMalloc((void**)&d_outputs, out_bytes) != cudaSuccess) { - cleanup(); return -3; - } - - if (inputs_arena_len) { - if (cudaMemcpy(d_inputs, inputs_arena, inputs_arena_len, - cudaMemcpyHostToDevice) != cudaSuccess) { - cleanup(); return -4; - } - } - if (cudaMemcpy(d_offsets, input_offsets, n * sizeof(uint32_t), - cudaMemcpyHostToDevice) != cudaSuccess) { - cleanup(); return -4; - } - if (cudaMemcpy(d_lens, input_lens, n * sizeof(uint32_t), - cudaMemcpyHostToDevice) != cudaSuccess) { - cleanup(); return -4; - } - - unsigned tg = 64; - unsigned grid = unsigned((n + tg - 1) / tg); - ripemd160_jobs<<>>(d_inputs, d_offsets, d_lens, - d_outputs, uint32_t(n)); - if (cudaDeviceSynchronize() != cudaSuccess) { - cleanup(); return -4; - } - if (cudaMemcpy(outputs_arena, d_outputs, out_bytes, - cudaMemcpyDeviceToHost) != cudaSuccess) { - cleanup(); return -4; - } - cleanup(); - return 0; -} - -#else // LUX_RIPEMD160_HAVE_CUDA not defined: stub mode - -extern "C" int lux_ripemd160_cuda_available(void) { return 0; } - -extern "C" int ripemd160_batch_cuda( - const uint8_t*, size_t, - const uint32_t*, const uint32_t*, - size_t, uint8_t*) { - return -1; -} - -#endif // LUX_RIPEMD160_HAVE_CUDA diff --git a/ripemd160/gpu/cuda/ripemd160_driver.h b/ripemd160/gpu/cuda/ripemd160_driver.h deleted file mode 100644 index 15d6623..0000000 --- a/ripemd160/gpu/cuda/ripemd160_driver.h +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Public C-ABI for the RIPEMD-160 CUDA driver. Mirrors the Metal driver -// in ripemd160/gpu/metal/ripemd160_batch_driver.mm. On hosts without CUDA -// every function returns -1 except lux_ripemd160_cuda_available() which -// returns 0. - -#ifndef LUX_RIPEMD160_DRIVER_CUDA_H -#define LUX_RIPEMD160_DRIVER_CUDA_H - -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -// Returns 1 if a CUDA device is available, 0 otherwise. -int lux_ripemd160_cuda_available(void); - -// Run N RIPEMD-160 hashes in one CUDA dispatch. Each input lives at -// inputs_arena[input_offsets[i] .. + input_lens[i]); each output goes to -// outputs_arena[i * 20 .. i * 20 + 20). Returns 0 on success, negative on -// failure (-1 = invalid args, -2 = device unavailable, -3 = device alloc -// failed, -4 = launch failed). -int ripemd160_batch_cuda( - const uint8_t* inputs_arena, - size_t inputs_arena_len, - const uint32_t* input_offsets, - const uint32_t* input_lens, - size_t n, - uint8_t* outputs_arena); - -#ifdef __cplusplus -} -#endif - -#endif // LUX_RIPEMD160_DRIVER_CUDA_H diff --git a/ripemd160/gpu/metal/ripemd160_batch.metal b/ripemd160/gpu/metal/ripemd160_batch.metal deleted file mode 100644 index d4cfa3b..0000000 --- a/ripemd160/gpu/metal/ripemd160_batch.metal +++ /dev/null @@ -1,172 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// GPU-batched RIPEMD-160 (Dobbertin et al. 1996). One thread per input. -// Byte-equal to ripemd160/cpp/ripemd160.cpp::ripemd160(). -// -// 64-byte blocks (16 LE uint32 words), 5 rounds × 16 steps × 2 parallel -// lines (z0/z1), output 20 bytes little-endian. Padding is MD4-style: -// append 0x80, zero pad to 56 mod 64, append 64-bit little-endian bit-length. - -#include -using namespace metal; - -constant uint K0[5] = { 0x00000000u, 0x5a827999u, 0x6ed9eba1u, 0x8f1bbcdcu, 0xa953fd4eu }; -constant uint K1[5] = { 0x50a28be6u, 0x5c4dd124u, 0x6d703ef3u, 0x7a6d76e9u, 0x00000000u }; - -constant uint8_t R0[80] = { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 7, 4, 13, 1, 10, 6, 15, 3, 12, 0, 9, 5, 2, 14, 11, 8, - 3, 10, 14, 4, 9, 15, 8, 1, 2, 7, 0, 6, 13, 11, 5, 12, - 1, 9, 11, 10, 0, 8, 12, 4, 13, 3, 7, 15, 14, 5, 6, 2, - 4, 0, 5, 9, 7, 12, 2, 10, 14, 1, 3, 8, 11, 6, 15, 13, -}; -constant uint8_t R1[80] = { - 5, 14, 7, 0, 9, 2, 11, 4, 13, 6, 15, 8, 1, 10, 3, 12, - 6, 11, 3, 7, 0, 13, 5, 10, 14, 15, 8, 12, 4, 9, 1, 2, - 15, 5, 1, 3, 7, 14, 6, 9, 11, 8, 12, 2, 10, 0, 4, 13, - 8, 6, 4, 1, 3, 11, 15, 0, 5, 12, 2, 13, 9, 7, 10, 14, - 12, 15, 10, 4, 1, 5, 8, 7, 6, 2, 13, 14, 0, 3, 9, 11, -}; -constant uint8_t S0[80] = { - 11, 14, 15, 12, 5, 8, 7, 9, 11, 13, 14, 15, 6, 7, 9, 8, - 7, 6, 8, 13, 11, 9, 7, 15, 7, 12, 15, 9, 11, 7, 13, 12, - 11, 13, 6, 7, 14, 9, 13, 15, 14, 8, 13, 6, 5, 12, 7, 5, - 11, 12, 14, 15, 14, 15, 9, 8, 9, 14, 5, 6, 8, 6, 5, 12, - 9, 15, 5, 11, 6, 8, 13, 12, 5, 12, 13, 14, 11, 8, 5, 6, -}; -constant uint8_t S1[80] = { - 8, 9, 9, 11, 13, 15, 15, 5, 7, 7, 8, 11, 14, 14, 12, 6, - 9, 13, 15, 7, 12, 8, 9, 11, 7, 7, 12, 7, 6, 15, 13, 11, - 9, 7, 15, 11, 8, 6, 6, 14, 12, 13, 5, 14, 13, 13, 7, 5, - 15, 5, 8, 11, 14, 14, 6, 14, 6, 9, 12, 9, 12, 5, 15, 8, - 8, 5, 12, 9, 12, 5, 14, 6, 8, 13, 6, 5, 15, 13, 11, 11, -}; - -inline uint rotl32(uint x, uint n) { return (x << n) | (x >> (32u - n)); } - -inline uint f1(uint x, uint y, uint z) { return x ^ y ^ z; } -inline uint f2(uint x, uint y, uint z) { return ((y ^ z) & x) ^ z; } // (x&y) | (~x&z) -inline uint f3(uint x, uint y, uint z) { return (x | ~y) ^ z; } -inline uint f4(uint x, uint y, uint z) { return ((x ^ y) & z) ^ y; } // (x&z) | (y&~z) -inline uint f5(uint x, uint y, uint z) { return x ^ (y | ~z); } - -inline uint round_f(uint round_idx, uint x, uint y, uint z) { - if (round_idx == 0) return f1(x, y, z); - if (round_idx == 1) return f2(x, y, z); - if (round_idx == 2) return f3(x, y, z); - if (round_idx == 3) return f4(x, y, z); - return f5(x, y, z); -} - -inline void compress(thread uint* h, thread const uint* w) { - // Two parallel lines z0 (uses R0/S0/K0/f1..f5) and z1 (uses R1/S1/K1/f5..f1). - uint a0 = h[0], b0 = h[1], c0 = h[2], d0 = h[3], e0 = h[4]; - uint a1 = h[0], b1 = h[1], c1 = h[2], d1 = h[3], e1 = h[4]; - - for (uint j = 0; j < 80; ++j) { - uint round_idx = j / 16u; - - // Line 0 - uint t0 = rotl32(a0 + round_f(round_idx, b0, c0, d0) - + w[R0[j]] + K0[round_idx], S0[j]) + e0; - a0 = e0; e0 = d0; d0 = rotl32(c0, 10); c0 = b0; b0 = t0; - - // Line 1: function index is 4 - round_idx (mirror). - uint inv_round = 4u - round_idx; - uint t1 = rotl32(a1 + round_f(inv_round, b1, c1, d1) - + w[R1[j]] + K1[round_idx], S1[j]) + e1; - a1 = e1; e1 = d1; d1 = rotl32(c1, 10); c1 = b1; b1 = t1; - } - - uint t = h[1] + c0 + d1; - h[1] = h[2] + d0 + e1; - h[2] = h[3] + e0 + a1; - h[3] = h[4] + a0 + b1; - h[4] = h[0] + b0 + c1; - h[0] = t; -} - -struct Ripemd160JobGPU { - uint input_offset; - uint input_len; - uint output_offset; - uint _pad; -}; - -kernel void ripemd160_jobs( - device const Ripemd160JobGPU* jobs [[buffer(0)]], - device const uchar* inputs [[buffer(1)]], - device uchar* outputs [[buffer(2)]], - constant uint& num_jobs [[buffer(3)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= num_jobs) return; - - Ripemd160JobGPU j = jobs[tid]; - const device uchar* in = inputs + j.input_offset; - device uchar* out = outputs + j.output_offset; - - uint h[5] = { 0x67452301u, 0xefcdab89u, 0x98badcfeu, 0x10325476u, 0xc3d2e1f0u }; - - // Process full 64-byte blocks. - uchar block[64]; - uint w[16]; - uint pos = 0; - while ((j.input_len - pos) >= 64u) { - for (uint i = 0; i < 64; ++i) block[i] = in[pos + i]; - for (uint i = 0; i < 16; ++i) { - w[i] = (uint(block[i * 4 + 0]) ) | - (uint(block[i * 4 + 1]) << 8) | - (uint(block[i * 4 + 2]) << 16) | - (uint(block[i * 4 + 3]) << 24); - } - compress(h, w); - pos += 64; - } - - // Final block(s) with MD4-style padding. - uint rem = j.input_len - pos; - for (uint i = 0; i < 64; ++i) block[i] = 0; - for (uint i = 0; i < rem; ++i) block[i] = in[pos + i]; - block[rem] = 0x80u; - - if (rem >= 56u) { - // Tail spans two final blocks. - for (uint i = 0; i < 16; ++i) { - w[i] = (uint(block[i * 4 + 0]) ) | - (uint(block[i * 4 + 1]) << 8) | - (uint(block[i * 4 + 2]) << 16) | - (uint(block[i * 4 + 3]) << 24); - } - compress(h, w); - for (uint i = 0; i < 64; ++i) block[i] = 0; - } - - // Append 64-bit little-endian bit length at offset 56..63. - ulong bit_len = (ulong)j.input_len * 8u; - block[56] = uchar((bit_len ) & 0xFF); - block[57] = uchar((bit_len >> 8) & 0xFF); - block[58] = uchar((bit_len >> 16) & 0xFF); - block[59] = uchar((bit_len >> 24) & 0xFF); - block[60] = uchar((bit_len >> 32) & 0xFF); - block[61] = uchar((bit_len >> 40) & 0xFF); - block[62] = uchar((bit_len >> 48) & 0xFF); - block[63] = uchar((bit_len >> 56) & 0xFF); - - for (uint i = 0; i < 16; ++i) { - w[i] = (uint(block[i * 4 + 0]) ) | - (uint(block[i * 4 + 1]) << 8) | - (uint(block[i * 4 + 2]) << 16) | - (uint(block[i * 4 + 3]) << 24); - } - compress(h, w); - - // Output little-endian. - for (uint i = 0; i < 5; ++i) { - out[i * 4 + 0] = uchar( h[i] & 0xFF); - out[i * 4 + 1] = uchar((h[i] >> 8) & 0xFF); - out[i * 4 + 2] = uchar((h[i] >> 16) & 0xFF); - out[i * 4 + 3] = uchar((h[i] >> 24) & 0xFF); - } -} diff --git a/ripemd160/gpu/metal/ripemd160_batch_driver.mm b/ripemd160/gpu/metal/ripemd160_batch_driver.mm deleted file mode 100644 index bad1b50..0000000 --- a/ripemd160/gpu/metal/ripemd160_batch_driver.mm +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Metal driver for batched RIPEMD-160. macOS / iOS only. -// Loads ripemd160_batch.metallib, dispatches `ripemd160_jobs` with one -// thread per input. Byte-equal to ripemd160/cpp/ripemd160.cpp::ripemd160(). - -#if __APPLE__ && __OBJC__ - -#import -#import - -#include -#include -#include -#include - -namespace { - -struct Ripemd160JobGPU { - uint32_t input_offset; - uint32_t input_len; - uint32_t output_offset; - uint32_t _pad; -}; - -} // namespace - -extern "C" int ripemd160_batch_metal( - const uint8_t* inputs_arena, - size_t inputs_arena_len, - const uint32_t* input_offsets, - const uint32_t* input_lens, - size_t n, - uint8_t* outputs_arena, - const char* metallib_path) { - - if (n == 0) return 0; - if (!inputs_arena || !input_offsets || !input_lens || !outputs_arena || - !metallib_path) return -1; - - @autoreleasepool { - id device = MTLCreateSystemDefaultDevice(); - if (!device) return -2; - - NSError* err = nil; - NSString* path = [NSString stringWithUTF8String:metallib_path]; - NSURL* url = [NSURL fileURLWithPath:path]; - id lib = [device newLibraryWithURL:url error:&err]; - if (!lib) return -3; - - id fn = [lib newFunctionWithName:@"ripemd160_jobs"]; - if (!fn) return -4; - - id pipeline = - [device newComputePipelineStateWithFunction:fn error:&err]; - if (!pipeline) return -5; - - id queue = [device newCommandQueue]; - - std::vector jobs(n); - for (size_t i = 0; i < n; ++i) { - jobs[i].input_offset = input_offsets[i]; - jobs[i].input_len = input_lens[i]; - jobs[i].output_offset = (uint32_t)(i * 20); - jobs[i]._pad = 0; - } - - id jobs_buf = [device newBufferWithBytes:jobs.data() - length:jobs.size() * sizeof(Ripemd160JobGPU) - options:MTLResourceStorageModeShared]; - id inputs_buf = [device newBufferWithBytes:inputs_arena - length:inputs_arena_len - options:MTLResourceStorageModeShared]; - id outputs_buf = [device newBufferWithLength:n * 20 - options:MTLResourceStorageModeShared]; - uint32_t n_u32 = (uint32_t)n; - id n_buf = [device newBufferWithBytes:&n_u32 - length:sizeof(n_u32) - options:MTLResourceStorageModeShared]; - - id cmd = [queue commandBuffer]; - id enc = [cmd computeCommandEncoder]; - [enc setComputePipelineState:pipeline]; - [enc setBuffer:jobs_buf offset:0 atIndex:0]; - [enc setBuffer:inputs_buf offset:0 atIndex:1]; - [enc setBuffer:outputs_buf offset:0 atIndex:2]; - [enc setBuffer:n_buf offset:0 atIndex:3]; - - NSUInteger tg_max = pipeline.maxTotalThreadsPerThreadgroup; - NSUInteger tg_w = tg_max < 64 ? tg_max : 64; - MTLSize threads_per_grid = MTLSizeMake(n, 1, 1); - MTLSize threads_per_tg = MTLSizeMake(tg_w, 1, 1); - [enc dispatchThreads:threads_per_grid threadsPerThreadgroup:threads_per_tg]; - [enc endEncoding]; - [cmd commit]; - [cmd waitUntilCompleted]; - - std::memcpy(outputs_arena, [outputs_buf contents], n * 20); - } - return 0; -} - -#endif // __APPLE__ && __OBJC__ diff --git a/ripemd160/gpu/wgsl/ripemd160.wgsl b/ripemd160/gpu/wgsl/ripemd160.wgsl deleted file mode 100644 index ea9bf3c..0000000 --- a/ripemd160/gpu/wgsl/ripemd160.wgsl +++ /dev/null @@ -1,199 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// RIPEMD-160 (Dobbertin et al. 1996) compute shader in WGSL. -// -// One thread per input. Byte-equal to ripemd160/cpp/ripemd160.cpp::ripemd160(), -// ripemd160/gpu/cuda/ripemd160.cu::ripemd160_batch, and the Metal kernel. -// -// Two parallel "lines": z[0] uses f1..f5 with K0/R0/S0, z[1] uses f5..f1 -// (mirrored function index 4 - round) with K1/R1/S1. Final mixdown -// interleaves the two lines back into the 5-word state. -// -// Padding: MD4-style — append 0x80 byte, zero-pad to 56 mod 64, append -// 64-bit little-endian bit length. - -struct HashInput { - offset: u32, - length: u32, -} - -@group(0) @binding(0) var inputs: array; -@group(0) @binding(1) var data: array; -@group(0) @binding(2) var outputs: array; - -// Round added constants. K0 = left line, K1 = right line. -const K0 = array( - 0x00000000u, 0x5a827999u, 0x6ed9eba1u, 0x8f1bbcdcu, 0xa953fd4eu -); -const K1 = array( - 0x50a28be6u, 0x5c4dd124u, 0x6d703ef3u, 0x7a6d76e9u, 0x00000000u -); - -// Message word selection (left). -const R0 = array( - 0u, 1u, 2u, 3u, 4u, 5u, 6u, 7u, 8u, 9u, 10u, 11u, 12u, 13u, 14u, 15u, - 7u, 4u, 13u, 1u, 10u, 6u, 15u, 3u, 12u, 0u, 9u, 5u, 2u, 14u, 11u, 8u, - 3u, 10u, 14u, 4u, 9u, 15u, 8u, 1u, 2u, 7u, 0u, 6u, 13u, 11u, 5u, 12u, - 1u, 9u, 11u, 10u, 0u, 8u, 12u, 4u, 13u, 3u, 7u, 15u, 14u, 5u, 6u, 2u, - 4u, 0u, 5u, 9u, 7u, 12u, 2u, 10u, 14u, 1u, 3u, 8u, 11u, 6u, 15u, 13u -); - -// Message word selection (right). -const R1 = array( - 5u, 14u, 7u, 0u, 9u, 2u, 11u, 4u, 13u, 6u, 15u, 8u, 1u, 10u, 3u, 12u, - 6u, 11u, 3u, 7u, 0u, 13u, 5u, 10u, 14u, 15u, 8u, 12u, 4u, 9u, 1u, 2u, - 15u, 5u, 1u, 3u, 7u, 14u, 6u, 9u, 11u, 8u, 12u, 2u, 10u, 0u, 4u, 13u, - 8u, 6u, 4u, 1u, 3u, 11u, 15u, 0u, 5u, 12u, 2u, 13u, 9u, 7u, 10u, 14u, - 12u, 15u, 10u, 4u, 1u, 5u, 8u, 7u, 6u, 2u, 13u, 14u, 0u, 3u, 9u, 11u -); - -// Rotation amounts (left). -const S0 = array( - 11u, 14u, 15u, 12u, 5u, 8u, 7u, 9u, 11u, 13u, 14u, 15u, 6u, 7u, 9u, 8u, - 7u, 6u, 8u, 13u, 11u, 9u, 7u, 15u, 7u, 12u, 15u, 9u, 11u, 7u, 13u, 12u, - 11u, 13u, 6u, 7u, 14u, 9u, 13u, 15u, 14u, 8u, 13u, 6u, 5u, 12u, 7u, 5u, - 11u, 12u, 14u, 15u, 14u, 15u, 9u, 8u, 9u, 14u, 5u, 6u, 8u, 6u, 5u, 12u, - 9u, 15u, 5u, 11u, 6u, 8u, 13u, 12u, 5u, 12u, 13u, 14u, 11u, 8u, 5u, 6u -); - -// Rotation amounts (right). -const S1 = array( - 8u, 9u, 9u, 11u, 13u, 15u, 15u, 5u, 7u, 7u, 8u, 11u, 14u, 14u, 12u, 6u, - 9u, 13u, 15u, 7u, 12u, 8u, 9u, 11u, 7u, 7u, 12u, 7u, 6u, 15u, 13u, 11u, - 9u, 7u, 15u, 11u, 8u, 6u, 6u, 14u, 12u, 13u, 5u, 14u, 13u, 13u, 7u, 5u, - 15u, 5u, 8u, 11u, 14u, 14u, 6u, 14u, 6u, 9u, 12u, 9u, 12u, 5u, 15u, 8u, - 8u, 5u, 12u, 9u, 12u, 5u, 14u, 6u, 8u, 13u, 6u, 5u, 15u, 13u, 11u, 11u -); - -fn rotl32(x: u32, n: u32) -> u32 { - return (x << n) | (x >> (32u - n)); -} - -// Boolean selection functions (Dobbertin §3, eq. 1-5). -fn round_f(round_idx: u32, x: u32, y: u32, z: u32) -> u32 { - if (round_idx == 0u) { return x ^ y ^ z; } // f1 - if (round_idx == 1u) { return ((y ^ z) & x) ^ z; } // f2 - if (round_idx == 2u) { return (x | (~y)) ^ z; } // f3 - if (round_idx == 3u) { return ((x ^ y) & z) ^ y; } // f4 - return x ^ (y | (~z)); // f5 -} - -// Read a single byte from the packed u32 data array (little-endian). -fn read_byte(byte_offset: u32) -> u32 { - let word_idx = byte_offset >> 2u; - let byte_pos = byte_offset & 3u; - return (data[word_idx] >> (byte_pos * 8u)) & 0xFFu; -} - -// Per-thread scratch. -var h: array; -var w: array; -var block: array; // u32 per byte for simplicity - -fn compress() { - var a0 = h[0]; var b0 = h[1]; var c0 = h[2]; var d0 = h[3]; var e0 = h[4]; - var a1 = h[0]; var b1 = h[1]; var c1 = h[2]; var d1 = h[3]; var e1 = h[4]; - - for (var j = 0u; j < 80u; j = j + 1u) { - let round_idx = j / 16u; - - // Left line. - let t0 = rotl32(a0 + round_f(round_idx, b0, c0, d0) - + w[R0[j]] + K0[round_idx], S0[j]) + e0; - a0 = e0; e0 = d0; d0 = rotl32(c0, 10u); c0 = b0; b0 = t0; - - // Right line. - let inv_round = 4u - round_idx; - let t1 = rotl32(a1 + round_f(inv_round, b1, c1, d1) - + w[R1[j]] + K1[round_idx], S1[j]) + e1; - a1 = e1; e1 = d1; d1 = rotl32(c1, 10u); c1 = b1; b1 = t1; - } - - let t = h[1] + c0 + d1; - h[1] = h[2] + d0 + e1; - h[2] = h[3] + e0 + a1; - h[3] = h[4] + a0 + b1; - h[4] = h[0] + b0 + c1; - h[0] = t; -} - -@compute @workgroup_size(64) -fn ripemd160_batch(@builtin(global_invocation_id) gid: vec3) { - let tid = gid.x; - let inp = inputs[tid]; - let offset = inp.offset; - let len = inp.length; - - // IV. - h[0] = 0x67452301u; - h[1] = 0xefcdab89u; - h[2] = 0x98badcfeu; - h[3] = 0x10325476u; - h[4] = 0xc3d2e1f0u; - - // Process full 64-byte blocks. - var pos = 0u; - for (; pos + 64u <= len; pos = pos + 64u) { - for (var i = 0u; i < 16u; i = i + 1u) { - let b0 = read_byte(offset + pos + i * 4u + 0u); - let b1 = read_byte(offset + pos + i * 4u + 1u); - let b2 = read_byte(offset + pos + i * 4u + 2u); - let b3 = read_byte(offset + pos + i * 4u + 3u); - w[i] = b0 | (b1 << 8u) | (b2 << 16u) | (b3 << 24u); - } - compress(); - } - - // Final block(s) with MD4-style padding. Use byte-granular scratch. - let rem = len - pos; - for (var i = 0u; i < 64u; i = i + 1u) { - block[i] = 0u; - } - for (var i = 0u; i < rem; i = i + 1u) { - block[i] = read_byte(offset + pos + i); - } - block[rem] = 0x80u; - - if (rem >= 56u) { - // First final block: data + padding bit, no length yet. - for (var i = 0u; i < 16u; i = i + 1u) { - w[i] = block[i * 4u + 0u] - | (block[i * 4u + 1u] << 8u) - | (block[i * 4u + 2u] << 16u) - | (block[i * 4u + 3u] << 24u); - } - compress(); - for (var i = 0u; i < 64u; i = i + 1u) { - block[i] = 0u; - } - } - - // Append 64-bit little-endian bit length. - let bit_len_lo = len << 3u; // low 32 bits of len*8 - let bit_len_hi = len >> 29u; // high 32 bits of len*8 - block[56] = bit_len_lo & 0xFFu; - block[57] = (bit_len_lo >> 8u) & 0xFFu; - block[58] = (bit_len_lo >> 16u) & 0xFFu; - block[59] = (bit_len_lo >> 24u) & 0xFFu; - block[60] = bit_len_hi & 0xFFu; - block[61] = (bit_len_hi >> 8u) & 0xFFu; - block[62] = (bit_len_hi >> 16u) & 0xFFu; - block[63] = (bit_len_hi >> 24u) & 0xFFu; - - for (var i = 0u; i < 16u; i = i + 1u) { - w[i] = block[i * 4u + 0u] - | (block[i * 4u + 1u] << 8u) - | (block[i * 4u + 2u] << 16u) - | (block[i * 4u + 3u] << 24u); - } - compress(); - - // Emit 20-byte digest, packed into 5 u32 lanes (little-endian). - let out_base = tid * 5u; - outputs[out_base + 0u] = h[0]; - outputs[out_base + 1u] = h[1]; - outputs[out_base + 2u] = h[2]; - outputs[out_base + 3u] = h[3]; - outputs[out_base + 4u] = h[4]; -} diff --git a/ripemd160/gpu/wgsl/ripemd160_driver_wgpu.cpp b/ripemd160/gpu/wgsl/ripemd160_driver_wgpu.cpp deleted file mode 100644 index 971aa7e..0000000 --- a/ripemd160/gpu/wgsl/ripemd160_driver_wgpu.cpp +++ /dev/null @@ -1,284 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// WebGPU/WGSL host driver for batched RIPEMD-160 (Dobbertin et al. 1996). -// -// Mirrors the Metal/CUDA layout: caller supplies a packed input arena, an -// (offset, length) descriptor per input, and a contiguous 20-byte stride -// output buffer. -// -// Build flags: -// * LUX_RIPEMD160_HAS_WEBGPU=1 - Dawn or wgpu-native runtime found -// * LUX_RIPEMD160_HAS_WGPU_NATIVE=1 - wgpu-native specifically (gives -// wgpuDevicePoll for synchronous waits) - -#include "ripemd160_driver_wgpu.h" - -#if defined(LUX_RIPEMD160_HAS_WEBGPU) - -#include -#if defined(LUX_RIPEMD160_HAS_WGPU_NATIVE) -# include -#endif - -#include -#include -#include -#include -#include -#include - -// WGSL source concatenated into a string literal by CMake. -#include "ripemd160_wgsl_sources.h" - -namespace { - -WGPUStringView sv(const char* s) { - WGPUStringView v{}; - v.data = s; - v.length = (s == nullptr) ? 0 : std::strlen(s); - return v; -} -WGPUStringView sv(const std::string& s) { - WGPUStringView v{}; - v.data = s.data(); - v.length = s.size(); - return v; -} - -void drain(WGPUInstance inst, WGPUDevice dev) { - if (inst) wgpuInstanceProcessEvents(inst); -#if defined(LUX_RIPEMD160_HAS_WGPU_NATIVE) - if (dev) wgpuDevicePoll(dev, /*wait=*/WGPU_TRUE, nullptr); -#else - (void)dev; -#endif -} - -bool wait_map(WGPUInstance inst, WGPUDevice dev, WGPUBuffer buf, - WGPUMapMode mode, size_t off, size_t size) { - struct State { - std::atomic done{false}; - WGPUMapAsyncStatus status{WGPUMapAsyncStatus_Error}; - } s; - WGPUBufferMapCallbackInfo cb{}; - cb.mode = WGPUCallbackMode_AllowProcessEvents; - cb.callback = [](WGPUMapAsyncStatus st, WGPUStringView, void* u, void*) { - auto* p = static_cast(u); - p->status = st; - p->done.store(true, std::memory_order_release); - }; - cb.userdata1 = &s; - wgpuBufferMapAsync(buf, mode, off, size, cb); - for (int spin = 0; spin < 4096; spin++) { - if (s.done.load(std::memory_order_acquire)) break; - drain(inst, dev); - } - return s.done.load() && s.status == WGPUMapAsyncStatus_Success; -} - -struct Engine { - WGPUInstance instance{nullptr}; - WGPUAdapter adapter{nullptr}; - WGPUDevice device{nullptr}; - WGPUQueue queue{nullptr}; - WGPUShaderModule module{nullptr}; - WGPUComputePipeline pipeline{nullptr}; - bool initialized{false}; -}; - -Engine& engine() { static Engine e; return e; } - -bool init_engine() { - Engine& e = engine(); - if (e.initialized) return true; - - WGPUInstanceDescriptor idesc{}; - e.instance = wgpuCreateInstance(&idesc); - if (!e.instance) return false; - - struct AS { std::atomic done{false}; WGPUAdapter ad{nullptr}; } as; - WGPURequestAdapterOptions ropt{}; - ropt.powerPreference = WGPUPowerPreference_HighPerformance; - WGPURequestAdapterCallbackInfo rcb{}; - rcb.mode = WGPUCallbackMode_AllowProcessEvents; - rcb.callback = [](WGPURequestAdapterStatus st, WGPUAdapter ad, - WGPUStringView, void* u, void*) { - auto* p = static_cast(u); - if (st == WGPURequestAdapterStatus_Success) p->ad = ad; - p->done.store(true, std::memory_order_release); - }; - rcb.userdata1 = &as; - wgpuInstanceRequestAdapter(e.instance, &ropt, rcb); - for (int spin = 0; spin < 4096; spin++) { - if (as.done.load(std::memory_order_acquire)) break; - wgpuInstanceProcessEvents(e.instance); - } - if (!as.ad) { std::fprintf(stderr, "wgpu: no adapter\n"); return false; } - e.adapter = as.ad; - - struct DS { std::atomic done{false}; WGPUDevice dev{nullptr}; } ds; - WGPUDeviceDescriptor ddesc{}; - WGPURequestDeviceCallbackInfo dcb{}; - dcb.mode = WGPUCallbackMode_AllowProcessEvents; - dcb.callback = [](WGPURequestDeviceStatus st, WGPUDevice dev, - WGPUStringView, void* u, void*) { - auto* p = static_cast(u); - if (st == WGPURequestDeviceStatus_Success) p->dev = dev; - p->done.store(true, std::memory_order_release); - }; - dcb.userdata1 = &ds; - wgpuAdapterRequestDevice(e.adapter, &ddesc, dcb); - for (int spin = 0; spin < 4096; spin++) { - if (ds.done.load(std::memory_order_acquire)) break; - wgpuInstanceProcessEvents(e.instance); - } - if (!ds.dev) { std::fprintf(stderr, "wgpu: no device\n"); return false; } - e.device = ds.dev; - e.queue = wgpuDeviceGetQueue(e.device); - if (!e.queue) return false; - - WGPUShaderSourceWGSL wgsl{}; - wgsl.chain.sType = WGPUSType_ShaderSourceWGSL; - wgsl.code = sv(kRIPEMD160_WGSL_Source); - - WGPUShaderModuleDescriptor smd{}; - smd.nextInChain = &wgsl.chain; - smd.label = sv("ripemd160"); - e.module = wgpuDeviceCreateShaderModule(e.device, &smd); - if (!e.module) { - std::fprintf(stderr, "wgpu: ripemd160 shader compile failed\n"); - return false; - } - - WGPUComputePipelineDescriptor cpd{}; - cpd.compute.module = e.module; - cpd.compute.entryPoint = sv("ripemd160_batch"); - cpd.label = sv("ripemd160_batch"); - e.pipeline = wgpuDeviceCreateComputePipeline(e.device, &cpd); - if (!e.pipeline) { - std::fprintf(stderr, "wgpu: ripemd160 pipeline failed\n"); - return false; - } - - e.initialized = true; - return true; -} - -WGPUBuffer make_buf(Engine& e, size_t size, WGPUBufferUsage usage) { - WGPUBufferDescriptor bd{}; - bd.size = (size + 3) & ~size_t(3); - if (bd.size == 0) bd.size = 4; - bd.usage = usage; - return wgpuDeviceCreateBuffer(e.device, &bd); -} - -} // namespace - -extern "C" int lux_ripemd160_wgpu_available(void) { - return init_engine() ? 1 : 0; -} - -extern "C" int ripemd160_batch_wgpu( - const uint8_t* inputs_arena, - size_t inputs_arena_len, - const uint32_t* input_offsets, - const uint32_t* input_lens, - size_t n, - uint8_t* outputs_arena) { - - if (n == 0) return 0; - if (!input_offsets || !input_lens || !outputs_arena) return -1; - if (!init_engine()) return -2; - Engine& e = engine(); - - // Pack the inputs descriptor (offset, length) into a u32 array of - // length 2*n. Pad inputs_arena up to a 4-byte boundary for the data - // buffer because WGSL reads it as array. - std::vector desc(n * 2); - for (size_t i = 0; i < n; ++i) { - desc[i * 2 + 0] = input_offsets[i]; - desc[i * 2 + 1] = input_lens[i]; - } - - size_t data_bytes = (inputs_arena_len + 3) & ~size_t(3); - if (data_bytes == 0) data_bytes = 4; - std::vector data_padded(data_bytes, 0); - if (inputs_arena_len) std::memcpy(data_padded.data(), inputs_arena, - inputs_arena_len); - - size_t out_words = n * 5; // 20 bytes per hash = 5 u32 words - size_t out_bytes = out_words * 4; - - WGPUBuffer buf_desc = make_buf(e, desc.size() * sizeof(uint32_t), - WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst); - WGPUBuffer buf_data = make_buf(e, data_bytes, - WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst); - WGPUBuffer buf_out = make_buf(e, out_bytes, - WGPUBufferUsage_Storage | WGPUBufferUsage_CopySrc); - WGPUBuffer buf_read = make_buf(e, out_bytes, - WGPUBufferUsage_MapRead | WGPUBufferUsage_CopyDst); - if (!buf_desc || !buf_data || !buf_out || !buf_read) return -3; - - wgpuQueueWriteBuffer(e.queue, buf_desc, 0, desc.data(), - desc.size() * sizeof(uint32_t)); - wgpuQueueWriteBuffer(e.queue, buf_data, 0, data_padded.data(), data_bytes); - - WGPUBindGroupLayout bgl = wgpuComputePipelineGetBindGroupLayout(e.pipeline, 0); - WGPUBindGroupEntry bge[3] = {}; - bge[0].binding = 0; bge[0].buffer = buf_desc; bge[0].size = desc.size() * sizeof(uint32_t); - bge[1].binding = 1; bge[1].buffer = buf_data; bge[1].size = data_bytes; - bge[2].binding = 2; bge[2].buffer = buf_out; bge[2].size = out_bytes; - WGPUBindGroupDescriptor bgd{}; - bgd.layout = bgl; - bgd.entryCount = 3; - bgd.entries = bge; - WGPUBindGroup bg = wgpuDeviceCreateBindGroup(e.device, &bgd); - if (!bg) return -4; - - WGPUCommandEncoderDescriptor ced{}; - WGPUCommandEncoder ce = wgpuDeviceCreateCommandEncoder(e.device, &ced); - WGPUComputePassDescriptor cpd2{}; - WGPUComputePassEncoder cpe = wgpuCommandEncoderBeginComputePass(ce, &cpd2); - wgpuComputePassEncoderSetPipeline(cpe, e.pipeline); - wgpuComputePassEncoderSetBindGroup(cpe, 0, bg, 0, nullptr); - uint32_t wg = uint32_t((n + 63) / 64); - wgpuComputePassEncoderDispatchWorkgroups(cpe, wg, 1, 1); - wgpuComputePassEncoderEnd(cpe); - - wgpuCommandEncoderCopyBufferToBuffer(ce, buf_out, 0, buf_read, 0, out_bytes); - WGPUCommandBufferDescriptor cbd{}; - WGPUCommandBuffer cmd = wgpuCommandEncoderFinish(ce, &cbd); - wgpuQueueSubmit(e.queue, 1, &cmd); - - if (!wait_map(e.instance, e.device, buf_read, WGPUMapMode_Read, 0, out_bytes)) { - std::fprintf(stderr, "wgpu: ripemd160 readback map failed\n"); - return -5; - } - const void* mapped = wgpuBufferGetConstMappedRange(buf_read, 0, out_bytes); - std::memcpy(outputs_arena, mapped, n * 20); - wgpuBufferUnmap(buf_read); - - wgpuComputePassEncoderRelease(cpe); - wgpuCommandEncoderRelease(ce); - wgpuCommandBufferRelease(cmd); - wgpuBindGroupRelease(bg); - wgpuBindGroupLayoutRelease(bgl); - wgpuBufferRelease(buf_desc); - wgpuBufferRelease(buf_data); - wgpuBufferRelease(buf_out); - wgpuBufferRelease(buf_read); - return 0; -} - -#else // LUX_RIPEMD160_HAS_WEBGPU not defined: stub mode - -extern "C" int lux_ripemd160_wgpu_available(void) { return 0; } -extern "C" int ripemd160_batch_wgpu( - const uint8_t*, size_t, - const uint32_t*, const uint32_t*, - size_t, uint8_t*) { - return -1; -} - -#endif diff --git a/ripemd160/gpu/wgsl/ripemd160_driver_wgpu.h b/ripemd160/gpu/wgsl/ripemd160_driver_wgpu.h deleted file mode 100644 index 61d6c6c..0000000 --- a/ripemd160/gpu/wgsl/ripemd160_driver_wgpu.h +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Public C-ABI for the RIPEMD-160 WebGPU/WGSL driver. On hosts without a -// wgpu runtime, lux_ripemd160_wgpu_available() returns 0 and -// ripemd160_batch_wgpu() returns -1. - -#ifndef LUX_RIPEMD160_DRIVER_WGPU_H -#define LUX_RIPEMD160_DRIVER_WGPU_H - -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -// Returns 1 if a WebGPU adapter+device initialised successfully, 0 otherwise. -int lux_ripemd160_wgpu_available(void); - -// Run N RIPEMD-160 hashes in one WGSL dispatch. Inputs share a flat byte -// arena; outputs are written as 20 contiguous bytes per hash. Returns 0 on -// success, negative on failure. -int ripemd160_batch_wgpu( - const uint8_t* inputs_arena, - size_t inputs_arena_len, - const uint32_t* input_offsets, - const uint32_t* input_lens, - size_t n, - uint8_t* outputs_arena); - -#ifdef __cplusplus -} -#endif - -#endif // LUX_RIPEMD160_DRIVER_WGPU_H diff --git a/secp256k1/gpu/cuda/secp256k1.cu b/secp256k1/gpu/cuda/secp256k1.cu deleted file mode 100644 index 2e7e47d..0000000 --- a/secp256k1/gpu/cuda/secp256k1.cu +++ /dev/null @@ -1,685 +0,0 @@ -// secp256k1 ECDSA batch recovery — CUDA implementation -// Matches secp256k1_recover.metal output byte-for-byte -// One thread per ecrecover (r, s, v, msg_hash) → 20-byte Ethereum address - -#include - -#ifndef __CUDA_ARCH__ -#define __device__ -#define __global__ -#define __shared__ -struct dim3 { unsigned x, y, z; }; -static dim3 blockIdx, blockDim, threadIdx; -#endif - -// ============================================================================= -// 256-bit unsigned integer (4 x 64-bit limbs, little-endian) -// ============================================================================= - -struct uint256 { - uint64_t limbs[4]; -}; - -// ============================================================================= -// secp256k1 constants -// ============================================================================= - -__device__ static const uint256 SECP256K1_P = {{ - 0xFFFFFFFEFFFFFC2FULL, 0xFFFFFFFFFFFFFFFFULL, - 0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL -}}; - -__device__ static const uint256 SECP256K1_N = {{ - 0xBFD25E8CD0364141ULL, 0xBAAEDCE6AF48A03BULL, - 0xFFFFFFFFFFFFFFFEULL, 0xFFFFFFFFFFFFFFFFULL -}}; - -__device__ static const uint256 GX = {{ - 0x59F2815B16F81798ULL, 0x029BFCDB2DCE28D9ULL, - 0x55A06295CE870B07ULL, 0x79BE667EF9DCBBACULL -}}; - -__device__ static const uint256 GY = {{ - 0x9C47D08FFB10D4B8ULL, 0xFD17B448A6855419ULL, - 0x5DA4FBFC0E1108A8ULL, 0x483ADA7726A3C465ULL -}}; - -__device__ static const uint256 MONT_R_P = {{ - 0x00000001000003D1ULL, 0x0000000000000000ULL, - 0x0000000000000000ULL, 0x0000000000000000ULL -}}; - -__device__ static const uint256 MONT_R2_P = {{ - 0x000007A2000E90A1ULL, 0x0000000000000001ULL, - 0x0000000000000000ULL, 0x0000000000000000ULL -}}; - -__device__ static const uint64_t P_INV = 0xD838091DD2253531ULL; - -__device__ static const uint256 MONT_R2_N = {{ - 0x896CF21467D7D140ULL, 0x741496C20E7CF878ULL, - 0xE697F5E45BCD07C6ULL, 0x9D671CD581C69BC5ULL -}}; - -__device__ static const uint64_t N_INV = 0x4B0DFF665588B13FULL; - -__device__ static const uint256 ZERO256 = {{0, 0, 0, 0}}; -__device__ static const uint256 ONE256 = {{1, 0, 0, 0}}; - -// ============================================================================= -// 256-bit arithmetic -// ============================================================================= - -__device__ static int u256_cmp(uint256 a, uint256 b) { - for (int i = 3; i >= 0; i--) { - if (a.limbs[i] < b.limbs[i]) return -1; - if (a.limbs[i] > b.limbs[i]) return 1; - } - return 0; -} - -__device__ static bool u256_is_zero(uint256 a) { - return (a.limbs[0] | a.limbs[1] | a.limbs[2] | a.limbs[3]) == 0; -} - -__device__ static uint256 u256_add(uint256 a, uint256 b, uint64_t& carry) { - uint256 r; - uint64_t c = 0; - for (int i = 0; i < 4; i++) { - uint64_t sum = a.limbs[i] + c; - c = (sum < a.limbs[i]) ? 1ULL : 0ULL; - uint64_t sum2 = sum + b.limbs[i]; - c += (sum2 < sum) ? 1ULL : 0ULL; - r.limbs[i] = sum2; - } - carry = c; - return r; -} - -__device__ static uint256 u256_sub(uint256 a, uint256 b, uint64_t& borrow) { - uint256 r; - uint64_t bw = 0; - for (int i = 0; i < 4; i++) { - uint64_t diff = a.limbs[i] - bw; - bw = (diff > a.limbs[i]) ? 1ULL : 0ULL; - uint64_t diff2 = diff - b.limbs[i]; - bw += (diff2 > diff) ? 1ULL : 0ULL; - r.limbs[i] = diff2; - } - borrow = bw; - return r; -} - -// ============================================================================= -// Montgomery arithmetic (parameterized by modulus m and inv = -m^(-1) mod 2^64) -// Uses __int128 on CUDA for 64x64->128 multiply -// ============================================================================= - -__device__ static uint256 mont_reduce(uint64_t t[8], uint256 m, uint64_t inv) { - uint64_t a[9]; - for (int i = 0; i < 8; i++) a[i] = t[i]; - a[8] = 0; - - for (int i = 0; i < 4; i++) { - uint64_t u = a[i] * inv; - - uint64_t carry = 0; - for (int j = 0; j < 4; j++) { -#ifdef __CUDA_ARCH__ - unsigned __int128 prod = (unsigned __int128)u * m.limbs[j]; - unsigned __int128 acc = prod + carry + a[i + j]; - a[i + j] = (uint64_t)acc; - carry = (uint64_t)(acc >> 64); -#else - uint64_t u_lo = u & 0xFFFFFFFFULL; - uint64_t u_hi = u >> 32; - uint64_t m_lo = m.limbs[j] & 0xFFFFFFFFULL; - uint64_t m_hi = m.limbs[j] >> 32; - uint64_t ll = u_lo * m_lo; - uint64_t lh = u_lo * m_hi; - uint64_t hl = u_hi * m_lo; - uint64_t hh = u_hi * m_hi; - uint64_t mid = lh + (ll >> 32); - uint64_t mid2 = mid + hl; - if (mid2 < mid) hh += (1ULL << 32); - uint64_t lo = (mid2 << 32) | (ll & 0xFFFFFFFFULL); - uint64_t hi = hh + (mid2 >> 32); - uint64_t sum = lo + carry; - if (sum < lo) hi++; - lo = sum; - sum = a[i + j] + lo; - if (sum < a[i + j]) hi++; - a[i + j] = sum; - carry = hi; -#endif - } - for (int j = 4; i + j <= 8; j++) { - uint64_t sum = a[i + j] + carry; - carry = (sum < a[i + j]) ? 1ULL : 0ULL; - a[i + j] = sum; - if (carry == 0) break; - } - } - - uint256 r; - r.limbs[0] = a[4]; - r.limbs[1] = a[5]; - r.limbs[2] = a[6]; - r.limbs[3] = a[7]; - - if (a[8] || u256_cmp(r, m) >= 0) { - uint64_t bw; - r = u256_sub(r, m, bw); - } - return r; -} - -__device__ static uint256 mont_mul(uint256 a, uint256 b, uint256 m, uint64_t inv) { - uint64_t t[8] = {0, 0, 0, 0, 0, 0, 0, 0}; - - for (int i = 0; i < 4; i++) { - uint64_t carry = 0; - for (int j = 0; j < 4; j++) { -#ifdef __CUDA_ARCH__ - unsigned __int128 prod = (unsigned __int128)a.limbs[i] * b.limbs[j]; - unsigned __int128 acc = prod + carry + t[i + j]; - t[i + j] = (uint64_t)acc; - carry = (uint64_t)(acc >> 64); -#else - uint64_t a_lo = a.limbs[i] & 0xFFFFFFFFULL; - uint64_t a_hi = a.limbs[i] >> 32; - uint64_t b_lo = b.limbs[j] & 0xFFFFFFFFULL; - uint64_t b_hi = b.limbs[j] >> 32; - uint64_t ll = a_lo * b_lo; - uint64_t lh = a_lo * b_hi; - uint64_t hl = a_hi * b_lo; - uint64_t hh = a_hi * b_hi; - uint64_t mid = lh + (ll >> 32); - uint64_t mid2 = mid + hl; - if (mid2 < mid) hh += (1ULL << 32); - uint64_t lo = (mid2 << 32) | (ll & 0xFFFFFFFFULL); - uint64_t hi = hh + (mid2 >> 32); - uint64_t sum = lo + carry; - if (sum < lo) hi++; - lo = sum; - sum = t[i + j] + lo; - if (sum < t[i + j]) hi++; - t[i + j] = sum; - carry = hi; -#endif - } - for (int j = 4; i + j < 8; j++) { - uint64_t sum = t[i + j] + carry; - carry = (sum < t[i + j]) ? 1ULL : 0ULL; - t[i + j] = sum; - if (carry == 0) break; - } - } - - return mont_reduce(t, m, inv); -} - -__device__ static uint256 to_mont(uint256 a, uint256 r2, uint256 m, uint64_t inv) { - return mont_mul(a, r2, m, inv); -} - -__device__ static uint256 from_mont(uint256 a, uint256 m, uint64_t inv) { - uint64_t t[8] = {a.limbs[0], a.limbs[1], a.limbs[2], a.limbs[3], 0, 0, 0, 0}; - return mont_reduce(t, m, inv); -} - -// Field operations over p (Montgomery form) -__device__ static uint256 fp_add(uint256 a, uint256 b) { - uint64_t carry; - uint256 r = u256_add(a, b, carry); - if (carry || u256_cmp(r, SECP256K1_P) >= 0) { - uint64_t bw; - r = u256_sub(r, SECP256K1_P, bw); - } - return r; -} - -__device__ static uint256 fp_sub(uint256 a, uint256 b) { - uint64_t bw; - uint256 r = u256_sub(a, b, bw); - if (bw) { - uint64_t c; - r = u256_add(r, SECP256K1_P, c); - } - return r; -} - -__device__ static uint256 fp_mul(uint256 a, uint256 b) { - return mont_mul(a, b, SECP256K1_P, P_INV); -} - -__device__ static uint256 fp_sqr(uint256 a) { - return fp_mul(a, a); -} - -// Scalar field operations over n -__device__ static uint256 fn_mul(uint256 a, uint256 b) { - return mont_mul(a, b, SECP256K1_N, N_INV); -} - -// Fermat inversion over p -__device__ static uint256 fp_inv(uint256 a) { - uint256 result = to_mont(ONE256, MONT_R2_P, SECP256K1_P, P_INV); - uint256 base = a; - - uint64_t exp[4] = { - 0xFFFFFFFEFFFFFC2DULL, 0xFFFFFFFFFFFFFFFFULL, - 0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL - }; - - for (int i = 0; i < 4; i++) { - for (int bit = 0; bit < 64; bit++) { - if ((exp[i] >> bit) & 1) { - result = fp_mul(result, base); - } - base = fp_sqr(base); - } - } - return result; -} - -// Scalar inversion over n -__device__ static uint256 fn_inv(uint256 a) { - uint64_t exp[4] = { - 0xBFD25E8CD036413FULL, 0xBAAEDCE6AF48A03BULL, - 0xFFFFFFFFFFFFFFFEULL, 0xFFFFFFFFFFFFFFFFULL - }; - - uint256 result = to_mont(ONE256, MONT_R2_N, SECP256K1_N, N_INV); - uint256 base = a; - - for (int i = 0; i < 4; i++) { - for (int bit = 0; bit < 64; bit++) { - if ((exp[i] >> bit) & 1) { - result = fn_mul(result, base); - } - base = fn_mul(base, base); - } - } - return result; -} - -// ============================================================================= -// secp256k1 EC point operations (Jacobian coordinates, Montgomery F_p) -// ============================================================================= - -struct ECPoint { - uint256 x, y, z; -}; - -__device__ static ECPoint ec_identity() { - ECPoint p; - p.x = to_mont(ONE256, MONT_R2_P, SECP256K1_P, P_INV); - p.y = to_mont(ONE256, MONT_R2_P, SECP256K1_P, P_INV); - p.z = ZERO256; - return p; -} - -__device__ static bool ec_is_infinity(ECPoint p) { - return u256_is_zero(p.z); -} - -__device__ static ECPoint ec_double(ECPoint p) { - if (ec_is_infinity(p)) return p; - - uint256 A = fp_sqr(p.y); - uint256 B = fp_mul(p.x, A); - uint256 C = fp_sqr(A); - - uint256 S = fp_add(B, B); - S = fp_add(S, S); - - uint256 X2 = fp_sqr(p.x); - uint256 M = fp_add(X2, fp_add(X2, X2)); - - uint256 X3 = fp_sub(fp_sqr(M), fp_add(S, S)); - - uint256 C8 = fp_add(C, C); - C8 = fp_add(C8, C8); - C8 = fp_add(C8, C8); - uint256 Y3 = fp_sub(fp_mul(M, fp_sub(S, X3)), C8); - - uint256 Z3 = fp_mul(p.y, p.z); - Z3 = fp_add(Z3, Z3); - - ECPoint r; - r.x = X3; - r.y = Y3; - r.z = Z3; - return r; -} - -__device__ static ECPoint ec_add_mixed(ECPoint P, uint256 Qx, uint256 Qy) { - if (ec_is_infinity(P)) { - ECPoint r; - r.x = Qx; - r.y = Qy; - r.z = to_mont(ONE256, MONT_R2_P, SECP256K1_P, P_INV); - return r; - } - - uint256 Z2 = fp_sqr(P.z); - uint256 U2 = fp_mul(Qx, Z2); - uint256 Z3 = fp_mul(Z2, P.z); - uint256 S2 = fp_mul(Qy, Z3); - - uint256 H = fp_sub(U2, P.x); - uint256 R = fp_sub(S2, P.y); - - if (u256_is_zero(H)) { - if (u256_is_zero(R)) { - return ec_double(P); - } - return ec_identity(); - } - - uint256 H2 = fp_sqr(H); - uint256 H3 = fp_mul(H, H2); - uint256 U1H2 = fp_mul(P.x, H2); - - uint256 X3 = fp_sub(fp_sub(fp_sqr(R), H3), fp_add(U1H2, U1H2)); - uint256 Y3 = fp_sub(fp_mul(R, fp_sub(U1H2, X3)), fp_mul(P.y, H3)); - uint256 Zr = fp_mul(H, P.z); - - ECPoint res; - res.x = X3; - res.y = Y3; - res.z = Zr; - return res; -} - -__device__ static void ec_to_affine(ECPoint p, uint256& ax, uint256& ay) { - if (ec_is_infinity(p)) { - ax = ZERO256; - ay = ZERO256; - return; - } - uint256 z_inv = fp_inv(p.z); - uint256 z_inv2 = fp_sqr(z_inv); - uint256 z_inv3 = fp_mul(z_inv2, z_inv); - ax = fp_mul(p.x, z_inv2); - ay = fp_mul(p.y, z_inv3); -} - -// NOT constant-time: branches on scalar bits. Safe for ecrecover (all inputs -// are public). MUST NOT be reused for signing where the nonce k is secret. -__device__ static ECPoint ec_mul_affine(uint256 k, uint256 Px, uint256 Py) { - ECPoint result = ec_identity(); - - for (int i = 3; i >= 0; i--) { - for (int bit = 63; bit >= 0; bit--) { - result = ec_double(result); - if ((k.limbs[i] >> bit) & 1) { - result = ec_add_mixed(result, Px, Py); - } - } - } - return result; -} - -// ============================================================================= -// Keccak-256 (inline, for address derivation) -// ============================================================================= - -__device__ static const uint64_t KECCAK_RC[24] = { - 0x0000000000000001ULL, 0x0000000000008082ULL, - 0x800000000000808AULL, 0x8000000080008000ULL, - 0x000000000000808BULL, 0x0000000080000001ULL, - 0x8000000080008081ULL, 0x8000000000008009ULL, - 0x000000000000008AULL, 0x0000000000000088ULL, - 0x0000000080008009ULL, 0x000000008000000AULL, - 0x000000008000808BULL, 0x800000000000008BULL, - 0x8000000000008089ULL, 0x8000000000008003ULL, - 0x8000000000008002ULL, 0x8000000000000080ULL, - 0x000000000000800AULL, 0x800000008000000AULL, - 0x8000000080008081ULL, 0x8000000000008080ULL, - 0x0000000080000001ULL, 0x8000000080008008ULL, -}; - -__device__ static const int KECCAK_PI_LANE[24] = { - 10, 7, 11, 17, 18, 3, 5, 16, 8, 21, 24, 4, - 15, 23, 19, 13, 12, 2, 20, 14, 22, 9, 6, 1 -}; - -__device__ static const int KECCAK_RHO[24] = { - 1, 3, 6, 10, 15, 21, 28, 36, 45, 55, 2, 14, - 27, 41, 56, 8, 25, 43, 62, 18, 39, 61, 20, 44 -}; - -__device__ static uint64_t keccak_rotl64(uint64_t x, int n) { - return (x << n) | (x >> (64 - n)); -} - -__device__ static void keccak_f1600(uint64_t st[25]) { - for (int round = 0; round < 24; ++round) { - uint64_t C[5]; - for (int x = 0; x < 5; ++x) - C[x] = st[x] ^ st[x + 5] ^ st[x + 10] ^ st[x + 15] ^ st[x + 20]; - for (int x = 0; x < 5; ++x) { - uint64_t d = C[(x + 4) % 5] ^ keccak_rotl64(C[(x + 1) % 5], 1); - for (int y = 0; y < 5; ++y) - st[x + 5 * y] ^= d; - } - uint64_t t = st[1]; - for (int i = 0; i < 24; ++i) { - uint64_t tmp = st[KECCAK_PI_LANE[i]]; - st[KECCAK_PI_LANE[i]] = keccak_rotl64(t, KECCAK_RHO[i]); - t = tmp; - } - for (int y = 0; y < 5; ++y) { - uint64_t row[5]; - for (int x = 0; x < 5; ++x) row[x] = st[x + 5 * y]; - for (int x = 0; x < 5; ++x) - st[x + 5 * y] = row[x] ^ ((~row[(x + 1) % 5]) & row[(x + 2) % 5]); - } - st[0] ^= KECCAK_RC[round]; - } -} - -__device__ static void keccak256_64(const uint8_t data[64], uint8_t out[32]) { - uint64_t state[25] = {}; - - for (uint32_t w = 0; w < 8; ++w) { - uint64_t lane = 0; - for (uint32_t b = 0; b < 8; ++b) - lane |= (uint64_t)data[w * 8 + b] << (b * 8); - state[w] ^= lane; - } - - state[8] ^= 0x01ULL; - state[16] ^= 0x80ULL << 56; - - keccak_f1600(state); - - for (uint32_t w = 0; w < 4; ++w) { - uint64_t lane = state[w]; - for (uint32_t b = 0; b < 8; ++b) - out[w * 8 + b] = (uint8_t)(lane >> (b * 8)); - } -} - -// ============================================================================= -// Input/Output structures -// ============================================================================= - -struct EcrecoverInput { - uint8_t r[32]; - uint8_t s[32]; - uint8_t v; - uint8_t _pad[3]; - uint8_t msg_hash[32]; - uint8_t _pad2[28]; -}; - -struct EcrecoverOutput { - uint8_t address[20]; - uint8_t valid; - uint8_t _pad[11]; -}; - -// ============================================================================= -// Helpers: big-endian load/store -// ============================================================================= - -__device__ static uint256 load_be32(const uint8_t bytes[32]) { - uint256 r; - for (int limb = 0; limb < 4; limb++) { - uint64_t v = 0; - int base = (3 - limb) * 8; - for (int b = 0; b < 8; b++) { - v = (v << 8) | (uint64_t)bytes[base + b]; - } - r.limbs[limb] = v; - } - return r; -} - -__device__ static void store_be32(uint256 val, uint8_t bytes[32]) { - for (int limb = 0; limb < 4; limb++) { - int base = (3 - limb) * 8; - uint64_t v = val.limbs[limb]; - for (int b = 7; b >= 0; b--) { - bytes[base + b] = (uint8_t)(v & 0xFF); - v >>= 8; - } - } -} - -// ============================================================================= -// Main kernel: batch secp256k1 ecrecover -// ============================================================================= - -extern "C" __global__ void secp256k1_ecrecover_batch( - const EcrecoverInput* __restrict__ inputs, - EcrecoverOutput* __restrict__ outputs, - const uint32_t num_sigs) -{ - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= num_sigs) return; - - const EcrecoverInput& inp = inputs[tid]; - EcrecoverOutput& out = outputs[tid]; - - // Clear output - for (int i = 0; i < 20; i++) out.address[i] = 0; - out.valid = 0; - for (int i = 0; i < 11; i++) out._pad[i] = 0; - - // Load r, s, v, hash - uint256 r = load_be32(inp.r); - uint256 s = load_be32(inp.s); - uint256 e = load_be32(inp.msg_hash); - uint32_t v = (uint32_t)inp.v; - - // Normalize v - if (v >= 27) v -= 27; - if (v >= 2) v = v % 2; - - // Validate - if (u256_is_zero(r) || u256_cmp(r, SECP256K1_N) >= 0) return; - if (u256_is_zero(s) || u256_cmp(s, SECP256K1_N) >= 0) return; - if (v > 1) return; - - // Step 1: Decompress r → R = (r, y) on secp256k1 - uint256 r_mont = to_mont(r, MONT_R2_P, SECP256K1_P, P_INV); - uint256 r2 = fp_sqr(r_mont); - uint256 r3 = fp_mul(r2, r_mont); - uint256 seven_mont = to_mont(uint256{{7, 0, 0, 0}}, MONT_R2_P, SECP256K1_P, P_INV); - uint256 y2 = fp_add(r3, seven_mont); - - // sqrt(y2) via a^((p+1)/4) since p = 3 mod 4 - uint256 y_mont; - { - uint64_t exp[4] = { - 0xFFFFFFFFBFFFFF0CULL, 0xFFFFFFFFFFFFFFFFULL, - 0xFFFFFFFFFFFFFFFFULL, 0x3FFFFFFFFFFFFFFFULL - }; - uint256 result = to_mont(ONE256, MONT_R2_P, SECP256K1_P, P_INV); - uint256 base = y2; - for (int i = 0; i < 4; i++) { - for (int bit = 0; bit < 64; bit++) { - if ((exp[i] >> bit) & 1) { - result = fp_mul(result, base); - } - base = fp_sqr(base); - } - } - y_mont = result; - } - - // Verify sqrt exists - if (u256_cmp(fp_sqr(y_mont), y2) != 0) return; - - // Select correct y parity - uint256 y_normal = from_mont(y_mont, SECP256K1_P, P_INV); - bool y_is_odd = (y_normal.limbs[0] & 1) != 0; - if ((v == 0 && y_is_odd) || (v == 1 && !y_is_odd)) { - y_mont = fp_sub(ZERO256, y_mont); - } - - uint256 Rx_mont = r_mont; - uint256 Ry_mont = y_mont; - - // Step 2: r_inv = r^(-1) mod n - uint256 r_n_mont = to_mont(r, MONT_R2_N, SECP256K1_N, N_INV); - uint256 r_inv_mont = fn_inv(r_n_mont); - - // Step 3: u1 = -(e * r_inv) mod n, u2 = s * r_inv mod n - uint256 e_n_mont = to_mont(e, MONT_R2_N, SECP256K1_N, N_INV); - uint256 s_n_mont = to_mont(s, MONT_R2_N, SECP256K1_N, N_INV); - - uint256 u1_mont = fn_mul(e_n_mont, r_inv_mont); - uint256 u1 = from_mont(u1_mont, SECP256K1_N, N_INV); - if (!u256_is_zero(u1)) { - uint64_t bw; - u1 = u256_sub(SECP256K1_N, u1, bw); - } - - uint256 u2 = from_mont(fn_mul(s_n_mont, r_inv_mont), SECP256K1_N, N_INV); - - // Step 4: Q = u1*G + u2*R - uint256 Gx_mont = to_mont(GX, MONT_R2_P, SECP256K1_P, P_INV); - uint256 Gy_mont = to_mont(GY, MONT_R2_P, SECP256K1_P, P_INV); - - ECPoint Q1 = ec_mul_affine(u1, Gx_mont, Gy_mont); - ECPoint Q2 = ec_mul_affine(u2, Rx_mont, Ry_mont); - - // Add Q1 + Q2 - ECPoint Q; - if (ec_is_infinity(Q1)) { - Q = Q2; - } else if (ec_is_infinity(Q2)) { - Q = Q1; - } else { - uint256 Q2x_aff, Q2y_aff; - ec_to_affine(Q2, Q2x_aff, Q2y_aff); - Q = ec_add_mixed(Q1, Q2x_aff, Q2y_aff); - } - - if (ec_is_infinity(Q)) return; - - // Step 5: Convert Q to affine, serialize big-endian - uint256 Qx_aff, Qy_aff; - ec_to_affine(Q, Qx_aff, Qy_aff); - - uint256 Qx_norm = from_mont(Qx_aff, SECP256K1_P, P_INV); - uint256 Qy_norm = from_mont(Qy_aff, SECP256K1_P, P_INV); - - uint8_t pubkey[64]; - store_be32(Qx_norm, pubkey); - store_be32(Qy_norm, pubkey + 32); - - // Step 6: address = keccak256(pubkey)[12:] - uint8_t hash[32]; - keccak256_64(pubkey, hash); - - for (int i = 0; i < 20; i++) { - out.address[i] = hash[12 + i]; - } - out.valid = 1; -} diff --git a/secp256k1/gpu/cuda/secp256k1_batch_inv.cu b/secp256k1/gpu/cuda/secp256k1_batch_inv.cu deleted file mode 100644 index b349216..0000000 --- a/secp256k1/gpu/cuda/secp256k1_batch_inv.cu +++ /dev/null @@ -1,330 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Montgomery batch inversion for secp256k1 base field (Fp) and scalar field -// (Fn) -- CUDA port of secp256k1_batch_inv.metal. Output byte-equal to: -// * cpp/batch_inv.hpp (CPU canonical body) -// * gpu/metal/secp256k1_batch_inv.metal (Metal kernel) -// -// Single-thread dispatch (one workgroup of one thread) preserves byte-equal -// determinism with the CPU implementation. The acceleration win is freeing -// the host CPU for other pipeline stages, not raw throughput. -// -// Compile-guarded: when __CUDACC__ is unset (e.g. on macOS hosts without the -// CUDA toolkit), this TU emits a NOTIMPL stub so the umbrella library still -// links. Real device code is built on the linux-amd64 CI runner. - -#include -#include -#include - -#ifdef __CUDACC__ -#include -#endif - -namespace { - -// ============================================================================= -// 256-bit field constants -- byte-equal to secp256k1_batch_inv.metal -// ============================================================================= - -struct uint256 { uint64_t limbs[4]; }; - -#ifdef __CUDACC__ -#define LUX_DEV __device__ -#else -#define LUX_DEV -#endif - -LUX_DEV static const uint256 P_MOD = {{ - 0xFFFFFFFEFFFFFC2FULL, 0xFFFFFFFFFFFFFFFFULL, - 0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL -}}; -LUX_DEV static const uint256 N_MOD = {{ - 0xBFD25E8CD0364141ULL, 0xBAAEDCE6AF48A03BULL, - 0xFFFFFFFFFFFFFFFEULL, 0xFFFFFFFFFFFFFFFFULL -}}; -LUX_DEV static const uint64_t P_INV = 0xD838091DD2253531ULL; -LUX_DEV static const uint64_t N_INV = 0x4B0DFF665588B13FULL; -LUX_DEV static const uint256 R2_N = {{ - 0x896CF21467D7D140ULL, 0x741496C20E7CF878ULL, - 0xE697F5E45BCD07C6ULL, 0x9D671CD581C69BC5ULL -}}; -LUX_DEV static const uint256 ONE_MONT_P = {{ - 0x00000001000003D1ULL, 0ULL, 0ULL, 0ULL -}}; -LUX_DEV static const uint64_t P_M2[4] = { - 0xFFFFFFFEFFFFFC2DULL, 0xFFFFFFFFFFFFFFFFULL, - 0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL -}; -LUX_DEV static const uint64_t N_M2[4] = { - 0xBFD25E8CD036413FULL, 0xBAAEDCE6AF48A03BULL, - 0xFFFFFFFFFFFFFFFEULL, 0xFFFFFFFFFFFFFFFFULL -}; -LUX_DEV static const uint256 ONE = {{ 1ULL, 0ULL, 0ULL, 0ULL }}; - -// ============================================================================= -// 256-bit arithmetic helpers (host-and-device, byte-equal to Metal) -// ============================================================================= - -#ifdef __CUDACC__ -__device__ static int u256_cmp(uint256 a, uint256 b) { -#else -static int u256_cmp(uint256 a, uint256 b) { -#endif - for (int i = 3; i >= 0; --i) { - if (a.limbs[i] < b.limbs[i]) return -1; - if (a.limbs[i] > b.limbs[i]) return 1; - } - return 0; -} - -#ifdef __CUDACC__ -__device__ static void mul64(uint64_t a, uint64_t b, uint64_t& lo, uint64_t& hi) { - unsigned __int128 prod = (unsigned __int128)a * (unsigned __int128)b; - lo = (uint64_t)prod; - hi = (uint64_t)(prod >> 64); -} -#else -static void mul64(uint64_t a, uint64_t b, uint64_t& lo, uint64_t& hi) { - uint64_t al = a & 0xFFFFFFFFULL, ah = a >> 32; - uint64_t bl = b & 0xFFFFFFFFULL, bh = b >> 32; - uint64_t ll = al * bl, lh = al * bh, hl = ah * bl, hh = ah * bh; - uint64_t mid = (ll >> 32) + (lh & 0xFFFFFFFFULL) + (hl & 0xFFFFFFFFULL); - lo = (ll & 0xFFFFFFFFULL) | (mid << 32); - hi = hh + (lh >> 32) + (hl >> 32) + (mid >> 32); -} -#endif - -#ifdef __CUDACC__ -__device__ static uint64_t addc(uint64_t a, uint64_t b, uint64_t c, uint64_t& out) { -#else -static uint64_t addc(uint64_t a, uint64_t b, uint64_t c, uint64_t& out) { -#endif - uint64_t t = a + b; - uint64_t c1 = (t < a) ? 1ULL : 0ULL; - uint64_t t2 = t + c; - uint64_t c2 = (t2 < t) ? 1ULL : 0ULL; - out = t2; - return c1 + c2; -} - -#ifdef __CUDACC__ -__device__ static uint64_t subb(uint64_t a, uint64_t b, uint64_t br, uint64_t& out) { -#else -static uint64_t subb(uint64_t a, uint64_t b, uint64_t br, uint64_t& out) { -#endif - uint64_t t = a - b; - uint64_t b1 = (t > a) ? 1ULL : 0ULL; - uint64_t t2 = t - br; - uint64_t b2 = (t2 > t) ? 1ULL : 0ULL; - out = t2; - return b1 + b2; -} - -#ifdef __CUDACC__ -__device__ static uint256 sub_256(uint256 a, uint256 b, uint64_t& borrow) { -#else -static uint256 sub_256(uint256 a, uint256 b, uint64_t& borrow) { -#endif - uint256 r; - uint64_t br = 0; - for (int i = 0; i < 4; ++i) br = subb(a.limbs[i], b.limbs[i], br, r.limbs[i]); - borrow = br; - return r; -} - -// CIOS Montgomery multiplication -- matches Metal exactly. -#ifdef __CUDACC__ -__device__ static uint256 mont_mul(uint256 a, uint256 b, uint256 m, uint64_t m_inv) { -#else -static uint256 mont_mul(uint256 a, uint256 b, uint256 m, uint64_t m_inv) { -#endif - uint64_t t[6]; - for (int i = 0; i < 6; ++i) t[i] = 0; - for (int i = 0; i < 4; ++i) { - uint64_t carry = 0; - for (int j = 0; j < 4; ++j) { - uint64_t lo, hi; - mul64(a.limbs[j], b.limbs[i], lo, hi); - uint64_t c1 = addc(t[j], lo, carry, t[j]); - carry = hi + c1; - } - uint64_t c1 = addc(t[4], carry, 0, t[4]); - t[5] += c1; - uint64_t u = t[0] * m_inv; - carry = 0; - for (int j = 0; j < 4; ++j) { - uint64_t lo, hi; - mul64(u, m.limbs[j], lo, hi); - uint64_t c2 = addc(t[j], lo, carry, t[j]); - carry = hi + c2; - } - uint64_t c2 = addc(t[4], carry, 0, t[4]); - t[5] += c2; - for (int j = 0; j < 5; ++j) t[j] = t[j + 1]; - t[5] = 0; - } - uint256 r = {{ t[0], t[1], t[2], t[3] }}; - if (t[4] != 0 || u256_cmp(r, m) >= 0) { - uint64_t bw; - r = sub_256(r, m, bw); - } - return r; -} - -#ifdef __CUDACC__ -#define LUX_INV_DECL __device__ static -#else -#define LUX_INV_DECL static -#endif - -LUX_INV_DECL uint256 fp_mul(uint256 a, uint256 b) { return mont_mul(a, b, P_MOD, P_INV); } -LUX_INV_DECL uint256 fn_mul(uint256 a, uint256 b) { return mont_mul(a, b, N_MOD, N_INV); } -LUX_INV_DECL uint256 fp_sqr(uint256 a) { return mont_mul(a, a, P_MOD, P_INV); } -LUX_INV_DECL uint256 fn_sqr(uint256 a) { return mont_mul(a, a, N_MOD, N_INV); } - -LUX_INV_DECL uint256 fp_pow(uint256 a, const uint64_t exp4[4]) { - uint256 result = ONE_MONT_P; - uint256 base = a; - for (int limb = 0; limb < 4; ++limb) { - uint64_t w = exp4[limb]; - for (int bit = 0; bit < 64; ++bit) { - if ((w >> bit) & 1) result = fp_mul(result, base); - base = fp_sqr(base); - } - } - return result; -} -LUX_INV_DECL uint256 fp_inv(uint256 a) { return fp_pow(a, P_M2); } - -LUX_INV_DECL uint256 fn_pow(uint256 a, const uint64_t exp4[4]) { - uint256 result = mont_mul(ONE, R2_N, N_MOD, N_INV); - uint256 base = a; - for (int limb = 0; limb < 4; ++limb) { - uint64_t w = exp4[limb]; - for (int bit = 0; bit < 64; ++bit) { - if ((w >> bit) & 1) result = fn_mul(result, base); - base = fn_sqr(base); - } - } - return result; -} -LUX_INV_DECL uint256 fn_inv(uint256 a) { return fn_pow(a, N_M2); } - -} // namespace - -// ============================================================================= -// Device kernels (compiled only with nvcc) -// ============================================================================= - -#ifdef __CUDACC__ - -extern "C" __global__ void cuda_secp256k1_batch_inv_fp_kernel( - const uint256* __restrict__ in, - uint256* __restrict__ out, - uint32_t n) -{ - if (threadIdx.x != 0 || blockIdx.x != 0) return; - if (n == 0) return; - - out[0] = in[0]; - for (uint32_t i = 1; i < n; ++i) { - out[i] = fp_mul(out[i - 1], in[i]); - } - uint256 inv = fp_inv(out[n - 1]); - for (uint32_t k = n; k > 1; --k) { - uint32_t i = k - 1; - uint256 t = fp_mul(inv, out[i - 1]); - inv = fp_mul(inv, in[i]); - out[i] = t; - } - out[0] = inv; -} - -extern "C" __global__ void cuda_secp256k1_batch_inv_fn_kernel( - const uint256* __restrict__ in, - uint256* __restrict__ out, - uint32_t n) -{ - if (threadIdx.x != 0 || blockIdx.x != 0) return; - if (n == 0) return; - - out[0] = in[0]; - for (uint32_t i = 1; i < n; ++i) { - out[i] = fn_mul(out[i - 1], in[i]); - } - uint256 inv = fn_inv(out[n - 1]); - for (uint32_t k = n; k > 1; --k) { - uint32_t i = k - 1; - uint256 t = fn_mul(inv, out[i - 1]); - inv = fn_mul(inv, in[i]); - out[i] = t; - } - out[0] = inv; -} - -#endif // __CUDACC__ - -// ============================================================================= -// Host launchers (always compiled; on non-CUDA hosts these are NOTIMPL stubs) -// ============================================================================= - -extern "C" int cuda_secp256k1_batch_inv_launch( - const uint8_t* in_mont, - size_t n, - uint8_t* out_mont, - int kind); - -#ifdef __CUDACC__ - -extern "C" int cuda_secp256k1_batch_inv_launch( - const uint8_t* in_mont, - size_t n, - uint8_t* out_mont, - int kind) -{ - if (n == 0) return 0; - if (!in_mont || !out_mont) return -1; - if (kind != 0 && kind != 1) return -2; - - const size_t bytes = n * sizeof(uint256); - - uint256 *d_in = nullptr, *d_out = nullptr; - cudaError_t e; - e = cudaMalloc(reinterpret_cast(&d_in), bytes); - if (e != cudaSuccess) return -10; - e = cudaMalloc(reinterpret_cast(&d_out), bytes); - if (e != cudaSuccess) { cudaFree(d_in); return -11; } - - e = cudaMemcpy(d_in, in_mont, bytes, cudaMemcpyHostToDevice); - if (e != cudaSuccess) { cudaFree(d_in); cudaFree(d_out); return -12; } - - if (kind == 0) { - cuda_secp256k1_batch_inv_fp_kernel<<<1, 1>>>(d_in, d_out, (uint32_t)n); - } else { - cuda_secp256k1_batch_inv_fn_kernel<<<1, 1>>>(d_in, d_out, (uint32_t)n); - } - e = cudaGetLastError(); - if (e != cudaSuccess) { cudaFree(d_in); cudaFree(d_out); return -13; } - - e = cudaDeviceSynchronize(); - if (e != cudaSuccess) { cudaFree(d_in); cudaFree(d_out); return -14; } - - e = cudaMemcpy(out_mont, d_out, bytes, cudaMemcpyDeviceToHost); - cudaFree(d_in); - cudaFree(d_out); - if (e != cudaSuccess) return -15; - return 0; -} - -#else // !__CUDACC__ - -// Non-CUDA host: emit NOTIMPL so the link still resolves. CI builds the real -// path on hanzo-build-linux-amd64 with nvcc. -extern "C" int cuda_secp256k1_batch_inv_launch( - const uint8_t*, size_t, uint8_t*, int) { - return -100; // CRYPTO_ERR_NOTIMPL -} - -#endif // __CUDACC__ diff --git a/secp256k1/gpu/cuda/secp256k1_batch_inv_driver.cu b/secp256k1/gpu/cuda/secp256k1_batch_inv_driver.cu deleted file mode 100644 index 4ea924e..0000000 --- a/secp256k1/gpu/cuda/secp256k1_batch_inv_driver.cu +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CUDA host driver for Stage A (Montgomery batch inversion) of the v0.63 -// ecrecover pipeline. Mirrors gpu/metal/secp256k1_batch_inv_driver.mm. -// -// Entry point matches the Metal driver's signature shape: byte buffers in -// Mont form (limb little-endian), kind = 0 for Fp / 1 for Fn. The CUDA -// "metallib_path" slot is ignored (kept in API for symmetry with future -// PTX-from-cubin loading); kernels are statically linked. -// -// On hosts without nvcc, this TU compiles as plain C++ and the launcher -// returns the NOTIMPL sentinel from secp256k1_batch_inv.cu. - -#include -#include - -extern "C" int cuda_secp256k1_batch_inv_launch( - const uint8_t* in_mont, size_t n, uint8_t* out_mont, int kind); - -// Public entry point; signature mirrors secp256k1_batch_inv_metal so the test -// harness can swap drivers by recompile. -extern "C" int cuda_secp256k1_batch_inv( - const uint8_t* in_mont, // n * 32 bytes (Mont-form, limb little-endian) - size_t n, - uint8_t* out_mont, // n * 32 bytes - int kind, // 0 = Fp, 1 = Fn - const char* /*unused_path*/) { - return cuda_secp256k1_batch_inv_launch(in_mont, n, out_mont, kind); -} diff --git a/secp256k1/gpu/cuda/secp256k1_first_party_cuda_driver.cu b/secp256k1/gpu/cuda/secp256k1_first_party_cuda_driver.cu deleted file mode 100644 index 9e5ac54..0000000 --- a/secp256k1/gpu/cuda/secp256k1_first_party_cuda_driver.cu +++ /dev/null @@ -1,134 +0,0 @@ -// secp256k1 ECDSA batch recovery — CUDA host driver. -// -// Symbol exposed: secp256k1_ecrecover_address_batch_cuda — looked up via -// dlsym from luxcpp/crypto/secp256k1/cpp/ecrecover.cpp when the runtime -// backend resolves to CUDA. Mirrors the Metal driver shape (Apple -// secp256k1_first_party_driver.mm) so dispatch logic is identical. -// -// On non-CUDA hosts CMake sets LANGUAGE CXX and CRYPTO_HAS_CUDA is unset, -// so this file compiles to a NOTIMPL stub (the umbrella library still -// links). On nvcc-built lanes the device path runs. - -#include -#include - -#include "lux/crypto/secp256k1.h" - -#ifdef CRYPTO_HAS_CUDA - -#include - -extern "C" __global__ void secp256k1_ecrecover_batch( - const struct EcrecoverInput* inputs, - struct EcrecoverOutput* outputs, - const uint32_t num_sigs); - -// Same layout as secp256k1_recover.cu kernel side. Re-declared here so the -// host translation unit compiles without including the kernel source. -struct EcrecoverInput { - uint8_t r[32]; - uint8_t s[32]; - uint8_t v; - uint8_t _pad[3]; - uint8_t msg_hash[32]; - uint8_t _pad2[28]; -}; - -struct EcrecoverOutput { - uint8_t address[20]; - uint8_t valid; - uint8_t _pad[11]; -}; - -extern "C" secp256k1_status secp256k1_ecrecover_address_batch_cuda( - const uint8_t* hashes, - const uint8_t* sigs, - size_t n, - uint8_t* out_addr, - uint8_t* out_st) -{ - if (!hashes || !sigs || !out_addr || !out_st || n == 0) { - return SECP256K1_ERR_NULL_ARG; - } - - EcrecoverInput* d_in = nullptr; - EcrecoverOutput* d_out = nullptr; - - cudaError_t err = cudaMalloc(&d_in, n * sizeof(EcrecoverInput)); - if (err != cudaSuccess) return SECP256K1_ERR_NULL_ARG; - - err = cudaMalloc(&d_out, n * sizeof(EcrecoverOutput)); - if (err != cudaSuccess) { - cudaFree(d_in); - return SECP256K1_ERR_NULL_ARG; - } - - EcrecoverInput* h_in = static_cast( - std::malloc(n * sizeof(EcrecoverInput))); - if (!h_in) { - cudaFree(d_in); - cudaFree(d_out); - return SECP256K1_ERR_NULL_ARG; - } - std::memset(h_in, 0, n * sizeof(EcrecoverInput)); - for (size_t i = 0; i < n; ++i) { - std::memcpy(h_in[i].r, sigs + i * 65, 32); - std::memcpy(h_in[i].s, sigs + i * 65 + 32, 32); - h_in[i].v = sigs[i * 65 + 64]; - std::memcpy(h_in[i].msg_hash, hashes + i * 32, 32); - } - - err = cudaMemcpy(d_in, h_in, n * sizeof(EcrecoverInput), - cudaMemcpyHostToDevice); - std::free(h_in); - if (err != cudaSuccess) { - cudaFree(d_in); - cudaFree(d_out); - return SECP256K1_ERR_NULL_ARG; - } - - constexpr int threadsPerBlock = 256; - int blocks = static_cast((n + threadsPerBlock - 1) / threadsPerBlock); - secp256k1_ecrecover_batch<<>>( - d_in, d_out, static_cast(n)); - - err = cudaGetLastError(); - if (err == cudaSuccess) err = cudaDeviceSynchronize(); - if (err != cudaSuccess) { - cudaFree(d_in); - cudaFree(d_out); - return SECP256K1_ERR_NULL_ARG; - } - - EcrecoverOutput* h_out = static_cast( - std::malloc(n * sizeof(EcrecoverOutput))); - if (!h_out) { - cudaFree(d_in); - cudaFree(d_out); - return SECP256K1_ERR_NULL_ARG; - } - err = cudaMemcpy(h_out, d_out, n * sizeof(EcrecoverOutput), - cudaMemcpyDeviceToHost); - cudaFree(d_in); - cudaFree(d_out); - if (err != cudaSuccess) { - std::free(h_out); - return SECP256K1_ERR_NULL_ARG; - } - - for (size_t i = 0; i < n; ++i) { - std::memcpy(out_addr + i * 20, h_out[i].address, 20); - out_st[i] = h_out[i].valid; - } - std::free(h_out); - return SECP256K1_OK; -} - -#else // !CRYPTO_HAS_CUDA — non-CUDA host stub - -extern "C" secp256k1_status secp256k1_ecrecover_address_batch_cuda( - const uint8_t*, const uint8_t*, size_t, uint8_t*, uint8_t*) { - return SECP256K1_ERR_NULL_ARG; -} - -#endif // CRYPTO_HAS_CUDA diff --git a/secp256k1/gpu/cuda/secp256k1_recover.cu b/secp256k1/gpu/cuda/secp256k1_recover.cu deleted file mode 100644 index 2e7e47d..0000000 --- a/secp256k1/gpu/cuda/secp256k1_recover.cu +++ /dev/null @@ -1,685 +0,0 @@ -// secp256k1 ECDSA batch recovery — CUDA implementation -// Matches secp256k1_recover.metal output byte-for-byte -// One thread per ecrecover (r, s, v, msg_hash) → 20-byte Ethereum address - -#include - -#ifndef __CUDA_ARCH__ -#define __device__ -#define __global__ -#define __shared__ -struct dim3 { unsigned x, y, z; }; -static dim3 blockIdx, blockDim, threadIdx; -#endif - -// ============================================================================= -// 256-bit unsigned integer (4 x 64-bit limbs, little-endian) -// ============================================================================= - -struct uint256 { - uint64_t limbs[4]; -}; - -// ============================================================================= -// secp256k1 constants -// ============================================================================= - -__device__ static const uint256 SECP256K1_P = {{ - 0xFFFFFFFEFFFFFC2FULL, 0xFFFFFFFFFFFFFFFFULL, - 0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL -}}; - -__device__ static const uint256 SECP256K1_N = {{ - 0xBFD25E8CD0364141ULL, 0xBAAEDCE6AF48A03BULL, - 0xFFFFFFFFFFFFFFFEULL, 0xFFFFFFFFFFFFFFFFULL -}}; - -__device__ static const uint256 GX = {{ - 0x59F2815B16F81798ULL, 0x029BFCDB2DCE28D9ULL, - 0x55A06295CE870B07ULL, 0x79BE667EF9DCBBACULL -}}; - -__device__ static const uint256 GY = {{ - 0x9C47D08FFB10D4B8ULL, 0xFD17B448A6855419ULL, - 0x5DA4FBFC0E1108A8ULL, 0x483ADA7726A3C465ULL -}}; - -__device__ static const uint256 MONT_R_P = {{ - 0x00000001000003D1ULL, 0x0000000000000000ULL, - 0x0000000000000000ULL, 0x0000000000000000ULL -}}; - -__device__ static const uint256 MONT_R2_P = {{ - 0x000007A2000E90A1ULL, 0x0000000000000001ULL, - 0x0000000000000000ULL, 0x0000000000000000ULL -}}; - -__device__ static const uint64_t P_INV = 0xD838091DD2253531ULL; - -__device__ static const uint256 MONT_R2_N = {{ - 0x896CF21467D7D140ULL, 0x741496C20E7CF878ULL, - 0xE697F5E45BCD07C6ULL, 0x9D671CD581C69BC5ULL -}}; - -__device__ static const uint64_t N_INV = 0x4B0DFF665588B13FULL; - -__device__ static const uint256 ZERO256 = {{0, 0, 0, 0}}; -__device__ static const uint256 ONE256 = {{1, 0, 0, 0}}; - -// ============================================================================= -// 256-bit arithmetic -// ============================================================================= - -__device__ static int u256_cmp(uint256 a, uint256 b) { - for (int i = 3; i >= 0; i--) { - if (a.limbs[i] < b.limbs[i]) return -1; - if (a.limbs[i] > b.limbs[i]) return 1; - } - return 0; -} - -__device__ static bool u256_is_zero(uint256 a) { - return (a.limbs[0] | a.limbs[1] | a.limbs[2] | a.limbs[3]) == 0; -} - -__device__ static uint256 u256_add(uint256 a, uint256 b, uint64_t& carry) { - uint256 r; - uint64_t c = 0; - for (int i = 0; i < 4; i++) { - uint64_t sum = a.limbs[i] + c; - c = (sum < a.limbs[i]) ? 1ULL : 0ULL; - uint64_t sum2 = sum + b.limbs[i]; - c += (sum2 < sum) ? 1ULL : 0ULL; - r.limbs[i] = sum2; - } - carry = c; - return r; -} - -__device__ static uint256 u256_sub(uint256 a, uint256 b, uint64_t& borrow) { - uint256 r; - uint64_t bw = 0; - for (int i = 0; i < 4; i++) { - uint64_t diff = a.limbs[i] - bw; - bw = (diff > a.limbs[i]) ? 1ULL : 0ULL; - uint64_t diff2 = diff - b.limbs[i]; - bw += (diff2 > diff) ? 1ULL : 0ULL; - r.limbs[i] = diff2; - } - borrow = bw; - return r; -} - -// ============================================================================= -// Montgomery arithmetic (parameterized by modulus m and inv = -m^(-1) mod 2^64) -// Uses __int128 on CUDA for 64x64->128 multiply -// ============================================================================= - -__device__ static uint256 mont_reduce(uint64_t t[8], uint256 m, uint64_t inv) { - uint64_t a[9]; - for (int i = 0; i < 8; i++) a[i] = t[i]; - a[8] = 0; - - for (int i = 0; i < 4; i++) { - uint64_t u = a[i] * inv; - - uint64_t carry = 0; - for (int j = 0; j < 4; j++) { -#ifdef __CUDA_ARCH__ - unsigned __int128 prod = (unsigned __int128)u * m.limbs[j]; - unsigned __int128 acc = prod + carry + a[i + j]; - a[i + j] = (uint64_t)acc; - carry = (uint64_t)(acc >> 64); -#else - uint64_t u_lo = u & 0xFFFFFFFFULL; - uint64_t u_hi = u >> 32; - uint64_t m_lo = m.limbs[j] & 0xFFFFFFFFULL; - uint64_t m_hi = m.limbs[j] >> 32; - uint64_t ll = u_lo * m_lo; - uint64_t lh = u_lo * m_hi; - uint64_t hl = u_hi * m_lo; - uint64_t hh = u_hi * m_hi; - uint64_t mid = lh + (ll >> 32); - uint64_t mid2 = mid + hl; - if (mid2 < mid) hh += (1ULL << 32); - uint64_t lo = (mid2 << 32) | (ll & 0xFFFFFFFFULL); - uint64_t hi = hh + (mid2 >> 32); - uint64_t sum = lo + carry; - if (sum < lo) hi++; - lo = sum; - sum = a[i + j] + lo; - if (sum < a[i + j]) hi++; - a[i + j] = sum; - carry = hi; -#endif - } - for (int j = 4; i + j <= 8; j++) { - uint64_t sum = a[i + j] + carry; - carry = (sum < a[i + j]) ? 1ULL : 0ULL; - a[i + j] = sum; - if (carry == 0) break; - } - } - - uint256 r; - r.limbs[0] = a[4]; - r.limbs[1] = a[5]; - r.limbs[2] = a[6]; - r.limbs[3] = a[7]; - - if (a[8] || u256_cmp(r, m) >= 0) { - uint64_t bw; - r = u256_sub(r, m, bw); - } - return r; -} - -__device__ static uint256 mont_mul(uint256 a, uint256 b, uint256 m, uint64_t inv) { - uint64_t t[8] = {0, 0, 0, 0, 0, 0, 0, 0}; - - for (int i = 0; i < 4; i++) { - uint64_t carry = 0; - for (int j = 0; j < 4; j++) { -#ifdef __CUDA_ARCH__ - unsigned __int128 prod = (unsigned __int128)a.limbs[i] * b.limbs[j]; - unsigned __int128 acc = prod + carry + t[i + j]; - t[i + j] = (uint64_t)acc; - carry = (uint64_t)(acc >> 64); -#else - uint64_t a_lo = a.limbs[i] & 0xFFFFFFFFULL; - uint64_t a_hi = a.limbs[i] >> 32; - uint64_t b_lo = b.limbs[j] & 0xFFFFFFFFULL; - uint64_t b_hi = b.limbs[j] >> 32; - uint64_t ll = a_lo * b_lo; - uint64_t lh = a_lo * b_hi; - uint64_t hl = a_hi * b_lo; - uint64_t hh = a_hi * b_hi; - uint64_t mid = lh + (ll >> 32); - uint64_t mid2 = mid + hl; - if (mid2 < mid) hh += (1ULL << 32); - uint64_t lo = (mid2 << 32) | (ll & 0xFFFFFFFFULL); - uint64_t hi = hh + (mid2 >> 32); - uint64_t sum = lo + carry; - if (sum < lo) hi++; - lo = sum; - sum = t[i + j] + lo; - if (sum < t[i + j]) hi++; - t[i + j] = sum; - carry = hi; -#endif - } - for (int j = 4; i + j < 8; j++) { - uint64_t sum = t[i + j] + carry; - carry = (sum < t[i + j]) ? 1ULL : 0ULL; - t[i + j] = sum; - if (carry == 0) break; - } - } - - return mont_reduce(t, m, inv); -} - -__device__ static uint256 to_mont(uint256 a, uint256 r2, uint256 m, uint64_t inv) { - return mont_mul(a, r2, m, inv); -} - -__device__ static uint256 from_mont(uint256 a, uint256 m, uint64_t inv) { - uint64_t t[8] = {a.limbs[0], a.limbs[1], a.limbs[2], a.limbs[3], 0, 0, 0, 0}; - return mont_reduce(t, m, inv); -} - -// Field operations over p (Montgomery form) -__device__ static uint256 fp_add(uint256 a, uint256 b) { - uint64_t carry; - uint256 r = u256_add(a, b, carry); - if (carry || u256_cmp(r, SECP256K1_P) >= 0) { - uint64_t bw; - r = u256_sub(r, SECP256K1_P, bw); - } - return r; -} - -__device__ static uint256 fp_sub(uint256 a, uint256 b) { - uint64_t bw; - uint256 r = u256_sub(a, b, bw); - if (bw) { - uint64_t c; - r = u256_add(r, SECP256K1_P, c); - } - return r; -} - -__device__ static uint256 fp_mul(uint256 a, uint256 b) { - return mont_mul(a, b, SECP256K1_P, P_INV); -} - -__device__ static uint256 fp_sqr(uint256 a) { - return fp_mul(a, a); -} - -// Scalar field operations over n -__device__ static uint256 fn_mul(uint256 a, uint256 b) { - return mont_mul(a, b, SECP256K1_N, N_INV); -} - -// Fermat inversion over p -__device__ static uint256 fp_inv(uint256 a) { - uint256 result = to_mont(ONE256, MONT_R2_P, SECP256K1_P, P_INV); - uint256 base = a; - - uint64_t exp[4] = { - 0xFFFFFFFEFFFFFC2DULL, 0xFFFFFFFFFFFFFFFFULL, - 0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL - }; - - for (int i = 0; i < 4; i++) { - for (int bit = 0; bit < 64; bit++) { - if ((exp[i] >> bit) & 1) { - result = fp_mul(result, base); - } - base = fp_sqr(base); - } - } - return result; -} - -// Scalar inversion over n -__device__ static uint256 fn_inv(uint256 a) { - uint64_t exp[4] = { - 0xBFD25E8CD036413FULL, 0xBAAEDCE6AF48A03BULL, - 0xFFFFFFFFFFFFFFFEULL, 0xFFFFFFFFFFFFFFFFULL - }; - - uint256 result = to_mont(ONE256, MONT_R2_N, SECP256K1_N, N_INV); - uint256 base = a; - - for (int i = 0; i < 4; i++) { - for (int bit = 0; bit < 64; bit++) { - if ((exp[i] >> bit) & 1) { - result = fn_mul(result, base); - } - base = fn_mul(base, base); - } - } - return result; -} - -// ============================================================================= -// secp256k1 EC point operations (Jacobian coordinates, Montgomery F_p) -// ============================================================================= - -struct ECPoint { - uint256 x, y, z; -}; - -__device__ static ECPoint ec_identity() { - ECPoint p; - p.x = to_mont(ONE256, MONT_R2_P, SECP256K1_P, P_INV); - p.y = to_mont(ONE256, MONT_R2_P, SECP256K1_P, P_INV); - p.z = ZERO256; - return p; -} - -__device__ static bool ec_is_infinity(ECPoint p) { - return u256_is_zero(p.z); -} - -__device__ static ECPoint ec_double(ECPoint p) { - if (ec_is_infinity(p)) return p; - - uint256 A = fp_sqr(p.y); - uint256 B = fp_mul(p.x, A); - uint256 C = fp_sqr(A); - - uint256 S = fp_add(B, B); - S = fp_add(S, S); - - uint256 X2 = fp_sqr(p.x); - uint256 M = fp_add(X2, fp_add(X2, X2)); - - uint256 X3 = fp_sub(fp_sqr(M), fp_add(S, S)); - - uint256 C8 = fp_add(C, C); - C8 = fp_add(C8, C8); - C8 = fp_add(C8, C8); - uint256 Y3 = fp_sub(fp_mul(M, fp_sub(S, X3)), C8); - - uint256 Z3 = fp_mul(p.y, p.z); - Z3 = fp_add(Z3, Z3); - - ECPoint r; - r.x = X3; - r.y = Y3; - r.z = Z3; - return r; -} - -__device__ static ECPoint ec_add_mixed(ECPoint P, uint256 Qx, uint256 Qy) { - if (ec_is_infinity(P)) { - ECPoint r; - r.x = Qx; - r.y = Qy; - r.z = to_mont(ONE256, MONT_R2_P, SECP256K1_P, P_INV); - return r; - } - - uint256 Z2 = fp_sqr(P.z); - uint256 U2 = fp_mul(Qx, Z2); - uint256 Z3 = fp_mul(Z2, P.z); - uint256 S2 = fp_mul(Qy, Z3); - - uint256 H = fp_sub(U2, P.x); - uint256 R = fp_sub(S2, P.y); - - if (u256_is_zero(H)) { - if (u256_is_zero(R)) { - return ec_double(P); - } - return ec_identity(); - } - - uint256 H2 = fp_sqr(H); - uint256 H3 = fp_mul(H, H2); - uint256 U1H2 = fp_mul(P.x, H2); - - uint256 X3 = fp_sub(fp_sub(fp_sqr(R), H3), fp_add(U1H2, U1H2)); - uint256 Y3 = fp_sub(fp_mul(R, fp_sub(U1H2, X3)), fp_mul(P.y, H3)); - uint256 Zr = fp_mul(H, P.z); - - ECPoint res; - res.x = X3; - res.y = Y3; - res.z = Zr; - return res; -} - -__device__ static void ec_to_affine(ECPoint p, uint256& ax, uint256& ay) { - if (ec_is_infinity(p)) { - ax = ZERO256; - ay = ZERO256; - return; - } - uint256 z_inv = fp_inv(p.z); - uint256 z_inv2 = fp_sqr(z_inv); - uint256 z_inv3 = fp_mul(z_inv2, z_inv); - ax = fp_mul(p.x, z_inv2); - ay = fp_mul(p.y, z_inv3); -} - -// NOT constant-time: branches on scalar bits. Safe for ecrecover (all inputs -// are public). MUST NOT be reused for signing where the nonce k is secret. -__device__ static ECPoint ec_mul_affine(uint256 k, uint256 Px, uint256 Py) { - ECPoint result = ec_identity(); - - for (int i = 3; i >= 0; i--) { - for (int bit = 63; bit >= 0; bit--) { - result = ec_double(result); - if ((k.limbs[i] >> bit) & 1) { - result = ec_add_mixed(result, Px, Py); - } - } - } - return result; -} - -// ============================================================================= -// Keccak-256 (inline, for address derivation) -// ============================================================================= - -__device__ static const uint64_t KECCAK_RC[24] = { - 0x0000000000000001ULL, 0x0000000000008082ULL, - 0x800000000000808AULL, 0x8000000080008000ULL, - 0x000000000000808BULL, 0x0000000080000001ULL, - 0x8000000080008081ULL, 0x8000000000008009ULL, - 0x000000000000008AULL, 0x0000000000000088ULL, - 0x0000000080008009ULL, 0x000000008000000AULL, - 0x000000008000808BULL, 0x800000000000008BULL, - 0x8000000000008089ULL, 0x8000000000008003ULL, - 0x8000000000008002ULL, 0x8000000000000080ULL, - 0x000000000000800AULL, 0x800000008000000AULL, - 0x8000000080008081ULL, 0x8000000000008080ULL, - 0x0000000080000001ULL, 0x8000000080008008ULL, -}; - -__device__ static const int KECCAK_PI_LANE[24] = { - 10, 7, 11, 17, 18, 3, 5, 16, 8, 21, 24, 4, - 15, 23, 19, 13, 12, 2, 20, 14, 22, 9, 6, 1 -}; - -__device__ static const int KECCAK_RHO[24] = { - 1, 3, 6, 10, 15, 21, 28, 36, 45, 55, 2, 14, - 27, 41, 56, 8, 25, 43, 62, 18, 39, 61, 20, 44 -}; - -__device__ static uint64_t keccak_rotl64(uint64_t x, int n) { - return (x << n) | (x >> (64 - n)); -} - -__device__ static void keccak_f1600(uint64_t st[25]) { - for (int round = 0; round < 24; ++round) { - uint64_t C[5]; - for (int x = 0; x < 5; ++x) - C[x] = st[x] ^ st[x + 5] ^ st[x + 10] ^ st[x + 15] ^ st[x + 20]; - for (int x = 0; x < 5; ++x) { - uint64_t d = C[(x + 4) % 5] ^ keccak_rotl64(C[(x + 1) % 5], 1); - for (int y = 0; y < 5; ++y) - st[x + 5 * y] ^= d; - } - uint64_t t = st[1]; - for (int i = 0; i < 24; ++i) { - uint64_t tmp = st[KECCAK_PI_LANE[i]]; - st[KECCAK_PI_LANE[i]] = keccak_rotl64(t, KECCAK_RHO[i]); - t = tmp; - } - for (int y = 0; y < 5; ++y) { - uint64_t row[5]; - for (int x = 0; x < 5; ++x) row[x] = st[x + 5 * y]; - for (int x = 0; x < 5; ++x) - st[x + 5 * y] = row[x] ^ ((~row[(x + 1) % 5]) & row[(x + 2) % 5]); - } - st[0] ^= KECCAK_RC[round]; - } -} - -__device__ static void keccak256_64(const uint8_t data[64], uint8_t out[32]) { - uint64_t state[25] = {}; - - for (uint32_t w = 0; w < 8; ++w) { - uint64_t lane = 0; - for (uint32_t b = 0; b < 8; ++b) - lane |= (uint64_t)data[w * 8 + b] << (b * 8); - state[w] ^= lane; - } - - state[8] ^= 0x01ULL; - state[16] ^= 0x80ULL << 56; - - keccak_f1600(state); - - for (uint32_t w = 0; w < 4; ++w) { - uint64_t lane = state[w]; - for (uint32_t b = 0; b < 8; ++b) - out[w * 8 + b] = (uint8_t)(lane >> (b * 8)); - } -} - -// ============================================================================= -// Input/Output structures -// ============================================================================= - -struct EcrecoverInput { - uint8_t r[32]; - uint8_t s[32]; - uint8_t v; - uint8_t _pad[3]; - uint8_t msg_hash[32]; - uint8_t _pad2[28]; -}; - -struct EcrecoverOutput { - uint8_t address[20]; - uint8_t valid; - uint8_t _pad[11]; -}; - -// ============================================================================= -// Helpers: big-endian load/store -// ============================================================================= - -__device__ static uint256 load_be32(const uint8_t bytes[32]) { - uint256 r; - for (int limb = 0; limb < 4; limb++) { - uint64_t v = 0; - int base = (3 - limb) * 8; - for (int b = 0; b < 8; b++) { - v = (v << 8) | (uint64_t)bytes[base + b]; - } - r.limbs[limb] = v; - } - return r; -} - -__device__ static void store_be32(uint256 val, uint8_t bytes[32]) { - for (int limb = 0; limb < 4; limb++) { - int base = (3 - limb) * 8; - uint64_t v = val.limbs[limb]; - for (int b = 7; b >= 0; b--) { - bytes[base + b] = (uint8_t)(v & 0xFF); - v >>= 8; - } - } -} - -// ============================================================================= -// Main kernel: batch secp256k1 ecrecover -// ============================================================================= - -extern "C" __global__ void secp256k1_ecrecover_batch( - const EcrecoverInput* __restrict__ inputs, - EcrecoverOutput* __restrict__ outputs, - const uint32_t num_sigs) -{ - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= num_sigs) return; - - const EcrecoverInput& inp = inputs[tid]; - EcrecoverOutput& out = outputs[tid]; - - // Clear output - for (int i = 0; i < 20; i++) out.address[i] = 0; - out.valid = 0; - for (int i = 0; i < 11; i++) out._pad[i] = 0; - - // Load r, s, v, hash - uint256 r = load_be32(inp.r); - uint256 s = load_be32(inp.s); - uint256 e = load_be32(inp.msg_hash); - uint32_t v = (uint32_t)inp.v; - - // Normalize v - if (v >= 27) v -= 27; - if (v >= 2) v = v % 2; - - // Validate - if (u256_is_zero(r) || u256_cmp(r, SECP256K1_N) >= 0) return; - if (u256_is_zero(s) || u256_cmp(s, SECP256K1_N) >= 0) return; - if (v > 1) return; - - // Step 1: Decompress r → R = (r, y) on secp256k1 - uint256 r_mont = to_mont(r, MONT_R2_P, SECP256K1_P, P_INV); - uint256 r2 = fp_sqr(r_mont); - uint256 r3 = fp_mul(r2, r_mont); - uint256 seven_mont = to_mont(uint256{{7, 0, 0, 0}}, MONT_R2_P, SECP256K1_P, P_INV); - uint256 y2 = fp_add(r3, seven_mont); - - // sqrt(y2) via a^((p+1)/4) since p = 3 mod 4 - uint256 y_mont; - { - uint64_t exp[4] = { - 0xFFFFFFFFBFFFFF0CULL, 0xFFFFFFFFFFFFFFFFULL, - 0xFFFFFFFFFFFFFFFFULL, 0x3FFFFFFFFFFFFFFFULL - }; - uint256 result = to_mont(ONE256, MONT_R2_P, SECP256K1_P, P_INV); - uint256 base = y2; - for (int i = 0; i < 4; i++) { - for (int bit = 0; bit < 64; bit++) { - if ((exp[i] >> bit) & 1) { - result = fp_mul(result, base); - } - base = fp_sqr(base); - } - } - y_mont = result; - } - - // Verify sqrt exists - if (u256_cmp(fp_sqr(y_mont), y2) != 0) return; - - // Select correct y parity - uint256 y_normal = from_mont(y_mont, SECP256K1_P, P_INV); - bool y_is_odd = (y_normal.limbs[0] & 1) != 0; - if ((v == 0 && y_is_odd) || (v == 1 && !y_is_odd)) { - y_mont = fp_sub(ZERO256, y_mont); - } - - uint256 Rx_mont = r_mont; - uint256 Ry_mont = y_mont; - - // Step 2: r_inv = r^(-1) mod n - uint256 r_n_mont = to_mont(r, MONT_R2_N, SECP256K1_N, N_INV); - uint256 r_inv_mont = fn_inv(r_n_mont); - - // Step 3: u1 = -(e * r_inv) mod n, u2 = s * r_inv mod n - uint256 e_n_mont = to_mont(e, MONT_R2_N, SECP256K1_N, N_INV); - uint256 s_n_mont = to_mont(s, MONT_R2_N, SECP256K1_N, N_INV); - - uint256 u1_mont = fn_mul(e_n_mont, r_inv_mont); - uint256 u1 = from_mont(u1_mont, SECP256K1_N, N_INV); - if (!u256_is_zero(u1)) { - uint64_t bw; - u1 = u256_sub(SECP256K1_N, u1, bw); - } - - uint256 u2 = from_mont(fn_mul(s_n_mont, r_inv_mont), SECP256K1_N, N_INV); - - // Step 4: Q = u1*G + u2*R - uint256 Gx_mont = to_mont(GX, MONT_R2_P, SECP256K1_P, P_INV); - uint256 Gy_mont = to_mont(GY, MONT_R2_P, SECP256K1_P, P_INV); - - ECPoint Q1 = ec_mul_affine(u1, Gx_mont, Gy_mont); - ECPoint Q2 = ec_mul_affine(u2, Rx_mont, Ry_mont); - - // Add Q1 + Q2 - ECPoint Q; - if (ec_is_infinity(Q1)) { - Q = Q2; - } else if (ec_is_infinity(Q2)) { - Q = Q1; - } else { - uint256 Q2x_aff, Q2y_aff; - ec_to_affine(Q2, Q2x_aff, Q2y_aff); - Q = ec_add_mixed(Q1, Q2x_aff, Q2y_aff); - } - - if (ec_is_infinity(Q)) return; - - // Step 5: Convert Q to affine, serialize big-endian - uint256 Qx_aff, Qy_aff; - ec_to_affine(Q, Qx_aff, Qy_aff); - - uint256 Qx_norm = from_mont(Qx_aff, SECP256K1_P, P_INV); - uint256 Qy_norm = from_mont(Qy_aff, SECP256K1_P, P_INV); - - uint8_t pubkey[64]; - store_be32(Qx_norm, pubkey); - store_be32(Qy_norm, pubkey + 32); - - // Step 6: address = keccak256(pubkey)[12:] - uint8_t hash[32]; - keccak256_64(pubkey, hash); - - for (int i = 0; i < 20; i++) { - out.address[i] = hash[12 + i]; - } - out.valid = 1; -} diff --git a/secp256k1/gpu/metal/secp256k1.metal b/secp256k1/gpu/metal/secp256k1.metal deleted file mode 100644 index ba62a70..0000000 --- a/secp256k1/gpu/metal/secp256k1.metal +++ /dev/null @@ -1,561 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// First-party Metal kernel for batch secp256k1 ecrecover. -// -// This kernel mirrors src/secp256k1/{field.hpp,curve.hpp,ecrecover.cpp} -// line-for-line so that CPU and GPU produce byte-identical output by -// construction. Correctness first; optimization in subsequent passes. -// -// One thread per signature. Output: address[20] (last 20 bytes of -// keccak256(pubkey)). - -#include -using namespace metal; - -// =========================================================================== -// 256-bit unsigned integer -// =========================================================================== - -struct uint256 { - ulong limbs[4]; -}; - -constant uint256 P_MOD = {{ - 0xFFFFFFFEFFFFFC2FUL, 0xFFFFFFFFFFFFFFFFUL, - 0xFFFFFFFFFFFFFFFFUL, 0xFFFFFFFFFFFFFFFFUL -}}; - -constant uint256 N_MOD = {{ - 0xBFD25E8CD0364141UL, 0xBAAEDCE6AF48A03BUL, - 0xFFFFFFFFFFFFFFFEUL, 0xFFFFFFFFFFFFFFFFUL -}}; - -constant uint256 R2_P = {{ - 0x000007A2000E90A1UL, 0x0000000000000001UL, - 0x0000000000000000UL, 0x0000000000000000UL -}}; -constant ulong P_INV = 0xD838091DD2253531UL; - -constant uint256 R2_N = {{ - 0x896CF21467D7D140UL, 0x741496C20E7CF878UL, - 0xE697F5E45BCD07C6UL, 0x9D671CD581C69BC5UL -}}; -constant ulong N_INV = 0x4B0DFF665588B13FUL; - -// Montgomery encoding of 1 mod p (== R mod p) -constant uint256 ONE_MONT_P = {{ - 0x00000001000003D1UL, 0UL, 0UL, 0UL -}}; - -constant uint256 GX_PLAIN = {{ - 0x59F2815B16F81798UL, 0x029BFCDB2DCE28D9UL, - 0x55A06295CE870B07UL, 0x79BE667EF9DCBBACUL -}}; -constant uint256 GY_PLAIN = {{ - 0x9C47D08FFB10D4B8UL, 0xFD17B448A6855419UL, - 0x5DA4FBFC0E1108A8UL, 0x483ADA7726A3C465UL -}}; - -constant ulong PP1_4[4] = { - 0xFFFFFFFFBFFFFF0CUL, 0xFFFFFFFFFFFFFFFFUL, - 0xFFFFFFFFFFFFFFFFUL, 0x3FFFFFFFFFFFFFFFUL -}; -constant ulong P_M2[4] = { - 0xFFFFFFFEFFFFFC2DUL, 0xFFFFFFFFFFFFFFFFUL, - 0xFFFFFFFFFFFFFFFFUL, 0xFFFFFFFFFFFFFFFFUL -}; -constant ulong N_M2[4] = { - 0xBFD25E8CD036413FUL, 0xBAAEDCE6AF48A03BUL, - 0xFFFFFFFFFFFFFFFEUL, 0xFFFFFFFFFFFFFFFFUL -}; - -constant uint256 ZERO = {{0,0,0,0}}; -constant uint256 ONE = {{1,0,0,0}}; - -// =========================================================================== -// Helpers -// =========================================================================== - -inline int u256_cmp(uint256 a, uint256 b) { - for (int i = 3; i >= 0; --i) { - if (a.limbs[i] < b.limbs[i]) return -1; - if (a.limbs[i] > b.limbs[i]) return 1; - } - return 0; -} - -inline bool u256_is_zero(uint256 a) { - return (a.limbs[0] | a.limbs[1] | a.limbs[2] | a.limbs[3]) == 0; -} - -inline void mul64(ulong a, ulong b, thread ulong &lo, thread ulong &hi) { - ulong al = a & 0xFFFFFFFFUL; - ulong ah = a >> 32; - ulong bl = b & 0xFFFFFFFFUL; - ulong bh = b >> 32; - - ulong ll = al * bl; - ulong lh = al * bh; - ulong hl = ah * bl; - ulong hh = ah * bh; - - ulong mid = (ll >> 32) + (lh & 0xFFFFFFFFUL) + (hl & 0xFFFFFFFFUL); - lo = (ll & 0xFFFFFFFFUL) | (mid << 32); - hi = hh + (lh >> 32) + (hl >> 32) + (mid >> 32); -} - -inline ulong addc(ulong a, ulong b, ulong c, thread ulong &out) { - ulong t = a + b; - ulong c1 = (t < a) ? 1UL : 0UL; - ulong t2 = t + c; - ulong c2 = (t2 < t) ? 1UL : 0UL; - out = t2; - return c1 + c2; -} - -inline ulong subb(ulong a, ulong b, ulong br, thread ulong &out) { - ulong t = a - b; - ulong b1 = (t > a) ? 1UL : 0UL; - ulong t2 = t - br; - ulong b2 = (t2 > t) ? 1UL : 0UL; - out = t2; - return b1 + b2; -} - -inline uint256 add_256(uint256 a, uint256 b, thread ulong &carry) { - uint256 r; - ulong c = 0; - for (int i = 0; i < 4; ++i) c = addc(a.limbs[i], b.limbs[i], c, r.limbs[i]); - carry = c; - return r; -} - -inline uint256 sub_256(uint256 a, uint256 b, thread ulong &borrow) { - uint256 r; - ulong br = 0; - for (int i = 0; i < 4; ++i) br = subb(a.limbs[i], b.limbs[i], br, r.limbs[i]); - borrow = br; - return r; -} - -inline uint256 mod_add(uint256 a, uint256 b, uint256 m) { - ulong c; - uint256 t = add_256(a, b, c); - if (c != 0 || u256_cmp(t, m) >= 0) { - ulong bw; - t = sub_256(t, m, bw); - } - return t; -} - -inline uint256 mod_sub(uint256 a, uint256 b, uint256 m) { - ulong bw; - uint256 t = sub_256(a, b, bw); - if (bw != 0) { - ulong c; - t = add_256(t, m, c); - } - return t; -} - -inline uint256 mont_mul(uint256 a, uint256 b, uint256 m, ulong m_inv) { - ulong t[6]; - for (int i = 0; i < 6; ++i) t[i] = 0; - - for (int i = 0; i < 4; ++i) { - ulong carry = 0; - for (int j = 0; j < 4; ++j) { - ulong lo, hi; - mul64(a.limbs[j], b.limbs[i], lo, hi); - ulong c1 = addc(t[j], lo, carry, t[j]); - carry = hi + c1; - } - ulong c1 = addc(t[4], carry, 0, t[4]); - t[5] += c1; - - ulong u = t[0] * m_inv; - - carry = 0; - for (int j = 0; j < 4; ++j) { - ulong lo, hi; - mul64(u, m.limbs[j], lo, hi); - ulong c2 = addc(t[j], lo, carry, t[j]); - carry = hi + c2; - } - ulong c2 = addc(t[4], carry, 0, t[4]); - t[5] += c2; - - for (int j = 0; j < 5; ++j) t[j] = t[j + 1]; - t[5] = 0; - } - - uint256 r = {{ t[0], t[1], t[2], t[3] }}; - if (t[4] != 0 || u256_cmp(r, m) >= 0) { - ulong bw; - r = sub_256(r, m, bw); - } - return r; -} - -inline uint256 to_mont(uint256 a, uint256 r2, uint256 m, ulong m_inv) { - return mont_mul(a, r2, m, m_inv); -} -inline uint256 from_mont(uint256 a, uint256 m, ulong m_inv) { - return mont_mul(a, ONE, m, m_inv); -} - -inline uint256 fp_add(uint256 a, uint256 b) { return mod_add(a, b, P_MOD); } -inline uint256 fp_sub(uint256 a, uint256 b) { return mod_sub(a, b, P_MOD); } -inline uint256 fp_mul(uint256 a, uint256 b) { return mont_mul(a, b, P_MOD, P_INV); } -inline uint256 fp_sqr(uint256 a) { return mont_mul(a, a, P_MOD, P_INV); } - -inline uint256 fp_pow(uint256 a_mont, constant ulong* exp4) { - uint256 result = ONE_MONT_P; - uint256 base = a_mont; - for (int limb = 0; limb < 4; ++limb) { - ulong w = exp4[limb]; - for (int bit = 0; bit < 64; ++bit) { - if ((w >> bit) & 1) result = fp_mul(result, base); - base = fp_sqr(base); - } - } - return result; -} - -inline uint256 fp_inv(uint256 a_mont) { return fp_pow(a_mont, P_M2); } - -inline bool fp_sqrt(uint256 a_mont, thread uint256 &out) { - uint256 cand = fp_pow(a_mont, PP1_4); - if (u256_cmp(fp_sqr(cand), a_mont) != 0) return false; - out = cand; - return true; -} - -inline uint256 fn_mul(uint256 a, uint256 b) { return mont_mul(a, b, N_MOD, N_INV); } -inline uint256 fn_sqr(uint256 a) { return mont_mul(a, a, N_MOD, N_INV); } - -inline uint256 fn_pow(uint256 a_mont, constant ulong* exp4) { - uint256 result = mont_mul(ONE, R2_N, N_MOD, N_INV); - uint256 base = a_mont; - for (int limb = 0; limb < 4; ++limb) { - ulong w = exp4[limb]; - for (int bit = 0; bit < 64; ++bit) { - if ((w >> bit) & 1) result = fn_mul(result, base); - base = fn_sqr(base); - } - } - return result; -} - -inline uint256 fn_inv(uint256 a_mont) { return fn_pow(a_mont, N_M2); } - -// =========================================================================== -// EC point operations -// =========================================================================== - -struct AffinePt { uint256 x; uint256 y; bool inf; }; -struct JacPt { uint256 X; uint256 Y; uint256 Z; bool inf; }; - -inline JacPt jac_zero() { - JacPt r; r.X = ZERO; r.Y = ZERO; r.Z = ZERO; r.inf = true; return r; -} - -inline JacPt aff_to_jac(AffinePt p) { - if (p.inf) return jac_zero(); - JacPt r; r.X = p.x; r.Y = p.y; r.Z = ONE_MONT_P; r.inf = false; - return r; -} - -inline AffinePt jac_to_aff(JacPt p) { - AffinePt a; - if (p.inf || u256_is_zero(p.Z)) { a.x = ZERO; a.y = ZERO; a.inf = true; return a; } - uint256 zi = fp_inv(p.Z); - uint256 zi2 = fp_sqr(zi); - uint256 zi3 = fp_mul(zi2, zi); - a.x = fp_mul(p.X, zi2); - a.y = fp_mul(p.Y, zi3); - a.inf = false; - return a; -} - -inline JacPt jac_double(JacPt p) { - if (p.inf) return p; - if (u256_is_zero(p.Y)) return jac_zero(); - uint256 A = fp_sqr(p.X); - uint256 B = fp_sqr(p.Y); - uint256 C = fp_sqr(B); - uint256 XplusB = fp_add(p.X, B); - uint256 D = fp_sub(fp_sqr(XplusB), A); - D = fp_sub(D, C); - D = fp_add(D, D); - uint256 E = fp_add(A, A); E = fp_add(E, A); - uint256 F = fp_sqr(E); - uint256 X3 = fp_sub(F, fp_add(D, D)); - uint256 D_m_X3 = fp_sub(D, X3); - uint256 eight_C = fp_add(C, C); eight_C = fp_add(eight_C, eight_C); eight_C = fp_add(eight_C, eight_C); - uint256 Y3 = fp_sub(fp_mul(E, D_m_X3), eight_C); - uint256 Z3 = fp_mul(p.Y, p.Z); Z3 = fp_add(Z3, Z3); - JacPt r; r.X = X3; r.Y = Y3; r.Z = Z3; r.inf = false; return r; -} - -inline JacPt jac_add(JacPt p, JacPt q) { - if (p.inf) return q; - if (q.inf) return p; - uint256 Z1Z1 = fp_sqr(p.Z); - uint256 Z2Z2 = fp_sqr(q.Z); - uint256 U1 = fp_mul(p.X, Z2Z2); - uint256 U2 = fp_mul(q.X, Z1Z1); - uint256 S1 = fp_mul(p.Y, fp_mul(Z2Z2, q.Z)); - uint256 S2 = fp_mul(q.Y, fp_mul(Z1Z1, p.Z)); - uint256 H = fp_sub(U2, U1); - uint256 r = fp_sub(S2, S1); - if (u256_is_zero(H)) { - if (u256_is_zero(r)) return jac_double(p); - return jac_zero(); - } - uint256 HH = fp_sqr(H); - uint256 HHH = fp_mul(H, HH); - uint256 U1HH = fp_mul(U1, HH); - uint256 X3 = fp_sub(fp_sqr(r), HHH); - X3 = fp_sub(X3, fp_add(U1HH, U1HH)); - uint256 Y3 = fp_mul(r, fp_sub(U1HH, X3)); - Y3 = fp_sub(Y3, fp_mul(S1, HHH)); - uint256 Z3 = fp_mul(fp_mul(p.Z, q.Z), H); - JacPt out; out.X = X3; out.Y = Y3; out.Z = Z3; out.inf = false; return out; -} - -inline JacPt jac_mul(uint256 k, AffinePt p) { - if (p.inf) return jac_zero(); - JacPt r = jac_zero(); - JacPt base = aff_to_jac(p); - for (int limb = 3; limb >= 0; --limb) { - ulong w = k.limbs[limb]; - for (int bit = 63; bit >= 0; --bit) { - r = jac_double(r); - if ((w >> bit) & 1) r = jac_add(r, base); - } - } - return r; -} - -// =========================================================================== -// Keccak-256 (Ethereum, delimiter 0x01) -// =========================================================================== - -constant ulong RC[24] = { - 0x0000000000000001UL, 0x0000000000008082UL, - 0x800000000000808AUL, 0x8000000080008000UL, - 0x000000000000808BUL, 0x0000000080000001UL, - 0x8000000080008081UL, 0x8000000000008009UL, - 0x000000000000008AUL, 0x0000000000000088UL, - 0x0000000080008009UL, 0x000000008000000AUL, - 0x000000008000808BUL, 0x800000000000008BUL, - 0x8000000000008089UL, 0x8000000000008003UL, - 0x8000000000008002UL, 0x8000000000000080UL, - 0x000000000000800AUL, 0x800000008000000AUL, - 0x8000000080008081UL, 0x8000000000008080UL, - 0x0000000080000001UL, 0x8000000080008008UL, -}; - -// Same offsets as src/keccak/keccak.cpp (mod 64). -constant int R_OFFSETS[5][5] = { - { 0, 36, 3, 41, 18}, - { 1, 44, 10, 45, 2}, - { 62, 6, 43, 15, 61}, - { 28, 55, 25, 21, 56}, - { 27, 20, 39, 8, 14}, -}; - -inline ulong rotl64(ulong x, int n) { - n &= 63; - if (n == 0) return x; - return (x << n) | (x >> (64 - n)); -} - -inline void keccakf1600(thread ulong* a) { - ulong C[5], D[5], B[25]; - for (int round = 0; round < 24; ++round) { - for (int x = 0; x < 5; ++x) - C[x] = a[x] ^ a[x + 5] ^ a[x + 10] ^ a[x + 15] ^ a[x + 20]; - for (int x = 0; x < 5; ++x) - D[x] = C[(x + 4) % 5] ^ rotl64(C[(x + 1) % 5], 1); - for (int y = 0; y < 5; ++y) - for (int x = 0; x < 5; ++x) - a[x + 5 * y] ^= D[x]; - - for (int x = 0; x < 5; ++x) - for (int y = 0; y < 5; ++y) { - int nx = y; - int ny = (2 * x + 3 * y) % 5; - B[nx + 5 * ny] = rotl64(a[x + 5 * y], R_OFFSETS[x][y]); - } - - for (int y = 0; y < 5; ++y) { - ulong row[5]; - for (int x = 0; x < 5; ++x) row[x] = B[x + 5 * y]; - for (int x = 0; x < 5; ++x) - a[x + 5 * y] = row[x] ^ ((~row[(x + 1) % 5]) & row[(x + 2) % 5]); - } - - a[0] ^= RC[round]; - } -} - -inline void keccak256_64(thread const uchar* in, thread uchar* out) { - const int RATE = 136; - (void)RATE; - ulong state[25]; - for (int i = 0; i < 25; ++i) state[i] = 0; - - uchar block[136]; - for (int i = 0; i < 64; ++i) block[i] = in[i]; - for (int i = 64; i < 136; ++i) block[i] = 0; - block[64] = 0x01; - block[135] |= 0x80; - - for (int j = 0; j < 17; ++j) { - ulong v = 0; - for (int b = 0; b < 8; ++b) v |= ((ulong)block[j * 8 + b]) << (8 * b); - state[j] ^= v; - } - keccakf1600(state); - - for (int j = 0; j < 4; ++j) { - ulong v = state[j]; - for (int b = 0; b < 8; ++b) { - out[j * 8 + b] = (uchar)(v & 0xFF); - v >>= 8; - } - } -} - -// =========================================================================== -// I/O structs -// =========================================================================== - -struct EcrecoverInput { - uchar hash[32]; - uchar r[32]; - uchar s[32]; - uchar v; - uchar _pad[15]; -}; - -struct EcrecoverOutput { - uchar address[20]; - uchar valid; - uchar _pad[11]; -}; - -inline uint256 load_be32(thread const uchar* b) { - uint256 r; - for (int limb = 0; limb < 4; ++limb) { - ulong v = 0; - int base = (3 - limb) * 8; - for (int i = 0; i < 8; ++i) v = (v << 8) | (ulong)b[base + i]; - r.limbs[limb] = v; - } - return r; -} - -inline void store_be32(uint256 a, thread uchar* b) { - for (int limb = 0; limb < 4; ++limb) { - int base = (3 - limb) * 8; - ulong v = a.limbs[limb]; - for (int i = 7; i >= 0; --i) { - b[base + i] = (uchar)(v & 0xFF); - v >>= 8; - } - } -} - -// =========================================================================== -// Main kernel -// =========================================================================== - -kernel void secp256k1_ecrecover_batch( - device const EcrecoverInput* inputs [[buffer(0)]], - device EcrecoverOutput* outputs [[buffer(1)]], - constant uint& num [[buffer(2)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= num) return; - device const EcrecoverInput& in = inputs[tid]; - device EcrecoverOutput& out = outputs[tid]; - - for (int i = 0; i < 20; ++i) out.address[i] = 0; - out.valid = 0; - for (int i = 0; i < 11; ++i) out._pad[i] = 0; - - uchar hash_buf[32], r_buf[32], s_buf[32]; - for (int i = 0; i < 32; ++i) hash_buf[i] = in.hash[i]; - for (int i = 0; i < 32; ++i) r_buf[i] = in.r[i]; - for (int i = 0; i < 32; ++i) s_buf[i] = in.s[i]; - uchar v = in.v; - if (v >= 27) v -= 27; - if (v > 1) v %= 2; - - uint256 r = load_be32(r_buf); - uint256 s = load_be32(s_buf); - uint256 e = load_be32(hash_buf); - - if (u256_is_zero(r) || u256_cmp(r, N_MOD) >= 0) return; - if (u256_is_zero(s) || u256_cmp(s, N_MOD) >= 0) return; - if (v > 1) return; - - uint256 r_pm = to_mont(r, R2_P, P_MOD, P_INV); - uint256 r2m = fp_sqr(r_pm); - uint256 r3m = fp_mul(r2m, r_pm); - uint256 sevn = {{7, 0, 0, 0}}; - uint256 sevnm = to_mont(sevn, R2_P, P_MOD, P_INV); - uint256 y2m = fp_add(r3m, sevnm); - - uint256 ym; - if (!fp_sqrt(y2m, ym)) return; - uint256 yn = from_mont(ym, P_MOD, P_INV); - bool y_odd = (yn.limbs[0] & 1UL) != 0; - bool want_odd = (v == 1); - if (y_odd != want_odd) ym = fp_sub(ZERO, ym); - - AffinePt R_pt; R_pt.x = r_pm; R_pt.y = ym; R_pt.inf = false; - - uint256 e_red = e; - if (u256_cmp(e_red, N_MOD) >= 0) { - ulong bw; e_red = sub_256(e_red, N_MOD, bw); - } - uint256 r_nm = to_mont(r, R2_N, N_MOD, N_INV); - uint256 r_inv = fn_inv(r_nm); - uint256 e_nm = to_mont(e_red, R2_N, N_MOD, N_INV); - uint256 s_nm = to_mont(s, R2_N, N_MOD, N_INV); - uint256 u1_nm = fn_mul(e_nm, r_inv); - uint256 u1_norm = from_mont(u1_nm, N_MOD, N_INV); - if (!u256_is_zero(u1_norm)) { - ulong bw; u1_norm = sub_256(N_MOD, u1_norm, bw); - } - uint256 u2_norm = from_mont(fn_mul(s_nm, r_inv), N_MOD, N_INV); - - AffinePt G; - G.x = to_mont(GX_PLAIN, R2_P, P_MOD, P_INV); - G.y = to_mont(GY_PLAIN, R2_P, P_MOD, P_INV); - G.inf = false; - - JacPt Q1 = jac_mul(u1_norm, G); - JacPt Q2 = jac_mul(u2_norm, R_pt); - JacPt Q = jac_add(Q1, Q2); - AffinePt Qa = jac_to_aff(Q); - if (Qa.inf) return; - - uint256 qx = from_mont(Qa.x, P_MOD, P_INV); - uint256 qy = from_mont(Qa.y, P_MOD, P_INV); - - uchar pubkey[64]; - store_be32(qx, pubkey); - store_be32(qy, pubkey + 32); - - uchar hash_addr[32]; - keccak256_64(pubkey, hash_addr); - - for (int i = 0; i < 20; ++i) out.address[i] = hash_addr[12 + i]; - out.valid = 1; -} diff --git a/secp256k1/gpu/metal/secp256k1_authored.metal b/secp256k1/gpu/metal/secp256k1_authored.metal deleted file mode 100644 index 9748752..0000000 --- a/secp256k1/gpu/metal/secp256k1_authored.metal +++ /dev/null @@ -1,1397 +0,0 @@ -// ============================================================================= -// secp256k1 GPU Kernels for Metal (Apple Silicon) -// ============================================================================= -// -// GPU-accelerated secp256k1 operations using the GTable approach. -// Based on CudaBrainSecp optimization for ~20x speedup over double-and-add. -// -// GTable Structure: -// - 16 chunks × 65536 points = 1,048,576 precomputed points (~67MB) -// - Scalar multiplication: 16 table lookups + 15 point additions -// - Perfect for batch operations on Apple Silicon GPU -// -// Copyright (C) 2024-2025 Lux Industries Inc. -// SPDX-License-Identifier: Apache-2.0 - -#include -using namespace metal; - -// ============================================================================= -// secp256k1 Field Constants -// ============================================================================= -// -// Prime: p = 2^256 - 2^32 - 977 -// = 0xFFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE FFFFFC2F -// -// Order: n = 0xFFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE BAAEDCE6 AF48A03B BFD25E8C D0364141 -// -// Generator: G = (Gx, Gy) where: -// Gx = 0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798 -// Gy = 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8 - -// secp256k1 prime (little-endian limbs) -constant uint64_t SECP256K1_P[4] = { - 0xFFFFFFFEFFFFFC2FULL, - 0xFFFFFFFFFFFFFFFFULL, - 0xFFFFFFFFFFFFFFFFULL, - 0xFFFFFFFFFFFFFFFFULL -}; - -// secp256k1 order (little-endian limbs) -constant uint64_t SECP256K1_N[4] = { - 0xBFD25E8CD0364141ULL, - 0xBAAEDCE6AF48A03BULL, - 0xFFFFFFFFFFFFFFFEULL, - 0xFFFFFFFFFFFFFFFFULL -}; - -// Generator point Gx (little-endian) -constant uint64_t SECP256K1_GX[4] = { - 0x59F2815B16F81798ULL, - 0x029BFCDB2DCE28D9ULL, - 0x55A06295CE870B07ULL, - 0x79BE667EF9DCBBACULL -}; - -// Generator point Gy (little-endian) -constant uint64_t SECP256K1_GY[4] = { - 0x9C47D08FFB10D4B8ULL, - 0xFD17B448A6855419ULL, - 0x5DA4FBFC0E1108A8ULL, - 0x483ADA7726A3C465ULL -}; - -// ============================================================================= -// Types -// ============================================================================= - -struct Fp256 { - uint64_t limbs[4]; -}; - -struct Scalar256 { - uint64_t limbs[4]; -}; - -struct AffinePoint { - Fp256 x; - Fp256 y; - bool infinity; -}; - -struct JacobianPoint { - Fp256 x; - Fp256 y; - Fp256 z; -}; - -// ============================================================================= -// 256-bit Arithmetic Primitives -// ============================================================================= - -// Add with carry -inline uint64_t adc(uint64_t a, uint64_t b, thread uint64_t& carry) { - uint64_t sum = a + b + carry; - carry = (sum < a || (carry && sum == a)) ? 1 : 0; - return sum; -} - -// Subtract with borrow -inline uint64_t sbb(uint64_t a, uint64_t b, thread uint64_t& borrow) { - uint64_t diff = a - b - borrow; - borrow = (a < b + borrow) ? 1 : 0; - return diff; -} - -// 256-bit addition -inline void add256(thread Fp256& c, Fp256 a, Fp256 b) { - uint64_t carry = 0; - for (int i = 0; i < 4; i++) { - c.limbs[i] = adc(a.limbs[i], b.limbs[i], carry); - } -} - -// 256-bit subtraction -inline void sub256(thread Fp256& c, Fp256 a, Fp256 b) { - uint64_t borrow = 0; - for (int i = 0; i < 4; i++) { - c.limbs[i] = sbb(a.limbs[i], b.limbs[i], borrow); - } -} - -// Compare: return true if a >= b -inline bool gte256(Fp256 a, Fp256 b) { - for (int i = 3; i >= 0; i--) { - if (a.limbs[i] > b.limbs[i]) return true; - if (a.limbs[i] < b.limbs[i]) return false; - } - return true; // Equal -} - -// Check if zero -inline bool is_zero256(Fp256 a) { - return a.limbs[0] == 0 && a.limbs[1] == 0 && - a.limbs[2] == 0 && a.limbs[3] == 0; -} - -// ============================================================================= -// secp256k1 Field Arithmetic -// ============================================================================= - -// Modular reduction: c = a mod p -// Uses secp256k1's special form: p = 2^256 - 2^32 - 977 -inline void fp_reduce(thread Fp256& c, Fp256 a) { - // Check if a >= p - Fp256 p = {{SECP256K1_P[0], SECP256K1_P[1], SECP256K1_P[2], SECP256K1_P[3]}}; - if (gte256(a, p)) { - sub256(c, a, p); - } else { - c = a; - } -} - -// Field addition: c = (a + b) mod p -inline void fp_add(thread Fp256& c, Fp256 a, Fp256 b) { - add256(c, a, b); - Fp256 p = {{SECP256K1_P[0], SECP256K1_P[1], SECP256K1_P[2], SECP256K1_P[3]}}; - if (gte256(c, p)) { - sub256(c, c, p); - } -} - -// Field subtraction: c = (a - b) mod p -inline void fp_sub(thread Fp256& c, Fp256 a, Fp256 b) { - if (gte256(a, b)) { - sub256(c, a, b); - } else { - Fp256 p = {{SECP256K1_P[0], SECP256K1_P[1], SECP256K1_P[2], SECP256K1_P[3]}}; - add256(c, a, p); - sub256(c, c, b); - } -} - -// Field negation: c = -a mod p -inline void fp_neg(thread Fp256& c, Fp256 a) { - if (is_zero256(a)) { - c = a; - } else { - Fp256 p = {{SECP256K1_P[0], SECP256K1_P[1], SECP256K1_P[2], SECP256K1_P[3]}}; - sub256(c, p, a); - } -} - -// Field doubling: c = 2*a mod p (faster than add) -inline void fp_double(thread Fp256& c, Fp256 a) { - fp_add(c, a, a); -} - -// Field multiplication: c = a * b mod p -// Uses schoolbook multiplication with reduction optimized for secp256k1 -inline void fp_mul(thread Fp256& c, Fp256 a, Fp256 b) { - // Full 512-bit product - uint64_t t[8] = {0}; - - // Schoolbook multiplication - for (int i = 0; i < 4; i++) { - uint64_t carry = 0; - for (int j = 0; j < 4; j++) { - // 64x64 -> 128-bit multiplication - uint64_t lo = a.limbs[i] * b.limbs[j]; - uint64_t hi = mulhi(a.limbs[i], b.limbs[j]); - - uint64_t sum = t[i+j] + lo; - uint64_t c1 = (sum < t[i+j]) ? 1 : 0; - sum += carry; - c1 += (sum < carry) ? 1 : 0; - t[i+j] = sum; - carry = hi + c1; - } - t[i+4] = carry; - } - - // Reduction using p = 2^256 - 2^32 - 977 - // t mod p = t[0..3] + (t[4..7] * 2^256) mod p - // = t[0..3] + t[4..7] * (2^32 + 977) mod p - - // First reduction round - uint64_t carry = 0; - for (int i = 0; i < 4; i++) { - // Add t[i+4] * (2^32 + 977) to t[i] - uint64_t lo = t[i+4] * 0x1000003D1ULL; // 2^32 + 977 - uint64_t hi = mulhi(t[i+4], 0x1000003D1ULL); - - uint64_t sum = t[i] + lo + carry; - carry = (sum < t[i] || (carry && sum == t[i] + lo)) ? 1 : 0; - carry += hi; - t[i] = sum; - } - - // Handle remaining carry - while (carry > 0) { - uint64_t lo = carry * 0x1000003D1ULL; - uint64_t hi = mulhi(carry, 0x1000003D1ULL); - - uint64_t sum = t[0] + lo; - uint64_t c1 = (sum < t[0]) ? 1 : 0; - t[0] = sum; - - sum = t[1] + c1; - c1 = (sum < t[1]) ? 1 : 0; - t[1] = sum; - - sum = t[2] + c1; - c1 = (sum < t[2]) ? 1 : 0; - t[2] = sum; - - sum = t[3] + c1; - c1 = (sum < t[3]) ? 1 : 0; - t[3] = sum; - - carry = hi + c1; - } - - // Final reduction if needed - c.limbs[0] = t[0]; - c.limbs[1] = t[1]; - c.limbs[2] = t[2]; - c.limbs[3] = t[3]; - fp_reduce(c, c); -} - -// Field squaring: c = a^2 mod p (optimized Comba squaring) -// Exploits symmetry: 2*a[i]*a[j] for i != j -inline void fp_sqr(thread Fp256& c, Fp256 a) { - // Full 512-bit product using Comba squaring - uint64_t t[8] = {0}; - - // Diagonal terms: a[i]^2 - for (int i = 0; i < 4; i++) { - uint64_t lo = a.limbs[i] * a.limbs[i]; - uint64_t hi = mulhi(a.limbs[i], a.limbs[i]); - - uint64_t sum = t[2*i] + lo; - uint64_t carry = (sum < t[2*i]) ? 1 : 0; - t[2*i] = sum; - - sum = t[2*i + 1] + hi + carry; - carry = (sum < t[2*i + 1] || (carry && sum == t[2*i + 1] + hi)) ? 1 : 0; - t[2*i + 1] = sum; - - // Propagate carry - for (int k = 2*i + 2; k < 8 && carry; k++) { - sum = t[k] + carry; - carry = (sum < t[k]) ? 1 : 0; - t[k] = sum; - } - } - - // Off-diagonal terms: 2*a[i]*a[j] for i < j - for (int i = 0; i < 4; i++) { - for (int j = i + 1; j < 4; j++) { - uint64_t lo = a.limbs[i] * a.limbs[j]; - uint64_t hi = mulhi(a.limbs[i], a.limbs[j]); - - // Double the cross term - uint64_t hi2 = hi >> 63; - hi = (hi << 1) | (lo >> 63); - lo = lo << 1; - - uint64_t sum = t[i+j] + lo; - uint64_t carry = (sum < t[i+j]) ? 1 : 0; - t[i+j] = sum; - - sum = t[i+j+1] + hi + carry; - carry = (sum < t[i+j+1] || (carry && sum == t[i+j+1] + hi)) ? 1 : 0; - t[i+j+1] = sum; - - carry += hi2; - - // Propagate carry - for (int k = i+j + 2; k < 8 && carry; k++) { - sum = t[k] + carry; - carry = (sum < t[k]) ? 1 : 0; - t[k] = sum; - } - } - } - - // Reduction using p = 2^256 - 2^32 - 977 - uint64_t carry = 0; - for (int i = 0; i < 4; i++) { - uint64_t lo = t[i+4] * 0x1000003D1ULL; - uint64_t hi = mulhi(t[i+4], 0x1000003D1ULL); - - uint64_t sum = t[i] + lo + carry; - carry = (sum < t[i] || (carry && sum == t[i] + lo)) ? 1 : 0; - carry += hi; - t[i] = sum; - } - - // Handle remaining carry - while (carry > 0) { - uint64_t lo = carry * 0x1000003D1ULL; - uint64_t hi = mulhi(carry, 0x1000003D1ULL); - - uint64_t sum = t[0] + lo; - uint64_t c1 = (sum < t[0]) ? 1 : 0; - t[0] = sum; - - sum = t[1] + c1; - c1 = (sum < t[1]) ? 1 : 0; - t[1] = sum; - - sum = t[2] + c1; - c1 = (sum < t[2]) ? 1 : 0; - t[2] = sum; - - sum = t[3] + c1; - c1 = (sum < t[3]) ? 1 : 0; - t[3] = sum; - - carry = hi + c1; - } - - c.limbs[0] = t[0]; - c.limbs[1] = t[1]; - c.limbs[2] = t[2]; - c.limbs[3] = t[3]; - fp_reduce(c, c); -} - -// Field inversion: c = a^(-1) mod p using Fermat's little theorem -// a^(-1) = a^(p-2) mod p -inline void fp_inv(thread Fp256& c, Fp256 a) { - // Exponent: p - 2 = 0xFFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF - // FFFFFFFF FFFFFFFF FFFFFFFE FFFFFC2D - - Fp256 result = {{1, 0, 0, 0}}; - Fp256 base = a; - - // Binary exponentiation - // Most bits are 1, so we optimize for that - for (int limb = 0; limb < 4; limb++) { - uint64_t exp_limb; - if (limb == 0) { - exp_limb = 0xFFFFFFFEFFFFFC2DULL; // p[0] - 2 - } else { - exp_limb = 0xFFFFFFFFFFFFFFFFULL; - } - - for (int bit = 0; bit < 64; bit++) { - if ((exp_limb >> bit) & 1) { - fp_mul(result, result, base); - } - fp_sqr(base, base); - } - } - - c = result; -} - -// ============================================================================= -// Point Operations -// ============================================================================= - -// Convert Jacobian to Affine: (X, Y, Z) -> (X/Z^2, Y/Z^3) -inline AffinePoint jacobian_to_affine(JacobianPoint p) { - if (is_zero256(p.z)) { - AffinePoint inf; - inf.infinity = true; - return inf; - } - - Fp256 z_inv, z_inv2, z_inv3; - fp_inv(z_inv, p.z); - fp_sqr(z_inv2, z_inv); - fp_mul(z_inv3, z_inv2, z_inv); - - AffinePoint result; - fp_mul(result.x, p.x, z_inv2); - fp_mul(result.y, p.y, z_inv3); - result.infinity = false; - - return result; -} - -// Convert Affine to Jacobian -inline JacobianPoint affine_to_jacobian(AffinePoint p) { - JacobianPoint result; - if (p.infinity) { - result.x = {{0, 0, 0, 0}}; - result.y = {{1, 0, 0, 0}}; - result.z = {{0, 0, 0, 0}}; - } else { - result.x = p.x; - result.y = p.y; - result.z = {{1, 0, 0, 0}}; - } - return result; -} - -// Point doubling in Jacobian coordinates -// Formula from https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-0.html -// secp256k1 has a = 0, so we use optimized formula -inline JacobianPoint point_double(JacobianPoint p) { - if (is_zero256(p.z)) { - return p; // Point at infinity - } - - Fp256 s, m, t, x3, y3, z3; - Fp256 yy, yyyy, xx; - - // YY = Y^2 - fp_sqr(yy, p.y); - - // S = 4*X*YY - fp_mul(s, p.x, yy); - fp_double(s, s); - fp_double(s, s); - - // M = 3*X^2 (a=0 for secp256k1) - fp_sqr(xx, p.x); - fp_add(m, xx, xx); - fp_add(m, m, xx); - - // T = M^2 - 2*S - fp_sqr(t, m); - Fp256 two_s; - fp_double(two_s, s); - fp_sub(t, t, two_s); - - // X3 = T - x3 = t; - - // Y3 = M*(S-T) - 8*YYYY - fp_sqr(yyyy, yy); - Fp256 eight_yyyy; - fp_double(eight_yyyy, yyyy); - fp_double(eight_yyyy, eight_yyyy); - fp_double(eight_yyyy, eight_yyyy); - - Fp256 s_minus_t; - fp_sub(s_minus_t, s, t); - fp_mul(y3, m, s_minus_t); - fp_sub(y3, y3, eight_yyyy); - - // Z3 = 2*Y*Z - fp_mul(z3, p.y, p.z); - fp_double(z3, z3); - - JacobianPoint result; - result.x = x3; - result.y = y3; - result.z = z3; - return result; -} - -// Point addition: Jacobian + Affine -> Jacobian (mixed addition) -// More efficient when one point is affine -inline JacobianPoint point_add_mixed(JacobianPoint p, AffinePoint q) { - if (q.infinity) { - return p; - } - if (is_zero256(p.z)) { - return affine_to_jacobian(q); - } - - Fp256 z1z1, u2, s2, h, hh, hhh, r, v; - - // Z1Z1 = Z1^2 - fp_sqr(z1z1, p.z); - - // U2 = X2 * Z1Z1 - fp_mul(u2, q.x, z1z1); - - // S2 = Y2 * Z1 * Z1Z1 - Fp256 z1_z1z1; - fp_mul(z1_z1z1, p.z, z1z1); - fp_mul(s2, q.y, z1_z1z1); - - // H = U2 - X1 - fp_sub(h, u2, p.x); - - // HH = H^2 - fp_sqr(hh, h); - - // HHH = H * HH - fp_mul(hhh, h, hh); - - // R = S2 - Y1 - fp_sub(r, s2, p.y); - - // V = X1 * HH - fp_mul(v, p.x, hh); - - // X3 = R^2 - HHH - 2*V - Fp256 x3, y3, z3; - Fp256 rr, two_v; - fp_sqr(rr, r); - fp_double(two_v, v); - fp_sub(x3, rr, hhh); - fp_sub(x3, x3, two_v); - - // Y3 = R*(V - X3) - Y1*HHH - Fp256 v_minus_x3, y1_hhh; - fp_sub(v_minus_x3, v, x3); - fp_mul(y3, r, v_minus_x3); - fp_mul(y1_hhh, p.y, hhh); - fp_sub(y3, y3, y1_hhh); - - // Z3 = Z1 * H - fp_mul(z3, p.z, h); - - JacobianPoint result; - result.x = x3; - result.y = y3; - result.z = z3; - return result; -} - -// Full Jacobian + Jacobian addition -inline JacobianPoint point_add(JacobianPoint p, JacobianPoint q) { - if (is_zero256(p.z)) return q; - if (is_zero256(q.z)) return p; - - Fp256 z1z1, z2z2, u1, u2, s1, s2, h, r; - - fp_sqr(z1z1, p.z); - fp_sqr(z2z2, q.z); - - fp_mul(u1, p.x, z2z2); - fp_mul(u2, q.x, z1z1); - - Fp256 z1_z1z1, z2_z2z2; - fp_mul(z1_z1z1, p.z, z1z1); - fp_mul(z2_z2z2, q.z, z2z2); - - fp_mul(s1, p.y, z2_z2z2); - fp_mul(s2, q.y, z1_z1z1); - - fp_sub(h, u2, u1); - fp_sub(r, s2, s1); - - // Check if P = Q (need doubling) or P = -Q (result is infinity) - if (is_zero256(h)) { - if (is_zero256(r)) { - return point_double(p); - } - // P = -Q, return infinity - JacobianPoint inf; - inf.x = {{0, 0, 0, 0}}; - inf.y = {{1, 0, 0, 0}}; - inf.z = {{0, 0, 0, 0}}; - return inf; - } - - Fp256 hh, hhh, v; - fp_sqr(hh, h); - fp_mul(hhh, h, hh); - fp_mul(v, u1, hh); - - Fp256 x3, y3, z3; - Fp256 rr, two_v; - fp_sqr(rr, r); - fp_double(two_v, v); - fp_sub(x3, rr, hhh); - fp_sub(x3, x3, two_v); - - Fp256 v_minus_x3, s1_hhh; - fp_sub(v_minus_x3, v, x3); - fp_mul(y3, r, v_minus_x3); - fp_mul(s1_hhh, s1, hhh); - fp_sub(y3, y3, s1_hhh); - - Fp256 z1_z2; - fp_mul(z1_z2, p.z, q.z); - fp_mul(z3, z1_z2, h); - - JacobianPoint result; - result.x = x3; - result.y = y3; - result.z = z3; - return result; -} - -// ============================================================================= -// GTable-based Scalar Multiplication -// ============================================================================= -// -// The GTable stores precomputed multiples of G: -// table[i][j] = (j * 2^(16*i)) * G for i in 0..15, j in 0..65535 -// -// To compute k * G for 256-bit scalar k: -// 1. Split k into 16 chunks of 16 bits each: k = k[0] + k[1]*2^16 + ... + k[15]*2^240 -// 2. Look up table[i][k[i]] for each chunk -// 3. Sum the 16 looked-up points -// -// This requires 16 lookups + 15 additions, vs 256 doublings + ~128 additions for double-and-add - -// Extract 16-bit chunk from scalar -inline uint32_t get_scalar_chunk(Scalar256 s, uint32_t chunk_idx) { - uint32_t bit_idx = chunk_idx * 16; - uint32_t limb_idx = bit_idx / 64; - uint32_t bit_offset = bit_idx % 64; - - uint64_t value = s.limbs[limb_idx] >> bit_offset; - - // Handle crossing limb boundary - if (bit_offset > 48 && limb_idx < 3) { - value |= s.limbs[limb_idx + 1] << (64 - bit_offset); - } - - return value & 0xFFFF; -} - -// GTable scalar multiplication kernel -// Each thread computes one scalar multiplication using the precomputed table -kernel void gtable_scalar_mul( - device const AffinePoint* gtable [[buffer(0)]], // 16 * 65536 precomputed points - device const Scalar256* scalars [[buffer(1)]], - device AffinePoint* results [[buffer(2)]], - constant uint32_t& count [[buffer(3)]], - uint index [[thread_position_in_grid]] -) { - if (index >= count) return; - - Scalar256 scalar = scalars[index]; - - // Initialize accumulator to identity (will be set on first non-zero chunk) - JacobianPoint acc; - acc.x = {{0, 0, 0, 0}}; - acc.y = {{1, 0, 0, 0}}; - acc.z = {{0, 0, 0, 0}}; - bool started = false; - - // Process all 16 chunks - for (uint32_t i = 0; i < 16; i++) { - uint32_t chunk = get_scalar_chunk(scalar, i); - - if (chunk != 0) { - // Look up point from table - uint32_t table_idx = i * 65536 + (chunk - 1); // -1 because we skip 0 - AffinePoint p = gtable[table_idx]; - - if (started) { - acc = point_add_mixed(acc, p); - } else { - acc = affine_to_jacobian(p); - started = true; - } - } - } - - // Convert to affine for output - if (started) { - results[index] = jacobian_to_affine(acc); - } else { - results[index].infinity = true; - } -} - -// ============================================================================= -// Double-and-Add Scalar Multiplication (for arbitrary base points) -// ============================================================================= - -kernel void scalar_mul_general( - device const AffinePoint* points [[buffer(0)]], - device const Scalar256* scalars [[buffer(1)]], - device AffinePoint* results [[buffer(2)]], - constant uint32_t& count [[buffer(3)]], - uint index [[thread_position_in_grid]] -) { - if (index >= count) return; - - AffinePoint base = points[index]; - Scalar256 scalar = scalars[index]; - - if (base.infinity) { - results[index].infinity = true; - return; - } - - JacobianPoint acc; - acc.x = {{0, 0, 0, 0}}; - acc.y = {{1, 0, 0, 0}}; - acc.z = {{0, 0, 0, 0}}; - bool started = false; - - // Double-and-add from MSB - for (int limb = 3; limb >= 0; limb--) { - for (int bit = 63; bit >= 0; bit--) { - if (started) { - acc = point_double(acc); - } - - if ((scalar.limbs[limb] >> bit) & 1) { - if (started) { - acc = point_add_mixed(acc, base); - } else { - acc = affine_to_jacobian(base); - started = true; - } - } - } - } - - if (started) { - results[index] = jacobian_to_affine(acc); - } else { - results[index].infinity = true; - } -} - -// ============================================================================= -// Point Addition Batch Kernel -// ============================================================================= - -kernel void point_add_batch( - device const AffinePoint* a [[buffer(0)]], - device const AffinePoint* b [[buffer(1)]], - device AffinePoint* results [[buffer(2)]], - constant uint32_t& count [[buffer(3)]], - uint index [[thread_position_in_grid]] -) { - if (index >= count) return; - - AffinePoint pa = a[index]; - AffinePoint pb = b[index]; - - if (pa.infinity) { - results[index] = pb; - return; - } - if (pb.infinity) { - results[index] = pa; - return; - } - - JacobianPoint ja = affine_to_jacobian(pa); - JacobianPoint sum = point_add_mixed(ja, pb); - results[index] = jacobian_to_affine(sum); -} - -// ============================================================================= -// Point Doubling Batch Kernel -// ============================================================================= - -kernel void point_double_batch( - device const AffinePoint* points [[buffer(0)]], - device AffinePoint* results [[buffer(1)]], - constant uint32_t& count [[buffer(2)]], - uint index [[thread_position_in_grid]] -) { - if (index >= count) return; - - AffinePoint p = points[index]; - - if (p.infinity) { - results[index] = p; - return; - } - - JacobianPoint jp = affine_to_jacobian(p); - JacobianPoint doubled = point_double(jp); - results[index] = jacobian_to_affine(doubled); -} - -// ============================================================================= -// Scalar Field Arithmetic (mod n) -// ============================================================================= - -// Scalar addition: c = (a + b) mod n -inline void sc_add(thread Scalar256& c, Scalar256 a, Scalar256 b) { - uint64_t carry = 0; - for (int i = 0; i < 4; i++) { - uint64_t sum = a.limbs[i] + b.limbs[i] + carry; - carry = (sum < a.limbs[i] || (carry && sum == a.limbs[i])) ? 1 : 0; - c.limbs[i] = sum; - } - // Reduce mod n - Scalar256 n = {{SECP256K1_N[0], SECP256K1_N[1], SECP256K1_N[2], SECP256K1_N[3]}}; - bool gte = true; - for (int i = 3; i >= 0; i--) { - if (c.limbs[i] > n.limbs[i]) break; - if (c.limbs[i] < n.limbs[i]) { gte = false; break; } - } - if (gte || carry) { - uint64_t borrow = 0; - for (int i = 0; i < 4; i++) { - uint64_t diff = c.limbs[i] - n.limbs[i] - borrow; - borrow = (c.limbs[i] < n.limbs[i] + borrow) ? 1 : 0; - c.limbs[i] = diff; - } - } -} - -// Scalar subtraction: c = (a - b) mod n -inline void sc_sub(thread Scalar256& c, Scalar256 a, Scalar256 b) { - bool gte = true; - for (int i = 3; i >= 0; i--) { - if (a.limbs[i] > b.limbs[i]) break; - if (a.limbs[i] < b.limbs[i]) { gte = false; break; } - } - if (gte) { - uint64_t borrow = 0; - for (int i = 0; i < 4; i++) { - uint64_t diff = a.limbs[i] - b.limbs[i] - borrow; - borrow = (a.limbs[i] < b.limbs[i] + borrow) ? 1 : 0; - c.limbs[i] = diff; - } - } else { - // a < b, so compute a + n - b - Scalar256 n = {{SECP256K1_N[0], SECP256K1_N[1], SECP256K1_N[2], SECP256K1_N[3]}}; - uint64_t carry = 0; - for (int i = 0; i < 4; i++) { - uint64_t sum = a.limbs[i] + n.limbs[i] + carry; - carry = (sum < a.limbs[i] || (carry && sum == a.limbs[i])) ? 1 : 0; - c.limbs[i] = sum; - } - uint64_t borrow = 0; - for (int i = 0; i < 4; i++) { - uint64_t diff = c.limbs[i] - b.limbs[i] - borrow; - borrow = (c.limbs[i] < b.limbs[i] + borrow) ? 1 : 0; - c.limbs[i] = diff; - } - } -} - -// Scalar multiplication: c = a * b mod n -inline void sc_mul(thread Scalar256& c, Scalar256 a, Scalar256 b) { - // Full 512-bit product - uint64_t t[8] = {0}; - - for (int i = 0; i < 4; i++) { - uint64_t carry = 0; - for (int j = 0; j < 4; j++) { - uint64_t lo = a.limbs[i] * b.limbs[j]; - uint64_t hi = mulhi(a.limbs[i], b.limbs[j]); - - uint64_t sum = t[i+j] + lo; - uint64_t c1 = (sum < t[i+j]) ? 1 : 0; - sum += carry; - c1 += (sum < carry) ? 1 : 0; - t[i+j] = sum; - carry = hi + c1; - } - t[i+4] = carry; - } - - // Barrett reduction mod n - // For secp256k1, n is close to 2^256, so we use simple reduction - Scalar256 n = {{SECP256K1_N[0], SECP256K1_N[1], SECP256K1_N[2], SECP256K1_N[3]}}; - - // Reduce high 256 bits: multiply by 2^256 mod n and add to low - // 2^256 mod n = 2^256 - n = 0x014551231950B75FC4402DA1732FC9BEBF - constant uint64_t R_MOD_N[4] = { - 0x402DA1732FC9BEBFULL, - 0x4551231950B75FC4ULL, - 0x0000000000000001ULL, - 0x0000000000000000ULL - }; - - uint64_t carry = 0; - for (int i = 0; i < 4; i++) { - uint64_t lo = 0, hi = 0; - for (int j = 0; j <= i; j++) { - if (j < 4 && (i-j) < 4) { - uint64_t plo = t[4+j] * R_MOD_N[i-j]; - uint64_t phi = mulhi(t[4+j], R_MOD_N[i-j]); - uint64_t sum = lo + plo; - uint64_t c1 = (sum < lo) ? 1 : 0; - lo = sum; - hi += phi + c1; - } - } - uint64_t sum = t[i] + lo + carry; - carry = (sum < t[i] || (carry && sum == t[i] + lo)) ? 1 : 0; - carry += hi; - t[i] = sum; - } - - // Handle remaining carry with reduction - while (carry > 0) { - uint64_t lo = carry * R_MOD_N[0]; - uint64_t hi = mulhi(carry, R_MOD_N[0]); - - uint64_t sum = t[0] + lo; - uint64_t c1 = (sum < t[0]) ? 1 : 0; - t[0] = sum; - - sum = t[1] + (carry * R_MOD_N[1]) + c1; - c1 = (sum < t[1]) ? 1 : 0; - t[1] = sum; - - sum = t[2] + carry + c1; - c1 = (sum < t[2]) ? 1 : 0; - t[2] = sum; - - sum = t[3] + c1; - c1 = (sum < t[3]) ? 1 : 0; - t[3] = sum; - - carry = hi + mulhi(carry, R_MOD_N[1]) + c1; - } - - // Final reduction if >= n - c.limbs[0] = t[0]; - c.limbs[1] = t[1]; - c.limbs[2] = t[2]; - c.limbs[3] = t[3]; - - bool gte = true; - for (int i = 3; i >= 0; i--) { - if (c.limbs[i] > n.limbs[i]) break; - if (c.limbs[i] < n.limbs[i]) { gte = false; break; } - } - if (gte) { - uint64_t borrow = 0; - for (int i = 0; i < 4; i++) { - uint64_t diff = c.limbs[i] - n.limbs[i] - borrow; - borrow = (c.limbs[i] < n.limbs[i] + borrow) ? 1 : 0; - c.limbs[i] = diff; - } - } -} - -// Scalar inversion: c = a^(-1) mod n using Fermat's little theorem -// a^(-1) = a^(n-2) mod n -inline void sc_inv(thread Scalar256& c, Scalar256 a) { - // n - 2 = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD036413F - Scalar256 result = {{1, 0, 0, 0}}; - Scalar256 base = a; - - // Binary exponentiation with n-2 - constant uint64_t N_MINUS_2[4] = { - 0xBFD25E8CD036413FULL, - 0xBAAEDCE6AF48A03BULL, - 0xFFFFFFFFFFFFFFFEULL, - 0xFFFFFFFFFFFFFFFFULL - }; - - for (int limb = 0; limb < 4; limb++) { - uint64_t exp_limb = N_MINUS_2[limb]; - for (int bit = 0; bit < 64; bit++) { - if ((exp_limb >> bit) & 1) { - sc_mul(result, result, base); - } - Scalar256 tmp; - sc_mul(tmp, base, base); - base = tmp; - } - } - - c = result; -} - -// ============================================================================= -// ECDSA Verification Kernel -// ============================================================================= -// -// Verifies signature (r, s) on message hash z with public key Q: -// 1. Compute u1 = z * s^(-1) mod n -// 2. Compute u2 = r * s^(-1) mod n -// 3. Compute R = u1*G + u2*Q -// 4. Signature is valid if R.x mod n == r - -kernel void ecdsa_verify_batch( - device const uint8_t* messages [[buffer(0)]], // 32 bytes each - device const Scalar256* r_values [[buffer(1)]], - device const Scalar256* s_values [[buffer(2)]], - device const AffinePoint* public_keys [[buffer(3)]], - device const AffinePoint* gtable [[buffer(4)]], // For u1*G - device int* results [[buffer(5)]], - constant uint32_t& count [[buffer(6)]], - uint index [[thread_position_in_grid]] -) { - if (index >= count) return; - - // Load inputs - Scalar256 r = r_values[index]; - Scalar256 s = s_values[index]; - AffinePoint Q = public_keys[index]; - - // Verify r and s are in valid range [1, n-1] - Scalar256 n = {{SECP256K1_N[0], SECP256K1_N[1], SECP256K1_N[2], SECP256K1_N[3]}}; - bool r_zero = (r.limbs[0] == 0 && r.limbs[1] == 0 && r.limbs[2] == 0 && r.limbs[3] == 0); - bool s_zero = (s.limbs[0] == 0 && s.limbs[1] == 0 && s.limbs[2] == 0 && s.limbs[3] == 0); - if (r_zero || s_zero) { - results[index] = 0; // Invalid signature - return; - } - - // Load message hash as scalar z (big-endian to little-endian) - Scalar256 z; - device const uint8_t* msg = messages + index * 32; - for (int i = 0; i < 4; i++) { - z.limbs[3-i] = 0; - for (int j = 0; j < 8; j++) { - z.limbs[3-i] = (z.limbs[3-i] << 8) | msg[i*8 + j]; - } - } - - // Step 1: Compute s^(-1) mod n - Scalar256 s_inv; - sc_inv(s_inv, s); - - // Step 2: Compute u1 = z * s^(-1) mod n - Scalar256 u1; - sc_mul(u1, z, s_inv); - - // Step 3: Compute u2 = r * s^(-1) mod n - Scalar256 u2; - sc_mul(u2, r, s_inv); - - // Step 4: Compute R = u1*G + u2*Q using double-and-add - // For u1*G, use GTable if available, otherwise double-and-add with generator - AffinePoint G; - G.x = {{SECP256K1_GX[0], SECP256K1_GX[1], SECP256K1_GX[2], SECP256K1_GX[3]}}; - G.y = {{SECP256K1_GY[0], SECP256K1_GY[1], SECP256K1_GY[2], SECP256K1_GY[3]}}; - G.infinity = false; - - // Compute u1*G using GTable (16 table lookups + 15 additions) - JacobianPoint acc_u1G; - acc_u1G.x = {{0, 0, 0, 0}}; - acc_u1G.y = {{1, 0, 0, 0}}; - acc_u1G.z = {{0, 0, 0, 0}}; - bool started_u1 = false; - - for (uint32_t i = 0; i < 16; i++) { - uint32_t bit_offset = i * 16; - uint32_t limb_idx = bit_offset / 64; - uint32_t limb_offset = bit_offset % 64; - - uint32_t chunk = (u1.limbs[limb_idx] >> limb_offset) & 0xFFFF; - if (limb_offset > 48 && limb_idx < 3) { - chunk |= (u1.limbs[limb_idx + 1] << (64 - limb_offset)) & 0xFFFF; - } - - if (chunk != 0) { - uint32_t table_idx = i * 65536 + (chunk - 1); - AffinePoint p = gtable[table_idx]; - if (started_u1) { - acc_u1G = point_add_mixed(acc_u1G, p); - } else { - acc_u1G = affine_to_jacobian(p); - started_u1 = true; - } - } - } - - // Compute u2*Q using double-and-add - JacobianPoint acc_u2Q; - acc_u2Q.x = {{0, 0, 0, 0}}; - acc_u2Q.y = {{1, 0, 0, 0}}; - acc_u2Q.z = {{0, 0, 0, 0}}; - bool started_u2 = false; - - for (int limb = 3; limb >= 0; limb--) { - for (int bit = 63; bit >= 0; bit--) { - if (started_u2) { - acc_u2Q = point_double(acc_u2Q); - } - if ((u2.limbs[limb] >> bit) & 1) { - if (started_u2) { - acc_u2Q = point_add_mixed(acc_u2Q, Q); - } else { - acc_u2Q = affine_to_jacobian(Q); - started_u2 = true; - } - } - } - } - - // Compute R = u1*G + u2*Q - JacobianPoint R_jacobian; - if (!started_u1 && !started_u2) { - // Both are identity, invalid - results[index] = 0; - return; - } else if (!started_u1) { - R_jacobian = acc_u2Q; - } else if (!started_u2) { - R_jacobian = acc_u1G; - } else { - // Add the two points (need full projective addition) - // Convert acc_u2Q to affine for mixed addition - AffinePoint u2Q_affine = jacobian_to_affine(acc_u2Q); - R_jacobian = point_add_mixed(acc_u1G, u2Q_affine); - } - - // Check if R is point at infinity - if (is_zero256(R_jacobian.z)) { - results[index] = 0; - return; - } - - // Convert R to affine to get x-coordinate - AffinePoint R = jacobian_to_affine(R_jacobian); - - // Step 5: Verify R.x mod n == r - // R.x is already in field, need to reduce mod n - Scalar256 rx; - rx.limbs[0] = R.x.limbs[0]; - rx.limbs[1] = R.x.limbs[1]; - rx.limbs[2] = R.x.limbs[2]; - rx.limbs[3] = R.x.limbs[3]; - - // Reduce rx mod n if >= n - bool gte = true; - for (int i = 3; i >= 0; i--) { - if (rx.limbs[i] > n.limbs[i]) break; - if (rx.limbs[i] < n.limbs[i]) { gte = false; break; } - } - if (gte) { - uint64_t borrow = 0; - for (int i = 0; i < 4; i++) { - uint64_t diff = rx.limbs[i] - n.limbs[i] - borrow; - borrow = (rx.limbs[i] < n.limbs[i] + borrow) ? 1 : 0; - rx.limbs[i] = diff; - } - } - - // Compare rx with r - bool equal = (rx.limbs[0] == r.limbs[0] && rx.limbs[1] == r.limbs[1] && - rx.limbs[2] == r.limbs[2] && rx.limbs[3] == r.limbs[3]); - - results[index] = equal ? 1 : 0; -} - -// ============================================================================= -// GTable Precomputation Kernel -// ============================================================================= -// -// Computes the GTable for generator point G: -// table[i][j] = (j * 2^(16*i)) * G -// -// This is run once at initialization (~67MB output) - -kernel void precompute_gtable( - device AffinePoint* gtable [[buffer(0)]], - constant AffinePoint& generator [[buffer(1)]], - constant uint32_t& chunk_idx [[buffer(2)]], // Which 16-bit chunk we're computing - uint j [[thread_position_in_grid]] -) { - if (j == 0 || j >= 65536) return; // Skip j=0 (identity) - - // Compute scalar = j * 2^(16 * chunk_idx) - Scalar256 scalar = {{0, 0, 0, 0}}; - uint32_t bit_offset = chunk_idx * 16; - uint32_t limb_idx = bit_offset / 64; - uint32_t limb_offset = bit_offset % 64; - - scalar.limbs[limb_idx] = (uint64_t)j << limb_offset; - if (limb_offset > 48 && limb_idx < 3) { - scalar.limbs[limb_idx + 1] = (uint64_t)j >> (64 - limb_offset); - } - - // Compute scalar * G using double-and-add - AffinePoint base = generator; - JacobianPoint acc; - acc.x = {{0, 0, 0, 0}}; - acc.y = {{1, 0, 0, 0}}; - acc.z = {{0, 0, 0, 0}}; - bool started = false; - - for (int limb = 3; limb >= 0; limb--) { - for (int bit = 63; bit >= 0; bit--) { - if (started) { - acc = point_double(acc); - } - if ((scalar.limbs[limb] >> bit) & 1) { - if (started) { - acc = point_add_mixed(acc, base); - } else { - acc = affine_to_jacobian(base); - started = true; - } - } - } - } - - // Store in table (j-1 because we skip 0) - uint32_t table_idx = chunk_idx * 65536 + (j - 1); - gtable[table_idx] = jacobian_to_affine(acc); -} - -// ============================================================================= -// Keccak256 for Address Derivation -// ============================================================================= - -// Keccak-256 round constants -constant uint64_t KECCAK_RC[24] = { - 0x0000000000000001ULL, 0x0000000000008082ULL, 0x800000000000808AULL, - 0x8000000080008000ULL, 0x000000000000808BULL, 0x0000000080000001ULL, - 0x8000000080008081ULL, 0x8000000000008009ULL, 0x000000000000008AULL, - 0x0000000000000088ULL, 0x0000000080008009ULL, 0x000000008000000AULL, - 0x000000008000808BULL, 0x800000000000008BULL, 0x8000000000008089ULL, - 0x8000000000008003ULL, 0x8000000000008002ULL, 0x8000000000000080ULL, - 0x000000000000800AULL, 0x800000008000000AULL, 0x8000000080008081ULL, - 0x8000000000008080ULL, 0x0000000080000001ULL, 0x8000000080008008ULL -}; - -// Keccak rotation offsets -constant int KECCAK_R[25] = { - 0, 1, 62, 28, 27, 36, 44, 6, 55, 20, - 3, 10, 43, 25, 39, 41, 45, 15, 21, 8, - 18, 2, 61, 56, 14 -}; - -inline uint64_t rotl64(uint64_t x, int n) { - return (x << n) | (x >> (64 - n)); -} - -// Single Keccak-f[1600] permutation round -inline void keccak_round(thread uint64_t* state, uint64_t rc) { - uint64_t C[5], D[5]; - - // θ step - for (int x = 0; x < 5; x++) { - C[x] = state[x] ^ state[x + 5] ^ state[x + 10] ^ state[x + 15] ^ state[x + 20]; - } - for (int x = 0; x < 5; x++) { - D[x] = C[(x + 4) % 5] ^ rotl64(C[(x + 1) % 5], 1); - } - for (int i = 0; i < 25; i++) { - state[i] ^= D[i % 5]; - } - - // ρ and π steps - uint64_t temp[25]; - for (int i = 0; i < 25; i++) { - int x = i % 5; - int y = i / 5; - int new_x = y; - int new_y = (2 * x + 3 * y) % 5; - temp[new_y * 5 + new_x] = rotl64(state[i], KECCAK_R[i]); - } - - // χ step - for (int y = 0; y < 5; y++) { - for (int x = 0; x < 5; x++) { - int i = y * 5 + x; - state[i] = temp[i] ^ ((~temp[y * 5 + (x + 1) % 5]) & temp[y * 5 + (x + 2) % 5]); - } - } - - // ι step - state[0] ^= rc; -} - -// Keccak-256 hash for deriving Ethereum address from public key -kernel void keccak256_batch( - device const uint8_t* inputs [[buffer(0)]], // 64 bytes each (uncompressed pubkey without 0x04) - device uint8_t* outputs [[buffer(1)]], // 32 bytes each - constant uint32_t& count [[buffer(2)]], - constant uint32_t& input_len [[buffer(3)]], - uint index [[thread_position_in_grid]] -) { - if (index >= count) return; - - const device uint8_t* input = inputs + index * input_len; - device uint8_t* output = outputs + index * 32; - - // Initialize state - uint64_t state[25] = {0}; - - // Absorb input (simplified for 64-byte input) - // For public key: input is 64 bytes (X || Y coordinates) - for (uint32_t i = 0; i < input_len && i < 136; i += 8) { - uint64_t block = 0; - for (int j = 0; j < 8 && i + j < input_len; j++) { - block |= (uint64_t)input[i + j] << (j * 8); - } - state[i / 8] ^= block; - } - - // Padding (0x01 ... 0x80 for Keccak-256) - if (input_len < 136) { - state[input_len / 8] ^= (uint64_t)0x01 << ((input_len % 8) * 8); - state[16] ^= 0x8000000000000000ULL; // rate = 136, so last block is index 16 - } - - // Keccak-f[1600] - for (int round = 0; round < 24; round++) { - keccak_round(state, KECCAK_RC[round]); - } - - // Squeeze output (32 bytes) - for (int i = 0; i < 4; i++) { - uint64_t block = state[i]; - for (int j = 0; j < 8; j++) { - output[i * 8 + j] = (block >> (j * 8)) & 0xFF; - } - } -} - -// Derive Ethereum address from public key -// Address = keccak256(pubkey)[12:32] -kernel void derive_address_batch( - device const AffinePoint* public_keys [[buffer(0)]], - device uint8_t* addresses [[buffer(1)]], // 20 bytes each - constant uint32_t& count [[buffer(2)]], - uint index [[thread_position_in_grid]] -) { - if (index >= count) return; - - AffinePoint pk = public_keys[index]; - - if (pk.infinity) { - // Invalid public key - for (int i = 0; i < 20; i++) { - addresses[index * 20 + i] = 0; - } - return; - } - - // Serialize public key (64 bytes: X || Y, big-endian) - uint8_t pubkey_bytes[64]; - for (int i = 0; i < 4; i++) { - uint64_t x_limb = pk.x.limbs[3 - i]; - uint64_t y_limb = pk.y.limbs[3 - i]; - for (int j = 0; j < 8; j++) { - pubkey_bytes[i * 8 + j] = (x_limb >> ((7 - j) * 8)) & 0xFF; - pubkey_bytes[32 + i * 8 + j] = (y_limb >> ((7 - j) * 8)) & 0xFF; - } - } - - // Compute Keccak-256 - uint64_t state[25] = {0}; - - // Absorb 64 bytes - for (int i = 0; i < 8; i++) { - uint64_t block = 0; - for (int j = 0; j < 8; j++) { - block |= (uint64_t)pubkey_bytes[i * 8 + j] << (j * 8); - } - state[i] ^= block; - } - - // Padding - state[8] ^= 0x01; - state[16] ^= 0x8000000000000000ULL; - - // Keccak-f[1600] - for (int round = 0; round < 24; round++) { - keccak_round(state, KECCAK_RC[round]); - } - - // Extract address (last 20 bytes of 32-byte hash) - // Hash bytes 12-31 become address bytes 0-19 - device uint8_t* addr = addresses + index * 20; - - // Bytes 12-15 from state[1] - addr[0] = (state[1] >> 32) & 0xFF; - addr[1] = (state[1] >> 40) & 0xFF; - addr[2] = (state[1] >> 48) & 0xFF; - addr[3] = (state[1] >> 56) & 0xFF; - - // Bytes 16-23 from state[2] - for (int j = 0; j < 8; j++) { - addr[4 + j] = (state[2] >> (j * 8)) & 0xFF; - } - - // Bytes 24-31 from state[3] - for (int j = 0; j < 8; j++) { - addr[12 + j] = (state[3] >> (j * 8)) & 0xFF; - } -} diff --git a/secp256k1/gpu/metal/secp256k1_batch_inv.metal b/secp256k1/gpu/metal/secp256k1_batch_inv.metal deleted file mode 100644 index 9ff944a..0000000 --- a/secp256k1/gpu/metal/secp256k1_batch_inv.metal +++ /dev/null @@ -1,211 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Montgomery batch inversion for secp256k1 base field (Fp) and scalar field -// (Fn). Mirrors crypto/secp256k1/cpp/batch_inv.hpp; output must be byte-equal. -// -// Algorithm: forward prefix product, single Fermat inversion of last prefix, -// backward sweep with running inverse (Knuth TAOCP vol 2). One workgroup of -// one thread for byte-determinism. -// -// Inputs: n elements of Fp (or Fn) in Montgomery form. -// Outputs: n elements; out[i] = in[i]^-1 (Mont form). -// Caller must ensure no zero entries (this kernel does NOT check). - -#include -using namespace metal; - -// ============================================================================= -// 256-bit field constants — matches secp256k1.metal exactly -// ============================================================================= - -struct uint256 { ulong limbs[4]; }; - -constant uint256 P_MOD = {{ - 0xFFFFFFFEFFFFFC2FUL, 0xFFFFFFFFFFFFFFFFUL, - 0xFFFFFFFFFFFFFFFFUL, 0xFFFFFFFFFFFFFFFFUL -}}; -constant uint256 N_MOD = {{ - 0xBFD25E8CD0364141UL, 0xBAAEDCE6AF48A03BUL, - 0xFFFFFFFFFFFFFFFEUL, 0xFFFFFFFFFFFFFFFFUL -}}; -constant ulong P_INV = 0xD838091DD2253531UL; -constant ulong N_INV = 0x4B0DFF665588B13FUL; -constant uint256 R2_N = {{ - 0x896CF21467D7D140UL, 0x741496C20E7CF878UL, - 0xE697F5E45BCD07C6UL, 0x9D671CD581C69BC5UL -}}; -constant uint256 ONE_MONT_P = {{ 0x00000001000003D1UL, 0UL, 0UL, 0UL }}; -constant ulong P_M2[4] = { - 0xFFFFFFFEFFFFFC2DUL, 0xFFFFFFFFFFFFFFFFUL, - 0xFFFFFFFFFFFFFFFFUL, 0xFFFFFFFFFFFFFFFFUL -}; -constant ulong N_M2[4] = { - 0xBFD25E8CD036413FUL, 0xBAAEDCE6AF48A03BUL, - 0xFFFFFFFFFFFFFFFEUL, 0xFFFFFFFFFFFFFFFFUL -}; -constant uint256 ONE = {{1, 0, 0, 0}}; - -inline int u256_cmp(uint256 a, uint256 b) { - for (int i = 3; i >= 0; --i) { - if (a.limbs[i] < b.limbs[i]) return -1; - if (a.limbs[i] > b.limbs[i]) return 1; - } - return 0; -} - -inline void mul64(ulong a, ulong b, thread ulong &lo, thread ulong &hi) { - ulong al = a & 0xFFFFFFFFUL, ah = a >> 32; - ulong bl = b & 0xFFFFFFFFUL, bh = b >> 32; - ulong ll = al * bl, lh = al * bh, hl = ah * bl, hh = ah * bh; - ulong mid = (ll >> 32) + (lh & 0xFFFFFFFFUL) + (hl & 0xFFFFFFFFUL); - lo = (ll & 0xFFFFFFFFUL) | (mid << 32); - hi = hh + (lh >> 32) + (hl >> 32) + (mid >> 32); -} - -inline ulong addc(ulong a, ulong b, ulong c, thread ulong &out) { - ulong t = a + b; - ulong c1 = (t < a) ? 1UL : 0UL; - ulong t2 = t + c; - ulong c2 = (t2 < t) ? 1UL : 0UL; - out = t2; - return c1 + c2; -} - -inline ulong subb(ulong a, ulong b, ulong br, thread ulong &out) { - ulong t = a - b; - ulong b1 = (t > a) ? 1UL : 0UL; - ulong t2 = t - br; - ulong b2 = (t2 > t) ? 1UL : 0UL; - out = t2; - return b1 + b2; -} - -inline uint256 sub_256(uint256 a, uint256 b, thread ulong &borrow) { - uint256 r; - ulong br = 0; - for (int i = 0; i < 4; ++i) br = subb(a.limbs[i], b.limbs[i], br, r.limbs[i]); - borrow = br; - return r; -} - -inline uint256 mont_mul(uint256 a, uint256 b, uint256 m, ulong m_inv) { - ulong t[6]; - for (int i = 0; i < 6; ++i) t[i] = 0; - for (int i = 0; i < 4; ++i) { - ulong carry = 0; - for (int j = 0; j < 4; ++j) { - ulong lo, hi; - mul64(a.limbs[j], b.limbs[i], lo, hi); - ulong c1 = addc(t[j], lo, carry, t[j]); - carry = hi + c1; - } - ulong c1 = addc(t[4], carry, 0, t[4]); - t[5] += c1; - ulong u = t[0] * m_inv; - carry = 0; - for (int j = 0; j < 4; ++j) { - ulong lo, hi; - mul64(u, m.limbs[j], lo, hi); - ulong c2 = addc(t[j], lo, carry, t[j]); - carry = hi + c2; - } - ulong c2 = addc(t[4], carry, 0, t[4]); - t[5] += c2; - for (int j = 0; j < 5; ++j) t[j] = t[j + 1]; - t[5] = 0; - } - uint256 r = {{ t[0], t[1], t[2], t[3] }}; - if (t[4] != 0 || u256_cmp(r, m) >= 0) { - ulong bw; - r = sub_256(r, m, bw); - } - return r; -} - -inline uint256 fp_mul(uint256 a, uint256 b) { return mont_mul(a, b, P_MOD, P_INV); } -inline uint256 fn_mul(uint256 a, uint256 b) { return mont_mul(a, b, N_MOD, N_INV); } -inline uint256 fp_sqr(uint256 a) { return mont_mul(a, a, P_MOD, P_INV); } -inline uint256 fn_sqr(uint256 a) { return mont_mul(a, a, N_MOD, N_INV); } - -inline uint256 fp_pow(uint256 a, constant ulong* exp4) { - uint256 result = ONE_MONT_P; - uint256 base = a; - for (int limb = 0; limb < 4; ++limb) { - ulong w = exp4[limb]; - for (int bit = 0; bit < 64; ++bit) { - if ((w >> bit) & 1) result = fp_mul(result, base); - base = fp_sqr(base); - } - } - return result; -} -inline uint256 fp_inv(uint256 a) { return fp_pow(a, P_M2); } - -inline uint256 fn_pow(uint256 a, constant ulong* exp4) { - uint256 result = mont_mul(ONE, R2_N, N_MOD, N_INV); - uint256 base = a; - for (int limb = 0; limb < 4; ++limb) { - ulong w = exp4[limb]; - for (int bit = 0; bit < 64; ++bit) { - if ((w >> bit) & 1) result = fn_mul(result, base); - base = fn_sqr(base); - } - } - return result; -} -inline uint256 fn_inv(uint256 a) { return fn_pow(a, N_M2); } - -// ============================================================================= -// Batch inversion kernels. -// One workgroup of 1 thread; the kernel walks the array sequentially. This -// preserves byte-equal output across CPU and Metal. -// ============================================================================= - -kernel void secp256k1_batch_inv_fp( - device const uint256* in [[buffer(0)]], - device uint256* out [[buffer(1)]], - constant uint& n [[buffer(2)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid != 0) return; - if (n == 0) return; - - // Forward sweep: prefix products into out[]. - out[0] = in[0]; - for (uint i = 1; i < n; ++i) { - out[i] = fp_mul(out[i - 1], in[i]); - } - // Single Fermat inversion of the last prefix. - uint256 inv = fp_inv(out[n - 1]); - // Backward sweep. - for (uint k = n; k > 1; --k) { - uint i = k - 1; - uint256 t = fp_mul(inv, out[i - 1]); - inv = fp_mul(inv, in[i]); - out[i] = t; - } - out[0] = inv; -} - -kernel void secp256k1_batch_inv_fn( - device const uint256* in [[buffer(0)]], - device uint256* out [[buffer(1)]], - constant uint& n [[buffer(2)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid != 0) return; - if (n == 0) return; - out[0] = in[0]; - for (uint i = 1; i < n; ++i) { - out[i] = fn_mul(out[i - 1], in[i]); - } - uint256 inv = fn_inv(out[n - 1]); - for (uint k = n; k > 1; --k) { - uint i = k - 1; - uint256 t = fn_mul(inv, out[i - 1]); - inv = fn_mul(inv, in[i]); - out[i] = t; - } - out[0] = inv; -} diff --git a/secp256k1/gpu/metal/secp256k1_batch_inv_driver.mm b/secp256k1/gpu/metal/secp256k1_batch_inv_driver.mm deleted file mode 100644 index 583f57c..0000000 --- a/secp256k1/gpu/metal/secp256k1_batch_inv_driver.mm +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Metal driver for Stage A (Montgomery batch inversion) of the v0.63 ecrecover -// pipeline. macOS / iOS only. -// -// Loads the precompiled secp256k1_batch_inv.metallib, dispatches a single- -// thread kernel that performs the n-element prefix product / inversion / -// backward sweep. Single-thread execution on Metal preserves byte-equality -// with the CPU implementation while still freeing the host CPU for other -// pipeline stages. - -#if __APPLE__ && __OBJC__ - -#import -#import - -#include "lux/crypto/secp256k1.h" -#include -#include -#include -#include - -namespace { - -struct U256GPU { uint64_t limbs[4]; }; - -} // namespace - -extern "C" int secp256k1_batch_inv_metal( - const uint8_t* in_mont, // n * 32 bytes (Mont-form, limb little-endian) - size_t n, - uint8_t* out_mont, // n * 32 bytes - int kind, // 0 = Fp, 1 = Fn - const char* metallib_path) { - - if (n == 0) return 0; - if (!in_mont || !out_mont || !metallib_path) return -1; - - @autoreleasepool { - id device = MTLCreateSystemDefaultDevice(); - if (!device) return -2; - - NSError* err = nil; - NSString* path = [NSString stringWithUTF8String:metallib_path]; - NSURL* url = [NSURL fileURLWithPath:path]; - id lib = [device newLibraryWithURL:url error:&err]; - if (!lib) return -3; - - NSString* entry = (kind == 0) ? @"secp256k1_batch_inv_fp" - : @"secp256k1_batch_inv_fn"; - id fn = [lib newFunctionWithName:entry]; - if (!fn) return -4; - - id pipeline = - [device newComputePipelineStateWithFunction:fn error:&err]; - if (!pipeline) return -5; - - id queue = [device newCommandQueue]; - - size_t bytes = n * sizeof(U256GPU); - id in_buf = [device newBufferWithBytes:in_mont - length:bytes - options:MTLResourceStorageModeShared]; - id out_buf = [device newBufferWithLength:bytes - options:MTLResourceStorageModeShared]; - uint32_t n_u32 = (uint32_t)n; - id n_buf = [device newBufferWithBytes:&n_u32 - length:sizeof(n_u32) - options:MTLResourceStorageModeShared]; - - id cmd = [queue commandBuffer]; - id enc = [cmd computeCommandEncoder]; - [enc setComputePipelineState:pipeline]; - [enc setBuffer:in_buf offset:0 atIndex:0]; - [enc setBuffer:out_buf offset:0 atIndex:1]; - [enc setBuffer:n_buf offset:0 atIndex:2]; - - // Single-thread dispatch: byte-equal determinism. - MTLSize threads_per_grid = MTLSizeMake(1, 1, 1); - MTLSize threads_per_tg = MTLSizeMake(1, 1, 1); - [enc dispatchThreads:threads_per_grid threadsPerThreadgroup:threads_per_tg]; - [enc endEncoding]; - [cmd commit]; - [cmd waitUntilCompleted]; - - std::memcpy(out_mont, [out_buf contents], bytes); - } - return 0; -} - -#endif // __APPLE__ && __OBJC__ diff --git a/secp256k1/gpu/metal/secp256k1_driver.h b/secp256k1/gpu/metal/secp256k1_driver.h deleted file mode 100644 index 4e71324..0000000 --- a/secp256k1/gpu/metal/secp256k1_driver.h +++ /dev/null @@ -1,527 +0,0 @@ -// ============================================================================= -// Lux Crypto Library - secp256k1 GPU Acceleration -// ============================================================================= -// -// GPU-accelerated secp256k1 elliptic curve operations using Metal/CUDA. -// Implements precomputed GTable approach from CudaBrainSecp for ~20x speedup. -// -// Use cases: -// - ECDSA signature verification (batch) -// - Schnorr signatures (BIP340) -// - Threshold ECDSA (CGGMP21/FROST/LSS) -// - Address derivation from public keys -// -// The GTable approach: -// - Precomputes 16 chunks × 65536 points each (~67MB table) -// - Scalar multiplication via 16 lookups + 15 additions -// - Much faster than double-and-add for random scalars -// -// Copyright (C) 2024-2025 Lux Industries Inc. -// SPDX-License-Identifier: Apache-2.0 - -#pragma once -#include -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -// ============================================================================= -// Constants -// ============================================================================= - -// secp256k1 curve parameters -// Prime: p = 2^256 - 2^32 - 977 -// Order: n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141 - -#define SECP256K1_FIELD_SIZE 32 // 256 bits = 32 bytes -#define SECP256K1_SCALAR_SIZE 32 // Scalar field size -#define SECP256K1_PUBKEY_SIZE 33 // Compressed public key -#define SECP256K1_PUBKEY_UNCOMPRESSED_SIZE 65 // Uncompressed public key -#define SECP256K1_SIGNATURE_SIZE 64 // ECDSA signature (r, s) -#define SECP256K1_RECOVERABLE_SIG_SIZE 65 // Recoverable signature (r, s, v) - -// GTable precomputation parameters (from CudaBrainSecp) -#define SECP256K1_GTABLE_CHUNKS 16 // 256 bits / 16 = 16 bits per chunk -#define SECP256K1_GTABLE_CHUNK_SIZE 65536 // 2^16 points per chunk -#define SECP256K1_GTABLE_TOTAL_POINTS (SECP256K1_GTABLE_CHUNKS * SECP256K1_GTABLE_CHUNK_SIZE) - -// ============================================================================= -// Types -// ============================================================================= - -/** - * 256-bit field element (modular arithmetic mod p) - * Stored in little-endian limb order - */ -typedef struct { - uint64_t limbs[4]; -} Secp256k1Fp; - -/** - * 256-bit scalar (modular arithmetic mod n) - */ -typedef struct { - uint64_t limbs[4]; -} Secp256k1Scalar; - -/** - * Affine point on secp256k1 - * For point at infinity: infinity = true - */ -typedef struct { - Secp256k1Fp x; - Secp256k1Fp y; - bool infinity; -} Secp256k1Affine; - -/** - * Jacobian projective point on secp256k1 - * (x, y, z) represents affine (x/z^2, y/z^3) - */ -typedef struct { - Secp256k1Fp x; - Secp256k1Fp y; - Secp256k1Fp z; -} Secp256k1Jacobian; - -/** - * ECDSA signature (r, s) - */ -typedef struct { - Secp256k1Scalar r; - Secp256k1Scalar s; -} Secp256k1Signature; - -/** - * Recoverable ECDSA signature (r, s, recovery_id) - */ -typedef struct { - Secp256k1Scalar r; - Secp256k1Scalar s; - uint8_t recovery_id; // 0-3 -} Secp256k1RecoverableSignature; - -/** - * Opaque GPU context handle - */ -typedef struct MetalSecp256k1Context MetalSecp256k1Context; - -// ============================================================================= -// Context Management -// ============================================================================= - -/** - * Create GPU context with precomputed GTable. - * This allocates ~67MB of GPU memory for the precomputed table. - * @param device_id GPU device ID (0 for default) - * @return Context handle, or NULL on error - */ -MetalSecp256k1Context* metal_secp256k1_create(int device_id); - -/** - * Destroy GPU context and free resources. - */ -void metal_secp256k1_destroy(MetalSecp256k1Context* ctx); - -/** - * Check if GPU acceleration is available for secp256k1. - * @return true if Metal/CUDA is available - */ -bool metal_secp256k1_gpu_available(void); - -/** - * Get GPU memory usage for secp256k1 operations. - * @param ctx Context handle - * @return Memory usage in bytes - */ -size_t metal_secp256k1_memory_usage(MetalSecp256k1Context* ctx); - -// ============================================================================= -// Scalar Multiplication (GTable-accelerated) -// ============================================================================= - -/** - * Compute scalar multiplication: result = scalar * G - * Uses precomputed GTable for ~20x speedup over double-and-add. - * - * @param ctx GPU context with precomputed GTable - * @param result Output point (affine) - * @param scalar 256-bit scalar - * @return 0 on success - */ -int metal_secp256k1_scalar_mul_g( - MetalSecp256k1Context* ctx, - Secp256k1Affine* result, - const Secp256k1Scalar* scalar -); - -/** - * Batch scalar multiplication: results[i] = scalars[i] * G - * Highly parallelized on GPU. - * - * @param ctx GPU context - * @param results Output points (caller allocates count elements) - * @param scalars Array of scalars - * @param count Number of multiplications - * @return 0 on success - */ -int metal_secp256k1_batch_scalar_mul_g( - MetalSecp256k1Context* ctx, - Secp256k1Affine* results, - const Secp256k1Scalar* scalars, - uint32_t count -); - -/** - * General scalar multiplication: result = scalar * point - * Uses double-and-add for arbitrary base points. - * For base point G, use metal_secp256k1_scalar_mul_g instead. - * - * @param ctx GPU context - * @param result Output point - * @param scalar 256-bit scalar - * @param point Base point - * @return 0 on success - */ -int metal_secp256k1_scalar_mul( - MetalSecp256k1Context* ctx, - Secp256k1Affine* result, - const Secp256k1Scalar* scalar, - const Secp256k1Affine* point -); - -/** - * Batch general scalar multiplication: results[i] = scalars[i] * points[i] - * - * @param ctx GPU context - * @param results Output points - * @param scalars Array of scalars - * @param points Array of base points - * @param count Number of multiplications - * @return 0 on success - */ -int metal_secp256k1_batch_scalar_mul( - MetalSecp256k1Context* ctx, - Secp256k1Affine* results, - const Secp256k1Scalar* scalars, - const Secp256k1Affine* points, - uint32_t count -); - -// ============================================================================= -// Multi-Scalar Multiplication (MSM) -// ============================================================================= - -/** - * Multi-scalar multiplication: result = sum(scalars[i] * points[i]) - * Uses Pippenger's algorithm with GPU parallelization. - * - * @param ctx GPU context - * @param result Output point - * @param points Array of base points - * @param scalars Array of scalars - * @param count Number of terms - * @return 0 on success - */ -int metal_secp256k1_msm( - MetalSecp256k1Context* ctx, - Secp256k1Affine* result, - const Secp256k1Affine* points, - const Secp256k1Scalar* scalars, - uint32_t count -); - -// ============================================================================= -// ECDSA Operations -// ============================================================================= - -/** - * Batch ECDSA signature verification. - * Verifies multiple (message, signature, public_key) tuples in parallel. - * - * @param ctx GPU context - * @param results Output array of verification results (1=valid, 0=invalid) - * @param messages Array of 32-byte message hashes - * @param signatures Array of signatures - * @param public_keys Array of public keys - * @param count Number of signatures - * @return 0 on success, negative on error - */ -int metal_secp256k1_batch_verify( - MetalSecp256k1Context* ctx, - int* results, - const uint8_t* const* messages, // Each 32 bytes - const Secp256k1Signature* signatures, - const Secp256k1Affine* public_keys, - uint32_t count -); - -/** - * Batch ECDSA signing. - * Signs multiple messages with multiple secret keys in parallel. - * - * @param ctx GPU context - * @param signatures Output signatures - * @param messages Array of 32-byte message hashes - * @param secret_keys Array of secret keys (scalars) - * @param count Number of signatures - * @return 0 on success - */ -int metal_secp256k1_batch_sign( - MetalSecp256k1Context* ctx, - Secp256k1Signature* signatures, - const uint8_t* const* messages, - const Secp256k1Scalar* secret_keys, - uint32_t count -); - -/** - * Batch public key recovery from signatures. - * Recovers public keys from recoverable signatures in parallel. - * - * @param ctx GPU context - * @param public_keys Output public keys - * @param messages Array of 32-byte message hashes - * @param signatures Array of recoverable signatures - * @param count Number of recoveries - * @return 0 on success - */ -int metal_secp256k1_batch_recover( - MetalSecp256k1Context* ctx, - Secp256k1Affine* public_keys, - const uint8_t* const* messages, - const Secp256k1RecoverableSignature* signatures, - uint32_t count -); - -// ============================================================================= -// Schnorr (BIP340) Operations -// ============================================================================= - -/** - * Batch Schnorr signature verification (BIP340). - * - * @param ctx GPU context - * @param results Output verification results - * @param messages Array of 32-byte message hashes - * @param signatures Array of 64-byte Schnorr signatures - * @param public_keys Array of 32-byte x-only public keys - * @param count Number of signatures - * @return 0 on success - */ -int metal_secp256k1_schnorr_batch_verify( - MetalSecp256k1Context* ctx, - int* results, - const uint8_t* const* messages, - const uint8_t* const* signatures, - const uint8_t* const* public_keys, - uint32_t count -); - -// ============================================================================= -// Key Derivation -// ============================================================================= - -/** - * Batch derive public keys from secret keys: pk[i] = sk[i] * G - * Uses GTable for efficient derivation. - * - * @param ctx GPU context - * @param public_keys Output public keys - * @param secret_keys Array of secret keys - * @param count Number of keys - * @return 0 on success - */ -int metal_secp256k1_batch_derive_pubkey( - MetalSecp256k1Context* ctx, - Secp256k1Affine* public_keys, - const Secp256k1Scalar* secret_keys, - uint32_t count -); - -/** - * Batch derive Ethereum addresses from public keys. - * Computes keccak256(pubkey)[12:] for each public key. - * - * @param ctx GPU context - * @param addresses Output addresses (20 bytes each) - * @param public_keys Array of public keys - * @param count Number of addresses - * @return 0 on success - */ -int metal_secp256k1_batch_derive_address( - MetalSecp256k1Context* ctx, - uint8_t* addresses, // count * 20 bytes - const Secp256k1Affine* public_keys, - uint32_t count -); - -// ============================================================================= -// Threshold ECDSA Support -// ============================================================================= - -/** - * Batch nonce generation for threshold ECDSA. - * Generates k values and computes R = k * G. - * - * @param ctx GPU context - * @param r_points Output R points - * @param k_values Output nonce values - * @param entropy Random entropy for nonce generation - * @param count Number of nonces - * @return 0 on success - */ -int metal_secp256k1_batch_nonce_gen( - MetalSecp256k1Context* ctx, - Secp256k1Affine* r_points, - Secp256k1Scalar* k_values, - const uint8_t* entropy, - uint32_t count -); - -/** - * Batch partial signature combination. - * Combines threshold signature shares into final signatures. - * - * @param ctx GPU context - * @param signatures Output combined signatures - * @param partial_sigs Array of partial signature arrays - * @param num_shares Number of shares per signature - * @param count Number of signatures - * @return 0 on success - */ -int metal_secp256k1_combine_partial_sigs( - MetalSecp256k1Context* ctx, - Secp256k1Signature* signatures, - const Secp256k1Scalar* const* partial_sigs, - uint32_t num_shares, - uint32_t count -); - -// ============================================================================= -// Point Arithmetic (for custom protocols) -// ============================================================================= - -/** - * Batch point addition: results[i] = a[i] + b[i] - */ -int metal_secp256k1_batch_add( - MetalSecp256k1Context* ctx, - Secp256k1Affine* results, - const Secp256k1Affine* a, - const Secp256k1Affine* b, - uint32_t count -); - -/** - * Batch point doubling: results[i] = 2 * points[i] - */ -int metal_secp256k1_batch_double( - MetalSecp256k1Context* ctx, - Secp256k1Affine* results, - const Secp256k1Affine* points, - uint32_t count -); - -/** - * Batch point negation: results[i] = -points[i] - */ -int metal_secp256k1_batch_negate( - MetalSecp256k1Context* ctx, - Secp256k1Affine* results, - const Secp256k1Affine* points, - uint32_t count -); - -// ============================================================================= -// Serialization -// ============================================================================= - -/** - * Serialize affine point to compressed public key (33 bytes). - */ -void secp256k1_affine_serialize_compressed( - uint8_t* out, // 33 bytes - const Secp256k1Affine* point -); - -/** - * Serialize affine point to uncompressed public key (65 bytes). - */ -void secp256k1_affine_serialize_uncompressed( - uint8_t* out, // 65 bytes - const Secp256k1Affine* point -); - -/** - * Deserialize compressed public key to affine point. - * @return 0 on success, -1 if invalid - */ -int secp256k1_affine_deserialize_compressed( - Secp256k1Affine* point, - const uint8_t* in // 33 bytes -); - -/** - * Deserialize uncompressed public key to affine point. - * @return 0 on success, -1 if invalid - */ -int secp256k1_affine_deserialize_uncompressed( - Secp256k1Affine* point, - const uint8_t* in // 65 bytes -); - -// ============================================================================= -// GTable Utilities -// ============================================================================= - -/** - * Precompute GTable for custom base point. - * Useful for protocols with fixed base points other than G. - * - * @param ctx GPU context - * @param table_id Output table ID for later use - * @param base Base point for table - * @return 0 on success - */ -int metal_secp256k1_precompute_table( - MetalSecp256k1Context* ctx, - uint32_t* table_id, - const Secp256k1Affine* base -); - -/** - * Scalar multiply using custom precomputed table. - */ -int metal_secp256k1_scalar_mul_table( - MetalSecp256k1Context* ctx, - Secp256k1Affine* result, - uint32_t table_id, - const Secp256k1Scalar* scalar -); - -/** - * Free custom precomputed table. - */ -void metal_secp256k1_free_table( - MetalSecp256k1Context* ctx, - uint32_t table_id -); - -// ============================================================================= -// Error Codes -// ============================================================================= - -#define SECP256K1_SUCCESS 0 -#define SECP256K1_ERROR_NULL_PTR -1 -#define SECP256K1_ERROR_INVALID -2 -#define SECP256K1_ERROR_GPU -3 -#define SECP256K1_ERROR_MEMORY -4 -#define SECP256K1_ERROR_NOT_ON_CURVE -5 - -#ifdef __cplusplus -} -#endif diff --git a/secp256k1/gpu/metal/secp256k1_driver.mm b/secp256k1/gpu/metal/secp256k1_driver.mm deleted file mode 100644 index af20505..0000000 --- a/secp256k1/gpu/metal/secp256k1_driver.mm +++ /dev/null @@ -1,1087 +0,0 @@ -// ============================================================================= -// Metal secp256k1 - GPU Acceleration Implementation -// ============================================================================= -// -// Objective-C++ implementation for Metal compute shader dispatch. -// Uses GTable precomputation for ~20x faster scalar multiplication. -// -// Copyright (C) 2024-2025 Lux Industries Inc. -// SPDX-License-Identifier: Apache-2.0 - -#import -#import -#include "lux/crypto/metal_secp256k1.h" -#include -#include -#include - -// ============================================================================= -// Metal Context Structure -// ============================================================================= - -struct MetalSecp256k1Context { - id device; - id commandQueue; - id library; - - // Compute pipeline states - id pipelinePrecomputeGtable; - id pipelineGtableScalarMul; - id pipelineBatchVerify; - id pipelineBatchDeriveAddress; - id pipelineParallelReduce; - id pipelineScalarMul; - id pipelineBatchAdd; - id pipelineBatchDouble; - - // Precomputed GTable (67MB) - id gtableBuffer; - bool gtableReady; - - // Custom tables for other base points - std::vector> customTables; - - // Memory tracking - size_t totalMemory; -}; - -// ============================================================================= -// Helper: Create Pipeline -// ============================================================================= - -static id createPipeline( - MetalSecp256k1Context* ctx, - const char* name) -{ - id func = [ctx->library newFunctionWithName: - [NSString stringWithUTF8String:name]]; - if (!func) { - NSLog(@"Metal secp256k1: Function '%s' not found", name); - return nil; - } - - NSError* error = nil; - id pipeline = - [ctx->device newComputePipelineStateWithFunction:func error:&error]; - if (!pipeline) { - NSLog(@"Metal secp256k1: Failed to create pipeline '%s': %@", - name, error.localizedDescription); - } - return pipeline; -} - -// ============================================================================= -// Initialization -// ============================================================================= - -extern "C" bool metal_secp256k1_gpu_available(void) { - @autoreleasepool { - id device = MTLCreateSystemDefaultDevice(); - return device != nil; - } -} - -extern "C" MetalSecp256k1Context* metal_secp256k1_create(int device_id) { - @autoreleasepool { - MetalSecp256k1Context* ctx = new MetalSecp256k1Context(); - memset(ctx, 0, sizeof(MetalSecp256k1Context)); - - // Get Metal device - if (device_id == 0) { - ctx->device = MTLCreateSystemDefaultDevice(); - } else { - // For multi-GPU systems - NSArray>* devices = MTLCopyAllDevices(); - if ((NSUInteger)device_id < devices.count) { - ctx->device = devices[device_id]; - } - } - - if (!ctx->device) { - delete ctx; - return nullptr; - } - - // Create command queue - ctx->commandQueue = [ctx->device newCommandQueue]; - if (!ctx->commandQueue) { - delete ctx; - return nullptr; - } - - // Load Metal library - NSError* error = nil; - - // Try loading pre-compiled metallib first - NSArray* metallibPaths = @[ - @"/usr/local/share/lux/crypto/lux_crypto.metallib", - @"/usr/local/share/lux/crypto/secp256k1.metallib", - [[NSBundle mainBundle] pathForResource:@"lux_crypto" ofType:@"metallib"] ?: @"", - [[NSBundle mainBundle] pathForResource:@"secp256k1" ofType:@"metallib"] ?: @"" - ]; - - for (NSString* libPath in metallibPaths) { - if (libPath.length > 0 && [[NSFileManager defaultManager] fileExistsAtPath:libPath]) { - NSURL* libURL = [NSURL fileURLWithPath:libPath]; - ctx->library = [ctx->device newLibraryWithURL:libURL error:&error]; - if (ctx->library) break; - } - } - - // Fall back to compiling from source - if (!ctx->library) { - NSArray* searchPaths = @[ - @"src/metal/secp256k1.metal", - @"../src/metal/secp256k1.metal", - @"metal/secp256k1.metal", - @"/usr/local/share/lux/crypto/secp256k1.metal" - ]; - - for (NSString* path in searchPaths) { - if ([[NSFileManager defaultManager] fileExistsAtPath:path]) { - NSString* source = [NSString stringWithContentsOfFile:path - encoding:NSUTF8StringEncoding - error:&error]; - if (source) { - MTLCompileOptions* options = [[MTLCompileOptions alloc] init]; - if (@available(macOS 15.0, *)) { - options.mathMode = MTLMathModeFast; - } else { -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wdeprecated-declarations" - options.fastMathEnabled = YES; -#pragma clang diagnostic pop - } - - ctx->library = [ctx->device newLibraryWithSource:source - options:options - error:&error]; - if (ctx->library) break; - } - } - } - } - - if (!ctx->library) { - NSLog(@"Metal secp256k1: Failed to load shader library: %@", - error ? error.localizedDescription : @"Unknown error"); - delete ctx; - return nullptr; - } - - // Create compute pipeline states - ctx->pipelinePrecomputeGtable = createPipeline(ctx, "precompute_gtable"); - ctx->pipelineGtableScalarMul = createPipeline(ctx, "gtable_scalar_mul"); - ctx->pipelineBatchVerify = createPipeline(ctx, "batch_verify"); - ctx->pipelineBatchDeriveAddress = createPipeline(ctx, "batch_derive_address"); - ctx->pipelineParallelReduce = createPipeline(ctx, "parallel_point_reduce"); - ctx->pipelineScalarMul = createPipeline(ctx, "scalar_mul"); - ctx->pipelineBatchAdd = createPipeline(ctx, "batch_point_add"); - ctx->pipelineBatchDouble = createPipeline(ctx, "batch_point_double"); - - // At minimum we need GTable scalar mul - if (!ctx->pipelineGtableScalarMul) { - delete ctx; - return nullptr; - } - - // Allocate GTable buffer (16 chunks × 65536 points × 64 bytes = 67MB) - size_t gtableSize = SECP256K1_GTABLE_TOTAL_POINTS * sizeof(Secp256k1Affine); - ctx->gtableBuffer = [ctx->device newBufferWithLength:gtableSize - options:MTLResourceStorageModeShared]; - if (!ctx->gtableBuffer) { - NSLog(@"Metal secp256k1: Failed to allocate GTable (%.1f MB)", - gtableSize / (1024.0 * 1024.0)); - delete ctx; - return nullptr; - } - ctx->totalMemory = gtableSize; - - // Precompute GTable on GPU - if (ctx->pipelinePrecomputeGtable) { - id commandBuffer = [ctx->commandQueue commandBuffer]; - id encoder = [commandBuffer computeCommandEncoder]; - - [encoder setComputePipelineState:ctx->pipelinePrecomputeGtable]; - [encoder setBuffer:ctx->gtableBuffer offset:0 atIndex:0]; - - // Launch enough threads to precompute all points - NSUInteger threadsPerGroup = MIN(256UL, - ctx->pipelinePrecomputeGtable.maxTotalThreadsPerThreadgroup); - - // We compute all points - shader handles the index calculation - [encoder dispatchThreads:MTLSizeMake(SECP256K1_GTABLE_TOTAL_POINTS, 1, 1) - threadsPerThreadgroup:MTLSizeMake(threadsPerGroup, 1, 1)]; - [encoder endEncoding]; - - [commandBuffer commit]; - [commandBuffer waitUntilCompleted]; - - ctx->gtableReady = true; - } else { - // Fall back to CPU precomputation (placeholder - would need implementation) - NSLog(@"Metal secp256k1: GTable precompute shader not found"); - ctx->gtableReady = false; - } - - return ctx; - } -} - -extern "C" void metal_secp256k1_destroy(MetalSecp256k1Context* ctx) { - if (!ctx) return; - - @autoreleasepool { - // Release custom tables - ctx->customTables.clear(); - - // ARC handles release of Objective-C objects - ctx->gtableBuffer = nil; - ctx->pipelinePrecomputeGtable = nil; - ctx->pipelineGtableScalarMul = nil; - ctx->pipelineBatchVerify = nil; - ctx->pipelineBatchDeriveAddress = nil; - ctx->pipelineParallelReduce = nil; - ctx->pipelineScalarMul = nil; - ctx->pipelineBatchAdd = nil; - ctx->pipelineBatchDouble = nil; - ctx->library = nil; - ctx->commandQueue = nil; - ctx->device = nil; - } - - delete ctx; -} - -extern "C" size_t metal_secp256k1_memory_usage(MetalSecp256k1Context* ctx) { - if (!ctx) return 0; - return ctx->totalMemory; -} - -// ============================================================================= -// Helper: Create Buffers -// ============================================================================= - -static id createBuffer(MetalSecp256k1Context* ctx, size_t size) { - return [ctx->device newBufferWithLength:size - options:MTLResourceStorageModeShared]; -} - -static id createBufferWithData(MetalSecp256k1Context* ctx, - const void* data, size_t size) { - return [ctx->device newBufferWithBytes:data - length:size - options:MTLResourceStorageModeShared]; -} - -// ============================================================================= -// Scalar Multiplication (GTable-accelerated) -// ============================================================================= - -extern "C" int metal_secp256k1_scalar_mul_g( - MetalSecp256k1Context* ctx, - Secp256k1Affine* result, - const Secp256k1Scalar* scalar) -{ - if (!ctx || !result || !scalar) { - return SECP256K1_ERROR_NULL_PTR; - } - - if (!ctx->gtableReady || !ctx->pipelineGtableScalarMul) { - return SECP256K1_ERROR_GPU; - } - - @autoreleasepool { - id scalarBuffer = createBufferWithData(ctx, scalar, sizeof(Secp256k1Scalar)); - id resultBuffer = createBuffer(ctx, sizeof(Secp256k1Affine)); - - if (!scalarBuffer || !resultBuffer) { - return SECP256K1_ERROR_MEMORY; - } - - id commandBuffer = [ctx->commandQueue commandBuffer]; - id encoder = [commandBuffer computeCommandEncoder]; - - [encoder setComputePipelineState:ctx->pipelineGtableScalarMul]; - [encoder setBuffer:ctx->gtableBuffer offset:0 atIndex:0]; - [encoder setBuffer:scalarBuffer offset:0 atIndex:1]; - [encoder setBuffer:resultBuffer offset:0 atIndex:2]; - uint32_t count = 1; - [encoder setBytes:&count length:sizeof(count) atIndex:3]; - - [encoder dispatchThreads:MTLSizeMake(1, 1, 1) - threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - [encoder endEncoding]; - - [commandBuffer commit]; - [commandBuffer waitUntilCompleted]; - - memcpy(result, [resultBuffer contents], sizeof(Secp256k1Affine)); - - return SECP256K1_SUCCESS; - } -} - -extern "C" int metal_secp256k1_batch_scalar_mul_g( - MetalSecp256k1Context* ctx, - Secp256k1Affine* results, - const Secp256k1Scalar* scalars, - uint32_t count) -{ - if (!ctx || !results || !scalars || count == 0) { - return SECP256K1_ERROR_NULL_PTR; - } - - if (!ctx->gtableReady || !ctx->pipelineGtableScalarMul) { - return SECP256K1_ERROR_GPU; - } - - @autoreleasepool { - size_t scalarSize = count * sizeof(Secp256k1Scalar); - size_t resultSize = count * sizeof(Secp256k1Affine); - - id scalarBuffer = createBufferWithData(ctx, scalars, scalarSize); - id resultBuffer = createBuffer(ctx, resultSize); - - if (!scalarBuffer || !resultBuffer) { - return SECP256K1_ERROR_MEMORY; - } - - id commandBuffer = [ctx->commandQueue commandBuffer]; - id encoder = [commandBuffer computeCommandEncoder]; - - [encoder setComputePipelineState:ctx->pipelineGtableScalarMul]; - [encoder setBuffer:ctx->gtableBuffer offset:0 atIndex:0]; - [encoder setBuffer:scalarBuffer offset:0 atIndex:1]; - [encoder setBuffer:resultBuffer offset:0 atIndex:2]; - [encoder setBytes:&count length:sizeof(count) atIndex:3]; - - NSUInteger threadsPerGroup = MIN(256UL, - ctx->pipelineGtableScalarMul.maxTotalThreadsPerThreadgroup); - - [encoder dispatchThreads:MTLSizeMake(count, 1, 1) - threadsPerThreadgroup:MTLSizeMake(threadsPerGroup, 1, 1)]; - [encoder endEncoding]; - - [commandBuffer commit]; - [commandBuffer waitUntilCompleted]; - - memcpy(results, [resultBuffer contents], resultSize); - - return SECP256K1_SUCCESS; - } -} - -extern "C" int metal_secp256k1_scalar_mul( - MetalSecp256k1Context* ctx, - Secp256k1Affine* result, - const Secp256k1Scalar* scalar, - const Secp256k1Affine* point) -{ - if (!ctx || !result || !scalar || !point) { - return SECP256K1_ERROR_NULL_PTR; - } - - if (!ctx->pipelineScalarMul) { - return SECP256K1_ERROR_GPU; - } - - @autoreleasepool { - id scalarBuffer = createBufferWithData(ctx, scalar, sizeof(Secp256k1Scalar)); - id pointBuffer = createBufferWithData(ctx, point, sizeof(Secp256k1Affine)); - id resultBuffer = createBuffer(ctx, sizeof(Secp256k1Affine)); - - if (!scalarBuffer || !pointBuffer || !resultBuffer) { - return SECP256K1_ERROR_MEMORY; - } - - id commandBuffer = [ctx->commandQueue commandBuffer]; - id encoder = [commandBuffer computeCommandEncoder]; - - [encoder setComputePipelineState:ctx->pipelineScalarMul]; - [encoder setBuffer:pointBuffer offset:0 atIndex:0]; - [encoder setBuffer:scalarBuffer offset:0 atIndex:1]; - [encoder setBuffer:resultBuffer offset:0 atIndex:2]; - uint32_t count = 1; - [encoder setBytes:&count length:sizeof(count) atIndex:3]; - - [encoder dispatchThreads:MTLSizeMake(1, 1, 1) - threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - [encoder endEncoding]; - - [commandBuffer commit]; - [commandBuffer waitUntilCompleted]; - - memcpy(result, [resultBuffer contents], sizeof(Secp256k1Affine)); - - return SECP256K1_SUCCESS; - } -} - -extern "C" int metal_secp256k1_batch_scalar_mul( - MetalSecp256k1Context* ctx, - Secp256k1Affine* results, - const Secp256k1Scalar* scalars, - const Secp256k1Affine* points, - uint32_t count) -{ - if (!ctx || !results || !scalars || !points || count == 0) { - return SECP256K1_ERROR_NULL_PTR; - } - - if (!ctx->pipelineScalarMul) { - return SECP256K1_ERROR_GPU; - } - - @autoreleasepool { - size_t scalarSize = count * sizeof(Secp256k1Scalar); - size_t pointSize = count * sizeof(Secp256k1Affine); - - id scalarBuffer = createBufferWithData(ctx, scalars, scalarSize); - id pointBuffer = createBufferWithData(ctx, points, pointSize); - id resultBuffer = createBuffer(ctx, pointSize); - - if (!scalarBuffer || !pointBuffer || !resultBuffer) { - return SECP256K1_ERROR_MEMORY; - } - - id commandBuffer = [ctx->commandQueue commandBuffer]; - id encoder = [commandBuffer computeCommandEncoder]; - - [encoder setComputePipelineState:ctx->pipelineScalarMul]; - [encoder setBuffer:pointBuffer offset:0 atIndex:0]; - [encoder setBuffer:scalarBuffer offset:0 atIndex:1]; - [encoder setBuffer:resultBuffer offset:0 atIndex:2]; - [encoder setBytes:&count length:sizeof(count) atIndex:3]; - - NSUInteger threadsPerGroup = MIN(64UL, - ctx->pipelineScalarMul.maxTotalThreadsPerThreadgroup); - - [encoder dispatchThreads:MTLSizeMake(count, 1, 1) - threadsPerThreadgroup:MTLSizeMake(threadsPerGroup, 1, 1)]; - [encoder endEncoding]; - - [commandBuffer commit]; - [commandBuffer waitUntilCompleted]; - - memcpy(results, [resultBuffer contents], pointSize); - - return SECP256K1_SUCCESS; - } -} - -// ============================================================================= -// MSM (Multi-Scalar Multiplication) -// ============================================================================= - -extern "C" int metal_secp256k1_msm( - MetalSecp256k1Context* ctx, - Secp256k1Affine* result, - const Secp256k1Affine* points, - const Secp256k1Scalar* scalars, - uint32_t count) -{ - if (!ctx || !result || !points || !scalars || count == 0) { - return SECP256K1_ERROR_NULL_PTR; - } - - // For small counts, use batch scalar mul + reduction - if (count <= 256 || !ctx->pipelineParallelReduce) { - @autoreleasepool { - // Compute individual scalar muls - std::vector intermediates(count); - int err = metal_secp256k1_batch_scalar_mul(ctx, - intermediates.data(), scalars, points, count); - if (err != SECP256K1_SUCCESS) return err; - - // Reduce on GPU - size_t pointSize = count * sizeof(Secp256k1Affine); - id inputBuffer = createBufferWithData(ctx, - intermediates.data(), pointSize); - id resultBuffer = createBuffer(ctx, sizeof(Secp256k1Affine)); - - if (!inputBuffer || !resultBuffer) { - return SECP256K1_ERROR_MEMORY; - } - - id commandBuffer = [ctx->commandQueue commandBuffer]; - id encoder = [commandBuffer computeCommandEncoder]; - - [encoder setComputePipelineState:ctx->pipelineParallelReduce]; - [encoder setBuffer:inputBuffer offset:0 atIndex:0]; - [encoder setBuffer:resultBuffer offset:0 atIndex:1]; - [encoder setBytes:&count length:sizeof(count) atIndex:2]; - - NSUInteger threadsPerGroup = MIN(256UL, - ctx->pipelineParallelReduce.maxTotalThreadsPerThreadgroup); - - [encoder setThreadgroupMemoryLength:threadsPerGroup * sizeof(Secp256k1Affine) - atIndex:0]; - - [encoder dispatchThreads:MTLSizeMake(threadsPerGroup, 1, 1) - threadsPerThreadgroup:MTLSizeMake(threadsPerGroup, 1, 1)]; - [encoder endEncoding]; - - [commandBuffer commit]; - [commandBuffer waitUntilCompleted]; - - memcpy(result, [resultBuffer contents], sizeof(Secp256k1Affine)); - - return SECP256K1_SUCCESS; - } - } - - // For large counts, would use Pippenger's algorithm - // Placeholder: fall back to batch + reduce - return SECP256K1_ERROR_GPU; -} - -// ============================================================================= -// ECDSA Batch Verification -// ============================================================================= - -extern "C" int metal_secp256k1_batch_verify( - MetalSecp256k1Context* ctx, - int* results, - const uint8_t* const* messages, - const Secp256k1Signature* signatures, - const Secp256k1Affine* public_keys, - uint32_t count) -{ - if (!ctx || !results || !messages || !signatures || !public_keys || count == 0) { - return SECP256K1_ERROR_NULL_PTR; - } - - if (!ctx->pipelineBatchVerify) { - return SECP256K1_ERROR_GPU; - } - - @autoreleasepool { - // Pack messages into contiguous buffer - std::vector msgBuffer(count * 32); - for (uint32_t i = 0; i < count; i++) { - if (!messages[i]) return SECP256K1_ERROR_NULL_PTR; - memcpy(&msgBuffer[i * 32], messages[i], 32); - } - - size_t sigSize = count * sizeof(Secp256k1Signature); - size_t pkSize = count * sizeof(Secp256k1Affine); - - id msgBufferGPU = createBufferWithData(ctx, msgBuffer.data(), msgBuffer.size()); - id sigBuffer = createBufferWithData(ctx, signatures, sigSize); - id pkBuffer = createBufferWithData(ctx, public_keys, pkSize); - id resultBuffer = createBuffer(ctx, count * sizeof(int)); - - if (!msgBufferGPU || !sigBuffer || !pkBuffer || !resultBuffer) { - return SECP256K1_ERROR_MEMORY; - } - - id commandBuffer = [ctx->commandQueue commandBuffer]; - id encoder = [commandBuffer computeCommandEncoder]; - - [encoder setComputePipelineState:ctx->pipelineBatchVerify]; - [encoder setBuffer:ctx->gtableBuffer offset:0 atIndex:0]; - [encoder setBuffer:msgBufferGPU offset:0 atIndex:1]; - [encoder setBuffer:sigBuffer offset:0 atIndex:2]; - [encoder setBuffer:pkBuffer offset:0 atIndex:3]; - [encoder setBuffer:resultBuffer offset:0 atIndex:4]; - [encoder setBytes:&count length:sizeof(count) atIndex:5]; - - NSUInteger threadsPerGroup = MIN(64UL, - ctx->pipelineBatchVerify.maxTotalThreadsPerThreadgroup); - - [encoder dispatchThreads:MTLSizeMake(count, 1, 1) - threadsPerThreadgroup:MTLSizeMake(threadsPerGroup, 1, 1)]; - [encoder endEncoding]; - - [commandBuffer commit]; - [commandBuffer waitUntilCompleted]; - - memcpy(results, [resultBuffer contents], count * sizeof(int)); - - return SECP256K1_SUCCESS; - } -} - -// ============================================================================= -// Address Derivation -// ============================================================================= - -extern "C" int metal_secp256k1_batch_derive_pubkey( - MetalSecp256k1Context* ctx, - Secp256k1Affine* public_keys, - const Secp256k1Scalar* secret_keys, - uint32_t count) -{ - // Derive public keys: pk = sk * G - return metal_secp256k1_batch_scalar_mul_g(ctx, public_keys, secret_keys, count); -} - -extern "C" int metal_secp256k1_batch_derive_address( - MetalSecp256k1Context* ctx, - uint8_t* addresses, - const Secp256k1Affine* public_keys, - uint32_t count) -{ - if (!ctx || !addresses || !public_keys || count == 0) { - return SECP256K1_ERROR_NULL_PTR; - } - - if (!ctx->pipelineBatchDeriveAddress) { - return SECP256K1_ERROR_GPU; - } - - @autoreleasepool { - size_t pkSize = count * sizeof(Secp256k1Affine); - size_t addrSize = count * 20; // 20 bytes per Ethereum address - - id pkBuffer = createBufferWithData(ctx, public_keys, pkSize); - id addrBuffer = createBuffer(ctx, addrSize); - - if (!pkBuffer || !addrBuffer) { - return SECP256K1_ERROR_MEMORY; - } - - id commandBuffer = [ctx->commandQueue commandBuffer]; - id encoder = [commandBuffer computeCommandEncoder]; - - [encoder setComputePipelineState:ctx->pipelineBatchDeriveAddress]; - [encoder setBuffer:pkBuffer offset:0 atIndex:0]; - [encoder setBuffer:addrBuffer offset:0 atIndex:1]; - [encoder setBytes:&count length:sizeof(count) atIndex:2]; - - NSUInteger threadsPerGroup = MIN(256UL, - ctx->pipelineBatchDeriveAddress.maxTotalThreadsPerThreadgroup); - - [encoder dispatchThreads:MTLSizeMake(count, 1, 1) - threadsPerThreadgroup:MTLSizeMake(threadsPerGroup, 1, 1)]; - [encoder endEncoding]; - - [commandBuffer commit]; - [commandBuffer waitUntilCompleted]; - - memcpy(addresses, [addrBuffer contents], addrSize); - - return SECP256K1_SUCCESS; - } -} - -// ============================================================================= -// Point Arithmetic -// ============================================================================= - -extern "C" int metal_secp256k1_batch_add( - MetalSecp256k1Context* ctx, - Secp256k1Affine* results, - const Secp256k1Affine* a, - const Secp256k1Affine* b, - uint32_t count) -{ - if (!ctx || !results || !a || !b || count == 0) { - return SECP256K1_ERROR_NULL_PTR; - } - - if (!ctx->pipelineBatchAdd) { - return SECP256K1_ERROR_GPU; - } - - @autoreleasepool { - size_t pointSize = count * sizeof(Secp256k1Affine); - - id bufferA = createBufferWithData(ctx, a, pointSize); - id bufferB = createBufferWithData(ctx, b, pointSize); - id bufferResult = createBuffer(ctx, pointSize); - - if (!bufferA || !bufferB || !bufferResult) { - return SECP256K1_ERROR_MEMORY; - } - - id commandBuffer = [ctx->commandQueue commandBuffer]; - id encoder = [commandBuffer computeCommandEncoder]; - - [encoder setComputePipelineState:ctx->pipelineBatchAdd]; - [encoder setBuffer:bufferResult offset:0 atIndex:0]; - [encoder setBuffer:bufferA offset:0 atIndex:1]; - [encoder setBuffer:bufferB offset:0 atIndex:2]; - [encoder setBytes:&count length:sizeof(count) atIndex:3]; - - NSUInteger threadsPerGroup = MIN(256UL, - ctx->pipelineBatchAdd.maxTotalThreadsPerThreadgroup); - - [encoder dispatchThreads:MTLSizeMake(count, 1, 1) - threadsPerThreadgroup:MTLSizeMake(threadsPerGroup, 1, 1)]; - [encoder endEncoding]; - - [commandBuffer commit]; - [commandBuffer waitUntilCompleted]; - - memcpy(results, [bufferResult contents], pointSize); - - return SECP256K1_SUCCESS; - } -} - -extern "C" int metal_secp256k1_batch_double( - MetalSecp256k1Context* ctx, - Secp256k1Affine* results, - const Secp256k1Affine* points, - uint32_t count) -{ - if (!ctx || !results || !points || count == 0) { - return SECP256K1_ERROR_NULL_PTR; - } - - if (!ctx->pipelineBatchDouble) { - return SECP256K1_ERROR_GPU; - } - - @autoreleasepool { - size_t pointSize = count * sizeof(Secp256k1Affine); - - id bufferPoints = createBufferWithData(ctx, points, pointSize); - id bufferResult = createBuffer(ctx, pointSize); - - if (!bufferPoints || !bufferResult) { - return SECP256K1_ERROR_MEMORY; - } - - id commandBuffer = [ctx->commandQueue commandBuffer]; - id encoder = [commandBuffer computeCommandEncoder]; - - [encoder setComputePipelineState:ctx->pipelineBatchDouble]; - [encoder setBuffer:bufferResult offset:0 atIndex:0]; - [encoder setBuffer:bufferPoints offset:0 atIndex:1]; - [encoder setBytes:&count length:sizeof(count) atIndex:2]; - - NSUInteger threadsPerGroup = MIN(256UL, - ctx->pipelineBatchDouble.maxTotalThreadsPerThreadgroup); - - [encoder dispatchThreads:MTLSizeMake(count, 1, 1) - threadsPerThreadgroup:MTLSizeMake(threadsPerGroup, 1, 1)]; - [encoder endEncoding]; - - [commandBuffer commit]; - [commandBuffer waitUntilCompleted]; - - memcpy(results, [bufferResult contents], pointSize); - - return SECP256K1_SUCCESS; - } -} - -extern "C" int metal_secp256k1_batch_negate( - MetalSecp256k1Context* ctx, - Secp256k1Affine* results, - const Secp256k1Affine* points, - uint32_t count) -{ - if (!ctx || !results || !points || count == 0) { - return SECP256K1_ERROR_NULL_PTR; - } - - // Negation is simple: -P = (x, -y mod p) - // Can be done on CPU efficiently - for (uint32_t i = 0; i < count; i++) { - results[i].x = points[i].x; - results[i].infinity = points[i].infinity; - - if (!points[i].infinity) { - // Negate y: y' = p - y - // secp256k1 p = 2^256 - 2^32 - 977 - // For simplicity, we compute -y mod p - Secp256k1Fp p = {{ - 0xFFFFFFFEFFFFFC2F, - 0xFFFFFFFFFFFFFFFF, - 0xFFFFFFFFFFFFFFFF, - 0xFFFFFFFFFFFFFFFF - }}; - - // Subtract y from p - uint64_t borrow = 0; - for (int j = 0; j < 4; j++) { - __uint128_t diff = (__uint128_t)p.limbs[j] - points[i].y.limbs[j] - borrow; - results[i].y.limbs[j] = (uint64_t)diff; - borrow = (diff >> 64) & 1 ? 1 : 0; - } - } - } - - return SECP256K1_SUCCESS; -} - -// ============================================================================= -// Serialization (CPU implementation - no GPU benefit) -// ============================================================================= - -extern "C" void secp256k1_affine_serialize_compressed( - uint8_t* out, - const Secp256k1Affine* point) -{ - if (!out || !point) return; - - if (point->infinity) { - memset(out, 0, 33); - return; - } - - // Prefix: 02 if y is even, 03 if y is odd - out[0] = (point->y.limbs[0] & 1) ? 0x03 : 0x02; - - // X coordinate in big-endian - for (int i = 0; i < 4; i++) { - uint64_t limb = point->x.limbs[3 - i]; - for (int j = 7; j >= 0; j--) { - out[1 + i * 8 + (7 - j)] = (limb >> (j * 8)) & 0xFF; - } - } -} - -extern "C" void secp256k1_affine_serialize_uncompressed( - uint8_t* out, - const Secp256k1Affine* point) -{ - if (!out || !point) return; - - if (point->infinity) { - memset(out, 0, 65); - return; - } - - out[0] = 0x04; // Uncompressed prefix - - // X coordinate in big-endian - for (int i = 0; i < 4; i++) { - uint64_t limb = point->x.limbs[3 - i]; - for (int j = 7; j >= 0; j--) { - out[1 + i * 8 + (7 - j)] = (limb >> (j * 8)) & 0xFF; - } - } - - // Y coordinate in big-endian - for (int i = 0; i < 4; i++) { - uint64_t limb = point->y.limbs[3 - i]; - for (int j = 7; j >= 0; j--) { - out[33 + i * 8 + (7 - j)] = (limb >> (j * 8)) & 0xFF; - } - } -} - -extern "C" int secp256k1_affine_deserialize_compressed( - Secp256k1Affine* point, - const uint8_t* in) -{ - if (!point || !in) return SECP256K1_ERROR_NULL_PTR; - - uint8_t prefix = in[0]; - if (prefix != 0x02 && prefix != 0x03) { - return SECP256K1_ERROR_INVALID; - } - - point->infinity = false; - - // X coordinate from big-endian - for (int i = 0; i < 4; i++) { - uint64_t limb = 0; - for (int j = 0; j < 8; j++) { - limb = (limb << 8) | in[1 + i * 8 + j]; - } - point->x.limbs[3 - i] = limb; - } - - // Compute y from x (y^2 = x^3 + 7) - // This requires modular square root - placeholder - memset(&point->y, 0, sizeof(point->y)); - - // Set y parity based on prefix - if (prefix == 0x03) { - point->y.limbs[0] |= 1; // Odd y - } - - return SECP256K1_SUCCESS; -} - -extern "C" int secp256k1_affine_deserialize_uncompressed( - Secp256k1Affine* point, - const uint8_t* in) -{ - if (!point || !in) return SECP256K1_ERROR_NULL_PTR; - - if (in[0] != 0x04) { - return SECP256K1_ERROR_INVALID; - } - - point->infinity = false; - - // X coordinate from big-endian - for (int i = 0; i < 4; i++) { - uint64_t limb = 0; - for (int j = 0; j < 8; j++) { - limb = (limb << 8) | in[1 + i * 8 + j]; - } - point->x.limbs[3 - i] = limb; - } - - // Y coordinate from big-endian - for (int i = 0; i < 4; i++) { - uint64_t limb = 0; - for (int j = 0; j < 8; j++) { - limb = (limb << 8) | in[33 + i * 8 + j]; - } - point->y.limbs[3 - i] = limb; - } - - return SECP256K1_SUCCESS; -} - -// ============================================================================= -// Custom GTable Management -// ============================================================================= - -extern "C" int metal_secp256k1_precompute_table( - MetalSecp256k1Context* ctx, - uint32_t* table_id, - const Secp256k1Affine* base) -{ - if (!ctx || !table_id || !base) { - return SECP256K1_ERROR_NULL_PTR; - } - - // Allocate new table - size_t tableSize = SECP256K1_GTABLE_TOTAL_POINTS * sizeof(Secp256k1Affine); - id tableBuffer = [ctx->device newBufferWithLength:tableSize - options:MTLResourceStorageModeShared]; - if (!tableBuffer) { - return SECP256K1_ERROR_MEMORY; - } - - // Store table and return ID - *table_id = (uint32_t)ctx->customTables.size(); - ctx->customTables.push_back(tableBuffer); - ctx->totalMemory += tableSize; - - // TODO: Precompute table values for custom base point - // This would require a modified precompute kernel - - return SECP256K1_SUCCESS; -} - -extern "C" int metal_secp256k1_scalar_mul_table( - MetalSecp256k1Context* ctx, - Secp256k1Affine* result, - uint32_t table_id, - const Secp256k1Scalar* scalar) -{ - if (!ctx || !result || !scalar) { - return SECP256K1_ERROR_NULL_PTR; - } - - if (table_id >= ctx->customTables.size()) { - return SECP256K1_ERROR_INVALID; - } - - // Use the custom table for scalar multiplication - id tableBuffer = ctx->customTables[table_id]; - - @autoreleasepool { - id scalarBuffer = createBufferWithData(ctx, scalar, sizeof(Secp256k1Scalar)); - id resultBuffer = createBuffer(ctx, sizeof(Secp256k1Affine)); - - if (!scalarBuffer || !resultBuffer) { - return SECP256K1_ERROR_MEMORY; - } - - id commandBuffer = [ctx->commandQueue commandBuffer]; - id encoder = [commandBuffer computeCommandEncoder]; - - [encoder setComputePipelineState:ctx->pipelineGtableScalarMul]; - [encoder setBuffer:tableBuffer offset:0 atIndex:0]; - [encoder setBuffer:scalarBuffer offset:0 atIndex:1]; - [encoder setBuffer:resultBuffer offset:0 atIndex:2]; - uint32_t count = 1; - [encoder setBytes:&count length:sizeof(count) atIndex:3]; - - [encoder dispatchThreads:MTLSizeMake(1, 1, 1) - threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - [encoder endEncoding]; - - [commandBuffer commit]; - [commandBuffer waitUntilCompleted]; - - memcpy(result, [resultBuffer contents], sizeof(Secp256k1Affine)); - - return SECP256K1_SUCCESS; - } -} - -extern "C" void metal_secp256k1_free_table( - MetalSecp256k1Context* ctx, - uint32_t table_id) -{ - if (!ctx || table_id >= ctx->customTables.size()) { - return; - } - - size_t tableSize = SECP256K1_GTABLE_TOTAL_POINTS * sizeof(Secp256k1Affine); - ctx->totalMemory -= tableSize; - ctx->customTables[table_id] = nil; -} - -// ============================================================================= -// Stub implementations for functions not yet implemented -// ============================================================================= - -extern "C" int metal_secp256k1_batch_sign( - MetalSecp256k1Context* ctx, - Secp256k1Signature* signatures, - const uint8_t* const* messages, - const Secp256k1Scalar* secret_keys, - uint32_t count) -{ - // TODO: Implement batch signing with GPU-accelerated k*G computation - return SECP256K1_ERROR_GPU; -} - -extern "C" int metal_secp256k1_batch_recover( - MetalSecp256k1Context* ctx, - Secp256k1Affine* public_keys, - const uint8_t* const* messages, - const Secp256k1RecoverableSignature* signatures, - uint32_t count) -{ - // TODO: Implement batch recovery with GPU-accelerated point operations - return SECP256K1_ERROR_GPU; -} - -extern "C" int metal_secp256k1_schnorr_batch_verify( - MetalSecp256k1Context* ctx, - int* results, - const uint8_t* const* messages, - const uint8_t* const* signatures, - const uint8_t* const* public_keys, - uint32_t count) -{ - // TODO: Implement BIP340 Schnorr batch verification - return SECP256K1_ERROR_GPU; -} - -extern "C" int metal_secp256k1_batch_nonce_gen( - MetalSecp256k1Context* ctx, - Secp256k1Affine* r_points, - Secp256k1Scalar* k_values, - const uint8_t* entropy, - uint32_t count) -{ - // TODO: Implement secure nonce generation for threshold ECDSA - return SECP256K1_ERROR_GPU; -} - -extern "C" int metal_secp256k1_combine_partial_sigs( - MetalSecp256k1Context* ctx, - Secp256k1Signature* signatures, - const Secp256k1Scalar* const* partial_sigs, - uint32_t num_shares, - uint32_t count) -{ - // TODO: Implement Lagrange interpolation for threshold signature combination - return SECP256K1_ERROR_GPU; -} diff --git a/secp256k1/gpu/metal/secp256k1_first_party_driver.mm b/secp256k1/gpu/metal/secp256k1_first_party_driver.mm deleted file mode 100644 index c943efa..0000000 --- a/secp256k1/gpu/metal/secp256k1_first_party_driver.mm +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Metal driver for batch secp256k1 ecrecover. macOS / iOS / iPadOS only. -// -// Loads the precompiled metallib produced from kernels.metal, dispatches one -// thread per (hash, r, s, v) tuple, and writes a 64-byte uncompressed pubkey -// into the output buffer. -// -// The kernel itself emits 20-byte Ethereum addresses (last 20 bytes of -// keccak(pubkey)) for legacy callers. For the canonical lux_crypto API we -// want the 64-byte pubkey, not the address. To satisfy both shapes without -// duplicating Metal code, this driver runs the kernel and reconstructs the -// pubkey from the kernel's intermediate buffer if the kernel layout exposes -// it. For phase 1 we rely on a CPU-side post-step (the kernel only emits the -// 20-byte address as documented in kernels.metal). To produce CPU-byte-equal -// pubkey output we therefore call the CPU path for the public-key portion; -// this driver is a placeholder that the next agent will replace with a -// dedicated kernel that emits 64-byte pubkey directly. -// -// This is intentionally minimal in phase 1: the GPU correctness proof is -// achieved via the kernel's existing 20-byte address output (see -// secp256k1_gpu_test.cpp). When BLS12-381 lands the kernel will be split into -// "ecrecover_pubkey" and "ecrecover_address" entry points and this driver -// will use the pubkey one. - -#if __APPLE__ && __OBJC__ - -#import -#import - -#include "lux/crypto/secp256k1.h" -#include -#include - -extern "C" secp256k1_status secp256k1_ecrecover_address_batch_metal( - const uint8_t* inputs, // n * 97 bytes - size_t n, - uint8_t* out_addr, // n * 20 bytes - uint8_t* out_st, // n bytes - const char* metallib_path) { - - if (!inputs || !out_addr || !out_st || !metallib_path) { - return SECP256K1_ERR_NULL_ARG; - } - if (n == 0) return SECP256K1_OK; - - @autoreleasepool { - id device = MTLCreateSystemDefaultDevice(); - if (!device) return SECP256K1_ERR_NULL_ARG; - - NSError* err = nil; - NSString* path = [NSString stringWithUTF8String:metallib_path]; - NSURL* url = [NSURL fileURLWithPath:path]; - id lib = [device newLibraryWithURL:url error:&err]; - if (!lib) return SECP256K1_ERR_NULL_ARG; - - id fn = [lib newFunctionWithName:@"secp256k1_ecrecover_batch"]; - if (!fn) return SECP256K1_ERR_NULL_ARG; - - id pipeline = - [device newComputePipelineStateWithFunction:fn error:&err]; - if (!pipeline) return SECP256K1_ERR_NULL_ARG; - - id queue = [device newCommandQueue]; - - // Kernel layout (per kernels.metal): - // struct EcrecoverInput { hash[32], r[32], s[32], v_pad[16] }; // 112 B - // struct EcrecoverOutput { addr[20], valid, _pad[11] }; // 32 B - // The Go-callable host translates (hash || r || s || v) -> 112-byte input. - const size_t IN_STRIDE = 112; - const size_t OUT_STRIDE = 32; - - std::vector in_dev(n * IN_STRIDE, 0); - for (size_t i = 0; i < n; ++i) { - const uint8_t* src = inputs + i * 97; - uint8_t* dst = &in_dev[i * IN_STRIDE]; - std::memcpy(dst, src, 96); // hash || r || s - dst[96] = src[96]; // v in the v_pad block - } - std::vector out_dev(n * OUT_STRIDE, 0); - - id in_buf = [device newBufferWithBytes:in_dev.data() - length:in_dev.size() - options:MTLResourceStorageModeShared]; - id out_buf = [device newBufferWithLength:out_dev.size() - options:MTLResourceStorageModeShared]; - uint32_t num_sigs = (uint32_t)n; - id n_buf = [device newBufferWithBytes:&num_sigs - length:sizeof(num_sigs) - options:MTLResourceStorageModeShared]; - - id cmd = [queue commandBuffer]; - id enc = [cmd computeCommandEncoder]; - [enc setComputePipelineState:pipeline]; - [enc setBuffer:in_buf offset:0 atIndex:0]; - [enc setBuffer:out_buf offset:0 atIndex:1]; - [enc setBuffer:n_buf offset:0 atIndex:2]; - - NSUInteger tg_size = pipeline.maxTotalThreadsPerThreadgroup; - if (tg_size > n) tg_size = n; - MTLSize threads_per_grid = MTLSizeMake(n, 1, 1); - MTLSize threads_per_tg = MTLSizeMake(tg_size, 1, 1); - [enc dispatchThreads:threads_per_grid threadsPerThreadgroup:threads_per_tg]; - [enc endEncoding]; - [cmd commit]; - [cmd waitUntilCompleted]; - - std::memcpy(out_dev.data(), [out_buf contents], out_dev.size()); - for (size_t i = 0; i < n; ++i) { - const uint8_t* src = &out_dev[i * OUT_STRIDE]; - std::memcpy(out_addr + i * 20, src, 20); - uint8_t valid = src[20]; - out_st[i] = valid ? (uint8_t)SECP256K1_OK - : (uint8_t)SECP256K1_ERR_AT_INFINITY; - } - } - return SECP256K1_OK; -} - -#endif // __APPLE__ && __OBJC__ diff --git a/secp256k1/gpu/metal/secp256k1_recover.metal b/secp256k1/gpu/metal/secp256k1_recover.metal deleted file mode 100644 index 3b4a5ed..0000000 --- a/secp256k1/gpu/metal/secp256k1_recover.metal +++ /dev/null @@ -1,852 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -/// @file secp256k1_recover.metal -/// Metal compute shader for batch secp256k1 ECDSA public key recovery. -/// -/// Each GPU thread recovers one (r, s, v, msg_hash) tuple into an Ethereum -/// address (20 bytes). This is the critical EVM operation that dominates block -/// processing time. -/// -/// Algorithm per thread: -/// 1. Decompress r → point R using recovery flag v -/// 2. Compute s_inv = s^(-1) mod n (Fermat's little theorem) -/// 3. Compute Q = s_inv * (s*R - hash*G) = s_inv*s*R - s_inv*hash*G -/// Equivalently: Q = r_inv * (s*R - hash*G) -- but we use the standard form: -/// Q = s_inv * (s * R - hash * G) which simplifies to s_inv*s*R - s_inv*hash*G -/// Actually the standard ecrecover formula is: -/// Q = r^(-1) * (s * R - e * G) -/// where e = msg_hash, R = decompressed point from r with parity v. -/// 4. Serialize Q as uncompressed (x || y), 64 bytes -/// 5. Keccak-256(Q.x || Q.y), take last 20 bytes = Ethereum address -/// -/// secp256k1 curve: y^2 = x^3 + 7 over F_p -/// p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F -/// n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141 -/// G = (0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798, -/// 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8) - -#include -using namespace metal; - -// ============================================================================= -// 256-bit unsigned integer (4 x 64-bit limbs, little-endian) -// ============================================================================= - -struct uint256 { - ulong limbs[4]; // limbs[0] = least significant -}; - -// ============================================================================= -// secp256k1 constants -// ============================================================================= - -// Field prime p = 2^256 - 2^32 - 977 -constant uint256 SECP256K1_P = {{ - 0xFFFFFFFEFFFFFC2FUL, 0xFFFFFFFFFFFFFFFFUL, - 0xFFFFFFFFFFFFFFFFUL, 0xFFFFFFFFFFFFFFFFUL -}}; - -// Curve order n -constant uint256 SECP256K1_N = {{ - 0xBFD25E8CD0364141UL, 0xBAAEDCE6AF48A03BUL, - 0xFFFFFFFFFFFFFFFEUL, 0xFFFFFFFFFFFFFFFFUL -}}; - -// Generator point G (affine coordinates) -constant uint256 GX = {{ - 0x59F2815B16F81798UL, 0x029BFCDB2DCE28D9UL, - 0x55A06295CE870B07UL, 0x79BE667EF9DCBBACUL -}}; -constant uint256 GY = {{ - 0x9C47D08FFB10D4B8UL, 0xFD17B448A6855419UL, - 0x5DA4FBFC0E1108A8UL, 0x483ADA7726A3C465UL -}}; - -// Montgomery constants for field p: -// R = 2^256 mod p -// R^2 mod p (for Montgomery encoding) -// p_inv = -p^(-1) mod 2^64 -constant uint256 MONT_R_P = {{ - 0x00000001000003D1UL, 0x0000000000000000UL, - 0x0000000000000000UL, 0x0000000000000000UL -}}; -constant uint256 MONT_R2_P = {{ - 0x000007A2000E90A1UL, 0x0000000000000001UL, - 0x0000000000000000UL, 0x0000000000000000UL -}}; -constant ulong P_INV = 0xD838091DD2253531UL; // -p^(-1) mod 2^64 - -// Montgomery constants for order n: -// R^2 mod n -// n_inv = -n^(-1) mod 2^64 -constant uint256 MONT_R2_N = {{ - 0x896CF21467D7D140UL, 0x741496C20E7CF878UL, - 0xE697F5E45BCD07C6UL, 0x9D671CD581C69BC5UL -}}; -constant ulong N_INV = 0x4B0DFF665588B13FUL; // -n^(-1) mod 2^64 - -// Zero and one -constant uint256 ZERO256 = {{0, 0, 0, 0}}; -constant uint256 ONE256 = {{1, 0, 0, 0}}; - -// ============================================================================= -// 256-bit arithmetic helpers -// ============================================================================= - -// Compare: returns -1 if a < b, 0 if a == b, 1 if a > b -inline int u256_cmp(uint256 a, uint256 b) { - for (int i = 3; i >= 0; i--) { - if (a.limbs[i] < b.limbs[i]) return -1; - if (a.limbs[i] > b.limbs[i]) return 1; - } - return 0; -} - -inline bool u256_is_zero(uint256 a) { - return (a.limbs[0] | a.limbs[1] | a.limbs[2] | a.limbs[3]) == 0; -} - -// a + b, returns carry -inline uint256 u256_add(uint256 a, uint256 b, thread ulong &carry) { - uint256 r; - ulong c = 0; - for (int i = 0; i < 4; i++) { - ulong sum = a.limbs[i] + c; - c = (sum < a.limbs[i]) ? 1UL : 0UL; - ulong sum2 = sum + b.limbs[i]; - c += (sum2 < sum) ? 1UL : 0UL; - r.limbs[i] = sum2; - } - carry = c; - return r; -} - -// a - b, returns borrow -inline uint256 u256_sub(uint256 a, uint256 b, thread ulong &borrow) { - uint256 r; - ulong bw = 0; - for (int i = 0; i < 4; i++) { - ulong diff = a.limbs[i] - bw; - bw = (diff > a.limbs[i]) ? 1UL : 0UL; - ulong diff2 = diff - b.limbs[i]; - bw += (diff2 > diff) ? 1UL : 0UL; - r.limbs[i] = diff2; - } - borrow = bw; - return r; -} - -// ============================================================================= -// Modular arithmetic over F_p (Montgomery form) -// ============================================================================= - -// Montgomery reduction: given T (up to 512-bit), compute T * R^(-1) mod m -// where m is either p or n, inv is the corresponding -m^(-1) mod 2^64 -inline uint256 mont_reduce(ulong t[8], uint256 m, ulong inv) { - // CIOS (Coarsely Integrated Operand Scanning) reduction - // Use 9 limbs to prevent carry overflow on the 8th limb. - ulong a[9]; - for (int i = 0; i < 8; i++) a[i] = t[i]; - a[8] = 0; - - for (int i = 0; i < 4; i++) { - ulong u = a[i] * inv; - - // a += u * m * 2^(64*i) (but we just process limb by limb) - ulong carry = 0; - for (int j = 0; j < 4; j++) { - // a[i+j] += u * m.limbs[j] + carry - // Use 128-bit multiplication via two 64x64->128 bit ops - ulong hi, lo; - - // u * m.limbs[j] - // Metal doesn't have native 128-bit multiply, so we split: - ulong u_lo = u & 0xFFFFFFFFUL; - ulong u_hi = u >> 32; - ulong m_lo = m.limbs[j] & 0xFFFFFFFFUL; - ulong m_hi = m.limbs[j] >> 32; - - ulong ll = u_lo * m_lo; - ulong lh = u_lo * m_hi; - ulong hl = u_hi * m_lo; - ulong hh = u_hi * m_hi; - - ulong mid = lh + (ll >> 32); - ulong mid2 = mid + hl; - if (mid2 < mid) hh += (1UL << 32); - - lo = (mid2 << 32) | (ll & 0xFFFFFFFFUL); - hi = hh + (mid2 >> 32); - - // Add carry - ulong sum = lo + carry; - if (sum < lo) hi++; - lo = sum; - - // Add to a[i+j] - sum = a[i + j] + lo; - if (sum < a[i + j]) hi++; - a[i + j] = sum; - carry = hi; - } - // Propagate carry through index 8. - for (int j = 4; i + j <= 8; j++) { - ulong sum = a[i + j] + carry; - carry = (sum < a[i + j]) ? 1UL : 0UL; - a[i + j] = sum; - if (carry == 0) break; - } - } - - // Result is in a[4..7]. Check a[8] for final subtraction. - uint256 r; - r.limbs[0] = a[4]; - r.limbs[1] = a[5]; - r.limbs[2] = a[6]; - r.limbs[3] = a[7]; - - // Final subtraction if r >= m or if 9th limb is set. - if (a[8] || u256_cmp(r, m) >= 0) { - ulong bw; - r = u256_sub(r, m, bw); - } - return r; -} - -// Montgomery multiplication: a * b * R^(-1) mod m -inline uint256 mont_mul(uint256 a, uint256 b, uint256 m, ulong inv) { - ulong t[8] = {0, 0, 0, 0, 0, 0, 0, 0}; - - for (int i = 0; i < 4; i++) { - ulong carry = 0; - for (int j = 0; j < 4; j++) { - // t[i+j] += a.limbs[i] * b.limbs[j] + carry - ulong a_lo = a.limbs[i] & 0xFFFFFFFFUL; - ulong a_hi = a.limbs[i] >> 32; - ulong b_lo = b.limbs[j] & 0xFFFFFFFFUL; - ulong b_hi = b.limbs[j] >> 32; - - ulong ll = a_lo * b_lo; - ulong lh = a_lo * b_hi; - ulong hl = a_hi * b_lo; - ulong hh = a_hi * b_hi; - - ulong mid = lh + (ll >> 32); - ulong mid2 = mid + hl; - if (mid2 < mid) hh += (1UL << 32); - - ulong lo = (mid2 << 32) | (ll & 0xFFFFFFFFUL); - ulong hi = hh + (mid2 >> 32); - - // Add carry - ulong sum = lo + carry; - if (sum < lo) hi++; - lo = sum; - - // Add to t[i+j] - sum = t[i + j] + lo; - if (sum < t[i + j]) hi++; - t[i + j] = sum; - carry = hi; - } - // Propagate carry into higher limbs - for (int j = 4; i + j < 8; j++) { - ulong sum = t[i + j] + carry; - carry = (sum < t[i + j]) ? 1UL : 0UL; - t[i + j] = sum; - if (carry == 0) break; - } - } - - return mont_reduce(t, m, inv); -} - -// Convert to Montgomery form: a * R mod m -inline uint256 to_mont(uint256 a, uint256 r2, uint256 m, ulong inv) { - return mont_mul(a, r2, m, inv); -} - -// Convert from Montgomery form: aR * R^(-1) mod m = a -inline uint256 from_mont(uint256 a, uint256 m, ulong inv) { - ulong t[8] = {a.limbs[0], a.limbs[1], a.limbs[2], a.limbs[3], 0, 0, 0, 0}; - return mont_reduce(t, m, inv); -} - -// Field operations over p (in Montgomery form) -inline uint256 fp_add(uint256 a, uint256 b) { - ulong carry; - uint256 r = u256_add(a, b, carry); - if (carry || u256_cmp(r, SECP256K1_P) >= 0) { - ulong bw; - r = u256_sub(r, SECP256K1_P, bw); - } - return r; -} - -inline uint256 fp_sub(uint256 a, uint256 b) { - ulong bw; - uint256 r = u256_sub(a, b, bw); - if (bw) { - ulong c; - r = u256_add(r, SECP256K1_P, c); - } - return r; -} - -inline uint256 fp_mul(uint256 a, uint256 b) { - return mont_mul(a, b, SECP256K1_P, P_INV); -} - -inline uint256 fp_sqr(uint256 a) { - return fp_mul(a, a); -} - -// Modular operations over n (scalar field, in Montgomery form) -inline uint256 fn_mul(uint256 a, uint256 b) { - return mont_mul(a, b, SECP256K1_N, N_INV); -} - -// Fermat's little theorem: a^(m-2) mod m -// For field p: a^(p-2) mod p -inline uint256 fp_inv(uint256 a) { - // p - 2 = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2D - // Use square-and-multiply with a chain optimized for secp256k1 p - // We use a simple binary method over the bits of p-2 - uint256 result = to_mont(ONE256, MONT_R2_P, SECP256K1_P, P_INV); - uint256 base = a; - - // p-2 in limbs (little-endian) - ulong exp[4] = { - 0xFFFFFFFEFFFFFC2DUL, 0xFFFFFFFFFFFFFFFFUL, - 0xFFFFFFFFFFFFFFFFUL, 0xFFFFFFFFFFFFFFFFUL - }; - - for (int i = 0; i < 4; i++) { - for (int bit = 0; bit < 64; bit++) { - if ((exp[i] >> bit) & 1) { - result = fp_mul(result, base); - } - base = fp_sqr(base); - } - } - return result; -} - -// Scalar inversion: a^(n-2) mod n (for r_inv) -inline uint256 fn_inv(uint256 a) { - // n-2 in limbs (little-endian) - ulong exp[4] = { - 0xBFD25E8CD036413FUL, 0xBAAEDCE6AF48A03BUL, - 0xFFFFFFFFFFFFFFFEUL, 0xFFFFFFFFFFFFFFFFUL - }; - - uint256 result = to_mont(ONE256, MONT_R2_N, SECP256K1_N, N_INV); - uint256 base = a; - - for (int i = 0; i < 4; i++) { - for (int bit = 0; bit < 64; bit++) { - if ((exp[i] >> bit) & 1) { - result = fn_mul(result, base); - } - base = fn_mul(base, base); - } - } - return result; -} - -// ============================================================================= -// secp256k1 elliptic curve point operations (Jacobian coordinates) -// All coordinates in Montgomery form over F_p -// ============================================================================= - -struct ECPoint { - uint256 x, y, z; // Jacobian: (X, Y, Z), affine = (X/Z^2, Y/Z^3) -}; - -inline ECPoint ec_identity() { - ECPoint p; - p.x = to_mont(ONE256, MONT_R2_P, SECP256K1_P, P_INV); - p.y = to_mont(ONE256, MONT_R2_P, SECP256K1_P, P_INV); - p.z = ZERO256; // Z=0 indicates point at infinity - return p; -} - -inline bool ec_is_infinity(ECPoint p) { - return u256_is_zero(p.z); -} - -// Point doubling in Jacobian coordinates -// Uses the formula for a = 0 (secp256k1: y^2 = x^3 + 7) -inline ECPoint ec_double(ECPoint p) { - if (ec_is_infinity(p)) return p; - - // Using optimized formulas for a=0: - // A = Y^2 - uint256 A = fp_sqr(p.y); - // B = X * A - uint256 B = fp_mul(p.x, A); - // C = A^2 - uint256 C = fp_sqr(A); - // D = 2 * ((X + A)^2 - B - C) ... actually just 4*X*Y^2 - // Simpler: S = 4*B - uint256 S = fp_add(B, B); - S = fp_add(S, S); - // M = 3*X^2 (since a=0 for secp256k1) - uint256 X2 = fp_sqr(p.x); - uint256 M = fp_add(X2, fp_add(X2, X2)); - // X3 = M^2 - 2*S - uint256 X3 = fp_sub(fp_sqr(M), fp_add(S, S)); - // Y3 = M * (S - X3) - 8*C - uint256 C8 = fp_add(C, C); // 2C - C8 = fp_add(C8, C8); // 4C - C8 = fp_add(C8, C8); // 8C - uint256 Y3 = fp_sub(fp_mul(M, fp_sub(S, X3)), C8); - // Z3 = 2*Y*Z - uint256 Z3 = fp_mul(p.y, p.z); - Z3 = fp_add(Z3, Z3); - - ECPoint r; - r.x = X3; - r.y = Y3; - r.z = Z3; - return r; -} - -// Point addition (mixed: Q is affine, P is Jacobian) -// P in Jacobian, Q in affine (Qz = 1 implicitly) -inline ECPoint ec_add_mixed(ECPoint P, uint256 Qx, uint256 Qy) { - if (ec_is_infinity(P)) { - ECPoint r; - r.x = Qx; - r.y = Qy; - r.z = to_mont(ONE256, MONT_R2_P, SECP256K1_P, P_INV); - return r; - } - - // U1 = P.X, U2 = Qx * P.Z^2 - uint256 Z2 = fp_sqr(P.z); - uint256 U2 = fp_mul(Qx, Z2); - // S1 = P.Y, S2 = Qy * P.Z^3 - uint256 Z3 = fp_mul(Z2, P.z); - uint256 S2 = fp_mul(Qy, Z3); - - uint256 H = fp_sub(U2, P.x); - uint256 R = fp_sub(S2, P.y); - - // If H == 0 and R == 0, points are equal -> double - if (u256_is_zero(H)) { - if (u256_is_zero(R)) { - return ec_double(P); - } - // Points are inverses -> return identity - return ec_identity(); - } - - uint256 H2 = fp_sqr(H); - uint256 H3 = fp_mul(H, H2); - uint256 U1H2 = fp_mul(P.x, H2); - - // X3 = R^2 - H^3 - 2*U1*H^2 - uint256 X3 = fp_sub(fp_sub(fp_sqr(R), H3), fp_add(U1H2, U1H2)); - // Y3 = R*(U1*H^2 - X3) - S1*H^3 - uint256 Y3 = fp_sub(fp_mul(R, fp_sub(U1H2, X3)), fp_mul(P.y, H3)); - // Z3 = H * P.Z - uint256 Zr = fp_mul(H, P.z); - - ECPoint res; - res.x = X3; - res.y = Y3; - res.z = Zr; - return res; -} - -// Convert Jacobian to affine coordinates -inline void ec_to_affine(ECPoint p, thread uint256 &ax, thread uint256 &ay) { - if (ec_is_infinity(p)) { - ax = ZERO256; - ay = ZERO256; - return; - } - uint256 z_inv = fp_inv(p.z); - uint256 z_inv2 = fp_sqr(z_inv); - uint256 z_inv3 = fp_mul(z_inv2, z_inv); - ax = fp_mul(p.x, z_inv2); - ay = fp_mul(p.y, z_inv3); -} - -// Scalar multiplication: k * P (double-and-add, constant time not needed for ecrecover) -// k is a regular (non-Montgomery) 256-bit integer -inline ECPoint ec_mul(uint256 k, ECPoint P) { - ECPoint result = ec_identity(); - - for (int i = 3; i >= 0; i--) { - for (int bit = 63; bit >= 0; bit--) { - result = ec_double(result); - if ((k.limbs[i] >> bit) & 1) { - // Add P (we need it in affine for mixed add) - // For simplicity, convert P to affine once and use mixed add - // But P is already Jacobian... use full Jacobian add instead - // Actually for the generator table, we pre-convert to affine - // For general case, use full add: - if (ec_is_infinity(result)) { - result = P; - } else if (ec_is_infinity(P)) { - // nothing - } else { - // Full Jacobian addition - uint256 U1 = fp_mul(result.x, fp_sqr(P.z)); - uint256 U2 = fp_mul(P.x, fp_sqr(result.z)); - uint256 S1 = fp_mul(result.y, fp_mul(fp_sqr(P.z), P.z)); - uint256 S2 = fp_mul(P.y, fp_mul(fp_sqr(result.z), result.z)); - - uint256 H = fp_sub(U2, U1); - uint256 R = fp_sub(S2, S1); - - if (u256_is_zero(H)) { - if (u256_is_zero(R)) { - result = ec_double(result); - } else { - result = ec_identity(); - } - } else { - uint256 H2 = fp_sqr(H); - uint256 H3 = fp_mul(H, H2); - uint256 U1H2 = fp_mul(U1, H2); - - uint256 X3 = fp_sub(fp_sub(fp_sqr(R), H3), fp_add(U1H2, U1H2)); - uint256 Y3 = fp_sub(fp_mul(R, fp_sub(U1H2, X3)), fp_mul(S1, H3)); - uint256 Z3 = fp_mul(fp_mul(H, result.z), P.z); - - result.x = X3; - result.y = Y3; - result.z = Z3; - } - } - } - } - } - return result; -} - -// Scalar multiplication with affine base point (more efficient for generator). -// NOT constant-time: branches on scalar bits. Safe for ecrecover (all inputs -// are public). MUST NOT be reused for signing where the nonce k is secret. -inline ECPoint ec_mul_affine(uint256 k, uint256 Px, uint256 Py) { - ECPoint result = ec_identity(); - - for (int i = 3; i >= 0; i--) { - for (int bit = 63; bit >= 0; bit--) { - result = ec_double(result); - if ((k.limbs[i] >> bit) & 1) { - result = ec_add_mixed(result, Px, Py); - } - } - } - return result; -} - -// ============================================================================= -// Keccak-256 (inline, for address derivation) -// ============================================================================= - -constant ulong KECCAK_RC[24] = { - 0x0000000000000001UL, 0x0000000000008082UL, - 0x800000000000808AUL, 0x8000000080008000UL, - 0x000000000000808BUL, 0x0000000080000001UL, - 0x8000000080008081UL, 0x8000000000008009UL, - 0x000000000000008AUL, 0x0000000000000088UL, - 0x0000000080008009UL, 0x000000008000000AUL, - 0x000000008000808BUL, 0x800000000000008BUL, - 0x8000000000008089UL, 0x8000000000008003UL, - 0x8000000000008002UL, 0x8000000000000080UL, - 0x000000000000800AUL, 0x800000008000000AUL, - 0x8000000080008081UL, 0x8000000000008080UL, - 0x0000000080000001UL, 0x8000000080008008UL, -}; - -constant int KECCAK_PI_LANE[24] = { - 10, 7, 11, 17, 18, 3, 5, 16, 8, 21, 24, 4, - 15, 23, 19, 13, 12, 2, 20, 14, 22, 9, 6, 1 -}; - -constant int KECCAK_RHO[24] = { - 1, 3, 6, 10, 15, 21, 28, 36, 45, 55, 2, 14, - 27, 41, 56, 8, 25, 43, 62, 18, 39, 61, 20, 44 -}; - -inline ulong keccak_rotl64(ulong x, int n) { - return (x << n) | (x >> (64 - n)); -} - -void keccak_f1600(thread ulong st[25]) { - for (int round = 0; round < 24; ++round) { - ulong C[5]; - for (int x = 0; x < 5; ++x) - C[x] = st[x] ^ st[x + 5] ^ st[x + 10] ^ st[x + 15] ^ st[x + 20]; - for (int x = 0; x < 5; ++x) { - ulong d = C[(x + 4) % 5] ^ keccak_rotl64(C[(x + 1) % 5], 1); - for (int y = 0; y < 5; ++y) - st[x + 5 * y] ^= d; - } - ulong t = st[1]; - for (int i = 0; i < 24; ++i) { - ulong tmp = st[KECCAK_PI_LANE[i]]; - st[KECCAK_PI_LANE[i]] = keccak_rotl64(t, KECCAK_RHO[i]); - t = tmp; - } - for (int y = 0; y < 5; ++y) { - ulong row[5]; - for (int x = 0; x < 5; ++x) row[x] = st[x + 5 * y]; - for (int x = 0; x < 5; ++x) - st[x + 5 * y] = row[x] ^ ((~row[(x + 1) % 5]) & row[(x + 2) % 5]); - } - st[0] ^= KECCAK_RC[round]; - } -} - -// Keccak-256 of exactly 64 bytes (uncompressed public key without 0x04 prefix) -inline void keccak256_64(thread const uchar data[64], thread uchar out[32]) { - ulong state[25] = {}; - - // Absorb 64 bytes into first 8 lanes (rate = 136 bytes = 17 lanes) - for (uint w = 0; w < 8; ++w) { - ulong lane = 0; - for (uint b = 0; b < 8; ++b) - lane |= ulong(data[w * 8 + b]) << (b * 8); - state[w] ^= lane; - } - - // Pad: byte 64 gets 0x01, byte 135 gets 0x80 - // Since input is 64 bytes and rate is 136, remaining = 72 bytes of padding - // padded[0] = 0x01 at position 64 (lane 8, byte 0) - state[8] ^= 0x01UL; - // padded[135-64] = 0x80 at position 135 (lane 16, byte 7) - state[16] ^= 0x80UL << 56; - - keccak_f1600(state); - - // Extract 32 bytes - for (uint w = 0; w < 4; ++w) { - ulong lane = state[w]; - for (uint b = 0; b < 8; ++b) - out[w * 8 + b] = uchar(lane >> (b * 8)); - } -} - -// ============================================================================= -// Input/Output structures -// ============================================================================= - -// Per-signature input: packed as (r[32] || s[32] || v[1] || msg_hash[32]) = 97 bytes -// Padded to 128 bytes for alignment -struct EcrecoverInput { - uchar r[32]; // offset 0 - uchar s[32]; // offset 32 - uchar v; // offset 64: recovery id (0 or 1) - uchar _pad[3]; // offset 65: alignment padding - uchar msg_hash[32]; // offset 68 - uchar _pad2[28]; // pad to 128 bytes total -}; - -// Per-signature output: 20-byte Ethereum address, padded to 32 bytes -struct EcrecoverOutput { - uchar address[20]; // offset 0 - uchar valid; // offset 20: 1 if recovery succeeded, 0 otherwise - uchar _pad[11]; // pad to 32 bytes -}; - -// ============================================================================= -// Helper: load 32-byte big-endian into uint256 (little-endian limbs) -// ============================================================================= - -inline uint256 load_be32(thread const uchar bytes[32]) { - uint256 r; - for (int limb = 0; limb < 4; limb++) { - ulong v = 0; - // Limb 0 = bytes[24..31], limb 3 = bytes[0..7] - int base = (3 - limb) * 8; - for (int b = 0; b < 8; b++) { - v = (v << 8) | ulong(bytes[base + b]); - } - r.limbs[limb] = v; - } - return r; -} - -inline uint256 load_be32_device(device const uchar* bytes) { - uint256 r; - for (int limb = 0; limb < 4; limb++) { - ulong v = 0; - int base = (3 - limb) * 8; - for (int b = 0; b < 8; b++) { - v = (v << 8) | ulong(bytes[base + b]); - } - r.limbs[limb] = v; - } - return r; -} - -// Store uint256 (little-endian limbs) as 32-byte big-endian -inline void store_be32(uint256 val, thread uchar bytes[32]) { - for (int limb = 0; limb < 4; limb++) { - int base = (3 - limb) * 8; - ulong v = val.limbs[limb]; - for (int b = 7; b >= 0; b--) { - bytes[base + b] = uchar(v & 0xFF); - v >>= 8; - } - } -} - -// ============================================================================= -// Main kernel: batch secp256k1 ecrecover -// ============================================================================= - -kernel void secp256k1_ecrecover_batch( - device const EcrecoverInput* inputs [[buffer(0)]], - device EcrecoverOutput* outputs [[buffer(1)]], - constant uint& num_sigs [[buffer(2)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= num_sigs) return; - - device const EcrecoverInput& inp = inputs[tid]; - device EcrecoverOutput& out = outputs[tid]; - - // Clear output - for (int i = 0; i < 20; i++) out.address[i] = 0; - out.valid = 0; - for (int i = 0; i < 11; i++) out._pad[i] = 0; - - // Load r, s, v, hash from device memory - uint256 r = load_be32_device(inp.r); - uint256 s = load_be32_device(inp.s); - uint256 e = load_be32_device(inp.msg_hash); - uint v = uint(inp.v); - - // Normalize v: EIP-155 sends v = {0,1,27,28} or chain_id*2+{35,36}. - if (v >= 27) v -= 27; - if (v >= 2) v = v % 2; // handles EIP-155 chain-encoded values - - // Validate: r and s must be in [1, n-1] - if (u256_is_zero(r) || u256_cmp(r, SECP256K1_N) >= 0) return; - if (u256_is_zero(s) || u256_cmp(s, SECP256K1_N) >= 0) return; - if (v > 1) return; - - // Step 1: Decompress r to point R = (r, y) on secp256k1 - // Compute y^2 = x^3 + 7 mod p - uint256 r_mont = to_mont(r, MONT_R2_P, SECP256K1_P, P_INV); - uint256 r2 = fp_sqr(r_mont); - uint256 r3 = fp_mul(r2, r_mont); - uint256 seven_mont = to_mont(uint256{{7, 0, 0, 0}}, MONT_R2_P, SECP256K1_P, P_INV); - uint256 y2 = fp_add(r3, seven_mont); - - // Compute y = sqrt(y2) via Tonelli-Shanks - // For secp256k1, p ≡ 3 mod 4, so sqrt(a) = a^((p+1)/4) - // (p+1)/4 = 0x3FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFBFFFFF0C - uint256 y_mont; - { - ulong exp[4] = { - 0xFFFFFFFFBFFFFF0CUL, 0xFFFFFFFFFFFFFFFFUL, - 0xFFFFFFFFFFFFFFFFUL, 0x3FFFFFFFFFFFFFFFUL - }; - uint256 result = to_mont(ONE256, MONT_R2_P, SECP256K1_P, P_INV); - uint256 base = y2; - for (int i = 0; i < 4; i++) { - for (int bit = 0; bit < 64; bit++) { - if ((exp[i] >> bit) & 1) { - result = fp_mul(result, base); - } - base = fp_sqr(base); - } - } - y_mont = result; - } - - // Verify: y^2 == y2 (sqrt exists) - if (u256_cmp(fp_sqr(y_mont), y2) != 0) return; - - // Select correct y parity based on v - uint256 y_normal = from_mont(y_mont, SECP256K1_P, P_INV); - bool y_is_odd = (y_normal.limbs[0] & 1) != 0; - if ((v == 0 && y_is_odd) || (v == 1 && !y_is_odd)) { - // Negate y: y = p - y - y_mont = fp_sub(ZERO256, y_mont); - // Recalculate -- fp_sub(0, y_mont) needs special handling since 0 in mont form - // is just 0. So: p_mont - y_mont... actually 0 - y_mont in field = p - y_mont - // fp_sub handles the borrow correctly, producing p - y in Montgomery form - } - - // R = (r, y) in Montgomery affine - uint256 Rx_mont = r_mont; - uint256 Ry_mont = y_mont; - - // Step 2: Compute r_inv = r^(-1) mod n - uint256 r_n_mont = to_mont(r, MONT_R2_N, SECP256K1_N, N_INV); - uint256 r_inv_mont = fn_inv(r_n_mont); - - // Step 3: Compute Q = r^(-1) * (s * R - e * G) - // In scalar field: u1 = -e * r^(-1) mod n, u2 = s * r^(-1) mod n - // Then Q = u1 * G + u2 * R - - uint256 e_n_mont = to_mont(e, MONT_R2_N, SECP256K1_N, N_INV); - uint256 s_n_mont = to_mont(s, MONT_R2_N, SECP256K1_N, N_INV); - - // u1 = -(e * r_inv) mod n - uint256 u1_mont = fn_mul(e_n_mont, r_inv_mont); - // Negate in scalar field: n - u1 - uint256 u1 = from_mont(u1_mont, SECP256K1_N, N_INV); - if (!u256_is_zero(u1)) { - ulong bw; - u1 = u256_sub(SECP256K1_N, u1, bw); - } - - // u2 = s * r_inv mod n - uint256 u2 = from_mont(fn_mul(s_n_mont, r_inv_mont), SECP256K1_N, N_INV); - - // Step 4: Multi-scalar multiply Q = u1*G + u2*R - // Generator G in Montgomery form - uint256 Gx_mont = to_mont(GX, MONT_R2_P, SECP256K1_P, P_INV); - uint256 Gy_mont = to_mont(GY, MONT_R2_P, SECP256K1_P, P_INV); - - ECPoint Q1 = ec_mul_affine(u1, Gx_mont, Gy_mont); - ECPoint Q2 = ec_mul_affine(u2, Rx_mont, Ry_mont); - - // Add Q1 + Q2 - ECPoint Q; - if (ec_is_infinity(Q1)) { - Q = Q2; - } else if (ec_is_infinity(Q2)) { - Q = Q1; - } else { - // Convert Q2 to affine for mixed addition - uint256 Q2x_aff, Q2y_aff; - ec_to_affine(Q2, Q2x_aff, Q2y_aff); - Q = ec_add_mixed(Q1, Q2x_aff, Q2y_aff); - } - - if (ec_is_infinity(Q)) return; - - // Step 5: Convert Q to affine, serialize as big-endian bytes - uint256 Qx_aff, Qy_aff; - ec_to_affine(Q, Qx_aff, Qy_aff); - - // Convert from Montgomery - uint256 Qx_norm = from_mont(Qx_aff, SECP256K1_P, P_INV); - uint256 Qy_norm = from_mont(Qy_aff, SECP256K1_P, P_INV); - - // Serialize Q.x || Q.y as 64 bytes big-endian - uchar pubkey[64]; - store_be32(Qx_norm, pubkey); - store_be32(Qy_norm, pubkey + 32); - - // Step 6: address = keccak256(pubkey)[12:] - uchar hash[32]; - keccak256_64(pubkey, hash); - - for (int i = 0; i < 20; i++) { - out.address[i] = hash[12 + i]; - } - out.valid = 1; -} diff --git a/secp256k1/gpu/wgsl/secp256k1.wgsl b/secp256k1/gpu/wgsl/secp256k1.wgsl deleted file mode 100644 index 31e4e2b..0000000 --- a/secp256k1/gpu/wgsl/secp256k1.wgsl +++ /dev/null @@ -1,695 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// secp256k1 ECDSA public key recovery (ecrecover) in WGSL. -// Matches secp256k1_recover.metal output byte-for-byte. -// 256-bit arithmetic uses 8 x u32 limbs (no native u64 in WGSL). -// -// Per thread: (r, s, v, msg_hash) -> 20-byte Ethereum address -// Algorithm: Q = r^{-1} * (s*R - e*G), address = keccak256(Q)[12:] - -// Input: [r[32], s[32], v[1], pad[3], msg_hash[32], pad[28]] = 128 bytes per sig -// Output: [address[20], valid[1], pad[11]] = 32 bytes per sig - -@group(0) @binding(0) var inputs: array; -@group(0) @binding(1) var outputs: array; -@group(0) @binding(2) var params: Params; - -struct Params { - num_items: u32, -} - -// ============================================================================ -// 256-bit integer as 8 x u32 (little-endian) -// ============================================================================ - -fn u256_zero() -> array { - return array(0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u); -} - -fn u256_is_zero(a: ptr>) -> bool { - var acc = 0u; - for (var i = 0u; i < 8u; i = i + 1u) { acc = acc | (*a)[i]; } - return acc == 0u; -} - -fn u256_cmp(a: ptr>, b: ptr>) -> i32 { - for (var i = 7i; i >= 0; i = i - 1) { - let ui = u32(i); - if ((*a)[ui] > (*b)[ui]) { return 1; } - if ((*a)[ui] < (*b)[ui]) { return -1; } - } - return 0; -} - -fn u256_add(a: ptr>, b: ptr>, - r: ptr>) -> u32 { - var c = 0u; - for (var i = 0u; i < 8u; i = i + 1u) { - let s1 = (*a)[i] + c; - c = select(0u, 1u, s1 < (*a)[i]); - let s2 = s1 + (*b)[i]; - c = c + select(0u, 1u, s2 < s1); - (*r)[i] = s2; - } - return c; -} - -fn u256_sub(a: ptr>, b: ptr>, - r: ptr>) -> u32 { - var bw = 0u; - for (var i = 0u; i < 8u; i = i + 1u) { - let d1 = (*a)[i] - bw; - bw = select(0u, 1u, d1 > (*a)[i]); - let d2 = d1 - (*b)[i]; - bw = bw + select(0u, 1u, d2 > d1); - (*r)[i] = d2; - } - return bw; -} - -// ============================================================================ -// secp256k1 constants (8 x u32 little-endian) -// ============================================================================ - -// Field prime p = 0xFFFFFFFF...FFFFFFFEFFFFFC2F -const SECP_P = array( - 0xFFFFFC2Fu, 0xFFFFFFFEu, 0xFFFFFFFFu, 0xFFFFFFFFu, - 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu -); -// Curve order n -const SECP_N = array( - 0xD0364141u, 0xBFD25E8Cu, 0xAF48A03Bu, 0xBAAEDCE6u, - 0xFFFFFFFEu, 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu -); -// Montgomery R mod p -const MONT_R_P = array( - 0x000003D1u, 0x00000001u, 0x00000000u, 0x00000000u, - 0x00000000u, 0x00000000u, 0x00000000u, 0x00000000u -); -// R^2 mod p -const MONT_R2_P = array( - 0x000E90A1u, 0x000007A2u, 0x00000001u, 0x00000000u, - 0x00000000u, 0x00000000u, 0x00000000u, 0x00000000u -); -// -p^{-1} mod 2^32 -const P_INV: u32 = 0xD2253531u; -// R^2 mod n -const MONT_R2_N = array( - 0x67D7D140u, 0x896CF214u, 0x0E7CF878u, 0x741496C2u, - 0x5BCD07C6u, 0xE697F5E4u, 0x81C69BC5u, 0x9D671CD5u -); -// -n^{-1} mod 2^32 -const N_INV: u32 = 0x5588B13Fu; -// Generator G.x -const GX = array( - 0x16F81798u, 0x59F2815Bu, 0x2DCE28D9u, 0x029BFCDB, - 0xCE870B07u, 0x55A06295u, 0xF9DCBBACu, 0x79BE667Eu -); -// Generator G.y -const GY = array( - 0xFB10D4B8u, 0x9C47D08Fu, 0xA6855419u, 0xFD17B448u, - 0x0E1108A8u, 0x5DA4FBFC, 0x26A3C465u, 0x483ADA77u -); - -// ============================================================================ -// Montgomery multiplication (256-bit, 8x u32 limbs) -// ============================================================================ - -fn mont_reduce(t: ptr>, m: ptr>, - inv: u32, r: ptr>) { - // Extended to 17 limbs for carry - var a: array; - for (var i = 0u; i < 16u; i = i + 1u) { a[i] = (*t)[i]; } - a[16] = 0u; - - for (var i = 0u; i < 8u; i = i + 1u) { - let u = a[i] * inv; - var carry = 0u; - for (var j = 0u; j < 8u; j = j + 1u) { - // u * m[j] -> (hi, lo) - let u_lo = u & 0xFFFFu; let u_hi = u >> 16u; - let m_lo = (*m)[j] & 0xFFFFu; let m_hi = (*m)[j] >> 16u; - let ll = u_lo * m_lo; - let lh = u_lo * m_hi; - let hl = u_hi * m_lo; - let hh = u_hi * m_hi; - let mid = lh + hl; - var lo = ll + (mid << 16u); - var hi = hh + (mid >> 16u) + select(0u, 1u, lo < ll) + select(0u, 0x10000u, mid < lh); - - // lo += carry - let s1 = lo + carry; - hi = hi + select(0u, 1u, s1 < lo); - // a[i+j] += s1 - let s2 = a[i + j] + s1; - hi = hi + select(0u, 1u, s2 < a[i + j]); - a[i + j] = s2; - carry = hi; - } - for (var j = 8u; i + j <= 16u; j = j + 1u) { - let s = a[i + j] + carry; - carry = select(0u, 1u, s < a[i + j]); - a[i + j] = s; - if (carry == 0u) { break; } - } - } - - for (var i = 0u; i < 8u; i = i + 1u) { (*r)[i] = a[i + 8u]; } - - // Final subtraction if r >= m - if (a[16] != 0u || u256_cmp(r, m) >= 0) { - let _ = u256_sub(r, m, r); - } -} - -fn mont_mul(a: ptr>, b: ptr>, - m: ptr>, inv: u32, r: ptr>) { - var t: array; - for (var i = 0u; i < 16u; i = i + 1u) { t[i] = 0u; } - - for (var i = 0u; i < 8u; i = i + 1u) { - var carry = 0u; - for (var j = 0u; j < 8u; j = j + 1u) { - let al = (*a)[i] & 0xFFFFu; let ah = (*a)[i] >> 16u; - let bl = (*b)[j] & 0xFFFFu; let bh = (*b)[j] >> 16u; - let ll = al * bl; - let lh = al * bh; - let hl = ah * bl; - let hh = ah * bh; - let mid = lh + hl; - var lo = ll + (mid << 16u); - var hi = hh + (mid >> 16u) + select(0u, 1u, lo < ll) + select(0u, 0x10000u, mid < lh); - let s1 = lo + carry; hi = hi + select(0u, 1u, s1 < lo); - let s2 = t[i + j] + s1; hi = hi + select(0u, 1u, s2 < t[i + j]); - t[i + j] = s2; - carry = hi; - } - for (var j = 8u; i + j < 16u; j = j + 1u) { - let s = t[i + j] + carry; - carry = select(0u, 1u, s < t[i + j]); - t[i + j] = s; - if (carry == 0u) { break; } - } - } - mont_reduce(&t, m, inv, r); -} - -// Field ops over p (Montgomery form) -fn fp_add(a: ptr>, b: ptr>, - r: ptr>) { - var p = SECP_P; - let c = u256_add(a, b, r); - if (c != 0u || u256_cmp(r, &p) >= 0) { - let _ = u256_sub(r, &p, r); - } -} - -fn fp_sub(a: ptr>, b: ptr>, - r: ptr>) { - var p = SECP_P; - let bw = u256_sub(a, b, r); - if (bw != 0u) { - let _ = u256_add(r, &p, r); - } -} - -fn fp_mul(a: ptr>, b: ptr>, - r: ptr>) { - var p = SECP_P; - mont_mul(a, b, &p, P_INV, r); -} - -fn fp_sqr(a: ptr>, r: ptr>) { - var p = SECP_P; - mont_mul(a, a, &p, P_INV, r); -} - -fn fn_mul(a: ptr>, b: ptr>, - r: ptr>) { - var n = SECP_N; - mont_mul(a, b, &n, N_INV, r); -} - -fn to_mont_p(a: ptr>, r: ptr>) { - var r2 = MONT_R2_P; - fp_mul(a, &r2, r); -} - -fn from_mont_p(a: ptr>, r: ptr>) { - var p = SECP_P; - var t: array; - for (var i = 0u; i < 16u; i = i + 1u) { t[i] = 0u; } - for (var i = 0u; i < 8u; i = i + 1u) { t[i] = (*a)[i]; } - mont_reduce(&t, &p, P_INV, r); -} - -fn to_mont_n(a: ptr>, r: ptr>) { - var r2 = MONT_R2_N; - fn_mul(a, &r2, r); -} - -fn from_mont_n(a: ptr>, r: ptr>) { - var n = SECP_N; - var t: array; - for (var i = 0u; i < 16u; i = i + 1u) { t[i] = 0u; } - for (var i = 0u; i < 8u; i = i + 1u) { t[i] = (*a)[i]; } - mont_reduce(&t, &n, N_INV, r); -} - -// Modular inversion via Fermat: a^(m-2) mod m -fn fp_inv(a: ptr>, r: ptr>) { - // p-2 little-endian u32 limbs - var exp = array( - 0xFFFFFC2Du, 0xFFFFFFFEu, 0xFFFFFFFFu, 0xFFFFFFFFu, - 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu - ); - var one = array(1u, 0u, 0u, 0u, 0u, 0u, 0u, 0u); - var result: array; - to_mont_p(&one, &result); - var base: array; - for (var i = 0u; i < 8u; i = i + 1u) { base[i] = (*a)[i]; } - - for (var i = 0u; i < 8u; i = i + 1u) { - for (var bit = 0u; bit < 32u; bit = bit + 1u) { - if (((exp[i] >> bit) & 1u) != 0u) { - var tmp: array; - fp_mul(&result, &base, &tmp); - result = tmp; - } - var tmp2: array; - fp_sqr(&base, &tmp2); - base = tmp2; - } - } - *r = result; -} - -fn fn_inv(a: ptr>, r: ptr>) { - // n-2 - var exp = array( - 0xD036413Fu, 0xBFD25E8Cu, 0xAF48A03Bu, 0xBAAEDCE6u, - 0xFFFFFFFEu, 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu - ); - var one = array(1u, 0u, 0u, 0u, 0u, 0u, 0u, 0u); - var result: array; - to_mont_n(&one, &result); - var base: array; - for (var i = 0u; i < 8u; i = i + 1u) { base[i] = (*a)[i]; } - - for (var i = 0u; i < 8u; i = i + 1u) { - for (var bit = 0u; bit < 32u; bit = bit + 1u) { - if (((exp[i] >> bit) & 1u) != 0u) { - var tmp: array; - fn_mul(&result, &base, &tmp); - result = tmp; - } - var tmp2: array; - fn_mul(&base, &base, &tmp2); - base = tmp2; - } - } - *r = result; -} - -// ============================================================================ -// EC point operations (Jacobian, Montgomery Fp) -// Point = (x[8], y[8], z[8]) = 24 u32 words -// ============================================================================ - -struct ECPoint { - x: array, - y: array, - z: array, -} - -fn ec_identity() -> ECPoint { - var p: ECPoint; - var one = array(1u, 0u, 0u, 0u, 0u, 0u, 0u, 0u); - to_mont_p(&one, &p.x); - p.y = p.x; - p.z = u256_zero(); - return p; -} - -fn ec_is_inf(p: ptr) -> bool { - var z = (*p).z; - return u256_is_zero(&z); -} - -fn ec_double(p: ptr, r: ptr) { - if (ec_is_inf(p)) { *r = *p; return; } - var A: array; fp_sqr(&(*p).y, &A); - var B: array; fp_mul(&(*p).x, &A, &B); - var C: array; fp_sqr(&A, &C); - // S = 4*B - var S: array; fp_add(&B, &B, &S); fp_add(&S, &S, &S); - // M = 3*X^2 (a=0) - var X2: array; fp_sqr(&(*p).x, &X2); - var X2_2: array; fp_add(&X2, &X2, &X2_2); - var M: array; fp_add(&X2_2, &X2, &M); - // X3 = M^2 - 2S - var M2: array; fp_sqr(&M, &M2); - var S2: array; fp_add(&S, &S, &S2); - var X3: array; fp_sub(&M2, &S2, &X3); - // Y3 = M*(S-X3) - 8C - var SX: array; fp_sub(&S, &X3, &SX); - var MSX: array; fp_mul(&M, &SX, &MSX); - var C2: array; fp_add(&C, &C, &C2); - var C4: array; fp_add(&C2, &C2, &C4); - var C8: array; fp_add(&C4, &C4, &C8); - var Y3: array; fp_sub(&MSX, &C8, &Y3); - // Z3 = 2*Y*Z - var YZ: array; fp_mul(&(*p).y, &(*p).z, &YZ); - var Z3: array; fp_add(&YZ, &YZ, &Z3); - (*r).x = X3; (*r).y = Y3; (*r).z = Z3; -} - -fn ec_add_mixed(P: ptr, Qx: ptr>, - Qy: ptr>, r: ptr) { - if (ec_is_inf(P)) { - (*r).x = *Qx; (*r).y = *Qy; - var one = array(1u, 0u, 0u, 0u, 0u, 0u, 0u, 0u); - to_mont_p(&one, &(*r).z); - return; - } - var Z2: array; fp_sqr(&(*P).z, &Z2); - var U2: array; fp_mul(Qx, &Z2, &U2); - var Z3: array; fp_mul(&Z2, &(*P).z, &Z3); - var S2: array; fp_mul(Qy, &Z3, &S2); - var H: array; fp_sub(&U2, &(*P).x, &H); - var R: array; fp_sub(&S2, &(*P).y, &R); - - if (u256_is_zero(&H)) { - if (u256_is_zero(&R)) { ec_double(P, r); return; } - *r = ec_identity(); - return; - } - - var H2: array; fp_sqr(&H, &H2); - var H3: array; fp_mul(&H, &H2, &H3); - var U1H2: array; fp_mul(&(*P).x, &H2, &U1H2); - // X3 = R^2 - H^3 - 2*U1H2 - var R2: array; fp_sqr(&R, &R2); - var U1H2_2: array; fp_add(&U1H2, &U1H2, &U1H2_2); - var t1: array; fp_sub(&R2, &H3, &t1); - var X3: array; fp_sub(&t1, &U1H2_2, &X3); - // Y3 = R*(U1H2 - X3) - Y1*H3 - var UX: array; fp_sub(&U1H2, &X3, &UX); - var RUX: array; fp_mul(&R, &UX, &RUX); - var YH3: array; fp_mul(&(*P).y, &H3, &YH3); - var Y3: array; fp_sub(&RUX, &YH3, &Y3); - // Z3 = H * P.Z - var Zr: array; fp_mul(&H, &(*P).z, &Zr); - (*r).x = X3; (*r).y = Y3; (*r).z = Zr; -} - -fn ec_mul_affine(k: ptr>, - Px: ptr>, - Py: ptr>) -> ECPoint { - var result = ec_identity(); - for (var i = 7i; i >= 0; i = i - 1) { - for (var bit = 31i; bit >= 0; bit = bit - 1) { - var dbl: ECPoint; - ec_double(&result, &dbl); - result = dbl; - if ((((*k)[u32(i)] >> u32(bit)) & 1u) != 0u) { - var tmp: ECPoint; - ec_add_mixed(&result, Px, Py, &tmp); - result = tmp; - } - } - } - return result; -} - -fn ec_to_affine(p: ptr, ax: ptr>, - ay: ptr>) { - if (ec_is_inf(p)) { *ax = u256_zero(); *ay = u256_zero(); return; } - var z_inv: array; fp_inv(&(*p).z, &z_inv); - var z_inv2: array; fp_sqr(&z_inv, &z_inv2); - var z_inv3: array; fp_mul(&z_inv2, &z_inv, &z_inv3); - fp_mul(&(*p).x, &z_inv2, ax); - fp_mul(&(*p).y, &z_inv3, ay); -} - -// ============================================================================ -// Inline Keccak-256 for 64 bytes (public key -> address) -// ============================================================================ - -var kst_lo: array; -var kst_hi: array; - -const KRC_LO = array( - 0x00000001u, 0x00008082u, 0x0000808Au, 0x80008000u, - 0x0000808Bu, 0x80000001u, 0x80008081u, 0x00008009u, - 0x0000008Au, 0x00000088u, 0x80008009u, 0x8000000Au, - 0x8000808Bu, 0x0000008Bu, 0x00008089u, 0x00008003u, - 0x00008002u, 0x00000080u, 0x0000800Au, 0x8000000Au, - 0x80008081u, 0x00008080u, 0x80000001u, 0x80008008u -); -const KRC_HI = array( - 0x00000000u, 0x00000000u, 0x80000000u, 0x80000000u, - 0x00000000u, 0x00000000u, 0x80000000u, 0x80000000u, - 0x00000000u, 0x00000000u, 0x00000000u, 0x00000000u, - 0x00000000u, 0x80000000u, 0x80000000u, 0x80000000u, - 0x80000000u, 0x80000000u, 0x00000000u, 0x80000000u, - 0x80000000u, 0x80000000u, 0x00000000u, 0x80000000u -); -const KPI = array( - 10u, 7u, 11u, 17u, 18u, 3u, 5u, 16u, 8u, 21u, 24u, 4u, - 15u, 23u, 19u, 13u, 12u, 2u, 20u, 14u, 22u, 9u, 6u, 1u -); -const KRHO = array( - 1u, 3u, 6u, 10u, 15u, 21u, 28u, 36u, 45u, 55u, 2u, 14u, - 27u, 41u, 56u, 8u, 25u, 43u, 62u, 18u, 39u, 61u, 20u, 44u -); - -fn krotl64(lo: u32, hi: u32, n: u32) -> vec2 { - if (n == 0u) { return vec2(lo, hi); } - if (n == 32u) { return vec2(hi, lo); } - if (n < 32u) { - return vec2((lo << n) | (hi >> (32u - n)), (hi << n) | (lo >> (32u - n))); - } - let m = n - 32u; - return vec2((hi << m) | (lo >> (32u - m)), (lo << m) | (hi >> (32u - m))); -} - -fn keccak_f() { - for (var round = 0u; round < 24u; round = round + 1u) { - var c_lo: array; var c_hi: array; - for (var x = 0u; x < 5u; x = x + 1u) { - c_lo[x] = kst_lo[x] ^ kst_lo[x+5u] ^ kst_lo[x+10u] ^ kst_lo[x+15u] ^ kst_lo[x+20u]; - c_hi[x] = kst_hi[x] ^ kst_hi[x+5u] ^ kst_hi[x+10u] ^ kst_hi[x+15u] ^ kst_hi[x+20u]; - } - for (var x = 0u; x < 5u; x = x + 1u) { - let r = krotl64(c_lo[(x+1u)%5u], c_hi[(x+1u)%5u], 1u); - let d_lo = c_lo[(x+4u)%5u] ^ r.x; - let d_hi = c_hi[(x+4u)%5u] ^ r.y; - for (var y = 0u; y < 5u; y = y + 1u) { - let idx = x + 5u * y; - kst_lo[idx] = kst_lo[idx] ^ d_lo; - kst_hi[idx] = kst_hi[idx] ^ d_hi; - } - } - var t_lo = kst_lo[1u]; var t_hi = kst_hi[1u]; - for (var i = 0u; i < 24u; i = i + 1u) { - let dst = KPI[i]; - let tmp_lo = kst_lo[dst]; let tmp_hi = kst_hi[dst]; - let r = krotl64(t_lo, t_hi, KRHO[i]); - kst_lo[dst] = r.x; kst_hi[dst] = r.y; - t_lo = tmp_lo; t_hi = tmp_hi; - } - for (var y = 0u; y < 5u; y = y + 1u) { - var rl: array; var rh: array; - for (var x = 0u; x < 5u; x = x + 1u) { - rl[x] = kst_lo[x + 5u*y]; rh[x] = kst_hi[x + 5u*y]; - } - for (var x = 0u; x < 5u; x = x + 1u) { - kst_lo[x+5u*y] = rl[x] ^ ((~rl[(x+1u)%5u]) & rl[(x+2u)%5u]); - kst_hi[x+5u*y] = rh[x] ^ ((~rh[(x+1u)%5u]) & rh[(x+2u)%5u]); - } - } - kst_lo[0] = kst_lo[0] ^ KRC_LO[round]; - kst_hi[0] = kst_hi[0] ^ KRC_HI[round]; - } -} - -fn keccak256_64(data: ptr>, hash: ptr>) { - for (var i = 0u; i < 25u; i = i + 1u) { kst_lo[i] = 0u; kst_hi[i] = 0u; } - // Absorb 64 bytes = 8 lanes (each lane = 8 bytes = 2 u32 words) - for (var w = 0u; w < 8u; w = w + 1u) { - kst_lo[w] = kst_lo[w] ^ (*data)[w * 2u]; - kst_hi[w] = kst_hi[w] ^ (*data)[w * 2u + 1u]; - } - // Keccak padding: byte 64 = 0x01, byte 135 = 0x80 - kst_lo[8] = kst_lo[8] ^ 0x01u; - kst_hi[16] = kst_hi[16] ^ 0x80000000u; - keccak_f(); - for (var w = 0u; w < 4u; w = w + 1u) { - (*hash)[w * 2u] = kst_lo[w]; - (*hash)[w * 2u + 1u] = kst_hi[w]; - } -} - -// ============================================================================ -// Load/store helpers (big-endian 32 bytes <-> u256 little-endian u32 limbs) -// ============================================================================ - -fn load_be32(word_base: u32) -> array { - // Input is 32 bytes = 8 u32 words in the inputs array (byte-packed) - // Stored as big-endian in the input. We need to reverse byte order within words - // and reverse word order for little-endian limbs. - var r: array; - for (var i = 0u; i < 8u; i = i + 1u) { - let w = inputs[word_base + 7u - i]; - // Byte-swap u32 (big-endian to little-endian) - r[i] = ((w >> 24u) & 0xFFu) | (((w >> 16u) & 0xFFu) << 8u) - | (((w >> 8u) & 0xFFu) << 16u) | ((w & 0xFFu) << 24u); - } - return r; -} - -// ============================================================================ -// Main kernel -// ============================================================================ - -@compute @workgroup_size(256) -fn secp256k1_ecrecover(@builtin(global_invocation_id) gid: vec3) { - let tid = gid.x; - if (tid >= params.num_items) { return; } - - // Clear output - let out_base = tid * 8u; // 32 bytes = 8 u32 - for (var i = 0u; i < 8u; i = i + 1u) { outputs[out_base + i] = 0u; } - - // Load signature: 128 bytes = 32 u32 per sig - let in_base = tid * 32u; - var r = load_be32(in_base); // r: bytes 0..31 - var s = load_be32(in_base + 8u); // s: bytes 32..63 - let v_byte = (inputs[in_base + 16u]) & 0xFFu; // v: byte 64 - var e = load_be32(in_base + 17u); // msg_hash: bytes 68..99 - - var v = v_byte; - if (v >= 27u) { v = v - 27u; } - if (v >= 2u) { v = v % 2u; } - - // Validate r, s in [1, n-1] - var n = SECP_N; - if (u256_is_zero(&r) || u256_cmp(&r, &n) >= 0) { return; } - if (u256_is_zero(&s) || u256_cmp(&s, &n) >= 0) { return; } - if (v > 1u) { return; } - - // Decompress r -> R = (r, y) - var r_mont: array; to_mont_p(&r, &r_mont); - var r2: array; fp_sqr(&r_mont, &r2); - var r3: array; fp_mul(&r2, &r_mont, &r3); - var seven = array(7u, 0u, 0u, 0u, 0u, 0u, 0u, 0u); - var seven_mont: array; to_mont_p(&seven, &seven_mont); - var y2: array; fp_add(&r3, &seven_mont, &y2); - - // sqrt via a^((p+1)/4) since p = 3 mod 4 - var exp_sqrt = array( - 0xBFFFFF0Cu, 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu, - 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu, 0x3FFFFFFFu - ); - var one = array(1u, 0u, 0u, 0u, 0u, 0u, 0u, 0u); - var y_mont: array; to_mont_p(&one, &y_mont); - var base_y = y2; - for (var i = 0u; i < 8u; i = i + 1u) { - for (var bit = 0u; bit < 32u; bit = bit + 1u) { - if (((exp_sqrt[i] >> bit) & 1u) != 0u) { - var tmp: array; - fp_mul(&y_mont, &base_y, &tmp); - y_mont = tmp; - } - var tmp2: array; - fp_sqr(&base_y, &tmp2); - base_y = tmp2; - } - } - - // Verify sqrt: y^2 == y2 - var check: array; fp_sqr(&y_mont, &check); - if (u256_cmp(&check, &y2) != 0) { return; } - - // Select y parity - var y_normal: array; from_mont_p(&y_mont, &y_normal); - let y_is_odd = (y_normal[0] & 1u) != 0u; - if ((v == 0u && y_is_odd) || (v == 1u && !y_is_odd)) { - var zero_val = u256_zero(); - fp_sub(&zero_val, &y_mont, &y_mont); - } - - // r_inv = r^{-1} mod n - var r_n_mont: array; to_mont_n(&r, &r_n_mont); - var r_inv_mont: array; fn_inv(&r_n_mont, &r_inv_mont); - - // u1 = -(e * r_inv) mod n, u2 = s * r_inv mod n - var e_n_mont: array; to_mont_n(&e, &e_n_mont); - var s_n_mont: array; to_mont_n(&s, &s_n_mont); - - var u1_mont: array; fn_mul(&e_n_mont, &r_inv_mont, &u1_mont); - var u1: array; from_mont_n(&u1_mont, &u1); - if (!u256_is_zero(&u1)) { - var nn = SECP_N; - let _ = u256_sub(&nn, &u1, &u1); - } - - var u2_mont: array; fn_mul(&s_n_mont, &r_inv_mont, &u2_mont); - var u2: array; from_mont_n(&u2_mont, &u2); - - // Q = u1*G + u2*R - var Gx_mont: array; var gx = GX; to_mont_p(&gx, &Gx_mont); - var Gy_mont: array; var gy = GY; to_mont_p(&gy, &Gy_mont); - - var Q1 = ec_mul_affine(&u1, &Gx_mont, &Gy_mont); - var Q2 = ec_mul_affine(&u2, &r_mont, &y_mont); - - // Add Q1 + Q2 - var Q: ECPoint; - if (ec_is_inf(&Q1)) { - Q = Q2; - } else if (ec_is_inf(&Q2)) { - Q = Q1; - } else { - var Q2x_aff: array; var Q2y_aff: array; - ec_to_affine(&Q2, &Q2x_aff, &Q2y_aff); - ec_add_mixed(&Q1, &Q2x_aff, &Q2y_aff, &Q); - } - - if (ec_is_inf(&Q)) { return; } - - var Qx_aff: array; var Qy_aff: array; - ec_to_affine(&Q, &Qx_aff, &Qy_aff); - var Qx_norm: array; from_mont_p(&Qx_aff, &Qx_norm); - var Qy_norm: array; from_mont_p(&Qy_aff, &Qy_norm); - - // Serialize Q.x || Q.y as 16 u32 words (big-endian bytes within each 32-byte half) - var pubkey: array; - for (var i = 0u; i < 8u; i = i + 1u) { - let w = Qx_norm[7u - i]; - pubkey[i] = ((w >> 24u) & 0xFFu) | (((w >> 16u) & 0xFFu) << 8u) - | (((w >> 8u) & 0xFFu) << 16u) | ((w & 0xFFu) << 24u); - } - for (var i = 0u; i < 8u; i = i + 1u) { - let w = Qy_norm[7u - i]; - pubkey[8u + i] = ((w >> 24u) & 0xFFu) | (((w >> 16u) & 0xFFu) << 8u) - | (((w >> 8u) & 0xFFu) << 16u) | ((w & 0xFFu) << 24u); - } - - // address = keccak256(pubkey)[12:] - var hash: array; - keccak256_64(&pubkey, &hash); - - // Output: address (bytes 12-31 of hash) = last 20 bytes - // hash is 32 bytes = 8 u32 words. Bytes 12..31 = words 3..7 (but byte offset 12 = word 3 byte 0) - // Store as 5 u32 words at output (20 bytes), then valid byte - outputs[out_base] = hash[3]; - outputs[out_base + 1u] = hash[4]; - outputs[out_base + 2u] = hash[5]; - outputs[out_base + 3u] = hash[6]; - outputs[out_base + 4u] = hash[7]; - // valid byte at output byte 20 = word 5 - outputs[out_base + 5u] = 1u; -} diff --git a/secp256k1/gpu/wgsl/secp256k1_batch_inv.wgsl b/secp256k1/gpu/wgsl/secp256k1_batch_inv.wgsl deleted file mode 100644 index c1587d4..0000000 --- a/secp256k1/gpu/wgsl/secp256k1_batch_inv.wgsl +++ /dev/null @@ -1,393 +0,0 @@ -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Montgomery batch inversion for secp256k1 -- WGSL port of -// secp256k1_batch_inv.metal. Output byte-equal to: -// * cpp/batch_inv.hpp (CPU canonical body) -// * gpu/metal/secp256k1_batch_inv.metal (Metal kernel) -// * gpu/cuda/secp256k1_batch_inv.cu (CUDA kernel) -// -// WGSL has no native u64; each 64-bit limb is represented as two u32 halves -// stored little-endian (lo,hi) in storage. Inside the kernel we lift to a -// pair of u32 and implement 64x64->128 multiply via four 32x32->64 partials. -// This matches the carry sequence in mont_mul exactly. -// -// Single-thread workgroup keeps byte-equal determinism with the CPU body. - -// 4 x u64 = 8 x u32. Layout: limbs[0].lo, limbs[0].hi, limbs[1].lo, limbs[1].hi, ... -struct U256 { w: array; }; - -@group(0) @binding(0) var in_buf: array; -@group(0) @binding(1) var out_buf: array; -@group(0) @binding(2) var cfg: vec4; -// cfg.x = n, cfg.y = kind (0 = Fp, 1 = Fn) - -// ---------- secp256k1 constants (limb little-endian, u32 halves) ---------- - -fn P_MOD() -> U256 { - return U256(array( - 0xFFFFFC2Fu, 0xFFFFFFFEu, 0xFFFFFFFFu, 0xFFFFFFFFu, - 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu, - )); -} -fn N_MOD() -> U256 { - return U256(array( - 0xD0364141u, 0xBFD25E8Cu, 0xAF48A03Bu, 0xBAAEDCE6u, - 0xFFFFFFFEu, 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu, - )); -} -fn ONE_MONT_P() -> U256 { - return U256(array( - 0x000003D1u, 0x00000001u, 0u, 0u, 0u, 0u, 0u, 0u, - )); -} -fn R2_N() -> U256 { - return U256(array( - 0x67D7D140u, 0x896CF214u, 0x0E7CF878u, 0x741496C2u, - 0x5BCD07C6u, 0xE697F5E4u, 0x81C69BC5u, 0x9D671CD5u, - )); -} -fn ONE() -> U256 { - return U256(array(1u, 0u, 0u, 0u, 0u, 0u, 0u, 0u)); -} - -// P_INV = 0xD838091DD2253531 -fn P_INV_LO() -> u32 { return 0xD2253531u; } -fn P_INV_HI() -> u32 { return 0xD838091Du; } -// N_INV = 0x4B0DFF665588B13F -fn N_INV_LO() -> u32 { return 0x5588B13Fu; } -fn N_INV_HI() -> u32 { return 0x4B0DFF66u; } - -// p - 2 -fn P_M2() -> U256 { - return U256(array( - 0xFFFFFC2Du, 0xFFFFFFFEu, 0xFFFFFFFFu, 0xFFFFFFFFu, - 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu, - )); -} -// n - 2 -fn N_M2() -> U256 { - return U256(array( - 0xD036413Fu, 0xBFD25E8Cu, 0xAF48A03Bu, 0xBAAEDCE6u, - 0xFFFFFFFEu, 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu, - )); -} - -// ---------- 64-bit limb load/store helpers ---------- - -struct U64 { lo: u32, hi: u32 }; - -fn limb(a: U256, i: u32) -> U64 { - return U64(a.w[i * 2u], a.w[i * 2u + 1u]); -} -fn set_limb(a: ptr, i: u32, v: U64) { - (*a).w[i * 2u] = v.lo; - (*a).w[i * 2u + 1u] = v.hi; -} - -// ---------- 64-bit primitives ---------- - -fn u64_add(a: U64, b: U64) -> U64 { - let lo = a.lo + b.lo; - var carry: u32 = 0u; - if (lo < a.lo) { carry = 1u; } - let hi = a.hi + b.hi + carry; - return U64(lo, hi); -} - -// Returns (sum, carry-out) for a + b + cin -fn u64_addc(a: U64, b: U64, cin: u32) -> array { - let s1lo = a.lo + b.lo; - var c1: u32 = 0u; - if (s1lo < a.lo) { c1 = 1u; } - let s1hi = a.hi + b.hi + c1; - var co1: u32 = 0u; - // detect overflow on hi: s1hi < a.hi (when c1 == 0) or s1hi <= a.hi (when c1 == 1) - if (c1 == 0u) { - if (s1hi < a.hi) { co1 = 1u; } - } else { - if (s1hi <= a.hi) { co1 = 1u; } - } - // add cin (a u64 with hi = 0) - let s2lo = s1lo + cin; - var c2: u32 = 0u; - if (s2lo < s1lo) { c2 = 1u; } - let s2hi = s1hi + c2; - var co2: u32 = 0u; - if (s2hi < s1hi) { co2 = 1u; } - let co = co1 + co2; - return array(U64(s2lo, s2hi), U64(co, 0u)); -} - -// Returns (diff, borrow-out) for a - b - bin -fn u64_subb(a: U64, b: U64, bin: u32) -> array { - let d1lo = a.lo - b.lo; - var br1: u32 = 0u; - if (d1lo > a.lo) { br1 = 1u; } - let d1hi = a.hi - b.hi - br1; - var bo1: u32 = 0u; - if (br1 == 0u) { - if (d1hi > a.hi) { bo1 = 1u; } - } else { - if (d1hi >= a.hi) { bo1 = 1u; } - } - let d2lo = d1lo - bin; - var br2: u32 = 0u; - if (d2lo > d1lo) { br2 = 1u; } - let d2hi = d1hi - br2; - var bo2: u32 = 0u; - if (d2hi > d1hi) { bo2 = 1u; } - let bo = bo1 + bo2; - return array(U64(d2lo, d2hi), U64(bo, 0u)); -} - -// 64x64 -> (lo, hi) via four 32x32 partials. -fn mul64(a: U64, b: U64) -> array { - let al: u32 = a.lo; - let ah: u32 = a.hi; - let bl: u32 = b.lo; - let bh: u32 = b.hi; - - // 32x32 helpers via splitting into 16-bit halves to stay within u32 ops. - // We implement low * b, high * b separately and combine via 64-bit adds. - // Easier: lift each pair to u64-via-(lo,hi) arithmetic. - // ll = al * bl (64-bit) - let ll = mul32(al, bl); - // lh = al * bh - let lh = mul32(al, bh); - // hl = ah * bl - let hl = mul32(ah, bl); - // hh = ah * bh - let hh = mul32(ah, bh); - - // mid = (ll >> 32) + (lh & 0xFFFFFFFF) + (hl & 0xFFFFFFFF) - let ll_hi = U64(ll.hi, 0u); - let lh_lo = U64(lh.lo, 0u); - let hl_lo = U64(hl.lo, 0u); - let mid1 = u64_add(ll_hi, lh_lo); - let mid = u64_add(mid1, hl_lo); - - // lo = (ll & 0xFFFFFFFF) | (mid << 32) - let lo = U64(ll.lo, mid.lo); - // hi = hh + (lh >> 32) + (hl >> 32) + (mid >> 32) - let hh_full = hh; - let lh_hi = U64(lh.hi, 0u); - let hl_hi = U64(hl.hi, 0u); - let mid_hi = U64(mid.hi, 0u); - let s1 = u64_add(hh_full, lh_hi); - let s2 = u64_add(s1, hl_hi); - let hi = u64_add(s2, mid_hi); - - return array(lo, hi); -} - -// 32x32 -> (lo, hi) packed as U64. -fn mul32(a: u32, b: u32) -> U64 { - let alo = a & 0xFFFFu; - let ahi = a >> 16u; - let blo = b & 0xFFFFu; - let bhi = b >> 16u; - let p00 = alo * blo; - let p01 = alo * bhi; - let p10 = ahi * blo; - let p11 = ahi * bhi; - // p00 + ((p01 + p10) << 16) + (p11 << 32) - let mid_a = p01 + (p00 >> 16u); - var mid_carry: u32 = 0u; - if (mid_a < p01) { mid_carry = 1u; } - let mid_b = mid_a + p10; - var mid_b_carry: u32 = 0u; - if (mid_b < mid_a) { mid_b_carry = 1u; } - let lo = (p00 & 0xFFFFu) | (mid_b << 16u); - let hi = p11 + (mid_b >> 16u) + (mid_carry << 16u) + (mid_b_carry << 16u); - return U64(lo, hi); -} - -// 64-bit truncating multiply (lower 64 bits only). -fn mul64_lo(a: U64, b: U64) -> U64 { - let prod = mul64(a, b); - return prod[0]; -} - -// ---------- 256-bit subtraction ---------- - -fn u256_cmp(a: U256, b: U256) -> i32 { - var i: i32 = 3; - loop { - if (i < 0) { break; } - let ai = limb(a, u32(i)); - let bi = limb(b, u32(i)); - if (ai.hi < bi.hi) { return -1; } - if (ai.hi > bi.hi) { return 1; } - if (ai.lo < bi.lo) { return -1; } - if (ai.lo > bi.lo) { return 1; } - i = i - 1; - } - return 0; -} - -fn sub_256(a: U256, b: U256) -> U256 { - var r: U256; - var br: u32 = 0u; - for (var i: u32 = 0u; i < 4u; i = i + 1u) { - let res = u64_subb(limb(a, i), limb(b, i), br); - set_limb(&r, i, res[0]); - br = res[1].lo; - } - return r; -} - -// ---------- CIOS Montgomery multiplication ---------- - -fn mont_mul(a: U256, b: U256, m: U256, m_inv: U64) -> U256 { - var t: array; - for (var k: u32 = 0u; k < 6u; k = k + 1u) { t[k] = U64(0u, 0u); } - - for (var i: u32 = 0u; i < 4u; i = i + 1u) { - var carry = U64(0u, 0u); - let bi = limb(b, i); - for (var j: u32 = 0u; j < 4u; j = j + 1u) { - let aj = limb(a, j); - let prod = mul64(aj, bi); - let lo = prod[0]; - let hi = prod[1]; - let sum1 = u64_addc(t[j], lo, 0u); - let s1 = sum1[0]; - let c1 = sum1[1].lo; - let sum2 = u64_addc(s1, carry, 0u); - t[j] = sum2[0]; - let c2 = sum2[1].lo; - let new_carry_step = u64_addc(hi, U64(c1 + c2, 0u), 0u); - carry = new_carry_step[0]; - } - let s4 = u64_addc(t[4], carry, 0u); - t[4] = s4[0]; - t[5] = u64_add(t[5], U64(s4[1].lo, 0u)); - - let u = mul64_lo(t[0], m_inv); - carry = U64(0u, 0u); - for (var j: u32 = 0u; j < 4u; j = j + 1u) { - let mj = limb(m, j); - let prod = mul64(u, mj); - let lo = prod[0]; - let hi = prod[1]; - let sum1 = u64_addc(t[j], lo, 0u); - let s1 = sum1[0]; - let c1 = sum1[1].lo; - let sum2 = u64_addc(s1, carry, 0u); - t[j] = sum2[0]; - let c2 = sum2[1].lo; - let new_carry_step = u64_addc(hi, U64(c1 + c2, 0u), 0u); - carry = new_carry_step[0]; - } - let s4b = u64_addc(t[4], carry, 0u); - t[4] = s4b[0]; - t[5] = u64_add(t[5], U64(s4b[1].lo, 0u)); - - // shift right by 64 bits (drop t[0]) - for (var j: u32 = 0u; j < 5u; j = j + 1u) { t[j] = t[j + 1u]; } - t[5] = U64(0u, 0u); - } - - var r: U256; - set_limb(&r, 0u, t[0]); - set_limb(&r, 1u, t[1]); - set_limb(&r, 2u, t[2]); - set_limb(&r, 3u, t[3]); - - let need_sub = (t[4].lo != 0u) || (t[4].hi != 0u) || (u256_cmp(r, m) >= 0); - if (need_sub) { - r = sub_256(r, m); - } - return r; -} - -fn fp_mul(a: U256, b: U256) -> U256 { return mont_mul(a, b, P_MOD(), U64(P_INV_LO(), P_INV_HI())); } -fn fn_mul(a: U256, b: U256) -> U256 { return mont_mul(a, b, N_MOD(), U64(N_INV_LO(), N_INV_HI())); } -fn fp_sqr(a: U256) -> U256 { return fp_mul(a, a); } -fn fn_sqr(a: U256) -> U256 { return fn_mul(a, a); } - -fn fp_pow(a: U256, e: U256) -> U256 { - var result = ONE_MONT_P(); - var base = a; - for (var lj: u32 = 0u; lj < 4u; lj = lj + 1u) { - let w = limb(e, lj); - // low half - for (var bit: u32 = 0u; bit < 32u; bit = bit + 1u) { - if (((w.lo >> bit) & 1u) != 0u) { result = fp_mul(result, base); } - base = fp_sqr(base); - } - // high half - for (var bit: u32 = 0u; bit < 32u; bit = bit + 1u) { - if (((w.hi >> bit) & 1u) != 0u) { result = fp_mul(result, base); } - base = fp_sqr(base); - } - } - return result; -} -fn fp_inv(a: U256) -> U256 { return fp_pow(a, P_M2()); } - -fn fn_pow(a: U256, e: U256) -> U256 { - var result = mont_mul(ONE(), R2_N(), N_MOD(), U64(N_INV_LO(), N_INV_HI())); - var base = a; - for (var lj: u32 = 0u; lj < 4u; lj = lj + 1u) { - let w = limb(e, lj); - for (var bit: u32 = 0u; bit < 32u; bit = bit + 1u) { - if (((w.lo >> bit) & 1u) != 0u) { result = fn_mul(result, base); } - base = fn_sqr(base); - } - for (var bit: u32 = 0u; bit < 32u; bit = bit + 1u) { - if (((w.hi >> bit) & 1u) != 0u) { result = fn_mul(result, base); } - base = fn_sqr(base); - } - } - return result; -} -fn fn_inv(a: U256) -> U256 { return fn_pow(a, N_M2()); } - -// ---------- Batch inversion kernels (single-thread workgroup) ---------- - -@compute @workgroup_size(1) -fn secp256k1_batch_inv_fp(@builtin(global_invocation_id) gid: vec3) { - if (gid.x != 0u) { return; } - let n = cfg.x; - if (n == 0u) { return; } - - out_buf[0] = in_buf[0]; - for (var i: u32 = 1u; i < n; i = i + 1u) { - out_buf[i] = fp_mul(out_buf[i - 1u], in_buf[i]); - } - var inv = fp_inv(out_buf[n - 1u]); - var k: u32 = n; - loop { - if (k <= 1u) { break; } - let i = k - 1u; - let t = fp_mul(inv, out_buf[i - 1u]); - inv = fp_mul(inv, in_buf[i]); - out_buf[i] = t; - k = k - 1u; - } - out_buf[0] = inv; -} - -@compute @workgroup_size(1) -fn secp256k1_batch_inv_fn(@builtin(global_invocation_id) gid: vec3) { - if (gid.x != 0u) { return; } - let n = cfg.x; - if (n == 0u) { return; } - - out_buf[0] = in_buf[0]; - for (var i: u32 = 1u; i < n; i = i + 1u) { - out_buf[i] = fn_mul(out_buf[i - 1u], in_buf[i]); - } - var inv = fn_inv(out_buf[n - 1u]); - var k: u32 = n; - loop { - if (k <= 1u) { break; } - let i = k - 1u; - let t = fn_mul(inv, out_buf[i - 1u]); - inv = fn_mul(inv, in_buf[i]); - out_buf[i] = t; - k = k - 1u; - } - out_buf[0] = inv; -} diff --git a/secp256k1/gpu/wgsl/secp256k1_batch_inv_driver.cpp b/secp256k1/gpu/wgsl/secp256k1_batch_inv_driver.cpp deleted file mode 100644 index 5677b18..0000000 --- a/secp256k1/gpu/wgsl/secp256k1_batch_inv_driver.cpp +++ /dev/null @@ -1,186 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// WGSL host driver for Stage A (Montgomery batch inversion). Mirrors the -// shape of gpu/metal/secp256k1_batch_inv_driver.mm and gpu/cuda/ -// secp256k1_batch_inv_driver.cu so the test harness can swap drivers. -// -// Real WebGPU dispatch is built only when CRYPTO_HAS_DAWN is defined -// (CI runner has Dawn / wgpu-native installed). Without it, the entry point -// returns a NOTIMPL sentinel so the umbrella library still links and the -// determinism tests skip the WGSL leg. -// -// Entry: cuda-shaped signature (kind = 0 Fp, 1 Fn). - -#include "crypto.h" - -#include -#include -#include - -#ifdef CRYPTO_HAS_DAWN -#include -#include -#include -#include -#include -#endif - -// Use the public C-ABI return codes from crypto.h; no local constants -// (Lux brand-neutral C-ABI per LP-137). -namespace { - -#ifdef CRYPTO_HAS_DAWN - -// Load WGSL source from disk (CMake places kernel next to the build artifacts). -std::string load_wgsl(const char* path) { - std::ifstream f(path); - if (!f) return {}; - std::stringstream ss; ss << f.rdbuf(); - return ss.str(); -} - -int dispatch(const uint8_t* in_mont, size_t n, uint8_t* out_mont, - int kind, const char* wgsl_path) { - if (n == 0) return CRYPTO_OK; - if (!in_mont || !out_mont || !wgsl_path) return CRYPTO_ERR_INPUT; - if (kind != 0 && kind != 1) return CRYPTO_ERR_INPUT; - - const std::string src = load_wgsl(wgsl_path); - if (src.empty()) return -3; - - // ---- Instance + adapter + device --------------------------------------- - wgpu::InstanceDescriptor idesc{}; - wgpu::Instance instance = wgpu::CreateInstance(&idesc); - if (!instance) return -4; - - wgpu::RequestAdapterOptions aopts{}; - aopts.powerPreference = wgpu::PowerPreference::HighPerformance; - wgpu::Adapter adapter = nullptr; - instance.RequestAdapter(&aopts, - [](WGPURequestAdapterStatus status, WGPUAdapter ad, const char*, void* ud) { - if (status == WGPURequestAdapterStatus_Success) { - *static_cast(ud) = wgpu::Adapter::Acquire(ad); - } - }, &adapter); - if (!adapter) return -5; - - wgpu::Device device = nullptr; - wgpu::DeviceDescriptor ddesc{}; - adapter.RequestDevice(&ddesc, - [](WGPURequestDeviceStatus status, WGPUDevice d, const char*, void* ud) { - if (status == WGPURequestDeviceStatus_Success) { - *static_cast(ud) = wgpu::Device::Acquire(d); - } - }, &device); - if (!device) return -6; - - wgpu::Queue queue = device.GetQueue(); - - // ---- Shader module ------------------------------------------------------ - wgpu::ShaderSourceWGSL wgsl{}; - wgsl.code = src.c_str(); - wgpu::ShaderModuleDescriptor smd{}; - smd.nextInChain = &wgsl; - wgpu::ShaderModule shader = device.CreateShaderModule(&smd); - if (!shader) return -7; - - // ---- Buffers ------------------------------------------------------------ - const size_t bytes = n * 32; - - wgpu::BufferDescriptor in_bd{}; - in_bd.size = bytes; - in_bd.usage = wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopyDst; - in_bd.mappedAtCreation = true; - wgpu::Buffer in_buf = device.CreateBuffer(&in_bd); - std::memcpy(in_buf.GetMappedRange(), in_mont, bytes); - in_buf.Unmap(); - - wgpu::BufferDescriptor out_bd{}; - out_bd.size = bytes; - out_bd.usage = wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc; - wgpu::Buffer out_buf = device.CreateBuffer(&out_bd); - - wgpu::BufferDescriptor cfg_bd{}; - cfg_bd.size = 16; // vec4 - cfg_bd.usage = wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst; - cfg_bd.mappedAtCreation = true; - wgpu::Buffer cfg_buf = device.CreateBuffer(&cfg_bd); - uint32_t cfg[4] = { (uint32_t)n, (uint32_t)kind, 0u, 0u }; - std::memcpy(cfg_buf.GetMappedRange(), cfg, sizeof(cfg)); - cfg_buf.Unmap(); - - wgpu::BufferDescriptor read_bd{}; - read_bd.size = bytes; - read_bd.usage = wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead; - wgpu::Buffer read_buf = device.CreateBuffer(&read_bd); - - // ---- Pipeline ----------------------------------------------------------- - wgpu::ComputePipelineDescriptor pd{}; - pd.compute.module = shader; - pd.compute.entryPoint = (kind == 0) ? "secp256k1_batch_inv_fp" - : "secp256k1_batch_inv_fn"; - wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&pd); - if (!pipeline) return -8; - - // ---- Bind group --------------------------------------------------------- - wgpu::BindGroupEntry entries[3] = {}; - entries[0].binding = 0; entries[0].buffer = in_buf; entries[0].size = bytes; - entries[1].binding = 1; entries[1].buffer = out_buf; entries[1].size = bytes; - entries[2].binding = 2; entries[2].buffer = cfg_buf; entries[2].size = 16; - - wgpu::BindGroupDescriptor bgd{}; - bgd.layout = pipeline.GetBindGroupLayout(0); - bgd.entryCount = 3; - bgd.entries = entries; - wgpu::BindGroup bg = device.CreateBindGroup(&bgd); - - // ---- Encode + dispatch (single thread) ---------------------------------- - wgpu::CommandEncoder enc = device.CreateCommandEncoder(); - { - wgpu::ComputePassEncoder pass = enc.BeginComputePass(); - pass.SetPipeline(pipeline); - pass.SetBindGroup(0, bg); - pass.DispatchWorkgroups(1, 1, 1); - pass.End(); - } - enc.CopyBufferToBuffer(out_buf, 0, read_buf, 0, bytes); - wgpu::CommandBuffer cmd = enc.Finish(); - queue.Submit(1, &cmd); - - // ---- Map + read -------------------------------------------------------- - bool done = false; - bool ok = false; - read_buf.MapAsync(wgpu::MapMode::Read, 0, bytes, - [](WGPUBufferMapAsyncStatus s, void* ud) { - auto* p = static_cast*>(ud); - *p->first = true; - *p->second = (s == WGPUBufferMapAsyncStatus_Success); - }, new std::pair(&done, &ok)); - while (!done) { device.Tick(); } - if (!ok) return -9; - - const void* mapped = read_buf.GetConstMappedRange(); - std::memcpy(out_mont, mapped, bytes); - read_buf.Unmap(); - return CRYPTO_OK; -} - -#endif // CRYPTO_HAS_DAWN - -} // namespace - -extern "C" int wgsl_secp256k1_batch_inv( - const uint8_t* in_mont, // n * 32 bytes (Mont-form, limb little-endian) - size_t n, - uint8_t* out_mont, // n * 32 bytes - int kind, // 0 = Fp, 1 = Fn - const char* wgsl_path) // path to secp256k1_batch_inv.wgsl -{ -#ifdef CRYPTO_HAS_DAWN - return dispatch(in_mont, n, out_mont, kind, wgsl_path); -#else - (void)in_mont; (void)n; (void)out_mont; (void)kind; (void)wgsl_path; - return CRYPTO_ERR_NOTIMPL; -#endif -} diff --git a/secp256k1/gpu/wgsl/secp256k1_recover.wgsl b/secp256k1/gpu/wgsl/secp256k1_recover.wgsl deleted file mode 100644 index 31e4e2b..0000000 --- a/secp256k1/gpu/wgsl/secp256k1_recover.wgsl +++ /dev/null @@ -1,695 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// secp256k1 ECDSA public key recovery (ecrecover) in WGSL. -// Matches secp256k1_recover.metal output byte-for-byte. -// 256-bit arithmetic uses 8 x u32 limbs (no native u64 in WGSL). -// -// Per thread: (r, s, v, msg_hash) -> 20-byte Ethereum address -// Algorithm: Q = r^{-1} * (s*R - e*G), address = keccak256(Q)[12:] - -// Input: [r[32], s[32], v[1], pad[3], msg_hash[32], pad[28]] = 128 bytes per sig -// Output: [address[20], valid[1], pad[11]] = 32 bytes per sig - -@group(0) @binding(0) var inputs: array; -@group(0) @binding(1) var outputs: array; -@group(0) @binding(2) var params: Params; - -struct Params { - num_items: u32, -} - -// ============================================================================ -// 256-bit integer as 8 x u32 (little-endian) -// ============================================================================ - -fn u256_zero() -> array { - return array(0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u); -} - -fn u256_is_zero(a: ptr>) -> bool { - var acc = 0u; - for (var i = 0u; i < 8u; i = i + 1u) { acc = acc | (*a)[i]; } - return acc == 0u; -} - -fn u256_cmp(a: ptr>, b: ptr>) -> i32 { - for (var i = 7i; i >= 0; i = i - 1) { - let ui = u32(i); - if ((*a)[ui] > (*b)[ui]) { return 1; } - if ((*a)[ui] < (*b)[ui]) { return -1; } - } - return 0; -} - -fn u256_add(a: ptr>, b: ptr>, - r: ptr>) -> u32 { - var c = 0u; - for (var i = 0u; i < 8u; i = i + 1u) { - let s1 = (*a)[i] + c; - c = select(0u, 1u, s1 < (*a)[i]); - let s2 = s1 + (*b)[i]; - c = c + select(0u, 1u, s2 < s1); - (*r)[i] = s2; - } - return c; -} - -fn u256_sub(a: ptr>, b: ptr>, - r: ptr>) -> u32 { - var bw = 0u; - for (var i = 0u; i < 8u; i = i + 1u) { - let d1 = (*a)[i] - bw; - bw = select(0u, 1u, d1 > (*a)[i]); - let d2 = d1 - (*b)[i]; - bw = bw + select(0u, 1u, d2 > d1); - (*r)[i] = d2; - } - return bw; -} - -// ============================================================================ -// secp256k1 constants (8 x u32 little-endian) -// ============================================================================ - -// Field prime p = 0xFFFFFFFF...FFFFFFFEFFFFFC2F -const SECP_P = array( - 0xFFFFFC2Fu, 0xFFFFFFFEu, 0xFFFFFFFFu, 0xFFFFFFFFu, - 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu -); -// Curve order n -const SECP_N = array( - 0xD0364141u, 0xBFD25E8Cu, 0xAF48A03Bu, 0xBAAEDCE6u, - 0xFFFFFFFEu, 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu -); -// Montgomery R mod p -const MONT_R_P = array( - 0x000003D1u, 0x00000001u, 0x00000000u, 0x00000000u, - 0x00000000u, 0x00000000u, 0x00000000u, 0x00000000u -); -// R^2 mod p -const MONT_R2_P = array( - 0x000E90A1u, 0x000007A2u, 0x00000001u, 0x00000000u, - 0x00000000u, 0x00000000u, 0x00000000u, 0x00000000u -); -// -p^{-1} mod 2^32 -const P_INV: u32 = 0xD2253531u; -// R^2 mod n -const MONT_R2_N = array( - 0x67D7D140u, 0x896CF214u, 0x0E7CF878u, 0x741496C2u, - 0x5BCD07C6u, 0xE697F5E4u, 0x81C69BC5u, 0x9D671CD5u -); -// -n^{-1} mod 2^32 -const N_INV: u32 = 0x5588B13Fu; -// Generator G.x -const GX = array( - 0x16F81798u, 0x59F2815Bu, 0x2DCE28D9u, 0x029BFCDB, - 0xCE870B07u, 0x55A06295u, 0xF9DCBBACu, 0x79BE667Eu -); -// Generator G.y -const GY = array( - 0xFB10D4B8u, 0x9C47D08Fu, 0xA6855419u, 0xFD17B448u, - 0x0E1108A8u, 0x5DA4FBFC, 0x26A3C465u, 0x483ADA77u -); - -// ============================================================================ -// Montgomery multiplication (256-bit, 8x u32 limbs) -// ============================================================================ - -fn mont_reduce(t: ptr>, m: ptr>, - inv: u32, r: ptr>) { - // Extended to 17 limbs for carry - var a: array; - for (var i = 0u; i < 16u; i = i + 1u) { a[i] = (*t)[i]; } - a[16] = 0u; - - for (var i = 0u; i < 8u; i = i + 1u) { - let u = a[i] * inv; - var carry = 0u; - for (var j = 0u; j < 8u; j = j + 1u) { - // u * m[j] -> (hi, lo) - let u_lo = u & 0xFFFFu; let u_hi = u >> 16u; - let m_lo = (*m)[j] & 0xFFFFu; let m_hi = (*m)[j] >> 16u; - let ll = u_lo * m_lo; - let lh = u_lo * m_hi; - let hl = u_hi * m_lo; - let hh = u_hi * m_hi; - let mid = lh + hl; - var lo = ll + (mid << 16u); - var hi = hh + (mid >> 16u) + select(0u, 1u, lo < ll) + select(0u, 0x10000u, mid < lh); - - // lo += carry - let s1 = lo + carry; - hi = hi + select(0u, 1u, s1 < lo); - // a[i+j] += s1 - let s2 = a[i + j] + s1; - hi = hi + select(0u, 1u, s2 < a[i + j]); - a[i + j] = s2; - carry = hi; - } - for (var j = 8u; i + j <= 16u; j = j + 1u) { - let s = a[i + j] + carry; - carry = select(0u, 1u, s < a[i + j]); - a[i + j] = s; - if (carry == 0u) { break; } - } - } - - for (var i = 0u; i < 8u; i = i + 1u) { (*r)[i] = a[i + 8u]; } - - // Final subtraction if r >= m - if (a[16] != 0u || u256_cmp(r, m) >= 0) { - let _ = u256_sub(r, m, r); - } -} - -fn mont_mul(a: ptr>, b: ptr>, - m: ptr>, inv: u32, r: ptr>) { - var t: array; - for (var i = 0u; i < 16u; i = i + 1u) { t[i] = 0u; } - - for (var i = 0u; i < 8u; i = i + 1u) { - var carry = 0u; - for (var j = 0u; j < 8u; j = j + 1u) { - let al = (*a)[i] & 0xFFFFu; let ah = (*a)[i] >> 16u; - let bl = (*b)[j] & 0xFFFFu; let bh = (*b)[j] >> 16u; - let ll = al * bl; - let lh = al * bh; - let hl = ah * bl; - let hh = ah * bh; - let mid = lh + hl; - var lo = ll + (mid << 16u); - var hi = hh + (mid >> 16u) + select(0u, 1u, lo < ll) + select(0u, 0x10000u, mid < lh); - let s1 = lo + carry; hi = hi + select(0u, 1u, s1 < lo); - let s2 = t[i + j] + s1; hi = hi + select(0u, 1u, s2 < t[i + j]); - t[i + j] = s2; - carry = hi; - } - for (var j = 8u; i + j < 16u; j = j + 1u) { - let s = t[i + j] + carry; - carry = select(0u, 1u, s < t[i + j]); - t[i + j] = s; - if (carry == 0u) { break; } - } - } - mont_reduce(&t, m, inv, r); -} - -// Field ops over p (Montgomery form) -fn fp_add(a: ptr>, b: ptr>, - r: ptr>) { - var p = SECP_P; - let c = u256_add(a, b, r); - if (c != 0u || u256_cmp(r, &p) >= 0) { - let _ = u256_sub(r, &p, r); - } -} - -fn fp_sub(a: ptr>, b: ptr>, - r: ptr>) { - var p = SECP_P; - let bw = u256_sub(a, b, r); - if (bw != 0u) { - let _ = u256_add(r, &p, r); - } -} - -fn fp_mul(a: ptr>, b: ptr>, - r: ptr>) { - var p = SECP_P; - mont_mul(a, b, &p, P_INV, r); -} - -fn fp_sqr(a: ptr>, r: ptr>) { - var p = SECP_P; - mont_mul(a, a, &p, P_INV, r); -} - -fn fn_mul(a: ptr>, b: ptr>, - r: ptr>) { - var n = SECP_N; - mont_mul(a, b, &n, N_INV, r); -} - -fn to_mont_p(a: ptr>, r: ptr>) { - var r2 = MONT_R2_P; - fp_mul(a, &r2, r); -} - -fn from_mont_p(a: ptr>, r: ptr>) { - var p = SECP_P; - var t: array; - for (var i = 0u; i < 16u; i = i + 1u) { t[i] = 0u; } - for (var i = 0u; i < 8u; i = i + 1u) { t[i] = (*a)[i]; } - mont_reduce(&t, &p, P_INV, r); -} - -fn to_mont_n(a: ptr>, r: ptr>) { - var r2 = MONT_R2_N; - fn_mul(a, &r2, r); -} - -fn from_mont_n(a: ptr>, r: ptr>) { - var n = SECP_N; - var t: array; - for (var i = 0u; i < 16u; i = i + 1u) { t[i] = 0u; } - for (var i = 0u; i < 8u; i = i + 1u) { t[i] = (*a)[i]; } - mont_reduce(&t, &n, N_INV, r); -} - -// Modular inversion via Fermat: a^(m-2) mod m -fn fp_inv(a: ptr>, r: ptr>) { - // p-2 little-endian u32 limbs - var exp = array( - 0xFFFFFC2Du, 0xFFFFFFFEu, 0xFFFFFFFFu, 0xFFFFFFFFu, - 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu - ); - var one = array(1u, 0u, 0u, 0u, 0u, 0u, 0u, 0u); - var result: array; - to_mont_p(&one, &result); - var base: array; - for (var i = 0u; i < 8u; i = i + 1u) { base[i] = (*a)[i]; } - - for (var i = 0u; i < 8u; i = i + 1u) { - for (var bit = 0u; bit < 32u; bit = bit + 1u) { - if (((exp[i] >> bit) & 1u) != 0u) { - var tmp: array; - fp_mul(&result, &base, &tmp); - result = tmp; - } - var tmp2: array; - fp_sqr(&base, &tmp2); - base = tmp2; - } - } - *r = result; -} - -fn fn_inv(a: ptr>, r: ptr>) { - // n-2 - var exp = array( - 0xD036413Fu, 0xBFD25E8Cu, 0xAF48A03Bu, 0xBAAEDCE6u, - 0xFFFFFFFEu, 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu - ); - var one = array(1u, 0u, 0u, 0u, 0u, 0u, 0u, 0u); - var result: array; - to_mont_n(&one, &result); - var base: array; - for (var i = 0u; i < 8u; i = i + 1u) { base[i] = (*a)[i]; } - - for (var i = 0u; i < 8u; i = i + 1u) { - for (var bit = 0u; bit < 32u; bit = bit + 1u) { - if (((exp[i] >> bit) & 1u) != 0u) { - var tmp: array; - fn_mul(&result, &base, &tmp); - result = tmp; - } - var tmp2: array; - fn_mul(&base, &base, &tmp2); - base = tmp2; - } - } - *r = result; -} - -// ============================================================================ -// EC point operations (Jacobian, Montgomery Fp) -// Point = (x[8], y[8], z[8]) = 24 u32 words -// ============================================================================ - -struct ECPoint { - x: array, - y: array, - z: array, -} - -fn ec_identity() -> ECPoint { - var p: ECPoint; - var one = array(1u, 0u, 0u, 0u, 0u, 0u, 0u, 0u); - to_mont_p(&one, &p.x); - p.y = p.x; - p.z = u256_zero(); - return p; -} - -fn ec_is_inf(p: ptr) -> bool { - var z = (*p).z; - return u256_is_zero(&z); -} - -fn ec_double(p: ptr, r: ptr) { - if (ec_is_inf(p)) { *r = *p; return; } - var A: array; fp_sqr(&(*p).y, &A); - var B: array; fp_mul(&(*p).x, &A, &B); - var C: array; fp_sqr(&A, &C); - // S = 4*B - var S: array; fp_add(&B, &B, &S); fp_add(&S, &S, &S); - // M = 3*X^2 (a=0) - var X2: array; fp_sqr(&(*p).x, &X2); - var X2_2: array; fp_add(&X2, &X2, &X2_2); - var M: array; fp_add(&X2_2, &X2, &M); - // X3 = M^2 - 2S - var M2: array; fp_sqr(&M, &M2); - var S2: array; fp_add(&S, &S, &S2); - var X3: array; fp_sub(&M2, &S2, &X3); - // Y3 = M*(S-X3) - 8C - var SX: array; fp_sub(&S, &X3, &SX); - var MSX: array; fp_mul(&M, &SX, &MSX); - var C2: array; fp_add(&C, &C, &C2); - var C4: array; fp_add(&C2, &C2, &C4); - var C8: array; fp_add(&C4, &C4, &C8); - var Y3: array; fp_sub(&MSX, &C8, &Y3); - // Z3 = 2*Y*Z - var YZ: array; fp_mul(&(*p).y, &(*p).z, &YZ); - var Z3: array; fp_add(&YZ, &YZ, &Z3); - (*r).x = X3; (*r).y = Y3; (*r).z = Z3; -} - -fn ec_add_mixed(P: ptr, Qx: ptr>, - Qy: ptr>, r: ptr) { - if (ec_is_inf(P)) { - (*r).x = *Qx; (*r).y = *Qy; - var one = array(1u, 0u, 0u, 0u, 0u, 0u, 0u, 0u); - to_mont_p(&one, &(*r).z); - return; - } - var Z2: array; fp_sqr(&(*P).z, &Z2); - var U2: array; fp_mul(Qx, &Z2, &U2); - var Z3: array; fp_mul(&Z2, &(*P).z, &Z3); - var S2: array; fp_mul(Qy, &Z3, &S2); - var H: array; fp_sub(&U2, &(*P).x, &H); - var R: array; fp_sub(&S2, &(*P).y, &R); - - if (u256_is_zero(&H)) { - if (u256_is_zero(&R)) { ec_double(P, r); return; } - *r = ec_identity(); - return; - } - - var H2: array; fp_sqr(&H, &H2); - var H3: array; fp_mul(&H, &H2, &H3); - var U1H2: array; fp_mul(&(*P).x, &H2, &U1H2); - // X3 = R^2 - H^3 - 2*U1H2 - var R2: array; fp_sqr(&R, &R2); - var U1H2_2: array; fp_add(&U1H2, &U1H2, &U1H2_2); - var t1: array; fp_sub(&R2, &H3, &t1); - var X3: array; fp_sub(&t1, &U1H2_2, &X3); - // Y3 = R*(U1H2 - X3) - Y1*H3 - var UX: array; fp_sub(&U1H2, &X3, &UX); - var RUX: array; fp_mul(&R, &UX, &RUX); - var YH3: array; fp_mul(&(*P).y, &H3, &YH3); - var Y3: array; fp_sub(&RUX, &YH3, &Y3); - // Z3 = H * P.Z - var Zr: array; fp_mul(&H, &(*P).z, &Zr); - (*r).x = X3; (*r).y = Y3; (*r).z = Zr; -} - -fn ec_mul_affine(k: ptr>, - Px: ptr>, - Py: ptr>) -> ECPoint { - var result = ec_identity(); - for (var i = 7i; i >= 0; i = i - 1) { - for (var bit = 31i; bit >= 0; bit = bit - 1) { - var dbl: ECPoint; - ec_double(&result, &dbl); - result = dbl; - if ((((*k)[u32(i)] >> u32(bit)) & 1u) != 0u) { - var tmp: ECPoint; - ec_add_mixed(&result, Px, Py, &tmp); - result = tmp; - } - } - } - return result; -} - -fn ec_to_affine(p: ptr, ax: ptr>, - ay: ptr>) { - if (ec_is_inf(p)) { *ax = u256_zero(); *ay = u256_zero(); return; } - var z_inv: array; fp_inv(&(*p).z, &z_inv); - var z_inv2: array; fp_sqr(&z_inv, &z_inv2); - var z_inv3: array; fp_mul(&z_inv2, &z_inv, &z_inv3); - fp_mul(&(*p).x, &z_inv2, ax); - fp_mul(&(*p).y, &z_inv3, ay); -} - -// ============================================================================ -// Inline Keccak-256 for 64 bytes (public key -> address) -// ============================================================================ - -var kst_lo: array; -var kst_hi: array; - -const KRC_LO = array( - 0x00000001u, 0x00008082u, 0x0000808Au, 0x80008000u, - 0x0000808Bu, 0x80000001u, 0x80008081u, 0x00008009u, - 0x0000008Au, 0x00000088u, 0x80008009u, 0x8000000Au, - 0x8000808Bu, 0x0000008Bu, 0x00008089u, 0x00008003u, - 0x00008002u, 0x00000080u, 0x0000800Au, 0x8000000Au, - 0x80008081u, 0x00008080u, 0x80000001u, 0x80008008u -); -const KRC_HI = array( - 0x00000000u, 0x00000000u, 0x80000000u, 0x80000000u, - 0x00000000u, 0x00000000u, 0x80000000u, 0x80000000u, - 0x00000000u, 0x00000000u, 0x00000000u, 0x00000000u, - 0x00000000u, 0x80000000u, 0x80000000u, 0x80000000u, - 0x80000000u, 0x80000000u, 0x00000000u, 0x80000000u, - 0x80000000u, 0x80000000u, 0x00000000u, 0x80000000u -); -const KPI = array( - 10u, 7u, 11u, 17u, 18u, 3u, 5u, 16u, 8u, 21u, 24u, 4u, - 15u, 23u, 19u, 13u, 12u, 2u, 20u, 14u, 22u, 9u, 6u, 1u -); -const KRHO = array( - 1u, 3u, 6u, 10u, 15u, 21u, 28u, 36u, 45u, 55u, 2u, 14u, - 27u, 41u, 56u, 8u, 25u, 43u, 62u, 18u, 39u, 61u, 20u, 44u -); - -fn krotl64(lo: u32, hi: u32, n: u32) -> vec2 { - if (n == 0u) { return vec2(lo, hi); } - if (n == 32u) { return vec2(hi, lo); } - if (n < 32u) { - return vec2((lo << n) | (hi >> (32u - n)), (hi << n) | (lo >> (32u - n))); - } - let m = n - 32u; - return vec2((hi << m) | (lo >> (32u - m)), (lo << m) | (hi >> (32u - m))); -} - -fn keccak_f() { - for (var round = 0u; round < 24u; round = round + 1u) { - var c_lo: array; var c_hi: array; - for (var x = 0u; x < 5u; x = x + 1u) { - c_lo[x] = kst_lo[x] ^ kst_lo[x+5u] ^ kst_lo[x+10u] ^ kst_lo[x+15u] ^ kst_lo[x+20u]; - c_hi[x] = kst_hi[x] ^ kst_hi[x+5u] ^ kst_hi[x+10u] ^ kst_hi[x+15u] ^ kst_hi[x+20u]; - } - for (var x = 0u; x < 5u; x = x + 1u) { - let r = krotl64(c_lo[(x+1u)%5u], c_hi[(x+1u)%5u], 1u); - let d_lo = c_lo[(x+4u)%5u] ^ r.x; - let d_hi = c_hi[(x+4u)%5u] ^ r.y; - for (var y = 0u; y < 5u; y = y + 1u) { - let idx = x + 5u * y; - kst_lo[idx] = kst_lo[idx] ^ d_lo; - kst_hi[idx] = kst_hi[idx] ^ d_hi; - } - } - var t_lo = kst_lo[1u]; var t_hi = kst_hi[1u]; - for (var i = 0u; i < 24u; i = i + 1u) { - let dst = KPI[i]; - let tmp_lo = kst_lo[dst]; let tmp_hi = kst_hi[dst]; - let r = krotl64(t_lo, t_hi, KRHO[i]); - kst_lo[dst] = r.x; kst_hi[dst] = r.y; - t_lo = tmp_lo; t_hi = tmp_hi; - } - for (var y = 0u; y < 5u; y = y + 1u) { - var rl: array; var rh: array; - for (var x = 0u; x < 5u; x = x + 1u) { - rl[x] = kst_lo[x + 5u*y]; rh[x] = kst_hi[x + 5u*y]; - } - for (var x = 0u; x < 5u; x = x + 1u) { - kst_lo[x+5u*y] = rl[x] ^ ((~rl[(x+1u)%5u]) & rl[(x+2u)%5u]); - kst_hi[x+5u*y] = rh[x] ^ ((~rh[(x+1u)%5u]) & rh[(x+2u)%5u]); - } - } - kst_lo[0] = kst_lo[0] ^ KRC_LO[round]; - kst_hi[0] = kst_hi[0] ^ KRC_HI[round]; - } -} - -fn keccak256_64(data: ptr>, hash: ptr>) { - for (var i = 0u; i < 25u; i = i + 1u) { kst_lo[i] = 0u; kst_hi[i] = 0u; } - // Absorb 64 bytes = 8 lanes (each lane = 8 bytes = 2 u32 words) - for (var w = 0u; w < 8u; w = w + 1u) { - kst_lo[w] = kst_lo[w] ^ (*data)[w * 2u]; - kst_hi[w] = kst_hi[w] ^ (*data)[w * 2u + 1u]; - } - // Keccak padding: byte 64 = 0x01, byte 135 = 0x80 - kst_lo[8] = kst_lo[8] ^ 0x01u; - kst_hi[16] = kst_hi[16] ^ 0x80000000u; - keccak_f(); - for (var w = 0u; w < 4u; w = w + 1u) { - (*hash)[w * 2u] = kst_lo[w]; - (*hash)[w * 2u + 1u] = kst_hi[w]; - } -} - -// ============================================================================ -// Load/store helpers (big-endian 32 bytes <-> u256 little-endian u32 limbs) -// ============================================================================ - -fn load_be32(word_base: u32) -> array { - // Input is 32 bytes = 8 u32 words in the inputs array (byte-packed) - // Stored as big-endian in the input. We need to reverse byte order within words - // and reverse word order for little-endian limbs. - var r: array; - for (var i = 0u; i < 8u; i = i + 1u) { - let w = inputs[word_base + 7u - i]; - // Byte-swap u32 (big-endian to little-endian) - r[i] = ((w >> 24u) & 0xFFu) | (((w >> 16u) & 0xFFu) << 8u) - | (((w >> 8u) & 0xFFu) << 16u) | ((w & 0xFFu) << 24u); - } - return r; -} - -// ============================================================================ -// Main kernel -// ============================================================================ - -@compute @workgroup_size(256) -fn secp256k1_ecrecover(@builtin(global_invocation_id) gid: vec3) { - let tid = gid.x; - if (tid >= params.num_items) { return; } - - // Clear output - let out_base = tid * 8u; // 32 bytes = 8 u32 - for (var i = 0u; i < 8u; i = i + 1u) { outputs[out_base + i] = 0u; } - - // Load signature: 128 bytes = 32 u32 per sig - let in_base = tid * 32u; - var r = load_be32(in_base); // r: bytes 0..31 - var s = load_be32(in_base + 8u); // s: bytes 32..63 - let v_byte = (inputs[in_base + 16u]) & 0xFFu; // v: byte 64 - var e = load_be32(in_base + 17u); // msg_hash: bytes 68..99 - - var v = v_byte; - if (v >= 27u) { v = v - 27u; } - if (v >= 2u) { v = v % 2u; } - - // Validate r, s in [1, n-1] - var n = SECP_N; - if (u256_is_zero(&r) || u256_cmp(&r, &n) >= 0) { return; } - if (u256_is_zero(&s) || u256_cmp(&s, &n) >= 0) { return; } - if (v > 1u) { return; } - - // Decompress r -> R = (r, y) - var r_mont: array; to_mont_p(&r, &r_mont); - var r2: array; fp_sqr(&r_mont, &r2); - var r3: array; fp_mul(&r2, &r_mont, &r3); - var seven = array(7u, 0u, 0u, 0u, 0u, 0u, 0u, 0u); - var seven_mont: array; to_mont_p(&seven, &seven_mont); - var y2: array; fp_add(&r3, &seven_mont, &y2); - - // sqrt via a^((p+1)/4) since p = 3 mod 4 - var exp_sqrt = array( - 0xBFFFFF0Cu, 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu, - 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu, 0x3FFFFFFFu - ); - var one = array(1u, 0u, 0u, 0u, 0u, 0u, 0u, 0u); - var y_mont: array; to_mont_p(&one, &y_mont); - var base_y = y2; - for (var i = 0u; i < 8u; i = i + 1u) { - for (var bit = 0u; bit < 32u; bit = bit + 1u) { - if (((exp_sqrt[i] >> bit) & 1u) != 0u) { - var tmp: array; - fp_mul(&y_mont, &base_y, &tmp); - y_mont = tmp; - } - var tmp2: array; - fp_sqr(&base_y, &tmp2); - base_y = tmp2; - } - } - - // Verify sqrt: y^2 == y2 - var check: array; fp_sqr(&y_mont, &check); - if (u256_cmp(&check, &y2) != 0) { return; } - - // Select y parity - var y_normal: array; from_mont_p(&y_mont, &y_normal); - let y_is_odd = (y_normal[0] & 1u) != 0u; - if ((v == 0u && y_is_odd) || (v == 1u && !y_is_odd)) { - var zero_val = u256_zero(); - fp_sub(&zero_val, &y_mont, &y_mont); - } - - // r_inv = r^{-1} mod n - var r_n_mont: array; to_mont_n(&r, &r_n_mont); - var r_inv_mont: array; fn_inv(&r_n_mont, &r_inv_mont); - - // u1 = -(e * r_inv) mod n, u2 = s * r_inv mod n - var e_n_mont: array; to_mont_n(&e, &e_n_mont); - var s_n_mont: array; to_mont_n(&s, &s_n_mont); - - var u1_mont: array; fn_mul(&e_n_mont, &r_inv_mont, &u1_mont); - var u1: array; from_mont_n(&u1_mont, &u1); - if (!u256_is_zero(&u1)) { - var nn = SECP_N; - let _ = u256_sub(&nn, &u1, &u1); - } - - var u2_mont: array; fn_mul(&s_n_mont, &r_inv_mont, &u2_mont); - var u2: array; from_mont_n(&u2_mont, &u2); - - // Q = u1*G + u2*R - var Gx_mont: array; var gx = GX; to_mont_p(&gx, &Gx_mont); - var Gy_mont: array; var gy = GY; to_mont_p(&gy, &Gy_mont); - - var Q1 = ec_mul_affine(&u1, &Gx_mont, &Gy_mont); - var Q2 = ec_mul_affine(&u2, &r_mont, &y_mont); - - // Add Q1 + Q2 - var Q: ECPoint; - if (ec_is_inf(&Q1)) { - Q = Q2; - } else if (ec_is_inf(&Q2)) { - Q = Q1; - } else { - var Q2x_aff: array; var Q2y_aff: array; - ec_to_affine(&Q2, &Q2x_aff, &Q2y_aff); - ec_add_mixed(&Q1, &Q2x_aff, &Q2y_aff, &Q); - } - - if (ec_is_inf(&Q)) { return; } - - var Qx_aff: array; var Qy_aff: array; - ec_to_affine(&Q, &Qx_aff, &Qy_aff); - var Qx_norm: array; from_mont_p(&Qx_aff, &Qx_norm); - var Qy_norm: array; from_mont_p(&Qy_aff, &Qy_norm); - - // Serialize Q.x || Q.y as 16 u32 words (big-endian bytes within each 32-byte half) - var pubkey: array; - for (var i = 0u; i < 8u; i = i + 1u) { - let w = Qx_norm[7u - i]; - pubkey[i] = ((w >> 24u) & 0xFFu) | (((w >> 16u) & 0xFFu) << 8u) - | (((w >> 8u) & 0xFFu) << 16u) | ((w & 0xFFu) << 24u); - } - for (var i = 0u; i < 8u; i = i + 1u) { - let w = Qy_norm[7u - i]; - pubkey[8u + i] = ((w >> 24u) & 0xFFu) | (((w >> 16u) & 0xFFu) << 8u) - | (((w >> 8u) & 0xFFu) << 16u) | ((w & 0xFFu) << 24u); - } - - // address = keccak256(pubkey)[12:] - var hash: array; - keccak256_64(&pubkey, &hash); - - // Output: address (bytes 12-31 of hash) = last 20 bytes - // hash is 32 bytes = 8 u32 words. Bytes 12..31 = words 3..7 (but byte offset 12 = word 3 byte 0) - // Store as 5 u32 words at output (20 bytes), then valid byte - outputs[out_base] = hash[3]; - outputs[out_base + 1u] = hash[4]; - outputs[out_base + 2u] = hash[5]; - outputs[out_base + 3u] = hash[6]; - outputs[out_base + 4u] = hash[7]; - // valid byte at output byte 20 = word 5 - outputs[out_base + 5u] = 1u; -} diff --git a/sha256/gpu/cuda/sha256.cu b/sha256/gpu/cuda/sha256.cu deleted file mode 100644 index 5db4d13..0000000 --- a/sha256/gpu/cuda/sha256.cu +++ /dev/null @@ -1,148 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// SHA-256 batch hashing — CUDA implementation. -// Matches sha256/cpp/sha256.cpp::sha256() and sha256/gpu/metal/sha256_batch.metal -// byte-for-byte (FIPS 180-4). One thread per input. -// -// Layout matches the Metal driver: caller fills a Sha256Job[] with -// (input_offset, input_len, output_offset). Inputs share a flat byte arena; -// outputs share a 32-byte stride arena. Padding is canonical (FIPS 180-4 -// sec 5.1.1): append 0x80, then zero pad to 64 mod 56, then 8-byte -// big-endian bit length. - -#include - -#ifndef __CUDA_ARCH__ -#define __device__ -#define __global__ -#define __shared__ -struct dim3 { unsigned x, y, z; }; -static dim3 blockIdx, blockDim, threadIdx; -#endif - -// FIPS 180-4 round constants (§4.2.2): first 32 bits of the fractional parts -// of the cube roots of the first 64 primes 2..311. -__device__ static const uint32_t K[64] = { - 0x428a2f98u, 0x71374491u, 0xb5c0fbcfu, 0xe9b5dba5u, - 0x3956c25bu, 0x59f111f1u, 0x923f82a4u, 0xab1c5ed5u, - 0xd807aa98u, 0x12835b01u, 0x243185beu, 0x550c7dc3u, - 0x72be5d74u, 0x80deb1feu, 0x9bdc06a7u, 0xc19bf174u, - 0xe49b69c1u, 0xefbe4786u, 0x0fc19dc6u, 0x240ca1ccu, - 0x2de92c6fu, 0x4a7484aau, 0x5cb0a9dcu, 0x76f988dau, - 0x983e5152u, 0xa831c66du, 0xb00327c8u, 0xbf597fc7u, - 0xc6e00bf3u, 0xd5a79147u, 0x06ca6351u, 0x14292967u, - 0x27b70a85u, 0x2e1b2138u, 0x4d2c6dfcu, 0x53380d13u, - 0x650a7354u, 0x766a0abbu, 0x81c2c92eu, 0x92722c85u, - 0xa2bfe8a1u, 0xa81a664bu, 0xc24b8b70u, 0xc76c51a3u, - 0xd192e819u, 0xd6990624u, 0xf40e3585u, 0x106aa070u, - 0x19a4c116u, 0x1e376c08u, 0x2748774cu, 0x34b0bcb5u, - 0x391c0cb3u, 0x4ed8aa4au, 0x5b9cca4fu, 0x682e6ff3u, - 0x748f82eeu, 0x78a5636fu, 0x84c87814u, 0x8cc70208u, - 0x90befffau, 0xa4506cebu, 0xbef9a3f7u, 0xc67178f2u, -}; - -__device__ static inline uint32_t rotr32(uint32_t x, uint32_t n) { - return (x >> n) | (x << (32u - n)); -} - -__device__ static inline void sha256_block(uint32_t* h, const uint8_t* p) { - uint32_t w[64]; - #pragma unroll - for (uint32_t i = 0; i < 16; ++i) { - w[i] = (uint32_t(p[i * 4 + 0]) << 24) | - (uint32_t(p[i * 4 + 1]) << 16) | - (uint32_t(p[i * 4 + 2]) << 8) | - (uint32_t(p[i * 4 + 3])); - } - #pragma unroll - for (uint32_t i = 16; i < 64; ++i) { - uint32_t s0 = rotr32(w[i - 15], 7) ^ rotr32(w[i - 15], 18) ^ (w[i - 15] >> 3); - uint32_t s1 = rotr32(w[i - 2], 17) ^ rotr32(w[i - 2], 19) ^ (w[i - 2] >> 10); - w[i] = w[i - 16] + s0 + w[i - 7] + s1; - } - - uint32_t a = h[0], b = h[1], c = h[2], d = h[3]; - uint32_t e = h[4], f = h[5], g = h[6], hh = h[7]; - - #pragma unroll - for (uint32_t i = 0; i < 64; ++i) { - uint32_t S1 = rotr32(e, 6) ^ rotr32(e, 11) ^ rotr32(e, 25); - uint32_t ch = (e & f) ^ ((~e) & g); - uint32_t t1 = hh + S1 + ch + K[i] + w[i]; - uint32_t S0 = rotr32(a, 2) ^ rotr32(a, 13) ^ rotr32(a, 22); - uint32_t mj = (a & b) ^ (a & c) ^ (b & c); - uint32_t t2 = S0 + mj; - hh = g; g = f; f = e; e = d + t1; - d = c; c = b; b = a; a = t1 + t2; - } - h[0] += a; h[1] += b; h[2] += c; h[3] += d; - h[4] += e; h[5] += f; h[6] += g; h[7] += hh; -} - -extern "C" __global__ void sha256_jobs( - const uint8_t* __restrict__ inputs, - const uint32_t* __restrict__ input_offsets, - const uint32_t* __restrict__ input_lens, - uint8_t* __restrict__ outputs, - uint32_t num_jobs) -{ - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= num_jobs) return; - - const uint8_t* in = inputs + input_offsets[tid]; - uint8_t* out = outputs + tid * 32u; - uint32_t len = input_lens[tid]; - - // FIPS 180-4 IV (§5.3.3): first 32 bits of fractional parts of the square - // roots of the first 8 primes 2..19. - uint32_t h[8] = { - 0x6a09e667u, 0xbb67ae85u, 0x3c6ef372u, 0xa54ff53au, - 0x510e527fu, 0x9b05688cu, 0x1f83d9abu, 0x5be0cd19u, - }; - - // Process full 64-byte blocks streaming from device memory. - uint8_t block[64]; - uint32_t absorbed = 0; - while (len - absorbed >= 64u) { - #pragma unroll - for (uint32_t i = 0; i < 64u; ++i) block[i] = in[absorbed + i]; - sha256_block(h, block); - absorbed += 64u; - } - - // Final block(s): copy tail, append 0x80, pad zero, append 64-bit - // big-endian bit length. May span one or two blocks. - uint32_t rem = len - absorbed; - #pragma unroll - for (uint32_t i = 0; i < 64u; ++i) block[i] = 0; - for (uint32_t i = 0; i < rem; ++i) block[i] = in[absorbed + i]; - block[rem] = 0x80u; - - if (rem >= 56u) { - // Tail spans two final blocks. - sha256_block(h, block); - #pragma unroll - for (uint32_t i = 0; i < 64u; ++i) block[i] = 0; - } - - uint64_t bit_len = uint64_t(len) * 8u; - block[56] = uint8_t((bit_len >> 56) & 0xFFu); - block[57] = uint8_t((bit_len >> 48) & 0xFFu); - block[58] = uint8_t((bit_len >> 40) & 0xFFu); - block[59] = uint8_t((bit_len >> 32) & 0xFFu); - block[60] = uint8_t((bit_len >> 24) & 0xFFu); - block[61] = uint8_t((bit_len >> 16) & 0xFFu); - block[62] = uint8_t((bit_len >> 8) & 0xFFu); - block[63] = uint8_t( bit_len & 0xFFu); - sha256_block(h, block); - - // Output big-endian. - #pragma unroll - for (uint32_t i = 0; i < 8u; ++i) { - out[i * 4 + 0] = uint8_t((h[i] >> 24) & 0xFFu); - out[i * 4 + 1] = uint8_t((h[i] >> 16) & 0xFFu); - out[i * 4 + 2] = uint8_t((h[i] >> 8) & 0xFFu); - out[i * 4 + 3] = uint8_t( h[i] & 0xFFu); - } -} diff --git a/sha256/gpu/cuda/sha256_driver.cpp b/sha256/gpu/cuda/sha256_driver.cpp deleted file mode 100644 index f876e4d..0000000 --- a/sha256/gpu/cuda/sha256_driver.cpp +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// CUDA host driver for batched SHA-256 (FIPS 180-4). -// -// Build modes: -// 1. With CUDA toolkit (LUX_SHA256_HAVE_CUDA defined): -// - Compiles sha256.cu via nvcc; invokes the kernel with one thread -// per input. Byte-equal to sha256/cpp/sha256.cpp::sha256() and to -// sha256/gpu/metal/sha256_batch.metal. -// 2. Without CUDA (LUX_SHA256_HAVE_CUDA not defined): -// - Stub mode: lux_sha256_cuda_available() returns 0, every other -// function returns -1 ("CUDA unavailable on this host"). The test -// harness skips the CUDA path on Apple/non-CUDA hosts. - -#include "sha256_driver.h" - -#include -#include - -#ifdef LUX_SHA256_HAVE_CUDA -#include - -// Forward declaration of the CUDA kernel defined in sha256.cu. -extern "C" __global__ void sha256_jobs( - const uint8_t* inputs, - const uint32_t* input_offsets, - const uint32_t* input_lens, - uint8_t* outputs, - uint32_t num_jobs); - -extern "C" int lux_sha256_cuda_available(void) { - int count = 0; - cudaError_t e = cudaGetDeviceCount(&count); - return (e == cudaSuccess && count > 0) ? 1 : 0; -} - -extern "C" int sha256_batch_cuda( - const uint8_t* inputs_arena, - size_t inputs_arena_len, - const uint32_t* input_offsets, - const uint32_t* input_lens, - size_t n, - uint8_t* outputs_arena) { - - if (n == 0) return 0; - if (!inputs_arena || !input_offsets || !input_lens || !outputs_arena) return -1; - if (!lux_sha256_cuda_available()) return -2; - - uint8_t* d_inputs = nullptr; - uint32_t* d_offsets = nullptr; - uint32_t* d_lens = nullptr; - uint8_t* d_outputs = nullptr; - size_t out_bytes = n * 32u; - - auto cleanup = [&]() { - if (d_inputs) cudaFree(d_inputs); - if (d_offsets) cudaFree(d_offsets); - if (d_lens) cudaFree(d_lens); - if (d_outputs) cudaFree(d_outputs); - }; - - if (cudaMalloc((void**)&d_inputs, inputs_arena_len ? inputs_arena_len : 1) != cudaSuccess) { - cleanup(); return -3; - } - if (cudaMalloc((void**)&d_offsets, n * sizeof(uint32_t)) != cudaSuccess) { - cleanup(); return -3; - } - if (cudaMalloc((void**)&d_lens, n * sizeof(uint32_t)) != cudaSuccess) { - cleanup(); return -3; - } - if (cudaMalloc((void**)&d_outputs, out_bytes) != cudaSuccess) { - cleanup(); return -3; - } - - if (inputs_arena_len) { - if (cudaMemcpy(d_inputs, inputs_arena, inputs_arena_len, - cudaMemcpyHostToDevice) != cudaSuccess) { - cleanup(); return -4; - } - } - if (cudaMemcpy(d_offsets, input_offsets, n * sizeof(uint32_t), - cudaMemcpyHostToDevice) != cudaSuccess) { - cleanup(); return -4; - } - if (cudaMemcpy(d_lens, input_lens, n * sizeof(uint32_t), - cudaMemcpyHostToDevice) != cudaSuccess) { - cleanup(); return -4; - } - - unsigned tg = 64; - unsigned grid = unsigned((n + tg - 1) / tg); - sha256_jobs<<>>(d_inputs, d_offsets, d_lens, - d_outputs, uint32_t(n)); - if (cudaDeviceSynchronize() != cudaSuccess) { - cleanup(); return -4; - } - if (cudaMemcpy(outputs_arena, d_outputs, out_bytes, - cudaMemcpyDeviceToHost) != cudaSuccess) { - cleanup(); return -4; - } - cleanup(); - return 0; -} - -#else // LUX_SHA256_HAVE_CUDA not defined: stub mode - -extern "C" int lux_sha256_cuda_available(void) { return 0; } - -extern "C" int sha256_batch_cuda( - const uint8_t*, size_t, - const uint32_t*, const uint32_t*, - size_t, uint8_t*) { - return -1; -} - -#endif // LUX_SHA256_HAVE_CUDA diff --git a/sha256/gpu/cuda/sha256_driver.h b/sha256/gpu/cuda/sha256_driver.h deleted file mode 100644 index 5d7e533..0000000 --- a/sha256/gpu/cuda/sha256_driver.h +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Public C-ABI interface for the SHA-256 CUDA driver. Mirrors the Metal -// driver in sha256/gpu/metal/sha256_batch_driver.mm. On hosts without CUDA -// every function returns -1 except lux_sha256_cuda_available() which -// returns 0. - -#ifndef LUX_SHA256_DRIVER_CUDA_H -#define LUX_SHA256_DRIVER_CUDA_H - -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -// Returns 1 if a CUDA device is available, 0 otherwise. -int lux_sha256_cuda_available(void); - -// Run N SHA-256 hashes in one CUDA dispatch. Each input lives at -// inputs_arena[input_offsets[i] .. + input_lens[i]); each output goes to -// outputs_arena[i * 32 .. i * 32 + 32). Inputs are concatenated in -// `inputs_arena` of length `inputs_arena_len`. Returns 0 on success, -// negative on failure (-1 = invalid args, -2 = device unavailable, -// -3 = device alloc failed, -4 = launch failed). -int sha256_batch_cuda( - const uint8_t* inputs_arena, - size_t inputs_arena_len, - const uint32_t* input_offsets, - const uint32_t* input_lens, - size_t n, - uint8_t* outputs_arena); - -#ifdef __cplusplus -} -#endif - -#endif // LUX_SHA256_DRIVER_CUDA_H diff --git a/sha256/gpu/metal/sha256_batch.metal b/sha256/gpu/metal/sha256_batch.metal deleted file mode 100644 index e8b24d9..0000000 --- a/sha256/gpu/metal/sha256_batch.metal +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// GPU-batched SHA-256 (FIPS 180-4). One thread per input. Byte-equal to -// sha256/cpp/sha256.cpp::sha256() for arbitrary-length inputs. -// -// Layout: caller fills a Sha256Job[] with (input_offset, input_len, -// output_offset). Inputs share a flat byte arena; outputs share a 32-byte -// stride arena. Padding is canonical (FIPS 180-4 sec 5.1.1): append 0x80, then -// zero pad to 64 mod 56, then 8-byte big-endian bit length. - -#include -using namespace metal; - -// FIPS 180-4 round constants (first 32 bits of fractional parts of cube roots -// of the first 64 primes). -constant uint K[64] = { - 0x428a2f98u, 0x71374491u, 0xb5c0fbcfu, 0xe9b5dba5u, - 0x3956c25bu, 0x59f111f1u, 0x923f82a4u, 0xab1c5ed5u, - 0xd807aa98u, 0x12835b01u, 0x243185beu, 0x550c7dc3u, - 0x72be5d74u, 0x80deb1feu, 0x9bdc06a7u, 0xc19bf174u, - 0xe49b69c1u, 0xefbe4786u, 0x0fc19dc6u, 0x240ca1ccu, - 0x2de92c6fu, 0x4a7484aau, 0x5cb0a9dcu, 0x76f988dau, - 0x983e5152u, 0xa831c66du, 0xb00327c8u, 0xbf597fc7u, - 0xc6e00bf3u, 0xd5a79147u, 0x06ca6351u, 0x14292967u, - 0x27b70a85u, 0x2e1b2138u, 0x4d2c6dfcu, 0x53380d13u, - 0x650a7354u, 0x766a0abbu, 0x81c2c92eu, 0x92722c85u, - 0xa2bfe8a1u, 0xa81a664bu, 0xc24b8b70u, 0xc76c51a3u, - 0xd192e819u, 0xd6990624u, 0xf40e3585u, 0x106aa070u, - 0x19a4c116u, 0x1e376c08u, 0x2748774cu, 0x34b0bcb5u, - 0x391c0cb3u, 0x4ed8aa4au, 0x5b9cca4fu, 0x682e6ff3u, - 0x748f82eeu, 0x78a5636fu, 0x84c87814u, 0x8cc70208u, - 0x90befffau, 0xa4506cebu, 0xbef9a3f7u, 0xc67178f2u, -}; - -inline uint rotr32(uint x, uint n) { return (x >> n) | (x << (32 - n)); } - -inline void sha256_block(thread uint* h, thread const uchar* p) { - uint w[64]; - for (uint i = 0; i < 16; ++i) { - w[i] = (uint(p[i * 4 + 0]) << 24) | - (uint(p[i * 4 + 1]) << 16) | - (uint(p[i * 4 + 2]) << 8) | - (uint(p[i * 4 + 3])); - } - for (uint i = 16; i < 64; ++i) { - uint s0 = rotr32(w[i - 15], 7) ^ rotr32(w[i - 15], 18) ^ (w[i - 15] >> 3); - uint s1 = rotr32(w[i - 2], 17) ^ rotr32(w[i - 2], 19) ^ (w[i - 2] >> 10); - w[i] = w[i - 16] + s0 + w[i - 7] + s1; - } - - uint a = h[0], b = h[1], c = h[2], d = h[3]; - uint e = h[4], f = h[5], g = h[6], hh = h[7]; - - for (uint i = 0; i < 64; ++i) { - uint S1 = rotr32(e, 6) ^ rotr32(e, 11) ^ rotr32(e, 25); - uint ch = (e & f) ^ ((~e) & g); - uint t1 = hh + S1 + ch + K[i] + w[i]; - uint S0 = rotr32(a, 2) ^ rotr32(a, 13) ^ rotr32(a, 22); - uint mj = (a & b) ^ (a & c) ^ (b & c); - uint t2 = S0 + mj; - hh = g; g = f; f = e; e = d + t1; - d = c; c = b; b = a; a = t1 + t2; - } - h[0] += a; h[1] += b; h[2] += c; h[3] += d; - h[4] += e; h[5] += f; h[6] += g; h[7] += hh; -} - -struct Sha256JobGPU { - uint input_offset; - uint input_len; - uint output_offset; - uint _pad; -}; - -kernel void sha256_jobs( - device const Sha256JobGPU* jobs [[buffer(0)]], - device const uchar* inputs [[buffer(1)]], - device uchar* outputs [[buffer(2)]], - constant uint& num_jobs [[buffer(3)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= num_jobs) return; - - Sha256JobGPU j = jobs[tid]; - const device uchar* in = inputs + j.input_offset; - device uchar* out = outputs + j.output_offset; - - // FIPS 180-4 IV. - uint h[8] = { - 0x6a09e667u, 0xbb67ae85u, 0x3c6ef372u, 0xa54ff53au, - 0x510e527fu, 0x9b05688cu, 0x1f83d9abu, 0x5be0cd19u, - }; - - // Process full 64-byte blocks streaming from device memory. - uchar block[64]; - uint absorbed = 0; - while (j.input_len - absorbed >= 64) { - for (uint i = 0; i < 64; ++i) block[i] = in[absorbed + i]; - sha256_block(h, block); - absorbed += 64; - } - - // Final block(s): copy tail, append 0x80, pad zero, append 64-bit - // big-endian bit length. May span one or two blocks. - uint rem = j.input_len - absorbed; - for (uint i = 0; i < 64; ++i) block[i] = 0; - for (uint i = 0; i < rem; ++i) block[i] = in[absorbed + i]; - block[rem] = 0x80u; - - if (rem >= 56) { - // Tail spans two final blocks. - sha256_block(h, block); - for (uint i = 0; i < 64; ++i) block[i] = 0; - } - - ulong bit_len = (ulong)j.input_len * 8u; - block[56] = uchar((bit_len >> 56) & 0xFF); - block[57] = uchar((bit_len >> 48) & 0xFF); - block[58] = uchar((bit_len >> 40) & 0xFF); - block[59] = uchar((bit_len >> 32) & 0xFF); - block[60] = uchar((bit_len >> 24) & 0xFF); - block[61] = uchar((bit_len >> 16) & 0xFF); - block[62] = uchar((bit_len >> 8) & 0xFF); - block[63] = uchar( bit_len & 0xFF); - sha256_block(h, block); - - // Output big-endian. - for (uint i = 0; i < 8; ++i) { - out[i * 4 + 0] = uchar((h[i] >> 24) & 0xFF); - out[i * 4 + 1] = uchar((h[i] >> 16) & 0xFF); - out[i * 4 + 2] = uchar((h[i] >> 8) & 0xFF); - out[i * 4 + 3] = uchar( h[i] & 0xFF); - } -} diff --git a/sha256/gpu/metal/sha256_batch_driver.mm b/sha256/gpu/metal/sha256_batch_driver.mm deleted file mode 100644 index f8b8b17..0000000 --- a/sha256/gpu/metal/sha256_batch_driver.mm +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Metal driver for batched SHA-256 (FIPS 180-4). macOS / iOS only. -// -// Loads the precompiled sha256_batch.metallib, dispatches `sha256_jobs` with -// one thread per input. Byte-equal to sha256/cpp/sha256.cpp::sha256(). - -#if __APPLE__ && __OBJC__ - -#import -#import - -#include -#include -#include -#include - -namespace { - -struct Sha256JobGPU { - uint32_t input_offset; - uint32_t input_len; - uint32_t output_offset; - uint32_t _pad; -}; - -} // namespace - -// Run N SHA-256 hashes in one Metal dispatch. Each input lives at -// inputs[input_offsets[i] .. + input_lens[i]); each output goes to -// outputs[i * 32 .. i * 32 + 32). Inputs are concatenated in `inputs_arena` -// of length `inputs_arena_len`. Returns 0 on success, negative on failure. -extern "C" int sha256_batch_metal( - const uint8_t* inputs_arena, - size_t inputs_arena_len, - const uint32_t* input_offsets, - const uint32_t* input_lens, - size_t n, - uint8_t* outputs_arena, - const char* metallib_path) { - - if (n == 0) return 0; - if (!inputs_arena || !input_offsets || !input_lens || !outputs_arena || - !metallib_path) return -1; - - @autoreleasepool { - id device = MTLCreateSystemDefaultDevice(); - if (!device) return -2; - - NSError* err = nil; - NSString* path = [NSString stringWithUTF8String:metallib_path]; - NSURL* url = [NSURL fileURLWithPath:path]; - id lib = [device newLibraryWithURL:url error:&err]; - if (!lib) return -3; - - id fn = [lib newFunctionWithName:@"sha256_jobs"]; - if (!fn) return -4; - - id pipeline = - [device newComputePipelineStateWithFunction:fn error:&err]; - if (!pipeline) return -5; - - id queue = [device newCommandQueue]; - - // Build job descriptor array. - std::vector jobs(n); - for (size_t i = 0; i < n; ++i) { - jobs[i].input_offset = input_offsets[i]; - jobs[i].input_len = input_lens[i]; - jobs[i].output_offset = (uint32_t)(i * 32); - jobs[i]._pad = 0; - } - - id jobs_buf = [device newBufferWithBytes:jobs.data() - length:jobs.size() * sizeof(Sha256JobGPU) - options:MTLResourceStorageModeShared]; - id inputs_buf = [device newBufferWithBytes:inputs_arena - length:inputs_arena_len - options:MTLResourceStorageModeShared]; - id outputs_buf = [device newBufferWithLength:n * 32 - options:MTLResourceStorageModeShared]; - uint32_t n_u32 = (uint32_t)n; - id n_buf = [device newBufferWithBytes:&n_u32 - length:sizeof(n_u32) - options:MTLResourceStorageModeShared]; - - id cmd = [queue commandBuffer]; - id enc = [cmd computeCommandEncoder]; - [enc setComputePipelineState:pipeline]; - [enc setBuffer:jobs_buf offset:0 atIndex:0]; - [enc setBuffer:inputs_buf offset:0 atIndex:1]; - [enc setBuffer:outputs_buf offset:0 atIndex:2]; - [enc setBuffer:n_buf offset:0 atIndex:3]; - - // One thread per job. Threadgroup width = pipeline's max width. - NSUInteger tg_max = pipeline.maxTotalThreadsPerThreadgroup; - NSUInteger tg_w = tg_max < 64 ? tg_max : 64; - MTLSize threads_per_grid = MTLSizeMake(n, 1, 1); - MTLSize threads_per_tg = MTLSizeMake(tg_w, 1, 1); - [enc dispatchThreads:threads_per_grid threadsPerThreadgroup:threads_per_tg]; - [enc endEncoding]; - [cmd commit]; - [cmd waitUntilCompleted]; - - std::memcpy(outputs_arena, [outputs_buf contents], n * 32); - } - return 0; -} - -#endif // __APPLE__ && __OBJC__ diff --git a/sha256/gpu/wgsl/sha256.wgsl b/sha256/gpu/wgsl/sha256.wgsl deleted file mode 100644 index 27dbd76..0000000 --- a/sha256/gpu/wgsl/sha256.wgsl +++ /dev/null @@ -1,196 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// SHA-256 (FIPS 180-4) compute shader in WGSL. -// -// One thread per input. Each thread reads its (offset, length) descriptor, -// processes 64-byte blocks through the canonical 64-round compression -// function, then emits a 32-byte big-endian digest. Padding is canonical -// (FIPS 180-4 §5.1.1): append 0x80, zero pad to 64 mod 56, append 8-byte -// big-endian bit length. -// -// Byte-equal to sha256/cpp/sha256.cpp and sha256/gpu/cuda/sha256.cu and -// sha256/gpu/metal/sha256_batch.metal. - -struct HashInput { - offset: u32, - length: u32, -} - -@group(0) @binding(0) var inputs: array; -@group(0) @binding(1) var data: array; -@group(0) @binding(2) var outputs: array; - -// FIPS 180-4 round constants (§4.2.2). -const K = array( - 0x428a2f98u, 0x71374491u, 0xb5c0fbcfu, 0xe9b5dba5u, - 0x3956c25bu, 0x59f111f1u, 0x923f82a4u, 0xab1c5ed5u, - 0xd807aa98u, 0x12835b01u, 0x243185beu, 0x550c7dc3u, - 0x72be5d74u, 0x80deb1feu, 0x9bdc06a7u, 0xc19bf174u, - 0xe49b69c1u, 0xefbe4786u, 0x0fc19dc6u, 0x240ca1ccu, - 0x2de92c6fu, 0x4a7484aau, 0x5cb0a9dcu, 0x76f988dau, - 0x983e5152u, 0xa831c66du, 0xb00327c8u, 0xbf597fc7u, - 0xc6e00bf3u, 0xd5a79147u, 0x06ca6351u, 0x14292967u, - 0x27b70a85u, 0x2e1b2138u, 0x4d2c6dfcu, 0x53380d13u, - 0x650a7354u, 0x766a0abbu, 0x81c2c92eu, 0x92722c85u, - 0xa2bfe8a1u, 0xa81a664bu, 0xc24b8b70u, 0xc76c51a3u, - 0xd192e819u, 0xd6990624u, 0xf40e3585u, 0x106aa070u, - 0x19a4c116u, 0x1e376c08u, 0x2748774cu, 0x34b0bcb5u, - 0x391c0cb3u, 0x4ed8aa4au, 0x5b9cca4fu, 0x682e6ff3u, - 0x748f82eeu, 0x78a5636fu, 0x84c87814u, 0x8cc70208u, - 0x90befffau, 0xa4506cebu, 0xbef9a3f7u, 0xc67178f2u, -); - -fn rotr32(x: u32, n: u32) -> u32 { - return (x >> n) | (x << (32u - n)); -} - -// Read a single byte from the packed u32 input arena (little-endian packing). -fn read_byte(byte_offset: u32) -> u32 { - let word_idx = byte_offset >> 2u; - let byte_pos = byte_offset & 3u; - return (data[word_idx] >> (byte_pos * 8u)) & 0xFFu; -} - -// Process one 64-byte block: build big-endian-encoded message schedule, -// run the 64-round compression, accumulate into h[0..7]. -fn sha256_block(block: ptr>, h: ptr>) { - var w: array; - for (var i = 0u; i < 16u; i = i + 1u) { - w[i] = (*block)[i]; - } - for (var i = 16u; i < 64u; i = i + 1u) { - let s0 = rotr32(w[i - 15u], 7u) ^ rotr32(w[i - 15u], 18u) ^ (w[i - 15u] >> 3u); - let s1 = rotr32(w[i - 2u], 17u) ^ rotr32(w[i - 2u], 19u) ^ (w[i - 2u] >> 10u); - w[i] = w[i - 16u] + s0 + w[i - 7u] + s1; - } - - var a: u32 = (*h)[0]; - var b: u32 = (*h)[1]; - var c: u32 = (*h)[2]; - var d: u32 = (*h)[3]; - var e: u32 = (*h)[4]; - var f: u32 = (*h)[5]; - var g: u32 = (*h)[6]; - var hh: u32 = (*h)[7]; - - for (var i = 0u; i < 64u; i = i + 1u) { - let S1 = rotr32(e, 6u) ^ rotr32(e, 11u) ^ rotr32(e, 25u); - let ch = (e & f) ^ ((~e) & g); - let t1 = hh + S1 + ch + K[i] + w[i]; - let S0 = rotr32(a, 2u) ^ rotr32(a, 13u) ^ rotr32(a, 22u); - let mj = (a & b) ^ (a & c) ^ (b & c); - let t2 = S0 + mj; - hh = g; g = f; f = e; e = d + t1; - d = c; c = b; b = a; a = t1 + t2; - } - - (*h)[0] = (*h)[0] + a; - (*h)[1] = (*h)[1] + b; - (*h)[2] = (*h)[2] + c; - (*h)[3] = (*h)[3] + d; - (*h)[4] = (*h)[4] + e; - (*h)[5] = (*h)[5] + f; - (*h)[6] = (*h)[6] + g; - (*h)[7] = (*h)[7] + hh; -} - -// Build the i-th big-endian message word from `block_offset` byte position. -fn be_word(block_offset: u32) -> u32 { - let b0 = read_byte(block_offset + 0u); - let b1 = read_byte(block_offset + 1u); - let b2 = read_byte(block_offset + 2u); - let b3 = read_byte(block_offset + 3u); - return (b0 << 24u) | (b1 << 16u) | (b2 << 8u) | b3; -} - -@compute @workgroup_size(64) -fn sha256_jobs(@builtin(global_invocation_id) gid: vec3) { - let tid = gid.x; - if (tid >= arrayLength(&inputs)) { - return; - } - let inp = inputs[tid]; - let offset = inp.offset; - let len = inp.length; - - // FIPS 180-4 IV (§5.3.3). - var h: array = array( - 0x6a09e667u, 0xbb67ae85u, 0x3c6ef372u, 0xa54ff53au, - 0x510e527fu, 0x9b05688cu, 0x1f83d9abu, 0x5be0cd19u, - ); - - // Process full 64-byte blocks. - var absorbed: u32 = 0u; - var block: array; - loop { - if (len - absorbed < 64u) { break; } - for (var i = 0u; i < 16u; i = i + 1u) { - block[i] = be_word(offset + absorbed + i * 4u); - } - sha256_block(&block, &h); - absorbed = absorbed + 64u; - } - - // Build final-block byte buffer (we use a u32[16] view written - // big-endian directly). - var pad_bytes: array; // each entry holds one byte (0..255) - for (var i = 0u; i < 64u; i = i + 1u) { - pad_bytes[i] = 0u; - } - let rem = len - absorbed; - for (var i = 0u; i < rem; i = i + 1u) { - pad_bytes[i] = read_byte(offset + absorbed + i); - } - pad_bytes[rem] = 0x80u; - - // If the tail spans two final blocks, hash the first now and zero - // pad_bytes for the second pass. - if (rem >= 56u) { - for (var i = 0u; i < 16u; i = i + 1u) { - block[i] = (pad_bytes[i * 4u + 0u] << 24u) - | (pad_bytes[i * 4u + 1u] << 16u) - | (pad_bytes[i * 4u + 2u] << 8u) - | pad_bytes[i * 4u + 3u]; - } - sha256_block(&block, &h); - for (var i = 0u; i < 64u; i = i + 1u) { - pad_bytes[i] = 0u; - } - } - - // Append 64-bit big-endian bit length in the trailing 8 bytes. - let bit_len_lo: u32 = len << 3u; - let bit_len_hi: u32 = len >> 29u; // upper 32 bits of (len * 8) - pad_bytes[56] = (bit_len_hi >> 24u) & 0xFFu; - pad_bytes[57] = (bit_len_hi >> 16u) & 0xFFu; - pad_bytes[58] = (bit_len_hi >> 8u) & 0xFFu; - pad_bytes[59] = bit_len_hi & 0xFFu; - pad_bytes[60] = (bit_len_lo >> 24u) & 0xFFu; - pad_bytes[61] = (bit_len_lo >> 16u) & 0xFFu; - pad_bytes[62] = (bit_len_lo >> 8u) & 0xFFu; - pad_bytes[63] = bit_len_lo & 0xFFu; - - for (var i = 0u; i < 16u; i = i + 1u) { - block[i] = (pad_bytes[i * 4u + 0u] << 24u) - | (pad_bytes[i * 4u + 1u] << 16u) - | (pad_bytes[i * 4u + 2u] << 8u) - | pad_bytes[i * 4u + 3u]; - } - sha256_block(&block, &h); - - // Emit 32-byte big-endian digest. Each lane is one big-endian u32; in the - // outputs buffer we still store little-endian-packed u32 so the host can - // read bytes back in canonical order. We pack each big-endian word as - // bytes and stuff them into outputs[i * 8 .. i * 8 + 8] little-endian. - let out_base = tid * 8u; - for (var i = 0u; i < 8u; i = i + 1u) { - let v = h[i]; - let b0: u32 = (v >> 24u) & 0xFFu; - let b1: u32 = (v >> 16u) & 0xFFu; - let b2: u32 = (v >> 8u) & 0xFFu; - let b3: u32 = v & 0xFFu; - // outputs is a u32 array packed little-endian: byte0 = low byte. - outputs[out_base + i] = b0 | (b1 << 8u) | (b2 << 16u) | (b3 << 24u); - } -} diff --git a/sha256/gpu/wgsl/sha256_driver_wgpu.cpp b/sha256/gpu/wgsl/sha256_driver_wgpu.cpp deleted file mode 100644 index 54028bf..0000000 --- a/sha256/gpu/wgsl/sha256_driver_wgpu.cpp +++ /dev/null @@ -1,284 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// WebGPU/WGSL host driver for batched SHA-256 (FIPS 180-4). -// -// Mirrors the Metal/CUDA layout: caller supplies a packed input arena, an -// (offset, length) descriptor per input, and a contiguous 32-byte stride -// output buffer. -// -// Build flags: -// * LUX_SHA256_HAS_WEBGPU=1 - Dawn or wgpu-native runtime found -// * LUX_SHA256_HAS_WGPU_NATIVE=1 - wgpu-native specifically (gives -// wgpuDevicePoll for synchronous waits) - -#include "sha256_driver_wgpu.h" - -#if defined(LUX_SHA256_HAS_WEBGPU) - -#include -#if defined(LUX_SHA256_HAS_WGPU_NATIVE) -# include -#endif - -#include -#include -#include -#include -#include -#include - -// WGSL source concatenated into a string literal by CMake. -#include "sha256_wgsl_sources.h" - -namespace { - -WGPUStringView sv(const char* s) { - WGPUStringView v{}; - v.data = s; - v.length = (s == nullptr) ? 0 : std::strlen(s); - return v; -} -WGPUStringView sv(const std::string& s) { - WGPUStringView v{}; - v.data = s.data(); - v.length = s.size(); - return v; -} - -void drain(WGPUInstance inst, WGPUDevice dev) { - if (inst) wgpuInstanceProcessEvents(inst); -#if defined(LUX_SHA256_HAS_WGPU_NATIVE) - if (dev) wgpuDevicePoll(dev, /*wait=*/WGPU_TRUE, nullptr); -#else - (void)dev; -#endif -} - -bool wait_map(WGPUInstance inst, WGPUDevice dev, WGPUBuffer buf, - WGPUMapMode mode, size_t off, size_t size) { - struct State { - std::atomic done{false}; - WGPUMapAsyncStatus status{WGPUMapAsyncStatus_Error}; - } s; - WGPUBufferMapCallbackInfo cb{}; - cb.mode = WGPUCallbackMode_AllowProcessEvents; - cb.callback = [](WGPUMapAsyncStatus st, WGPUStringView, void* u, void*) { - auto* p = static_cast(u); - p->status = st; - p->done.store(true, std::memory_order_release); - }; - cb.userdata1 = &s; - wgpuBufferMapAsync(buf, mode, off, size, cb); - for (int spin = 0; spin < 4096; spin++) { - if (s.done.load(std::memory_order_acquire)) break; - drain(inst, dev); - } - return s.done.load() && s.status == WGPUMapAsyncStatus_Success; -} - -struct Engine { - WGPUInstance instance{nullptr}; - WGPUAdapter adapter{nullptr}; - WGPUDevice device{nullptr}; - WGPUQueue queue{nullptr}; - WGPUShaderModule module{nullptr}; - WGPUComputePipeline pipeline{nullptr}; - bool initialized{false}; -}; - -Engine& engine() { static Engine e; return e; } - -bool init_engine() { - Engine& e = engine(); - if (e.initialized) return true; - - WGPUInstanceDescriptor idesc{}; - e.instance = wgpuCreateInstance(&idesc); - if (!e.instance) return false; - - struct AS { std::atomic done{false}; WGPUAdapter ad{nullptr}; } as; - WGPURequestAdapterOptions ropt{}; - ropt.powerPreference = WGPUPowerPreference_HighPerformance; - WGPURequestAdapterCallbackInfo rcb{}; - rcb.mode = WGPUCallbackMode_AllowProcessEvents; - rcb.callback = [](WGPURequestAdapterStatus st, WGPUAdapter ad, - WGPUStringView, void* u, void*) { - auto* p = static_cast(u); - if (st == WGPURequestAdapterStatus_Success) p->ad = ad; - p->done.store(true, std::memory_order_release); - }; - rcb.userdata1 = &as; - wgpuInstanceRequestAdapter(e.instance, &ropt, rcb); - for (int spin = 0; spin < 4096; spin++) { - if (as.done.load(std::memory_order_acquire)) break; - wgpuInstanceProcessEvents(e.instance); - } - if (!as.ad) { std::fprintf(stderr, "wgpu: no adapter\n"); return false; } - e.adapter = as.ad; - - struct DS { std::atomic done{false}; WGPUDevice dev{nullptr}; } ds; - WGPUDeviceDescriptor ddesc{}; - WGPURequestDeviceCallbackInfo dcb{}; - dcb.mode = WGPUCallbackMode_AllowProcessEvents; - dcb.callback = [](WGPURequestDeviceStatus st, WGPUDevice dev, - WGPUStringView, void* u, void*) { - auto* p = static_cast(u); - if (st == WGPURequestDeviceStatus_Success) p->dev = dev; - p->done.store(true, std::memory_order_release); - }; - dcb.userdata1 = &ds; - wgpuAdapterRequestDevice(e.adapter, &ddesc, dcb); - for (int spin = 0; spin < 4096; spin++) { - if (ds.done.load(std::memory_order_acquire)) break; - wgpuInstanceProcessEvents(e.instance); - } - if (!ds.dev) { std::fprintf(stderr, "wgpu: no device\n"); return false; } - e.device = ds.dev; - e.queue = wgpuDeviceGetQueue(e.device); - if (!e.queue) return false; - - WGPUShaderSourceWGSL wgsl{}; - wgsl.chain.sType = WGPUSType_ShaderSourceWGSL; - wgsl.code = sv(kSHA256_WGSL_Source); - - WGPUShaderModuleDescriptor smd{}; - smd.nextInChain = &wgsl.chain; - smd.label = sv("sha256"); - e.module = wgpuDeviceCreateShaderModule(e.device, &smd); - if (!e.module) { - std::fprintf(stderr, "wgpu: sha256 shader compile failed\n"); - return false; - } - - WGPUComputePipelineDescriptor cpd{}; - cpd.compute.module = e.module; - cpd.compute.entryPoint = sv("sha256_jobs"); - cpd.label = sv("sha256_jobs"); - e.pipeline = wgpuDeviceCreateComputePipeline(e.device, &cpd); - if (!e.pipeline) { - std::fprintf(stderr, "wgpu: sha256 pipeline failed\n"); - return false; - } - - e.initialized = true; - return true; -} - -WGPUBuffer make_buf(Engine& e, size_t size, WGPUBufferUsage usage) { - WGPUBufferDescriptor bd{}; - bd.size = (size + 3) & ~size_t(3); - if (bd.size == 0) bd.size = 4; - bd.usage = usage; - return wgpuDeviceCreateBuffer(e.device, &bd); -} - -} // namespace - -extern "C" int lux_sha256_wgpu_available(void) { - return init_engine() ? 1 : 0; -} - -extern "C" int sha256_batch_wgpu( - const uint8_t* inputs_arena, - size_t inputs_arena_len, - const uint32_t* input_offsets, - const uint32_t* input_lens, - size_t n, - uint8_t* outputs_arena) { - - if (n == 0) return 0; - if (!input_offsets || !input_lens || !outputs_arena) return -1; - if (!init_engine()) return -2; - Engine& e = engine(); - - // Pack the inputs descriptor (offset, length) into a u32 array of - // length 2*n. Pad inputs_arena up to a 4-byte boundary for the data - // buffer because WGSL reads it as array. - std::vector desc(n * 2); - for (size_t i = 0; i < n; ++i) { - desc[i * 2 + 0] = input_offsets[i]; - desc[i * 2 + 1] = input_lens[i]; - } - - size_t data_bytes = (inputs_arena_len + 3) & ~size_t(3); - if (data_bytes == 0) data_bytes = 4; - std::vector data_padded(data_bytes, 0); - if (inputs_arena_len) std::memcpy(data_padded.data(), inputs_arena, - inputs_arena_len); - - size_t out_words = n * 8; // 32 bytes per hash = 8 u32 words - size_t out_bytes = out_words * 4; - - WGPUBuffer buf_desc = make_buf(e, desc.size() * sizeof(uint32_t), - WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst); - WGPUBuffer buf_data = make_buf(e, data_bytes, - WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst); - WGPUBuffer buf_out = make_buf(e, out_bytes, - WGPUBufferUsage_Storage | WGPUBufferUsage_CopySrc); - WGPUBuffer buf_read = make_buf(e, out_bytes, - WGPUBufferUsage_MapRead | WGPUBufferUsage_CopyDst); - if (!buf_desc || !buf_data || !buf_out || !buf_read) return -3; - - wgpuQueueWriteBuffer(e.queue, buf_desc, 0, desc.data(), - desc.size() * sizeof(uint32_t)); - wgpuQueueWriteBuffer(e.queue, buf_data, 0, data_padded.data(), data_bytes); - - WGPUBindGroupLayout bgl = wgpuComputePipelineGetBindGroupLayout(e.pipeline, 0); - WGPUBindGroupEntry bge[3] = {}; - bge[0].binding = 0; bge[0].buffer = buf_desc; bge[0].size = desc.size() * sizeof(uint32_t); - bge[1].binding = 1; bge[1].buffer = buf_data; bge[1].size = data_bytes; - bge[2].binding = 2; bge[2].buffer = buf_out; bge[2].size = out_bytes; - WGPUBindGroupDescriptor bgd{}; - bgd.layout = bgl; - bgd.entryCount = 3; - bgd.entries = bge; - WGPUBindGroup bg = wgpuDeviceCreateBindGroup(e.device, &bgd); - if (!bg) return -4; - - WGPUCommandEncoderDescriptor ced{}; - WGPUCommandEncoder ce = wgpuDeviceCreateCommandEncoder(e.device, &ced); - WGPUComputePassDescriptor cpd2{}; - WGPUComputePassEncoder cpe = wgpuCommandEncoderBeginComputePass(ce, &cpd2); - wgpuComputePassEncoderSetPipeline(cpe, e.pipeline); - wgpuComputePassEncoderSetBindGroup(cpe, 0, bg, 0, nullptr); - uint32_t wg = uint32_t((n + 63) / 64); - wgpuComputePassEncoderDispatchWorkgroups(cpe, wg, 1, 1); - wgpuComputePassEncoderEnd(cpe); - - wgpuCommandEncoderCopyBufferToBuffer(ce, buf_out, 0, buf_read, 0, out_bytes); - WGPUCommandBufferDescriptor cbd{}; - WGPUCommandBuffer cmd = wgpuCommandEncoderFinish(ce, &cbd); - wgpuQueueSubmit(e.queue, 1, &cmd); - - if (!wait_map(e.instance, e.device, buf_read, WGPUMapMode_Read, 0, out_bytes)) { - std::fprintf(stderr, "wgpu: sha256 readback map failed\n"); - return -5; - } - const void* mapped = wgpuBufferGetConstMappedRange(buf_read, 0, out_bytes); - std::memcpy(outputs_arena, mapped, n * 32); - wgpuBufferUnmap(buf_read); - - wgpuComputePassEncoderRelease(cpe); - wgpuCommandEncoderRelease(ce); - wgpuCommandBufferRelease(cmd); - wgpuBindGroupRelease(bg); - wgpuBindGroupLayoutRelease(bgl); - wgpuBufferRelease(buf_desc); - wgpuBufferRelease(buf_data); - wgpuBufferRelease(buf_out); - wgpuBufferRelease(buf_read); - return 0; -} - -#else // LUX_SHA256_HAS_WEBGPU not defined: stub mode - -extern "C" int lux_sha256_wgpu_available(void) { return 0; } -extern "C" int sha256_batch_wgpu( - const uint8_t*, size_t, - const uint32_t*, const uint32_t*, - size_t, uint8_t*) { - return -1; -} - -#endif diff --git a/sha256/gpu/wgsl/sha256_driver_wgpu.h b/sha256/gpu/wgsl/sha256_driver_wgpu.h deleted file mode 100644 index 2650471..0000000 --- a/sha256/gpu/wgsl/sha256_driver_wgpu.h +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// Public C-ABI for the SHA-256 WebGPU/WGSL driver. On hosts without a wgpu -// runtime, lux_sha256_wgpu_available() returns 0 and sha256_batch_wgpu() -// returns -1. - -#ifndef LUX_SHA256_DRIVER_WGPU_H -#define LUX_SHA256_DRIVER_WGPU_H - -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -// Returns 1 if a WebGPU adapter+device initialised successfully, 0 otherwise. -int lux_sha256_wgpu_available(void); - -// Run N SHA-256 hashes in one WGSL dispatch. Inputs share a flat byte arena; -// outputs are written as 32 contiguous bytes per hash. Returns 0 on success, -// negative on failure. -int sha256_batch_wgpu( - const uint8_t* inputs_arena, - size_t inputs_arena_len, - const uint32_t* input_offsets, - const uint32_t* input_lens, - size_t n, - uint8_t* outputs_arena); - -#ifdef __cplusplus -} -#endif - -#endif // LUX_SHA256_DRIVER_WGPU_H diff --git a/slhdsa/gpu/cuda/slhdsa.cu b/slhdsa/gpu/cuda/slhdsa.cu deleted file mode 100644 index d5177e8..0000000 --- a/slhdsa/gpu/cuda/slhdsa.cu +++ /dev/null @@ -1,376 +0,0 @@ -// SLH-DSA (FIPS 205, SPHINCS+) batch verify -- CUDA implementation -// Matches slhdsa.metal output byte-for-byte -// One thread per signature verification - -#include - -#ifndef __CUDA_ARCH__ -#define __device__ -#define __global__ -#define __shared__ -struct dim3 { unsigned x, y, z; }; -static dim3 blockIdx, blockDim, threadIdx; -#endif - -// ============================================================================= -// SLH-DSA-SHAKE-128f parameters -// ============================================================================= - -#define SLHDSA_N 16 // Security parameter (bytes) -#define SLHDSA_D 22 // Number of hypertree layers -#define SLHDSA_HP 3 // Height per layer (h/d) -#define SLHDSA_A 6 // FORS tree height -#define SLHDSA_K 33 // FORS trees -#define SLHDSA_W 16 // Winternitz parameter -#define SLHDSA_LEN1 4 // WOTS+ len1 = ceil(8n/log2(w)) -#define SLHDSA_LEN 7 // Total WOTS+ length (len1 + len2) - -// ============================================================================= -// Keccak-f[1600] permutation (for SHAKE256) -// ============================================================================= - -__device__ static const uint64_t KECCAK_RC[24] = { - 0x0000000000000001ULL, 0x0000000000008082ULL, - 0x800000000000808AULL, 0x8000000080008000ULL, - 0x000000000000808BULL, 0x0000000080000001ULL, - 0x8000000080008081ULL, 0x8000000000008009ULL, - 0x000000000000008AULL, 0x0000000000000088ULL, - 0x0000000080008009ULL, 0x000000008000000AULL, - 0x000000008000808BULL, 0x800000000000008BULL, - 0x8000000000008089ULL, 0x8000000000008003ULL, - 0x8000000000008002ULL, 0x8000000000000080ULL, - 0x000000000000800AULL, 0x800000008000000AULL, - 0x8000000080008081ULL, 0x8000000000008080ULL, - 0x0000000080000001ULL, 0x8000000080008008ULL, -}; - -__device__ static const int KECCAK_PI[24] = { - 10, 7, 11, 17, 18, 3, 5, 16, 8, 21, 24, 4, - 15, 23, 19, 13, 12, 2, 20, 14, 22, 9, 6, 1 -}; - -__device__ static const int KECCAK_RHO[24] = { - 1, 3, 6, 10, 15, 21, 28, 36, 45, 55, 2, 14, - 27, 41, 56, 8, 25, 43, 62, 18, 39, 61, 20, 44 -}; - -__device__ static uint64_t rotl64(uint64_t x, int n) { - return (x << n) | (x >> (64 - n)); -} - -__device__ static void keccak_f(uint64_t st[25]) { - for (int round = 0; round < 24; ++round) { - uint64_t C[5]; - for (int x = 0; x < 5; ++x) - C[x] = st[x] ^ st[x + 5] ^ st[x + 10] ^ st[x + 15] ^ st[x + 20]; - for (int x = 0; x < 5; ++x) { - uint64_t d = C[(x + 4) % 5] ^ rotl64(C[(x + 1) % 5], 1); - for (int y = 0; y < 5; ++y) st[x + 5 * y] ^= d; - } - uint64_t t = st[1]; - for (int i = 0; i < 24; ++i) { - uint64_t tmp = st[KECCAK_PI[i]]; - st[KECCAK_PI[i]] = rotl64(t, KECCAK_RHO[i]); - t = tmp; - } - for (int y = 0; y < 5; ++y) { - uint64_t row[5]; - for (int x = 0; x < 5; ++x) row[x] = st[x + 5 * y]; - for (int x = 0; x < 5; ++x) - st[x + 5 * y] = row[x] ^ ((~row[(x + 1) % 5]) & row[(x + 2) % 5]); - } - st[0] ^= KECCAK_RC[round]; - } -} - -// ============================================================================= -// SHAKE256 helper: absorb + squeeze n bytes -// ============================================================================= - -__device__ static void shake256(const uint8_t* input, uint32_t input_len, - uint8_t* output, uint32_t output_len) { - const uint32_t rate = 136; - uint64_t state[25]; - for (int i = 0; i < 25; i++) state[i] = 0; - - // Absorb - uint32_t absorbed = 0; - while (absorbed + rate <= input_len) { - for (uint32_t w = 0; w < rate / 8; ++w) { - uint64_t lane = 0; - for (uint32_t b = 0; b < 8; ++b) - lane |= (uint64_t)input[absorbed + w * 8 + b] << (b * 8); - state[w] ^= lane; - } - keccak_f(state); - absorbed += rate; - } - - // Pad (SHAKE: 0x1F || 0x00...0x00 || 0x80) - uint8_t padded[136]; - for (uint32_t i = 0; i < 136; i++) padded[i] = 0; - uint32_t remaining = input_len - absorbed; - for (uint32_t i = 0; i < remaining; i++) padded[i] = input[absorbed + i]; - padded[remaining] = 0x1F; - padded[rate - 1] |= 0x80; - - for (uint32_t w = 0; w < rate / 8; ++w) { - uint64_t lane = 0; - for (uint32_t b = 0; b < 8; ++b) - lane |= (uint64_t)padded[w * 8 + b] << (b * 8); - state[w] ^= lane; - } - keccak_f(state); - - // Squeeze - uint32_t squeezed = 0; - while (squeezed < output_len) { - uint32_t to_copy = output_len - squeezed; - if (to_copy > rate) to_copy = rate; - for (uint32_t i = 0; i < to_copy; i++) { - output[squeezed + i] = (uint8_t)(state[i / 8] >> ((i % 8) * 8)); - } - squeezed += to_copy; - if (squeezed < output_len) keccak_f(state); - } -} - -// ============================================================================= -// WOTS+ chain function -// ============================================================================= - -__device__ static void wots_chain_step(const uint8_t pk_seed[SLHDSA_N], - const uint8_t adrs[32], - const uint8_t input[SLHDSA_N], - uint8_t output[SLHDSA_N]) { - uint8_t buf[64]; - for (uint32_t i = 0; i < SLHDSA_N; i++) buf[i] = pk_seed[i]; - for (uint32_t i = 0; i < 32; i++) buf[SLHDSA_N + i] = adrs[i]; - for (uint32_t i = 0; i < SLHDSA_N; i++) buf[SLHDSA_N + 32 + i] = input[i]; - shake256(buf, 64, output, SLHDSA_N); -} - -__device__ static void wots_chain(const uint8_t pk_seed[SLHDSA_N], - uint8_t adrs[32], - const uint8_t x[SLHDSA_N], - int start, int steps, - uint8_t out[SLHDSA_N]) { - for (uint32_t i = 0; i < SLHDSA_N; i++) out[i] = x[i]; - - for (int i = start; i < start + steps; i++) { - adrs[28] = (uint8_t)(i >> 24); - adrs[29] = (uint8_t)(i >> 16); - adrs[30] = (uint8_t)(i >> 8); - adrs[31] = (uint8_t)(i); - - uint8_t tmp[SLHDSA_N]; - wots_chain_step(pk_seed, adrs, out, tmp); - for (uint32_t j = 0; j < SLHDSA_N; j++) out[j] = tmp[j]; - } -} - -// ============================================================================= -// SLH-DSA structures -// ============================================================================= - -struct SLHDSAPublicKey { - uint8_t data[32]; // PK.seed[16] || PK.root[16] -}; - -struct SLHDSAMessage { - uint8_t data[32]; -}; - -struct SLHDSASignature { - uint8_t data[17088]; // Max signature size for 128f, padded -}; - -// ============================================================================= -// Verification kernel -// ============================================================================= - -extern "C" __global__ void slhdsa_verify_batch( - const SLHDSAPublicKey* __restrict__ pubkeys, - const SLHDSAMessage* __restrict__ messages, - const SLHDSASignature* __restrict__ signatures, - uint32_t* __restrict__ results, - const uint32_t* __restrict__ num_sigs_ptr) -{ - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - uint32_t num_sigs = *num_sigs_ptr; - if (tid >= num_sigs) return; - - const uint8_t* pk = pubkeys[tid].data; - const uint8_t* sig = signatures[tid].data; - - // Extract PK.seed and PK.root - uint8_t pk_seed[SLHDSA_N]; - uint8_t pk_root[SLHDSA_N]; - for (uint32_t i = 0; i < SLHDSA_N; i++) { - pk_seed[i] = pk[i]; - pk_root[i] = pk[SLHDSA_N + i]; - } - - // Extract randomizer R from signature - uint8_t R[SLHDSA_N]; - for (uint32_t i = 0; i < SLHDSA_N; i++) R[i] = sig[i]; - - // -- Compute message digest using SHAKE256 -- - // digest = SHAKE256(R || PK.seed || PK.root || M) - uint8_t hash_input[96]; // R[16] + pk_seed[16] + pk_root[16] + msg[32] = 80 bytes - for (uint32_t i = 0; i < SLHDSA_N; i++) hash_input[i] = R[i]; - for (uint32_t i = 0; i < SLHDSA_N; i++) hash_input[SLHDSA_N + i] = pk_seed[i]; - for (uint32_t i = 0; i < SLHDSA_N; i++) hash_input[2 * SLHDSA_N + i] = pk_root[i]; - for (uint32_t i = 0; i < 32; i++) hash_input[3 * SLHDSA_N + i] = messages[tid].data[i]; - - uint8_t digest[32]; - shake256(hash_input, 3 * SLHDSA_N + 32, digest, 32); - - // -- FORS verification -- - uint32_t fors_offset = SLHDSA_N; - - uint8_t fors_roots[SLHDSA_K][SLHDSA_N]; - for (uint32_t tree = 0; tree < SLHDSA_K; tree++) { - // Extract FORS leaf - uint8_t leaf[SLHDSA_N]; - for (uint32_t i = 0; i < SLHDSA_N; i++) { - leaf[i] = sig[fors_offset + tree * (SLHDSA_N + SLHDSA_A * SLHDSA_N) + i]; - } - - // Hash the leaf to get node - uint8_t node[SLHDSA_N]; - uint8_t leaf_input[64]; - for (uint32_t i = 0; i < SLHDSA_N; i++) leaf_input[i] = pk_seed[i]; - for (uint32_t i = SLHDSA_N; i < 48; i++) leaf_input[i] = 0; - for (uint32_t i = 0; i < SLHDSA_N; i++) leaf_input[48 + i] = leaf[i]; - shake256(leaf_input, 64, node, SLHDSA_N); - - // Climb auth path - uint32_t auth_offset = fors_offset + tree * (SLHDSA_N + SLHDSA_A * SLHDSA_N) + SLHDSA_N; - - // Extract tree index from digest - uint32_t tree_idx = 0; - uint32_t bit_offset = tree * SLHDSA_A; - for (uint32_t b = 0; b < SLHDSA_A; b++) { - uint32_t byte_idx = (bit_offset + b) / 8; - uint32_t bit_pos = (bit_offset + b) % 8; - tree_idx |= ((uint32_t)(digest[byte_idx] >> bit_pos) & 1) << b; - } - - for (uint32_t layer = 0; layer < SLHDSA_A; layer++) { - uint8_t sibling[SLHDSA_N]; - for (uint32_t i = 0; i < SLHDSA_N; i++) { - sibling[i] = sig[auth_offset + layer * SLHDSA_N + i]; - } - - uint8_t pair_input[64]; - for (uint32_t i = 0; i < SLHDSA_N; i++) pair_input[i] = pk_seed[i]; - for (uint32_t i = SLHDSA_N; i < 32; i++) pair_input[i] = 0; - - if ((tree_idx >> layer) & 1) { - for (uint32_t i = 0; i < SLHDSA_N; i++) pair_input[32 + i] = sibling[i]; - for (uint32_t i = 0; i < SLHDSA_N; i++) pair_input[32 + SLHDSA_N + i] = node[i]; - } else { - for (uint32_t i = 0; i < SLHDSA_N; i++) pair_input[32 + i] = node[i]; - for (uint32_t i = 0; i < SLHDSA_N; i++) pair_input[32 + SLHDSA_N + i] = sibling[i]; - } - - shake256(pair_input, 64, node, SLHDSA_N); - } - - for (uint32_t i = 0; i < SLHDSA_N; i++) fors_roots[tree][i] = node[i]; - } - - // -- Compute FORS public key hash from roots -- - uint8_t fors_pk_input[SLHDSA_N + SLHDSA_K * SLHDSA_N]; - for (uint32_t i = 0; i < SLHDSA_N; i++) fors_pk_input[i] = pk_seed[i]; - for (uint32_t t = 0; t < SLHDSA_K; t++) { - for (uint32_t i = 0; i < SLHDSA_N; i++) { - fors_pk_input[SLHDSA_N + t * SLHDSA_N + i] = fors_roots[t][i]; - } - } - uint8_t fors_pk[SLHDSA_N]; - shake256(fors_pk_input, SLHDSA_N + SLHDSA_K * SLHDSA_N, fors_pk, SLHDSA_N); - - // -- Hypertree verification -- - uint8_t current_node[SLHDSA_N]; - for (uint32_t i = 0; i < SLHDSA_N; i++) current_node[i] = fors_pk[i]; - - uint32_t ht_offset = fors_offset + SLHDSA_K * (SLHDSA_N + SLHDSA_A * SLHDSA_N); - - for (uint32_t layer = 0; layer < SLHDSA_D; layer++) { - // Extract WOTS+ signature for this layer - uint8_t wots_sig[SLHDSA_LEN][SLHDSA_N]; - for (uint32_t i = 0; i < SLHDSA_LEN; i++) { - for (uint32_t j = 0; j < SLHDSA_N; j++) { - wots_sig[i][j] = sig[ht_offset + layer * (SLHDSA_LEN * SLHDSA_N + SLHDSA_HP * SLHDSA_N) + i * SLHDSA_N + j]; - } - } - - // Compute WOTS+ public key from signature - uint8_t adrs[32]; - for (uint32_t i = 0; i < 32; i++) adrs[i] = 0; - adrs[4] = (uint8_t)layer; - - uint8_t wots_pk_parts[SLHDSA_LEN][SLHDSA_N]; - for (uint32_t i = 0; i < SLHDSA_LEN; i++) { - uint32_t msg_byte = i < SLHDSA_N ? current_node[i] : 0; - uint32_t chain_start, chain_len; - - if (i < SLHDSA_LEN1) { - uint32_t digit = (msg_byte >> ((i % 2) * 4)) & 0x0F; - chain_start = digit; - chain_len = SLHDSA_W - 1 - digit; - } else { - chain_start = 0; - chain_len = SLHDSA_W - 1; - } - - adrs[20] = (uint8_t)(i >> 8); - adrs[21] = (uint8_t)(i); - - wots_chain(pk_seed, adrs, wots_sig[i], chain_start, chain_len, - wots_pk_parts[i]); - } - - // Hash WOTS+ PK parts to get node - uint8_t wots_pk_input[SLHDSA_N + SLHDSA_LEN * SLHDSA_N]; - for (uint32_t i = 0; i < SLHDSA_N; i++) wots_pk_input[i] = pk_seed[i]; - for (uint32_t i = 0; i < SLHDSA_LEN; i++) { - for (uint32_t j = 0; j < SLHDSA_N; j++) { - wots_pk_input[SLHDSA_N + i * SLHDSA_N + j] = wots_pk_parts[i][j]; - } - } - shake256(wots_pk_input, SLHDSA_N + SLHDSA_LEN * SLHDSA_N, current_node, SLHDSA_N); - - // Climb Merkle tree auth path for this layer - uint32_t auth_base = ht_offset + layer * (SLHDSA_LEN * SLHDSA_N + SLHDSA_HP * SLHDSA_N) - + SLHDSA_LEN * SLHDSA_N; - - for (uint32_t h = 0; h < SLHDSA_HP; h++) { - uint8_t sibling[SLHDSA_N]; - for (uint32_t i = 0; i < SLHDSA_N; i++) { - sibling[i] = sig[auth_base + h * SLHDSA_N + i]; - } - - uint8_t pair_input[64]; - for (uint32_t i = 0; i < SLHDSA_N; i++) pair_input[i] = pk_seed[i]; - for (uint32_t i = SLHDSA_N; i < 32; i++) pair_input[i] = 0; - for (uint32_t i = 0; i < SLHDSA_N; i++) pair_input[32 + i] = current_node[i]; - for (uint32_t i = 0; i < SLHDSA_N; i++) pair_input[32 + SLHDSA_N + i] = sibling[i]; - - shake256(pair_input, 64, current_node, SLHDSA_N); - } - } - - // -- Compare reconstructed root with PK.root -- - bool valid = true; - for (uint32_t i = 0; i < SLHDSA_N; i++) { - if (current_node[i] != pk_root[i]) { - valid = false; - break; - } - } - - results[tid] = valid ? 1u : 0u; -} diff --git a/slhdsa/gpu/metal/slhdsa.metal b/slhdsa/gpu/metal/slhdsa.metal deleted file mode 100644 index 39cfc1e..0000000 --- a/slhdsa/gpu/metal/slhdsa.metal +++ /dev/null @@ -1,415 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -/// @file slhdsa.metal -/// Metal compute shader for batch SLH-DSA (FIPS 205, SPHINCS+) verification. -/// -/// SLH-DSA is a stateless hash-based signature scheme. Verification is dominated -/// by hash computations (SHA-256 or SHAKE256). The GPU accelerates the -/// independent hash evaluations within the WOTS+ and Merkle tree layers. -/// -/// This implementation uses the SHAKE256-based instantiation (SLH-DSA-SHAKE-128f). -/// We leverage the Keccak permutation already available in keccak256.metal. -/// -/// Verification steps: -/// 1. Compute FORS tree root from signature -/// 2. Verify WOTS+ signatures at each hypertree layer -/// 3. Reconstruct Merkle tree path and verify root -/// -/// GPU advantage: each hash in the WOTS+ chain and Merkle tree is independent, -/// perfectly parallelizable across GPU threads. - -#include -using namespace metal; - -// ============================================================================= -// SLH-DSA-SHAKE-128f parameters -// ============================================================================= - -constant uint SLHDSA_N = 16; // Security parameter (bytes) -constant uint SLHDSA_D = 22; // Number of hypertree layers -constant uint SLHDSA_HP = 3; // Height per layer (h/d) -constant uint SLHDSA_A = 6; // FORS tree height -constant uint SLHDSA_K = 33; // FORS trees -constant uint SLHDSA_W = 16; // Winternitz parameter -constant uint SLHDSA_LEN1 = 4; // WOTS+ len1 = ceil(8n/log2(w)) -constant uint SLHDSA_LEN = 7; // Total WOTS+ length (len1 + len2) - -// ============================================================================= -// Keccak-f[1600] permutation (for SHAKE256) -// Reused from keccak256.metal patterns -// ============================================================================= - -constant ulong KECCAK_RC[24] = { - 0x0000000000000001UL, 0x0000000000008082UL, - 0x800000000000808AUL, 0x8000000080008000UL, - 0x000000000000808BUL, 0x0000000080000001UL, - 0x8000000080008081UL, 0x8000000000008009UL, - 0x000000000000008AUL, 0x0000000000000088UL, - 0x0000000080008009UL, 0x000000008000000AUL, - 0x000000008000808BUL, 0x800000000000008BUL, - 0x8000000000008089UL, 0x8000000000008003UL, - 0x8000000000008002UL, 0x8000000000000080UL, - 0x000000000000800AUL, 0x800000008000000AUL, - 0x8000000080008081UL, 0x8000000000008080UL, - 0x0000000080000001UL, 0x8000000080008008UL, -}; - -constant int KECCAK_PI[24] = { - 10, 7, 11, 17, 18, 3, 5, 16, 8, 21, 24, 4, - 15, 23, 19, 13, 12, 2, 20, 14, 22, 9, 6, 1 -}; - -constant int KECCAK_RHO[24] = { - 1, 3, 6, 10, 15, 21, 28, 36, 45, 55, 2, 14, - 27, 41, 56, 8, 25, 43, 62, 18, 39, 61, 20, 44 -}; - -inline ulong rotl64(ulong x, int n) { - return (x << n) | (x >> (64 - n)); -} - -void keccak_f(thread ulong st[25]) { - for (int round = 0; round < 24; ++round) { - ulong C[5]; - for (int x = 0; x < 5; ++x) - C[x] = st[x] ^ st[x + 5] ^ st[x + 10] ^ st[x + 15] ^ st[x + 20]; - for (int x = 0; x < 5; ++x) { - ulong d = C[(x + 4) % 5] ^ rotl64(C[(x + 1) % 5], 1); - for (int y = 0; y < 5; ++y) st[x + 5 * y] ^= d; - } - ulong t = st[1]; - for (int i = 0; i < 24; ++i) { - ulong tmp = st[KECCAK_PI[i]]; - st[KECCAK_PI[i]] = rotl64(t, KECCAK_RHO[i]); - t = tmp; - } - for (int y = 0; y < 5; ++y) { - ulong row[5]; - for (int x = 0; x < 5; ++x) row[x] = st[x + 5 * y]; - for (int x = 0; x < 5; ++x) - st[x + 5 * y] = row[x] ^ ((~row[(x + 1) % 5]) & row[(x + 2) % 5]); - } - st[0] ^= KECCAK_RC[round]; - } -} - -// ============================================================================= -// SHAKE256 helper: absorb + squeeze n bytes -// ============================================================================= - -/// Hash arbitrary input to n output bytes using SHAKE256. -/// Rate = 136 bytes (1088 bits). -inline void shake256(thread const uchar* input, uint input_len, - thread uchar* output, uint output_len) { - const uint rate = 136; - ulong state[25] = {}; - - // Absorb - uint absorbed = 0; - while (absorbed + rate <= input_len) { - for (uint w = 0; w < rate / 8; ++w) { - ulong lane = 0; - for (uint b = 0; b < 8; ++b) - lane |= ulong(input[absorbed + w * 8 + b]) << (b * 8); - state[w] ^= lane; - } - keccak_f(state); - absorbed += rate; - } - - // Pad (SHAKE: 0x1F || 0x00...0x00 || 0x80) - uchar padded[136] = {}; - uint remaining = input_len - absorbed; - for (uint i = 0; i < remaining; i++) padded[i] = input[absorbed + i]; - padded[remaining] = 0x1F; - padded[rate - 1] |= 0x80; - - for (uint w = 0; w < rate / 8; ++w) { - ulong lane = 0; - for (uint b = 0; b < 8; ++b) - lane |= ulong(padded[w * 8 + b]) << (b * 8); - state[w] ^= lane; - } - keccak_f(state); - - // Squeeze - uint squeezed = 0; - while (squeezed < output_len) { - uint to_copy = min(rate, output_len - squeezed); - for (uint i = 0; i < to_copy; i++) { - output[squeezed + i] = uchar(state[i / 8] >> ((i % 8) * 8)); - } - squeezed += to_copy; - if (squeezed < output_len) keccak_f(state); - } -} - -// ============================================================================= -// WOTS+ chain function -// ============================================================================= - -/// Compute one step of the WOTS+ chain: hash input with ADRS tweak. -/// F(PK.seed, ADRS, M) = SHAKE256(PK.seed || ADRS || M) -inline void wots_chain_step(thread const uchar pk_seed[SLHDSA_N], - thread const uchar adrs[32], - thread const uchar input[SLHDSA_N], - thread uchar output[SLHDSA_N]) { - // Concatenate: pk_seed[16] || adrs[32] || input[16] = 64 bytes - uchar buf[64]; - for (uint i = 0; i < SLHDSA_N; i++) buf[i] = pk_seed[i]; - for (uint i = 0; i < 32; i++) buf[SLHDSA_N + i] = adrs[i]; - for (uint i = 0; i < SLHDSA_N; i++) buf[SLHDSA_N + 32 + i] = input[i]; - - shake256(buf, 64, output, SLHDSA_N); -} - -/// Compute WOTS+ chain: iterate hash s times starting from value X -inline void wots_chain(thread const uchar pk_seed[SLHDSA_N], - thread uchar adrs[32], - thread const uchar x[SLHDSA_N], - int start, int steps, - thread uchar out[SLHDSA_N]) { - for (uint i = 0; i < SLHDSA_N; i++) out[i] = x[i]; - - for (int i = start; i < start + steps; i++) { - // Set chain index in ADRS - adrs[28] = uchar(i >> 24); - adrs[29] = uchar(i >> 16); - adrs[30] = uchar(i >> 8); - adrs[31] = uchar(i); - - uchar tmp[SLHDSA_N]; - wots_chain_step(pk_seed, adrs, out, tmp); - for (uint j = 0; j < SLHDSA_N; j++) out[j] = tmp[j]; - } -} - -// ============================================================================= -// SLH-DSA structures -// ============================================================================= - -/// SLH-DSA-SHAKE-128f public key: 2*n = 32 bytes -struct SLHDSAPublicKey { - uchar data[32]; // PK.seed[16] || PK.root[16] -}; - -/// SLH-DSA message (pre-hashed, 32 bytes) -struct SLHDSAMessage { - uchar data[32]; -}; - -/// SLH-DSA signature (variable size, max ~17KB for 128f) -/// Packed: R[16] || FORS_SIG || HT_SIG -struct SLHDSASignature { - uchar data[17088]; // Max signature size for 128f, padded -}; - -// ============================================================================= -// Verification kernel -// ============================================================================= - -/// Batch SLH-DSA signature verification. -/// Each thread verifies one signature by recomputing hash chains. -/// -/// The GPU accelerates the large number of independent hash evaluations -/// in WOTS+ chains and Merkle tree computations. -/// -/// Output: results[tid] = 1 if WOTS+ chain checks pass, 0 otherwise. -kernel void slhdsa_verify_batch( - device const SLHDSAPublicKey* pubkeys [[buffer(0)]], - device const SLHDSAMessage* messages [[buffer(1)]], - device const SLHDSASignature* signatures [[buffer(2)]], - device uint* results [[buffer(3)]], - constant uint& num_sigs [[buffer(4)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= num_sigs) return; - - device const uchar* pk = pubkeys[tid].data; - device const uchar* sig = signatures[tid].data; - - // Extract PK.seed and PK.root - uchar pk_seed[SLHDSA_N]; - uchar pk_root[SLHDSA_N]; - for (uint i = 0; i < SLHDSA_N; i++) { - pk_seed[i] = pk[i]; - pk_root[i] = pk[SLHDSA_N + i]; - } - - // Extract randomizer R from signature - uchar R[SLHDSA_N]; - for (uint i = 0; i < SLHDSA_N; i++) R[i] = sig[i]; - - // -- Compute message digest using SHAKE256 -- - // digest = SHAKE256(R || PK.seed || PK.root || M) - uchar hash_input[96]; // R[16] + pk_seed[16] + pk_root[16] + msg[32] = 80 bytes - for (uint i = 0; i < SLHDSA_N; i++) hash_input[i] = R[i]; - for (uint i = 0; i < SLHDSA_N; i++) hash_input[SLHDSA_N + i] = pk_seed[i]; - for (uint i = 0; i < SLHDSA_N; i++) hash_input[2 * SLHDSA_N + i] = pk_root[i]; - for (uint i = 0; i < 32; i++) hash_input[3 * SLHDSA_N + i] = messages[tid].data[i]; - - uchar digest[32]; - shake256(hash_input, 3 * SLHDSA_N + 32, digest, 32); - - // -- FORS verification -- - // Extract FORS signature: k trees, each with a leaf value and auth path - // FORS sig starts at offset SLHDSA_N in the signature - uint fors_offset = SLHDSA_N; - - // Compute FORS public key from signature - uchar fors_roots[SLHDSA_K][SLHDSA_N]; - for (uint tree = 0; tree < SLHDSA_K; tree++) { - // Extract FORS leaf - uchar leaf[SLHDSA_N]; - for (uint i = 0; i < SLHDSA_N; i++) { - leaf[i] = sig[fors_offset + tree * (SLHDSA_N + SLHDSA_A * SLHDSA_N) + i]; - } - - // Hash the leaf to get node - uchar node[SLHDSA_N]; - uchar leaf_input[64]; - for (uint i = 0; i < SLHDSA_N; i++) leaf_input[i] = pk_seed[i]; - // Simple ADRS placeholder - for (uint i = SLHDSA_N; i < 48; i++) leaf_input[i] = 0; - for (uint i = 0; i < SLHDSA_N; i++) leaf_input[48 + i] = leaf[i]; - shake256(leaf_input, 64, node, SLHDSA_N); - - // Climb auth path - uint auth_offset = fors_offset + tree * (SLHDSA_N + SLHDSA_A * SLHDSA_N) + SLHDSA_N; - - // Extract tree index from digest - uint tree_idx = 0; - uint bit_offset = tree * SLHDSA_A; - for (uint b = 0; b < SLHDSA_A; b++) { - uint byte_idx = (bit_offset + b) / 8; - uint bit_pos = (bit_offset + b) % 8; - tree_idx |= ((uint)(digest[byte_idx] >> bit_pos) & 1) << b; - } - - for (uint layer = 0; layer < SLHDSA_A; layer++) { - uchar sibling[SLHDSA_N]; - for (uint i = 0; i < SLHDSA_N; i++) { - sibling[i] = sig[auth_offset + layer * SLHDSA_N + i]; - } - - // Hash pair: order depends on tree_idx bit - uchar pair_input[64]; - for (uint i = 0; i < SLHDSA_N; i++) pair_input[i] = pk_seed[i]; - for (uint i = SLHDSA_N; i < 32; i++) pair_input[i] = 0; - - if ((tree_idx >> layer) & 1) { - for (uint i = 0; i < SLHDSA_N; i++) pair_input[32 + i] = sibling[i]; - for (uint i = 0; i < SLHDSA_N; i++) pair_input[32 + SLHDSA_N + i] = node[i]; - } else { - for (uint i = 0; i < SLHDSA_N; i++) pair_input[32 + i] = node[i]; - for (uint i = 0; i < SLHDSA_N; i++) pair_input[32 + SLHDSA_N + i] = sibling[i]; - } - - shake256(pair_input, 64, node, SLHDSA_N); - } - - for (uint i = 0; i < SLHDSA_N; i++) fors_roots[tree][i] = node[i]; - } - - // -- Compute FORS public key hash from roots -- - // PK_FORS = T_k(PK.seed, ADRS, fors_roots) - // Simplified: hash all roots together - uchar fors_pk_input[SLHDSA_N + SLHDSA_K * SLHDSA_N]; - for (uint i = 0; i < SLHDSA_N; i++) fors_pk_input[i] = pk_seed[i]; - for (uint t = 0; t < SLHDSA_K; t++) { - for (uint i = 0; i < SLHDSA_N; i++) { - fors_pk_input[SLHDSA_N + t * SLHDSA_N + i] = fors_roots[t][i]; - } - } - uchar fors_pk[SLHDSA_N]; - shake256(fors_pk_input, SLHDSA_N + SLHDSA_K * SLHDSA_N, fors_pk, SLHDSA_N); - - // -- Hypertree verification -- - // For each layer of the hypertree, verify WOTS+ signature and climb Merkle tree - // The FORS PK becomes the message for the first hypertree layer - - uchar current_node[SLHDSA_N]; - for (uint i = 0; i < SLHDSA_N; i++) current_node[i] = fors_pk[i]; - - uint ht_offset = fors_offset + SLHDSA_K * (SLHDSA_N + SLHDSA_A * SLHDSA_N); - - for (uint layer = 0; layer < SLHDSA_D; layer++) { - // Extract WOTS+ signature for this layer - uchar wots_sig[SLHDSA_LEN][SLHDSA_N]; - for (uint i = 0; i < SLHDSA_LEN; i++) { - for (uint j = 0; j < SLHDSA_N; j++) { - wots_sig[i][j] = sig[ht_offset + layer * (SLHDSA_LEN * SLHDSA_N + SLHDSA_HP * SLHDSA_N) + i * SLHDSA_N + j]; - } - } - - // Compute WOTS+ public key from signature - // For each chain: complete the chain to W-1 - uchar adrs[32] = {}; - adrs[4] = uchar(layer); // layer address - - uchar wots_pk_parts[SLHDSA_LEN][SLHDSA_N]; - for (uint i = 0; i < SLHDSA_LEN; i++) { - // Determine chain length from message - uint msg_byte = i < SLHDSA_N ? current_node[i] : 0; - uint chain_start, chain_len; - - if (i < SLHDSA_LEN1) { - // Base-w digit from message - uint digit = (msg_byte >> ((i % 2) * 4)) & 0x0F; - chain_start = digit; - chain_len = SLHDSA_W - 1 - digit; - } else { - // Checksum digit - chain_start = 0; - chain_len = SLHDSA_W - 1; - } - - adrs[20] = uchar(i >> 8); - adrs[21] = uchar(i); - - wots_chain(pk_seed, adrs, wots_sig[i], chain_start, chain_len, - wots_pk_parts[i]); - } - - // Hash WOTS+ PK parts to get node - uchar wots_pk_input[SLHDSA_N + SLHDSA_LEN * SLHDSA_N]; - for (uint i = 0; i < SLHDSA_N; i++) wots_pk_input[i] = pk_seed[i]; - for (uint i = 0; i < SLHDSA_LEN; i++) { - for (uint j = 0; j < SLHDSA_N; j++) { - wots_pk_input[SLHDSA_N + i * SLHDSA_N + j] = wots_pk_parts[i][j]; - } - } - shake256(wots_pk_input, SLHDSA_N + SLHDSA_LEN * SLHDSA_N, current_node, SLHDSA_N); - - // Climb Merkle tree auth path for this layer - uint auth_base = ht_offset + layer * (SLHDSA_LEN * SLHDSA_N + SLHDSA_HP * SLHDSA_N) - + SLHDSA_LEN * SLHDSA_N; - - for (uint h = 0; h < SLHDSA_HP; h++) { - uchar sibling[SLHDSA_N]; - for (uint i = 0; i < SLHDSA_N; i++) { - sibling[i] = sig[auth_base + h * SLHDSA_N + i]; - } - - uchar pair_input[64]; - for (uint i = 0; i < SLHDSA_N; i++) pair_input[i] = pk_seed[i]; - for (uint i = SLHDSA_N; i < 32; i++) pair_input[i] = 0; - for (uint i = 0; i < SLHDSA_N; i++) pair_input[32 + i] = current_node[i]; - for (uint i = 0; i < SLHDSA_N; i++) pair_input[32 + SLHDSA_N + i] = sibling[i]; - - shake256(pair_input, 64, current_node, SLHDSA_N); - } - } - - // -- Compare reconstructed root with PK.root -- - bool valid = true; - for (uint i = 0; i < SLHDSA_N; i++) { - if (current_node[i] != pk_root[i]) { - valid = false; - break; - } - } - - results[tid] = valid ? 1u : 0u; -} diff --git a/slhdsa/gpu/metal/slhdsa_driver.h b/slhdsa/gpu/metal/slhdsa_driver.h deleted file mode 100644 index 8a14c74..0000000 --- a/slhdsa/gpu/metal/slhdsa_driver.h +++ /dev/null @@ -1,249 +0,0 @@ -// ============================================================================= -// Metal SLH-DSA - GPU Acceleration Interface for Hash-Based Signatures -// ============================================================================= -// -// C++ interface for dispatching SLH-DSA (FIPS 205, Stateless Hash-Based -// Digital Signature Algorithm, formerly SPHINCS+) operations to Metal compute. -// -// SLH-DSA is hash-based and doesn't use NTT like ML-DSA/ML-KEM. -// GPU acceleration focuses on: -// - Parallel hash tree computations (WOTS+, XMSS, FORS) -// - Batch signature verification -// - Parallel SHAKE/SHA2 operations -// -// Copyright (C) 2024-2025 Lux Industries Inc. -// SPDX-License-Identifier: Apache-2.0 - -#pragma once -#include -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -// ============================================================================= -// SLH-DSA Modes -// ============================================================================= - -/** - * SLH-DSA parameter sets. - * Format: {hash}_{security}_{variant} - * - hash: SHA2 or SHAKE - * - security: 128, 192, or 256 bits - * - variant: s (small signature) or f (fast signing) - */ -typedef enum { - // 128-bit security (NIST Level 1) - SLHDSA_SHA2_128s = 0, - SLHDSA_SHAKE_128s = 1, - SLHDSA_SHA2_128f = 2, - SLHDSA_SHAKE_128f = 3, - - // 192-bit security (NIST Level 3) - SLHDSA_SHA2_192s = 4, - SLHDSA_SHAKE_192s = 5, - SLHDSA_SHA2_192f = 6, - SLHDSA_SHAKE_192f = 7, - - // 256-bit security (NIST Level 5) - SLHDSA_SHA2_256s = 8, - SLHDSA_SHAKE_256s = 9, - SLHDSA_SHA2_256f = 10, - SLHDSA_SHAKE_256f = 11, -} SLHDSAMode; - -// ============================================================================= -// Size Constants -// ============================================================================= - -// 128-bit small (Level 1) -#define SLHDSA_128S_PUBLIC_KEY_SIZE 32 -#define SLHDSA_128S_SECRET_KEY_SIZE 64 -#define SLHDSA_128S_SIGNATURE_SIZE 7856 - -// 128-bit fast (Level 1) -#define SLHDSA_128F_PUBLIC_KEY_SIZE 32 -#define SLHDSA_128F_SECRET_KEY_SIZE 64 -#define SLHDSA_128F_SIGNATURE_SIZE 17088 - -// 192-bit small (Level 3) -#define SLHDSA_192S_PUBLIC_KEY_SIZE 48 -#define SLHDSA_192S_SECRET_KEY_SIZE 96 -#define SLHDSA_192S_SIGNATURE_SIZE 16224 - -// 192-bit fast (Level 3) -#define SLHDSA_192F_PUBLIC_KEY_SIZE 48 -#define SLHDSA_192F_SECRET_KEY_SIZE 96 -#define SLHDSA_192F_SIGNATURE_SIZE 35664 - -// 256-bit small (Level 5) -#define SLHDSA_256S_PUBLIC_KEY_SIZE 64 -#define SLHDSA_256S_SECRET_KEY_SIZE 128 -#define SLHDSA_256S_SIGNATURE_SIZE 29792 - -// 256-bit fast (Level 5) -#define SLHDSA_256F_PUBLIC_KEY_SIZE 64 -#define SLHDSA_256F_SECRET_KEY_SIZE 128 -#define SLHDSA_256F_SIGNATURE_SIZE 49856 - -// ============================================================================= -// Context Management -// ============================================================================= - -/** - * Opaque handle to Metal SLH-DSA compute context. - */ -typedef struct MetalSLHDSAContext MetalSLHDSAContext; - -/** - * Initialize Metal SLH-DSA context. - * Loads hash shaders and creates compute pipelines. - * @return Context handle, or NULL if Metal unavailable - */ -MetalSLHDSAContext* metal_slhdsa_init(void); - -/** - * Destroy Metal SLH-DSA context and release resources. - */ -void metal_slhdsa_destroy(MetalSLHDSAContext* ctx); - -/** - * Check if Metal acceleration is available for SLH-DSA. - * @return true if Metal GPU is available - */ -bool metal_slhdsa_available(void); - -// ============================================================================= -// Key Generation -// ============================================================================= - -/** - * Generate SLH-DSA key pair on GPU. - * Uses GPU-accelerated hash tree computation. - * @param ctx Metal context - * @param mode Parameter set - * @param public_key Output public key buffer - * @param secret_key Output secret key buffer - * @param seed Random seed (n bytes where n = 16/24/32 for 128/192/256-bit) - * @return 0 on success, negative on error - */ -int metal_slhdsa_keygen( - MetalSLHDSAContext* ctx, - SLHDSAMode mode, - uint8_t* public_key, - uint8_t* secret_key, - const uint8_t* seed); - -// ============================================================================= -// Signing Operations -// ============================================================================= - -/** - * Sign a message using SLH-DSA on GPU. - * @param ctx Metal context - * @param mode Parameter set - * @param signature Output signature buffer - * @param secret_key Secret key bytes - * @param message Message to sign - * @param message_len Message length - * @param context Optional context string (NULL for empty) - * @param context_len Context string length - * @return 0 on success, negative on error - */ -int metal_slhdsa_sign( - MetalSLHDSAContext* ctx, - SLHDSAMode mode, - uint8_t* signature, - const uint8_t* secret_key, - const uint8_t* message, - size_t message_len, - const uint8_t* context, - size_t context_len); - -/** - * Batch sign multiple messages on GPU. - * All messages use the same secret key. - * @param ctx Metal context - * @param mode Parameter set - * @param signatures Output signature buffers (count elements) - * @param secret_key Secret key bytes - * @param messages Array of message pointers - * @param message_lens Array of message lengths - * @param count Number of messages to sign - * @return 0 on success, negative on error - */ -int metal_slhdsa_batch_sign( - MetalSLHDSAContext* ctx, - SLHDSAMode mode, - uint8_t** signatures, - const uint8_t* secret_key, - const uint8_t* const* messages, - const size_t* message_lens, - uint32_t count); - -// ============================================================================= -// Verification Operations -// ============================================================================= - -/** - * Verify an SLH-DSA signature on GPU. - * @param ctx Metal context - * @param mode Parameter set - * @param public_key Public key bytes - * @param signature Signature bytes - * @param message Message that was signed - * @param message_len Message length - * @param context Optional context string - * @param context_len Context string length - * @return 1 if valid, 0 if invalid, negative on error - */ -int metal_slhdsa_verify( - MetalSLHDSAContext* ctx, - SLHDSAMode mode, - const uint8_t* public_key, - const uint8_t* signature, - const uint8_t* message, - size_t message_len, - const uint8_t* context, - size_t context_len); - -/** - * Batch verify multiple SLH-DSA signatures on GPU. - * Significantly faster than individual verification for large batches. - * @param ctx Metal context - * @param mode Parameter set - * @param public_keys Array of public keys - * @param signatures Array of signatures - * @param messages Array of message pointers - * @param message_lens Array of message lengths - * @param count Number of signatures to verify - * @param results Output: 1 if valid, 0 if invalid (count elements) - * @return Number of valid signatures, negative on error - */ -int metal_slhdsa_batch_verify( - MetalSLHDSAContext* ctx, - SLHDSAMode mode, - const uint8_t* const* public_keys, - const uint8_t* const* signatures, - const uint8_t* const* messages, - const size_t* message_lens, - uint32_t count, - int* results); - -// ============================================================================= -// Error Codes -// ============================================================================= - -#define METAL_SLHDSA_SUCCESS 0 -#define METAL_SLHDSA_ERROR_NO_DEVICE -1 -#define METAL_SLHDSA_ERROR_NO_SHADER -2 -#define METAL_SLHDSA_ERROR_ALLOC -3 -#define METAL_SLHDSA_ERROR_NULL_PTR -4 -#define METAL_SLHDSA_ERROR_INVALID -5 -#define METAL_SLHDSA_ERROR_VERIFY -6 - -#ifdef __cplusplus -} -#endif diff --git a/slhdsa/gpu/metal/slhdsa_driver.mm b/slhdsa/gpu/metal/slhdsa_driver.mm deleted file mode 100644 index 0b1e247..0000000 --- a/slhdsa/gpu/metal/slhdsa_driver.mm +++ /dev/null @@ -1,204 +0,0 @@ -// ============================================================================= -// Metal SLH-DSA - GPU Acceleration for Hash-Based Signatures -// ============================================================================= -// -// Implementation of SLH-DSA (FIPS 205) with Metal GPU acceleration. -// Falls back to CPU implementation when GPU is not available. -// -// Copyright (C) 2024-2025 Lux Industries Inc. -// SPDX-License-Identifier: Apache-2.0 - -#include "lux/crypto/metal_slhdsa.h" -#include -#include - -#ifdef __APPLE__ -#import -#import -#endif - -// External CPU implementations from crypto.cpp (mode-aware versions) -extern "C" { - int slhdsa_keygen(int mode, uint8_t* pk, uint8_t* sk, const uint8_t* seed); - int slhdsa_sign(int mode, uint8_t* sig, const uint8_t* sk, - const uint8_t* msg, size_t msg_len); - int slhdsa_verify(int mode, const uint8_t* pk, const uint8_t* sig, - const uint8_t* msg, size_t msg_len); - int slhdsa_batch_sign(int mode, uint8_t** sigs, const uint8_t* sk, - const uint8_t* const* msgs, const size_t* msg_lens, - uint32_t count); - int slhdsa_batch_verify(int mode, const uint8_t* const* pks, - const uint8_t* const* sigs, - const uint8_t* const* msgs, - const size_t* msg_lens, - uint32_t count, int* results); -} - -// ============================================================================= -// Context Structure -// ============================================================================= - -struct MetalSLHDSAContext { -#ifdef __APPLE__ - id device; - id queue; - id library; - bool gpu_available; -#else - bool gpu_available; -#endif -}; - -// ============================================================================= -// Context Management -// ============================================================================= - -MetalSLHDSAContext* metal_slhdsa_init(void) { - MetalSLHDSAContext* ctx = new MetalSLHDSAContext(); - ctx->gpu_available = false; - -#ifdef __APPLE__ - @autoreleasepool { - ctx->device = MTLCreateSystemDefaultDevice(); - if (ctx->device) { - ctx->queue = [ctx->device newCommandQueue]; - ctx->gpu_available = (ctx->queue != nil); - // TODO: Load hash compute shaders when implemented - } - } -#endif - - return ctx; -} - -void metal_slhdsa_destroy(MetalSLHDSAContext* ctx) { - if (ctx) { -#ifdef __APPLE__ - @autoreleasepool { - ctx->queue = nil; - ctx->library = nil; - ctx->device = nil; - } -#endif - delete ctx; - } -} - -bool metal_slhdsa_available(void) { -#ifdef __APPLE__ - @autoreleasepool { - id device = MTLCreateSystemDefaultDevice(); - return device != nil; - } -#else - return false; -#endif -} - -// ============================================================================= -// Key Generation -// ============================================================================= - -int metal_slhdsa_keygen( - MetalSLHDSAContext* ctx, - SLHDSAMode mode, - uint8_t* public_key, - uint8_t* secret_key, - const uint8_t* seed) -{ - if (!ctx || !public_key || !secret_key || !seed) { - return METAL_SLHDSA_ERROR_NULL_PTR; - } - - // Use CPU implementation (GPU hash tree acceleration TODO) - return slhdsa_keygen((int)mode, public_key, secret_key, seed); -} - -// ============================================================================= -// Signing Operations -// ============================================================================= - -int metal_slhdsa_sign( - MetalSLHDSAContext* ctx, - SLHDSAMode mode, - uint8_t* signature, - const uint8_t* secret_key, - const uint8_t* message, - size_t message_len, - const uint8_t* context, - size_t context_len) -{ - if (!ctx || !signature || !secret_key || !message) { - return METAL_SLHDSA_ERROR_NULL_PTR; - } - - // TODO: Add context support - (void)context; - (void)context_len; - - return slhdsa_sign((int)mode, signature, secret_key, message, message_len); -} - -int metal_slhdsa_batch_sign( - MetalSLHDSAContext* ctx, - SLHDSAMode mode, - uint8_t** signatures, - const uint8_t* secret_key, - const uint8_t* const* messages, - const size_t* message_lens, - uint32_t count) -{ - if (!ctx || !signatures || !secret_key || !messages || !message_lens) { - return METAL_SLHDSA_ERROR_NULL_PTR; - } - - if (count == 0) { - return METAL_SLHDSA_SUCCESS; - } - - return slhdsa_batch_sign((int)mode, signatures, secret_key, - messages, message_lens, count); -} - -// ============================================================================= -// Verification Operations -// ============================================================================= - -int metal_slhdsa_verify( - MetalSLHDSAContext* ctx, - SLHDSAMode mode, - const uint8_t* public_key, - const uint8_t* signature, - const uint8_t* message, - size_t message_len, - const uint8_t* context, - size_t context_len) -{ - if (!ctx || !public_key || !signature || !message) { - return METAL_SLHDSA_ERROR_NULL_PTR; - } - - // TODO: Add context support - (void)context; - (void)context_len; - - return slhdsa_verify((int)mode, public_key, signature, message, message_len); -} - -int metal_slhdsa_batch_verify( - MetalSLHDSAContext* ctx, - SLHDSAMode mode, - const uint8_t* const* public_keys, - const uint8_t* const* signatures, - const uint8_t* const* messages, - const size_t* message_lens, - uint32_t count, - int* results) -{ - if (!ctx || !public_keys || !signatures || !messages || !message_lens || !results) { - return METAL_SLHDSA_ERROR_NULL_PTR; - } - - return slhdsa_batch_verify((int)mode, public_keys, signatures, messages, - message_lens, count, results); -} diff --git a/slhdsa/gpu/wgsl/slhdsa.wgsl b/slhdsa/gpu/wgsl/slhdsa.wgsl deleted file mode 100644 index 523c8c1..0000000 --- a/slhdsa/gpu/wgsl/slhdsa.wgsl +++ /dev/null @@ -1,231 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// SLH-DSA (FIPS 205, SPHINCS+) batch verification in WGSL. -// Hash-based signature scheme using SHAKE256 (Keccak-based). -// Each thread verifies one signature by recomputing WOTS+ chains -// and Merkle tree paths. - -@group(0) @binding(0) var pubkeys: array; -@group(0) @binding(1) var messages: array; -@group(0) @binding(2) var signatures: array; -@group(0) @binding(3) var results: array; -@group(0) @binding(4) var params: vec4; // params.x = num_sigs - -const SLH_N: u32 = 16u; -const SLH_K: u32 = 33u; -const SLH_A: u32 = 6u; -const SLH_D: u32 = 22u; -const SLH_HP: u32 = 3u; -const SLH_W: u32 = 16u; -const SLH_LEN: u32 = 7u; - -// Keccak-f[1600] round constants (lo, hi pairs for u64 emulation) -const RC_LO = array( - 0x00000001u, 0x00008082u, 0x0000808Au, 0x80008000u, - 0x0000808Bu, 0x80000001u, 0x80008081u, 0x00008009u, - 0x0000008Au, 0x00000088u, 0x80008009u, 0x8000000Au, - 0x8000808Bu, 0x0000008Bu, 0x00008089u, 0x00008003u, - 0x00008002u, 0x00000080u, 0x0000800Au, 0x8000000Au, - 0x80008081u, 0x00008080u, 0x80000001u, 0x80008008u -); - -const RC_HI = array( - 0x00000000u, 0x00000000u, 0x80000000u, 0x80000000u, - 0x00000000u, 0x00000000u, 0x80000000u, 0x80000000u, - 0x00000000u, 0x00000000u, 0x00000000u, 0x00000000u, - 0x00000000u, 0x80000000u, 0x80000000u, 0x80000000u, - 0x80000000u, 0x80000000u, 0x00000000u, 0x80000000u, - 0x80000000u, 0x80000000u, 0x00000000u, 0x80000000u -); - -const PI_LANE = array( - 10u, 7u, 11u, 17u, 18u, 3u, 5u, 16u, 8u, 21u, 24u, 4u, - 15u, 23u, 19u, 13u, 12u, 2u, 20u, 14u, 22u, 9u, 6u, 1u -); - -const RHO_OFFSETS = array( - 1u, 3u, 6u, 10u, 15u, 21u, 28u, 36u, 45u, 55u, 2u, 14u, - 27u, 41u, 56u, 8u, 25u, 43u, 62u, 18u, 39u, 61u, 20u, 44u -); - -var st_lo: array; -var st_hi: array; - -fn rotl64(lo: u32, hi: u32, n: u32) -> vec2 { - if (n == 0u) { return vec2(lo, hi); } - if (n == 32u) { return vec2(hi, lo); } - if (n < 32u) { - return vec2((lo << n) | (hi >> (32u - n)), (hi << n) | (lo >> (32u - n))); - } - let m = n - 32u; - return vec2((hi << m) | (lo >> (32u - m)), (lo << m) | (hi >> (32u - m))); -} - -fn keccak_f() { - for (var round = 0u; round < 24u; round = round + 1u) { - var c_lo: array; - var c_hi: array; - for (var x = 0u; x < 5u; x = x + 1u) { - c_lo[x] = st_lo[x] ^ st_lo[x+5u] ^ st_lo[x+10u] ^ st_lo[x+15u] ^ st_lo[x+20u]; - c_hi[x] = st_hi[x] ^ st_hi[x+5u] ^ st_hi[x+10u] ^ st_hi[x+15u] ^ st_hi[x+20u]; - } - for (var x = 0u; x < 5u; x = x + 1u) { - let r = rotl64(c_lo[(x+1u) % 5u], c_hi[(x+1u) % 5u], 1u); - let d_lo = c_lo[(x+4u) % 5u] ^ r.x; - let d_hi = c_hi[(x+4u) % 5u] ^ r.y; - for (var y = 0u; y < 5u; y = y + 1u) { - let idx = x + 5u * y; - st_lo[idx] = st_lo[idx] ^ d_lo; - st_hi[idx] = st_hi[idx] ^ d_hi; - } - } - var t_lo = st_lo[1u]; - var t_hi = st_hi[1u]; - for (var i = 0u; i < 24u; i = i + 1u) { - let dst = PI_LANE[i]; - let tmp_lo = st_lo[dst]; - let tmp_hi = st_hi[dst]; - let r = rotl64(t_lo, t_hi, RHO_OFFSETS[i]); - st_lo[dst] = r.x; - st_hi[dst] = r.y; - t_lo = tmp_lo; - t_hi = tmp_hi; - } - for (var y = 0u; y < 5u; y = y + 1u) { - var row_lo: array; - var row_hi: array; - for (var x = 0u; x < 5u; x = x + 1u) { - row_lo[x] = st_lo[x + 5u * y]; - row_hi[x] = st_hi[x + 5u * y]; - } - for (var x = 0u; x < 5u; x = x + 1u) { - st_lo[x + 5u * y] = row_lo[x] ^ ((~row_lo[(x+1u) % 5u]) & row_lo[(x+2u) % 5u]); - st_hi[x + 5u * y] = row_hi[x] ^ ((~row_hi[(x+1u) % 5u]) & row_hi[(x+2u) % 5u]); - } - } - st_lo[0u] = st_lo[0u] ^ RC_LO[round]; - st_hi[0u] = st_hi[0u] ^ RC_HI[round]; - } -} - -fn read_sig_byte(sig_base: u32, idx: u32) -> u32 { - let word_idx = (sig_base + idx) >> 2u; - let byte_pos = (sig_base + idx) & 3u; - return (signatures[word_idx] >> (byte_pos * 8u)) & 0xFFu; -} - -fn read_pk_byte(pk_base: u32, idx: u32) -> u32 { - let word_idx = (pk_base + idx) >> 2u; - let byte_pos = (pk_base + idx) & 3u; - return (pubkeys[word_idx] >> (byte_pos * 8u)) & 0xFFu; -} - -// Simple SHAKE256 hash of 64 bytes -> 16 bytes via Keccak -fn shake256_64_to_16(input: ptr>, output: ptr>) { - for (var i = 0u; i < 25u; i = i + 1u) { st_lo[i] = 0u; st_hi[i] = 0u; } - - // Absorb 64 bytes = 8 u64 words into rate (136 bytes = 17 u64 words) - for (var w = 0u; w < 8u; w = w + 1u) { - st_lo[w] = st_lo[w] ^ (*input)[w * 2u]; - st_hi[w] = st_hi[w] ^ (*input)[w * 2u + 1u]; - } - - // SHAKE256 padding at byte 64: 0x1F - st_lo[8u] = st_lo[8u] ^ 0x1Fu; - // Last byte of rate (byte 135): 0x80 - st_hi[16u] = st_hi[16u] ^ 0x80000000u; - - keccak_f(); - - // Squeeze 16 bytes = 2 u64 words - (*output)[0u] = st_lo[0u]; - (*output)[1u] = st_hi[0u]; - (*output)[2u] = st_lo[1u]; - (*output)[3u] = st_hi[1u]; -} - -@compute @workgroup_size(64) -fn slhdsa_verify_batch(@builtin(global_invocation_id) gid: vec3) { - let tid = gid.x; - if (tid >= params.x) { return; } - - let pk_base = tid * 8u; // 32 bytes = 8 u32 words - let sig_base_words = tid * 4272u; // ~17088 bytes = 4272 u32 words - - // Read PK.seed and PK.root (each 16 bytes = 4 u32 words) - var pk_seed: array; - var pk_root: array; - for (var i = 0u; i < 4u; i = i + 1u) { - pk_seed[i] = pubkeys[pk_base + i]; - pk_root[i] = pubkeys[pk_base + 4u + i]; - } - - // Read randomizer R (16 bytes) from signature - var R: array; - for (var i = 0u; i < 4u; i = i + 1u) { - R[i] = signatures[sig_base_words + i]; - } - - // Compute message digest via SHAKE256(R || pk_seed || pk_root || msg) - // = 80 bytes input -> 32 bytes output - var hash_input: array; - for (var i = 0u; i < 4u; i = i + 1u) { hash_input[i] = R[i]; } - for (var i = 0u; i < 4u; i = i + 1u) { hash_input[4u + i] = pk_seed[i]; } - for (var i = 0u; i < 4u; i = i + 1u) { hash_input[8u + i] = pk_root[i]; } - let msg_base = tid * 8u; - for (var i = 0u; i < 4u; i = i + 1u) { hash_input[12u + i] = messages[msg_base + i]; } - - var digest: array; - shake256_64_to_16(&hash_input, &digest); - - // Verify FORS tree structure - // For each FORS tree, hash leaf and climb auth path - var fors_offset = 4u; // After R (in u32 words) - var current_node: array; - var valid = true; - - for (var tree = 0u; tree < SLH_K; tree = tree + 1u) { - // Read leaf - var leaf: array; - for (var i = 0u; i < 4u; i = i + 1u) { - leaf[i] = signatures[sig_base_words + fors_offset + tree * 28u + i]; - } - - // Hash leaf - var leaf_input: array; - for (var i = 0u; i < 4u; i = i + 1u) { leaf_input[i] = pk_seed[i]; } - for (var i = 4u; i < 12u; i = i + 1u) { leaf_input[i] = 0u; } - for (var i = 0u; i < 4u; i = i + 1u) { leaf_input[12u + i] = leaf[i]; } - - var node: array; - shake256_64_to_16(&leaf_input, &node); - - // Climb auth path - for (var layer = 0u; layer < SLH_A; layer = layer + 1u) { - var sibling: array; - let sib_off = sig_base_words + fors_offset + tree * 28u + 4u + layer * 4u; - for (var i = 0u; i < 4u; i = i + 1u) { - sibling[i] = signatures[sib_off + i]; - } - - var pair_input: array; - for (var i = 0u; i < 4u; i = i + 1u) { pair_input[i] = pk_seed[i]; } - for (var i = 4u; i < 8u; i = i + 1u) { pair_input[i] = 0u; } - for (var i = 0u; i < 4u; i = i + 1u) { pair_input[8u + i] = node[i]; } - for (var i = 0u; i < 4u; i = i + 1u) { pair_input[12u + i] = sibling[i]; } - - shake256_64_to_16(&pair_input, &node); - } - current_node = node; - } - - // Compare final root with PK.root - for (var i = 0u; i < 4u; i = i + 1u) { - if (current_node[i] != pk_root[i]) { - valid = false; - } - } - - results[tid] = select(0u, 1u, valid); -} diff --git a/sr25519/gpu/cuda/sr25519.cu b/sr25519/gpu/cuda/sr25519.cu deleted file mode 100644 index 3269b4f..0000000 --- a/sr25519/gpu/cuda/sr25519.cu +++ /dev/null @@ -1,393 +0,0 @@ -// sr25519/Ristretto255 batch verification — CUDA implementation -// Matches sr25519.metal output byte-for-byte -// One thread per Schnorr signature verification - -#include - -#ifndef __CUDA_ARCH__ -#define __device__ -#define __global__ -#define __shared__ -struct dim3 { unsigned x, y, z; }; -static dim3 blockIdx, blockDim, threadIdx; -#endif - -// ============================================================================= -// 256-bit integer (4 x 64-bit limbs, little-endian) -// ============================================================================= - -struct uint256 { - uint64_t limbs[4]; -}; - -// ============================================================================= -// Constants (same field as Ed25519: p = 2^255 - 19) -// ============================================================================= - -__device__ static const uint256 SR_P = {{ - 0xFFFFFFFFFFFFFFEDULL, 0xFFFFFFFFFFFFFFFFULL, - 0xFFFFFFFFFFFFFFFFULL, 0x7FFFFFFFFFFFFFFFULL -}}; - -__device__ static const uint256 SR_L = {{ - 0x5812631A5CF5D3EDULL, 0x14DEF9DEA2F79CD6ULL, - 0x0000000000000000ULL, 0x1000000000000000ULL -}}; - -__device__ static const uint256 SR_ZERO = {{0, 0, 0, 0}}; -__device__ static const uint256 SR_ONE = {{1, 0, 0, 0}}; - -__device__ static const uint256 SR_D = {{ - 0x75EB4DCA135978A3ULL, 0x00700A4D4141D8ABULL, - 0x8CC740797779E898ULL, 0x52036CBC148B6DE8ULL -}}; - -__device__ static const uint256 SR_2D = {{ - 0xEBD69B9426B2F159ULL, 0x00E0149A8283B156ULL, - 0x198E80F2EEF3D130ULL, 0x2406D9DC56DFFCE7ULL -}}; - -__device__ static const uint256 SR_SQRT_M1 = {{ - 0xC4EE1B274A0EA0B0ULL, 0x2F431806AD2FE478ULL, - 0x2B4D00993DFBD7A7ULL, 0x2B8324804FC1DF0BULL -}}; - -// ============================================================================= -// Field arithmetic (standalone, same as ed25519.metal) -// ============================================================================= - -__device__ static int u256_cmp(uint256 a, uint256 b) { - for (int i = 3; i >= 0; i--) { - if (a.limbs[i] < b.limbs[i]) return -1; - if (a.limbs[i] > b.limbs[i]) return 1; - } - return 0; -} - -__device__ static bool u256_is_zero(uint256 a) { - return (a.limbs[0] | a.limbs[1] | a.limbs[2] | a.limbs[3]) == 0; -} - -__device__ static uint256 u256_add(uint256 a, uint256 b, uint64_t& carry) { - uint256 r; - uint64_t c = 0; - for (int i = 0; i < 4; i++) { - uint64_t sum = a.limbs[i] + c; - c = (sum < a.limbs[i]) ? 1ULL : 0ULL; - uint64_t sum2 = sum + b.limbs[i]; - c += (sum2 < sum) ? 1ULL : 0ULL; - r.limbs[i] = sum2; - } - carry = c; - return r; -} - -__device__ static uint256 u256_sub(uint256 a, uint256 b, uint64_t& borrow) { - uint256 r; - uint64_t bw = 0; - for (int i = 0; i < 4; i++) { - uint64_t diff = a.limbs[i] - bw; - bw = (diff > a.limbs[i]) ? 1ULL : 0ULL; - uint64_t diff2 = diff - b.limbs[i]; - bw += (diff2 > diff) ? 1ULL : 0ULL; - r.limbs[i] = diff2; - } - borrow = bw; - return r; -} - -__device__ static uint256 fp_add(uint256 a, uint256 b) { - uint64_t c; - uint256 r = u256_add(a, b, c); - if (c || u256_cmp(r, SR_P) >= 0) { uint64_t bw; r = u256_sub(r, SR_P, bw); } - return r; -} - -__device__ static uint256 fp_sub(uint256 a, uint256 b) { - uint64_t bw; - uint256 r = u256_sub(a, b, bw); - if (bw) { uint64_t c; r = u256_add(r, SR_P, c); } - return r; -} - -__device__ static uint256 fp_mul(uint256 a, uint256 b) { - uint64_t t[8] = {}; - for (int i = 0; i < 4; i++) { - uint64_t carry = 0; - for (int j = 0; j < 4; j++) { -#ifdef __CUDA_ARCH__ - unsigned __int128 prod = (unsigned __int128)a.limbs[i] * b.limbs[j]; - unsigned __int128 acc = prod + carry + t[i + j]; - t[i + j] = (uint64_t)acc; - carry = (uint64_t)(acc >> 64); -#else - uint64_t a_lo = a.limbs[i] & 0xFFFFFFFFULL, a_hi = a.limbs[i] >> 32; - uint64_t b_lo = b.limbs[j] & 0xFFFFFFFFULL, b_hi = b.limbs[j] >> 32; - uint64_t ll = a_lo * b_lo, lh = a_lo * b_hi; - uint64_t hl = a_hi * b_lo, hh = a_hi * b_hi; - uint64_t mid = lh + (ll >> 32); - uint64_t mid2 = mid + hl; - if (mid2 < mid) hh += (1ULL << 32); - uint64_t lo = (mid2 << 32) | (ll & 0xFFFFFFFFULL); - uint64_t hi = hh + (mid2 >> 32); - uint64_t sum = lo + carry; if (sum < lo) hi++; - lo = sum; - sum = t[i + j] + lo; if (sum < t[i + j]) hi++; - t[i + j] = sum; - carry = hi; -#endif - } - t[i + 4] = carry; - } - uint256 lo_part = {{t[0], t[1], t[2], t[3]}}; - uint256 hi_part = {{t[4], t[5], t[6], t[7]}}; - uint64_t c2 = 0; - uint256 hi38; - for (int i = 0; i < 4; i++) { -#ifdef __CUDA_ARCH__ - unsigned __int128 prod = (unsigned __int128)hi_part.limbs[i] * 38ULL + c2; - hi38.limbs[i] = (uint64_t)prod; - c2 = (uint64_t)(prod >> 64); -#else - uint64_t a_lo = hi_part.limbs[i] & 0xFFFFFFFFULL; - uint64_t a_hi = hi_part.limbs[i] >> 32; - uint64_t ll = a_lo * 38ULL; - uint64_t hl = a_hi * 38ULL; - uint64_t lo = ll + (hl << 32); - uint64_t hi = (hl >> 32) + ((lo < ll) ? 1ULL : 0ULL); - uint64_t sum = lo + c2; - if (sum < lo) hi++; - c2 = hi; - hi38.limbs[i] = sum; -#endif - } - uint64_t c; - uint256 result = u256_add(lo_part, hi38, c); - if (c || c2) { - uint64_t extra = (c + c2) * 38; - uint256 extra256 = {{extra, 0, 0, 0}}; - result = u256_add(result, extra256, c); - } - while (u256_cmp(result, SR_P) >= 0) { uint64_t bw; result = u256_sub(result, SR_P, bw); } - return result; -} - -__device__ static uint256 fp_sqr(uint256 a) { return fp_mul(a, a); } - -__device__ static uint256 fp_neg(uint256 a) { - if (u256_is_zero(a)) return a; - uint64_t bw; return u256_sub(SR_P, a, bw); -} - -__device__ static uint256 fp_inv(uint256 a) { - uint256 exp = SR_P; exp.limbs[0] -= 2; - uint256 result = SR_ONE, base = a; - for (int i = 0; i < 4; i++) - for (int bit = 0; bit < 64; bit++) { - if ((exp.limbs[i] >> bit) & 1) result = fp_mul(result, base); - base = fp_sqr(base); - } - return result; -} - -// ============================================================================= -// Extended Edwards point (same curve as Ed25519) -// ============================================================================= - -struct RistrettoPoint { - uint256 X, Y, Z, T; -}; - -__device__ static RistrettoPoint ristretto_identity() { - RistrettoPoint p; - p.X = SR_ZERO; p.Y = SR_ONE; p.Z = SR_ONE; p.T = SR_ZERO; - return p; -} - -__device__ static RistrettoPoint ristretto_double(RistrettoPoint P) { - uint256 A = fp_sqr(P.X); - uint256 B = fp_sqr(P.Y); - uint256 C = fp_add(fp_sqr(P.Z), fp_sqr(P.Z)); - uint256 D = fp_neg(A); - uint256 E = fp_sub(fp_sqr(fp_add(P.X, P.Y)), fp_add(A, B)); - uint256 G = fp_add(D, B); - uint256 F = fp_sub(G, C); - uint256 H = fp_sub(D, B); - RistrettoPoint R; - R.X = fp_mul(E, F); R.Y = fp_mul(G, H); - R.T = fp_mul(E, H); R.Z = fp_mul(F, G); - return R; -} - -__device__ static RistrettoPoint ristretto_add(RistrettoPoint P, RistrettoPoint Q) { - uint256 A = fp_mul(P.X, Q.X); - uint256 B = fp_mul(P.Y, Q.Y); - uint256 C = fp_mul(P.T, fp_mul(SR_2D, Q.T)); - uint256 D = fp_add(fp_mul(P.Z, Q.Z), fp_mul(P.Z, Q.Z)); - uint256 E = fp_sub(fp_mul(fp_add(P.X, P.Y), fp_add(Q.X, Q.Y)), fp_add(A, B)); - uint256 F = fp_sub(D, C); - uint256 G = fp_add(D, C); - uint256 H = fp_add(B, A); - RistrettoPoint R; - R.X = fp_mul(E, F); R.Y = fp_mul(G, H); - R.T = fp_mul(E, H); R.Z = fp_mul(F, G); - return R; -} - -__device__ static RistrettoPoint ristretto_mul(uint256 k, RistrettoPoint P) { - RistrettoPoint result = ristretto_identity(); - for (int i = 3; i >= 0; i--) - for (int bit = 63; bit >= 0; bit--) { - result = ristretto_double(result); - if ((k.limbs[i] >> bit) & 1) result = ristretto_add(result, P); - } - return result; -} - -// ============================================================================= -// Ristretto255 decoding -// ============================================================================= - -__device__ static bool ristretto_decode(const uint8_t* encoded, RistrettoPoint& P) { - uint256 s; - for (int i = 0; i < 4; i++) { - s.limbs[i] = 0; - for (int b = 0; b < 8 && i * 8 + b < 32; b++) - s.limbs[i] |= (uint64_t)encoded[i * 8 + b] << (b * 8); - } - - if (s.limbs[3] >> 63) return false; - if (u256_cmp(s, SR_P) >= 0) return false; - - uint256 ss = fp_sqr(s); - uint256 u1 = fp_sub(SR_ONE, ss); - uint256 u2 = fp_add(SR_ONE, ss); - uint256 u2_sq = fp_sqr(u2); - uint256 v = fp_sub(fp_neg(fp_mul(SR_D, fp_sqr(u1))), u2_sq); - - uint256 vu2sq = fp_mul(v, u2_sq); - - // (v * u2^2)^((p-5)/8) - uint256 exp58 = SR_P; - exp58.limbs[0] -= 5; - for (int i = 0; i < 3; i++) - exp58.limbs[i] = (exp58.limbs[i] >> 3) | (exp58.limbs[i + 1] << 61); - exp58.limbs[3] >>= 3; - - uint256 inv_sqrt = SR_ONE; - uint256 base = vu2sq; - for (int i = 0; i < 4; i++) - for (int bit = 0; bit < 64; bit++) { - if ((exp58.limbs[i] >> bit) & 1) inv_sqrt = fp_mul(inv_sqrt, base); - base = fp_sqr(base); - } - - uint256 check = fp_mul(fp_sqr(inv_sqrt), vu2sq); - if (u256_cmp(check, SR_ONE) != 0) { - uint256 neg1 = fp_neg(SR_ONE); - if (u256_cmp(check, neg1) == 0) { - inv_sqrt = fp_mul(inv_sqrt, SR_SQRT_M1); - } else { - return false; - } - } - - uint256 x = fp_mul(fp_add(s, s), fp_mul(inv_sqrt, u2)); - if (x.limbs[0] & 1) x = fp_neg(x); - - uint256 y = fp_mul(u1, fp_mul(inv_sqrt, u2)); - - P.X = x; - P.Y = y; - P.Z = SR_ONE; - P.T = fp_mul(x, y); - return true; -} - -// ============================================================================= -// Structures -// ============================================================================= - -struct Sr25519PublicKey { - uint8_t data[32]; -}; - -struct Sr25519Signature { - uint8_t data[64]; -}; - -struct Sr25519Message { - uint8_t hash[64]; -}; - -// ============================================================================= -// Verification kernel -// ============================================================================= - -extern "C" __global__ void sr25519_verify_batch( - const Sr25519PublicKey* __restrict__ pubkeys, - const Sr25519Message* __restrict__ messages, - const Sr25519Signature* __restrict__ signatures, - uint32_t* __restrict__ results, - const uint32_t num_sigs) -{ - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= num_sigs) return; - - // Decode public key (Ristretto255 point) - RistrettoPoint A; - if (!ristretto_decode(pubkeys[tid].data, A)) { - results[tid] = 0; - return; - } - - // Decode R from signature - RistrettoPoint R; - if (!ristretto_decode(signatures[tid].data, R)) { - results[tid] = 0; - return; - } - - // Read scalar s - uint256 s; - for (int i = 0; i < 4; i++) { - s.limbs[i] = 0; - for (int b = 0; b < 8; b++) - s.limbs[i] |= (uint64_t)signatures[tid].data[32 + i * 8 + b] << (b * 8); - } - - if (u256_cmp(s, SR_L) >= 0) { - results[tid] = 0; - return; - } - - // Read pre-computed challenge scalar (reduced mod L by host) - uint256 c; - for (int i = 0; i < 4; i++) { - c.limbs[i] = 0; - for (int b = 0; b < 8; b++) - c.limbs[i] |= (uint64_t)messages[tid].hash[i * 8 + b] << (b * 8); - } - - // Generator point B (same as Ed25519) - const uint256 BX = {{0xC9562D608F25D51AULL, 0x692CC7609525A7B2ULL, - 0xC0A4E231FDD6DC5CULL, 0x216936D3CD6E53FEULL}}; - const uint256 BY = {{0x6666666666666658ULL, 0x6666666666666666ULL, - 0x6666666666666666ULL, 0x6666666666666666ULL}}; - RistrettoPoint B; - B.X = BX; B.Y = BY; B.Z = SR_ONE; B.T = fp_mul(BX, BY); - - // Verify: s*B == R + c*A - RistrettoPoint sB = ristretto_mul(s, B); - RistrettoPoint cA = ristretto_mul(c, A); - RistrettoPoint RcA = ristretto_add(R, cA); - - // Compare in affine - uint256 sb_x = fp_mul(sB.X, fp_inv(sB.Z)); - uint256 sb_y = fp_mul(sB.Y, fp_inv(sB.Z)); - uint256 rca_x = fp_mul(RcA.X, fp_inv(RcA.Z)); - uint256 rca_y = fp_mul(RcA.Y, fp_inv(RcA.Z)); - - bool valid = (u256_cmp(sb_x, rca_x) == 0) && (u256_cmp(sb_y, rca_y) == 0); - results[tid] = valid ? 1u : 0u; -} diff --git a/sr25519/gpu/metal/sr25519.metal b/sr25519/gpu/metal/sr25519.metal deleted file mode 100644 index b418b7f..0000000 --- a/sr25519/gpu/metal/sr25519.metal +++ /dev/null @@ -1,404 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -/// @file sr25519.metal -/// Metal compute shader for batch sr25519 (Schnorrkel/Ristretto255) verification. -/// -/// sr25519 uses Schnorr signatures on the Ristretto255 group, which provides -/// a prime-order group via cofactor elimination on Curve25519. -/// -/// Ristretto255 uses the same field as Ed25519 (p = 2^255 - 19) with -/// cofactor-free group operations. -/// -/// VRF (Verifiable Random Function) support included. -/// -/// Each thread verifies one Schnorr signature. - -#include -using namespace metal; - -// ============================================================================= -// Reuse Ed25519 field arithmetic (same field: p = 2^255 - 19) -// ============================================================================= - -struct uint256 { - ulong limbs[4]; -}; - -constant uint256 SR_P = {{ - 0xFFFFFFFFFFFFFFEDUL, 0xFFFFFFFFFFFFFFFFUL, - 0xFFFFFFFFFFFFFFFFUL, 0x7FFFFFFFFFFFFFFFUL -}}; - -constant uint256 SR_L = {{ - 0x5812631A5CF5D3EDUL, 0x14DEF9DEA2F79CD6UL, - 0x0000000000000000UL, 0x1000000000000000UL -}}; - -constant uint256 SR_ZERO = {{0, 0, 0, 0}}; -constant uint256 SR_ONE = {{1, 0, 0, 0}}; - -// d = -121665/121666 mod p (same as Ed25519) -constant uint256 SR_D = {{ - 0x75EB4DCA135978A3UL, 0x00700A4D4141D8ABUL, - 0x8CC740797779E898UL, 0x52036CBC148B6DE8UL -}}; - -constant uint256 SR_2D = {{ - 0xEBD69B9426B2F159UL, 0x00E0149A8283B156UL, - 0x198E80F2EEF3D130UL, 0x2406D9DC56DFFCE7UL -}}; - -// sqrt(-1) mod p -constant uint256 SR_SQRT_M1 = {{ - 0xC4EE1B274A0EA0B0UL, 0x2F431806AD2FE478UL, - 0x2B4D00993DFBD7A7UL, 0x2B8324804FC1DF0BUL -}}; - -// ============================================================================= -// Field arithmetic (same as ed25519.metal, reproduced for standalone compilation) -// ============================================================================= - -inline int u256_cmp(uint256 a, uint256 b) { - for (int i = 3; i >= 0; i--) { - if (a.limbs[i] < b.limbs[i]) return -1; - if (a.limbs[i] > b.limbs[i]) return 1; - } - return 0; -} - -inline bool u256_is_zero(uint256 a) { - return (a.limbs[0] | a.limbs[1] | a.limbs[2] | a.limbs[3]) == 0; -} - -inline uint256 u256_add(uint256 a, uint256 b, thread ulong& carry) { - uint256 r; - ulong c = 0; - for (int i = 0; i < 4; i++) { - ulong sum = a.limbs[i] + c; - c = (sum < a.limbs[i]) ? 1UL : 0UL; - ulong sum2 = sum + b.limbs[i]; - c += (sum2 < sum) ? 1UL : 0UL; - r.limbs[i] = sum2; - } - carry = c; - return r; -} - -inline uint256 u256_sub(uint256 a, uint256 b, thread ulong& borrow) { - uint256 r; - ulong bw = 0; - for (int i = 0; i < 4; i++) { - ulong diff = a.limbs[i] - bw; - bw = (diff > a.limbs[i]) ? 1UL : 0UL; - ulong diff2 = diff - b.limbs[i]; - bw += (diff2 > diff) ? 1UL : 0UL; - r.limbs[i] = diff2; - } - borrow = bw; - return r; -} - -inline void mul64(ulong a, ulong b, thread ulong& lo, thread ulong& hi) { - ulong a_lo = a & 0xFFFFFFFFUL, a_hi = a >> 32; - ulong b_lo = b & 0xFFFFFFFFUL, b_hi = b >> 32; - ulong ll = a_lo * b_lo, lh = a_lo * b_hi; - ulong hl = a_hi * b_lo, hh = a_hi * b_hi; - ulong mid = lh + (ll >> 32); - ulong mid2 = mid + hl; - if (mid2 < mid) hh += (1UL << 32); - lo = (mid2 << 32) | (ll & 0xFFFFFFFFUL); - hi = hh + (mid2 >> 32); -} - -inline uint256 fp_add(uint256 a, uint256 b) { - ulong c; - uint256 r = u256_add(a, b, c); - if (c || u256_cmp(r, SR_P) >= 0) { ulong bw; r = u256_sub(r, SR_P, bw); } - return r; -} - -inline uint256 fp_sub(uint256 a, uint256 b) { - ulong bw; - uint256 r = u256_sub(a, b, bw); - if (bw) { ulong c; r = u256_add(r, SR_P, c); } - return r; -} - -inline uint256 fp_mul(uint256 a, uint256 b) { - ulong t[8] = {}; - for (int i = 0; i < 4; i++) { - ulong carry = 0; - for (int j = 0; j < 4; j++) { - ulong lo, hi; - mul64(a.limbs[i], b.limbs[j], lo, hi); - ulong sum = lo + carry; if (sum < lo) hi++; - lo = sum; - sum = t[i + j] + lo; if (sum < t[i + j]) hi++; - t[i + j] = sum; - carry = hi; - } - t[i + 4] = carry; - } - uint256 lo_part = {{t[0], t[1], t[2], t[3]}}; - uint256 hi_part = {{t[4], t[5], t[6], t[7]}}; - ulong c2 = 0; - uint256 hi38; - for (int i = 0; i < 4; i++) { - ulong lo, hi; - mul64(hi_part.limbs[i], 38UL, lo, hi); - ulong sum = lo + c2; - c2 = hi + ((sum < lo) ? 1UL : 0UL); - hi38.limbs[i] = sum; - } - ulong c; - uint256 result = u256_add(lo_part, hi38, c); - if (c || c2) { - ulong extra = (c + c2) * 38; - uint256 extra256 = {{extra, 0, 0, 0}}; - result = u256_add(result, extra256, c); - } - while (u256_cmp(result, SR_P) >= 0) { ulong bw; result = u256_sub(result, SR_P, bw); } - return result; -} - -inline uint256 fp_sqr(uint256 a) { return fp_mul(a, a); } -inline uint256 fp_neg(uint256 a) { - if (u256_is_zero(a)) return a; - ulong bw; return u256_sub(SR_P, a, bw); -} - -inline uint256 fp_inv(uint256 a) { - uint256 exp = SR_P; exp.limbs[0] -= 2; - uint256 result = SR_ONE, base = a; - for (int i = 0; i < 4; i++) - for (int bit = 0; bit < 64; bit++) { - if ((exp.limbs[i] >> bit) & 1) result = fp_mul(result, base); - base = fp_sqr(base); - } - return result; -} - -// ============================================================================= -// Extended Edwards point (same curve as Ed25519) -// ============================================================================= - -struct RistrettoPoint { - uint256 X, Y, Z, T; -}; - -inline RistrettoPoint ristretto_identity() { - RistrettoPoint p; - p.X = SR_ZERO; p.Y = SR_ONE; p.Z = SR_ONE; p.T = SR_ZERO; - return p; -} - -inline RistrettoPoint ristretto_double(RistrettoPoint P) { - uint256 A = fp_sqr(P.X); - uint256 B = fp_sqr(P.Y); - uint256 C = fp_add(fp_sqr(P.Z), fp_sqr(P.Z)); - uint256 D = fp_neg(A); - uint256 E = fp_sub(fp_sqr(fp_add(P.X, P.Y)), fp_add(A, B)); - uint256 G = fp_add(D, B); - uint256 F = fp_sub(G, C); - uint256 H = fp_sub(D, B); - RistrettoPoint R; - R.X = fp_mul(E, F); R.Y = fp_mul(G, H); - R.T = fp_mul(E, H); R.Z = fp_mul(F, G); - return R; -} - -inline RistrettoPoint ristretto_add(RistrettoPoint P, RistrettoPoint Q) { - uint256 A = fp_mul(P.X, Q.X); - uint256 B = fp_mul(P.Y, Q.Y); - uint256 C = fp_mul(P.T, fp_mul(SR_2D, Q.T)); - uint256 D = fp_add(fp_mul(P.Z, Q.Z), fp_mul(P.Z, Q.Z)); - uint256 E = fp_sub(fp_mul(fp_add(P.X, P.Y), fp_add(Q.X, Q.Y)), fp_add(A, B)); - uint256 F = fp_sub(D, C); - uint256 G = fp_add(D, C); - uint256 H = fp_add(B, A); - RistrettoPoint R; - R.X = fp_mul(E, F); R.Y = fp_mul(G, H); - R.T = fp_mul(E, H); R.Z = fp_mul(F, G); - return R; -} - -inline RistrettoPoint ristretto_mul(uint256 k, RistrettoPoint P) { - RistrettoPoint result = ristretto_identity(); - for (int i = 3; i >= 0; i--) - for (int bit = 63; bit >= 0; bit--) { - result = ristretto_double(result); - if ((k.limbs[i] >> bit) & 1) result = ristretto_add(result, P); - } - return result; -} - -// ============================================================================= -// Ristretto255 decoding (from 32-byte compressed form) -// ============================================================================= - -/// Decode a Ristretto255 point from 32 bytes. -/// Follows the Ristretto255 spec (draft-irtf-cfrg-ristretto255-00). -inline bool ristretto_decode(device const uchar* encoded, thread RistrettoPoint& P) { - // Read s (little-endian field element) - uint256 s; - for (int i = 0; i < 4; i++) { - s.limbs[i] = 0; - for (int b = 0; b < 8 && i * 8 + b < 32; b++) - s.limbs[i] |= (ulong)encoded[i * 8 + b] << (b * 8); - } - - // s must be non-negative (MSB must be 0) and < p - if (s.limbs[3] >> 63) return false; - if (u256_cmp(s, SR_P) >= 0) return false; - - // Check s is canonical (s must equal its encoding) - // s^2 - uint256 ss = fp_sqr(s); - // u1 = 1 - s^2 - uint256 u1 = fp_sub(SR_ONE, ss); - // u2 = 1 + s^2 - uint256 u2 = fp_add(SR_ONE, ss); - uint256 u2_sq = fp_sqr(u2); - // v = -(d) * u1^2 - u2^2 - uint256 v = fp_sub(fp_neg(fp_mul(SR_D, fp_sqr(u1))), u2_sq); - - // invsqrt(v * u2^2) using p = 5 mod 8 shortcut - uint256 vu2sq = fp_mul(v, u2_sq); - - // Compute candidate: (v * u2^2)^((p-5)/8) - uint256 exp58 = SR_P; - exp58.limbs[0] -= 5; - // Divide by 8: shift right 3 - for (int i = 0; i < 3; i++) - exp58.limbs[i] = (exp58.limbs[i] >> 3) | (exp58.limbs[i + 1] << 61); - exp58.limbs[3] >>= 3; - - uint256 inv_sqrt = SR_ONE; - uint256 base = vu2sq; - for (int i = 0; i < 4; i++) - for (int bit = 0; bit < 64; bit++) { - if ((exp58.limbs[i] >> bit) & 1) inv_sqrt = fp_mul(inv_sqrt, base); - base = fp_sqr(base); - } - - // Check: (inv_sqrt^2 * v * u2^2) should be +/-1 - uint256 check = fp_mul(fp_sqr(inv_sqrt), vu2sq); - bool negated = false; - if (u256_cmp(check, SR_ONE) != 0) { - uint256 neg1 = fp_neg(SR_ONE); - if (u256_cmp(check, neg1) == 0) { - inv_sqrt = fp_mul(inv_sqrt, SR_SQRT_M1); - negated = true; - } else { - return false; - } - } - - // x = |2 * s * invsqrt| (take absolute value) - uint256 x = fp_mul(fp_add(s, s), fp_mul(inv_sqrt, u2)); - if (x.limbs[0] & 1) x = fp_neg(x); // Make non-negative (even) - - uint256 y = fp_mul(u1, fp_mul(inv_sqrt, u2)); - - P.X = x; - P.Y = y; - P.Z = SR_ONE; - P.T = fp_mul(x, y); - return true; -} - -// ============================================================================= -// Structures -// ============================================================================= - -struct Sr25519PublicKey { - uchar data[32]; // Ristretto255 compressed point -}; - -struct Sr25519Signature { - uchar data[64]; // R[32] || s[32] -}; - -struct Sr25519Message { - uchar hash[64]; // Pre-computed transcript hash -}; - -// ============================================================================= -// Verification kernel -// ============================================================================= - -/// Batch sr25519 Schnorr signature verification. -/// Each thread verifies one signature. -/// -/// Schnorr verify: check s*B == R + H(R||A||M)*A -/// Host pre-computes the transcript hash and reduces mod L. -/// -/// Output: results[tid] = 1 if valid, 0 otherwise. -kernel void sr25519_verify_batch( - device const Sr25519PublicKey* pubkeys [[buffer(0)]], - device const Sr25519Message* messages [[buffer(1)]], - device const Sr25519Signature* signatures [[buffer(2)]], - device uint* results [[buffer(3)]], - constant uint& num_sigs [[buffer(4)]], - uint tid [[thread_position_in_grid]]) -{ - if (tid >= num_sigs) return; - - // Decode public key (Ristretto255 point) - RistrettoPoint A; - if (!ristretto_decode(pubkeys[tid].data, A)) { - results[tid] = 0; - return; - } - - // Decode R from signature - RistrettoPoint R; - if (!ristretto_decode(signatures[tid].data, R)) { - results[tid] = 0; - return; - } - - // Read scalar s - uint256 s; - for (int i = 0; i < 4; i++) { - s.limbs[i] = 0; - for (int b = 0; b < 8; b++) - s.limbs[i] |= (ulong)signatures[tid].data[32 + i * 8 + b] << (b * 8); - } - - if (u256_cmp(s, SR_L) >= 0) { - results[tid] = 0; - return; - } - - // Read pre-computed challenge scalar (reduced mod L by host) - uint256 c; - for (int i = 0; i < 4; i++) { - c.limbs[i] = 0; - for (int b = 0; b < 8; b++) - c.limbs[i] |= (ulong)messages[tid].hash[i * 8 + b] << (b * 8); - } - - // Generator point B (same as Ed25519) - const uint256 BX = {{0xC9562D608F25D51AUL, 0x692CC7609525A7B2UL, - 0xC0A4E231FDD6DC5CUL, 0x216936D3CD6E53FEUL}}; - const uint256 BY = {{0x6666666666666658UL, 0x6666666666666666UL, - 0x6666666666666666UL, 0x6666666666666666UL}}; - RistrettoPoint B; - B.X = BX; B.Y = BY; B.Z = SR_ONE; B.T = fp_mul(BX, BY); - - // Verify: s*B == R + c*A - RistrettoPoint sB = ristretto_mul(s, B); - RistrettoPoint cA = ristretto_mul(c, A); - RistrettoPoint RcA = ristretto_add(R, cA); - - // Compare in affine - uint256 sb_x = fp_mul(sB.X, fp_inv(sB.Z)); - uint256 sb_y = fp_mul(sB.Y, fp_inv(sB.Z)); - uint256 rca_x = fp_mul(RcA.X, fp_inv(RcA.Z)); - uint256 rca_y = fp_mul(RcA.Y, fp_inv(RcA.Z)); - - bool valid = (u256_cmp(sb_x, rca_x) == 0) && (u256_cmp(sb_y, rca_y) == 0); - results[tid] = valid ? 1u : 0u; -} diff --git a/sr25519/gpu/wgsl/sr25519.wgsl b/sr25519/gpu/wgsl/sr25519.wgsl deleted file mode 100644 index 50a8d55..0000000 --- a/sr25519/gpu/wgsl/sr25519.wgsl +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright (c) 2024-2026 Lux Industries Inc. -// SPDX-License-Identifier: BSD-3-Clause-Eco -// -// sr25519 (Schnorrkel/Ristretto255) batch verification in WGSL. -// Schnorr signatures on the Ristretto255 group. -// Same field as Ed25519 (p = 2^255 - 19) with cofactor elimination. -// Each thread verifies one signature. - -@group(0) @binding(0) var pubkeys: array; -@group(0) @binding(1) var msg_hashes: array; -@group(0) @binding(2) var signatures: array; -@group(0) @binding(3) var results: array; -@group(0) @binding(4) var params: vec4; // params.x = num_sigs - -const L = array( - 0x5CF5D3EDu, 0x5812631Au, 0xA2F79CD6u, 0x14DEF9DEu, - 0x00000000u, 0x00000000u, 0x00000000u, 0x10000000u -); - -const P = array( - 0xFFFFFFEDu, 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu, - 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu, 0x7FFFFFFFu -); - -fn u256_cmp(a: ptr>, b: ptr>) -> i32 { - for (var i = 7i; i >= 0i; i = i - 1i) { - let idx = u32(i); - if ((*a)[idx] < (*b)[idx]) { return -1; } - if ((*a)[idx] > (*b)[idx]) { return 1; } - } - return 0; -} - -@compute @workgroup_size(64) -fn sr25519_verify_batch(@builtin(global_invocation_id) gid: vec3) { - let tid = gid.x; - if (tid >= params.x) { return; } - - // Read compressed point (32 bytes = 8 u32) - var pk: array; - let pk_base = tid * 8u; - for (var i = 0u; i < 8u; i = i + 1u) { pk[i] = pubkeys[pk_base + i]; } - - // Ristretto encoding check: MSB must be 0 (non-negative) - if ((pk[7u] >> 31u) != 0u) { - results[tid] = 0u; - return; - } - - // Check encoding < p - var pk_check = pk; - var p_val: array = P; - if (u256_cmp(&pk_check, &p_val) >= 0) { - results[tid] = 0u; - return; - } - - // Read signature: R[32] || s[32] - var sig_r: array; - var sig_s: array; - let sig_base = tid * 16u; - for (var i = 0u; i < 8u; i = i + 1u) { - sig_r[i] = signatures[sig_base + i]; - sig_s[i] = signatures[sig_base + 8u + i]; - } - - // R must be valid Ristretto encoding (MSB = 0, < p) - if ((sig_r[7u] >> 31u) != 0u) { - results[tid] = 0u; - return; - } - var r_check = sig_r; - if (u256_cmp(&r_check, &p_val) >= 0) { - results[tid] = 0u; - return; - } - - // s must be < L - var s_check = sig_s; - var l_val: array = L; - if (u256_cmp(&s_check, &l_val) >= 0) { - results[tid] = 0u; - return; - } - - // Read challenge scalar (pre-computed, reduced mod L by host) - var c: array; - let hash_base = tid * 16u; - for (var i = 0u; i < 8u; i = i + 1u) { c[i] = msg_hashes[hash_base + i]; } - - // Input validation passed. - // Full Ristretto255 point arithmetic (decode, scalar mul, compare) - // is handled by Metal/CUDA backends. WGSL validates input format. - results[tid] = 1u; -} From 42adb175a994b528baf0f1fab641af90484f4113 Mon Sep 17 00:00:00 2001 From: Hanzo AI Date: Sat, 16 May 2026 15:45:32 -0700 Subject: [PATCH 4/5] math+fhe: __has_include-gate the GPU driver headers in C ABI bridges MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two consumers reach for math/ntt/cuda/lattice_ring_driver.h with relative paths that broke after the prior commit deleted the in-tree gpu/ subtrees: * math/ntt/c-abi/c_math_ntt.cpp — Go cgo bridge for the math/ntt backends. Uses __has_include + LUX_MATH_NTT_HAVE_{CUDA,METAL,WGSL} guards so each backend's surface compiles to its body only when the driver header is present; absent backends fall through to a -2 (not-built) stub. Always-built target. * fhe/cpp/backends/cuda/cuda_ntt_kernel.cpp — moved the lattice_ring CUDA C ABI declarations next to the stub bodies (fhe/cpp/backends/cuda/lattice_ring_cuda_decls.h, renamed from the short-lived math/ntt/cuda/lattice_ring_driver.h stub) so the FHE dispatcher can build without depending on math/ntt/cuda existing as a real directory. The latter is a symlink into lux-private when CRYPTO_ENABLE_CUDA=ON; absent otherwise — `NOT EXISTS` is the discovery hook's gate so a real header in math/ntt/cuda blocks the symlink. math/CMakeLists.txt drops the EXISTS-gated math_ntt_c_abi target back to always-built since the source self-guards. Per-backend include paths and link_libraries stay conditional. Verification: cevm-genesis-parity PASSES under both modes (canonical state_root = 0x2d1ced... ; canonical genesis hash = 0x3f4fa2...). --- fhe/cpp/backends/cuda/cuda_ntt_kernel.cpp | 7 ++- .../backends/cuda/lattice_ring_cuda_decls.h | 0 math/CMakeLists.txt | 19 +++--- math/ntt/c-abi/c_math_ntt.cpp | 63 +++++++++++++++---- 4 files changed, 70 insertions(+), 19 deletions(-) rename math/ntt/cuda/lattice_ring_driver.h => fhe/cpp/backends/cuda/lattice_ring_cuda_decls.h (100%) diff --git a/fhe/cpp/backends/cuda/cuda_ntt_kernel.cpp b/fhe/cpp/backends/cuda/cuda_ntt_kernel.cpp index 115e781..bc6cd20 100644 --- a/fhe/cpp/backends/cuda/cuda_ntt_kernel.cpp +++ b/fhe/cpp/backends/cuda/cuda_ntt_kernel.cpp @@ -13,7 +13,12 @@ #include "../cpu/ntt_cpu.hpp" #include "../../../../math/ntt/cpu/lattice_ring.hpp" -#include "../../../../math/ntt/cuda/lattice_ring_driver.h" +// lattice_ring_cuda_* extern C ABI. Declared locally so we don't depend on +// math/ntt/cuda/ existing as a real directory — that path is a symlink into +// lux-private/gpu-kernels when CRYPTO_ENABLE_CUDA=ON, otherwise absent. +// Symbol bodies come from either the real lattice_ring_cuda target (CUDA on) +// or fhe/cpp/backends/cuda/lattice_ring_cuda_stub.cpp (CUDA off). +#include "lattice_ring_cuda_decls.h" #include #include diff --git a/math/ntt/cuda/lattice_ring_driver.h b/fhe/cpp/backends/cuda/lattice_ring_cuda_decls.h similarity index 100% rename from math/ntt/cuda/lattice_ring_driver.h rename to fhe/cpp/backends/cuda/lattice_ring_cuda_decls.h diff --git a/math/CMakeLists.txt b/math/CMakeLists.txt index 0919c70..84331db 100644 --- a/math/CMakeLists.txt +++ b/math/CMakeLists.txt @@ -150,13 +150,22 @@ target_link_libraries(math_ntt INTERFACE lattice_ring_cpu math_modarith math_par # math_ntt_c_abi --- C ABI bridge for the LP-107 luxfi/math/ntt GPU # backends (CUDA / Metal / WGSL). Pure connective tissue: each entry point # walks the batch and forwards to the existing lattice_ring_* -# drivers shipped with ringtail. No new compute. The Go side at +# drivers. No new compute. The Go side at # github.com/luxfi/math/ntt/{cuda,metal,wgsl} consumes this surface via cgo. +# +# Built only when at least one backend's lattice_ring_* target exists +# (i.e. lux-gpu-kernels was installed for CUDA, or math/ntt/{metal,wgsl} +# was present for those backends). On a CPU-only build the bridge has +# nothing to bridge to and is skipped — Go callers using cgo for these +# backends are equally absent in CPU-only configs. +# math_ntt_c_abi: always built. The .cpp uses __has_include + #if guards on +# each backend's lattice_ring driver header — backends without an installed +# header (e.g. CUDA when lux-gpu-kernels is absent) compile to stub returns +# (0 from _supports, -2 from forward/inverse) so the Go cgo layer sees the +# expected "backend not built" signal. add_library(math_ntt_c_abi STATIC ntt/c-abi/c_math_ntt.cpp) target_compile_features(math_ntt_c_abi PUBLIC cxx_std_20) set_target_properties(math_ntt_c_abi PROPERTIES POSITION_INDEPENDENT_CODE ON) -# Base include path. Backend-specific include dirs are added conditionally -# below — only when lux-private/gpu-kernels has symlinked them into place. target_include_directories(math_ntt_c_abi PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} ) @@ -172,10 +181,6 @@ if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/ntt/wgsl) target_include_directories(math_ntt_c_abi PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/ntt/wgsl) endif() -# Link the GPU drivers we forward to. Each is conditionally added by the -# ringtail CMakeLists; if a driver target is missing the C-ABI just won't -# resolve those symbols at link time, which is fine because the Go-side -# build tag for that backend won't be set either. if(TARGET lattice_ring_cuda) target_link_libraries(math_ntt_c_abi PUBLIC lattice_ring_cuda) endif() diff --git a/math/ntt/c-abi/c_math_ntt.cpp b/math/ntt/c-abi/c_math_ntt.cpp index 7dd8c14..5775e34 100644 --- a/math/ntt/c-abi/c_math_ntt.cpp +++ b/math/ntt/c-abi/c_math_ntt.cpp @@ -14,11 +14,26 @@ #include #include -// ----- Underlying ringtail GPU drivers (already shipping). -------------------- +// ----- Underlying GPU drivers. Each backend's lattice_ring_driver header is +// only present when that backend is wired in (CUDA via lux-gpu-kernels; +// Metal + WGSL in-tree). The CMake target gates compilation on +// (lattice_ring_cuda OR lattice_ring_metal OR lattice_ring_wgpu) so at +// least one backend is always available when this TU compiles, but we +// still guard each include so a single-backend build (e.g. WGSL-only) +// doesn't reach for the others. -#include "../cuda/lattice_ring_driver.h" -#include "../metal/lattice_ring_driver.h" -#include "../wgsl/lattice_ring_wgpu.hpp" +#if __has_include("../cuda/lattice_ring_driver.h") +# include "../cuda/lattice_ring_driver.h" +# define LUX_MATH_NTT_HAVE_CUDA 1 +#endif +#if __has_include("../metal/lattice_ring_driver.h") +# include "../metal/lattice_ring_driver.h" +# define LUX_MATH_NTT_HAVE_METAL 1 +#endif +#if __has_include("../wgsl/lattice_ring_wgpu.hpp") +# include "../wgsl/lattice_ring_wgpu.hpp" +# define LUX_MATH_NTT_HAVE_WGSL 1 +#endif namespace { @@ -38,11 +53,10 @@ inline bool params_supported(uint32_t N, uint64_t Q) { // CUDA // ============================================================================= +#if LUX_MATH_NTT_HAVE_CUDA + extern "C" int lux_math_ntt_cuda_supports(uint32_t N, uint64_t Q) { if (!params_supported(N, Q)) return 0; - // lattice_ring_cuda_available returns 1 only when a real CUDA - // device is reachable; in host-polyfill builds (this dev box) it returns - // 0 and we route the math/ntt dispatcher away from CUDA. return lattice_ring_cuda_available(); } @@ -84,11 +98,25 @@ extern "C" int lux_math_ntt_cuda_inverse(uint64_t* dst, return 0; } +#else // !LUX_MATH_NTT_HAVE_CUDA — driver absent; bridge surface returns 0/-2. + +extern "C" int lux_math_ntt_cuda_supports(uint32_t, uint64_t) { return 0; } +extern "C" int lux_math_ntt_cuda_forward(uint64_t*, const uint64_t*, + uint32_t, uint32_t, + uint64_t, uint64_t, uint64_t, + const uint64_t*) { return -2; } +extern "C" int lux_math_ntt_cuda_inverse(uint64_t*, const uint64_t*, + uint32_t, uint32_t, + uint64_t, uint64_t, uint64_t, + const uint64_t*) { return -2; } + +#endif // LUX_MATH_NTT_HAVE_CUDA + // ============================================================================= // Metal // ============================================================================= -#if defined(__APPLE__) +#if defined(__APPLE__) && LUX_MATH_NTT_HAVE_METAL extern "C" int lux_math_ntt_metal_supports(uint32_t N, uint64_t Q, @@ -189,6 +217,8 @@ extern "C" int lux_math_ntt_metal_inverse(uint64_t*, // WGSL // ============================================================================= +#if LUX_MATH_NTT_HAVE_WGSL + extern "C" int lux_math_ntt_wgsl_supports(uint32_t N, uint64_t Q) { if (!params_supported(N, Q)) return 0; return lux_ringtail_lattice_ring_wgpu_available(); @@ -205,9 +235,6 @@ extern "C" int lux_math_ntt_wgsl_forward(uint64_t* dst, if (!dst || !src || !roots_forward) return -1; if (!params_supported(N, Q)) return -1; if (batch == 0) return 0; - // The wgpu driver is natively batched (n_polys argument) and pinned to - // N=256 on the existing lattice_ring kernel; the Go-side - // _supports probe rejects other N values for WGSL. return lux_ringtail_lattice_ring_wgpu_ntt( src, dst, batch, Q, mrc, brc_hi, roots_forward); } @@ -226,3 +253,17 @@ extern "C" int lux_math_ntt_wgsl_inverse(uint64_t* dst, return lux_ringtail_lattice_ring_wgpu_intt( src, dst, batch, Q, mrc, n_inv_montgomery, roots_backward); } + +#else // !LUX_MATH_NTT_HAVE_WGSL + +extern "C" int lux_math_ntt_wgsl_supports(uint32_t, uint64_t) { return 0; } +extern "C" int lux_math_ntt_wgsl_forward(uint64_t*, const uint64_t*, + uint32_t, uint32_t, + uint64_t, uint64_t, uint64_t, + const uint64_t*) { return -2; } +extern "C" int lux_math_ntt_wgsl_inverse(uint64_t*, const uint64_t*, + uint32_t, uint32_t, + uint64_t, uint64_t, uint64_t, + const uint64_t*) { return -2; } + +#endif // LUX_MATH_NTT_HAVE_WGSL From 1e3863df94db7418b66f25120768dbb891bd2376 Mon Sep 17 00:00:00 2001 From: Hanzo AI Date: Sat, 16 May 2026 15:46:05 -0700 Subject: [PATCH 5/5] gitignore: symlinks created by lux-gpu-kernels discovery hook MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CMake's find_package(lux-gpu-kernels) at top-level CMakeLists symlinks each scheme's gpu// from the lux-private install prefix at configure time. Those symlinks are runtime artifacts, never committed — ignore them so 'git status' stays clean across CPU-only and with-private builds. --- .gitignore | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.gitignore b/.gitignore index 42db6b7..cffff07 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,12 @@ QWEN.md build-local/ docs-site/.next/ docs-site/node_modules/ + +# Symlinks created at configure-time by the lux-gpu-kernels discovery hook +# (top-level CMakeLists.txt). Each /gpu/ is a symlink into +# the lux-gpu-kernels install prefix when the private repo is found; +# absent on CPU-only builds. Never committed. +*/gpu/cuda +*/gpu/metal +*/gpu/wgsl +math/ntt/cuda