Skip to content

Compile bug: RocmWMMA doesn't work #14193

@martinalderson

Description

@martinalderson

Git commit

3cb203c

Operating systems

Linux

GGML backends

HIP

Problem description & steps to reproduce

Related to 13110. Can't reopen so putting it here. rocwmma-dev suggests gfx1201 is supported now, so opening a new ticket:

rocWMMA currently supports the following AMD GPU architectures:

CDNA class GPU featuring matrix core support: gfx908, gfx90a, gfx942, gfx950 as 'gfx9'
RDNA class GPU featuring AI acceleration support: gfx1100, gfx1101, gfx1102 as 'gfx11'; gfx1200, gfx1201 as 'gfx12'

I get the following error on gfx1201 with rocmWMMA:

HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=gfx1201 -DCMAKE_BUILD_TYPE=Release -DGGML_HIP_ROCWMMA_FATTN=ON && cmake --build build --config Release -- -j 16

in fact, i could only get it working with this:

HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=gfx1201 -DCMAKE_BUILD_TYPE=Release -DGGML_HIP_ROCWMMA_FATTN=OFF -DGGML_HIP_WMMA=OFF && cmake --build build --config Release -- -j 16

using ubuntu 22.04 w/ rocm 6.4.1
rocwmma-dev is already the newest version (1.7.0.60401-83~22.04).

First Bad Commit

No response

Compile command

HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=gfx1201 -DCMAKE_BUILD_TYPE=Release -DGGML_HIP_ROCWMMA_FATTN=ON && cmake --build build --config Release -- -j 16

Relevant log output

martin@DESKTOP-LEAK61N:~/llama.cpp$ HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=gfx1201 -DCMAKE_BUILD_TYPE=Release -DGGML_HIP_ROCWMMA_FATTN=ON && cmake --build build --config Release -- -j 16 > build.log
CMake Warning at CMakeLists.txt:116 (message):
  LLAMA_NATIVE is deprecated and will be removed in the future.

  Use GGML_NATIVE instead

Call Stack (most recent call first):
  CMakeLists.txt:126 (llama_option_depr)


-- ccache found, compilation results will be cached. Disable with GGML_CCACHE=OFF.
-- CMAKE_SYSTEM_PROCESSOR: x86_64
-- GGML_SYSTEM_ARCH: x86
-- Including CPU backend
-- x86 detected
-- Adding CPU backend variant ggml-cpu: -march=native
CMake Warning (dev) at /opt/rocm/lib/cmake/hip/hip-config-amd.cmake:70 (message):
  AMDGPU_TARGETS is deprecated.  Please use GPU_TARGETS instead.
Call Stack (most recent call first):
  /opt/rocm/lib/cmake/hip/hip-config.cmake:149 (include)
  ggml/src/ggml-hip/CMakeLists.txt:39 (find_package)
This warning is for project developers.  Use -Wno-dev to suppress it.

-- HIP and hipBLAS found
-- Including HIP backend
-- Configuring done (5.4s)
-- Generating done (0.1s)
-- Build files have been written to: /home/martin/llama.cpp/build
In file included from /home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:5:
/home/martin/llama.cpp/ggml/src/ggml-cuda/common.cuh:266:12: warning: macro '__AMDGCN_WAVEFRONT_SIZE' has been marked as deprecated: compile-time-constant access to the wavefront size will be removed in a future release [-Wdeprecated-pragma]
  266 |     return __AMDGCN_WAVEFRONT_SIZE;
      |            ^
<built-in>:891:139: note: macro marked 'deprecated' here
  891 | #pragma clang deprecated(__AMDGCN_WAVEFRONT_SIZE, "compile-time-constant access to the wavefront size will be removed in a future release")
      |                                                                                                                                           ^
In file included from /home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:15:
In file included from /opt/rocm-6.4.1/lib/llvm/bin/../../../include/rocwmma/rocwmma.hpp:321:
In file included from /opt/rocm-6.4.1/lib/llvm/bin/../../../include/rocwmma/rocwmma_impl.hpp:56:
/opt/rocm-6.4.1/lib/llvm/bin/../../../include/rocwmma/internal/wmma.hpp:124:23: error: static assertion failed due to requirement 'VecTraits<HIP_vector_type<float, 4>>::size() == IOTraits<16, 16, __half, 1>::UnpackedSize': WMMA backend input size mismatch
  124 |         static_assert(VecTraitsC::size() == IOTraitsAcc::UnpackedSize,
      |                       ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/opt/rocm-6.4.1/lib/llvm/bin/../../../include/rocwmma/rocwmma_impl.hpp:377:37: note: in instantiation of template class 'rocwmma::Wmma<__half, __half, 16, 16, 16>' requested here
  377 |                     PackAcc::unpack(Mma::exec(PackA::pack(PreMmaA::exec(a.mAccess)),
      |                                     ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:340:27: note: in instantiation of function template specialization 'rocwmma::mma_sync<16U, 16U, 16U, __half, __half, rocwmma::col_major, rocwmma::col_major, void, void>' requested here
  340 |                     wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
      |                           ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:486:24: note: in instantiation of function template specialization 'flash_attn_ext_f16<64, 16, 4, 64, float, false>' requested here
  486 |         fattn_kernel = flash_attn_ext_f16<
      |                        ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:508:21: note: in instantiation of function template specialization 'ggml_cuda_flash_attn_ext_wmma_f16_case<64, 16, float>' requested here
  508 |                     ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
      |                     ^
/opt/rocm-6.4.1/lib/llvm/bin/../../../include/rocwmma/internal/wmma.hpp:124:42: note: expression evaluates to '4 == 8'
  124 |         static_assert(VecTraitsC::size() == IOTraitsAcc::UnpackedSize,
      |                       ~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:193:21: error: no matching function for call to 'mma_sync'
  193 |                     wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
      |                     ^~~~~~~~~~~~~~
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:490:24: note: in instantiation of function template specialization 'flash_attn_ext_f16<64, 16, 4, 64, __half, true>' requested here
  490 |         fattn_kernel = flash_attn_ext_f16<
      |                        ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:586:17: note: in instantiation of function template specialization 'ggml_cuda_flash_attn_ext_wmma_f16_case<64, 16, __half>' requested here
  586 |                 ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
      |                 ^
/opt/rocm-6.4.1/lib/llvm/bin/../../../include/rocwmma/rocwmma_impl.hpp:331:9: note: candidate template ignored: substitution failure [with BlockM = 16, BlockN = 16, BlockK = 16, InputT = __half, ComputeT = __half, LayoutA = rocwmma::row_major, LayoutB = rocwmma::col_major, LayoutC = void, LayoutD = void]
  331 |         mma_sync(fragment<accumulator, BlockM, BlockN, BlockK, ComputeT, LayoutD>&       d,
      |         ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:193:21: error: no matching function for call to 'mma_sync'
  193 |                     wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
      |                     ^~~~~~~~~~~~~~
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:486:24: note: in instantiation of function template specialization 'flash_attn_ext_f16<80, 16, 4, 16, __half, false>' requested here
  486 |         fattn_kernel = flash_attn_ext_f16<
      |                        ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:589:17: note: in instantiation of function template specialization 'ggml_cuda_flash_attn_ext_wmma_f16_case<80, 16, __half>' requested here
  589 |                 ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
      |                 ^
/opt/rocm-6.4.1/lib/llvm/bin/../../../include/rocwmma/rocwmma_impl.hpp:331:9: note: candidate template ignored: substitution failure [with BlockM = 16, BlockN = 16, BlockK = 16, InputT = __half, ComputeT = __half, LayoutA = rocwmma::row_major, LayoutB = rocwmma::col_major, LayoutC = void, LayoutD = void]
  331 |         mma_sync(fragment<accumulator, BlockM, BlockN, BlockK, ComputeT, LayoutD>&       d,
      |         ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:193:21: error: no matching function for call to 'mma_sync'
  193 |                     wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
      |                     ^~~~~~~~~~~~~~
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:490:24: note: in instantiation of function template specialization 'flash_attn_ext_f16<80, 16, 4, 16, __half, true>' requested here
  490 |         fattn_kernel = flash_attn_ext_f16<
      |                        ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:589:17: note: in instantiation of function template specialization 'ggml_cuda_flash_attn_ext_wmma_f16_case<80, 16, __half>' requested here
  589 |                 ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
      |                 ^
/opt/rocm-6.4.1/lib/llvm/bin/../../../include/rocwmma/rocwmma_impl.hpp:331:9: note: candidate template ignored: substitution failure [with BlockM = 16, BlockN = 16, BlockK = 16, InputT = __half, ComputeT = __half, LayoutA = rocwmma::row_major, LayoutB = rocwmma::col_major, LayoutC = void, LayoutD = void]
  331 |         mma_sync(fragment<accumulator, BlockM, BlockN, BlockK, ComputeT, LayoutD>&       d,
      |         ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:193:21: error: no matching function for call to 'mma_sync'
  193 |                     wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
      |                     ^~~~~~~~~~~~~~
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:486:24: note: in instantiation of function template specialization 'flash_attn_ext_f16<96, 16, 4, 32, __half, false>' requested here
  486 |         fattn_kernel = flash_attn_ext_f16<
      |                        ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:592:17: note: in instantiation of function template specialization 'ggml_cuda_flash_attn_ext_wmma_f16_case<96, 16, __half>' requested here
  592 |                 ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
      |                 ^
/opt/rocm-6.4.1/lib/llvm/bin/../../../include/rocwmma/rocwmma_impl.hpp:331:9: note: candidate template ignored: substitution failure [with BlockM = 16, BlockN = 16, BlockK = 16, InputT = __half, ComputeT = __half, LayoutA = rocwmma::row_major, LayoutB = rocwmma::col_major, LayoutC = void, LayoutD = void]
  331 |         mma_sync(fragment<accumulator, BlockM, BlockN, BlockK, ComputeT, LayoutD>&       d,
      |         ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:193:21: error: no matching function for call to 'mma_sync'
  193 |                     wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
      |                     ^~~~~~~~~~~~~~
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:490:24: note: in instantiation of function template specialization 'flash_attn_ext_f16<96, 16, 4, 32, __half, true>' requested here
  490 |         fattn_kernel = flash_attn_ext_f16<
      |                        ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:592:17: note: in instantiation of function template specialization 'ggml_cuda_flash_attn_ext_wmma_f16_case<96, 16, __half>' requested here
  592 |                 ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
      |                 ^
/opt/rocm-6.4.1/lib/llvm/bin/../../../include/rocwmma/rocwmma_impl.hpp:331:9: note: candidate template ignored: substitution failure [with BlockM = 16, BlockN = 16, BlockK = 16, InputT = __half, ComputeT = __half, LayoutA = rocwmma::row_major, LayoutB = rocwmma::col_major, LayoutC = void, LayoutD = void]
  331 |         mma_sync(fragment<accumulator, BlockM, BlockN, BlockK, ComputeT, LayoutD>&       d,
      |         ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:193:21: error: no matching function for call to 'mma_sync'
  193 |                     wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
      |                     ^~~~~~~~~~~~~~
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:486:24: note: in instantiation of function template specialization 'flash_attn_ext_f16<112, 16, 4, 16, __half, false>' requested here
  486 |         fattn_kernel = flash_attn_ext_f16<
      |                        ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:595:17: note: in instantiation of function template specialization 'ggml_cuda_flash_attn_ext_wmma_f16_case<112, 16, __half>' requested here
  595 |                 ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
      |                 ^
/opt/rocm-6.4.1/lib/llvm/bin/../../../include/rocwmma/rocwmma_impl.hpp:331:9: note: candidate template ignored: substitution failure [with BlockM = 16, BlockN = 16, BlockK = 16, InputT = __half, ComputeT = __half, LayoutA = rocwmma::row_major, LayoutB = rocwmma::col_major, LayoutC = void, LayoutD = void]
  331 |         mma_sync(fragment<accumulator, BlockM, BlockN, BlockK, ComputeT, LayoutD>&       d,
      |         ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:193:21: error: no matching function for call to 'mma_sync'
  193 |                     wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
      |                     ^~~~~~~~~~~~~~
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:490:24: note: in instantiation of function template specialization 'flash_attn_ext_f16<112, 16, 4, 16, __half, true>' requested here
  490 |         fattn_kernel = flash_attn_ext_f16<
      |                        ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:595:17: note: in instantiation of function template specialization 'ggml_cuda_flash_attn_ext_wmma_f16_case<112, 16, __half>' requested here
  595 |                 ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
      |                 ^
/opt/rocm-6.4.1/lib/llvm/bin/../../../include/rocwmma/rocwmma_impl.hpp:331:9: note: candidate template ignored: substitution failure [with BlockM = 16, BlockN = 16, BlockK = 16, InputT = __half, ComputeT = __half, LayoutA = rocwmma::row_major, LayoutB = rocwmma::col_major, LayoutC = void, LayoutD = void]
  331 |         mma_sync(fragment<accumulator, BlockM, BlockN, BlockK, ComputeT, LayoutD>&       d,
      |         ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:193:21: error: no matching function for call to 'mma_sync'
  193 |                     wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
      |                     ^~~~~~~~~~~~~~
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:486:24: note: in instantiation of function template specialization 'flash_attn_ext_f16<128, 16, 4, 64, __half, false>' requested here
  486 |         fattn_kernel = flash_attn_ext_f16<
      |                        ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:598:17: note: in instantiation of function template specialization 'ggml_cuda_flash_attn_ext_wmma_f16_case<128, 16, __half>' requested here
  598 |                 ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
      |                 ^
/opt/rocm-6.4.1/lib/llvm/bin/../../../include/rocwmma/rocwmma_impl.hpp:331:9: note: candidate template ignored: substitution failure [with BlockM = 16, BlockN = 16, BlockK = 16, InputT = __half, ComputeT = __half, LayoutA = rocwmma::row_major, LayoutB = rocwmma::col_major, LayoutC = void, LayoutD = void]
  331 |         mma_sync(fragment<accumulator, BlockM, BlockN, BlockK, ComputeT, LayoutD>&       d,
      |         ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:193:21: error: no matching function for call to 'mma_sync'
  193 |                     wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
      |                     ^~~~~~~~~~~~~~
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:490:24: note: in instantiation of function template specialization 'flash_attn_ext_f16<128, 16, 4, 64, __half, true>' requested here
  490 |         fattn_kernel = flash_attn_ext_f16<
      |                        ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:598:17: note: in instantiation of function template specialization 'ggml_cuda_flash_attn_ext_wmma_f16_case<128, 16, __half>' requested here
  598 |                 ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
      |                 ^
/opt/rocm-6.4.1/lib/llvm/bin/../../../include/rocwmma/rocwmma_impl.hpp:331:9: note: candidate template ignored: substitution failure [with BlockM = 16, BlockN = 16, BlockK = 16, InputT = __half, ComputeT = __half, LayoutA = rocwmma::row_major, LayoutB = rocwmma::col_major, LayoutC = void, LayoutD = void]
  331 |         mma_sync(fragment<accumulator, BlockM, BlockN, BlockK, ComputeT, LayoutD>&       d,
      |         ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:193:21: error: no matching function for call to 'mma_sync'
  193 |                     wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
      |                     ^~~~~~~~~~~~~~
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:486:24: note: in instantiation of function template specialization 'flash_attn_ext_f16<256, 16, 4, 64, __half, false>' requested here
  486 |         fattn_kernel = flash_attn_ext_f16<
      |                        ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:601:17: note: in instantiation of function template specialization 'ggml_cuda_flash_attn_ext_wmma_f16_case<256, 16, __half>' requested here
  601 |                 ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
      |                 ^
/opt/rocm-6.4.1/lib/llvm/bin/../../../include/rocwmma/rocwmma_impl.hpp:331:9: note: candidate template ignored: substitution failure [with BlockM = 16, BlockN = 16, BlockK = 16, InputT = __half, ComputeT = __half, LayoutA = rocwmma::row_major, LayoutB = rocwmma::col_major, LayoutC = void, LayoutD = void]
  331 |         mma_sync(fragment<accumulator, BlockM, BlockN, BlockK, ComputeT, LayoutD>&       d,
      |         ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:193:21: error: no matching function for call to 'mma_sync'
  193 |                     wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
      |                     ^~~~~~~~~~~~~~
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:490:24: note: in instantiation of function template specialization 'flash_attn_ext_f16<256, 16, 4, 64, __half, true>' requested here
  490 |         fattn_kernel = flash_attn_ext_f16<
      |                        ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:601:17: note: in instantiation of function template specialization 'ggml_cuda_flash_attn_ext_wmma_f16_case<256, 16, __half>' requested here
  601 |                 ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
      |                 ^
/opt/rocm-6.4.1/lib/llvm/bin/../../../include/rocwmma/rocwmma_impl.hpp:331:9: note: candidate template ignored: substitution failure [with BlockM = 16, BlockN = 16, BlockK = 16, InputT = __half, ComputeT = __half, LayoutA = rocwmma::row_major, LayoutB = rocwmma::col_major, LayoutC = void, LayoutD = void]
  331 |         mma_sync(fragment<accumulator, BlockM, BlockN, BlockK, ComputeT, LayoutD>&       d,
      |         ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:193:21: error: no matching function for call to 'mma_sync'
  193 |                     wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
      |                     ^~~~~~~~~~~~~~
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:486:24: note: in instantiation of function template specialization 'flash_attn_ext_f16<64, 32, 4, 64, __half, false>' requested here
  486 |         fattn_kernel = flash_attn_ext_f16<
      |                        ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:613:13: note: in instantiation of function template specialization 'ggml_cuda_flash_attn_ext_wmma_f16_case<64, 32, __half>' requested here
  613 |             ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
      |             ^
/opt/rocm-6.4.1/lib/llvm/bin/../../../include/rocwmma/rocwmma_impl.hpp:331:9: note: candidate template ignored: substitution failure [with BlockM = 16, BlockN = 16, BlockK = 16, InputT = __half, ComputeT = __half, LayoutA = rocwmma::row_major, LayoutB = rocwmma::col_major, LayoutC = void, LayoutD = void]
  331 |         mma_sync(fragment<accumulator, BlockM, BlockN, BlockK, ComputeT, LayoutD>&       d,
      |         ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:193:21: error: no matching function for call to 'mma_sync'
  193 |                     wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
      |                     ^~~~~~~~~~~~~~
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:490:24: note: in instantiation of function template specialization 'flash_attn_ext_f16<64, 32, 4, 64, __half, true>' requested here
  490 |         fattn_kernel = flash_attn_ext_f16<
      |                        ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:613:13: note: in instantiation of function template specialization 'ggml_cuda_flash_attn_ext_wmma_f16_case<64, 32, __half>' requested here
  613 |             ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
      |             ^
/opt/rocm-6.4.1/lib/llvm/bin/../../../include/rocwmma/rocwmma_impl.hpp:331:9: note: candidate template ignored: substitution failure [with BlockM = 16, BlockN = 16, BlockK = 16, InputT = __half, ComputeT = __half, LayoutA = rocwmma::row_major, LayoutB = rocwmma::col_major, LayoutC = void, LayoutD = void]
  331 |         mma_sync(fragment<accumulator, BlockM, BlockN, BlockK, ComputeT, LayoutD>&       d,
      |         ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:193:21: error: no matching function for call to 'mma_sync'
  193 |                     wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
      |                     ^~~~~~~~~~~~~~
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:486:24: note: in instantiation of function template specialization 'flash_attn_ext_f16<80, 32, 4, 16, __half, false>' requested here
  486 |         fattn_kernel = flash_attn_ext_f16<
      |                        ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:616:13: note: in instantiation of function template specialization 'ggml_cuda_flash_attn_ext_wmma_f16_case<80, 32, __half>' requested here
  616 |             ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
      |             ^
/opt/rocm-6.4.1/lib/llvm/bin/../../../include/rocwmma/rocwmma_impl.hpp:331:9: note: candidate template ignored: substitution failure [with BlockM = 16, BlockN = 16, BlockK = 16, InputT = __half, ComputeT = __half, LayoutA = rocwmma::row_major, LayoutB = rocwmma::col_major, LayoutC = void, LayoutD = void]
  331 |         mma_sync(fragment<accumulator, BlockM, BlockN, BlockK, ComputeT, LayoutD>&       d,
      |         ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:193:21: error: no matching function for call to 'mma_sync'
  193 |                     wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
      |                     ^~~~~~~~~~~~~~
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:490:24: note: in instantiation of function template specialization 'flash_attn_ext_f16<80, 32, 4, 16, __half, true>' requested here
  490 |         fattn_kernel = flash_attn_ext_f16<
      |                        ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:616:13: note: in instantiation of function template specialization 'ggml_cuda_flash_attn_ext_wmma_f16_case<80, 32, __half>' requested here
  616 |             ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
      |             ^
/opt/rocm-6.4.1/lib/llvm/bin/../../../include/rocwmma/rocwmma_impl.hpp:331:9: note: candidate template ignored: substitution failure [with BlockM = 16, BlockN = 16, BlockK = 16, InputT = __half, ComputeT = __half, LayoutA = rocwmma::row_major, LayoutB = rocwmma::col_major, LayoutC = void, LayoutD = void]
  331 |         mma_sync(fragment<accumulator, BlockM, BlockN, BlockK, ComputeT, LayoutD>&       d,
      |         ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:193:21: error: no matching function for call to 'mma_sync'
  193 |                     wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
      |                     ^~~~~~~~~~~~~~
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:486:24: note: in instantiation of function template specialization 'flash_attn_ext_f16<96, 32, 4, 32, __half, false>' requested here
  486 |         fattn_kernel = flash_attn_ext_f16<
      |                        ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:619:13: note: in instantiation of function template specialization 'ggml_cuda_flash_attn_ext_wmma_f16_case<96, 32, __half>' requested here
  619 |             ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
      |             ^
/opt/rocm-6.4.1/lib/llvm/bin/../../../include/rocwmma/rocwmma_impl.hpp:331:9: note: candidate template ignored: substitution failure [with BlockM = 16, BlockN = 16, BlockK = 16, InputT = __half, ComputeT = __half, LayoutA = rocwmma::row_major, LayoutB = rocwmma::col_major, LayoutC = void, LayoutD = void]
  331 |         mma_sync(fragment<accumulator, BlockM, BlockN, BlockK, ComputeT, LayoutD>&       d,
      |         ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:193:21: error: no matching function for call to 'mma_sync'
  193 |                     wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
      |                     ^~~~~~~~~~~~~~
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:490:24: note: in instantiation of function template specialization 'flash_attn_ext_f16<96, 32, 4, 32, __half, true>' requested here
  490 |         fattn_kernel = flash_attn_ext_f16<
      |                        ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:619:13: note: in instantiation of function template specialization 'ggml_cuda_flash_attn_ext_wmma_f16_case<96, 32, __half>' requested here
  619 |             ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
      |             ^
/opt/rocm-6.4.1/lib/llvm/bin/../../../include/rocwmma/rocwmma_impl.hpp:331:9: note: candidate template ignored: substitution failure [with BlockM = 16, BlockN = 16, BlockK = 16, InputT = __half, ComputeT = __half, LayoutA = rocwmma::row_major, LayoutB = rocwmma::col_major, LayoutC = void, LayoutD = void]
  331 |         mma_sync(fragment<accumulator, BlockM, BlockN, BlockK, ComputeT, LayoutD>&       d,
      |         ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:193:21: error: no matching function for call to 'mma_sync'
  193 |                     wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
      |                     ^~~~~~~~~~~~~~
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:486:24: note: in instantiation of function template specialization 'flash_attn_ext_f16<112, 32, 4, 16, __half, false>' requested here
  486 |         fattn_kernel = flash_attn_ext_f16<
      |                        ^
/home/martin/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu:622:13: note: in instantiation of function template specialization 'ggml_cuda_flash_attn_ext_wmma_f16_case<112, 32, __half>' requested here
  622 |             ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
      |             ^
/opt/rocm-6.4.1/lib/llvm/bin/../../../include/rocwmma/rocwmma_impl.hpp:331:9: note: candidate template ignored: substitution failure [with BlockM = 16, BlockN = 16, BlockK = 16, InputT = __half, ComputeT = __half, LayoutA = rocwmma::row_major, LayoutB = rocwmma::col_major, LayoutC = void, LayoutD = void]
  331 |         mma_sync(fragment<accumulator, BlockM, BlockN, BlockK, ComputeT, LayoutD>&       d,
      |         ^
fatal error: too many errors emitted, stopping now [-ferror-limit=]
1 warning and 20 errors generated when compiling for gfx1201.
gmake[2]: *** [ggml/src/ggml-hip/CMakeFiles/ggml-hip.dir/build.make:273: ggml/src/ggml-hip/CMakeFiles/ggml-hip.dir/__/ggml-cuda/fattn-wmma-f16.cu.o] Error 1
gmake[2]: *** Waiting for unfinished jobs....
In file included from /home/martin/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu:3:
In file included from /home/martin/llama.cpp/ggml/src/ggml-cuda/template-instances/../fattn-vec-f32.cuh:1:
/home/martin/llama.cpp/ggml/src/ggml-cuda/template-instances/../common.cuh:266:12: warning: macro '__AMDGCN_WAVEFRONT_SIZE' has been marked as deprecated: compile-time-constant access to the wavefront size will be removed in a future release [-Wdeprecated-pragma]
  266 |     return __AMDGCN_WAVEFRONT_SIZE;
      |            ^
<built-in>:891:139: note: macro marked 'deprecated' here
  891 | #pragma clang deprecated(__AMDGCN_WAVEFRONT_SIZE, "compile-time-constant access to the wavefront size will be removed in a future release")
      |                                                                                                                                           ^
1 warning generated when compiling for gfx1201.
In file included from /home/martin/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu:3:
In file included from /home/martin/llama.cpp/ggml/src/ggml-cuda/template-instances/../fattn-vec-f32.cuh:1:
/home/martin/llama.cpp/ggml/src/ggml-cuda/template-instances/../common.cuh:266:12: warning: macro '__AMDGCN_WAVEFRONT_SIZE' has been marked as deprecated: compile-time-constant access to the wavefront size will be removed in a future release [-Wdeprecated-pragma]
  266 |     return __AMDGCN_WAVEFRONT_SIZE;
      |            ^
<built-in>:430:139: note: macro marked 'deprecated' here
  430 | #pragma clang deprecated(__AMDGCN_WAVEFRONT_SIZE, "compile-time-constant access to the wavefront size will be removed in a future release")
      |                                                                                                                                           ^
1 warning generated when compiling for host.
In file included from /home/martin/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu:3:
In file included from /home/martin/llama.cpp/ggml/src/ggml-cuda/template-instances/../fattn-vec-f16.cuh:1:
/home/martin/llama.cpp/ggml/src/ggml-cuda/template-instances/../common.cuh:266:12: warning: macro '__AMDGCN_WAVEFRONT_SIZE' has been marked as deprecated: compile-time-constant access to the wavefront size will be removed in a future release [-Wdeprecated-pragma]
  266 |     return __AMDGCN_WAVEFRONT_SIZE;
      |            ^
<built-in>:891:139: note: macro marked 'deprecated' here
  891 | #pragma clang deprecated(__AMDGCN_WAVEFRONT_SIZE, "compile-time-constant access to the wavefront size will be removed in a future release")
      |                                                                                                                                           ^
1 warning generated when compiling for gfx1201.
In file included from /home/martin/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu:3:
In file included from /home/martin/llama.cpp/ggml/src/ggml-cuda/template-instances/../fattn-vec-f16.cuh:1:
/home/martin/llama.cpp/ggml/src/ggml-cuda/template-instances/../common.cuh:266:12: warning: macro '__AMDGCN_WAVEFRONT_SIZE' has been marked as deprecated: compile-time-constant access to the wavefront size will be removed in a future release [-Wdeprecated-pragma]
  266 |     return __AMDGCN_WAVEFRONT_SIZE;
      |            ^
<built-in>:430:139: note: macro marked 'deprecated' here
  430 | #pragma clang deprecated(__AMDGCN_WAVEFRONT_SIZE, "compile-time-constant access to the wavefront size will be removed in a future release")
      |                                                                                                                                           ^
1 warning generated when compiling for host.
In file included from /home/martin/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu:3:
In file included from /home/martin/llama.cpp/ggml/src/ggml-cuda/template-instances/../fattn-vec-f32.cuh:1:
/home/martin/llama.cpp/ggml/src/ggml-cuda/template-instances/../common.cuh:266:12: warning: macro '__AMDGCN_WAVEFRONT_SIZE' has been marked as deprecated: compile-time-constant access to the wavefront size will be removed in a future release [-Wdeprecated-pragma]
  266 |     return __AMDGCN_WAVEFRONT_SIZE;
      |            ^
<built-in>:891:139: note: macro marked 'deprecated' here
  891 | #pragma clang deprecated(__AMDGCN_WAVEFRONT_SIZE, "compile-time-constant access to the wavefront size will be removed in a future release")
      |                                                                                                                                           ^
1 warning generated when compiling for gfx1201.
In file included from /home/martin/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu:3:
In file included from /home/martin/llama.cpp/ggml/src/ggml-cuda/template-instances/../fattn-vec-f32.cuh:1:
/home/martin/llama.cpp/ggml/src/ggml-cuda/template-instances/../common.cuh:266:12: warning: macro '__AMDGCN_WAVEFRONT_SIZE' has been marked as deprecated: compile-time-constant access to the wavefront size will be removed in a future release [-Wdeprecated-pragma]
  266 |     return __AMDGCN_WAVEFRONT_SIZE;
      |            ^
<built-in>:430:139: note: macro marked 'deprecated' here
  430 | #pragma clang deprecated(__AMDGCN_WAVEFRONT_SIZE, "compile-time-constant access to the wavefront size will be removed in a future release")
      |                                                                                                                                           ^
1 warning generated when compiling for host.
In file included from /home/martin/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu:3:
In file included from /home/martin/llama.cpp/ggml/src/ggml-cuda/template-instances/../fattn-vec-f16.cuh:1:
/home/martin/llama.cpp/ggml/src/ggml-cuda/template-instances/../common.cuh:266:12: warning: macro '__AMDGCN_WAVEFRONT_SIZE' has been marked as deprecated: compile-time-constant access to the wavefront size will be removed in a future release [-Wdeprecated-pragma]
  266 |     return __AMDGCN_WAVEFRONT_SIZE;
      |            ^
<built-in>:891:139: note: macro marked 'deprecated' here
  891 | #pragma clang deprecated(__AMDGCN_WAVEFRONT_SIZE, "compile-time-constant access to the wavefront size will be removed in a future release")
      |                                                                                                                                           ^
1 warning generated when compiling for gfx1201.
In file included from /home/martin/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu:3:
In file included from /home/martin/llama.cpp/ggml/src/ggml-cuda/template-instances/../fattn-vec-f16.cuh:1:
/home/martin/llama.cpp/ggml/src/ggml-cuda/template-instances/../common.cuh:266:12: warning: macro '__AMDGCN_WAVEFRONT_SIZE' has been marked as deprecated: compile-time-constant access to the wavefront size will be removed in a future release [-Wdeprecated-pragma]
  266 |     return __AMDGCN_WAVEFRONT_SIZE;
      |            ^
<built-in>:430:139: note: macro marked 'deprecated' here
  430 | #pragma clang deprecated(__AMDGCN_WAVEFRONT_SIZE, "compile-time-constant access to the wavefront size will be removed in a future release")
      |                                                                                                                                           ^
1 warning generated when compiling for host.
In file included from /home/martin/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu:3:
In file included from /home/martin/llama.cpp/ggml/src/ggml-cuda/template-instances/../fattn-vec-f32.cuh:1:
/home/martin/llama.cpp/ggml/src/ggml-cuda/template-instances/../common.cuh:266:12: warning: macro '__AMDGCN_WAVEFRONT_SIZE' has been marked as deprecated: compile-time-constant access to the wavefront size will be removed in a future release [-Wdeprecated-pragma]
  266 |     return __AMDGCN_WAVEFRONT_SIZE;
      |            ^
<built-in>:891:139: note: macro marked 'deprecated' here
  891 | #pragma clang deprecated(__AMDGCN_WAVEFRONT_SIZE, "compile-time-constant access to the wavefront size will be removed in a future release")
      |                                                                                                                                           ^
1 warning generated when compiling for gfx1201.
In file included from /home/martin/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu:3:
In file included from /home/martin/llama.cpp/ggml/src/ggml-cuda/template-instances/../fattn-vec-f32.cuh:1:
/home/martin/llama.cpp/ggml/src/ggml-cuda/template-instances/../common.cuh:266:12: warning: macro '__AMDGCN_WAVEFRONT_SIZE' has been marked as deprecated: compile-time-constant access to the wavefront size will be removed in a future release [-Wdeprecated-pragma]
  266 |     return __AMDGCN_WAVEFRONT_SIZE;
      |            ^
<built-in>:430:139: note: macro marked 'deprecated' here
  430 | #pragma clang deprecated(__AMDGCN_WAVEFRONT_SIZE, "compile-time-constant access to the wavefront size will be removed in a future release")
      |                                                                                                                                           ^
1 warning generated when compiling for host.
gmake[1]: *** [CMakeFiles/Makefile2:2223: ggml/src/ggml-hip/CMakeFiles/ggml-hip.dir/all] Error 2
gmake: *** [Makefile:146: all] Error 2

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions