-
Notifications
You must be signed in to change notification settings - Fork 13.7k
Description
Hi,
I compiled llama.cpp from git, todays master HEAD commit 8030da7afea2d89f997aeadbd14183d399a017b9 on Fedora Rawhide (ROCm 6.0.x) like this:
CC=/usr/bin/clang CXX=/usr/bin/clang++ cmake .. -DLLAMA_HIPBLAS=ON -DAMDGPU_TARGETS=gfx900 -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_FLAGS="--rocm-device-lib-path=/usr/lib/clang/17/amdgcn/bitcode"
make -j 16
Then I tried to run a prompt using the codebooga-34b-v0.1.Q5_K_M.gguf model which I got from here: https://huggingface.co/TheBloke/CodeBooga-34B-v0.1-GGUF
I kept the prompt simple and used the following command:
./main -t 10 -ngl 16 -m ~/models/codebooga-34b-v0.1.Q5_K_M.gguf --color -c 2048 --temp 0.7 --repeat_penalty 1.1 -n -1 -p "### Instruction: How do I get the length of a Vec in Rust?\n### Response:"
I have an AMD Instinct MI25 card with 16GB VRAM, according to nvtop with -ngl 16 about half of it is used 8.219Gi/15.984, so this does not seem to be an OOM issue.
The console output looks like this:
Log start
main: build = 2408 (8030da7a)
main: built with clang version 18.1.0 (Fedora 18.1.0~rc4-2.fc41) for x86_64-redhat-linux-gnu
main: seed = 1710292844
[New Thread 0x7fff074006c0 (LWP 11038)]
[New Thread 0x7ffe068006c0 (LWP 11039)]
[Thread 0x7ffe068006c0 (LWP 11039) exited]
ggml_init_cublas: GGML_CUDA_FORCE_MMQ: no
ggml_init_cublas: CUDA_USE_TENSOR_CORES: yes
ggml_init_cublas: found 1 ROCm devices:
Device 0: AMD Radeon Instinct MI25, compute capability 9.0, VMM: no
llama_model_loader: loaded meta data with 21 key-value pairs and 435 tensors from /home/jin/Work/text-generation-webui/models/codebooga-34b-v0.1.Q5_K_M.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv 0: general.architecture str = llama
llama_model_loader: - kv 1: general.name str = oobabooga_codebooga-34b-v0.1
llama_model_loader: - kv 2: llama.context_length u32 = 16384
llama_model_loader: - kv 3: llama.embedding_length u32 = 8192
llama_model_loader: - kv 4: llama.block_count u32 = 48
llama_model_loader: - kv 5: llama.feed_forward_length u32 = 22016
llama_model_loader: - kv 6: llama.rope.dimension_count u32 = 128
llama_model_loader: - kv 7: llama.attention.head_count u32 = 64
llama_model_loader: - kv 8: llama.attention.head_count_kv u32 = 8
llama_model_loader: - kv 9: llama.attention.layer_norm_rms_epsilon f32 = 0.000010
llama_model_loader: - kv 10: llama.rope.freq_base f32 = 1000000.000000
llama_model_loader: - kv 11: general.file_type u32 = 17
llama_model_loader: - kv 12: tokenizer.ggml.model str = llama
llama_model_loader: - kv 13: tokenizer.ggml.tokens arr[str,32000] = ["<unk>", "<s>", "</s>", "<0x00>", "<...
llama_model_loader: - kv 14: tokenizer.ggml.scores arr[f32,32000] = [0.000000, 0.000000, 0.000000, 0.0000...
llama_model_loader: - kv 15: tokenizer.ggml.token_type arr[i32,32000] = [2, 3, 3, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...
llama_model_loader: - kv 16: tokenizer.ggml.bos_token_id u32 = 1
llama_model_loader: - kv 17: tokenizer.ggml.eos_token_id u32 = 2
llama_model_loader: - kv 18: tokenizer.ggml.unknown_token_id u32 = 0
llama_model_loader: - kv 19: tokenizer.ggml.padding_token_id u32 = 2
llama_model_loader: - kv 20: general.quantization_version u32 = 2
llama_model_loader: - type f32: 97 tensors
llama_model_loader: - type q5_K: 289 tensors
llama_model_loader: - type q6_K: 49 tensors
llm_load_vocab: special tokens definition check successful ( 259/32000 ).
llm_load_print_meta: format = GGUF V3 (latest)
llm_load_print_meta: arch = llama
llm_load_print_meta: vocab type = SPM
llm_load_print_meta: n_vocab = 32000
llm_load_print_meta: n_merges = 0
llm_load_print_meta: n_ctx_train = 16384
llm_load_print_meta: n_embd = 8192
llm_load_print_meta: n_head = 64
llm_load_print_meta: n_head_kv = 8
llm_load_print_meta: n_layer = 48
llm_load_print_meta: n_rot = 128
llm_load_print_meta: n_embd_head_k = 128
llm_load_print_meta: n_embd_head_v = 128
llm_load_print_meta: n_gqa = 8
llm_load_print_meta: n_embd_k_gqa = 1024
llm_load_print_meta: n_embd_v_gqa = 1024
llm_load_print_meta: f_norm_eps = 0.0e+00
llm_load_print_meta: f_norm_rms_eps = 1.0e-05
llm_load_print_meta: f_clamp_kqv = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: n_ff = 22016
llm_load_print_meta: n_expert = 0
llm_load_print_meta: n_expert_used = 0
llm_load_print_meta: causal attm = 1
llm_load_print_meta: pooling type = 0
llm_load_print_meta: rope type = 0
llm_load_print_meta: rope scaling = linear
llm_load_print_meta: freq_base_train = 1000000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx = 16384
llm_load_print_meta: rope_finetuned = unknown
llm_load_print_meta: ssm_d_conv = 0
llm_load_print_meta: ssm_d_inner = 0
llm_load_print_meta: ssm_d_state = 0
llm_load_print_meta: ssm_dt_rank = 0
llm_load_print_meta: model type = 34B
llm_load_print_meta: model ftype = Q5_K - Medium
llm_load_print_meta: model params = 33.74 B
llm_load_print_meta: model size = 22.20 GiB (5.65 BPW)
llm_load_print_meta: general.name = oobabooga_codebooga-34b-v0.1
llm_load_print_meta: BOS token = 1 '<s>'
llm_load_print_meta: EOS token = 2 '</s>'
llm_load_print_meta: UNK token = 0 '<unk>'
llm_load_print_meta: PAD token = 2 '</s>'
llm_load_print_meta: LF token = 13 '<0x0A>'
llm_load_tensors: ggml ctx size = 0.33 MiB
llm_load_tensors: offloading 16 repeating layers to GPU
llm_load_tensors: offloaded 16/49 layers to GPU
llm_load_tensors: ROCm0 buffer size = 7500.06 MiB
llm_load_tensors: CPU buffer size = 22733.73 MiB
....................................................................................................
llama_new_context_with_model: n_ctx = 2048
llama_new_context_with_model: freq_base = 1000000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init: ROCm0 KV buffer size = 128.00 MiB
llama_kv_cache_init: ROCm_Host KV buffer size = 256.00 MiB
llama_new_context_with_model: KV self size = 384.00 MiB, K (f16): 192.00 MiB, V (f16): 192.00 MiB
llama_new_context_with_model: ROCm_Host input buffer size = 21.02 MiB
ggml_gallocr_reserve_n: reallocating ROCm0 buffer from size 0.00 MiB to 324.00 MiB
ggml_gallocr_reserve_n: reallocating ROCm_Host buffer from size 0.00 MiB to 336.00 MiB
llama_new_context_with_model: ROCm0 compute buffer size = 324.00 MiB
llama_new_context_with_model: ROCm_Host compute buffer size = 336.00 MiB
llama_new_context_with_model: graph splits (measure): 3
Shortly after I get a segfault, although sometimes it starts responding and crashes a few seconds into the response:
(gdb) bt
#0 amd::KernelParameters::set (this=0x1d9cb10, index=11, size=4,
value=0x100000020, svmBound=false)
at /usr/src/debug/rocclr-6.0.2-1.fc41.x86_64/rocclr/platform/kernel.cpp:127
#1 0x00007fffb9822b7c in ihipLaunchKernel_validate (f=f@entry=0x3281e20,
globalWorkSizeX=globalWorkSizeX@entry=4096,
globalWorkSizeY=globalWorkSizeY@entry=1,
globalWorkSizeZ=globalWorkSizeZ@entry=1, blockDimX=blockDimX@entry=32,
blockDimY=blockDimY@entry=1, blockDimZ=1, sharedMemBytes=256,
kernelParams=0x7fffffff7430, extra=0x0, deviceId=0, params=0)
at /usr/src/debug/rocclr-6.0.2-1.fc41.x86_64/hipamd/src/hip_module.cpp:301
#2 0x00007fffb98273fd in ihipModuleLaunchKernel (f=0x3281e20,
globalWorkSizeX=4096, globalWorkSizeY=1, globalWorkSizeZ=1, blockDimX=32,
blockDimY=1, blockDimZ=1, sharedMemBytes=256, hStream=0x195d320,
kernelParams=0x7fffffff7430, extra=0x0, startEvent=0x0, stopEvent=0x0,
flags=0, params=0, gridId=0, numGrids=0, prevGridSum=0, allGridSum=0,
firstDevice=0)
at /usr/src/debug/rocclr-6.0.2-1.fc41.x86_64/hipamd/src/hip_module.cpp:371
#3 0x00007fffb98492a2 in ihipLaunchKernel (
hostFunction=0x679308 <void soft_max_f32<true, 32, 32>(float const*, float const*, float const*, float*, int, int, float, float, float, float, unsigned int)--Type <RET> for more, q to quit, c to continue without paging--
>, gridDim=..., blockDim=..., args=0x7fffffff7430, sharedMemBytes=256,
stream=0x195d320, startEvent=0x0, stopEvent=0x0, flags=0)
at /usr/src/debug/rocclr-6.0.2-1.fc41.x86_64/hipamd/src/hip_platform.cpp:584
#4 0x00007fffb9822519 in hipLaunchKernel_common (
hostFunction=hostFunction@entry=0x679308 <void soft_max_f32<true, 32, 32>(float const*, float const*, float const*, float*, int, int, float, float, float, float, unsigned int)>, gridDim=..., blockDim=...,
args=args@entry=0x7fffffff7430, sharedMemBytes=256, stream=<optimized out>)
at /usr/src/debug/rocclr-6.0.2-1.fc41.x86_64/hipamd/src/hip_module.cpp:662
#5 0x00007fffb9824b83 in hipLaunchKernel (hostFunction=<optimized out>,
gridDim=..., blockDim=..., args=0x7fffffff7430,
sharedMemBytes=<optimized out>, stream=<optimized out>)
at /usr/src/debug/rocclr-6.0.2-1.fc41.x86_64/hipamd/src/hip_module.cpp:669
#6 0x000000000062ea50 in void __device_stub__soft_max_f32<true, 32, 32>(float const*, float const*, float const*, float*, int, int, float, float, float, float, unsigned int) ()
#7 0x000000000062e1f9 in soft_max_f32_cuda (x=0x7ff65e400800,
mask=0x7ff65c000800, pos=0x0, dst=0x7ff65e400800, ncols_x=32, nrows_x=128,
nrows_y=2, scale=0.0883883461, max_bias=0, stream=0x195d320)
at /llama.cpp/ggml-cuda.cu:7505
#8 0x000000000062ded6 in ggml_cuda_op_soft_max (src0=0x7ff66eca9450,
src1=0x7ff66e80ee50, dst=0x7ff66eca95e0, src0_dd=0x7ff65e400800,
src1_dd=0x7ff65c000800, dst_dd=0x7ff65e400800, main_stream=0x195d320)
at /llama.cpp/ggml-cuda.cu:9053
#9 0x00000000005f98f7 in ggml_cuda_op_flatten (src0=0x7ff66eca9450,
src1=0x7ff66e80ee50, dst=0x7ff66eca95e0,
op=0x62db50 <ggml_cuda_op_soft_max(ggml_tensor const*, ggml_tensor const*, ggml_tensor*, float const*, float const*, float*, ihipStream_t*)>)
at /llama.cpp/ggml-cuda.cu:9145
#10 0x00000000005f856f in ggml_cuda_soft_max (src0=0x7ff66eca9450,
src1=0x7ff66e80ee50, dst=0x7ff66eca95e0)
at /llama.cpp/ggml-cuda.cu:10393
#11 0x00000000005f5cb8 in ggml_cuda_compute_forward (params=0x7fffffff7b78,
tensor=0x7ff66eca95e0) at /llama.cpp/ggml-cuda.cu:10619
#12 0x0000000000635106 in ggml_backend_cuda_graph_compute (backend=0x19e1420,
cgraph=0x7ff66e8002d8) at /llama.cpp/ggml-cuda.cu:11310
#13 0x00000000005c1d42 in ggml_backend_graph_compute (backend=0x19e1420,
cgraph=0x7ff66e8002d8) at /llama.cpp/ggml-backend.c:270
#14 0x00000000005c55c3 in ggml_backend_sched_compute_splits (
--Type <RET> for more, q to quit, c to continue without paging--
sched=0x7ff66e800010) at /llama.cpp/ggml-backend.c:1474
#15 0x00000000005c5237 in ggml_backend_sched_graph_compute (
sched=0x7ff66e800010, graph=0x7ff66ec00030)
at /llama.cpp/ggml-backend.c:1597
#16 0x00000000004f85e9 in llama_graph_compute (lctx=..., gf=0x7ff66ec00030,
n_threads=10) at /llama.cpp/llama.cpp:8733
#17 0x00000000004b7926 in llama_decode_internal (lctx=..., batch=...)
at /llama.cpp/llama.cpp:8887
#18 0x00000000004b6fc3 in llama_decode (ctx=0x19f7b60, batch=...)
at /llama.cpp/llama.cpp:13837
#19 0x0000000000452e95 in llama_init_from_gpt_params (params=...)
at /llama.cpp/common/common.cpp:1380
#20 0x000000000042c0a5 in main (argc=18, argv=0x7fffffffdac8)
at /llama.cpp/examples/main/main.cpp:199
I saw some issues about partial offloading and also tried a smaller model which should completely fit on my GPU, but the segfault was still there, the smaller model is this one:
llm_load_print_meta: model type = 13B
llm_load_print_meta: model ftype = Q8_0
llm_load_print_meta: model params = 13.02 B
llm_load_print_meta: model size = 12.88 GiB (8.50 BPW)
llm_load_print_meta: general.name = newhope.ggmlv3.q8_0.bin
llm_load_print_meta: BOS token = 1 '<s>'
llm_load_print_meta: EOS token = 2 '</s>'
llm_load_print_meta: UNK token = 0 '<unk>'
llm_load_print_meta: LF token = 13 '<0x0A>'
llm_load_tensors: ggml ctx size = 0.28 MiB
llm_load_tensors: offloading 40 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 41/41 layers to GPU
llm_load_tensors: ROCm0 buffer size = 13023.85 MiB
llm_load_tensors: CPU buffer size = 166.02 MiB
Crashed as well with a very similar backtrace.
Since this is nicely reproducible, I can provide more more info or add some debug logs as needed, please let me know what you need.