Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow pooled embeddings on any model #7477

Merged
merged 6 commits into from
Jun 21, 2024

Conversation

iamlemec
Copy link
Collaborator

This allows one to compute pooled embeddings on any model, not just classical embedding models. This is increasingly useful due to the rise of generative-type models in embedding benchmarks (most recently, gte-Qwen1.5-7B-instruct). The main changes are:

  • Add an append_pooling function to llm_build_context that grafts a pooling layer onto the last tensor of an existing graph. This makes some assumptions about how the underlying graph is laid out, but we're already doing that in a couple of places, and there are tensor name checks too.
  • Add a LLAMA_POOLING_TYPE_LAST pooling type since this is a common type of pooling used with generative models. Works very similarly to CLS pooling.
  • Allow user to specify attention type on context creation (causal, non-causal, or default for model).
  • Updates embedding/retreival examples to request correct logits depending on pooling_type.

Copy link
Contributor

github-actions bot commented May 22, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 545 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=8607.14ms p(95)=20155.29ms fails=, finish reason: stop=492 truncated=53
  • Prompt processing (pp): avg=103.99tk/s p(95)=449.3tk/s
  • Token generation (tg): avg=32.11tk/s p(95)=47.93tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=append-pooling commit=eb35a6ca4dc1f423071a2374910886407bc3ef61

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 545 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1717762287 --> 1717762919
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 639.65, 639.65, 639.65, 639.65, 639.65, 621.73, 621.73, 621.73, 621.73, 621.73, 639.18, 639.18, 639.18, 639.18, 639.18, 660.61, 660.61, 660.61, 660.61, 660.61, 746.06, 746.06, 746.06, 746.06, 746.06, 748.82, 748.82, 748.82, 748.82, 748.82, 768.93, 768.93, 768.93, 768.93, 768.93, 785.23, 785.23, 785.23, 785.23, 785.23, 798.22, 798.22, 798.22, 798.22, 798.22, 799.19, 799.19, 799.19, 799.19, 799.19, 824.07, 824.07, 824.07, 824.07, 824.07, 860.95, 860.95, 860.95, 860.95, 860.95, 890.2, 890.2, 890.2, 890.2, 890.2, 833.4, 833.4, 833.4, 833.4, 833.4, 813.82, 813.82, 813.82, 813.82, 813.82, 820.78, 820.78, 820.78, 820.78, 820.78, 832.2, 832.2, 832.2, 832.2, 832.2, 811.16, 811.16, 811.16, 811.16, 811.16, 813.58, 813.58, 813.58, 813.58, 813.58, 819.53, 819.53, 819.53, 819.53, 819.53, 822.15, 822.15, 822.15, 822.15, 822.15, 835.21, 835.21, 835.21, 835.21, 835.21, 801.57, 801.57, 801.57, 801.57, 801.57, 803.82, 803.82, 803.82, 803.82, 803.82, 822.43, 822.43, 822.43, 822.43, 822.43, 819.74, 819.74, 819.74, 819.74, 819.74, 820.66, 820.66, 820.66, 820.66, 820.66, 821.62, 821.62, 821.62, 821.62, 821.62, 823.01, 823.01, 823.01, 823.01, 823.01, 825.52, 825.52, 825.52, 825.52, 825.52, 824.11, 824.11, 824.11, 824.11, 824.11, 828.65, 828.65, 828.65, 828.65, 828.65, 827.87, 827.87, 827.87, 827.87, 827.87, 837.8, 837.8, 837.8, 837.8, 837.8, 831.75, 831.75, 831.75, 831.75, 831.75, 841.25, 841.25, 841.25, 841.25, 841.25, 839.76, 839.76, 839.76, 839.76, 839.76, 839.46, 839.46, 839.46, 839.46, 839.46, 841.15, 841.15, 841.15, 841.15, 841.15, 843.47, 843.47, 843.47, 843.47, 843.47, 842.96, 842.96, 842.96, 842.96, 842.96, 833.58, 833.58, 833.58, 833.58, 833.58, 830.96, 830.96, 830.96, 830.96, 830.96, 830.8, 830.8, 830.8, 830.8, 830.8, 828.87, 828.87, 828.87, 828.87, 828.87, 818.5, 818.5, 818.5, 818.5, 818.5, 821.8, 821.8, 821.8, 821.8, 821.8, 820.92, 820.92, 820.92, 820.92, 820.92, 822.1, 822.1, 822.1, 822.1, 822.1, 822.93, 822.93, 822.93, 822.93, 822.93, 824.89, 824.89, 824.89, 824.89, 824.89, 828.34, 828.34, 828.34, 828.34, 828.34, 827.37, 827.37, 827.37, 827.37, 827.37, 828.65, 828.65, 828.65, 828.65, 828.65, 827.94, 827.94, 827.94, 827.94, 827.94, 828.91, 828.91, 828.91, 828.91, 828.91, 828.25, 828.25, 828.25, 828.25, 828.25, 829.49, 829.49, 829.49, 829.49, 829.49, 829.09, 829.09, 829.09, 829.09, 829.09, 831.55, 831.55, 831.55, 831.55, 831.55, 834.35, 834.35]
                    
Loading
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 545 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1717762287 --> 1717762919
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 42.57, 42.57, 42.57, 42.57, 42.57, 30.4, 30.4, 30.4, 30.4, 30.4, 28.19, 28.19, 28.19, 28.19, 28.19, 29.65, 29.65, 29.65, 29.65, 29.65, 29.26, 29.26, 29.26, 29.26, 29.26, 31.71, 31.71, 31.71, 31.71, 31.71, 32.48, 32.48, 32.48, 32.48, 32.48, 32.58, 32.58, 32.58, 32.58, 32.58, 32.79, 32.79, 32.79, 32.79, 32.79, 32.82, 32.82, 32.82, 32.82, 32.82, 32.99, 32.99, 32.99, 32.99, 32.99, 32.65, 32.65, 32.65, 32.65, 32.65, 31.42, 31.42, 31.42, 31.42, 31.42, 31.09, 31.09, 31.09, 31.09, 31.09, 30.67, 30.67, 30.67, 30.67, 30.67, 29.47, 29.47, 29.47, 29.47, 29.47, 30.01, 30.01, 30.01, 30.01, 30.01, 30.04, 30.04, 30.04, 30.04, 30.04, 30.18, 30.18, 30.18, 30.18, 30.18, 30.24, 30.24, 30.24, 30.24, 30.24, 30.42, 30.42, 30.42, 30.42, 30.42, 30.72, 30.72, 30.72, 30.72, 30.72, 30.57, 30.57, 30.57, 30.57, 30.57, 30.66, 30.66, 30.66, 30.66, 30.66, 30.93, 30.93, 30.93, 30.93, 30.93, 30.79, 30.79, 30.79, 30.79, 30.79, 30.73, 30.73, 30.73, 30.73, 30.73, 30.82, 30.82, 30.82, 30.82, 30.82, 30.94, 30.94, 30.94, 30.94, 30.94, 31.0, 31.0, 31.0, 31.0, 31.0, 31.1, 31.1, 31.1, 31.1, 31.1, 31.22, 31.22, 31.22, 31.22, 31.22, 31.23, 31.23, 31.23, 31.23, 31.23, 31.07, 31.07, 31.07, 31.07, 31.07, 30.95, 30.95, 30.95, 30.95, 30.95, 30.68, 30.68, 30.68, 30.68, 30.68, 30.5, 30.5, 30.5, 30.5, 30.5, 30.51, 30.51, 30.51, 30.51, 30.51, 30.71, 30.71, 30.71, 30.71, 30.71, 30.74, 30.74, 30.74, 30.74, 30.74, 30.85, 30.85, 30.85, 30.85, 30.85, 30.97, 30.97, 30.97, 30.97, 30.97, 30.62, 30.62, 30.62, 30.62, 30.62, 30.22, 30.22, 30.22, 30.22, 30.22, 29.63, 29.63, 29.63, 29.63, 29.63, 29.22, 29.22, 29.22, 29.22, 29.22, 29.22, 29.22, 29.22, 29.22, 29.22, 29.24, 29.24, 29.24, 29.24, 29.24, 29.34, 29.34, 29.34, 29.34, 29.34, 29.36, 29.36, 29.36, 29.36, 29.36, 29.46, 29.46, 29.46, 29.46, 29.46, 29.49, 29.49, 29.49, 29.49, 29.49, 29.46, 29.46, 29.46, 29.46, 29.46, 29.42, 29.42, 29.42, 29.42, 29.42, 29.42, 29.42, 29.42, 29.42, 29.42, 29.45, 29.45, 29.45, 29.45, 29.45, 29.52, 29.52, 29.52, 29.52, 29.52, 29.57, 29.57, 29.57, 29.57, 29.57, 29.66, 29.66, 29.66, 29.66, 29.66, 29.76, 29.76, 29.76, 29.76, 29.76, 29.79, 29.79]
                    
Loading

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 545 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1717762287 --> 1717762919
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.27, 0.27, 0.27, 0.27, 0.27, 0.31, 0.31, 0.31, 0.31, 0.31, 0.24, 0.24, 0.24, 0.24, 0.24, 0.11, 0.11, 0.11, 0.11, 0.11, 0.18, 0.18, 0.18, 0.18, 0.18, 0.13, 0.13, 0.13, 0.13, 0.13, 0.17, 0.17, 0.17, 0.17, 0.17, 0.18, 0.18, 0.18, 0.18, 0.18, 0.26, 0.26, 0.26, 0.26, 0.26, 0.1, 0.1, 0.1, 0.1, 0.1, 0.24, 0.24, 0.24, 0.24, 0.24, 0.32, 0.32, 0.32, 0.32, 0.32, 0.26, 0.26, 0.26, 0.26, 0.26, 0.38, 0.38, 0.38, 0.38, 0.38, 0.35, 0.35, 0.35, 0.35, 0.35, 0.18, 0.18, 0.18, 0.18, 0.18, 0.25, 0.25, 0.25, 0.25, 0.25, 0.11, 0.11, 0.11, 0.11, 0.11, 0.16, 0.16, 0.16, 0.16, 0.16, 0.14, 0.14, 0.14, 0.14, 0.14, 0.17, 0.17, 0.17, 0.17, 0.17, 0.26, 0.26, 0.26, 0.26, 0.26, 0.12, 0.12, 0.12, 0.12, 0.12, 0.14, 0.14, 0.14, 0.14, 0.14, 0.32, 0.32, 0.32, 0.32, 0.32, 0.17, 0.17, 0.17, 0.17, 0.17, 0.13, 0.13, 0.13, 0.13, 0.13, 0.16, 0.16, 0.16, 0.16, 0.16, 0.19, 0.19, 0.19, 0.19, 0.19, 0.17, 0.17, 0.17, 0.17, 0.17, 0.18, 0.18, 0.18, 0.18, 0.18, 0.16, 0.16, 0.16, 0.16, 0.16, 0.22, 0.22, 0.22, 0.22, 0.22, 0.24, 0.24, 0.24, 0.24, 0.24, 0.22, 0.22, 0.22, 0.22, 0.22, 0.34, 0.34, 0.34, 0.34, 0.34, 0.21, 0.21, 0.21, 0.21, 0.21, 0.11, 0.11, 0.11, 0.11, 0.11, 0.12, 0.12, 0.12, 0.12, 0.12, 0.1, 0.1, 0.1, 0.1, 0.1, 0.16, 0.16, 0.16, 0.16, 0.16, 0.38, 0.38, 0.38, 0.38, 0.38, 0.55, 0.55, 0.55, 0.55, 0.55, 0.51, 0.51, 0.51, 0.51, 0.51, 0.47, 0.47, 0.47, 0.47, 0.47, 0.16, 0.16, 0.16, 0.16, 0.16, 0.22, 0.22, 0.22, 0.22, 0.22, 0.15, 0.15, 0.15, 0.15, 0.15, 0.14, 0.14, 0.14, 0.14, 0.14, 0.15, 0.15, 0.15, 0.15, 0.15, 0.16, 0.16, 0.16, 0.16, 0.16, 0.29, 0.29, 0.29, 0.29, 0.29, 0.18, 0.18, 0.18, 0.18, 0.18, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.2, 0.2, 0.2, 0.2, 0.2, 0.21, 0.21, 0.21, 0.21, 0.21, 0.15, 0.15, 0.15, 0.15, 0.15, 0.16, 0.16, 0.16, 0.16, 0.16, 0.14, 0.14, 0.14, 0.14, 0.14, 0.22, 0.22]
                    
Loading
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 545 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1717762287 --> 1717762919
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.0, 4.0, 4.0, 4.0, 4.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 2.0, 2.0, 2.0, 2.0, 2.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0, 1.0, 1.0, 1.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 1.0, 1.0, 1.0, 1.0, 1.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0]
                    
Loading

Copy link
Collaborator

@compilade compilade left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting and seems like this is going to be useful. Otherwise only BERT was affected by pooling types.

llama.cpp Outdated
Comment on lines 7045 to 7048
struct ggml_tensor * inp = gf->nodes[gf->n_nodes - 1];
if (strcmp(inp->name, "result_embd") != 0) {
inp = gf->nodes[gf->n_nodes - 2];
GGML_ASSERT(strcmp(inp->name, "result_norm") == 0 && "embeddings tensor not found");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably won't work for Grok, Phi 2, MiniCPM, and Command R, as their "result_norm" is the 3rd (or sometimes 4th for Command R) last tensor.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, brining back the backwards search for result_norm.

Comment on lines -11325 to 12313
} else if (!hparams.causal_attn) {
res = nullptr; // do not extract logits for embedding models such as BERT

// token or sequence embeddings
embd = gf->nodes[gf->n_nodes - 1];

GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0);
} else if (cparams.embeddings) {
// the embeddings could be in the second to last tensor, or any of the previous tensors
int i_embd = gf->n_nodes - 2;
for (int i = 3; strcmp(embd->name, "result_norm") != 0; ++i) {
i_embd = gf->n_nodes - i;
if (i_embd < 0) { break; }
embd = gf->nodes[i_embd];
}
GGML_ASSERT(i_embd >= 0 && "missing result_norm tensor");

// TODO: use a per-batch flag to know when to skip logits while keeping embeddings
if (!cparams.causal_attn) {
res = nullptr; // do not extract logits when not needed
// skip computing logits
// TODO: is this safe?
gf->n_nodes = i_embd + 1;
res = nullptr; // do not extract logits for embedding case
embd = gf->nodes[gf->n_nodes - 1];
if (strcmp(embd->name, "result_embd_pooled") != 0) {
embd = gf->nodes[gf->n_nodes - 2];
}
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
} else {
embd = nullptr; // do not extract embeddings when not needed
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So an embeddings model will crash on the first decode when cparams.embeddings is set to false?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, though I can't think of any case where you'd use an embedding model without cparams.embeddings. I guess there's nothing really indicating something is an embedding model other than the lack of a result_output tensor, so it's hard to intercept this earlier and give an error.

Copy link
Collaborator

@compilade compilade May 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, hparams.causal_attn is false for BERT at least, and it's the only embedding-only model architecture currently in llama.cpp. All BERT-like architectures also set this key to false when converted to GGUF. It's true by default, and by extension, for all other models.

There might be a need for a dedicated metadata key-value pair for embedding-only models if non-causal text generation models are a thing. (T5? Or is it causal?) Anyway, cparams.causal_attn can be used to get non-causal attention with any model, I think (I did not test this), except for recurrent models (Mamba).

I think there should at least be some abstraction (exported in llama.h) to know whether or not a model can provide embeddings and/or logits. This would make things like #7448 easier, even if it initially relies on hparams.causal_attn.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, so at least for now, it looks like hparams.causal_attn is a good indicator of whether a model is embedding-only. And I can't imagine a generative model with non-causal attention. I think T5 is causal, at least for the decoder part.

Then I guess we want to assert hparams.causal_attn || cparams.embeddings at some point. That way we don't have to worry about divergence and the error is caught earlier.

Comment on lines 58 to -49
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
if (embd == NULL) {
fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i);
continue;
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove support for LLAMA_POOLING_TYPE_NONE in the embedding example?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly because we're not actually printing out the entire token level embeddings anyway. The way it was implemented before was essentially doing last token pooling (not necessarily the last position in the sequence though, just the last one in the order the batch was loaded), but now that last token pooling is an official option, may as well encourage the user to make that choice conciously.

examples/retrieval/retrieval.cpp Outdated Show resolved Hide resolved
examples/embedding/embedding.cpp Outdated Show resolved Hide resolved
Comment on lines 11372 to 12313
// no output
res = nullptr;
embd = nullptr;
} else if (!hparams.causal_attn) {
res = nullptr; // do not extract logits for embedding models such as BERT

// token or sequence embeddings
embd = gf->nodes[gf->n_nodes - 1];

GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0);
} else if (cparams.embeddings) {
// the embeddings could be in the second to last tensor, or any of the previous tensors
int i_embd = gf->n_nodes - 2;
for (int i = 3; strcmp(embd->name, "result_norm") != 0; ++i) {
i_embd = gf->n_nodes - i;
if (i_embd < 0) { break; }
embd = gf->nodes[i_embd];
}
GGML_ASSERT(i_embd >= 0 && "missing result_norm tensor");

// TODO: use a per-batch flag to know when to skip logits while keeping embeddings
if (!cparams.causal_attn) {
res = nullptr; // do not extract logits when not needed
// skip computing logits
// TODO: is this safe?
gf->n_nodes = i_embd + 1;
res = nullptr; // do not extract logits for embedding case
embd = gf->nodes[gf->n_nodes - 1];
if (strcmp(embd->name, "result_embd_pooled") != 0) {
embd = gf->nodes[gf->n_nodes - 2];
}
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
} else {
embd = nullptr; // do not extract embeddings when not needed
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are places that need to know when embeddings or logits will be output, like llama_output_reserve

llama.cpp/llama.cpp

Lines 11064 to 11065 in cd93a28

const bool has_logits = cparams.causal_attn;
const bool has_embd = cparams.embeddings && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);

This will need to be updated to reflect exactly how this affects what happens later in this function near the comments // extract logits and // extract embeddings.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So can we get away with saying you're either getting logits or embeddings but never both, and that behavior is exclusively controlled by cparams.embeddings? In that case we could just have

const bool has_logits = !cparams.embeddings;
const bool has_embd   =  cparams.embeddings;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I can't really think of a use-case where both would be needed at the same time. Except maybe for a server serving both completions and embeddings out of the same model. So that's something to consider.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, but for a given call to llama_decode presumably you would never want both. For the gritlm example, I actually just made two contexts, one for generation one for embeddings. Another option would be to add a llama_set_embeddings function.

@compilade compilade added enhancement New feature or request embeddings embedding related topics Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level labels May 22, 2024
llama.cpp Outdated
Comment on lines 11064 to 11121
const bool has_logits = cparams.causal_attn;
const bool has_embd = cparams.embeddings && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
const bool has_logits = !cparams.embeddings;
const bool has_embd = cparams.embeddings;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that lctx.embd is not used by all pooled embeddings types, it's really only used with LLAMA_POOLING_TYPE_NONE.

(This is all done near the end of llama_decode_internal in a switch statement)

So maybe the condition for has_embd could be cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE?

(See also the other places where hparams.causal_attn was used to understand the assumptions that stem from it, to check if they need to be modified)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, that makes sense. I put in the pooling_type check you suggested in there. I also changed the inp_out_ids calculation to rely on cparams.embeddings rather than hparams.causal_attn.

llama.h Outdated
@@ -275,7 +282,7 @@ extern "C" {

enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
// (ignored if no pooling layer)
enum llama_attention_type attention_type; // causal, non-causal, or unspecified
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could be missing something, but it seems that attention_type does not bring any value over the existing (h/c)params.causal_attn + llama_set_causal_attn(). Do we need both?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, actually I don't think it gives you any new capabilities. Perhaps it's best to keep it the way it is and avoid breaking changes. Will switch back!

Copy link
Collaborator Author

@iamlemec iamlemec Jun 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only inconvenience is that it makes it slightly awkward to specify as an CLI flag in the examples.

One option would be to have both attention_type in the constructor and llama_set_causal_attn?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, took out the attention_type stuff, should work now. If server tests are failing, does that mean I have to rebase on master?

llama.h Show resolved Hide resolved
@ggerganov ggerganov requested a review from compilade June 7, 2024 11:50
@ngxson
Copy link
Collaborator

ngxson commented Jun 14, 2024

@iamlemec This PR looks good and would be very useful. Could you update it to latest master to see if it compiles on CI?

Regarding model crash on warming up the model with an empty run (see this CI run), do you think it's related to one of the change in this PR?

@iamlemec
Copy link
Collaborator Author

@ngxson Just rebased onto master. Let's see if that server error persists. I'm on a weak laptop and internet connection right now, but will double check things later today.

llama.cpp Outdated
@@ -11754,7 +11779,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
}

if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
if (!cparams.embeddings || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
Copy link
Collaborator

@compilade compilade Jun 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. Some outputs should still be skipped when embedding for some of the pooling types, no?

This will cause use of uninitialized lctx.inp_out_ids when embedding with non-Bert models with pooling types other than NONE.

This condition was there originally for how BERT managed output skipping.

if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {

Since batch.logits is likely correctly set when using pooled embeddings (at least, how you wrote them seems correct), then should this condition instead always be true?

And if that is done, then inp_cls would be redundant, since the correct rows would already be the only thing left.

Might be out of scope for this PR. What do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that makes sense. I'm guessing we want to avoid putting a pooling_type == LLAMA_POOLING_TYPE_NONE in every single other model? In that case, I guess we have to actually require all logits be set when getting non-NONE embeddings from non-Bert models. The downside is that it results in a needless get_rows on all the outputs.

In fact, it seems like batch.logits isn't really used when pooling_type is not NONE, since we use all the outputs and the results are stored in embd_seq_out. Or actually, all that's currently required is that at least one logit is requested so you go down the right branch when we check if lctx.n_outputs == 0 in llama_decode_internal. It seems like in this case we might want to officially ignore batch.logits and give priority to cparams.embeddings.

Copy link
Collaborator

@compilade compilade Jun 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the simpler way to fix this in the meantime is to make n_outputs == n_tokens_all in llama_decode_internal for all non-NONE pooling types when cparams.embeddings is true, even when batch.logits is set. This would then re-use the same logic as logits_all in the other places that use n_outputs.

But I think the CLS and LAST pooling types could eventually skip computing the embeddings they don't need (but it's not necessary to do this in this PR).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I think this should do it. Basically bypass logits when doing non-NONE embeddings. Note that I'm using hparams.causal_attn to decide if we're in a BERT model or not in llama_set_inputs.

@iamlemec
Copy link
Collaborator Author

@compilade sorry to ping you again, but I think this is ready to go. I believe the issues with n_outputs are sorted, and I've tested with a few different models, both embedding and non-embedding.

Copy link
Collaborator

@compilade compilade left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry to ping you again

Pinging is the right thing, because otherwise I tend to forget to go back and re-review, unless recent activity catches my attention enough to have another look at the changes.

I believe the issues with n_outputs are sorted

I believe this as well.

@@ -44,6 +44,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve

// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx);
llama_set_embeddings(ctx, true);
Copy link
Collaborator

@ngxson ngxson Jun 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a small question here: in the case when both embeddings and causal_attn are enabled, will it still be correct?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, it's possible to run with embeddings=true and causal_attn=true, as long as the underlying model supports causal attention. For the GritLM case, I just checked here, and it will run but give incorrect results since it expects the embeddings to be run non-causally.

@ggerganov ggerganov merged commit 80ea089 into ggerganov:master Jun 21, 2024
66 checks passed
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Jun 30, 2024
* create append_pooling operation; allow to specify attention_type; add last token pooling; update examples

* find result_norm/result_embd tensors properly; update output allocation logic

* only use embd output for pooling_type NONE

* get rid of old causal_attn accessor

* take out attention_type; add in llama_set_embeddings

* bypass logits when doing non-NONE pooling
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
embeddings embedding related topics enhancement New feature or request examples Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants