Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support glm3 and glm4. #8031

Open
wants to merge 29 commits into
base: master
Choose a base branch
from
Open

Support glm3 and glm4. #8031

wants to merge 29 commits into from

Conversation

youth123
Copy link

I have fixed the issues mentioned in #6999. This code can totally supports glm3 and glm4 model architecture and can be emdded in ollama server. This PR is based on https://github.com/mnlife/llama.cpp/tree/glm4 and https://github.com/mnlife/llama.cpp/tree/chatglm3, by @mnlife and @xingxingqiao.

xingxingqiao and others added 8 commits May 29, 2024 13:30
Signed-off-by: XingXing Qiao <qiaoxx@dingdao.com>
Signed-off-by: XingXing Qiao <qiaoxx@dingdao.com>
Signed-off-by: XingXing Qiao <qiaoxx@dingdao.com>
Signed-off-by: XingXing Qiao <qiaoxx@dingdao.com>
@github-actions github-actions bot added testing Everything test related python python script changes labels Jun 20, 2024
@xunkai55
Copy link

Thanks for the great work!

@youth123 youth123 marked this pull request as draft June 20, 2024 10:07
@youth123 youth123 marked this pull request as ready for review June 20, 2024 10:15
llama.cpp Outdated Show resolved Hide resolved
@arch-btw
Copy link
Contributor

This is so great! Thank you 👍 !

There are only a couple of things that I ran into:

During compile, a small note:

llama.cpp: In function ‘int32_t llama_tokenize(const llama_model*, const char*, int32_t, llama_token*, int32_t, bool, bool)’:
llama.cpp:18603:28: warning: moving a temporary object prevents copy elision [-Wpessimizing-move]
18603 |     auto prompt = std::move(std::string(text, text_len));
      |                   ~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~
llama.cpp:18603:28: note: remove ‘std::move’ call

And convert-hf-to-gguf.py doesn't work with bf16.

But it does work with f32.

bf16 log:

INFO:hf-to-gguf:output.weight,             torch.bfloat16 --> BF16, shape = {4096, 151552}
Writing:   0%|                                                                                                                                                                     | 0.00/18.8G [00:00<?, ?byte/s]Traceback (most recent call last):
  File "/home/glm4/convert-hf-to-gguf.py", line 3072, in <module>
    main()
  File "/home/glm4/convert-hf-to-gguf.py", line 3066, in main
    model_instance.write()
  File "/home/glm4/convert-hf-to-gguf.py", line 331, in write
    self.gguf_writer.write_tensors_to_file(progress=True)
  File "/home/glm4/gguf-py/gguf/gguf_writer.py", line 312, in write_tensors_to_file
    ti.tensor.tofile(self.fout)
  File "/home/glm4/gguf-py/gguf/lazy.py", line 233, in tofile
    eager = LazyNumpyTensor.to_eager(self)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/glm4/gguf-py/gguf/lazy.py", line 193, in to_eager
    return cls._recurse_apply(t, simple_to_eager)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/glm4/gguf-py/gguf/lazy.py", line 109, in _recurse_apply
    return fn(o)
           ^^^^^
  File "/home/glm4/gguf-py/gguf/lazy.py", line 185, in simple_to_eager
    lt._data = lt._func(lt._args)
               ^^^^^^^^^^^^^^^^^^
  File "/home/glm4/gguf-py/gguf/lazy.py", line 158, in <lambda>
    return cls(meta=cls.eager_to_meta(res), lazy=shared_lazy, args=args, func=lambda a: fn(*a, **kwargs))
                                                                                        ^^^^^^^^^^^^^^^^
  File "/home/glm4/gguf-py/gguf/quants.py", line 52, in __quantize_bf16_array
    return __apply_over_grouped_rows(__compute_fp32_to_bf16, arr=n, otype=np.int16, oshape=n.shape)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/glm4/gguf-py/gguf/quants.py", line 47, in __apply_over_grouped_rows
    np.concatenate([func(group).ravel() for group in np.array_split(rows, n_groups)], axis=0, out=out)
                    ^^^^^^^^^^^
  File "/home/glm4/gguf-py/gguf/quants.py", line 30, in __compute_fp32_to_bf16
    n = np.where((n & 0x7fffffff) > 0x7f800000, (n & 0xffff0000) | (64 << 16), n)
                                                 ~~^~~~~~~~~~~~
OverflowError: Python integer 4294901760 out of bounds for int32

Other than that it works great.

Prompt:

./llama-cli -m glm4-Q6_K.gguf --color -p "[gMASK]<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\n"

Output:

You are a helpful assistant
Hello

Hello! How can I assist you today? [end of text]

@youth123
Copy link
Author

youth123 commented Jun 21, 2024

This is so great! Thank you 👍 !

There are only a couple of things that I ran into:

During compile, a small note:

llama.cpp: In function ‘int32_t llama_tokenize(const llama_model*, const char*, int32_t, llama_token*, int32_t, bool, bool)’:
llama.cpp:18603:28: warning: moving a temporary object prevents copy elision [-Wpessimizing-move]
18603 |     auto prompt = std::move(std::string(text, text_len));
      |                   ~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~
llama.cpp:18603:28: note: remove ‘std::move’ call

And convert-hf-to-gguf.py doesn't work with bf16.

But it does work with f32.

bf16 log:

INFO:hf-to-gguf:output.weight,             torch.bfloat16 --> BF16, shape = {4096, 151552}
Writing:   0%|                                                                                                                                                                     | 0.00/18.8G [00:00<?, ?byte/s]Traceback (most recent call last):
  File "/home/glm4/convert-hf-to-gguf.py", line 3072, in <module>
    main()
  File "/home/glm4/convert-hf-to-gguf.py", line 3066, in main
    model_instance.write()
  File "/home/glm4/convert-hf-to-gguf.py", line 331, in write
    self.gguf_writer.write_tensors_to_file(progress=True)
  File "/home/glm4/gguf-py/gguf/gguf_writer.py", line 312, in write_tensors_to_file
    ti.tensor.tofile(self.fout)
  File "/home/glm4/gguf-py/gguf/lazy.py", line 233, in tofile
    eager = LazyNumpyTensor.to_eager(self)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/glm4/gguf-py/gguf/lazy.py", line 193, in to_eager
    return cls._recurse_apply(t, simple_to_eager)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/glm4/gguf-py/gguf/lazy.py", line 109, in _recurse_apply
    return fn(o)
           ^^^^^
  File "/home/glm4/gguf-py/gguf/lazy.py", line 185, in simple_to_eager
    lt._data = lt._func(lt._args)
               ^^^^^^^^^^^^^^^^^^
  File "/home/glm4/gguf-py/gguf/lazy.py", line 158, in <lambda>
    return cls(meta=cls.eager_to_meta(res), lazy=shared_lazy, args=args, func=lambda a: fn(*a, **kwargs))
                                                                                        ^^^^^^^^^^^^^^^^
  File "/home/glm4/gguf-py/gguf/quants.py", line 52, in __quantize_bf16_array
    return __apply_over_grouped_rows(__compute_fp32_to_bf16, arr=n, otype=np.int16, oshape=n.shape)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/glm4/gguf-py/gguf/quants.py", line 47, in __apply_over_grouped_rows
    np.concatenate([func(group).ravel() for group in np.array_split(rows, n_groups)], axis=0, out=out)
                    ^^^^^^^^^^^
  File "/home/glm4/gguf-py/gguf/quants.py", line 30, in __compute_fp32_to_bf16
    n = np.where((n & 0x7fffffff) > 0x7f800000, (n & 0xffff0000) | (64 << 16), n)
                                                 ~~^~~~~~~~~~~~
OverflowError: Python integer 4294901760 out of bounds for int32

Other than that it works great.

Prompt:

./llama-cli -m glm4-Q6_K.gguf --color -p "[gMASK]<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\n"

Output:

You are a helpful assistant
Hello

Hello! How can I assist you today? [end of text]

I reran the conversion for GLM3 and GLM4, but did not encounter the issue you mentioned.
Here are my run commands and model weight links.

python convert-hf-to-gguf.py   /root/.cache/huggingface/hub/models--THUDM--glm-4-9b-chat/snapshots/75792d7ee58a335df6943c5d719cc559b64f8e2a/ --outtype bf16 --outfile test.gguf

https://huggingface.co/THUDM/glm-4-9b-chat

@arch-btw
Copy link
Contributor

Thanks, I think it's related to my pip environment and not a problem with the code.

@mofosyne mofosyne added the Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level label Jun 21, 2024
llama.cpp Outdated
@@ -18324,6 +18550,19 @@ llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_to
}

bool llama_token_is_eog(const struct llama_model * model, llama_token token) {
auto arch_name = llama_model_arch_name(model->arch);
auto vocab_type = model->vocab.type;
if (strcmp(arch_name, "chatglm") == 0) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llama_token_is_eog is called quite often, doing string compare here may have impact on performance

Copy link
Collaborator

@ngxson ngxson Jun 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at tokenizer_config.json, I think that it's safe to stop at EOS (<|endoftext|>), so no need to hard-code token IDs here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit: looking at chat template, seems like the model does not have the notion end-of-turn token (strange!). Maybe we need to introduce EOT token as a list instead of single value. This will require adding metadata to gguf (CC @ggerganov )

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, I will add an eot list to the metadata of gguf. Then, during the initialization of vocab, I will put all the eot entries into this variable. At that time, the judgment will only require traversing this list.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have already added the eod_id_list variable to the gguf meta and ensured compatibility with the previous versions. Could you please check if there are any other modifications needed?

https://github.com/ggerganov/llama.cpp/pull/8031/files#diff-4f653096980bd7d10518aa909cb648452cd3aa380ff93cb9fb642dca48536526R110

https://github.com/ggerganov/llama.cpp/pull/8031/files#diff-150dc86746a90bad4fc2c3334aeb9b5887b3adad3cc1459446717638605348efR5090

llama.cpp Outdated Show resolved Hide resolved
llama.cpp Outdated Show resolved Hide resolved
gguf-py/pyproject.toml Outdated Show resolved Hide resolved
gguf-py/gguf/constants.py Outdated Show resolved Hide resolved
convert-hf-to-gguf.py Outdated Show resolved Hide resolved
llama.cpp Outdated Show resolved Hide resolved
@arch-btw
Copy link
Contributor

@dranger003 thank you that fixed it for me too! I think it's related to NumPy 2.0 being released last week.

It's odd that upgrading by itself didn't work. I had to completely remove numpy and torch and then it worked by using your command. 👍

Copy link

@gabrielpondc gabrielpondc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, which env of gguf u used, after the revision the issue still showed

python convert-hf-to-gguf.py ../THUDM/glm-4-9b-chat --outtype q8_0 --outfile glm4.gguf
Traceback (most recent call last):
  File "/2T/Langchain-Ch/glm4cpp/convert-hf-to-gguf.py", line 842, in <module>
    class OrionModel(Model):
  File "/2T/Langchain-Ch/glm4cpp/convert-hf-to-gguf.py", line 843, in OrionModel
    model_arch = gguf.MODEL_ARCH.ORION
  File "/root/anaconda3/envs/chatglmcpp/lib/python3.9/enum.py", line 429, in __getattr__
    raise AttributeError(name) from None
AttributeError: ORION

@youth123
Copy link
Author

Sorry, which env of gguf u used, after the revision the issue still showed

python convert-hf-to-gguf.py ../THUDM/glm-4-9b-chat --outtype q8_0 --outfile glm4.gguf
Traceback (most recent call last):
  File "/2T/Langchain-Ch/glm4cpp/convert-hf-to-gguf.py", line 842, in <module>
    class OrionModel(Model):
  File "/2T/Langchain-Ch/glm4cpp/convert-hf-to-gguf.py", line 843, in OrionModel
    model_arch = gguf.MODEL_ARCH.ORION
  File "/root/anaconda3/envs/chatglmcpp/lib/python3.9/enum.py", line 429, in __getattr__
    raise AttributeError(name) from None
AttributeError: ORION

This issue seems to be due to an incorrect architectures attribute in your model's config.json file. You can re-download the model configuration from hf.
https://huggingface.co/THUDM/glm-4-9b-chat

Copy link
Collaborator

@ngxson ngxson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't test this atm, but the code looks good to me. Thanks for the contribution @youth123

FYI, there is also a fix #8198 that may fix the issue with incorrect that response that @matteoserva pointed out.

Let's also wait for approval from @ggerganov before merging

src/llama.cpp Outdated Show resolved Hide resolved
src/llama.cpp Outdated Show resolved Hide resolved
src/llama.cpp Outdated Show resolved Hide resolved
tests/test-chat-template.cpp Outdated Show resolved Hide resolved
@touale
Copy link

touale commented Jun 30, 2024

Thank you again for your work. I found that when the model performs a perplexity check it returns nan.

$ CUDA_VISIBLE_DEVICES=0 ./llama-perplexity -m ./output.guff -f ./wiki.test.raw
main: build = 3283 (bbe1926f)
main: built with cc (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0 for x86_64-linux-gnu
main: seed  = 1719723373
llama_model_loader: loaded meta data with 24 key-value pairs and 283 tensors from ./output.guff (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = chatglm
llama_model_loader: - kv   1:                               general.name str              = glm-4-9b-chat
llama_model_loader: - kv   2:                     chatglm.context_length u32              = 131072
llama_model_loader: - kv   3:                   chatglm.embedding_length u32              = 4096
llama_model_loader: - kv   4:                chatglm.feed_forward_length u32              = 13696
llama_model_loader: - kv   5:                        chatglm.block_count u32              = 40
llama_model_loader: - kv   6:               chatglm.attention.head_count u32              = 32
llama_model_loader: - kv   7:            chatglm.attention.head_count_kv u32              = 2
llama_model_loader: - kv   8:   chatglm.attention.layer_norm_rms_epsilon f32              = 0.000000
llama_model_loader: - kv   9:                          general.file_type u32              = 1
llama_model_loader: - kv  10:               chatglm.rope.dimension_count u32              = 64
llama_model_loader: - kv  11:               tokenizer.ggml.add_bos_token bool             = false
llama_model_loader: - kv  12:                     chatglm.rope.freq_base f32              = 500.000000
llama_model_loader: - kv  13:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  14:                         tokenizer.ggml.pre str              = chatglm-bpe
llama_model_loader: - kv  15:                      tokenizer.ggml.tokens arr[str,151552]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  16:                  tokenizer.ggml.token_type arr[i32,151552]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  17:                      tokenizer.ggml.merges arr[str,151073]  = ["Ġ Ġ", "ĠĠ ĠĠ", "i n", "Ġ t",...
llama_model_loader: - kv  18:            tokenizer.ggml.padding_token_id u32              = 151329
llama_model_loader: - kv  19:                tokenizer.ggml.eos_token_id u32              = 151329
llama_model_loader: - kv  20:                tokenizer.ggml.eot_token_id u32              = 151336
llama_model_loader: - kv  21:            tokenizer.ggml.unknown_token_id u32              = 151329
llama_model_loader: - kv  22:                    tokenizer.chat_template str              = chatglm4
llama_model_loader: - kv  23:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:  121 tensors
llama_model_loader: - type  f16:  162 tensors
merges_keyidx: 17
n_merges: 151073
llm_load_vocab: special tokens cache size = 223
llm_load_vocab: token to piece cache size = 0.9732 MB
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = chatglm
llm_load_print_meta: vocab type       = BPE
llm_load_print_meta: n_vocab          = 151552
llm_load_print_meta: n_merges         = 151073
llm_load_print_meta: n_ctx_train      = 131072
llm_load_print_meta: n_embd           = 4096
llm_load_print_meta: n_head           = 32
llm_load_print_meta: n_head_kv        = 2
llm_load_print_meta: n_layer          = 40
llm_load_print_meta: n_rot            = 64
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 16
llm_load_print_meta: n_embd_k_gqa     = 256
llm_load_print_meta: n_embd_v_gqa     = 256
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.6e-07
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 13696
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 500.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_ctx_orig_yarn  = 131072
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: model type       = 9B
llm_load_print_meta: model ftype      = F16
llm_load_print_meta: model params     = 9.40 B
llm_load_print_meta: model size       = 17.51 GiB (16.00 BPW) 
llm_load_print_meta: general.name     = glm-4-9b-chat
llm_load_print_meta: EOS token        = 151329 '<|endoftext|>'
llm_load_print_meta: UNK token        = 151329 '<|endoftext|>'
llm_load_print_meta: PAD token        = 151329 '<|endoftext|>'
llm_load_print_meta: LF token         = 151331 '[gMASK]'
llm_load_print_meta: EOT token        = 151336 '<|user|>'
llm_load_print_meta: max token length = 1024
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA A100-SXM4-80GB, compute capability 8.0, VMM: yes
llm_load_tensors: ggml ctx size =    0.14 MiB
llm_load_tensors: offloading 0 repeating layers to GPU
llm_load_tensors: offloaded 0/41 layers to GPU
llm_load_tensors:        CPU buffer size = 17929.97 MiB
.................................................................................
llama_new_context_with_model: n_ctx      = 2048
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 500.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:  CUDA_Host KV buffer size =    80.00 MiB
llama_new_context_with_model: KV self size  =   80.00 MiB, K (f16):   40.00 MiB, V (f16):   40.00 MiB
llama_new_context_with_model:  CUDA_Host  output buffer size =     2.31 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =  1488.00 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =    12.01 MiB
llama_new_context_with_model: graph nodes  = 1606
llama_new_context_with_model: graph splits = 364

system_info: n_threads = 52 / 104 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 1 | AVX512_VBMI = 0 | AVX512_VNNI = 1 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 | 
perplexity: tokenizing the input ..
perplexity: tokenization took 1336.87 ms
perplexity: calculating perplexity over 565 chunks, n_ctx=512, batch_size=2048, n_seq=4
perplexity: 16.21 seconds per pass - ETA 38.13 minutes
[1]8.5085,[2]11.3082,[3]11.8973,[4]12.5024,[5]13.1712,[6]14.4324,[7]15.3985,[8]16.1918,[9]17.7532,
[10]18.5259,[11]18.8466,[12]19.8125,[13]21.1505,[14]20.0578,[15]19.3602,[16]19.1104,[17]nan,[18]nan,[19]nan,
[20]nan,[21]nan,[22]nan,[23]nan,[24]nan,[25]nan,[26]nan,[27]nan,[28]nan,[29]nan,
[30]nan...

Also when I run the llama model check, it is normal.

./llama-perplexity -m //root/.cache/huggingface/big_models/llama-3-chinese-8b-instruct-v3/ggml-model-f16.gguf -f ./wiki.test.raw
main: build = 3283 (bbe1926f)
main: built with cc (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0 for x86_64-linux-gnu
main: seed  = 1719572900
llama_model_loader: loaded meta data with 22 key-value pairs and 291 tensors from //root/.cache/huggingface/big_models/llama-3-chinese-8b-instruct-v3/ggml-model-f16.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.name str              = llama-3-chinese-8b-instruct-v3
llama_model_loader: - kv   2:                          llama.block_count u32              = 32
llama_model_loader: - kv   3:                       llama.context_length u32              = 8192
llama_model_loader: - kv   4:                     llama.embedding_length u32              = 4096
llama_model_loader: - kv   5:                  llama.feed_forward_length u32              = 14336
llama_model_loader: - kv   6:                 llama.attention.head_count u32              = 32
llama_model_loader: - kv   7:              llama.attention.head_count_kv u32              = 8
llama_model_loader: - kv   8:                       llama.rope.freq_base f32              = 500000.000000
llama_model_loader: - kv   9:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  10:                          general.file_type u32              = 1
llama_model_loader: - kv  11:                           llama.vocab_size u32              = 128256
llama_model_loader: - kv  12:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv  13:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  14:                         tokenizer.ggml.pre str              = llama-bpe
llama_model_loader: - kv  15:                      tokenizer.ggml.tokens arr[str,128256]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  16:                  tokenizer.ggml.token_type arr[i32,128256]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  17:                      tokenizer.ggml.merges arr[str,280147]  = ["Ġ Ġ", "Ġ ĠĠĠ", "ĠĠ ĠĠ", "...
llama_model_loader: - kv  18:                tokenizer.ggml.bos_token_id u32              = 128000
llama_model_loader: - kv  19:                tokenizer.ggml.eos_token_id u32              = 128009
llama_model_loader: - kv  20:                    tokenizer.chat_template str              = {% set loop_messages = messages %}{% ...
llama_model_loader: - kv  21:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:   65 tensors
llama_model_loader: - type  f16:  226 tensors
merges_keyidx: 17
n_merges: 280147
llm_load_vocab: special tokens cache size = 256
llm_load_vocab: token to piece cache size = 0.8000 MB
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = llama
llm_load_print_meta: vocab type       = BPE
llm_load_print_meta: n_vocab          = 128256
llm_load_print_meta: n_merges         = 280147
llm_load_print_meta: n_ctx_train      = 8192
llm_load_print_meta: n_embd           = 4096
llm_load_print_meta: n_head           = 32
llm_load_print_meta: n_head_kv        = 8
llm_load_print_meta: n_layer          = 32
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 4
llm_load_print_meta: n_embd_k_gqa     = 1024
llm_load_print_meta: n_embd_v_gqa     = 1024
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 14336
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 500000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_ctx_orig_yarn  = 8192
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: model type       = 8B
llm_load_print_meta: model ftype      = F16
llm_load_print_meta: model params     = 8.03 B
llm_load_print_meta: model size       = 14.96 GiB (16.00 BPW) 
llm_load_print_meta: general.name     = llama-3-chinese-8b-instruct-v3
llm_load_print_meta: BOS token        = 128000 '<|begin_of_text|>'
llm_load_print_meta: EOS token        = 128009 '<|eot_id|>'
llm_load_print_meta: LF token         = 128 'Ä'
llm_load_print_meta: EOT token        = 128009 '<|eot_id|>'
llm_load_print_meta: max token length = 256
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 4 CUDA devices:
  Device 0: NVIDIA A100-SXM4-80GB, compute capability 8.0, VMM: yes
  Device 1: NVIDIA A100-SXM4-80GB, compute capability 8.0, VMM: yes
  Device 2: NVIDIA A100-SXM4-80GB, compute capability 8.0, VMM: yes
  Device 3: NVIDIA A100-SXM4-80GB, compute capability 8.0, VMM: yes
llm_load_tensors: ggml ctx size =    0.14 MiB
llm_load_tensors: offloading 0 repeating layers to GPU
llm_load_tensors: offloaded 0/33 layers to GPU
llm_load_tensors:        CPU buffer size = 15317.02 MiB
.........................................................................................
llama_new_context_with_model: n_ctx      = 2048
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 500000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:  CUDA_Host KV buffer size =   256.00 MiB
llama_new_context_with_model: KV self size  =  256.00 MiB, K (f16):  128.00 MiB, V (f16):  128.00 MiB
llama_new_context_with_model:  CUDA_Host  output buffer size =     1.96 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =  1260.50 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =    12.01 MiB
llama_new_context_with_model: graph nodes  = 1030
llama_new_context_with_model: graph splits = 356

system_info: n_threads = 52 / 104 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 1 | AVX512_VBMI = 0 | AVX512_VNNI = 1 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 | 
perplexity: tokenizing the input ..
perplexity: tokenization took 418.448 ms
perplexity: calculating perplexity over 564 chunks, n_ctx=512, batch_size=2048, n_seq=4
perplexity: 16.67 seconds per pass - ETA 39.17 minutes
[1]4.7840,[2]5.8377,[3]6.2029,[4]6.7270,[5]7.0772,[6]7.2855,[7]7.6473,[8]8.2439,[9]8.8646,
[10]9.1446,[11]9.2682,[12]9.3960,[13]9.7859,[14]9.3636,[15]9.3063,[16]9.0332,[17]8.9551,[18]9.0666,[19]8.7837,
[20]8.6513,[21]8.6646,[22]8.3099,[23]8.0029,[24]7.8407,[25]7.6153,[26]7.5095,[27]7.4035,[28]7.3077,[29]7.3900,[30]7.3924....

@Forevery1
Copy link

Is there any progress?

@youth123
Copy link
Author

youth123 commented Jul 1, 2024

Is there any progress?

I am re-modifying the code according to the review comments and merging it into the latest branch. There have been many changes to llama.cpp recently.

@youth123
Copy link
Author

youth123 commented Jul 1, 2024

@dranger003

I reran the perplexity check for GLM3 and GLM4, but did not encounter the issue you mentioned.
Here are my run commands and model weight links.

python convert-hf-to-gguf.py   /root/.cache/huggingface/hub/models--THUDM--glm-4-9b-chat/snapshots/75792d7ee58a335df6943c5d719cc559b64f8e2a/ --outfile test.gguf
./build/bin/llama-perplexity -m test.gguf  -f /root/tmp/wikitext-2-raw/wiki.test.raw

https://huggingface.co/THUDM/glm-4-9b-chat
It seems to be the same reason as the question raised by @dranger003 and @arch-btw. You can try to reinstall the numpy and torch to solve the question.

re-installing torch and numpy using --index-url https://download.pytorch.org/whl/cu121 worked... no idea why though.

Comment on lines +7452 to +7464
case LLM_FFN_SWIGLU:
{
// Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
int64_t split_point = cur->ne[0] / 2;
struct ggml_tensor * x0 = ggml_cont(ctx, ggml_view_2d(ctx, cur, split_point, cur->ne[1], cur->nb[1], 0));
struct ggml_tensor * x1 = ggml_cont(ctx, ggml_view_2d(ctx, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));

x0 = ggml_silu(ctx, x0);
cb(cur, "ffn_silu", il);

cur = ggml_mul(ctx, x0, x1);
cb(cur, "ffn_mul", il);
} break;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of introducing this new activation path, you can use LLM_FFN_SILU + LLM_FFN_PAR. Just need to split the ffn_up tensor into ffn_up + ffn_gate during conversion

Copy link
Author

@youth123 youth123 Jul 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The order of the split tensor running formulas is reversed. Assume that the split tensors are x0 and x1. The original implementation is silu(x0) * x1. Now if x1 is assigned to gate, the executed formula is silu(x0 * x1). So it should still be necessary to add an implementation of swiglu.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not reversed, LLM_FFN_PAR computes exactly silu(x0) * x1. But I just noticed that we already have the same code for Phi3, so in that case we can introduce LLM_FFN_SWIGLU and reuse it in both models

Comment on lines +14987 to +14994
// add prefix to chatglm3
if (vocab.type_pre == LLAMA_VOCAB_PRE_TYPE_CHATGLM3) {
output.push_back(64790); // [gMask]
output.push_back(64792); // sop
output.push_back(64795); // <|user|>
output.push_back(30910); // \n
output.push_back(13);
}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These prefixes and suffixes should not by applied here. Instead, the user code is responsible for that (e.g. chat templates)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@youth123 I didn't noticed this part, maybe this explain why the model's answer is incorrect: because the [gMask]sop<|user|> is added twice, once here and once in the chat template. We should remove it here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect answers have nothing to do with being added twice. I have compared the input of the hf model output before, and the output of each layer of the transformer is percentile aligned. But the error becomes larger in the last layer, and I'm still looking at this strange problem now. And regarding chat template, will it run in conversation mode in the future?

@Forevery1
Copy link

Is there any progress?

I am re-modifying the code according to the review comments and merging it into the latest branch. There have been many changes to llama.cpp recently.

Thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
examples python python script changes Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level server testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet