Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gemma model #5631

Merged
merged 1 commit into from
Feb 21, 2024
Merged

Add gemma model #5631

merged 1 commit into from
Feb 21, 2024

Conversation

postmasters
Copy link
Contributor

There are couple things in this architecture:

  1. Shared input and output embedding parameters.
  2. Key length and value length are not derived from n_embd.

More information about the models can be found at
https://ai.google.dev/gemma. GGUFs can be downloaded from https://huggingface.co/google.

There are couple things in this architecture:

1. Shared input and output embedding parameters.
2. Key length and value length are not derived from `n_embd`.

More information about the models can be found at
https://ai.google.dev/gemma. GGUFs can be downloaded from
https://huggingface.co/google.
@ggerganov ggerganov merged commit 580111d into ggerganov:master Feb 21, 2024
53 checks passed
@pablodz
Copy link

pablodz commented Feb 21, 2024

that was fast

@girmay
Copy link

girmay commented Feb 21, 2024

Holly Moses. This was fast. Thank you

@akx
Copy link
Contributor

akx commented Feb 21, 2024

A model converted and quantized from the safetensors weights still fails with

llama_model_load: error loading model: create_tensor: tensor 'output.weight' not found

for me.
There's a

[ 254/ 254]                   output_norm.weight - [ 3072,     1,     1,     1], type =    f32, size =    0.012 MB

tensor visible in the conversion and quantization output though.

@sroecker
Copy link
Contributor

sroecker commented Feb 21, 2024

More information about the models can be found at https://ai.google.dev/gemma. GGUFs can be downloaded from https://huggingface.co/google.

Interesting, is there a reason why the GGUF file is twice as large as the safetensors?

@postmasters
Copy link
Contributor Author

A model converted and quantized from the safetensors weights still fails with

llama_model_load: error loading model: create_tensor: tensor 'output.weight' not found

This depends on how your conversion is done. Two things to make sure: 1) the arch must be gemma and 2) there is no output weight in this arch because it shares the same embedding weights as the input layer. The error you see suggests that the arch is likely not set / copied correctly by the converter.

Interesting, is there a reason why the GGUF file is twice as large as the safetensors?

The weights here are as close to the internal checkpoints as you can get. They are in float32. We are leaning into the community to experiment with other quantized versions ;). For example, you could use the quantize tool included in this repository to produce an F16 version.

@cebtenzzre
Copy link
Collaborator

This PR doesn't make any changes to the convert scripts. How do I convert a Gemma model to GGUF?

@postmasters
Copy link
Contributor Author

This PR doesn't make any changes to the convert scripts. How do I convert a Gemma model to GGUF?

You could simply download the models released on HuggingFace, for example https://huggingface.co/google/gemma-2b/blob/main/gemma-2b.gguf.

@cebtenzzre
Copy link
Collaborator

cebtenzzre commented Feb 21, 2024

You could simply download the models released on HuggingFace

Are there plans to open-source the conversion scripts used, or will the community have to implement them? The safetensors checkpoint is a smaller download (presumably because of BF16 being converted to F32?) and one would imagine that people would like to be able to manipulate the Transformers weights (merge, finetune, etc.) before converting to GGUF, just as they do with other model architectures.

@postmasters
Copy link
Contributor Author

I don't work with SafeTensors so I can't promise I will take this up personally. I'm sure folks will contribute later though 🤞 .

@ggerganov
Copy link
Owner

Yup, hope we get some insights. I tried updating convert-hf-to-gguf.py to support the conversion, but something is missing because the inference produces garbage:

diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py
index 9771fccf..d328e524 100755
--- a/convert-hf-to-gguf.py
+++ b/convert-hf-to-gguf.py
@@ -218,6 +218,8 @@ class Model:
             return BertModel
         if model_architecture == "NomicBertModel":
             return NomicBertModel
+        if model_architecture in "GemmaForCausalLM":
+            return GemmaModel
         return Model
 
     def _is_model_safetensors(self) -> bool:
@@ -277,6 +279,8 @@ class Model:
             return gguf.MODEL_ARCH.BERT
         if arch == "NomicBertModel":
             return gguf.MODEL_ARCH.NOMIC_BERT
+        if arch in "GemmaForCausalLM":
+            return gguf.MODEL_ARCH.GEMMA
 
         raise NotImplementedError(f'Architecture "{arch}" not supported!')
 
@@ -1785,6 +1789,24 @@ class NomicBertModel(BertModel):
             yield name, data
 
 
+class GemmaModel(Model):
+    def set_vocab(self):
+        self._set_vocab_sentencepiece()
+
+    def set_gguf_parameters(self):
+        hparams = self.hparams
+        block_count = hparams["num_hidden_layers"]
+
+        self.gguf_writer.add_name(self.dir_model.name)
+        self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
+        self.gguf_writer.add_embedding_length(hparams["hidden_size"])
+        self.gguf_writer.add_block_count(block_count)
+        self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
+        self.gguf_writer.add_head_count(hparams["num_attention_heads"])
+        self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
+        self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"])
+
+

One thing that is strange is the vocab size in config.json specified as 256000, but the token_embd.weight tensor in the F32 GGUF files that are already provided has dimensions of [2048, 256128]. I tried padding with 128 but still producing garbage.

The F32 GGUF files work as expected

I'm currently testing just with the 2B model

@akx
Copy link
Contributor

akx commented Feb 21, 2024

if model_architecture in "GemmaForCausalLM":

That won't do the right thing...

@postmasters
Copy link
Contributor Author

I would not be surprised if the Gemma implementation in HF Transformers requires different transposes of the weight tensors than the implementation in this PR.

@ggerganov
Copy link
Owner

ggerganov commented Feb 21, 2024

Huh very weird. I've been dumping the tensors from the locally converted models and comparing the values with the provided F32 GGUF models. The values are not transposed.

However, all norm tensors have values that are ~1.0f less compared to the F32 GGUF data.
So if I apply this change, the conversion starts working:

# Huh? Why is this needed?
if name.endswith(("norm.weight")):
    data_torch = data_torch + 1

Here is the full convert-hf-to-gguf.py patch that produces working models:

diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py
index 9771fccf..e88308dc 100755
--- a/convert-hf-to-gguf.py
+++ b/convert-hf-to-gguf.py
@@ -218,6 +218,8 @@ class Model:
             return BertModel
         if model_architecture == "NomicBertModel":
             return NomicBertModel
+        if model_architecture in "GemmaForCausalLM":
+            return GemmaModel
         return Model
 
     def _is_model_safetensors(self) -> bool:
@@ -277,6 +279,8 @@ class Model:
             return gguf.MODEL_ARCH.BERT
         if arch == "NomicBertModel":
             return gguf.MODEL_ARCH.NOMIC_BERT
+        if arch in "GemmaForCausalLM":
+            return gguf.MODEL_ARCH.GEMMA
 
         raise NotImplementedError(f'Architecture "{arch}" not supported!')
 
@@ -1785,6 +1789,64 @@ class NomicBertModel(BertModel):
             yield name, data
 
 
+class GemmaModel(Model):
+    def set_vocab(self):
+        self._set_vocab_sentencepiece()
+
+    def set_gguf_parameters(self):
+        hparams = self.hparams
+        block_count = hparams["num_hidden_layers"]
+
+        self.gguf_writer.add_name(self.dir_model.name)
+        self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
+        self.gguf_writer.add_embedding_length(hparams["hidden_size"])
+        self.gguf_writer.add_block_count(block_count)
+        self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
+        self.gguf_writer.add_head_count(hparams["num_attention_heads"])
+        self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
+        self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"])
+
+    def write_tensors(self):
+        block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
+        tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
+
+        for name, data_torch in self.get_tensors():
+            # we don't need these
+            if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq", ".attn.bias", ".attn.masked_bias")):
+                continue
+
+            # Huh? Why is this needed?
+            if name.endswith(("norm.weight")):
+                data_torch = data_torch + 1
+
+            old_dtype = data_torch.dtype
+
+            # convert any unsupported data types to float32
+            if data_torch.dtype not in (torch.float16, torch.float32):
+                data_torch = data_torch.to(torch.float32)
+
+            data = data_torch.squeeze().numpy()
+
+            # map tensor names
+            new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
+            if new_name is None:
+                print(f"Can not map tensor {name!r}")
+                sys.exit()
+
+            n_dims = len(data.shape)
+            data_dtype = data.dtype
+
+            data = data.astype(np.float32)
+
+            # if f16 desired, convert any float32 2-dim weight tensors to float16
+            if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
+                data = data.astype(np.float16)
+
+            print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
+
+            self.gguf_writer.add_tensor(new_name, data)
+
+

Edit: Ah, there it is:

https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89

Edit2: here is a PR with the conversion script:

#5647

@alfred-liu96
Copy link

so fast!!!

@yiakwy-xpu-ml-framework-team
Copy link

yiakwy-xpu-ml-framework-team commented Feb 22, 2024

Gemma model to GGUF?

Google has already provided GGUF (float32) in hf repo.

@yiakwy-xpu-ml-framework-team

@ggerganov could we add check against bfloat16 dtype ?

@ggerganov
Copy link
Owner

What checks do you have in mind specifically?

@yiakwy-xpu-ml-framework-team

Google uses bfloat16 for inference, while llama.cpp does not.

But I personally believe not too much degression after converting from bf16 to float16. I guess in the converter, if no float16, then we convert them to float32 (or even better if no overflow happens).

If we check if tensor.dtype is bf16 and keep it as fp16, we will have a 17 GB GGUF file instead of a 34 GB GGUF file.

@fmichaelobrien
Copy link

Team, thank you for integrating Gemma support into llama.cpp yesterday - this was an extremely fast and efficient alignment with a model that just came out a couple hours before.
I personally am very grateful to your efforts.
A wider community thank you is in order
thank you for
#5631

@cebtenzzre
Copy link
Collaborator

cebtenzzre commented Feb 22, 2024

Trying out the F32 ggml-7b-it.gguf provided by Google, I'm getting a perplexity of "nan" at 2048 context - around 20 for the first few chunks. Also around 25 for the first chunk at 8192 context. For reference, llama-2-7b Q4_0 perplexity is about 5.16 at 4096 context.

@postmasters Are you sure the implementation is correct?

@postmasters
Copy link
Contributor Author

DANtm in #5635 (comment) suggested that setting repeat-penalty would give better inference outputs.

@slaren
Copy link
Collaborator

slaren commented Feb 22, 2024

With the gemma-7b.gguf base model (without instruction tuning), converted to f16, I get 6.5376 PPL at 2048 context, 6.2240 at 8912 context.

@cebtenzzre
Copy link
Collaborator

cebtenzzre commented Feb 23, 2024

With the gemma-7b.gguf base model (without instruction tuning), converted to f16, I get 6.5376 PPL at 2048 context, 6.2240 at 8912 context.

Weird. Here is what I just tried:

  1. Download safetensors model from https://huggingface.co/google/gemma-7b
  2. Checkout latest llama.cpp master (commit 15499eb)
  3. ./convert-hf-to-gguf.py gemma-7b --outfile gemma-7b.f16.gguf --outtype f16
  4. cmake -B build -DCMAKE_BUILD_TYPE=RelWithDebInfo -DLLAMA_CUBLAS=ON
  5. make -C build perplexity
  6. Run perplexity on my Tesla P40:
$ CUDA_VISIBLE_DEVICES=0 build/bin/perplexity -f wiki.test.raw -c 2048 -m gemma-7b.f16.gguf -ngl 99
<snip>
perplexity: tokenizing the input ..
perplexity: tokenization took 974.102 ms
perplexity: calculating perplexity over 142 chunks, batch_size=512
perplexity: 6.52 seconds per pass - ETA 15.43 minutes
[1]nan,

And there's no point in running it longer than that because the running average will stay NaN.

@slaren
Copy link
Collaborator

slaren commented Feb 23, 2024

I didn't convert from the HF model, I downloaded the fp32 gguf and converted it to fp16 with the quantize tool.

@ggerganov
Copy link
Owner

ggerganov commented Feb 23, 2024

It works with Metal and CPU using the convert model from HF data:

llama_new_context_with_model: n_ctx      = 2048
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
ggml_metal_init: allocating
ggml_metal_init: found device: Apple M2 Ultra
ggml_metal_init: picking default device: Apple M2 Ultra
ggml_metal_init: default.metallib not found, loading from source
ggml_metal_init: GGML_METAL_PATH_RESOURCES = nil
ggml_metal_init: loading '/Users/ggerganov/development/github/llama.cpp/ggml-metal.metal'
ggml_metal_init: GPU name:   Apple M2 Ultra
ggml_metal_init: GPU family: MTLGPUFamilyApple8  (1008)
ggml_metal_init: GPU family: MTLGPUFamilyCommon3 (3003)
ggml_metal_init: GPU family: MTLGPUFamilyMetal3  (5001)
ggml_metal_init: simdgroup reduction support   = true
ggml_metal_init: simdgroup matrix mul. support = true
ggml_metal_init: hasUnifiedMemory              = true
ggml_metal_init: recommendedMaxWorkingSetSize  = 154618.82 MB
ggml_backend_metal_buffer_type_alloc_buffer: allocated buffer, size =   896.00 MiB, (17182.56 / 147456.00)
llama_kv_cache_init:      Metal KV buffer size =   896.00 MiB
llama_new_context_with_model: KV self size  =  896.00 MiB, K (f16):  448.00 MiB, V (f16):  448.00 MiB
llama_new_context_with_model:        CPU input buffer size   =    11.02 MiB
ggml_backend_metal_buffer_type_alloc_buffer: allocated buffer, size =   506.00 MiB, (17688.56 / 147456.00)
llama_new_context_with_model:      Metal compute buffer size =   506.00 MiB
llama_new_context_with_model:        CPU compute buffer size =     6.00 MiB
llama_new_context_with_model: graph splits (measure): 3

system_info: n_threads = 16 / 24 | AVX = 0 | AVX_VNNI = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | SSSE3 = 0 | VSX = 0 | MATMUL_INT8 = 0 | 
perplexity: tokenizing the input ..
perplexity: tokenization took 680.214 ms
perplexity: calculating perplexity over 142 chunks, batch_size=512
perplexity: 2.20 seconds per pass - ETA 5.20 minutes
[1]5.6440,[2]6.8762,[3]7.2915,[4]6.5856,[5]6.1074,^C

I haven't tried CUDA

@slaren
Copy link
Collaborator

slaren commented Feb 23, 2024

With the 7B it model, both fp32 and fp16, I get PPL around ~20, but I didn't get any nan. Using CUDA with a 3090 Ti.

./perplexity -m models/gemma-7b-it.gguf -ngl 16 -f wikitext-2-raw/wiki.test.raw -c 2048 -b 2048

[1]18.7209,[2]24.9808,[3]27.5477,[4]24.3773,[5]23.0816,[6]18.8618,[7]17.1438,[8]16.7933,[9]17.6730,[10]17.7396,[11]17.5431,[12]18.4702,[13]18.4808,[14]18.8670,[15]19.0879,[16]19.8607,[17]19.9045,[18]20.0591,[19]19.5361,[20]19.5234,[21]19.4770,[22]19.6030,[23]19.7331,[24]19.7435,[25]20.3328,[26]20.3213,[27]21.0581,[28]21.2913,[29]21.3306,[30]21.3418,[31]21.2371,[32]20.8257,[33]20.9147,[34]20.7695,[35]20.3417,[36]19.8174,[37]19.3932,[38]19.1021,[39]18.6685,[40]18.3326,[41]18.4187,[42]18.7619,[43]19.1801,[44]19.2661,[45]19.5229,[46]19.7044,[47]19.8079,[48]19.8868,[49]19.7144,[50]19.8063,[51]19.6146,[52]19.4874,[53]19.3591,[54]19.1463,[55]19.0401,[56]18.7300,[57]18.7026,[58]18.6677,[59]18.7592,[60]18.9321,[61]19.0706,[62]19.3321,[63]19.3754,[64]19.1777,[65]19.1974,[66]19.0348,[67]18.9782,[68]18.9456,[69]18.7255,[70]18.6829,[71]18.9174,[72]19.0924,[73]19.0105,[74]19.0003,[75]19.0218,[76]19.0347,[77]19.0468,[78]19.0965,[79]19.2280,[80]19.1060,[81]19.0114,[82]18.8994,[83]18.8548,[84]18.8064,[85]18.6885,[86]18.6932,[87]18.7991,[88]18.8360,[89]18.9913,[90]19.2310,[91]19.3942,[92]19.4716,[93]19.6063,[94]19.7548,[95]19.8119,[96]19.8152,[97]19.8260,[98]19.8841,[99]19.8545,[100]19.8863,[101]19.9525,[102]20.0558,[103]20.1068,[104]20.1235,[105]20.1545,[106]20.1087,[107]20.1187,[108]20.1356,[109]20.0118,[110]19.9838,[111]19.9104,[112]19.9309,[113]19.9644,[114]20.0094,[115]19.9969,[116]19.9885,[117]19.9382,[118]19.9754,[119]19.9458,[120]19.8893,[121]19.8274,[122]19.8208,[123]19.8898,[124]19.8708,[125]19.8470,[126]19.7959,[127]19.7669,[128]19.8034,[129]19.7301,[130]19.7389,[131]19.7456,[132]19.8056,[133]19.8714,[134]19.7818,[135]19.5798,[136]19.6104,[137]19.6654,[138]19.7277,[139]19.7196,[140]19.7987,[141]19.7966,[142]19.8815,
Final estimate: PPL = 19.8815 +/- 0.21118

./perplexity -m models/gemma-7b-it-f16.gguf -ngl 99 -f wikitext-2-raw/wiki.test.raw -c 2048 -b 2048

[1]20.1795,[2]26.9640,[3]29.7014,[4]26.0533,[5]24.6819,[6]20.1910,[7]18.3000,[8]17.9538,[9]18.9682,[10]19.0602,[11]18.8272,[12]19.8369,[13]19.8639,[14]20.2951,[15]20.5834,[16]21.4171,[17]21.4557,[18]21.5971,[19]21.0191,[20]21.0192,[21]20.9496,[22]21.0968,[23]21.2142,[24]21.2150,[25]21.8738,[26]21.8679,[27]22.6807,[28]22.9258,[29]22.9705,[30]22.9568,[31]22.8573,[32]22.4253,[33]22.5072,[34]22.3343,[35]21.8594,[36]21.2822,[37]20.8135,[38]20.5065,[39]20.0285,[40]19.6789,[41]19.7705,[42]20.1286,[43]20.5842,[44]20.6894,[45]20.9686,[46]21.1656,[47]21.2816,[48]21.3709,[49]21.1856,[50]21.2821,[51]21.0609,[52]20.9323,[53]20.8012,[54]20.5745,[55]20.4565,[56]20.1180,[57]20.0956,[58]20.0513,[59]20.1503,[60]20.3293,[61]20.4832,[62]20.7717,[63]20.8173,[64]20.6016,[65]20.6253,[66]20.4452,[67]20.3804,[68]20.3380,[69]20.0944,[70]20.0470,[71]20.3014,[72]20.4836,[73]20.3948,[74]20.3817,[75]20.4047,[76]20.4281,[77]20.4366,[78]20.4864,[79]20.6221,[80]20.4927,[81]20.3864,[82]20.2693,[83]20.2190,[84]20.1664,[85]20.0367,[86]20.0338,[87]20.1434,[88]20.1821,[89]20.3557,[90]20.6163,[91]20.7865,[92]20.8807,[93]21.0238,[94]21.1747,[95]21.2370,[96]21.2429,[97]21.2521,[98]21.3145,[99]21.2806,[100]21.3140,[101]21.3920,[102]21.5062,[103]21.5592,[104]21.5844,[105]21.6126,[106]21.5651,[107]21.5795,[108]21.5964,[109]21.4639,[110]21.4302,[111]21.3496,[112]21.3736,[113]21.4072,[114]21.4540,[115]21.4443,[116]21.4348,[117]21.3798,[118]21.4199,[119]21.3902,[120]21.3283,[121]21.2562,[122]21.2507,[123]21.3275,[124]21.3050,[125]21.2840,[126]21.2304,[127]21.2018,[128]21.2447,[129]21.1617,[130]21.1693,[131]21.1789,[132]21.2450,[133]21.3154,[134]21.2170,[135]20.9955,[136]21.0295,[137]21.0900,[138]21.1598,[139]21.1552,[140]21.2385,[141]21.2370,[142]21.3300,
PPL = 21.3300 +/- 0.22891

@cebtenzzre
Copy link
Collaborator

cebtenzzre commented Feb 26, 2024

I didn't convert from the HF model, I downloaded the fp32 gguf and converted it to fp16 with the quantize tool.

Fun fact: This will leave you with a Q6_K output tensor unless you pass --pure. @ikawrakow This probably isn't intended, right?

6. Run perplexity on my Tesla P40:

$ CUDA_VISIBLE_DEVICES=0 build/bin/perplexity -f wiki.test.raw -c 2048 -m gemma-7b.f16.gguf -ngl 99
<snip>
perplexity: tokenizing the input ..
perplexity: tokenization took 974.102 ms
perplexity: calculating perplexity over 142 chunks, batch_size=512
perplexity: 6.52 seconds per pass - ETA 15.43 minutes
[1]nan,

And there's no point in running it longer than that because the running average will stay NaN.

I've discovered that these NaNs occur with -ngl 2 and above, but not with -ngl 1 or with --no-kv-offload. I can reproduce them on my P40 with either the FP16 converted from safetensors, or the FP16 quantized from Google's provided GGUF.

@slaren I wonder if you can reproduce if you build with -DLLAMA_CUDA_FORCE_MMQ=ON? That's effectively always enabled on my P40.

@slaren
Copy link
Collaborator

slaren commented Feb 27, 2024

I tried with LLAMA_CUDA_FORCE_MMQ. With the FP32 model the results are identical, which is expected since MMQ is only used with quants. With FP16 they are slightly differently, probably due to the Q6_K output tensor, but still no nan.

$ ./perplexity -m models/gemma-7b-it-f16.gguf -ngl 160 -f wikitext-2-raw/wiki.test.raw -c 2048 -b 2048
main: build = 2276 (b11a93df)
main: built with cc (Ubuntu 12.3.0-9ubuntu2) 12.3.0 for x86_64-linux-gnu
main: seed  = 1708998783
ggml_init_cublas: GGML_CUDA_FORCE_MMQ:   yes
ggml_init_cublas: CUDA_USE_TENSOR_CORES: no
ggml_init_cublas: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3090 Ti, compute capability 8.6, VMM: yes
llama_model_loader: loaded meta data with 21 key-value pairs and 254 tensors from models/gemma-7b-it-f16.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = gemma
llama_model_loader: - kv   1:                               general.name str              = gemma-7b-it
llama_model_loader: - kv   2:                       gemma.context_length u32              = 8192
llama_model_loader: - kv   3:                          gemma.block_count u32              = 28
llama_model_loader: - kv   4:                     gemma.embedding_length u32              = 3072
llama_model_loader: - kv   5:                  gemma.feed_forward_length u32              = 24576
llama_model_loader: - kv   6:                 gemma.attention.head_count u32              = 16
llama_model_loader: - kv   7:              gemma.attention.head_count_kv u32              = 16
llama_model_loader: - kv   8:                 gemma.attention.key_length u32              = 256
llama_model_loader: - kv   9:               gemma.attention.value_length u32              = 256
llama_model_loader: - kv  10:     gemma.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv  11:                       tokenizer.ggml.model str              = llama
llama_model_loader: - kv  12:                tokenizer.ggml.bos_token_id u32              = 2
llama_model_loader: - kv  13:                tokenizer.ggml.eos_token_id u32              = 1
llama_model_loader: - kv  14:            tokenizer.ggml.padding_token_id u32              = 0
llama_model_loader: - kv  15:            tokenizer.ggml.unknown_token_id u32              = 3
llama_model_loader: - kv  16:                      tokenizer.ggml.tokens arr[str,256128]  = ["<pad>", "<eos>", "<bos>", "<unk>", ...
llama_model_loader: - kv  17:                      tokenizer.ggml.scores arr[f32,256128]  = [0.000000, 0.000000, 0.000000, 0.0000...
llama_model_loader: - kv  18:                  tokenizer.ggml.token_type arr[i32,256128]  = [3, 3, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  19:               general.quantization_version u32              = 2
llama_model_loader: - kv  20:                          general.file_type u32              = 1
llama_model_loader: - type  f32:   57 tensors
llama_model_loader: - type  f16:  196 tensors
llama_model_loader: - type q6_K:    1 tensors
llm_load_vocab: mismatch in special tokens definition ( 544/256128 vs 388/256128 ).
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = gemma
llm_load_print_meta: vocab type       = SPM
llm_load_print_meta: n_vocab          = 256128
llm_load_print_meta: n_merges         = 0
llm_load_print_meta: n_ctx_train      = 8192
llm_load_print_meta: n_embd           = 3072
llm_load_print_meta: n_head           = 16
llm_load_print_meta: n_head_kv        = 16
llm_load_print_meta: n_layer          = 28
llm_load_print_meta: n_rot            = 192
llm_load_print_meta: n_embd_head_k    = 256
llm_load_print_meta: n_embd_head_v    = 256
llm_load_print_meta: n_gqa            = 1
llm_load_print_meta: n_embd_k_gqa     = 4096
llm_load_print_meta: n_embd_v_gqa     = 4096
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-06
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: n_ff             = 24576
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 2
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx  = 8192
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: model type       = 7B
llm_load_print_meta: model ftype      = F16
llm_load_print_meta: model params     = 8.54 B
llm_load_print_meta: model size       = 15.04 GiB (15.13 BPW)
llm_load_print_meta: general.name     = gemma-7b-it
llm_load_print_meta: BOS token        = 2 '<bos>'
llm_load_print_meta: EOS token        = 1 '<eos>'
llm_load_print_meta: UNK token        = 3 '<unk>'
llm_load_print_meta: PAD token        = 0 '<pad>'
llm_load_print_meta: LF token         = 227 '<0x0A>'
llm_load_tensors: ggml ctx size =    0.19 MiB
llm_load_tensors: offloading 28 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 29/29 layers to GPU
llm_load_tensors:        CPU buffer size =   615.54 MiB
llm_load_tensors:      CUDA0 buffer size = 15400.21 MiB
...............................................................................................
llama_new_context_with_model: n_ctx      = 2048
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:      CUDA0 KV buffer size =   896.00 MiB
llama_new_context_with_model: KV self size  =  896.00 MiB, K (f16):  448.00 MiB, V (f16):  448.00 MiB
llama_new_context_with_model:  CUDA_Host input buffer size   =    56.04 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =  2025.00 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =    24.00 MiB
llama_new_context_with_model: graph splits (measure): 2

system_info: n_threads = 16 / 32 | AVX = 1 | AVX_VNNI = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 |
perplexity: tokenizing the input ..
perplexity: tokenization took 699.902 ms
perplexity: calculating perplexity over 142 chunks, batch_size=2048
perplexity: 0.74 seconds per pass - ETA 1.73 minutes
[1]19.2199,[2]25.5085,[3]28.2863,[4]24.9815,[5]23.6883,[6]19.3050,[7]17.5079,[8]17.1901,[9]18.1072,[10]18.1731,[11]17.9541,[12]18.8803,[13]18.8942,[14]19.3041,[15]19.5662,[16]20.3611,[17]20.3960,[18]20.5681,[19]20.0183,[20]20.0213,[21]19.9828,[22]20.1108,[23]20.2374,[24]20.2426,[25]20.8742,[26]20.8718,[27]21.6406,[28]21.8823,[29]21.9235,[30]21.9278,[31]21.8361,[32]21.4179,[33]21.5008,[34]21.3394,[35]20.8930,[36]20.3415,[37]19.8962,[38]19.5981,[39]19.1508,[40]18.8111,[41]18.8966,[42]19.2478,[43]19.6744,[44]19.7701,[45]20.0372,[46]20.2223,[47]20.3328,[48]20.4142,[49]20.2323,[50]20.3273,[51]20.1208,[52]19.9948,[53]19.8622,[54]19.6426,[55]19.5273,[56]19.2086,[57]19.1838,[58]19.1437,[59]19.2485,[60]19.4229,[61]19.5662,[62]19.8336,[63]19.8759,[64]19.6712,[65]19.6941,[66]19.5257,[67]19.4613,[68]19.4237,[69]19.1952,[70]19.1513,[71]19.3910,[72]19.5703,[73]19.4842,[74]19.4706,[75]19.4894,[76]19.5041,[77]19.5165,[78]19.5623,[79]19.6986,[80]19.5719,[81]19.4743,[82]19.3572,[83]19.3111,[84]19.2598,[85]19.1350,[86]19.1354,[87]19.2384,[88]19.2723,[89]19.4347,[90]19.6804,[91]19.8457,[92]19.9296,[93]20.0646,[94]20.2134,[95]20.2746,[96]20.2791,[97]20.2875,[98]20.3473,[99]20.3163,[100]20.3460,[101]20.4164,[102]20.5213,[103]20.5729,[104]20.5908,[105]20.6238,[106]20.5783,[107]20.5890,[108]20.6058,[109]20.4771,[110]20.4473,[111]20.3707,[112]20.3964,[113]20.4287,[114]20.4769,[115]20.4692,[116]20.4649,[117]20.4149,[118]20.4545,[119]20.4266,[120]20.3705,[121]20.3041,[122]20.2998,[123]20.3681,[124]20.3486,[125]20.3271,[126]20.2752,[127]20.2481,[128]20.2866,[129]20.2116,[130]20.2200,[131]20.2285,[132]20.2901,[133]20.3579,[134]20.2678,[135]20.0590,[136]20.0929,[137]20.1507,[138]20.2150,[139]20.2084,[140]20.2878,[141]20.2857,[142]20.3748,
Final estimate: PPL = 20.3748 +/- 0.21724
$ sha1sum models/gemma-7b*                                                                                                                                       
2cf27aa925ef6bb98255232f85f0df9a43278f4a  models/gemma-7b-f16.gguf
dffe52093bed13608d55387cbebbb3861bd072ff  models/gemma-7b-it-f16.gguf
0476921538163089c7564854a7d89417fcdc3b21  models/gemma-7b-it.gguf
5fc6c1bddc756971a56672fe26638eed9bc30a67  models/gemma-7b.gguf

jordankanter pushed a commit to jordankanter/llama.cpp that referenced this pull request Mar 13, 2024
There are couple things in this architecture:

1. Shared input and output embedding parameters.
2. Key length and value length are not derived from `n_embd`.

More information about the models can be found at
https://ai.google.dev/gemma. GGUFs can be downloaded from
https://huggingface.co/google.
hodlen pushed a commit to hodlen/llama.cpp that referenced this pull request Apr 1, 2024
There are couple things in this architecture:

1. Shared input and output embedding parameters.
2. Key length and value length are not derived from `n_embd`.

More information about the models can be found at
https://ai.google.dev/gemma. GGUFs can be downloaded from
https://huggingface.co/google.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet