Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add qwen2moe #6074

Merged
merged 11 commits into from Apr 16, 2024
Merged

Add qwen2moe #6074

merged 11 commits into from Apr 16, 2024

Conversation

simonJJJ
Copy link
Contributor

This PR adds the support of codes for the coming Qwen2 MoE models hf.
I changed several macro values to support the 60 experts setting. @ggerganov

ggml-backend.c Outdated
@@ -1003,13 +1004,14 @@ static bool ggml_is_view_op(enum ggml_op op) {
#endif

#ifndef GGML_SCHED_MAX_SPLIT_INPUTS
#define GGML_SCHED_MAX_SPLIT_INPUTS 16
#define GGML_SCHED_MAX_SPLIT_INPUTS 256
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the reason to increase this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's for the graph hash size.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just increase and test step-by-step from 16->32->64->128->256, and 256 works.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the problem that you are trying to solve with this change? This seems like a workaround for a different issue. What would be the correct fix? This constant will be removed in the short term in favor of dynamically allocating the array of splits, so whatever issue this is trying to workaround, we need to deal with that directly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is when calling ggml_gallocr_reserve_n in ggml_backend_sched_reserve, the hash value conflicts.

ggml.h Outdated
@@ -227,7 +227,7 @@
#define GGML_MAX_DIMS 4
#define GGML_MAX_PARAMS 2048
#define GGML_MAX_CONTEXTS 64
#define GGML_MAX_SRC 10
#define GGML_MAX_SRC 62
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, we are not well prepared for 60-expert MoE models

This change will more than double the size of ggml_tensor and I still don't know how to support many source buffers in Metal:

llama.cpp/ggml-metal.m

Lines 1750 to 1759 in 131b058

// TODO: how to make this an array? read Metal docs
for (int j = 0; j < 8; ++j) {
// NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
size_t offs_src_cur = 0;
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
}

What are our options here? cc @slaren

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The size of each expert is very small. The ffn_moe_intermediate size is divided by 4.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could store all the experts in the same tensor, I think it would be easy to adapt the implementations.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we merge this PR first and then try to resolve this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if we need to merge this right now. As it is, this will only work with the CPU backend, the sort operators in the GPU backends also require the number of experts to be a power of two, so assuming that it is a large model, this implementation may not be very useful. When we have a model to test with, we can improve the implementation.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember qwen2 has 1.8BX16. that make sense

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@slaren do you have some implementations to deal with this extra large number experts issue?

@simonJJJ
Copy link
Contributor Author

@ggerganov @slaren is there any update?

@sorasoras
Copy link

@simonJJJ @ggerganov
https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B is released.
so any update?
we can start testing it.

@andy-zhangtao
Copy link

This PR looks awesome, can't wait for it to get merged into the main branch. Could you quickly take a look at the failed test case and patch it up?

@ggerganov
Copy link
Owner

We have to resolve the issues outlined earlier in the discussion before we can merge this. Sorry for the delay, but without a proper CUDA and Metal support it does not make sense to merge. Now that there is a model available, it will be easier to complete

@jpmottin
Copy link

jpmottin commented Mar 30, 2024

I just tested the "Add qwen2moe #6074" PR...on METAL (M3)

CONVERT

I try to convert the HF model :
python3 convert.py ../../Models/Qwen1.5-MoE-A2.7B --outfile ../../models/Qwen1.5-MoE-A2.7B_1.gguf --vocab-type bpe --pad-vocab

I got this error :

Traceback (most recent call last):
  File "convert.py", line 1486, in <module>
    main()
  File "convert.py", line 1472, in main
    model   = convert_model_names(model, params, args.skip_unknown)
  File "convert.py", line 1217, in convert_model_names
    raise Exception("Unexpected tensor name: [tensor_name]. Use --skip-unknown to ignore it")
Exception: Unexpected tensor name: [tensor_name]. Use --skip-unknown to ignore it

To make it work, I have added : --skip-unknown

USAGE

I then had an other issue, number of experts not correct (0), so I modified/added properties in the config.json of the model.

  "num_local_experts":60,
  "num_local_experts_per_tok":60,
  "num_experts_per_tok": 4,
  "num_experts": 60,
  "moe": {
    "num_experts_per_tok": 4,
    "num_experts": 60
  },

Now I got this ERROR

llm_load_tensors: ggml ctx size =    7.10 MiB
llama_model_load: error loading model: create_tensor: tensor 'blk.0.ffn_gate.0.weight' has wrong shape; expected  2048,  5632, got  2048,  1408,     1,     1
llama_load_model_from_file: failed to load model
llama_init_from_gpt_params: error: failed to load model '../../Models/Qwen1.5-MoE-A2.7B_1.gguf'
{"function":"load_model","level":"ERR","line":678,"model":"../../Models/Qwen1.5-MoE-A2.7B_1.gguf","msg":"unable to load model","tid":"0x1dc4cdc40","timestamp":1711809559}

So far, I don't know yet how to fix this 😞 I hope to help others with this comment

@foldl
Copy link
Contributor

foldl commented Mar 31, 2024

We have to resolve the issues outlined earlier in the discussion before we can merge this. Sorry for the delay, but without a proper CUDA and Metal support it does not make sense to merge. Now that there is a model available, it will be easier to complete

I hope below update could be merged to ggml.

#define GGML_MAX_SRC 62

I have added this model to ChatLLM.cpp. It works for CPU. Maybe add a warning when GGML_MAX_SRC is larger when building for CUDA or Metal at present? After CUDA and Metal are supported, the warning can be removed.

@slaren
Copy link
Collaborator

slaren commented Mar 31, 2024

I hope below update could be merged to ggml.

#define GGML_MAX_SRC 62

It can't. Instead, we are moving all the experts to the same tensor. The work is being done in #6387.

@ggerganov
Copy link
Owner

@simonJJJ With #6387 now merged, this PR should be adapted and we can merge. Let us know if you would be able to take a look

@simonJJJ
Copy link
Contributor Author

@ggerganov just updated! hope to merge asap.

@simonJJJ simonJJJ requested a review from ggerganov April 15, 2024 18:20
ggml-backend.c Outdated Show resolved Hide resolved
ggml.h Outdated Show resolved Hide resolved
llama.cpp Outdated Show resolved Hide resolved
llama.cpp Outdated
ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, i, cur);
cb(cur_gate, "ffn_moe_gate", il);

cur_gate = ggml_silu(ctx0, cur_gate);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is now build_moe_ffn helper function for mixtral, grok, and dbrx

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the arch is not the same. qwen2moe has shared experts in each block.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The moe ffn still looks the same to me, and could be implemented using build_moe_ffn.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another diff is the "weights" is not normalized in qwen2moe.

llama.cpp Outdated Show resolved Hide resolved
@simonJJJ simonJJJ requested a review from ggerganov April 15, 2024 19:36
@ggerganov ggerganov requested a review from slaren April 15, 2024 19:43
llama.cpp Outdated
Comment on lines 8634 to 8650
// sigmoid
ggml_tensor * logits_shared_exp = ggml_silu(ctx0, gate_shared_exp);
cb(logits_shared_exp, "ffn_moe_logits_shared_exp", il);

ggml_tensor * probs_shared_exp = ggml_div(ctx0, logits_shared_exp, gate_shared_exp);
cb(probs_shared_exp, "ffn_moe_probs_shared_exp", il);

ggml_tensor * ffn_shared_exp = llm_build_ffn(ctx0, cur,
model.layers[il].ffn_up_shared_exp, NULL,
model.layers[il].ffn_gate_shared_exp, NULL,
model.layers[il].ffn_down_shared_exp, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(ffn_shared_exp, "ffn_moe_shared_exp", il);

ggml_tensor * ffn_shared_exp_out = ggml_mul(ctx0, ffn_shared_exp, probs_shared_exp);
cb(ffn_shared_exp_out, "ffn_moe_shared_exp_out", il);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand now that the purpose of the division is to transform the silu into a sigmoid, but the names of the variables could be more descriptive. Is calling the silu the "logits" and the sigmoid the "probabilities" really accurate here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"logits" ∈(−∞,∞) and "probs" ∈[0,1] look good to me. I can not find better options here.

llama.cpp Outdated Show resolved Hide resolved
Copy link
Contributor

github-actions bot commented Apr 16, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 432 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=10918.22ms p(95)=28187.57ms fails=, finish reason: stop=379 truncated=53
  • Prompt processing (pp): avg=124.15tk/s p(95)=570.07tk/s
  • Token generation (tg): avg=22.58tk/s p(95)=33.33tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=add_qwen2moe commit=245565fc6dcc3158018f0e43b687f9e9f1541afb

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 432 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1713282203 --> 1713282833
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 433.12, 433.12, 433.12, 433.12, 433.12, 607.67, 607.67, 607.67, 607.67, 607.67, 377.26, 377.26, 377.26, 377.26, 377.26, 400.49, 400.49, 400.49, 400.49, 400.49, 427.06, 427.06, 427.06, 427.06, 427.06, 459.82, 459.82, 459.82, 459.82, 459.82, 476.05, 476.05, 476.05, 476.05, 476.05, 482.47, 482.47, 482.47, 482.47, 482.47, 499.79, 499.79, 499.79, 499.79, 499.79, 518.08, 518.08, 518.08, 518.08, 518.08, 522.03, 522.03, 522.03, 522.03, 522.03, 534.18, 534.18, 534.18, 534.18, 534.18, 535.79, 535.79, 535.79, 535.79, 535.79, 553.4, 553.4, 553.4, 553.4, 553.4, 584.18, 584.18, 584.18, 584.18, 584.18, 591.37, 591.37, 591.37, 591.37, 591.37, 605.82, 605.82, 605.82, 605.82, 605.82, 574.53, 574.53, 574.53, 574.53, 574.53, 576.94, 576.94, 576.94, 576.94, 576.94, 580.27, 580.27, 580.27, 580.27, 580.27, 581.06, 581.06, 581.06, 581.06, 581.06, 594.95, 594.95, 594.95, 594.95, 594.95, 597.82, 597.82, 597.82, 597.82, 597.82, 599.24, 599.24, 599.24, 599.24, 599.24, 599.84, 599.84, 599.84, 599.84, 599.84, 603.82, 603.82, 603.82, 603.82, 603.82, 607.28, 607.28, 607.28, 607.28, 607.28, 609.37, 609.37, 609.37, 609.37, 609.37, 620.5, 620.5, 620.5, 620.5, 620.5, 620.81, 620.81, 620.81, 620.81, 620.81, 624.08, 624.08, 624.08, 624.08, 624.08, 624.56, 624.56, 624.56, 624.56, 624.56, 631.36, 631.36, 631.36, 631.36, 631.36, 634.08, 634.08, 634.08, 634.08, 634.08, 634.03, 634.03, 634.03, 634.03, 634.03, 634.36, 634.36, 634.36, 634.36, 634.36, 637.54, 637.54, 637.54, 637.54, 637.54, 640.36, 640.36, 640.36, 640.36, 640.36, 640.9, 640.9, 640.9, 640.9, 640.9, 640.19, 640.19, 640.19, 640.19, 640.19, 642.46, 642.46, 642.46, 642.46, 642.46, 644.0, 644.0, 644.0, 644.0, 644.0, 646.47, 646.47, 646.47, 646.47, 646.47, 651.14, 651.14, 651.14, 651.14, 651.14, 657.53, 657.53, 657.53, 657.53, 657.53, 659.57, 659.57, 659.57, 659.57, 659.57, 658.97, 658.97, 658.97, 658.97, 658.97, 658.61, 658.61, 658.61, 658.61, 658.61, 661.5, 661.5, 661.5, 661.5, 661.5, 664.25, 664.25, 664.25, 664.25, 664.25, 665.7, 665.7, 665.7, 665.7, 665.7, 664.71, 664.71, 664.71, 664.71, 664.71, 662.35, 662.35, 662.35, 662.35, 662.35, 659.9, 659.9, 659.9, 659.9, 659.9, 659.25, 659.25, 659.25, 659.25, 659.25, 658.23, 658.23, 658.23, 658.23, 658.23, 657.85, 657.85, 657.85, 657.85, 657.85, 657.96, 657.96, 657.96, 657.96, 657.96, 660.98, 660.98, 660.98, 660.98, 660.98, 661.89, 661.89, 661.89, 661.89, 661.89, 661.86]
                    
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 432 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1713282203 --> 1713282833
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 32.23, 32.23, 32.23, 32.23, 32.23, 32.41, 32.41, 32.41, 32.41, 32.41, 22.77, 22.77, 22.77, 22.77, 22.77, 23.15, 23.15, 23.15, 23.15, 23.15, 23.27, 23.27, 23.27, 23.27, 23.27, 23.25, 23.25, 23.25, 23.25, 23.25, 23.06, 23.06, 23.06, 23.06, 23.06, 23.76, 23.76, 23.76, 23.76, 23.76, 24.17, 24.17, 24.17, 24.17, 24.17, 24.48, 24.48, 24.48, 24.48, 24.48, 24.56, 24.56, 24.56, 24.56, 24.56, 24.61, 24.61, 24.61, 24.61, 24.61, 24.35, 24.35, 24.35, 24.35, 24.35, 24.16, 24.16, 24.16, 24.16, 24.16, 24.24, 24.24, 24.24, 24.24, 24.24, 23.51, 23.51, 23.51, 23.51, 23.51, 23.51, 23.51, 23.51, 23.51, 23.51, 23.04, 23.04, 23.04, 23.04, 23.04, 22.39, 22.39, 22.39, 22.39, 22.39, 22.41, 22.41, 22.41, 22.41, 22.41, 22.69, 22.69, 22.69, 22.69, 22.69, 22.71, 22.71, 22.71, 22.71, 22.71, 22.35, 22.35, 22.35, 22.35, 22.35, 22.17, 22.17, 22.17, 22.17, 22.17, 22.13, 22.13, 22.13, 22.13, 22.13, 21.92, 21.92, 21.92, 21.92, 21.92, 21.81, 21.81, 21.81, 21.81, 21.81, 21.83, 21.83, 21.83, 21.83, 21.83, 21.93, 21.93, 21.93, 21.93, 21.93, 21.8, 21.8, 21.8, 21.8, 21.8, 21.86, 21.86, 21.86, 21.86, 21.86, 21.93, 21.93, 21.93, 21.93, 21.93, 21.97, 21.97, 21.97, 21.97, 21.97, 21.77, 21.77, 21.77, 21.77, 21.77, 21.77, 21.77, 21.77, 21.77, 21.77, 21.83, 21.83, 21.83, 21.83, 21.83, 22.05, 22.05, 22.05, 22.05, 22.05, 22.1, 22.1, 22.1, 22.1, 22.1, 22.13, 22.13, 22.13, 22.13, 22.13, 22.16, 22.16, 22.16, 22.16, 22.16, 22.25, 22.25, 22.25, 22.25, 22.25, 22.29, 22.29, 22.29, 22.29, 22.29, 22.29, 22.29, 22.29, 22.29, 22.29, 22.3, 22.3, 22.3, 22.3, 22.3, 22.25, 22.25, 22.25, 22.25, 22.25, 22.22, 22.22, 22.22, 22.22, 22.22, 22.03, 22.03, 22.03, 22.03, 22.03, 22.07, 22.07, 22.07, 22.07, 22.07, 22.19, 22.19, 22.19, 22.19, 22.19, 22.31, 22.31, 22.31, 22.31, 22.31, 22.37, 22.37, 22.37, 22.37, 22.37, 22.39, 22.39, 22.39, 22.39, 22.39, 22.28, 22.28, 22.28, 22.28, 22.28, 22.06, 22.06, 22.06, 22.06, 22.06, 21.9, 21.9, 21.9, 21.9, 21.9, 21.66, 21.66, 21.66, 21.66, 21.66, 21.66, 21.66, 21.66, 21.66, 21.66, 20.79, 20.79, 20.79, 20.79, 20.79, 20.8, 20.8, 20.8, 20.8, 20.8, 20.85, 20.85, 20.85, 20.85, 20.85, 20.94]
                    

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 432 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1713282203 --> 1713282833
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.14, 0.14, 0.14, 0.14, 0.14, 0.4, 0.4, 0.4, 0.4, 0.4, 0.16, 0.16, 0.16, 0.16, 0.16, 0.22, 0.22, 0.22, 0.22, 0.22, 0.16, 0.16, 0.16, 0.16, 0.16, 0.19, 0.19, 0.19, 0.19, 0.19, 0.17, 0.17, 0.17, 0.17, 0.17, 0.16, 0.16, 0.16, 0.16, 0.16, 0.18, 0.18, 0.18, 0.18, 0.18, 0.14, 0.14, 0.14, 0.14, 0.14, 0.17, 0.17, 0.17, 0.17, 0.17, 0.18, 0.18, 0.18, 0.18, 0.18, 0.17, 0.17, 0.17, 0.17, 0.17, 0.12, 0.12, 0.12, 0.12, 0.12, 0.27, 0.27, 0.27, 0.27, 0.27, 0.14, 0.14, 0.14, 0.14, 0.14, 0.31, 0.31, 0.31, 0.31, 0.31, 0.26, 0.26, 0.26, 0.26, 0.26, 0.16, 0.16, 0.16, 0.16, 0.16, 0.15, 0.15, 0.15, 0.15, 0.15, 0.16, 0.16, 0.16, 0.16, 0.16, 0.26, 0.26, 0.26, 0.26, 0.26, 0.26, 0.26, 0.26, 0.26, 0.26, 0.27, 0.27, 0.27, 0.27, 0.27, 0.28, 0.28, 0.28, 0.28, 0.28, 0.14, 0.14, 0.14, 0.14, 0.14, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.22, 0.22, 0.22, 0.22, 0.22, 0.11, 0.11, 0.11, 0.11, 0.11, 0.14, 0.14, 0.14, 0.14, 0.14, 0.19, 0.19, 0.19, 0.19, 0.19, 0.22, 0.22, 0.22, 0.22, 0.22, 0.27, 0.27, 0.27, 0.27, 0.27, 0.14, 0.14, 0.14, 0.14, 0.14, 0.11, 0.11, 0.11, 0.11, 0.11, 0.12, 0.12, 0.12, 0.12, 0.12, 0.16, 0.16, 0.16, 0.16, 0.16, 0.19, 0.19, 0.19, 0.19, 0.19, 0.13, 0.13, 0.13, 0.13, 0.13, 0.14, 0.14, 0.14, 0.14, 0.14, 0.12, 0.12, 0.12, 0.12, 0.12, 0.21, 0.21, 0.21, 0.21, 0.21, 0.2, 0.2, 0.2, 0.2, 0.2, 0.27, 0.27, 0.27, 0.27, 0.27, 0.22, 0.22, 0.22, 0.22, 0.22, 0.12, 0.12, 0.12, 0.12, 0.12, 0.13, 0.13, 0.13, 0.13, 0.13, 0.15, 0.15, 0.15, 0.15, 0.15, 0.11, 0.11, 0.11, 0.11, 0.11, 0.14, 0.14, 0.14, 0.14, 0.14, 0.36, 0.36, 0.36, 0.36, 0.36, 0.52, 0.52, 0.52, 0.52, 0.52, 0.46, 0.46, 0.46, 0.46, 0.46, 0.45, 0.45, 0.45, 0.45, 0.45, 0.46, 0.46, 0.46, 0.46, 0.46, 0.5, 0.5, 0.5, 0.5, 0.5, 0.13, 0.13, 0.13, 0.13, 0.13, 0.12, 0.12, 0.12, 0.12, 0.12, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14]
                    
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 432 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1713282203 --> 1713282833
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 1.0, 1.0, 1.0, 1.0, 1.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 2.0, 2.0, 2.0, 2.0, 2.0, 1.0]
                    

@ggerganov
Copy link
Owner

I changed the "shared_exp" suffix to the shorter "shexp": f88e684

Seems more inline with the naming convention. This requires to reconvert the models.
@simonJJJ Any objections to this change?

@simonJJJ
Copy link
Contributor Author

I changed the "shared_exp" suffix to the shorter "shexp": f88e684

Seems more inline with the naming convention. This requires to reconvert the models. @simonJJJ Any objections to this change?

no

@simonJJJ
Copy link
Contributor Author

@ggerganov is there any check hinders the merging?

@ggerganov
Copy link
Owner

@ggerganov is there any check hinders the merging?

Should be ready soon - slaren will merge when done with the review

@@ -1723,6 +1751,7 @@ enum e_model {
MODEL_MEDIUM,
MODEL_LARGE,
MODEL_XL,
MODEL_A2_7B,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also needs an entry in llama_model_type_name for its string representation.

@ggerganov ggerganov merged commit f4dea7d into ggerganov:master Apr 16, 2024
54 of 59 checks passed
tybalex pushed a commit to tybalex/function.cpp that referenced this pull request Apr 17, 2024
* support qwen2moe

* fix-review

* metal : support unary ops for nelements % 4 != 0

* metal : require contiguousness for float4 unary kernels

* metal : require contiguousness for float4 unary kernels (cont)

* fix-review

* names : for brevity "SHARED_EXP" -> "SHEXP"

* llama : reuse build_moe_ffn()

* llama : add model type name

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants