Skip to content

Segmentation fault during inference on AMD gfx900 with codebooga-34b-v0.1.Q5_K_M.gguf #6031

@jin-eld

Description

@jin-eld

Hi,

I compiled llama.cpp from git, todays master HEAD commit 8030da7afea2d89f997aeadbd14183d399a017b9 on Fedora Rawhide (ROCm 6.0.x) like this:

CC=/usr/bin/clang CXX=/usr/bin/clang++ cmake .. -DLLAMA_HIPBLAS=ON -DAMDGPU_TARGETS=gfx900 -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_FLAGS="--rocm-device-lib-path=/usr/lib/clang/17/amdgcn/bitcode"
make -j 16

Then I tried to run a prompt using the codebooga-34b-v0.1.Q5_K_M.gguf model which I got from here: https://huggingface.co/TheBloke/CodeBooga-34B-v0.1-GGUF

I kept the prompt simple and used the following command:
./main -t 10 -ngl 16 -m ~/models/codebooga-34b-v0.1.Q5_K_M.gguf --color -c 2048 --temp 0.7 --repeat_penalty 1.1 -n -1 -p "### Instruction: How do I get the length of a Vec in Rust?\n### Response:"

I have an AMD Instinct MI25 card with 16GB VRAM, according to nvtop with -ngl 16 about half of it is used 8.219Gi/15.984, so this does not seem to be an OOM issue.

The console output looks like this:

Log start                                                                       
main: build = 2408 (8030da7a)
main: built with clang version 18.1.0 (Fedora 18.1.0~rc4-2.fc41) for x86_64-redhat-linux-gnu
main: seed  = 1710292844
[New Thread 0x7fff074006c0 (LWP 11038)]
[New Thread 0x7ffe068006c0 (LWP 11039)]
[Thread 0x7ffe068006c0 (LWP 11039) exited]
ggml_init_cublas: GGML_CUDA_FORCE_MMQ:   no
ggml_init_cublas: CUDA_USE_TENSOR_CORES: yes
ggml_init_cublas: found 1 ROCm devices:
  Device 0: AMD Radeon Instinct MI25, compute capability 9.0, VMM: no
llama_model_loader: loaded meta data with 21 key-value pairs and 435 tensors from /home/jin/Work/text-generation-webui/models/codebooga-34b-v0.1.Q5_K_M.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.name str              = oobabooga_codebooga-34b-v0.1
llama_model_loader: - kv   2:                       llama.context_length u32              = 16384
llama_model_loader: - kv   3:                     llama.embedding_length u32              = 8192
llama_model_loader: - kv   4:                          llama.block_count u32              = 48
llama_model_loader: - kv   5:                  llama.feed_forward_length u32              = 22016
llama_model_loader: - kv   6:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv   7:                 llama.attention.head_count u32              = 64
llama_model_loader: - kv   8:              llama.attention.head_count_kv u32              = 8
llama_model_loader: - kv   9:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  10:                       llama.rope.freq_base f32              = 1000000.000000
llama_model_loader: - kv  11:                          general.file_type u32              = 17
llama_model_loader: - kv  12:                       tokenizer.ggml.model str              = llama
llama_model_loader: - kv  13:                      tokenizer.ggml.tokens arr[str,32000]   = ["<unk>", "<s>", "</s>", "<0x00>", "<...
llama_model_loader: - kv  14:                      tokenizer.ggml.scores arr[f32,32000]   = [0.000000, 0.000000, 0.000000, 0.0000...
llama_model_loader: - kv  15:                  tokenizer.ggml.token_type arr[i32,32000]   = [2, 3, 3, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...
llama_model_loader: - kv  16:                tokenizer.ggml.bos_token_id u32              = 1
llama_model_loader: - kv  17:                tokenizer.ggml.eos_token_id u32              = 2
llama_model_loader: - kv  18:            tokenizer.ggml.unknown_token_id u32              = 0
llama_model_loader: - kv  19:            tokenizer.ggml.padding_token_id u32              = 2
llama_model_loader: - kv  20:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:   97 tensors
llama_model_loader: - type q5_K:  289 tensors
llama_model_loader: - type q6_K:   49 tensors
llm_load_vocab: special tokens definition check successful ( 259/32000 ).
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = llama
llm_load_print_meta: vocab type       = SPM
llm_load_print_meta: n_vocab          = 32000
llm_load_print_meta: n_merges         = 0
llm_load_print_meta: n_ctx_train      = 16384
llm_load_print_meta: n_embd           = 8192
llm_load_print_meta: n_head           = 64
llm_load_print_meta: n_head_kv        = 8
llm_load_print_meta: n_layer          = 48
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 8
llm_load_print_meta: n_embd_k_gqa     = 1024
llm_load_print_meta: n_embd_v_gqa     = 1024
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: n_ff             = 22016
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attm      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 1000000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx  = 16384
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: model type       = 34B
llm_load_print_meta: model ftype      = Q5_K - Medium
llm_load_print_meta: model params     = 33.74 B
llm_load_print_meta: model size       = 22.20 GiB (5.65 BPW) 
llm_load_print_meta: general.name     = oobabooga_codebooga-34b-v0.1
llm_load_print_meta: BOS token        = 1 '<s>'
llm_load_print_meta: EOS token        = 2 '</s>'
llm_load_print_meta: UNK token        = 0 '<unk>'
llm_load_print_meta: PAD token        = 2 '</s>'
llm_load_print_meta: LF token         = 13 '<0x0A>'
llm_load_tensors: ggml ctx size =    0.33 MiB
llm_load_tensors: offloading 16 repeating layers to GPU
llm_load_tensors: offloaded 16/49 layers to GPU
llm_load_tensors:      ROCm0 buffer size =  7500.06 MiB
llm_load_tensors:        CPU buffer size = 22733.73 MiB
....................................................................................................
llama_new_context_with_model: n_ctx      = 2048
llama_new_context_with_model: freq_base  = 1000000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:      ROCm0 KV buffer size =   128.00 MiB
llama_kv_cache_init:  ROCm_Host KV buffer size =   256.00 MiB
llama_new_context_with_model: KV self size  =  384.00 MiB, K (f16):  192.00 MiB, V (f16):  192.00 MiB
llama_new_context_with_model:  ROCm_Host input buffer size   =    21.02 MiB
ggml_gallocr_reserve_n: reallocating ROCm0 buffer from size 0.00 MiB to 324.00 MiB
ggml_gallocr_reserve_n: reallocating ROCm_Host buffer from size 0.00 MiB to 336.00 MiB
llama_new_context_with_model:      ROCm0 compute buffer size =   324.00 MiB
llama_new_context_with_model:  ROCm_Host compute buffer size =   336.00 MiB
llama_new_context_with_model: graph splits (measure): 3

Shortly after I get a segfault, although sometimes it starts responding and crashes a few seconds into the response:

(gdb) bt
#0  amd::KernelParameters::set (this=0x1d9cb10, index=11, size=4, 
    value=0x100000020, svmBound=false)
    at /usr/src/debug/rocclr-6.0.2-1.fc41.x86_64/rocclr/platform/kernel.cpp:127
#1  0x00007fffb9822b7c in ihipLaunchKernel_validate (f=f@entry=0x3281e20, 
    globalWorkSizeX=globalWorkSizeX@entry=4096, 
    globalWorkSizeY=globalWorkSizeY@entry=1, 
    globalWorkSizeZ=globalWorkSizeZ@entry=1, blockDimX=blockDimX@entry=32, 
    blockDimY=blockDimY@entry=1, blockDimZ=1, sharedMemBytes=256, 
    kernelParams=0x7fffffff7430, extra=0x0, deviceId=0, params=0)
    at /usr/src/debug/rocclr-6.0.2-1.fc41.x86_64/hipamd/src/hip_module.cpp:301
#2  0x00007fffb98273fd in ihipModuleLaunchKernel (f=0x3281e20, 
    globalWorkSizeX=4096, globalWorkSizeY=1, globalWorkSizeZ=1, blockDimX=32, 
    blockDimY=1, blockDimZ=1, sharedMemBytes=256, hStream=0x195d320, 
    kernelParams=0x7fffffff7430, extra=0x0, startEvent=0x0, stopEvent=0x0, 
    flags=0, params=0, gridId=0, numGrids=0, prevGridSum=0, allGridSum=0, 
    firstDevice=0)
    at /usr/src/debug/rocclr-6.0.2-1.fc41.x86_64/hipamd/src/hip_module.cpp:371
#3  0x00007fffb98492a2 in ihipLaunchKernel (
    hostFunction=0x679308 <void soft_max_f32<true, 32, 32>(float const*, float const*, float const*, float*, int, int, float, float, float, float, unsigned int)--Type <RET> for more, q to quit, c to continue without paging--
>, gridDim=..., blockDim=..., args=0x7fffffff7430, sharedMemBytes=256, 
    stream=0x195d320, startEvent=0x0, stopEvent=0x0, flags=0)
    at /usr/src/debug/rocclr-6.0.2-1.fc41.x86_64/hipamd/src/hip_platform.cpp:584
#4  0x00007fffb9822519 in hipLaunchKernel_common (
    hostFunction=hostFunction@entry=0x679308 <void soft_max_f32<true, 32, 32>(float const*, float const*, float const*, float*, int, int, float, float, float, float, unsigned int)>, gridDim=..., blockDim=..., 
    args=args@entry=0x7fffffff7430, sharedMemBytes=256, stream=<optimized out>)
    at /usr/src/debug/rocclr-6.0.2-1.fc41.x86_64/hipamd/src/hip_module.cpp:662
#5  0x00007fffb9824b83 in hipLaunchKernel (hostFunction=<optimized out>, 
    gridDim=..., blockDim=..., args=0x7fffffff7430, 
    sharedMemBytes=<optimized out>, stream=<optimized out>)
    at /usr/src/debug/rocclr-6.0.2-1.fc41.x86_64/hipamd/src/hip_module.cpp:669
#6  0x000000000062ea50 in void __device_stub__soft_max_f32<true, 32, 32>(float const*, float const*, float const*, float*, int, int, float, float, float, float, unsigned int) ()
#7  0x000000000062e1f9 in soft_max_f32_cuda (x=0x7ff65e400800, 
    mask=0x7ff65c000800, pos=0x0, dst=0x7ff65e400800, ncols_x=32, nrows_x=128, 
    nrows_y=2, scale=0.0883883461, max_bias=0, stream=0x195d320)
    at /llama.cpp/ggml-cuda.cu:7505
#8  0x000000000062ded6 in ggml_cuda_op_soft_max (src0=0x7ff66eca9450, 
    src1=0x7ff66e80ee50, dst=0x7ff66eca95e0, src0_dd=0x7ff65e400800, 
    src1_dd=0x7ff65c000800, dst_dd=0x7ff65e400800, main_stream=0x195d320)
    at /llama.cpp/ggml-cuda.cu:9053
#9  0x00000000005f98f7 in ggml_cuda_op_flatten (src0=0x7ff66eca9450, 
    src1=0x7ff66e80ee50, dst=0x7ff66eca95e0, 
    op=0x62db50 <ggml_cuda_op_soft_max(ggml_tensor const*, ggml_tensor const*, ggml_tensor*, float const*, float const*, float*, ihipStream_t*)>)
    at /llama.cpp/ggml-cuda.cu:9145
#10 0x00000000005f856f in ggml_cuda_soft_max (src0=0x7ff66eca9450, 
    src1=0x7ff66e80ee50, dst=0x7ff66eca95e0)
    at /llama.cpp/ggml-cuda.cu:10393
#11 0x00000000005f5cb8 in ggml_cuda_compute_forward (params=0x7fffffff7b78, 
    tensor=0x7ff66eca95e0) at /llama.cpp/ggml-cuda.cu:10619
#12 0x0000000000635106 in ggml_backend_cuda_graph_compute (backend=0x19e1420, 
    cgraph=0x7ff66e8002d8) at /llama.cpp/ggml-cuda.cu:11310
#13 0x00000000005c1d42 in ggml_backend_graph_compute (backend=0x19e1420, 
    cgraph=0x7ff66e8002d8) at /llama.cpp/ggml-backend.c:270
#14 0x00000000005c55c3 in ggml_backend_sched_compute_splits (
--Type <RET> for more, q to quit, c to continue without paging--
    sched=0x7ff66e800010) at /llama.cpp/ggml-backend.c:1474
#15 0x00000000005c5237 in ggml_backend_sched_graph_compute (
    sched=0x7ff66e800010, graph=0x7ff66ec00030)
    at /llama.cpp/ggml-backend.c:1597
#16 0x00000000004f85e9 in llama_graph_compute (lctx=..., gf=0x7ff66ec00030, 
    n_threads=10) at /llama.cpp/llama.cpp:8733
#17 0x00000000004b7926 in llama_decode_internal (lctx=..., batch=...)
    at /llama.cpp/llama.cpp:8887
#18 0x00000000004b6fc3 in llama_decode (ctx=0x19f7b60, batch=...)
    at /llama.cpp/llama.cpp:13837
#19 0x0000000000452e95 in llama_init_from_gpt_params (params=...)
    at /llama.cpp/common/common.cpp:1380
#20 0x000000000042c0a5 in main (argc=18, argv=0x7fffffffdac8)
    at /llama.cpp/examples/main/main.cpp:199

I saw some issues about partial offloading and also tried a smaller model which should completely fit on my GPU, but the segfault was still there, the smaller model is this one:

llm_load_print_meta: model type       = 13B
llm_load_print_meta: model ftype      = Q8_0
llm_load_print_meta: model params     = 13.02 B
llm_load_print_meta: model size       = 12.88 GiB (8.50 BPW) 
llm_load_print_meta: general.name     = newhope.ggmlv3.q8_0.bin
llm_load_print_meta: BOS token        = 1 '<s>'
llm_load_print_meta: EOS token        = 2 '</s>'
llm_load_print_meta: UNK token        = 0 '<unk>'
llm_load_print_meta: LF token         = 13 '<0x0A>'
llm_load_tensors: ggml ctx size =    0.28 MiB
llm_load_tensors: offloading 40 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 41/41 layers to GPU
llm_load_tensors:      ROCm0 buffer size = 13023.85 MiB
llm_load_tensors:        CPU buffer size =   166.02 MiB

Crashed as well with a very similar backtrace.

Since this is nicely reproducible, I can provide more more info or add some debug logs as needed, please let me know what you need.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions