Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "llama.h"

#include <ctime>
#include <cstdio>
#include <algorithm>

#if defined(_MSC_VER)
Expand Down Expand Up @@ -70,6 +71,29 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
}
}

// plain, pipe-friendly output: one embedding per line
static void print_raw_embeddings(const float * emb,
int n_embd_count,
int n_embd,
const llama_model * model,
enum llama_pooling_type pooling_type,
int embd_normalize) {
const uint32_t n_cls_out = llama_model_n_cls_out(model);
const bool is_rank = (pooling_type == LLAMA_POOLING_TYPE_RANK);
const int cols = is_rank ? std::min<int>(n_embd, (int) n_cls_out) : n_embd;

for (int j = 0; j < n_embd_count; ++j) {
for (int i = 0; i < cols; ++i) {
if (embd_normalize == 0) {
printf("%1.0f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : ""));
} else {
printf("%1.7f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : ""));
}
}
printf("\n");
}
}

int main(int argc, char ** argv) {
common_params params;

Expand Down Expand Up @@ -374,6 +398,10 @@ int main(int argc, char ** argv) {
if (notArray) LOG("\n}\n");
}

if (params.embd_out == "raw") {
print_raw_embeddings(emb, n_embd_count, n_embd, model, pooling_type, params.embd_normalize);
}

LOG("\n");
llama_perf_context_print(ctx);

Expand Down