Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow pooled embeddings on any model #7477

Merged
merged 6 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; }
else { invalid_param = true; }
return true;
}
Expand Down Expand Up @@ -1820,6 +1821,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param

options.push_back({ "backend" });
options.push_back({ "*", " --rpc SERVERS", "comma separated list of RPC servers" });

if (llama_supports_mlock()) {
options.push_back({ "*", " --mlock", "force system to keep model in RAM rather than swapping or compressing" });
}
Expand Down
38 changes: 27 additions & 11 deletions examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,25 @@ static std::vector<std::string> split_lines(const std::string & s) {
return lines;
}

static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
for (size_t i = 0; i < tokens.size(); i++) {
llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
static bool needs_logit(enum llama_pooling_type pooling_type, int pos, int n_tokens) {
switch (pooling_type) {
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_NONE:
return true;
case LLAMA_POOLING_TYPE_CLS:
return pos == 0;
case LLAMA_POOLING_TYPE_LAST:
return pos == n_tokens - 1;
default:
GGML_ASSERT(false && "unsupported pooling type");
}
}

static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id, enum llama_pooling_type pooling_type) {
size_t n_tokens = tokens.size();
for (size_t i = 0; i < n_tokens; i++) {
bool logit = needs_logit(pooling_type, i, n_tokens);
llama_batch_add(batch, tokens[i], i, { seq_id }, logit);
}
}

Expand All @@ -40,13 +56,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu

// try to get sequence embeddings - supported only when pooling_type is not NONE
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
if (embd == NULL) {
fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i);
continue;
}
}
Comment on lines 43 to -49
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove support for LLAMA_POOLING_TYPE_NONE in the embedding example?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly because we're not actually printing out the entire token level embeddings anyway. The way it was implemented before was essentially doing last token pooling (not necessarily the last position in the sequence though, just the last one in the order the batch was loaded), but now that last token pooling is an official option, may as well encourage the user to make that choice conciously.

GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");

float * out = output + batch.seq_id[i][0] * n_embd;
//TODO: I would also add a parameter here to enable normalization or not.
Expand Down Expand Up @@ -97,6 +107,12 @@ int main(int argc, char ** argv) {
const int n_ctx_train = llama_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx);

const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__);
return 1;
}

if (n_ctx > n_ctx_train) {
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
__func__, n_ctx_train, n_ctx);
Expand Down Expand Up @@ -176,7 +192,7 @@ int main(int argc, char ** argv) {
}

// add to batch
batch_add_seq(batch, inp, s);
batch_add_seq(batch, inp, s, pooling_type);
s += 1;
}

Expand Down
6 changes: 4 additions & 2 deletions examples/gritlm/gritlm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve

// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx);
llama_set_embeddings(ctx, true);
Copy link
Collaborator

@ngxson ngxson Jun 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a small question here: in the case when both embeddings and causal_attn are enabled, will it still be correct?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, it's possible to run with embeddings=true and causal_attn=true, as long as the underlying model supports causal attention. For the GritLM case, I just checked here, and it will run but give incorrect results since it expects the embeddings to be run non-causally.

llama_set_causal_attn(ctx, false);

// run model
Expand Down Expand Up @@ -98,7 +99,9 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
llama_token eos_token = llama_token_eos(mdl);

llama_kv_cache_clear(ctx);
llama_set_embeddings(ctx, false);
llama_set_causal_attn(ctx, true);

llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);

std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
Expand Down Expand Up @@ -166,8 +169,7 @@ int main(int argc, char * argv[]) {

llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams);

// create new context - set to embedding mode
cparams.embeddings = true;
// create generation context
llama_context * ctx = llama_new_context_with_model(mdl, cparams);

// ### Embedding/Representation ###
Expand Down
27 changes: 22 additions & 5 deletions examples/retrieval/retrieval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,25 @@ static std::vector<chunk> chunk_file(const std::string & filename, int chunk_siz
return chunks;
}

static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
for (size_t i = 0; i < tokens.size(); i++) {
llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
static bool needs_logit(enum llama_pooling_type pooling_type, int pos, int n_tokens) {
switch (pooling_type) {
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_NONE:
return true;
case LLAMA_POOLING_TYPE_CLS:
return pos == 0;
case LLAMA_POOLING_TYPE_LAST:
return pos == n_tokens - 1;
default:
GGML_ASSERT(false && "unsupported pooling type");
}
}

static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id, enum llama_pooling_type pooling_type) {
size_t n_tokens = tokens.size();
for (size_t i = 0; i < n_tokens; i++) {
bool logit = needs_logit(pooling_type, i, n_tokens);
llama_batch_add(batch, tokens[i], i, { seq_id }, logit);
}
}

Expand Down Expand Up @@ -159,6 +175,7 @@ int main(int argc, char ** argv) {

const int n_ctx_train = llama_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx);
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);

if (n_ctx > n_ctx_train) {
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
Expand Down Expand Up @@ -230,7 +247,7 @@ int main(int argc, char ** argv) {
}

// add to batch
batch_add_seq(batch, inp, s);
batch_add_seq(batch, inp, s, pooling_type);
s += 1;
}

Expand All @@ -253,7 +270,7 @@ int main(int argc, char ** argv) {
std::vector<int32_t> query_tokens = llama_tokenize(ctx, query, true);

struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1);
batch_add_seq(query_batch, query_tokens, 0);
batch_add_seq(query_batch, query_tokens, 0, pooling_type);

std::vector<float> query_emb(n_embd, 0);
batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd);
Expand Down
148 changes: 95 additions & 53 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7435,6 +7435,50 @@ struct llm_build_context {
return lctx.inp_s_seq;
}

struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) {
// find result_norm tensor for input
struct ggml_tensor * inp = nullptr;
for (int i = gf->n_nodes - 1; i >= 0; --i) {
inp = gf->nodes[i];
if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
break;
} else {
inp = nullptr;
}
}
GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor");

struct ggml_tensor * cur;

switch (pooling_type) {
case LLAMA_POOLING_TYPE_MEAN:
{
struct ggml_tensor * inp_mean = build_inp_mean();
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
} break;
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST:
{
struct ggml_tensor * inp_cls = build_inp_cls();
cur = ggml_get_rows(ctx0, inp, inp_cls);
} break;
case LLAMA_POOLING_TYPE_NONE:
{
cur = inp;
} break;
default:
{
GGML_ASSERT(false && "unknown pooling type");
} break;
}

cb(cur, "result_embd_pooled", -1);

ggml_build_forward_expand(gf, cur);

return gf;
}

struct ggml_cgraph * build_llama() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);

Expand Down Expand Up @@ -8415,8 +8459,6 @@ struct llm_build_context {
if (model.arch != LLM_ARCH_JINA_BERT_V2) {
inp_pos = build_inp_pos();
}
struct ggml_tensor * inp_mean = build_inp_mean();
struct ggml_tensor * inp_cls = build_inp_cls();

// construct input embeddings (token, type, position)
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
Expand Down Expand Up @@ -8591,28 +8633,6 @@ struct llm_build_context {
cur = inpL;
cb(cur, "result_embd", -1);

// pooling layer
switch (pooling_type) {
case LLAMA_POOLING_TYPE_NONE:
{
// nop
} break;
case LLAMA_POOLING_TYPE_MEAN:
{
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean);
cb(cur, "result_embd_pooled", -1);
} break;
case LLAMA_POOLING_TYPE_CLS:
{
cur = ggml_get_rows(ctx0, cur, inp_cls);
cb(cur, "result_embd_pooled", -1);
} break;
case LLAMA_POOLING_TYPE_UNSPECIFIED:
{
GGML_ASSERT(false && "Invalid pooling type");
} break;
}

ggml_build_forward_expand(gf, cur);

return gf;
Expand Down Expand Up @@ -11697,6 +11717,11 @@ static struct ggml_cgraph * llama_build_graph(
GGML_ASSERT(false);
}

// add on pooling layer
if (lctx.cparams.embeddings) {
result = llm.append_pooling(result);
}

llm.free();

return result;
Expand Down Expand Up @@ -11754,7 +11779,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
}

if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
if (!cparams.embeddings || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
Copy link
Collaborator

@compilade compilade Jun 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. Some outputs should still be skipped when embedding for some of the pooling types, no?

This will cause use of uninitialized lctx.inp_out_ids when embedding with non-Bert models with pooling types other than NONE.

This condition was there originally for how BERT managed output skipping.

if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {

Since batch.logits is likely correctly set when using pooled embeddings (at least, how you wrote them seems correct), then should this condition instead always be true?

And if that is done, then inp_cls would be redundant, since the correct rows would already be the only thing left.

Might be out of scope for this PR. What do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that makes sense. I'm guessing we want to avoid putting a pooling_type == LLAMA_POOLING_TYPE_NONE in every single other model? In that case, I guess we have to actually require all logits be set when getting non-NONE embeddings from non-Bert models. The downside is that it results in a needless get_rows on all the outputs.

In fact, it seems like batch.logits isn't really used when pooling_type is not NONE, since we use all the outputs and the results are stored in embd_seq_out. Or actually, all that's currently required is that at least one logit is requested so you go down the right branch when we check if lctx.n_outputs == 0 in llama_decode_internal. It seems like in this case we might want to officially ignore batch.logits and give priority to cparams.embeddings.

Copy link
Collaborator

@compilade compilade Jun 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the simpler way to fix this in the meantime is to make n_outputs == n_tokens_all in llama_decode_internal for all non-NONE pooling types when cparams.embeddings is true, even when batch.logits is set. This would then re-use the same logic as logits_all in the other places that use n_outputs.

But I think the CLS and LAST pooling types could eventually skip computing the embeddings they don't need (but it's not necessary to do this in this PR).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I think this should do it. Basically bypass logits when doing non-NONE embeddings. Note that I'm using hparams.causal_attn to decide if we're in a BERT model or not in llama_set_inputs.

GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
const int64_t n_tokens = batch.n_tokens;

Expand Down Expand Up @@ -11786,7 +11811,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
// (!a || b) is a logical implication (a -> b)
// !hparams.causal_attn -> !cparams.causal_attn
(hparams.causal_attn || !cparams.causal_attn) &&
"causal attention with embedding models is not supported"
"causal attention is not supported by this model"
);

if (lctx.inp_KQ_mask) {
Expand Down Expand Up @@ -11918,6 +11943,37 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}

if (cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
const int64_t n_tokens = batch.n_tokens;

GGML_ASSERT(lctx.inp_cls);
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));

uint32_t * data = (uint32_t *) lctx.inp_cls->data;
memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));

std::vector<int> last_pos(n_tokens, -1);
std::vector<int> last_row(n_tokens, -1);

for (int i = 0; i < n_tokens; ++i) {
const llama_seq_id seq_id = batch.seq_id[i][0];
const llama_pos pos = batch.pos[i];

GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");

if (pos >= last_pos[seq_id]) {
last_pos[seq_id] = pos;
last_row[seq_id] = i;
}
}

for (int i = 0; i < n_tokens; ++i) {
if (last_row[i] >= 0) {
data[i] = last_row[i];
}
}
}

if (kv_self.recurrent) {
const int64_t n_kv = kv_self.n;

Expand Down Expand Up @@ -11979,8 +12035,8 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
const auto n_embd = hparams.n_embd;

// TODO: use a per-batch flag for logits presence instead
const bool has_logits = cparams.causal_attn;
const bool has_embd = cparams.embeddings && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
const bool has_logits = !cparams.embeddings;
const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);

const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;
Expand Down Expand Up @@ -12245,30 +12301,13 @@ static int llama_decode_internal(
// no output
res = nullptr;
embd = nullptr;
} else if (!hparams.causal_attn) {
res = nullptr; // do not extract logits for embedding models such as BERT

// token or sequence embeddings
embd = gf->nodes[gf->n_nodes - 1];

GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0);
} else if (cparams.embeddings) {
// the embeddings could be in the second to last tensor, or any of the previous tensors
int i_embd = gf->n_nodes - 2;
for (int i = 3; strcmp(embd->name, "result_norm") != 0; ++i) {
i_embd = gf->n_nodes - i;
if (i_embd < 0) { break; }
embd = gf->nodes[i_embd];
}
GGML_ASSERT(i_embd >= 0 && "missing result_norm tensor");

// TODO: use a per-batch flag to know when to skip logits while keeping embeddings
if (!cparams.causal_attn) {
res = nullptr; // do not extract logits when not needed
// skip computing logits
// TODO: is this safe?
gf->n_nodes = i_embd + 1;
res = nullptr; // do not extract logits for embedding case
embd = gf->nodes[gf->n_nodes - 1];
if (strcmp(embd->name, "result_embd_pooled") != 0) {
embd = gf->nodes[gf->n_nodes - 2];
}
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
} else {
embd = nullptr; // do not extract embeddings when not needed
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
Comment on lines -12248 to 12315
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So an embeddings model will crash on the first decode when cparams.embeddings is set to false?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, though I can't think of any case where you'd use an embedding model without cparams.embeddings. I guess there's nothing really indicating something is an embedding model other than the lack of a result_output tensor, so it's hard to intercept this earlier and give an error.

Copy link
Collaborator

@compilade compilade May 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, hparams.causal_attn is false for BERT at least, and it's the only embedding-only model architecture currently in llama.cpp. All BERT-like architectures also set this key to false when converted to GGUF. It's true by default, and by extension, for all other models.

There might be a need for a dedicated metadata key-value pair for embedding-only models if non-causal text generation models are a thing. (T5? Or is it causal?) Anyway, cparams.causal_attn can be used to get non-causal attention with any model, I think (I did not test this), except for recurrent models (Mamba).

I think there should at least be some abstraction (exported in llama.h) to know whether or not a model can provide embeddings and/or logits. This would make things like #7448 easier, even if it initially relies on hparams.causal_attn.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, so at least for now, it looks like hparams.causal_attn is a good indicator of whether a model is embedding-only. And I can't imagine a generative model with non-causal attention. I think T5 is causal, at least for the decoder part.

Then I guess we want to assert hparams.causal_attn || cparams.embeddings at some point. That way we don't have to worry about divergence and the error is caught earlier.

Comment on lines 12303 to 12315
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are places that need to know when embeddings or logits will be output, like llama_output_reserve

llama.cpp/llama.cpp

Lines 11064 to 11065 in cd93a28

const bool has_logits = cparams.causal_attn;
const bool has_embd = cparams.embeddings && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);

This will need to be updated to reflect exactly how this affects what happens later in this function near the comments // extract logits and // extract embeddings.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So can we get away with saying you're either getting logits or embeddings but never both, and that behavior is exclusively controlled by cparams.embeddings? In that case we could just have

const bool has_logits = !cparams.embeddings;
const bool has_embd   =  cparams.embeddings;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I can't really think of a use-case where both would be needed at the same time. Except maybe for a server serving both completions and embeddings out of the same model. So that's something to consider.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, but for a given call to llama_decode presumably you would never want both. For the gritlm example, I actually just made two contexts, one for generation one for embeddings. Another option would be to add a llama_set_embeddings function.

Expand Down Expand Up @@ -12337,11 +12376,10 @@ static int llama_decode_internal(
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST:
{
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0);

// extract sequence embeddings
auto & embd_seq_out = lctx.embd_seq;
embd_seq_out.clear();
Expand Down Expand Up @@ -17870,6 +17908,10 @@ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)
ctx->abort_callback_data = abort_callback_data;
}

void llama_set_embeddings(struct llama_context * ctx, bool embeddings) {
ctx->cparams.embeddings = embeddings;
}

void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
ctx->cparams.causal_attn = causal_attn;
}
Expand Down
Loading
Loading